Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
repos:
# --- isort: import sorter ---
- repo: https://github.com/PyCQA/isort
rev: 5.13.2
hooks:
- id: isort
args: ["--profile", "black", "--line-length", "120"]
# --- autoflake: remove unused imports/vars ---
- repo: https://github.com/myint/autoflake
rev: v2.3.1
hooks:
- id: autoflake
args: [
"--in-place",
"--remove-unused-variables",
"--remove-all-unused-imports",
]
exclude: "__init__.py"
# --- Black: Python code formatter ---
- repo: https://github.com/psf/black
rev: 24.8.0 # use a stable release
hooks:
- id: black
language_version: python3
args: ["--line-length=120"]
# --- YAML lint/format ---
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: check-yaml

3 changes: 3 additions & 0 deletions .pylintrc
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[MESSAGES CONTROL]
disable=missing-docstring,broad-exception-caught,too-many-return-statements,too-many-arguments,too-many-positional-arguments

7 changes: 2 additions & 5 deletions demos/headers_manager_demo.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import logging
import time
import unittest

from ec_tools.database import SqliteKvDao, SqliteClient, CipherKvDao
from ec_tools.database import CipherKvDao, SqliteClient, SqliteKvDao
from ec_tools.tools.cipher import AesCipherGenerator
from ec_tools.tools.key_manager import KeyManager

from ec_tools.tools.headers_manager import HeadersManager
from ec_tools.tools.key_manager import KeyManager

logging.basicConfig(level=logging.DEBUG)

Expand Down
4 changes: 1 addition & 3 deletions demos/key_manager_demo.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import logging
import time
import unittest

from ec_tools.database import SqliteKvDao, SqliteClient, CipherKvDao
from ec_tools.database import CipherKvDao, SqliteClient, SqliteKvDao
from ec_tools.tools.cipher import AesCipherGenerator
from ec_tools.tools.key_manager import KeyManager

Expand Down
2 changes: 1 addition & 1 deletion ec_tools/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .data_object import CustomizedJsonEncoder, DataObject, Formatter
from .json_type import JsonType
from .data_object import DataObject, Formatter
62 changes: 40 additions & 22 deletions ec_tools/data/data_object.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
import dataclasses
import enum
import json
from typing import Any, Dict, List, Callable, get_origin, get_args, Set
from typing import Any, Callable, Dict, List, Set, get_args, get_origin

from ec_tools.utils.io_utils import CustomizedJsonEncoder

class CustomizedJsonEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, set):
return list(o)
if isinstance(o, enum.Enum):
return o.name
return json.JSONEncoder.default(self, o)


@dataclasses.dataclass
Expand Down Expand Up @@ -48,14 +55,10 @@ def _format_by_field(self, field_type, value: Any):
return {self._format(get_args(field_type)[0], each) for each in value}
if get_origin(field_type) in [dict, Dict]:
return {
self._format(get_args(field_type)[0], k): self._format(
get_args(field_type)[1], v
)
self._format(get_args(field_type)[0], k): self._format(get_args(field_type)[1], v)
for k, v in value.items()
}
raise FormatErrorException(
self.class_name, self.field, f"unknown field type {field_type} to format"
)
raise FormatErrorException(self.class_name, self.field, f"unknown field type {field_type} to format")

def _format_by_class(self, clazz: type, value: Any):
if clazz is int:
Expand All @@ -82,16 +85,14 @@ def _format_by_class(self, clazz: type, value: Any):
self.field,
f"unknown value {value} found in enum {clazz}",
)
raise FormatErrorException(
self.class_name, self.field, f"unknown type {clazz} to format"
)
raise FormatErrorException(self.class_name, self.field, f"unknown type {clazz} to format")

def _get_value(self, value: Any):
if value is not None:
return value
if self.field and self.field.default_factory != dataclasses.MISSING:
return self.field.default_factory()
elif self.field and self.field.default != dataclasses.MISSING:
if self.field and self.field.default != dataclasses.MISSING:
return self.field.default
return None

Expand All @@ -118,24 +119,29 @@ def field_names(cls) -> List[str]:

@classmethod
def from_json(cls, json_obj: Dict[str, Any]):
function_mapping = cls._customized_mapping_function("_load__")
function_mapping = cls._customized_load_function("_load__")
return cls(
**{
field.name: function_mapping[field.name](json_obj.get(field.name, None))
for field in cls.fields()
}
**{field.name: function_mapping[field.name](json_obj.get(field.name, None)) for field in cls.fields()}
)

