diff --git a/accesser/__init__.py b/accesser/__init__.py
index 7e96d14..da72686 100644
--- a/accesser/__init__.py
+++ b/accesser/__init__.py
@@ -30,6 +30,8 @@
from packaging.version import Version
from tld import get_tld, is_tld
import dns, dns.asyncresolver, dns.nameserver
+import platform
+from pathlib import Path
from .utils import certmanager as cm
from .utils import importca
@@ -51,12 +53,12 @@ async def update_cert(server_name):
async with cert_lock:
if not server_name in cert_store:
cm.create_certificate(server_name)
- context.load_cert_chain(os.path.join(cm.certpath, "{}.crt".format(server_name)))
+ context.load_cert_chain(setting.certpath.joinpath("{}.crt".format(server_name)))
cert_store.add(server_name)
async def send_pac(writer: asyncio.StreamWriter):
- with open('pac' if os.path.exists('pac') else os.path.join(basepath, 'pac'), 'rb') as f:
- pac = f.read().replace(b'{{port}}', str(setting.config['server']['port']).encode('iso-8859-1')).replace(b'{{host}}', setting.config['server'].get('pac_host', '127.0.0.1').encode('iso-8859-1'))
+ pac_file = Path('pac') if Path('pac').exists() else Path(basepath).joinpath('pac')
+ pac = pac_file.read_bytes().replace(b'{{port}}', str(setting.config['server']['port']).encode('iso-8859-1')).replace(b'{{host}}', setting.config['server'].get('pac_host', '127.0.0.1').encode('iso-8859-1'))
writer.write(f'HTTP/1.1 200 OK\r\nContent-Type: application/x-ns-proxy-autoconfig\r\nContent-Length: {len(pac)}\r\n\r\n'.encode('iso-8859-1'))
writer.write(pac)
await writer.drain()
@@ -64,8 +66,7 @@ async def send_pac(writer: asyncio.StreamWriter):
await writer.wait_closed()
async def send_crt(writer: asyncio.StreamWriter, path: str):
- with open(os.path.join(cm.certpath, path.rsplit(sep = '/',maxsplit = 1)[-1]), 'rb') as f:
- crt = f.read()
+ crt = setting.certpath.joinpath(path.rsplit(sep = '/',maxsplit = 1)[-1]).read_bytes()
writer.write(f'HTTP/1.1 200 OK\r\nContent-Type: application/x-x509-ca-cert\r\nContent-Length: {len(crt)}\r\n\r\n'.encode('iso-8859-1'))
writer.write(crt)
await writer.drain()
@@ -209,7 +210,11 @@ async def main():
global context, cert_store, cert_lock, DNSresolver
print(f"Accesser v{__version__} Copyright (C) 2018-2024 URenko")
setting.parse_args()
-
+ if platform.system() == "Linux" or platform.system() == "FreeBSD":
+ if os.geteuid() == 0:
+ logger.warning(
+ "Running Accesser as the root user carries certain risks. Do not use it in production."
+ )
if setting.rules_update_case in ('old', 'missing'):
logger.warning("Updated rules.toml because it is %s.", setting.rules_update_case)
elif setting.rules_update_case == 'modified':
diff --git a/accesser/utils/certmanager.py b/accesser/utils/certmanager.py
index 37df450..29f57d6 100644
--- a/accesser/utils/certmanager.py
+++ b/accesser/utils/certmanager.py
@@ -16,9 +16,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
-import os, platform
import datetime
-from pathlib import Path
from cryptography import x509
from cryptography.x509.oid import NameOID
@@ -34,53 +32,9 @@
from cryptography.x509.oid import ExtendedKeyUsageOID
from .log import logger
+
logger = logger.getChild("certmanager")
from . import setting
-from .setting import basepath
-
-
-def decide_state_path_legacy():
- if setting.config["importca"]:
- return Path(basepath)
- else:
- return Path()
-
-
-def decide_state_path_unix_like():
- if os.geteuid() == 0:
- logger.warn("Running Accesser as the root user carries certain risks; see pull #245")
- return Path("/var/lib") / "accesser"
-
- state_path = os.getenv("XDG_STATE_HOME", None)
- if state_path is not None:
- state_path = Path(state_path) / "accesser"
- else:
- state_path = Path.home() / ".local/state" / "accesser"
- return state_path
-
-
-def decide_certpath():
- certpath = None
- # 人为指定最优先
- #if setting.config["state_dir"]:
- #return Path(setting.config["state_dir"]) / "cert"
- match platform.system():
- case 'Linux' | 'FreeBSD':
- deprecated_path = decide_state_path_legacy() / "CERT"
- # 暂仅在 *nix 上视为已废弃
- if deprecated_path.exists():
- logger.warn("deprecated path, see pull #245")
- return deprecated_path
- certpath = decide_state_path_unix_like() / "cert"
- case _:
- # windows,mac,android ...
- certpath = decide_state_path_legacy() / "CERT"
- return certpath
-
-
-certpath = decide_certpath()
-if not certpath.exists():
- os.makedirs(certpath, exist_ok=True)
def create_root_ca():
@@ -131,11 +85,11 @@ def create_root_ca():
.sign(key, hashes.SHA256())
)
- (certpath / "root.crt").write_bytes(
+ setting.certpath.joinpath("root.crt").write_bytes(
cert.public_bytes(serialization.Encoding.PEM)
)
- (certpath / "root.key").write_bytes(
+ setting.certpath.joinpath("root.key").write_bytes(
key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
@@ -143,7 +97,7 @@ def create_root_ca():
)
)
- (certpath / "root.pfx").write_bytes(
+ setting.certpath.joinpath("root.pfx").write_bytes(
serialization.pkcs12.serialize_key_and_certificates(
b"Accesser", key, cert, None, serialization.NoEncryption()
)
@@ -151,8 +105,8 @@ def create_root_ca():
def create_certificate(server_name):
- rootpem = (certpath / "root.crt").read_bytes()
- rootkey = (certpath / "root.key").read_bytes()
+ rootpem = setting.certpath.joinpath("root.crt").read_bytes()
+ rootkey = setting.certpath.joinpath("root.key").read_bytes()
ca_cert = x509.load_pem_x509_certificate(rootpem)
pkey = serialization.load_pem_private_key(rootkey, password=None)
@@ -219,7 +173,7 @@ def create_certificate(server_name):
.sign(pkey, hashes.SHA256())
)
- (certpath / f"{server_name}.crt").write_bytes(
+ setting.certpath.joinpath(f"{server_name}.crt").write_bytes(
cert.public_bytes(serialization.Encoding.PEM)
+ pkey.private_bytes(
encoding=serialization.Encoding.PEM,
diff --git a/accesser/utils/importca.py b/accesser/utils/importca.py
index 6b7380c..a354850 100644
--- a/accesser/utils/importca.py
+++ b/accesser/utils/importca.py
@@ -17,7 +17,6 @@
# along with this program. If not, see .
import os, sys
-from pathlib import Path
import subprocess
import locale
@@ -29,7 +28,6 @@
from .log import logger
logger = logger.getChild('importca')
-certpath = cm.certpath
def logandrun(cmd):
if hasattr(subprocess, 'STARTUPINFO'):
@@ -42,33 +40,32 @@ def logandrun(cmd):
def import_windows_ca():
try:
- logandrun('certutil -f -user -p "" -exportPFX My Accesser '+os.path.join(certpath, 'root.pfx'))
+ logandrun('certutil -f -user -p "" -exportPFX My Accesser '+str(setting.certpath.joinpath('root.pfx')))
except subprocess.CalledProcessError:
logger.debug("Export Failed")
- if not os.path.exists(os.path.join(certpath ,"root.pfx")):
+ if not setting.certpath.joinpath("root.pfx").exists():
cm.create_root_ca()
try:
logger.info("Importing new certificate")
- logandrun('CertUtil -f -user -p "" -importPFX My '+os.path.join(certpath, 'root.pfx'))
+ logandrun('CertUtil -f -user -p "" -importPFX My '+str(setting.certpath.joinpath("root.pfx")))
except subprocess.CalledProcessError:
logger.error("Import Failed")
logandrun('CertUtil -user -delstore My Accesser')
- # os.remove(os.path.join(certpath ,"root.pfx"))
- # os.remove(os.path.join(certpath ,"root.crt"))
- # os.remove(os.path.join(certpath ,"root.key"))
+ # os.remove(os.path.join(setting.certpath ,"root.pfx"))
+ # os.remove(os.path.join(setting.certpath ,"root.crt"))
+ # os.remove(os.path.join(setting.certpath ,"root.key"))
# sys.exit(5)
logger.warning('Try to manually import the certificate')
else:
- with open(os.path.join(certpath ,"root.pfx"), 'rb') as pfxfile:
- private_key, certificate, _ = pkcs12.load_key_and_certificates(pfxfile.read(), password=None)
- with open(os.path.join(certpath ,"root.crt"), "wb") as certfile:
- certfile.write(certificate.public_bytes(serialization.Encoding.PEM))
- with open(os.path.join(certpath ,"root.key"), "wb") as pkeyfile:
- pkeyfile.write(private_key.private_bytes(
+ private_key, certificate, _ = pkcs12.load_key_and_certificates(setting.certpath.joinpath("root.pfx").read_bytes(), password=None)
+ setting.certpath.joinpath("root.crt").write_bytes(certificate.public_bytes(serialization.Encoding.PEM))
+ setting.certpath.joinpath("root.key").write_bytes(
+ private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
- ))
+ )
+ )
def import_mac_ca():
ca_hash = CertUtil.ca_thumbprint.replace(':', '')
@@ -99,7 +96,7 @@ def get_exist_ca_sha1():
os.system(cmd)
def import_ca():
- if not(os.path.exists(os.path.join(certpath ,"root.crt")) and os.path.exists(os.path.join(certpath ,"root.key"))):
+ if not(setting.certpath.joinpath("root.crt").exists() and setting.certpath.joinpath("root.key").exists()):
if setting.config['importca']:
if sys.platform.startswith('win'):
import_windows_ca()
diff --git a/accesser/utils/setting.py b/accesser/utils/setting.py
index 6b871ff..91b3af0 100644
--- a/accesser/utils/setting.py
+++ b/accesser/utils/setting.py
@@ -6,8 +6,11 @@
except ModuleNotFoundError:
import tomli as tomllib
import argparse
+import platform
+import logging
basepath = Path(__file__).parent.parent
+certpath = None
def deep_merge(config_a: dict, config_b: dict):
@@ -59,7 +62,44 @@ def deep_merge(config_a: dict, config_b: dict):
config = deep_merge(_config, _rules)
+def decide_state_path_legacy():
+ if config["importca"]:
+ return Path(basepath)
+ return Path()
+
+
+def decide_state_path_unix_like():
+ if os.geteuid() == 0:
+ return Path("/var/lib") / "accesser"
+
+ state_path = os.getenv("XDG_STATE_HOME", None)
+ if state_path is not None:
+ state_path = Path(state_path) / "accesser"
+ else:
+ state_path = Path.home() / ".local/state" / "accesser"
+ return state_path
+
+
+def decide_certpath():
+ # 人为指定最优先
+ if "state_dir" in config and config["state_dir"] is not None:
+ return Path(config["state_dir"]) / "CERT"
+ match platform.system():
+ case "Linux" | "FreeBSD":
+ deprecated_path = decide_state_path_legacy() / "CERT"
+ if deprecated_path.exists():
+ logging.warning("cert path %s is deprecated.", str(deprecated_path))
+ logging.warning("Please check https://github.com/URenko/Accesser/pull/245 for migration.")
+ return deprecated_path
+ return decide_state_path_unix_like() / "CERT"
+ case _:
+ # windows,mac,android ...
+ return decide_state_path_legacy() / "CERT"
+ return
+
+
def parse_args():
+ global certpath
parser = argparse.ArgumentParser()
parser.add_argument(
"--notsetproxy",
@@ -80,9 +120,11 @@ def parse_args():
args = parser.parse_args()
if args.notsetproxy:
config["setproxy"] = False
- return
- # FIXME Wrong initialization sequence
- # see pull #245
if args.notimportca:
config["importca"] = False
- config["state_dir"] = args.state_dir
+ if args.state_dir is not None:
+ config["state_dir"] = args.state_dir
+ certpath = decide_certpath()
+ if not certpath.exists() or certpath.is_file():
+ certpath.mkdir(parents=True)
+ return