diff --git a/python/submit/pyproject.toml b/python/submit/pyproject.toml index 9050fd2..6b107e0 100644 --- a/python/submit/pyproject.toml +++ b/python/submit/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "submit" -version = "0.2.0" +version = "0.3.0" description = "Extract code block which should be submitted to regulatory agency." readme = "README.md" requires-python = ">=3.10" diff --git a/python/submit/submit.py b/python/submit/submit.py index 579fbaf..dbc591c 100644 --- a/python/submit/submit.py +++ b/python/submit/submit.py @@ -22,6 +22,10 @@ COMMENT_NOT_SUBMIT_NEGIN: str = rf"\/\*{SYLBOMS}*NOT\s*SUBMIT\s*BEGIN{SYLBOMS}*\*\/" COMMENT_NOT_SUBMIT_END: str = rf"\/\*{SYLBOMS}*NOT\s*SUBMIT\s*END{SYLBOMS}*\*\/" +# 宏变量 +# 仅支持一层宏变量引用(例如:&id),不支持嵌套宏变量引用(例如:&&id、&&&id) +MACRO_VAR = r"(? list[ConvertMode]: def copy_file( - sas_file: str, txt_file: str, convert_mode: ConvertMode = ConvertMode.BOTH, encoding: str | None = None + sas_file: str, + txt_file: str, + convert_mode: ConvertMode = ConvertMode.BOTH, + macro_subs: dict[str, str] | None = None, + encoding: str | None = None, ) -> None: """将 SAS 代码复制到 txt 文件中,并移除指定标记之间的内容。 @@ -61,6 +69,7 @@ def copy_file( sas_file (str): SAS 文件路径。 txt_file (str): TXT 文件路径。 convert_mode (ConvertMode, optional): 转换模式,默认值为 ConvertMode.BOTH。 + macro_subs (dict[str, str] | None, optional): 一个字典,其键为 SAS 代码中的宏变量名称,值为替代的字符串,默认值为 None。 encoding (str | None, optional): 字符编码,默认值为 None,将自动检测编码。 """ @@ -69,35 +78,36 @@ def copy_file( encoding = detect(f.read())["encoding"] with open(sas_file, "r", encoding=encoding) as f: - sas_code = f.read() + code = f.read() + # 提取代码片段 if convert_mode & ConvertMode.NEGATIVE: # 移除不需要递交的代码片段 - sas_code = re.sub( - rf"{COMMENT_NOT_SUBMIT_NEGIN}.*?{COMMENT_NOT_SUBMIT_END}", - "", - sas_code, - flags=re.I | re.S, - ) + code = re.sub(rf"{COMMENT_NOT_SUBMIT_NEGIN}.*?{COMMENT_NOT_SUBMIT_END}", "", code, flags=re.I | re.S) if convert_mode & ConvertMode.POSITIVE: # 提取需要递交的代码片段 - sas_code = re.findall(rf"{COMMENT_SUBMIT_BEGIN}(.*?){COMMENT_SUBMIT_END}", sas_code, re.I | re.S) - sas_code = "".join(sas_code) - - txt_code = sas_code - - txt_code_dir = os.path.dirname(txt_file) - if not os.path.exists(txt_code_dir): - os.makedirs(txt_code_dir) + code = re.findall(rf"{COMMENT_SUBMIT_BEGIN}(.*?){COMMENT_SUBMIT_END}", code, re.I | re.S) + code = "".join(code) + + # 替换宏变量 + if macro_subs is not None: + for key, value in macro_subs.items(): + regex_macro = re.compile(rf"(? dict[str, str]: + """解析字典字符串。 + + Args: + arg (str): 字典字符串。 + + Returns: + dict[str, str]: 字典。 + """ + + arg = arg.strip("{}") + try: + return dict([ele.strip("\"'") for ele in item.split("=")] for item in arg.split(",")) + except ValueError: + raise argparse.ArgumentTypeError("无效的字典字符串") def main() -> None: @@ -155,6 +183,9 @@ def main() -> None: default="both", help="转换模式(默认 both)", ) + parent_parser.add_argument( + "--macro-subs", type=parse_dict, help="宏变量替换,格式为 {key1=value1,key2=value2}(默认无)" + ) parent_parser.add_argument("--encoding", default=None, help="编码格式(默认自动检测)") # 子命令 copyfile @@ -176,6 +207,7 @@ def main() -> None: sas_file=args.sas_file, txt_file=args.txt_file, convert_mode=args.convert_mode, + macro_subs=args.macro_subs, encoding=args.encoding, ) elif args.command == "copydir": @@ -183,6 +215,7 @@ def main() -> None: sas_dir=args.sas_dir, txt_dir=args.txt_dir, convert_mode=args.convert_mode, + macro_subs=args.macro_subs, exclude_files=args.exclude_files, exclude_dirs=args.exclude_dirs, encoding=args.encoding, diff --git a/python/submit/tests/conftest.py b/python/submit/tests/conftest.py index 1d5bcb5..ae351cd 100644 --- a/python/submit/tests/conftest.py +++ b/python/submit/tests/conftest.py @@ -59,9 +59,11 @@ def shared_test_directory(tmp_path_factory: pytest.TempPathFactory) -> Path: proc datasets library = work memtype = data kill noprint; quit; + %let id = %str(); + /*====SUBMIT BEGIN====*/ proc sql; - create table t2 as select * from adam.adae; + create table t2 as select * from adam.adeff&id; quit; /*====SUBMIT END====*/ @@ -175,7 +177,7 @@ def shared_validate_directory(tmp_path_factory: pytest.TempPathFactory) -> Path: """) (dir_tfl / "t2.txt").write_text(""" proc sql; - create table t2 as select * from adam.adae; + create table t2 as select * from adam.adeff; quit; """) (dir_tfl / "t3.txt").write_text(""" diff --git a/python/submit/tests/test_submit.py b/python/submit/tests/test_submit.py index d0a87a9..305e620 100644 --- a/python/submit/tests/test_submit.py +++ b/python/submit/tests/test_submit.py @@ -20,7 +20,9 @@ def test_copy_file(self, shared_test_directory: Path, shared_validate_directory: assert re.sub(r"\s*", "", tmp_code) == re.sub(r"\s*", "", validate_code) def test_copy_directory(self, shared_test_directory: Path, shared_validate_directory: Path, tmp_path: Path): - copy_directory(shared_test_directory, tmp_path, exclude_dirs=["other"], exclude_files=["fcmp.sas"]) + copy_directory( + shared_test_directory, tmp_path, exclude_dirs=["other"], exclude_files=["fcmp.sas"], macro_subs={"id": ""} + ) copy_directory(shared_test_directory / "macro", tmp_path / "macro", convert_mode=ConvertMode.NEGATIVE) for validate_file in shared_validate_directory.rglob("*.txt"): diff --git a/python/submit/uv.lock b/python/submit/uv.lock index 13c0683..82156e7 100644 --- a/python/submit/uv.lock +++ b/python/submit/uv.lock @@ -222,7 +222,7 @@ wheels = [ [[package]] name = "submit" -version = "0.1.0" +version = "0.3.0" source = { editable = "." } dependencies = [ { name = "chardet" },