def to_json(self) -> Dict[str, Any]:
return json.loads(self.to_json_str())
function_mapping = self._customized_dump_function("_dump__")
return {
field.name: (
function_mapping[field.name](self.__dict__[field.name])
if function_mapping[field.name]
else json.loads(
json.dumps(dataclasses._asdict_inner(self.__dict__[field.name], dict), cls=CustomizedJsonEncoder)
)
)
for field in self.fields()
}

def to_json_str(self) -> str:
return json.dumps(dataclasses.asdict(self), cls=CustomizedJsonEncoder)
return json.dumps(self.to_json(), ensure_ascii=False, cls=CustomizedJsonEncoder)

@classmethod
def _customized_mapping_function(
cls, prefix: str
) -> Dict[str, Callable[[Any], Any]]:
def _customized_load_function(cls, prefix: str) -> Dict[str, Callable[[Any], Any]]:
all_functions = {
item: getattr(cls, item)
for item in dir(cls)
Expand All @@ -148,3 +154,15 @@ def _customized_mapping_function(
)
for field in cls.fields()
}

@classmethod
def _customized_dump_function(
cls,
prefix: str,
) -> Dict[str, Callable[[Any], Any]]:
all_functions = {
item: getattr(cls, item)
for item in dir(cls)
if isinstance(getattr(cls, item), Callable) and item.startswith(prefix)
}
return {field.name: all_functions.get(prefix + field.name) for field in cls.fields()}
6 changes: 2 additions & 4 deletions ec_tools/data/json_type.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Union, List, Dict, TypeAlias
from typing import Dict, List, TypeAlias, Union

JsonType: TypeAlias = Union[
None, int, str, bool, List["JsonType"], Dict[str, "JsonType"]
]
JsonType: TypeAlias = Union[None, int, str, bool, float, List["JsonType"], Dict[str, "JsonType"]]
12 changes: 5 additions & 7 deletions ec_tools/database/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from .sqlite_client.sqlite_query import SqliteQuery
from .kv_dao.cipher_kv_dao import CipherKvDao
from .kv_dao.kv_dao import ONE_THOUSAND_YEAR, KvDao
from .kv_dao.kv_data import KvData
from .kv_dao.sqlite_kv_dao import SqliteKvDao
from .sqlite_client.sqlite_client import SqliteClient

from .sqlite_client.sqlite_query import SqliteQuery
from .sqlite_dao.sqlite_dao import SqliteDao
from .sqlite_dao.sqlite_data_object import SqliteDataObject
from .sqlite_dao.sqlite_query_generator import SqliteQueryGenerator

from .kv_dao.kv_dao import KvDao, ONE_THOUSAND_YEAR
from .kv_dao.kv_data import KvData
from .kv_dao.sqlite_kv_dao import SqliteKvDao
from .kv_dao.cipher_kv_dao import CipherKvDao
38 changes: 18 additions & 20 deletions ec_tools/database/kv_dao/cipher_kv_dao.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import dataclasses
from typing import Optional
from typing import List, Optional

from ec_tools.database.kv_dao.kv_dao import KvDao
from ec_tools.database.kv_dao.sqlite_kv_dao import SqliteKvDao
from ec_tools.database.sqlite_client.sqlite_client import SqliteClient
from ec_tools.tools.cipher import CipherGenerator, AesCipherGenerator, Cipher, AesMode
from ec_tools.tools.cipher import AesCipherGenerator, AesMode, Cipher, CipherGenerator


@dataclasses.dataclass
Expand All @@ -15,37 +16,31 @@ class CipherKvDao:

@classmethod
def create_sqlite_dao(
cls, db_path: str, encoding: str = "utf-8", mode: AesMode = AesMode.AES_256_CBC
):
cls,
db_path: str,
encoding: str = "utf-8",
mode: AesMode = AesMode.AES_256_CBC,
table_name: Optional[str] = None,
) -> "CipherKvDao":
return CipherKvDao(
SqliteKvDao(sqlite_client=SqliteClient(db_path)),
SqliteKvDao(sqlite_client=SqliteClient(db_path), table_name=table_name),
AesCipherGenerator(encoding, mode),
)

def get(
self, key: str, password: str, default: Optional[str] = None
) -> Optional[str]:
value = self.get_bytes(
key, password, default.encode(self.encoding) if default else None
)
def get(self, key: str, password: str, default: Optional[str] = None) -> Optional[str]:
value = self.get_bytes(key, password, default.encode(self.encoding) if default else None)
return value.decode(self.encoding) if value else None

def get_bytes(
self, key: str, password: str, default: Optional[bytes] = None
) -> Optional[bytes]:
def get_bytes(self, key: str, password: str, default: Optional[bytes] = None) -> Optional[bytes]:
value = self.kv_dao.get(key)
if value:
return self.cipher_generator.decrypt(
password.encode(self.encoding), Cipher.loads(value)
)
return self.cipher_generator.decrypt(password.encode(self.encoding), Cipher.loads(value))
return default

