From d60da26ec44a49ea75a4cb979e641aa29a066fce Mon Sep 17 00:00:00 2001 From: raspberry Date: Sun, 4 May 2025 23:49:25 +0800 Subject: [PATCH 1/3] upload script --- upload.sh | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100755 upload.sh diff --git a/upload.sh b/upload.sh new file mode 100755 index 0000000..af6c915 --- /dev/null +++ b/upload.sh @@ -0,0 +1,7 @@ +pytest +black . +rm -rf dist/ build/ *.egg-info .pytest_cache/ +find . | grep -E "(/__pycache__$|/\.DS_Store$)" | xargs rm -rf +python3 -m build +twine upload --repository ec-tools dist/* + From 9aa388a9d29ad917589904187779acfffb2efee6 Mon Sep 17 00:00:00 2001 From: raspberry Date: Fri, 13 Jun 2025 10:48:33 +0800 Subject: [PATCH 2/3] update cipher --- .pylintrc | 3 + demos/headers_manager_demo.py | 2 - demos/key_manager_demo.py | 2 - ec_tools/data/__init__.py | 2 +- ec_tools/data/data_object.py | 42 ++++++++-- ec_tools/data/json_type.py | 2 +- ec_tools/tools/cipher/__init__.py | 7 +- ec_tools/tools/cipher/aes_cipher.py | 80 +++++++++++++++++++ ec_tools/tools/cipher/aes_config.py | 17 ++++ ec_tools/tools/cipher/cipher.py | 29 +++++++ .../tools/cipher/cipher_generator/__init__.py | 0 .../aes_cipher_generator.py | 22 +---- .../cipher_generator.py | 34 +------- ec_tools/tools/key_manager.py | 8 +- ec_tools/tools/thread_pool.py | 24 ++---- ec_tools/utils/hash_utils.py | 8 ++ ec_tools/utils/io_utils.py | 53 ++++++++---- scripts/lint.sh | 4 + local.sh => scripts/local.sh | 4 +- setup.py | 2 +- tests/database/data_object_enum_test.py | 1 - tests/tools/cipher/cipher_generator_test.py | 3 +- tests/utils/io_test.py | 28 +++++++ tests/utils/test.json | 3 + upload.sh | 7 -- 25 files changed, 271 insertions(+), 116 deletions(-) create mode 100644 .pylintrc create mode 100644 ec_tools/tools/cipher/aes_cipher.py create mode 100644 ec_tools/tools/cipher/aes_config.py create mode 100644 ec_tools/tools/cipher/cipher.py create mode 100644 ec_tools/tools/cipher/cipher_generator/__init__.py rename ec_tools/tools/cipher/{ => cipher_generator}/aes_cipher_generator.py (81%) rename ec_tools/tools/cipher/{ => cipher_generator}/cipher_generator.py (50%) create mode 100755 scripts/lint.sh rename local.sh => scripts/local.sh (51%) create mode 100644 tests/utils/io_test.py create mode 100644 tests/utils/test.json delete mode 100755 upload.sh diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 0000000..74f82bb --- /dev/null +++ b/.pylintrc @@ -0,0 +1,3 @@ +[MESSAGES CONTROL] +disable=missing-docstring,broad-exception-caught,too-many-return-statements,too-many-arguments,too-many-positional-arguments + diff --git a/demos/headers_manager_demo.py b/demos/headers_manager_demo.py index ed85048..95affb3 100644 --- a/demos/headers_manager_demo.py +++ b/demos/headers_manager_demo.py @@ -1,6 +1,4 @@ import logging -import time -import unittest from ec_tools.database import SqliteKvDao, SqliteClient, CipherKvDao from ec_tools.tools.cipher import AesCipherGenerator diff --git a/demos/key_manager_demo.py b/demos/key_manager_demo.py index 6b57689..a987936 100644 --- a/demos/key_manager_demo.py +++ b/demos/key_manager_demo.py @@ -1,6 +1,4 @@ import logging -import time -import unittest from ec_tools.database import SqliteKvDao, SqliteClient, CipherKvDao from ec_tools.tools.cipher import AesCipherGenerator diff --git a/ec_tools/data/__init__.py b/ec_tools/data/__init__.py index a81b7fe..c8761cb 100755 --- a/ec_tools/data/__init__.py +++ b/ec_tools/data/__init__.py @@ -1,2 +1,2 @@ from .json_type import JsonType -from .data_object import DataObject, Formatter +from .data_object import DataObject, Formatter, CustomizedJsonEncoder diff --git a/ec_tools/data/data_object.py b/ec_tools/data/data_object.py index 29fc474..6ab4f33 100755 --- a/ec_tools/data/data_object.py +++ b/ec_tools/data/data_object.py @@ -3,7 +3,14 @@ import json from typing import Any, Dict, List, Callable, get_origin, get_args, Set -from ec_tools.utils.io_utils import CustomizedJsonEncoder + +class CustomizedJsonEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, set): + return list(o) + if isinstance(o, enum.Enum): + return o.name + return json.JSONEncoder.default(self, o) @dataclasses.dataclass @@ -91,7 +98,7 @@ def _get_value(self, value: Any): return value if self.field and self.field.default_factory != dataclasses.MISSING: return self.field.default_factory() - elif self.field and self.field.default != dataclasses.MISSING: + if self.field and self.field.default != dataclasses.MISSING: return self.field.default return None @@ -118,7 +125,7 @@ def field_names(cls) -> List[str]: @classmethod def from_json(cls, json_obj: Dict[str, Any]): - function_mapping = cls._customized_mapping_function("_load__") + function_mapping = cls._customized_load_function("_load__") return cls( **{ field.name: function_mapping[field.name](json_obj.get(field.name, None)) @@ -127,15 +134,22 @@ def from_json(cls, json_obj: Dict[str, Any]): ) def to_json(self) -> Dict[str, Any]: - return json.loads(self.to_json_str()) + default_json = json.loads(self.to_json_str()) + function_mapping = self._customized_dump_function("_dump__") + return { + field.name: ( + function_mapping[field.name](self.__dict__[field.name]) + if function_mapping[field.name] + else default_json.get(field.name) + ) + for field in self.fields() + } def to_json_str(self) -> str: return json.dumps(dataclasses.asdict(self), cls=CustomizedJsonEncoder) @classmethod - def _customized_mapping_function( - cls, prefix: str - ) -> Dict[str, Callable[[Any], Any]]: + def _customized_load_function(cls, prefix: str) -> Dict[str, Callable[[Any], Any]]: all_functions = { item: getattr(cls, item) for item in dir(cls) @@ -148,3 +162,17 @@ def _customized_mapping_function( ) for field in cls.fields() } + + @classmethod + def _customized_dump_function( + cls, + prefix: str, + ) -> Dict[str, Callable[[Any], Any]]: + all_functions = { + item: getattr(cls, item) + for item in dir(cls) + if isinstance(getattr(cls, item), Callable) and item.startswith(prefix) + } + return { + field.name: all_functions.get(prefix + field.name) for field in cls.fields() + } diff --git a/ec_tools/data/json_type.py b/ec_tools/data/json_type.py index 0fce8a9..87378fe 100755 --- a/ec_tools/data/json_type.py +++ b/ec_tools/data/json_type.py @@ -1,5 +1,5 @@ from typing import Union, List, Dict, TypeAlias JsonType: TypeAlias = Union[ - None, int, str, bool, List["JsonType"], Dict[str, "JsonType"] + None, int, str, bool, float, List["JsonType"], Dict[str, "JsonType"] ] diff --git a/ec_tools/tools/cipher/__init__.py b/ec_tools/tools/cipher/__init__.py index 2decb6e..efa4baa 100644 --- a/ec_tools/tools/cipher/__init__.py +++ b/ec_tools/tools/cipher/__init__.py @@ -1,2 +1,5 @@ -from .aes_cipher_generator import AesCipherGenerator, AesMode -from .cipher_generator import Cipher, CipherGenerator +from .aes_cipher import AesCipher +from .aes_config import AesConfig, AesMode +from .cipher import Cipher, SecrectKey +from .cipher_generator.cipher_generator import CipherGenerator +from .cipher_generator.aes_cipher_generator import AesCipherGenerator diff --git a/ec_tools/tools/cipher/aes_cipher.py b/ec_tools/tools/cipher/aes_cipher.py new file mode 100644 index 0000000..6f500bf --- /dev/null +++ b/ec_tools/tools/cipher/aes_cipher.py @@ -0,0 +1,80 @@ +import hashlib +import os +import dataclasses + +from Crypto.Cipher import AES +from ec_tools.tools.cipher.cipher import SecrectKey +from ec_tools.tools.cipher.aes_config import AesMode + + +@dataclasses.dataclass +class AesCipher: + password: bytes + aes_mode: AesMode + + def __init__( + self, + password: str, + salt: str, + iterations: int = 10000, + aes_mode: AesMode = AesMode.AES_256_CBC, + ): + self.aes_mode = aes_mode + self.password = self.generate_password(password, salt, iterations) + + def encrypt(self, plain_text: bytes) -> bytes: + size = len(plain_text) + data = self._augment_bytes(plain_text) + secret_key = self.generate_key(self.password, os.urandom(32), 10) + aes = AES.new(secret_key.key, self.aes_mode.value.mode, iv=secret_key.iv) + cipher_text = aes.encrypt(data) + return str(size).encode("utf-8") + b"\0" + secret_key.salt + cipher_text + + def decrypt(self, cipher_text: bytes) -> bytes: + zero = cipher_text.index(b"\0") + size = int(cipher_text[:zero].decode("utf-8")) + salt = cipher_text[zero + 1 : zero + 1 + 32] + data = cipher_text[zero + 1 + 32 :] + secret_key = self.generate_key(self.password, salt, 10) + aes = AES.new(secret_key.key, self.aes_mode.value.mode, iv=secret_key.iv) + decrypted_data = aes.decrypt(data) + return self._recover_bytes(decrypted_data, size) + + def _augment_bytes(self, data: bytes) -> bytes: + return b"".join( + [ + os.urandom(self.aes_mode.value.key_size), + data, + os.urandom(self.aes_mode.value.key_size), + os.urandom( + self.aes_mode.value.key_size + - len(data) % self.aes_mode.value.key_size + ), + ] + ) + + def _recover_bytes(self, data: bytes, size: int) -> bytes: + return data[self.aes_mode.value.key_size : self.aes_mode.value.key_size + size] + + @classmethod + def generate_password(cls, password: str, salt: str, iterations: int) -> bytes: + return hashlib.pbkdf2_hmac( + "sha512", + password.encode("utf-8"), + bytes.fromhex(salt), + iterations, + ) + + def generate_key(self, password: bytes, salt: bytes, iterations: int): + hsh = hashlib.pbkdf2_hmac( + "sha512", + password, + salt, + iterations, + dklen=self.aes_mode.value.key_size + self.aes_mode.value.iv_size, + ) + return SecrectKey( + key=hsh[: self.aes_mode.value.key_size], + iv=hsh[self.aes_mode.value.key_size :], + salt=salt, + ) diff --git a/ec_tools/tools/cipher/aes_config.py b/ec_tools/tools/cipher/aes_config.py new file mode 100644 index 0000000..3d7f7bd --- /dev/null +++ b/ec_tools/tools/cipher/aes_config.py @@ -0,0 +1,17 @@ +import dataclasses +import enum + +from Crypto.Cipher import AES + + +@dataclasses.dataclass +class AesConfig: + key_size: int + iv_size: int + mode: int + + +class AesMode(enum.Enum): + AES_128_CBC = AesConfig(16, 16, AES.MODE_CBC) + AES_192_CBC = AesConfig(24, 16, AES.MODE_CBC) + AES_256_CBC = AesConfig(32, 16, AES.MODE_CBC) diff --git a/ec_tools/tools/cipher/cipher.py b/ec_tools/tools/cipher/cipher.py new file mode 100644 index 0000000..0d1aed5 --- /dev/null +++ b/ec_tools/tools/cipher/cipher.py @@ -0,0 +1,29 @@ +import json +import dataclasses + + +@dataclasses.dataclass +class Cipher: + cipher_text: str + mode: str + salt: str + + def dumps(self) -> str: + return json.dumps(dataclasses.asdict(self)) + + @classmethod + def loads(cls, text: str) -> "Cipher": + return Cipher(**json.loads(text)) + + +@dataclasses.dataclass +class SecrectKey: + key: bytes + iv: bytes + salt: bytes + + def __str__(self) -> str: + return f"SecretKey(key={len(self.key)},iv={len(self.iv)},salt={len(self.salt)})" + + def __repr__(self) -> str: + return str(self) diff --git a/ec_tools/tools/cipher/cipher_generator/__init__.py b/ec_tools/tools/cipher/cipher_generator/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ec_tools/tools/cipher/aes_cipher_generator.py b/ec_tools/tools/cipher/cipher_generator/aes_cipher_generator.py similarity index 81% rename from ec_tools/tools/cipher/aes_cipher_generator.py rename to ec_tools/tools/cipher/cipher_generator/aes_cipher_generator.py index c5b25ab..39a6a94 100644 --- a/ec_tools/tools/cipher/aes_cipher_generator.py +++ b/ec_tools/tools/cipher/cipher_generator/aes_cipher_generator.py @@ -1,25 +1,11 @@ -import json import os import hashlib import dataclasses -import enum from Crypto.Cipher import AES -from ec_tools.utils.hash_utils import hmac_sha256 -from ec_tools.tools.cipher.cipher_generator import Cipher, CipherGenerator, SecrectKey - - -@dataclasses.dataclass -class AesConfig: - key_size: int - iv_size: int - mode: int - - -class AesMode(enum.Enum): - AES_128_CBC = AesConfig(16, 16, AES.MODE_CBC) - AES_192_CBC = AesConfig(24, 16, AES.MODE_CBC) - AES_256_CBC = AesConfig(32, 16, AES.MODE_CBC) +from ec_tools.tools.cipher.cipher import Cipher, SecrectKey +from ec_tools.tools.cipher.cipher_generator.cipher_generator import CipherGenerator +from ec_tools.tools.cipher.aes_config import AesMode @dataclasses.dataclass @@ -54,7 +40,7 @@ def encrypt(self, password: bytes, plain_text: bytes) -> Cipher: ) @classmethod - def _augment_bytes(self, data: bytes, padding_size: int) -> bytes: + def _augment_bytes(cls, data: bytes, padding_size: int) -> bytes: padded_text = (data + os.urandom(padding_size - len(data) % padding_size)).hex() random_bytes = os.urandom(len(padded_text)).hex() mixture = bytes.fromhex("".join(map("".join, zip(padded_text, random_bytes)))) diff --git a/ec_tools/tools/cipher/cipher_generator.py b/ec_tools/tools/cipher/cipher_generator/cipher_generator.py similarity index 50% rename from ec_tools/tools/cipher/cipher_generator.py rename to ec_tools/tools/cipher/cipher_generator/cipher_generator.py index a4fd449..ba065df 100644 --- a/ec_tools/tools/cipher/cipher_generator.py +++ b/ec_tools/tools/cipher/cipher_generator/cipher_generator.py @@ -1,38 +1,6 @@ -import json -import os -import hashlib import dataclasses import abc -from Crypto.Cipher import AES - -from ec_tools.utils.hash_utils import hmac_sha256 - - -@dataclasses.dataclass -class Cipher: - cipher_text: str - mode: str - salt: str - - def dumps(self) -> str: - return json.dumps(dataclasses.asdict(self)) - - @classmethod - def loads(self, text: str) -> "Cipher": - return Cipher(**json.loads(text)) - - -@dataclasses.dataclass -class SecrectKey: - key: bytes - iv: bytes - salt: bytes - - def __str__(self) -> str: - return f"SecretKey(key={len(self.key)},iv={len(self.iv)},salt={len(self.salt)})" - - def __repr__(self) -> str: - return str(self) +from ec_tools.tools.cipher.cipher import Cipher @dataclasses.dataclass diff --git a/ec_tools/tools/key_manager.py b/ec_tools/tools/key_manager.py index 9ecbb51..bbe1413 100644 --- a/ec_tools/tools/key_manager.py +++ b/ec_tools/tools/key_manager.py @@ -1,23 +1,19 @@ import dataclasses -import logging -import os import threading -import time -from copy import deepcopy from typing import Dict, Optional, Callable, List from ec_tools.database import CipherKvDao def _collect_user_input(field: str, retries: int = 10) -> str: - for i in range(retries): + for _ in range(retries): try: value = input(f"Enter field {field}: ") value = value.strip() if not value: continue return value - except Exception as e: + except Exception: continue raise RuntimeError(f"Failed to input {field}") diff --git a/ec_tools/tools/thread_pool.py b/ec_tools/tools/thread_pool.py index 0193319..6ced028 100755 --- a/ec_tools/tools/thread_pool.py +++ b/ec_tools/tools/thread_pool.py @@ -12,28 +12,25 @@ def __init__( max_workers=None, thread_name_prefix="", initializer=None, + ignore_exception: bool = False, initargs=(), ): self.futures = [] + self.ignore_exception = ignore_exception super().__init__(max_workers, thread_name_prefix, initializer, initargs) - def submit(self, fn, ignore_exception: bool = False, **kwargs): - future = super().submit( - self.try_execute, func=fn, ignore_exception=ignore_exception, **kwargs - ) + def submit(self, fn, /, *args, **kwargs): + future = super().submit(self.try_execute, func=fn, *args, **kwargs) self.futures.append(future) return future - @classmethod - def try_execute( - cls, func: Callable, ignore_exception: bool, **kwargs - ) -> Optional[Any]: + def try_execute(self, func: Callable, *args, **kwargs) -> Optional[Any]: try: - return func(**kwargs) + return func(*args, **kwargs) except Exception as e: - if not ignore_exception: + if not self.ignore_exception: raise e - logging.error(f"[ThreadPool] try to run failed with %s", e) + logging.error("[ThreadPool] try to run failed with %s", e) return None def join(self, log_time: int = 5, clear_after_wait: bool = True) -> List[Any]: @@ -58,8 +55,3 @@ def join(self, log_time: int = 5, clear_after_wait: bool = True) -> List[Any]: self.futures.clear() logging.info("[ThreadPool] all futures complete") return results - - def __del__(self): - logging.info("[ThreadPool] stopping") - self.join() - logging.info("[ThreadPool] stopped") diff --git a/ec_tools/utils/hash_utils.py b/ec_tools/utils/hash_utils.py index d09b5da..c187729 100644 --- a/ec_tools/utils/hash_utils.py +++ b/ec_tools/utils/hash_utils.py @@ -15,3 +15,11 @@ def hmac_md5_text(key: str, value: str, encoding="utf-8") -> str: return hmac.new( key.encode(encoding=encoding), value.encode(encoding=encoding), hashlib.md5 ).hexdigest() + + +def calc_md5(file_path: str, batch_size: int = 8192) -> str: + hasher = hashlib.md5() + with open(file_path, "rb") as f: + while chunk := f.read(batch_size): + hasher.update(chunk) + return hasher.hexdigest() diff --git a/ec_tools/utils/io_utils.py b/ec_tools/utils/io_utils.py index 7c7af69..3cb3070 100755 --- a/ec_tools/utils/io_utils.py +++ b/ec_tools/utils/io_utils.py @@ -1,28 +1,49 @@ -import enum import json +import os -from ec_tools.data import JsonType +from typing import List +from ec_tools.data import JsonType, CustomizedJsonEncoder -class CustomizedJsonEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, set): - return list(obj) - if isinstance(obj, enum.Enum): - return obj.name - return json.JSONEncoder.default(self, obj) +def load_json(path: str, default: JsonType = None, encoding: str = "utf-8") -> JsonType: + if os.path.isfile(path): + with open(path, "r", encoding=encoding) as f: + return json.load(f) + if default is not None: + return default + raise IOError(f"no such file {path}") -def load_json(path: str) -> JsonType: - with open(path, "r") as f: - return json.load(f) - - -def load_file(path: str) -> str: - with open(path, "r") as f: +def load_file(path: str, encoding: str = "utf-8") -> str: + with open(path, "r", encoding=encoding) as f: return f.read() def load_binary(path: str) -> bytes: with open(path, "rb") as f: return f.read() + + +def load_file_by_lines(path: str, encoding: str = "utf-8") -> List[str]: + with open(path, "r", encoding=encoding) as f: + rows = [row.strip() for row in f.readlines()] + return [row for row in rows if row] + + +def touch_dir(path: str): + os.makedirs(os.path.dirname(path), exist_ok=True) + + +def dump_json(data: JsonType, path: str, encoding: str = "utf-8"): + with open(path, "w", encoding=encoding) as f: + json.dump(data, f, ensure_ascii=False, indent=2, cls=CustomizedJsonEncoder) + + +def dump_binary(data: bytes, path: str): + with open(path, "wb") as f: + f.write(data) + + +def dump_file(data: str, path: str, encoding: str = "utf-8"): + with open(path, "w", encoding=encoding) as f: + f.write(data) diff --git a/scripts/lint.sh b/scripts/lint.sh new file mode 100755 index 0000000..c4f7a4c --- /dev/null +++ b/scripts/lint.sh @@ -0,0 +1,4 @@ +autoflake --in-place --remove-all-unused-imports --remove-unused-variables $(find ec_tools demos tests -name "*.py" -not -path "**/__init__.py") +black ec_tools demos tests +pylint ec_tools demos tests + diff --git a/local.sh b/scripts/local.sh similarity index 51% rename from local.sh rename to scripts/local.sh index 2215e98..508ea16 100755 --- a/local.sh +++ b/scripts/local.sh @@ -1,6 +1,8 @@ -pytest +autoflake --in-place --remove-all-unused-imports --remove-unused-variables $(find ec_tools demos tests -name "*.py" -not -path "**/__init__.py") black . +pytest rm -rf dist/ build/ *.egg-info .pytest_cache/ find . | grep -E "(/__pycache__$|/\.DS_Store$)" | xargs rm -rf python3 -m build pip3 install . + diff --git a/setup.py b/setup.py index 7076ff7..89e02d5 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setuptools.setup( name="ec_tools", - version="2.6", + version="2.7", description="EC Tools", packages=setuptools.find_packages(exclude=["tests", "tests.*"]), install_requires=["dataclasses", "typing"], diff --git a/tests/database/data_object_enum_test.py b/tests/database/data_object_enum_test.py index 9164972..85f4b79 100644 --- a/tests/database/data_object_enum_test.py +++ b/tests/database/data_object_enum_test.py @@ -1,6 +1,5 @@ import dataclasses import enum -import json import unittest from typing import List diff --git a/tests/tools/cipher/cipher_generator_test.py b/tests/tools/cipher/cipher_generator_test.py index d9d4e1b..a2ae174 100755 --- a/tests/tools/cipher/cipher_generator_test.py +++ b/tests/tools/cipher/cipher_generator_test.py @@ -1,9 +1,8 @@ import logging -import time import os import unittest -from ec_tools.tools.cipher import AesCipherGenerator, Cipher, CipherGenerator, AesMode +from ec_tools.tools.cipher import AesCipherGenerator, CipherGenerator, AesMode logging.basicConfig(level=logging.DEBUG) diff --git a/tests/utils/io_test.py b/tests/utils/io_test.py new file mode 100644 index 0000000..a39c363 --- /dev/null +++ b/tests/utils/io_test.py @@ -0,0 +1,28 @@ +import os +import unittest + +from ec_tools.utils import io_utils + + +class IoTest(unittest.TestCase): + def test_load_json(self): + self.assertEqual( + {"a": "b"}, + io_utils.load_json(os.path.join(os.path.dirname(__file__), "test.json")), + ) + self.assertEqual( + {}, + io_utils.load_json( + os.path.join(os.path.dirname(__file__), "test_not_exist.json"), {} + ), + ) + self.assertRaises( + IOError, + lambda: io_utils.load_json( + os.path.join(os.path.dirname(__file__), "test_not_exist.json") + ), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/utils/test.json b/tests/utils/test.json new file mode 100644 index 0000000..577a4fd --- /dev/null +++ b/tests/utils/test.json @@ -0,0 +1,3 @@ +{ + "a": "b" +} diff --git a/upload.sh b/upload.sh deleted file mode 100755 index af6c915..0000000 --- a/upload.sh +++ /dev/null @@ -1,7 +0,0 @@ -pytest -black . -rm -rf dist/ build/ *.egg-info .pytest_cache/ -find . | grep -E "(/__pycache__$|/\.DS_Store$)" | xargs rm -rf -python3 -m build -twine upload --repository ec-tools dist/* - From f13a31d6200ab6dd8b54fed00ade4bfc1b042397 Mon Sep 17 00:00:00 2001 From: raspberry Date: Sat, 23 Aug 2025 22:05:28 +0800 Subject: [PATCH 3/3] add aes tools --- .pre-commit-config.yaml | 31 +++++++++ demos/headers_manager_demo.py | 5 +- demos/key_manager_demo.py | 2 +- ec_tools/data/__init__.py | 2 +- ec_tools/data/data_object.py | 30 +++------ ec_tools/data/json_type.py | 6 +- ec_tools/database/__init__.py | 12 ++-- ec_tools/database/kv_dao/cipher_kv_dao.py | 38 +++++------ ec_tools/database/kv_dao/kv_dao.py | 4 +- ec_tools/database/kv_dao/sqlite_kv_dao.py | 22 +++++-- .../database/sqlite_client/sqlite_client.py | 9 +-- .../database/sqlite_client/sqlite_query.py | 2 +- ec_tools/database/sqlite_dao/sqlite_dao.py | 37 +++++------ .../sqlite_dao/sqlite_query_generator.py | 21 ++---- ec_tools/tools/cipher/__init__.py | 2 +- ec_tools/tools/cipher/aes_cipher.py | 29 +++------ ec_tools/tools/cipher/chunk_config.py | 37 +++++++++++ .../tools/cipher/chunk_encryption_utils.py | 65 +++++++++++++++++++ ec_tools/tools/cipher/cipher.py | 2 +- .../cipher_generator/aes_cipher_generator.py | 20 +++--- .../cipher_generator/cipher_generator.py | 3 +- .../tools/cipher/file_encryption_utils.py | 44 +++++++++++++ ec_tools/tools/cipher/password_tool.py | 27 ++++++++ ec_tools/tools/headers_manager.py | 3 +- ec_tools/tools/key_manager.py | 2 +- ec_tools/tools/thread_pool.py | 4 +- ec_tools/utils/hash_utils.py | 4 +- ec_tools/utils/io_utils.py | 13 +++- ec_tools/utils/misc.py | 13 +++- ec_tools/utils/os_utils.py | 2 +- ec_tools/utils/timer.py | 36 ++++++++++ scripts/lint.sh | 4 -- scripts/local.sh | 3 +- setup.py | 2 +- tests/database/cipher_kv_dao_test.py | 2 +- tests/database/data_object_complex_test.py | 2 +- tests/database/data_object_enum_test.py | 2 +- tests/database/data_object_union_test.py | 7 +- tests/database/kv_dao_test.py | 4 +- tests/database/sqlite_dao_test.py | 10 +-- tests/tools/cipher/cipher_generator_test.py | 2 +- tests/utils/io_test.py | 8 +-- 42 files changed, 389 insertions(+), 184 deletions(-) create mode 100644 .pre-commit-config.yaml create mode 100644 ec_tools/tools/cipher/chunk_config.py create mode 100644 ec_tools/tools/cipher/chunk_encryption_utils.py create mode 100644 ec_tools/tools/cipher/file_encryption_utils.py create mode 100644 ec_tools/tools/cipher/password_tool.py create mode 100644 ec_tools/utils/timer.py delete mode 100755 scripts/lint.sh diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..492285e --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,31 @@ +repos: + # --- isort: import sorter --- + - repo: https://github.com/PyCQA/isort + rev: 5.13.2 + hooks: + - id: isort + args: ["--profile", "black", "--line-length", "120"] + # --- autoflake: remove unused imports/vars --- + - repo: https://github.com/myint/autoflake + rev: v2.3.1 + hooks: + - id: autoflake + args: [ + "--in-place", + "--remove-unused-variables", + "--remove-all-unused-imports", + ] + exclude: "__init__.py" + # --- Black: Python code formatter --- + - repo: https://github.com/psf/black + rev: 24.8.0 # use a stable release + hooks: + - id: black + language_version: python3 + args: ["--line-length=120"] + # --- YAML lint/format --- + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: check-yaml + diff --git a/demos/headers_manager_demo.py b/demos/headers_manager_demo.py index 95affb3..41dd63a 100644 --- a/demos/headers_manager_demo.py +++ b/demos/headers_manager_demo.py @@ -1,10 +1,9 @@ import logging -from ec_tools.database import SqliteKvDao, SqliteClient, CipherKvDao +from ec_tools.database import CipherKvDao, SqliteClient, SqliteKvDao from ec_tools.tools.cipher import AesCipherGenerator -from ec_tools.tools.key_manager import KeyManager - from ec_tools.tools.headers_manager import HeadersManager +from ec_tools.tools.key_manager import KeyManager logging.basicConfig(level=logging.DEBUG) diff --git a/demos/key_manager_demo.py b/demos/key_manager_demo.py index a987936..3deaec2 100644 --- a/demos/key_manager_demo.py +++ b/demos/key_manager_demo.py @@ -1,6 +1,6 @@ import logging -from ec_tools.database import SqliteKvDao, SqliteClient, CipherKvDao +from ec_tools.database import CipherKvDao, SqliteClient, SqliteKvDao from ec_tools.tools.cipher import AesCipherGenerator from ec_tools.tools.key_manager import KeyManager diff --git a/ec_tools/data/__init__.py b/ec_tools/data/__init__.py index c8761cb..087e26c 100755 --- a/ec_tools/data/__init__.py +++ b/ec_tools/data/__init__.py @@ -1,2 +1,2 @@ +from .data_object import CustomizedJsonEncoder, DataObject, Formatter from .json_type import JsonType -from .data_object import DataObject, Formatter, CustomizedJsonEncoder diff --git a/ec_tools/data/data_object.py b/ec_tools/data/data_object.py index 6ab4f33..17301e6 100755 --- a/ec_tools/data/data_object.py +++ b/ec_tools/data/data_object.py @@ -1,7 +1,7 @@ import dataclasses import enum import json -from typing import Any, Dict, List, Callable, get_origin, get_args, Set +from typing import Any, Callable, Dict, List, Set, get_args, get_origin class CustomizedJsonEncoder(json.JSONEncoder): @@ -55,14 +55,10 @@ def _format_by_field(self, field_type, value: Any): return {self._format(get_args(field_type)[0], each) for each in value} if get_origin(field_type) in [dict, Dict]: return { - self._format(get_args(field_type)[0], k): self._format( - get_args(field_type)[1], v - ) + self._format(get_args(field_type)[0], k): self._format(get_args(field_type)[1], v) for k, v in value.items() } - raise FormatErrorException( - self.class_name, self.field, f"unknown field type {field_type} to format" - ) + raise FormatErrorException(self.class_name, self.field, f"unknown field type {field_type} to format") def _format_by_class(self, clazz: type, value: Any): if clazz is int: @@ -89,9 +85,7 @@ def _format_by_class(self, clazz: type, value: Any): self.field, f"unknown value {value} found in enum {clazz}", ) - raise FormatErrorException( - self.class_name, self.field, f"unknown type {clazz} to format" - ) + raise FormatErrorException(self.class_name, self.field, f"unknown type {clazz} to format") def _get_value(self, value: Any): if value is not None: @@ -127,26 +121,24 @@ def field_names(cls) -> List[str]: def from_json(cls, json_obj: Dict[str, Any]): function_mapping = cls._customized_load_function("_load__") return cls( - **{ - field.name: function_mapping[field.name](json_obj.get(field.name, None)) - for field in cls.fields() - } + **{field.name: function_mapping[field.name](json_obj.get(field.name, None)) for field in cls.fields()} ) def to_json(self) -> Dict[str, Any]: - default_json = json.loads(self.to_json_str()) function_mapping = self._customized_dump_function("_dump__") return { field.name: ( function_mapping[field.name](self.__dict__[field.name]) if function_mapping[field.name] - else default_json.get(field.name) + else json.loads( + json.dumps(dataclasses._asdict_inner(self.__dict__[field.name], dict), cls=CustomizedJsonEncoder) + ) ) for field in self.fields() } def to_json_str(self) -> str: - return json.dumps(dataclasses.asdict(self), cls=CustomizedJsonEncoder) + return json.dumps(self.to_json(), ensure_ascii=False, cls=CustomizedJsonEncoder) @classmethod def _customized_load_function(cls, prefix: str) -> Dict[str, Callable[[Any], Any]]: @@ -173,6 +165,4 @@ def _customized_dump_function( for item in dir(cls) if isinstance(getattr(cls, item), Callable) and item.startswith(prefix) } - return { - field.name: all_functions.get(prefix + field.name) for field in cls.fields() - } + return {field.name: all_functions.get(prefix + field.name) for field in cls.fields()} diff --git a/ec_tools/data/json_type.py b/ec_tools/data/json_type.py index 87378fe..dd62857 100755 --- a/ec_tools/data/json_type.py +++ b/ec_tools/data/json_type.py @@ -1,5 +1,3 @@ -from typing import Union, List, Dict, TypeAlias +from typing import Dict, List, TypeAlias, Union -JsonType: TypeAlias = Union[ - None, int, str, bool, float, List["JsonType"], Dict[str, "JsonType"] -] +JsonType: TypeAlias = Union[None, int, str, bool, float, List["JsonType"], Dict[str, "JsonType"]] diff --git a/ec_tools/database/__init__.py b/ec_tools/database/__init__.py index b0f4221..e8d183f 100755 --- a/ec_tools/database/__init__.py +++ b/ec_tools/database/__init__.py @@ -1,11 +1,9 @@ -from .sqlite_client.sqlite_query import SqliteQuery +from .kv_dao.cipher_kv_dao import CipherKvDao +from .kv_dao.kv_dao import ONE_THOUSAND_YEAR, KvDao +from .kv_dao.kv_data import KvData +from .kv_dao.sqlite_kv_dao import SqliteKvDao from .sqlite_client.sqlite_client import SqliteClient - +from .sqlite_client.sqlite_query import SqliteQuery from .sqlite_dao.sqlite_dao import SqliteDao from .sqlite_dao.sqlite_data_object import SqliteDataObject from .sqlite_dao.sqlite_query_generator import SqliteQueryGenerator - -from .kv_dao.kv_dao import KvDao, ONE_THOUSAND_YEAR -from .kv_dao.kv_data import KvData -from .kv_dao.sqlite_kv_dao import SqliteKvDao -from .kv_dao.cipher_kv_dao import CipherKvDao diff --git a/ec_tools/database/kv_dao/cipher_kv_dao.py b/ec_tools/database/kv_dao/cipher_kv_dao.py index e2387b9..42bd2b4 100644 --- a/ec_tools/database/kv_dao/cipher_kv_dao.py +++ b/ec_tools/database/kv_dao/cipher_kv_dao.py @@ -1,9 +1,10 @@ import dataclasses -from typing import Optional +from typing import List, Optional + from ec_tools.database.kv_dao.kv_dao import KvDao from ec_tools.database.kv_dao.sqlite_kv_dao import SqliteKvDao from ec_tools.database.sqlite_client.sqlite_client import SqliteClient -from ec_tools.tools.cipher import CipherGenerator, AesCipherGenerator, Cipher, AesMode +from ec_tools.tools.cipher import AesCipherGenerator, AesMode, Cipher, CipherGenerator @dataclasses.dataclass @@ -15,37 +16,31 @@ class CipherKvDao: @classmethod def create_sqlite_dao( - cls, db_path: str, encoding: str = "utf-8", mode: AesMode = AesMode.AES_256_CBC - ): + cls, + db_path: str, + encoding: str = "utf-8", + mode: AesMode = AesMode.AES_256_CBC, + table_name: Optional[str] = None, + ) -> "CipherKvDao": return CipherKvDao( - SqliteKvDao(sqlite_client=SqliteClient(db_path)), + SqliteKvDao(sqlite_client=SqliteClient(db_path), table_name=table_name), AesCipherGenerator(encoding, mode), ) - def get( - self, key: str, password: str, default: Optional[str] = None - ) -> Optional[str]: - value = self.get_bytes( - key, password, default.encode(self.encoding) if default else None - ) + def get(self, key: str, password: str, default: Optional[str] = None) -> Optional[str]: + value = self.get_bytes(key, password, default.encode(self.encoding) if default else None) return value.decode(self.encoding) if value else None - def get_bytes( - self, key: str, password: str, default: Optional[bytes] = None - ) -> Optional[bytes]: + def get_bytes(self, key: str, password: str, default: Optional[bytes] = None) -> Optional[bytes]: value = self.kv_dao.get(key) if value: - return self.cipher_generator.decrypt( - password.encode(self.encoding), Cipher.loads(value) - ) + return self.cipher_generator.decrypt(password.encode(self.encoding), Cipher.loads(value)) return default def set(self, key: str, password: str, value: str, duration: float = None) -> None: return self.set_bytes(key, password, value.encode(self.encoding), duration) - def set_bytes( - self, key: str, password: str, value: bytes, duration: float = None - ) -> None: + def set_bytes(self, key: str, password: str, value: bytes, duration: float = None) -> None: cipher = self.cipher_generator.encrypt(password.encode(self.encoding), value) return self.kv_dao.set(key, cipher.dumps(), duration) @@ -54,3 +49,6 @@ def delete(self, key: str) -> None: def clear(self) -> None: return self.kv_dao.clear() + + def keys(self) -> List[str]: + return self.kv_dao.keys() diff --git a/ec_tools/database/kv_dao/kv_dao.py b/ec_tools/database/kv_dao/kv_dao.py index f63c7e1..470b3ef 100755 --- a/ec_tools/database/kv_dao/kv_dao.py +++ b/ec_tools/database/kv_dao/kv_dao.py @@ -1,5 +1,5 @@ import abc -from typing import Any +from typing import Any, List ONE_THOUSAND_YEAR = 86400 * 365 * 1000 @@ -16,6 +16,8 @@ def get(self, key: str, default: Any = None) -> Any: def set(self, key: str, value: str, duration: float = None) -> None: self._set(key, value, duration or self._default_duration) + def keys(self) -> List[str]: ... + @abc.abstractmethod def delete(self, key: str) -> None: ... diff --git a/ec_tools/database/kv_dao/sqlite_kv_dao.py b/ec_tools/database/kv_dao/sqlite_kv_dao.py index 1a00648..67985f0 100755 --- a/ec_tools/database/kv_dao/sqlite_kv_dao.py +++ b/ec_tools/database/kv_dao/sqlite_kv_dao.py @@ -1,7 +1,7 @@ import time -from typing import Any +from typing import Any, List, Optional -from ec_tools.database.kv_dao.kv_dao import KvDao, ONE_THOUSAND_YEAR +from ec_tools.database.kv_dao.kv_dao import ONE_THOUSAND_YEAR, KvDao from ec_tools.database.kv_dao.kv_data import KvData from ec_tools.database.sqlite_client.sqlite_client import SqliteClient from ec_tools.database.sqlite_client.sqlite_query import SqliteQuery @@ -13,8 +13,9 @@ def __init__( self, sqlite_client: SqliteClient, default_duration: float = ONE_THOUSAND_YEAR, + table_name: Optional[str] = None, ): - SqliteDao.__init__(self, sqlite_client=sqlite_client, data_type=KvData) + SqliteDao.__init__(self, sqlite_client=sqlite_client, data_type=KvData, table_name=table_name) KvDao.__init__(self, default_duration=default_duration) def _get(self, key: str) -> Any: @@ -29,9 +30,18 @@ def _get(self, key: str) -> Any: return rows[0]["value"] if rows else None def _set(self, key: str, value: str, duration: float) -> None: - return self.insert_or_replace( - [KvData(key=key, value=value, expired_at=time.time() + duration)] - ) + return self.insert_or_replace([KvData(key=key, value=value, expired_at=time.time() + duration)]) + + def keys(self) -> List[str]: + rows = self.execute( + SqliteQuery( + f"SELECT key FROM {self._table_name}", + "WHERE expired_at > ?", + args=[time.time()], + ), + commit=False, + )[0] + return [row["key"] for row in rows] def delete(self, key: str) -> None: return self.delete_by_values(key=key) diff --git a/ec_tools/database/sqlite_client/sqlite_client.py b/ec_tools/database/sqlite_client/sqlite_client.py index bef571b..0ecaa1e 100755 --- a/ec_tools/database/sqlite_client/sqlite_client.py +++ b/ec_tools/database/sqlite_client/sqlite_client.py @@ -18,9 +18,7 @@ class SqliteClient: def __init__(self, db_path: str): self._db_path = db_path self._conn = sqlite3.connect(db_path, check_same_thread=False) - self._conn.row_factory = lambda cursor, row: { - col[0]: row[idx] for idx, col in enumerate(cursor.description) - } + self._conn.row_factory = lambda cursor, row: {col[0]: row[idx] for idx, col in enumerate(cursor.description)} self._cursor = self._conn.cursor() self._lock = threading.Lock() self._cursor.execute("PRAGMA foreign_keys = ON") @@ -32,10 +30,7 @@ def __del__(self): def execute(self, *queries: SqliteQuery, commit=True): _ = [self._logger.debug("SQL: %s", query) for query in queries] with self._lock: - results = [ - self._cursor.execute(query.sql, query.args).fetchall() - for query in queries - ] + results = [self._cursor.execute(query.sql, query.args).fetchall() for query in queries] if commit: self._conn.commit() return results diff --git a/ec_tools/database/sqlite_client/sqlite_query.py b/ec_tools/database/sqlite_client/sqlite_query.py index fef93e4..3080a6e 100755 --- a/ec_tools/database/sqlite_client/sqlite_query.py +++ b/ec_tools/database/sqlite_client/sqlite_query.py @@ -1,5 +1,5 @@ import dataclasses -from typing import List, Any +from typing import Any, List @dataclasses.dataclass diff --git a/ec_tools/database/sqlite_dao/sqlite_dao.py b/ec_tools/database/sqlite_dao/sqlite_dao.py index fe87642..b90ee31 100755 --- a/ec_tools/database/sqlite_dao/sqlite_dao.py +++ b/ec_tools/database/sqlite_dao/sqlite_dao.py @@ -1,4 +1,4 @@ -from typing import Type, List, Generic, TypeVar, Dict, Any +from typing import Any, Dict, Generic, List, Optional, Type, TypeVar from ec_tools.database.sqlite_client.sqlite_client import SqliteClient from ec_tools.database.sqlite_client.sqlite_query import SqliteQuery @@ -14,11 +14,16 @@ class SqliteDao(Generic[T]): _sql_generator: SqliteQueryGenerator _table_name: str - def __init__(self, sqlite_client: SqliteClient, data_type: Type[T]): + def __init__( + self, + sqlite_client: SqliteClient, + data_type: Type[T], + table_name: Optional[str] = None, + ): self._sqlite_client = sqlite_client self._data_type = data_type - self._table_name = self._data_type.table_name() - self._sql_generator = SqliteQueryGenerator(clz=self._data_type) + self._table_name = table_name or self._data_type.__name__ + self._sql_generator = SqliteQueryGenerator(clz=self._data_type, table_name=self._table_name) self.create_table() @property @@ -32,19 +37,13 @@ def drop_table(self): self._sqlite_client.execute(self._sql_generator.drop_table_sql) def insert(self, objs: List[T]): - return self._sqlite_client.batch_insert( - self._sql_generator.insert_sql(objs=objs) - ) + return self._sqlite_client.batch_insert(self._sql_generator.insert_sql(objs=objs)) def insert_or_replace(self, objs: List[T]): - return self._sqlite_client.batch_insert( - self._sql_generator.insert_sql(objs=objs, supplement="OR REPLACE") - ) + return self._sqlite_client.batch_insert(self._sql_generator.insert_sql(objs=objs, supplement="OR REPLACE")) def insert_or_ignore(self, objs: List[T]): - return self._sqlite_client.batch_insert( - self._sql_generator.insert_sql(objs=objs, supplement="OR IGNORE") - ) + return self._sqlite_client.batch_insert(self._sql_generator.insert_sql(objs=objs, supplement="OR IGNORE")) def delete_by_values(self, **value_map): return self._sqlite_client.execute( @@ -54,17 +53,13 @@ def delete_by_values(self, **value_map): def count_by_values(self, **value_map) -> int: return self.count_group_by_values([], **value_map)[0]["count"] - def count_group_by_values( - self, group_by: List[str], **value_map - ) -> List[Dict[str, Any]]: + def count_group_by_values(self, group_by: List[str], **value_map) -> List[Dict[str, Any]]: return self._sqlite_client.execute( self._sql_generator.count_by_values_sql(value_map, group_by or []), commit=False, )[0] - def query_by_values( - self, limit: int = 100, offset: int = 0, **value_map - ) -> List[T]: + def query_by_values(self, limit: int = 100, offset: int = 0, **value_map) -> List[T]: return list( map( self._data_type.from_json, @@ -76,9 +71,7 @@ def query_fields_by_values( self, fields: List[str], limit: int = 100, offset: int = 0, **value_map ) -> List[Dict[str, Any]]: return self._sqlite_client.execute( - self._sql_generator.query_by_values_sql( - value_map, limit=limit, offset=offset, fields=fields or [] - ), + self._sql_generator.query_by_values_sql(value_map, limit=limit, offset=offset, fields=fields or []), commit=False, )[0] diff --git a/ec_tools/database/sqlite_dao/sqlite_query_generator.py b/ec_tools/database/sqlite_dao/sqlite_query_generator.py index 0942df9..e4aaaa2 100755 --- a/ec_tools/database/sqlite_dao/sqlite_query_generator.py +++ b/ec_tools/database/sqlite_dao/sqlite_query_generator.py @@ -1,5 +1,5 @@ import dataclasses -from typing import List, Any, Type, Dict, Collection +from typing import Any, Collection, Dict, List, Type from ec_tools.database.sqlite_client.sqlite_query import SqliteQuery from ec_tools.database.sqlite_dao.sqlite_data_object import SqliteDataObject @@ -13,9 +13,9 @@ class SqliteQueryGenerator: _unique_keys: List[List[str]] _field_names: List[str] - def __init__(self, clz: Type[SqliteDataObject]): + def __init__(self, clz: Type[SqliteDataObject], table_name: str): self._data_type = clz - self._table_name = clz.table_name() + self._table_name = table_name self._primary_keys = clz.primary_keys() self._unique_keys = clz.unique_keys() self._field_names = clz.field_names() @@ -45,17 +45,13 @@ def create_table_sql(self) -> List[SqliteQuery]: def drop_table_sql(self) -> SqliteQuery: return SqliteQuery(f"DROP TABLE IF EXISTS {self._table_name}") - def insert_sql( - self, objs: Collection[SqliteDataObject], supplement: str = "" - ) -> SqliteQuery: + def insert_sql(self, objs: Collection[SqliteDataObject], supplement: str = "") -> SqliteQuery: items = [obj.to_json() for obj in objs] return SqliteQuery( f"INSERT {supplement} INTO {self._table_name}", f"({', '.join(self._field_names)})", f"VALUES ({', '.join('?' * len(self._field_names))})", - args=[ - [item[field_name] for field_name in self._field_names] for item in items - ], + args=[[item[field_name] for field_name in self._field_names] for item in items], ) def query_by_values_sql( @@ -123,15 +119,12 @@ def _build_constraint_sqls(self) -> List[str]: if not self._unique_keys: return [] return [ - f"CONSTRAINT {'__'.join(['unique', *unique])} UNIQUE ({', '.join(unique)})" - for unique in self._unique_keys + f"CONSTRAINT {'__'.join(['unique', *unique])} UNIQUE ({', '.join(unique)})" for unique in self._unique_keys ] # avoid injection or unexpected behavior def _filter_keys(self, keys: Collection[str], strict: bool = True): if strict: unexpected_keys = [key for key in keys if key not in self._field_names] - assert ( - len(unexpected_keys) == 0 - ), f"unexpected keys found: {unexpected_keys}" + assert len(unexpected_keys) == 0, f"unexpected keys found: {unexpected_keys}" return [key for key in keys if key in self._field_names] diff --git a/ec_tools/tools/cipher/__init__.py b/ec_tools/tools/cipher/__init__.py index efa4baa..1668738 100644 --- a/ec_tools/tools/cipher/__init__.py +++ b/ec_tools/tools/cipher/__init__.py @@ -1,5 +1,5 @@ from .aes_cipher import AesCipher from .aes_config import AesConfig, AesMode from .cipher import Cipher, SecrectKey -from .cipher_generator.cipher_generator import CipherGenerator from .cipher_generator.aes_cipher_generator import AesCipherGenerator +from .cipher_generator.cipher_generator import CipherGenerator diff --git a/ec_tools/tools/cipher/aes_cipher.py b/ec_tools/tools/cipher/aes_cipher.py index 6f500bf..52b88dd 100644 --- a/ec_tools/tools/cipher/aes_cipher.py +++ b/ec_tools/tools/cipher/aes_cipher.py @@ -1,10 +1,10 @@ -import hashlib -import os import dataclasses +import os from Crypto.Cipher import AES -from ec_tools.tools.cipher.cipher import SecrectKey + from ec_tools.tools.cipher.aes_config import AesMode +from ec_tools.tools.cipher.password_tool import generate_key, generate_password @dataclasses.dataclass @@ -46,10 +46,7 @@ def _augment_bytes(self, data: bytes) -> bytes: os.urandom(self.aes_mode.value.key_size), data, os.urandom(self.aes_mode.value.key_size), - os.urandom( - self.aes_mode.value.key_size - - len(data) % self.aes_mode.value.key_size - ), + os.urandom(self.aes_mode.value.key_size - len(data) % self.aes_mode.value.key_size), ] ) @@ -58,23 +55,13 @@ def _recover_bytes(self, data: bytes, size: int) -> bytes: @classmethod def generate_password(cls, password: str, salt: str, iterations: int) -> bytes: - return hashlib.pbkdf2_hmac( - "sha512", - password.encode("utf-8"), - bytes.fromhex(salt), - iterations, - ) + return generate_password(password, salt, iterations) def generate_key(self, password: bytes, salt: bytes, iterations: int): - hsh = hashlib.pbkdf2_hmac( - "sha512", + return generate_key( password, salt, + self.aes_mode.value.key_size, + self.aes_mode.value.iv_size, iterations, - dklen=self.aes_mode.value.key_size + self.aes_mode.value.iv_size, - ) - return SecrectKey( - key=hsh[: self.aes_mode.value.key_size], - iv=hsh[self.aes_mode.value.key_size :], - salt=salt, ) diff --git a/ec_tools/tools/cipher/chunk_config.py b/ec_tools/tools/cipher/chunk_config.py new file mode 100644 index 0000000..d3279ae --- /dev/null +++ b/ec_tools/tools/cipher/chunk_config.py @@ -0,0 +1,37 @@ +import dataclasses +import json +from typing import Dict + +from ec_tools.tools.cipher.aes_config import AesMode + + +@dataclasses.dataclass +class ChunkConfig: + salt: bytes + aes_mode: AesMode + iterations: int + chunk_size: int + file_size: int + + def to_json(self) -> Dict[str, str]: + return { + "salt": self.salt.hex(), + "mode": self.aes_mode.name, + "iterations": self.iterations, + "chunk_size": self.chunk_size, + "file_size": self.file_size, + } + + def to_json_bytes(self) -> bytes: + return json.dumps(self.to_json()).encode("utf-8") + + @classmethod + def from_json(cls, json_str: str) -> "ChunkConfig": + data = json.loads(json_str) + return cls( + salt=bytes.fromhex(data["salt"]), + aes_mode=AesMode[data["mode"]], + iterations=data["iterations"], + chunk_size=data["chunk_size"], + file_size=data["file_size"], + ) diff --git a/ec_tools/tools/cipher/chunk_encryption_utils.py b/ec_tools/tools/cipher/chunk_encryption_utils.py new file mode 100644 index 0000000..b0f10c6 --- /dev/null +++ b/ec_tools/tools/cipher/chunk_encryption_utils.py @@ -0,0 +1,65 @@ +import os +from typing import Generator + +from Crypto.Cipher import AES + +from ec_tools.tools.cipher import password_tool +from ec_tools.tools.cipher.aes_config import AesMode +from ec_tools.tools.cipher.chunk_config import ChunkConfig +from ec_tools.tools.cipher.cipher import SecrectKey + + +def padding_chunk(chunk: bytes, block_size: int) -> bytes: + pad_size = (block_size - (len(chunk) % block_size)) % block_size + return chunk + os.urandom(pad_size) + + +def encrypt_by_chunk( + chunk_generator: Generator[bytes, None, None], + password: str, + salt: str, + aes_mode: AesMode = AesMode.AES_256_CBC, + interations: int = 10000, +) -> Generator[bytes, None, None]: + secret_key: SecrectKey = password_tool.generate_key( + password=password.encode("utf-8"), + salt=salt, + key_size=aes_mode.value.key_size, + iv_size=aes_mode.value.iv_size, + iterations=interations, + ) + aes = AES.new(secret_key.key, aes_mode.value.mode, iv=secret_key.iv) + for chunk in chunk_generator: + chunk = padding_chunk(chunk, aes_mode.value.key_size) + chunk = os.urandom(aes_mode.value.key_size) + chunk + os.urandom(aes_mode.value.key_size) + encrypted_chunk = aes.encrypt(chunk) + yield encrypted_chunk + + +def decrypt_by_chunk( + chunk_generator: Generator[bytes, None, None], + password: str, + chunk_config: ChunkConfig, +) -> Generator[bytes, None, None]: + secret_key: SecrectKey = password_tool.generate_key( + password=password.encode("utf-8"), + salt=chunk_config.salt, + key_size=chunk_config.aes_mode.value.key_size, + iv_size=chunk_config.aes_mode.value.iv_size, + iterations=chunk_config.iterations, + ) + aes = AES.new(secret_key.key, chunk_config.aes_mode.value.mode, iv=secret_key.iv) + total_read = 0 + for chunk in chunk_generator: + if total_read == chunk_config.file_size: + raise Exception("File size mismatch, decryption may be incorrect.") + decrypted_chunk = aes.decrypt(chunk) + decrypted_chunk = decrypted_chunk[ + chunk_config.aes_mode.value.key_size : len(decrypted_chunk) - chunk_config.aes_mode.value.key_size + ] + total_read += len(decrypted_chunk) + if total_read > chunk_config.file_size: + trip_size = total_read - chunk_config.file_size + total_read -= trip_size + decrypted_chunk = decrypted_chunk[:-trip_size] + yield decrypted_chunk diff --git a/ec_tools/tools/cipher/cipher.py b/ec_tools/tools/cipher/cipher.py index 0d1aed5..72782cc 100644 --- a/ec_tools/tools/cipher/cipher.py +++ b/ec_tools/tools/cipher/cipher.py @@ -1,5 +1,5 @@ -import json import dataclasses +import json @dataclasses.dataclass diff --git a/ec_tools/tools/cipher/cipher_generator/aes_cipher_generator.py b/ec_tools/tools/cipher/cipher_generator/aes_cipher_generator.py index 39a6a94..7cda726 100644 --- a/ec_tools/tools/cipher/cipher_generator/aes_cipher_generator.py +++ b/ec_tools/tools/cipher/cipher_generator/aes_cipher_generator.py @@ -1,11 +1,12 @@ -import os -import hashlib import dataclasses +import os + from Crypto.Cipher import AES -from ec_tools.tools.cipher.cipher import Cipher, SecrectKey -from ec_tools.tools.cipher.cipher_generator.cipher_generator import CipherGenerator +from ec_tools.tools.cipher import password_tool from ec_tools.tools.cipher.aes_config import AesMode +from ec_tools.tools.cipher.cipher import Cipher +from ec_tools.tools.cipher.cipher_generator.cipher_generator import CipherGenerator @dataclasses.dataclass @@ -47,15 +48,10 @@ def _augment_bytes(cls, data: bytes, padding_size: int) -> bytes: return mixture def _generate_key(self, password: bytes, salt: bytes): - hsh = hashlib.pbkdf2_hmac( - "sha512", + return password_tool.generate_key( password, salt, + self.mode.value.key_size, + self.mode.value.iv_size, self.pbkdf2_iterations, - dklen=self.mode.value.key_size + self.mode.value.iv_size, - ) - return SecrectKey( - key=hsh[: self.mode.value.key_size], - iv=hsh[self.mode.value.key_size :], - salt=salt, ) diff --git a/ec_tools/tools/cipher/cipher_generator/cipher_generator.py b/ec_tools/tools/cipher/cipher_generator/cipher_generator.py index ba065df..462b921 100644 --- a/ec_tools/tools/cipher/cipher_generator/cipher_generator.py +++ b/ec_tools/tools/cipher/cipher_generator/cipher_generator.py @@ -1,5 +1,6 @@ -import dataclasses import abc +import dataclasses + from ec_tools.tools.cipher.cipher import Cipher diff --git a/ec_tools/tools/cipher/file_encryption_utils.py b/ec_tools/tools/cipher/file_encryption_utils.py new file mode 100644 index 0000000..2fc9b27 --- /dev/null +++ b/ec_tools/tools/cipher/file_encryption_utils.py @@ -0,0 +1,44 @@ +import os + +from ec_tools.tools.cipher.aes_config import AesMode +from ec_tools.tools.cipher.chunk_config import ChunkConfig +from ec_tools.tools.cipher.chunk_encryption_utils import decrypt_by_chunk, encrypt_by_chunk +from ec_tools.utils.io_utils import chunk_read_file + + +def encrypt_file( + input_file: str, + output_file: str, + password: str, + aes_mode: AesMode = AesMode.AES_256_CBC, + iterations: int = 10000, + chunk_size: int = 1024 * 1024, +): + assert chunk_size % aes_mode.value.key_size == 0, "Chunk size must be a multiple of the AES key size." + salt = os.urandom(32) + file_size = os.path.getsize(input_file) + outf = open(output_file, "wb") + inf = open(input_file, "rb") + chunk_config = ChunkConfig( + salt=salt, aes_mode=aes_mode, iterations=iterations, chunk_size=chunk_size, file_size=file_size + ) + outf.write(chunk_config.to_json_bytes() + b"\n") + input_generator = chunk_read_file(inf, chunk_size) + for encrypt_chunk in encrypt_by_chunk(input_generator, password, salt, aes_mode, iterations): + outf.write(encrypt_chunk) + outf.flush() + inf.close() + outf.close() + + +def decrypt_file(input_file: str, output_file: str, password: str): + inf = open(input_file, "rb") + outf = open(output_file, "wb") + header = inf.readline().decode("utf-8").strip() + chunk_config = ChunkConfig.from_json(header) + chunk_generator = chunk_read_file(inf, chunk_config.chunk_size + chunk_config.aes_mode.value.key_size * 2) + for decrypted_chunk in decrypt_by_chunk(chunk_generator, password, chunk_config): + outf.write(decrypted_chunk) + outf.flush() + outf.close() + inf.close() diff --git a/ec_tools/tools/cipher/password_tool.py b/ec_tools/tools/cipher/password_tool.py new file mode 100644 index 0000000..aa17afe --- /dev/null +++ b/ec_tools/tools/cipher/password_tool.py @@ -0,0 +1,27 @@ +import hashlib + +from ec_tools.tools.cipher.cipher import SecrectKey + + +def generate_password(password: str, salt: str, iterations: int) -> bytes: + return hashlib.pbkdf2_hmac( + "sha512", + password.encode("utf-8"), + bytes.fromhex(salt), + iterations, + ) + + +def generate_key(password: bytes, salt: bytes, key_size: int, iv_size: int, iterations: int): + hsh = hashlib.pbkdf2_hmac( + "sha512", + password, + salt, + iterations, + dklen=key_size + iv_size, + ) + return SecrectKey( + key=hsh[:key_size], + iv=hsh[key_size:], + salt=salt, + ) diff --git a/ec_tools/tools/headers_manager.py b/ec_tools/tools/headers_manager.py index 56a3d89..5a9b67e 100644 --- a/ec_tools/tools/headers_manager.py +++ b/ec_tools/tools/headers_manager.py @@ -1,6 +1,7 @@ -import dataclasses import copy +import dataclasses from typing import Dict, List + from ec_tools.tools.key_manager import KeyManager diff --git a/ec_tools/tools/key_manager.py b/ec_tools/tools/key_manager.py index bbe1413..2c44462 100644 --- a/ec_tools/tools/key_manager.py +++ b/ec_tools/tools/key_manager.py @@ -1,6 +1,6 @@ import dataclasses import threading -from typing import Dict, Optional, Callable, List +from typing import Callable, Dict, List, Optional from ec_tools.database import CipherKvDao diff --git a/ec_tools/tools/thread_pool.py b/ec_tools/tools/thread_pool.py index 6ced028..b1f0c3b 100755 --- a/ec_tools/tools/thread_pool.py +++ b/ec_tools/tools/thread_pool.py @@ -1,7 +1,7 @@ import logging import time -from concurrent.futures import ThreadPoolExecutor, Future -from typing import List, Callable, Any, Optional +from concurrent.futures import Future, ThreadPoolExecutor +from typing import Any, Callable, List, Optional class CustomThreadPoolExecutor(ThreadPoolExecutor): diff --git a/ec_tools/utils/hash_utils.py b/ec_tools/utils/hash_utils.py index c187729..61e0cb8 100644 --- a/ec_tools/utils/hash_utils.py +++ b/ec_tools/utils/hash_utils.py @@ -12,9 +12,7 @@ def hmac_sha256_text(key: str, value: str, encoding="utf-8") -> str: def hmac_md5_text(key: str, value: str, encoding="utf-8") -> str: - return hmac.new( - key.encode(encoding=encoding), value.encode(encoding=encoding), hashlib.md5 - ).hexdigest() + return hmac.new(key.encode(encoding=encoding), value.encode(encoding=encoding), hashlib.md5).hexdigest() def calc_md5(file_path: str, batch_size: int = 8192) -> str: diff --git a/ec_tools/utils/io_utils.py b/ec_tools/utils/io_utils.py index 3cb3070..ba0d251 100755 --- a/ec_tools/utils/io_utils.py +++ b/ec_tools/utils/io_utils.py @@ -1,8 +1,17 @@ import json import os +from io import TextIOWrapper +from typing import Generator, List -from typing import List -from ec_tools.data import JsonType, CustomizedJsonEncoder +from ec_tools.data import CustomizedJsonEncoder, JsonType + + +def chunk_read_file(fp: TextIOWrapper, chunk_size: int = 1024 * 1024) -> Generator[str, None, None]: + while True: + chunk = fp.read(chunk_size) + if not chunk: + break + yield chunk def load_json(path: str, default: JsonType = None, encoding: str = "utf-8") -> JsonType: diff --git a/ec_tools/utils/misc.py b/ec_tools/utils/misc.py index 728aefe..4a252b9 100755 --- a/ec_tools/utils/misc.py +++ b/ec_tools/utils/misc.py @@ -1,11 +1,22 @@ import logging -from typing import Dict, Any, Optional +from typing import Any, Dict, Generator, List, Optional def remove_none_from_dict(obj: Dict[str, Any]) -> Dict[str, Any]: return {k: v for k, v in obj.items() if v is not None} +def get_batch(generator: Generator, batch_size: int) -> Generator[List, None, None]: + batch = [] + for item in generator: + batch.append(item) + if len(batch) >= batch_size: + yield batch + batch = [] + if batch: + yield batch + + def get_logger( level: int = logging.DEBUG, logger_name: Optional[str] = None, diff --git a/ec_tools/utils/os_utils.py b/ec_tools/utils/os_utils.py index 4bcb32a..5307214 100755 --- a/ec_tools/utils/os_utils.py +++ b/ec_tools/utils/os_utils.py @@ -1,5 +1,5 @@ import os -from typing import Set, Callable, List +from typing import Callable, List, Set def list_files(path: str, eligible: Callable[[str], bool]) -> List[str]: diff --git a/ec_tools/utils/timer.py b/ec_tools/utils/timer.py new file mode 100644 index 0000000..fa95b25 --- /dev/null +++ b/ec_tools/utils/timer.py @@ -0,0 +1,36 @@ +import logging +import time + +from colorama import Fore, Style + + +class Timer: + msg: str + start_time: float + muted: bool + + def __init__(self, msg: str): + self.msg = msg + self.muted = False + self.start_time = time.time() + logging.info(f"{Fore.CYAN}{msg}{Style.RESET_ALL}") + logging.info("=" * 32) + + def __enter__(self): + return self + + def mute(self): + self.muted = True + + def __exit__(self, *args): + if self.muted: + return + elapsed = time.time() - self.start_time + logging.info("=" * 32) + if elapsed < 1: + logging.info(f"{Fore.CYAN}$ # finished in {elapsed * 1000:.2f}ms{Style.RESET_ALL}") + elif elapsed < 60: + logging.info(f"{Fore.CYAN}$ # finished in {elapsed:.3f}s{Style.RESET_ALL}") + else: + logging.info(f"{Fore.CYAN}$ # finished in {elapsed / 60}m {elapsed % 60}s{Style.RESET_ALL}") + logging.info("\n") diff --git a/scripts/lint.sh b/scripts/lint.sh deleted file mode 100755 index c4f7a4c..0000000 --- a/scripts/lint.sh +++ /dev/null @@ -1,4 +0,0 @@ -autoflake --in-place --remove-all-unused-imports --remove-unused-variables $(find ec_tools demos tests -name "*.py" -not -path "**/__init__.py") -black ec_tools demos tests -pylint ec_tools demos tests - diff --git a/scripts/local.sh b/scripts/local.sh index 508ea16..45922e0 100755 --- a/scripts/local.sh +++ b/scripts/local.sh @@ -1,5 +1,4 @@ -autoflake --in-place --remove-all-unused-imports --remove-unused-variables $(find ec_tools demos tests -name "*.py" -not -path "**/__init__.py") -black . +pre-commit run --all-files pytest rm -rf dist/ build/ *.egg-info .pytest_cache/ find . | grep -E "(/__pycache__$|/\.DS_Store$)" | xargs rm -rf diff --git a/setup.py b/setup.py index 89e02d5..c949a4a 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setuptools.setup( name="ec_tools", - version="2.7", + version="2.8", description="EC Tools", packages=setuptools.find_packages(exclude=["tests", "tests.*"]), install_requires=["dataclasses", "typing"], diff --git a/tests/database/cipher_kv_dao_test.py b/tests/database/cipher_kv_dao_test.py index 1a51dc5..cd0c776 100755 --- a/tests/database/cipher_kv_dao_test.py +++ b/tests/database/cipher_kv_dao_test.py @@ -2,7 +2,7 @@ import time import unittest -from ec_tools.database import SqliteKvDao, SqliteClient, CipherKvDao +from ec_tools.database import CipherKvDao, SqliteClient, SqliteKvDao from ec_tools.tools.cipher import AesCipherGenerator logging.basicConfig(level=logging.DEBUG) diff --git a/tests/database/data_object_complex_test.py b/tests/database/data_object_complex_test.py index 665290b..07d8e9f 100644 --- a/tests/database/data_object_complex_test.py +++ b/tests/database/data_object_complex_test.py @@ -1,6 +1,6 @@ import dataclasses import unittest -from typing import List, Dict, Optional, Set +from typing import Dict, List, Optional, Set from ec_tools.data import DataObject diff --git a/tests/database/data_object_enum_test.py b/tests/database/data_object_enum_test.py index 85f4b79..9937245 100644 --- a/tests/database/data_object_enum_test.py +++ b/tests/database/data_object_enum_test.py @@ -25,6 +25,6 @@ class DataObjectEnumTest(unittest.TestCase): def test_enums(self): a = Sea([SeaAnimalType.FISH, SeaAnimalType.SHRIMP]) a_json = a.to_json() - aa = Sea.from_json(a_json) self.assertEqual({"animals": ["FISH", "SHRIMP"]}, a_json) + aa = Sea.from_json(a_json) self.assertEqual(a, aa) diff --git a/tests/database/data_object_union_test.py b/tests/database/data_object_union_test.py index 29d93ad..9f09e2e 100644 --- a/tests/database/data_object_union_test.py +++ b/tests/database/data_object_union_test.py @@ -1,6 +1,6 @@ import dataclasses import unittest -from typing import List, Dict, Union, Any +from typing import Any, Dict, List, Union from ec_tools.data import DataObject @@ -40,10 +40,7 @@ def test(self): b_json = {"a": [{"1": {"x": 1, "y": "2"}}]} with self.assertRaises(Exception) as e: B.from_json(b_json) - self.assertTrue( - 'A#y is not serializable, try to add function "_load__y" with "@classmethod" to A' - in e - ) + self.assertTrue('A#y is not serializable, try to add function "_load__y" with "@classmethod" to A' in e) def test_load(self): d_json = {"c": [{"1": {"x": 1, "y": "2"}}]} diff --git a/tests/database/kv_dao_test.py b/tests/database/kv_dao_test.py index 9fe2c87..0939d61 100755 --- a/tests/database/kv_dao_test.py +++ b/tests/database/kv_dao_test.py @@ -2,7 +2,7 @@ import time import unittest -from ec_tools.database import SqliteKvDao, SqliteClient +from ec_tools.database import SqliteClient, SqliteKvDao logging.basicConfig(level=logging.DEBUG) @@ -36,6 +36,8 @@ def test(self): self.assertEqual(self.kv_dao.get("hi"), "how are you") self.assertEqual(self.kv_dao.get("hello"), None) + self.assertEqual(self.kv_dao.keys(), ["hi"]) # delete hi self.kv_dao.delete("hi") self.assertEqual(self.kv_dao.get("hi"), None) + self.assertEqual(self.kv_dao.keys(), []) diff --git a/tests/database/sqlite_dao_test.py b/tests/database/sqlite_dao_test.py index 1d5d237..6e173a3 100755 --- a/tests/database/sqlite_dao_test.py +++ b/tests/database/sqlite_dao_test.py @@ -4,7 +4,7 @@ import unittest from typing import List -from ec_tools.database import SqliteDataObject, SqliteClient, SqliteDao +from ec_tools.database import SqliteClient, SqliteDao, SqliteDataObject logging.basicConfig(level=logging.DEBUG) @@ -36,12 +36,8 @@ def test(self): self.sqlite_dao.create_table() result = self.sqlite_dao.insert( [ - Fruit( - name="apple", weight=1.1, price=1, is_delicious=True, remaining=1 - ), - Fruit( - name="banana", weight=0.1, price=0, is_delicious=False, remaining=0 - ), + Fruit(name="apple", weight=1.1, price=1, is_delicious=True, remaining=1), + Fruit(name="banana", weight=0.1, price=0, is_delicious=False, remaining=0), ] ) logging.info("insert result: %s", result) diff --git a/tests/tools/cipher/cipher_generator_test.py b/tests/tools/cipher/cipher_generator_test.py index a2ae174..5d9d016 100755 --- a/tests/tools/cipher/cipher_generator_test.py +++ b/tests/tools/cipher/cipher_generator_test.py @@ -2,7 +2,7 @@ import os import unittest -from ec_tools.tools.cipher import AesCipherGenerator, CipherGenerator, AesMode +from ec_tools.tools.cipher import AesCipherGenerator, AesMode, CipherGenerator logging.basicConfig(level=logging.DEBUG) diff --git a/tests/utils/io_test.py b/tests/utils/io_test.py index a39c363..df95074 100644 --- a/tests/utils/io_test.py +++ b/tests/utils/io_test.py @@ -12,15 +12,11 @@ def test_load_json(self): ) self.assertEqual( {}, - io_utils.load_json( - os.path.join(os.path.dirname(__file__), "test_not_exist.json"), {} - ), + io_utils.load_json(os.path.join(os.path.dirname(__file__), "test_not_exist.json"), {}), ) self.assertRaises( IOError, - lambda: io_utils.load_json( - os.path.join(os.path.dirname(__file__), "test_not_exist.json") - ), + lambda: io_utils.load_json(os.path.join(os.path.dirname(__file__), "test_not_exist.json")), )