diff --git a/src/mars_patcher/cli.py b/src/mars_patcher/cli.py index 61b029a..7642918 100644 --- a/src/mars_patcher/cli.py +++ b/src/mars_patcher/cli.py @@ -1,9 +1,6 @@ import argparse -import copy import json -import typing -from mars_patcher.auto_generated_types import MarsSchema from mars_patcher.patcher import patch, validate_patch_data @@ -18,11 +15,9 @@ def main() -> None: with open(args.patch_data_path, encoding="utf-8") as f: patch_data = json.load(f) - validate_patch_data(patch_data) - patch( args.rom_path, args.out_path, - typing.cast(MarsSchema, copy.copy(patch_data)), + validate_patch_data(patch_data), lambda message, progress: print(message), ) diff --git a/src/mars_patcher/patcher.py b/src/mars_patcher/patcher.py index a18b42f..31e389b 100644 --- a/src/mars_patcher/patcher.py +++ b/src/mars_patcher/patcher.py @@ -1,5 +1,7 @@ import json -from typing import Callable +import typing +from os import PathLike +from typing import Callable, Union from jsonschema import validate @@ -33,7 +35,7 @@ from mars_patcher.text import write_seed_hash -def validate_patch_data(patch_data: dict) -> None: +def validate_patch_data(patch_data: dict) -> MarsSchema: """ Validates whether the specified patch_data satisfies the schema for it. @@ -43,11 +45,12 @@ def validate_patch_data(patch_data: dict) -> None: with open(get_data_path("schema.json")) as f: schema = json.load(f) validate(patch_data, schema) + return typing.cast(MarsSchema, patch_data) def patch( - input_path: str, - output_path: str, + input_path: Union[str, PathLike[str]], + output_path: Union[str, PathLike[str]], patch_data: MarsSchema, status_update: Callable[[str, float], None], ) -> None: diff --git a/src/mars_patcher/rom.py b/src/mars_patcher/rom.py index 69fc16f..f20ef60 100644 --- a/src/mars_patcher/rom.py +++ b/src/mars_patcher/rom.py @@ -82,7 +82,7 @@ class Rom: }, } - def __init__(self, path: str): + def __init__(self, path: Union[str, PathLike[str]]): # Read file with open(path, "rb") as f: self.data = bytearray(f.read())