def set(self, key: str, password: str, value: str, duration: float = None) -> None:
return self.set_bytes(key, password, value.encode(self.encoding), duration)

def set_bytes(
self, key: str, password: str, value: bytes, duration: float = None
) -> None:
def set_bytes(self, key: str, password: str, value: bytes, duration: float = None) -> None:
cipher = self.cipher_generator.encrypt(password.encode(self.encoding), value)
return self.kv_dao.set(key, cipher.dumps(), duration)

Expand All @@ -54,3 +49,6 @@ def delete(self, key: str) -> None:

def clear(self) -> None:
return self.kv_dao.clear()

def keys(self) -> List[str]:
return self.kv_dao.keys()
4 changes: 3 additions & 1 deletion ec_tools/database/kv_dao/kv_dao.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import abc
from typing import Any
from typing import Any, List

ONE_THOUSAND_YEAR = 86400 * 365 * 1000

Expand All @@ -16,6 +16,8 @@ def get(self, key: str, default: Any = None) -> Any:
def set(self, key: str, value: str, duration: float = None) -> None:
self._set(key, value, duration or self._default_duration)

def keys(self) -> List[str]: ...

@abc.abstractmethod
def delete(self, key: str) -> None: ...

Expand Down
22 changes: 16 additions & 6 deletions ec_tools/database/kv_dao/sqlite_kv_dao.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import time
from typing import Any
from typing import Any, List, Optional

from ec_tools.database.kv_dao.kv_dao import KvDao, ONE_THOUSAND_YEAR
from ec_tools.database.kv_dao.kv_dao import ONE_THOUSAND_YEAR, KvDao
from ec_tools.database.kv_dao.kv_data import KvData
from ec_tools.database.sqlite_client.sqlite_client import SqliteClient
from ec_tools.database.sqlite_client.sqlite_query import SqliteQuery
Expand All @@ -13,8 +13,9 @@ def __init__(
self,
sqlite_client: SqliteClient,
default_duration: float = ONE_THOUSAND_YEAR,
table_name: Optional[str] = None,
):
SqliteDao.__init__(self, sqlite_client=sqlite_client, data_type=KvData)
SqliteDao.__init__(self, sqlite_client=sqlite_client, data_type=KvData, table_name=table_name)
KvDao.__init__(self, default_duration=default_duration)

def _get(self, key: str) -> Any:
Expand All @@ -29,9 +30,18 @@ def _get(self, key: str) -> Any:
return rows[0]["value"] if rows else None

def _set(self, key: str, value: str, duration: float) -> None:
return self.insert_or_replace(
[KvData(key=key, value=value, expired_at=time.time() + duration)]
)
return self.insert_or_replace([KvData(key=key, value=value, expired_at=time.time() + duration)])

def keys(self) -> List[str]:
rows = self.execute(
SqliteQuery(
f"SELECT key FROM {self._table_name}",
"WHERE expired_at > ?",
args=[time.time()],
),
commit=False,
)[0]
return [row["key"] for row in rows]

def delete(self, key: str) -> None:
return self.delete_by_values(key=key)
Expand Down
9 changes: 2 additions & 7 deletions ec_tools/database/sqlite_client/sqlite_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@ class SqliteClient:
def __init__(self, db_path: str):
self._db_path = db_path
self._conn = sqlite3.connect(db_path, check_same_thread=False)
self._conn.row_factory = lambda cursor, row: {
col[0]: row[idx] for idx, col in enumerate(cursor.description)
}
self._conn.row_factory = lambda cursor, row: {col[0]: row[idx] for idx, col in enumerate(cursor.description)}
self._cursor = self._conn.cursor()
self._lock = threading.Lock()
self._cursor.execute("PRAGMA foreign_keys = ON")
Expand All @@ -32,10 +30,7 @@ def __del__(self):
def execute(self, *queries: SqliteQuery, commit=True):
_ = [self._logger.debug("SQL: %s", query) for query in queries]
with self._lock:
results = [
self._cursor.execute(query.sql, query.args).fetchall()
for query in queries
]
results = [self._cursor.execute(query.sql, query.args).fetchall() for query in queries]
if commit:
self._conn.commit()
return results
Expand Down
2 changes: 1 addition & 1 deletion ec_tools/database/sqlite_client/sqlite_query.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import dataclasses
from typing import List, Any
from typing import Any, List


@dataclasses.dataclass
Expand Down
Loading