Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
name: Pytest

permissions:
contents: read

on:
pull_request:

jobs:
pytest:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.10", "3.11", "3.12"]

steps:
- name: Check out repository
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Install Poetry
run: |
python -m pip install --upgrade pip
pip install poetry

- name: Install dependencies
run: poetry install --with dev

- name: Run pytest
run: poetry run pytest tests
10 changes: 10 additions & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Agent Guidelines

- Never use `getattr` or `setattr`;
- use type hints
- write clean code. if you're writing many if statements you're probably doing it wrong. Less code is better to understand the logic and to reduce bugs.
- modularize behavior into the most relevant class/module; do not centralize unrelated functionality in one class.
- avoid keyword-only `*` in method/function signatures unless explicitly requested.
- Before you commit, run pre-commit ruff format. commit and push the changes (use a dedicated branch for each session). If the pre-commit returns errors, fix them. For the pre-commit to work you have to cd into the current project and activate the environment.
- Ensure git hooks can resolve `python`: run commit/pre-commit commands with the project venv first on `PATH`, e.g. `PATH="$(poetry env info -p)/bin:$PATH" poetry run pre-commit run ruff-format --files <files>` and `PATH="$(poetry env info -p)/bin:$PATH" git commit -m "<message>"`.
- run relevant pytests
76 changes: 49 additions & 27 deletions bitcoin_safe_lib/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from base64 import urlsafe_b64encode as b64e
from collections.abc import Callable, Iterable
from pathlib import Path
from typing import Any, Protocol, Self, TypeVar
from typing import Any, TypeAlias, TypeGuard, TypeVar

import bdkpython as bdk
from cryptography.fernet import Fernet
Expand All @@ -53,7 +53,13 @@

from .util import fast_version

T = TypeVar("T")
T = TypeVar("T", bound="BaseSaveableClass")
ClassArgs: TypeAlias = dict[str, Any] # noqa: UP040
ClassKwargs: TypeAlias = dict[str, ClassArgs] # noqa: UP040
SaveableClass: TypeAlias = type["BaseSaveableClass"] # noqa: UP040
EnumClass: TypeAlias = type[enum.Enum] # noqa: UP040
KnownClass: TypeAlias = SaveableClass | EnumClass # noqa: UP040
KnownClasses: TypeAlias = dict[str, KnownClass] # noqa: UP040

logger = logging.getLogger(__name__)

Expand All @@ -68,11 +74,7 @@ def filtered_dict(d: dict, allowed_keys: Iterable[str]) -> dict:
return {k: v for k, v in d.items() if k in allowed_keys}


class SupportsInit(Protocol):
def __init__(self, *args, **kwargs: Any) -> None: ...


def filtered_for_init(d: dict, cls: type[SupportsInit]) -> dict:
def filtered_for_init(d: dict, cls: type[Any]) -> dict:
"""Filtered for init."""
return filtered_dict(d, varnames(cls.__init__))

Expand Down Expand Up @@ -162,22 +164,42 @@ def load(self, filename: str, password: str | None = None) -> str:


class ClassSerializer:
@staticmethod
def _is_saveable_class(obj_cls: KnownClass) -> TypeGuard[SaveableClass]:
return issubclass(obj_cls, BaseSaveableClass)

@staticmethod
def _is_enum_class(obj_cls: KnownClass) -> TypeGuard[EnumClass]:
return issubclass(obj_cls, enum.Enum)

@staticmethod
def _merge_class_kwargs(dct: dict[str, Any], cls_string: str, extra_kwargs: ClassArgs) -> dict[str, Any]:
duplicate_keys = sorted(set(dct).intersection(extra_kwargs))
if duplicate_keys:
logger.error(
f"Duplicate deserialization keys for {cls_string=}; keeping values from dct. {duplicate_keys=}"
)

merged_dct = extra_kwargs.copy()
merged_dct.update(dct)
return merged_dct

@classmethod
def general_deserializer(cls, known_classes, class_kwargs) -> Callable:
def general_deserializer(
cls, known_classes: KnownClasses, class_kwargs: ClassKwargs
) -> Callable[[dict[str, Any]], Any]:
"""General deserializer."""

def deserializer(dct: dict) -> dict:
def deserializer(dct: dict[str, Any]) -> Any:
"""Deserializer."""
cls_string = dct.get("__class__") # e.g. KeyStore
if cls_string:
if cls_string in known_classes:
obj_cls = known_classes.get(cls_string)
if hasattr(obj_cls, "from_dump"): # is there KeyStore.from_dump ?
if class_kwargs.get(cls_string): # apply additional arguments to the class from_dump
dct.update(class_kwargs.get(cls_string))
return obj_cls.from_dump(
dct, class_kwargs=class_kwargs
) # do: KeyStore.from_dump(**dct)
if obj_cls and cls._is_saveable_class(obj_cls):
if extra_class_kwargs := class_kwargs.get(cls_string):
dct = cls._merge_class_kwargs(dct, cls_string, extra_class_kwargs)
return obj_cls.from_dump(dct, class_kwargs=class_kwargs)
else:
raise Exception(f"{obj_cls} doesnt have a from_dump classmethod.")
else:
Expand All @@ -199,8 +221,8 @@ def deserializer(dct: dict) -> dict:
)
elif dct.get("__enum__"):
obj_cls = known_classes.get(dct["name"])
if obj_cls and hasattr(obj_cls, dct["value"]):
return getattr(obj_cls, dct["value"])
if obj_cls and cls._is_enum_class(obj_cls) and dct["value"] in obj_cls.__members__:
return obj_cls[dct["value"]]
else:
logger.exception(f"Could not deserialize {obj_cls}({dct.get('value')}).")

