diff --git a/README.md b/README.md old mode 100755 new mode 100644 diff --git a/demos/headers_manager_demo.py b/demos/headers_manager_demo.py new file mode 100644 index 0000000..ed85048 --- /dev/null +++ b/demos/headers_manager_demo.py @@ -0,0 +1,35 @@ +import logging +import time +import unittest + +from ec_tools.database import SqliteKvDao, SqliteClient, CipherKvDao +from ec_tools.tools.cipher import AesCipherGenerator +from ec_tools.tools.key_manager import KeyManager + +from ec_tools.tools.headers_manager import HeadersManager + +logging.basicConfig(level=logging.DEBUG) + + +class HeadersManagerTest: + sqlite_client = SqliteClient(":memory:") + kv_dao = SqliteKvDao(sqlite_client) + cipher_generator = AesCipherGenerator() + cipher_kv_dao = CipherKvDao(kv_dao, cipher_generator) + manager = HeadersManager( + KeyManager(cipher_kv_dao, "12345678", hint_function=lambda: print("hi")), + ["cookies", "auth"], + {"a": "c"}, + ) + + def test(self): + print(self.manager.get()) + print(self.manager.get()) + + +def test(): + HeadersManagerTest().test() + + +if __name__ == "__main__": + test() diff --git a/demos/key_manager_demo.py b/demos/key_manager_demo.py new file mode 100644 index 0000000..6b57689 --- /dev/null +++ b/demos/key_manager_demo.py @@ -0,0 +1,30 @@ +import logging +import time +import unittest + +from ec_tools.database import SqliteKvDao, SqliteClient, CipherKvDao +from ec_tools.tools.cipher import AesCipherGenerator +from ec_tools.tools.key_manager import KeyManager + +logging.basicConfig(level=logging.DEBUG) + + +class KeyManagerTest: + sqlite_client = SqliteClient(":memory:") + kv_dao = SqliteKvDao(sqlite_client) + cipher_generator = AesCipherGenerator() + cipher_kv_dao = CipherKvDao(kv_dao, cipher_generator) + manager = KeyManager(cipher_kv_dao, "12345678", hint_function=lambda: print("hi")) + + def test(self): + print(self.manager.get_keys(["a", "b", "c"])) + print(self.manager.get_key("a")) + print(self.manager.get_key("d")) + + +def test(): + KeyManagerTest().test() + + +if __name__ == "__main__": + test() diff --git a/ec_tools/database/__init__.py b/ec_tools/database/__init__.py index 3120137..b0f4221 100755 --- a/ec_tools/database/__init__.py +++ b/ec_tools/database/__init__.py @@ -8,3 +8,4 @@ from .kv_dao.kv_dao import KvDao, ONE_THOUSAND_YEAR from .kv_dao.kv_data import KvData from .kv_dao.sqlite_kv_dao import SqliteKvDao +from .kv_dao.cipher_kv_dao import CipherKvDao diff --git a/ec_tools/database/kv_dao/cipher_kv_dao.py b/ec_tools/database/kv_dao/cipher_kv_dao.py new file mode 100644 index 0000000..e2387b9 --- /dev/null +++ b/ec_tools/database/kv_dao/cipher_kv_dao.py @@ -0,0 +1,56 @@ +import dataclasses +from typing import Optional +from ec_tools.database.kv_dao.kv_dao import KvDao +from ec_tools.database.kv_dao.sqlite_kv_dao import SqliteKvDao +from ec_tools.database.sqlite_client.sqlite_client import SqliteClient +from ec_tools.tools.cipher import CipherGenerator, AesCipherGenerator, Cipher, AesMode + + +@dataclasses.dataclass +class CipherKvDao: + kv_dao: KvDao + cipher_generator: CipherGenerator + + encoding: str = "utf-8" + + @classmethod + def create_sqlite_dao( + cls, db_path: str, encoding: str = "utf-8", mode: AesMode = AesMode.AES_256_CBC + ): + return CipherKvDao( + SqliteKvDao(sqlite_client=SqliteClient(db_path)), + AesCipherGenerator(encoding, mode), + ) + + def get( + self, key: str, password: str, default: Optional[str] = None + ) -> Optional[str]: + value = self.get_bytes( + key, password, default.encode(self.encoding) if default else None + ) + return value.decode(self.encoding) if value else None + + def get_bytes( + self, key: str, password: str, default: Optional[bytes] = None + ) -> Optional[bytes]: + value = self.kv_dao.get(key) + if value: + return self.cipher_generator.decrypt( + password.encode(self.encoding), Cipher.loads(value) + ) + return default + + def set(self, key: str, password: str, value: str, duration: float = None) -> None: + return self.set_bytes(key, password, value.encode(self.encoding), duration) + + def set_bytes( + self, key: str, password: str, value: bytes, duration: float = None + ) -> None: + cipher = self.cipher_generator.encrypt(password.encode(self.encoding), value) + return self.kv_dao.set(key, cipher.dumps(), duration) + + def delete(self, key: str) -> None: + return self.kv_dao.delete(key) + + def clear(self) -> None: + return self.kv_dao.clear() diff --git a/ec_tools/database/kv_dao/kv_dao.py b/ec_tools/database/kv_dao/kv_dao.py index 89a3db2..f63c7e1 100755 --- a/ec_tools/database/kv_dao/kv_dao.py +++ b/ec_tools/database/kv_dao/kv_dao.py @@ -17,20 +17,16 @@ def set(self, key: str, value: str, duration: float = None) -> None: self._set(key, value, duration or self._default_duration) @abc.abstractmethod - def delete(self, key: str) -> None: - ... + def delete(self, key: str) -> None: ... @abc.abstractmethod - def clear(self) -> None: - ... + def clear(self) -> None: ... @abc.abstractmethod - def _get(self, key: str) -> Any: - ... + def _get(self, key: str) -> Any: ... @abc.abstractmethod - def _set(self, key: str, value: str, duration: float) -> None: - ... + def _set(self, key: str, value: str, duration: float) -> None: ... def __getitem__(self, key: str) -> Any: return self._get(key) diff --git a/ec_tools/database/sqlite_dao/sqlite_data_object.py b/ec_tools/database/sqlite_dao/sqlite_data_object.py index 4e019ec..b740e00 100755 --- a/ec_tools/database/sqlite_dao/sqlite_data_object.py +++ b/ec_tools/database/sqlite_dao/sqlite_data_object.py @@ -21,8 +21,7 @@ class SqliteDataObject(abc.ABC, DataObject): @classmethod @abc.abstractmethod - def primary_keys(cls) -> List[str]: - ... + def primary_keys(cls) -> List[str]: ... @classmethod def extra_indexes(cls) -> List[List[str]]: diff --git a/ec_tools/tools/cipher/__init__.py b/ec_tools/tools/cipher/__init__.py new file mode 100644 index 0000000..2decb6e --- /dev/null +++ b/ec_tools/tools/cipher/__init__.py @@ -0,0 +1,2 @@ +from .aes_cipher_generator import AesCipherGenerator, AesMode +from .cipher_generator import Cipher, CipherGenerator diff --git a/ec_tools/tools/cipher/aes_cipher_generator.py b/ec_tools/tools/cipher/aes_cipher_generator.py new file mode 100644 index 0000000..c5b25ab --- /dev/null +++ b/ec_tools/tools/cipher/aes_cipher_generator.py @@ -0,0 +1,75 @@ +import json +import os +import hashlib +import dataclasses +import enum +from Crypto.Cipher import AES + +from ec_tools.utils.hash_utils import hmac_sha256 +from ec_tools.tools.cipher.cipher_generator import Cipher, CipherGenerator, SecrectKey + + +@dataclasses.dataclass +class AesConfig: + key_size: int + iv_size: int + mode: int + + +class AesMode(enum.Enum): + AES_128_CBC = AesConfig(16, 16, AES.MODE_CBC) + AES_192_CBC = AesConfig(24, 16, AES.MODE_CBC) + AES_256_CBC = AesConfig(32, 16, AES.MODE_CBC) + + +@dataclasses.dataclass +class AesCipherGenerator(CipherGenerator): + mode: AesMode = AesMode.AES_256_CBC + pbkdf2_iterations: int = 10000 + _DIVIDER = b"\0" + + def decrypt(self, password: bytes, cipher: Cipher) -> bytes: + salt = bytes.fromhex(cipher.salt) + secrect_key = self._generate_key(password, salt) + aes = AES.new(secrect_key.key, self.mode.value.mode, iv=secrect_key.iv) + augmented_text = aes.decrypt(bytes.fromhex(cipher.cipher_text)) + decoded = bytes.fromhex(augmented_text.hex()[::2]) + divider_index = decoded.find(self._DIVIDER) + assert divider_index != -1, "invalid cipher: divider not found" + text_length = int(decoded[:divider_index].decode(self.encoding)) + data = decoded[divider_index + len(self._DIVIDER) :] + return data[:text_length] + + def encrypt(self, password: bytes, plain_text: bytes) -> Cipher: + secrect_key = self._generate_key(password, os.urandom(self.mode.value.key_size)) + text_length = len(plain_text) + plain_text = str(text_length).encode(self.encoding) + self._DIVIDER + plain_text + augmented_text = self._augment_bytes(plain_text, self.mode.value.key_size) + aes = AES.new(secrect_key.key, self.mode.value.mode, iv=secrect_key.iv) + cipher_text = aes.encrypt(augmented_text) + return Cipher( + cipher_text=cipher_text.hex(), + salt=secrect_key.salt.hex(), + mode=self.mode.name, + ) + + @classmethod + def _augment_bytes(self, data: bytes, padding_size: int) -> bytes: + padded_text = (data + os.urandom(padding_size - len(data) % padding_size)).hex() + random_bytes = os.urandom(len(padded_text)).hex() + mixture = bytes.fromhex("".join(map("".join, zip(padded_text, random_bytes)))) + return mixture + + def _generate_key(self, password: bytes, salt: bytes): + hsh = hashlib.pbkdf2_hmac( + "sha512", + password, + salt, + self.pbkdf2_iterations, + dklen=self.mode.value.key_size + self.mode.value.iv_size, + ) + return SecrectKey( + key=hsh[: self.mode.value.key_size], + iv=hsh[self.mode.value.key_size :], + salt=salt, + ) diff --git a/ec_tools/tools/cipher/cipher_generator.py b/ec_tools/tools/cipher/cipher_generator.py new file mode 100644 index 0000000..a4fd449 --- /dev/null +++ b/ec_tools/tools/cipher/cipher_generator.py @@ -0,0 +1,55 @@ +import json +import os +import hashlib +import dataclasses +import abc +from Crypto.Cipher import AES + +from ec_tools.utils.hash_utils import hmac_sha256 + + +@dataclasses.dataclass +class Cipher: + cipher_text: str + mode: str + salt: str + + def dumps(self) -> str: + return json.dumps(dataclasses.asdict(self)) + + @classmethod + def loads(self, text: str) -> "Cipher": + return Cipher(**json.loads(text)) + + +@dataclasses.dataclass +class SecrectKey: + key: bytes + iv: bytes + salt: bytes + + def __str__(self) -> str: + return f"SecretKey(key={len(self.key)},iv={len(self.iv)},salt={len(self.salt)})" + + def __repr__(self) -> str: + return str(self) + + +@dataclasses.dataclass +class CipherGenerator(abc.ABC): + encoding: str = "utf-8" + + @abc.abstractmethod + def decrypt(self, password: bytes, cipher: Cipher) -> bytes: ... + + @abc.abstractmethod + def encrypt(self, password: bytes, plain_text: bytes) -> Cipher: ... + + def decrypt_str(self, password: str, cipher: Cipher) -> str: + return self.decrypt(password.encode(self.encoding), cipher).decode("utf-8") + + def encrypt_str(self, password: str, text: str) -> Cipher: + return self.encrypt( + password=password.encode(self.encoding), + plain_text=text.encode(self.encoding), + ) diff --git a/ec_tools/tools/headers_manager.py b/ec_tools/tools/headers_manager.py new file mode 100644 index 0000000..56a3d89 --- /dev/null +++ b/ec_tools/tools/headers_manager.py @@ -0,0 +1,20 @@ +import dataclasses +import copy +from typing import Dict, List +from ec_tools.tools.key_manager import KeyManager + + +@dataclasses.dataclass +class HeadersManager: + key_manager: KeyManager + stored_keys: List[str] + + default_headers_template: Dict[str, str] = dataclasses.field(default_factory=dict) + + def get(self): + headers = copy.deepcopy(self.default_headers_template) + headers.update(self.key_manager.get_keys(self.stored_keys)) + return headers + + def refresh(self): + return self.key_manager.refresh_keys(self.stored_keys) diff --git a/ec_tools/tools/key_manager.py b/ec_tools/tools/key_manager.py new file mode 100644 index 0000000..9ecbb51 --- /dev/null +++ b/ec_tools/tools/key_manager.py @@ -0,0 +1,58 @@ +import dataclasses +import logging +import os +import threading +import time +from copy import deepcopy +from typing import Dict, Optional, Callable, List + +from ec_tools.database import CipherKvDao + + +def _collect_user_input(field: str, retries: int = 10) -> str: + for i in range(retries): + try: + value = input(f"Enter field {field}: ") + value = value.strip() + if not value: + continue + return value + except Exception as e: + continue + raise RuntimeError(f"Failed to input {field}") + + +@dataclasses.dataclass +class KeyManager: + cipher_kv_dao: CipherKvDao + password: str + + lock = threading.Lock() + expiration_time: float = 1000 * 365 * 86400 + hint_function: Optional[Callable] = None + + def get_key(self, key: str) -> str: + with self.lock: + return self.cipher_kv_dao.get(key, self.password) + + def get_keys(self, keys: List[str]) -> Dict[str, str]: + result = self._get_keys(keys) + missing_keys = [key for key in keys if result.get(key, None) is None] + if missing_keys: + if self.hint_function: + self.hint_function() + self.refresh_keys(keys) + return self._get_keys(keys) + + def refresh_keys(self, keys: List[str]): + for key in keys: + self._collect(key) + + def _collect(self, field: str) -> str: + with self.lock: + value = _collect_user_input(field) + self.cipher_kv_dao.set(field, self.password, value, self.expiration_time) + return value + + def _get_keys(self, keys: List[str]) -> Dict[str, Optional[str]]: + return {k: self.cipher_kv_dao.get(k, self.password) for k in keys} diff --git a/ec_tools/utils/hash_utils.py b/ec_tools/utils/hash_utils.py new file mode 100644 index 0000000..d09b5da --- /dev/null +++ b/ec_tools/utils/hash_utils.py @@ -0,0 +1,17 @@ +import hashlib +import hmac + + +def hmac_sha256(key: bytes, value: bytes) -> bytes: + sha256 = hmac.new(key, value, hashlib.sha256) + return sha256.digest() + + +def hmac_sha256_text(key: str, value: str, encoding="utf-8") -> str: + return hmac_sha256(key.encode(encoding), value.encode(encoding)).hex() + + +def hmac_md5_text(key: str, value: str, encoding="utf-8") -> str: + return hmac.new( + key.encode(encoding=encoding), value.encode(encoding=encoding), hashlib.md5 + ).hexdigest() diff --git a/local.sh b/local.sh new file mode 100755 index 0000000..2215e98 --- /dev/null +++ b/local.sh @@ -0,0 +1,6 @@ +pytest +black . +rm -rf dist/ build/ *.egg-info .pytest_cache/ +find . | grep -E "(/__pycache__$|/\.DS_Store$)" | xargs rm -rf +python3 -m build +pip3 install . diff --git a/requirements.txt b/requirements.txt old mode 100755 new mode 100644 index 23e605a..b20af35 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ -setuptools==68.0.0 +pycryptodome==3.20.0 +setuptools==69.5.1 diff --git a/setup.py b/setup.py old mode 100755 new mode 100644 index a6fd9dc..7076ff7 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setuptools.setup( name="ec_tools", - version="2.5", + version="2.6", description="EC Tools", packages=setuptools.find_packages(exclude=["tests", "tests.*"]), install_requires=["dataclasses", "typing"], diff --git a/tests/database/cipher_kv_dao_test.py b/tests/database/cipher_kv_dao_test.py new file mode 100755 index 0000000..1a51dc5 --- /dev/null +++ b/tests/database/cipher_kv_dao_test.py @@ -0,0 +1,41 @@ +import logging +import time +import unittest + +from ec_tools.database import SqliteKvDao, SqliteClient, CipherKvDao +from ec_tools.tools.cipher import AesCipherGenerator + +logging.basicConfig(level=logging.DEBUG) + + +class CipherKvDaoTest(unittest.TestCase): + sqlite_client = SqliteClient(":memory:") + kv_dao = SqliteKvDao(sqlite_client) + cipher_generator = AesCipherGenerator() + cipher_kv_dao = CipherKvDao(kv_dao, cipher_generator) + + def test(self): + password = "abcdefg12345678" + + logging.info("dao: %s", self.kv_dao) + self.kv_dao.drop_table() + self.kv_dao.create_table() + + # tests default get + self.assertEqual(self.cipher_kv_dao.get("hello", password, "??"), "??") + self.assertEqual(self.cipher_kv_dao.get_bytes("hello", password, b"??"), b"??") + self.assertEqual(self.cipher_kv_dao.get("hello", password, None), None) + self.assertEqual(self.cipher_kv_dao.get_bytes("hello", password, None), None) + + # tests set and get + self.cipher_kv_dao.set("hi", password, "how are you") + self.cipher_kv_dao.set_bytes("hihi", password, b"how are you") + self.cipher_kv_dao.set("hello", password, "world", 1) + self.assertEqual(self.cipher_kv_dao.get("hi", password), "how are you") + self.assertEqual(self.cipher_kv_dao.get_bytes("hihi", password), b"how are you") + self.assertEqual(self.cipher_kv_dao.get("hello", password), "world") + time.sleep(1) + + self.assertEqual(self.cipher_kv_dao.get("hi", password), "how are you") + self.cipher_kv_dao.delete("hi") + self.assertEqual(self.cipher_kv_dao.get("hi", password), None) diff --git a/tests/tools/__init__.py b/tests/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/tools/cipher/__init__.py b/tests/tools/cipher/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/tools/cipher/cipher_generator_test.py b/tests/tools/cipher/cipher_generator_test.py new file mode 100755 index 0000000..d9d4e1b --- /dev/null +++ b/tests/tools/cipher/cipher_generator_test.py @@ -0,0 +1,20 @@ +import logging +import time +import os +import unittest + +from ec_tools.tools.cipher import AesCipherGenerator, Cipher, CipherGenerator, AesMode + +logging.basicConfig(level=logging.DEBUG) + + +class CipherGeneratorTest(unittest.TestCase): + def test(self): + for mode in AesMode: + cipher_generator: CipherGenerator = AesCipherGenerator(mode=mode) + for length in list(range(32)) + [127, 256]: + password = os.urandom(length) + plain_text = os.urandom(length) + cipher = cipher_generator.encrypt(password, plain_text) + self.assertTrue(cipher.mode.startswith("AES_")) + self.assertEqual(plain_text, cipher_generator.decrypt(password, cipher))