Expand All @@ -222,12 +244,12 @@ def general_serializer(cls, obj):


class BaseSaveableClass:
known_classes: dict[str, Any] = {"Network": bdk.Network}
known_classes: KnownClasses = {"Network": bdk.Network}
VERSION = "0.0.0"
_version_from_dump: str | None = None

@staticmethod
def cls_kwargs(*args, **kwargs):
def cls_kwargs(*args, **kwargs) -> ClassArgs:
return {}

@abstractmethod
Expand All @@ -254,7 +276,7 @@ def from_dump_downgrade_migration(cls, dct: dict[str, Any]):
return dct

@classmethod
def _from_dump(cls, dct: dict[str, Any], class_kwargs: dict | None = None):
def _from_dump(cls, dct: dict[str, Any], class_kwargs: ClassArgs | None = None):
"""From dump."""
assert dct.get("__class__") == cls.__name__
del dct["__class__"]
Expand All @@ -273,11 +295,11 @@ def _from_dump(cls, dct: dict[str, Any], class_kwargs: dict | None = None):

@classmethod
@abstractmethod
def from_dump(cls: type[SupportsInit], dct: dict[str, Any], class_kwargs: dict | None = None):
def from_dump(cls, dct: dict[str, Any], class_kwargs: ClassKwargs | None = None):
"""From dump."""
raise NotImplementedError()

def clone(self, class_kwargs: dict | None = None) -> Self:
def clone(self: T, class_kwargs: ClassKwargs | None = None) -> T:
"""Clone."""
return self._from_dumps(self.dumps(), class_kwargs=class_kwargs)

Expand Down Expand Up @@ -314,7 +336,7 @@ def dumps(self, indent=None) -> str:
return self.dumps_object(self, indent=indent)

@staticmethod
def _flatten_known_classes(known_classes: dict[str, Any]) -> dict[str, Any]:
def _flatten_known_classes(known_classes: KnownClasses) -> KnownClasses:
"Recursively extends the dict to includes all known_classes of known_classes"
known_classes = known_classes.copy()
for known_class in list(known_classes.values()):
Expand All @@ -323,13 +345,13 @@ def _flatten_known_classes(known_classes: dict[str, Any]) -> dict[str, Any]:
return known_classes

@classmethod
def get_known_classes(cls) -> dict[str, Any]:
def get_known_classes(cls) -> KnownClasses:
"Gets a flattened list of known classes that a json deserializer needs to interpet all objects"
return BaseSaveableClass._flatten_known_classes({cls.__name__: cls})

@classmethod
@time_logger
def _from_dumps(cls, json_string: str, class_kwargs: dict | None = None):
def _from_dumps(cls, json_string: str, class_kwargs: ClassKwargs | None = None):
return json.loads(
json_string,
object_hook=ClassSerializer.general_deserializer(
Expand All @@ -339,7 +361,7 @@ def _from_dumps(cls, json_string: str, class_kwargs: dict | None = None):

@classmethod
@time_logger
def _from_file(cls, filename: str, password: str | None = None, class_kwargs: dict | None = None):
def _from_file(cls, filename: str, password: str | None = None, class_kwargs: ClassKwargs | None = None):
"""Loads the class from a file. This offers the option of add class_kwargs args.

Args:
Expand Down Expand Up @@ -371,7 +393,7 @@ def dump(self):
return d

@classmethod
def from_dump(cls, dct: dict, class_kwargs: dict | None = None):
def from_dump(cls, dct: dict, class_kwargs: ClassKwargs | None = None):
"""From dump."""
super()._from_dump(dct, class_kwargs=class_kwargs)
return cls(**filtered_for_init(dct, cls))
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ line-length = 110

[tool.poetry]
name = "bitcoin-safe-lib"
version = "2.1.0"
version = "2.1.1"
authors = ["andreasgriffin <andreasgriffin@proton.me>"]
license = "GPL-3.0"
readme = "README.md"
Expand Down Expand Up @@ -53,4 +53,3 @@ known-first-party = ["bitcoin_safe_lib"]
[tool.ruff.format]
# Ruff formatter is Black-compatible; keep defaults or tweak here.


29 changes: 29 additions & 0 deletions tests/test_storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import logging

from bitcoin_safe_lib.storage import SaveAllClass


class ExampleSaveable(SaveAllClass):
VERSION = "1.0.0"

def __init__(self, value: str, optional: str | None = None) -> None:
self.value = value
self.optional = optional


def test_from_dumps_prefers_dct_values_over_class_kwargs(caplog) -> None:
json_string = (
'{"__class__":"ExampleSaveable","VERSION":"1.0.0","value":"from_dct","optional":"from_json"}'
)

with caplog.at_level(logging.ERROR, logger="bitcoin_safe_lib.storage"):
obj = ExampleSaveable._from_dumps(
json_string,
class_kwargs={"ExampleSaveable": {"value": "from_kwargs", "extra": "unused"}},
)

assert obj.value == "from_dct"
assert obj.optional == "from_json"
assert "Duplicate deserialization keys" in caplog.text
assert "ExampleSaveable" in caplog.text
assert "value" in caplog.text
Loading