Skip to content

Commit b034ef9

Browse files
Improve storage deserializer typing
1 parent 6aca98c commit b034ef9

3 files changed

Lines changed: 86 additions & 28 deletions

File tree

bitcoin_safe_lib/storage.py

Lines changed: 55 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from base64 import urlsafe_b64encode as b64e
4242
from collections.abc import Callable, Iterable
4343
from pathlib import Path
44-
from typing import Any, Protocol, Self, TypeVar
44+
from typing import Any, Self, TypeAlias, TypeGuard, TypeVar
4545

4646
import bdkpython as bdk
4747
from cryptography.fernet import Fernet
@@ -54,6 +54,12 @@
5454
from .util import fast_version
5555

5656
T = TypeVar("T")
57+
ClassArgs: TypeAlias = dict[str, Any] # noqa: UP040
58+
ClassKwargs: TypeAlias = dict[str, ClassArgs] # noqa: UP040
59+
SaveableClass: TypeAlias = type["BaseSaveableClass"] # noqa: UP040
60+
EnumClass: TypeAlias = type[enum.Enum] # noqa: UP040
61+
KnownClass: TypeAlias = SaveableClass | EnumClass # noqa: UP040
62+
KnownClasses: TypeAlias = dict[str, KnownClass] # noqa: UP040
5763

5864
logger = logging.getLogger(__name__)
5965

@@ -68,11 +74,7 @@ def filtered_dict(d: dict, allowed_keys: Iterable[str]) -> dict:
6874
return {k: v for k, v in d.items() if k in allowed_keys}
6975

7076

71-
class SupportsInit(Protocol):
72-
def __init__(self, *args, **kwargs: Any) -> None: ...
73-
74-
75-
def filtered_for_init(d: dict, cls: type[SupportsInit]) -> dict:
77+
def filtered_for_init(d: dict, cls: type[Any]) -> dict:
7678
"""Filtered for init."""
7779
return filtered_dict(d, varnames(cls.__init__))
7880

@@ -162,22 +164,49 @@ def load(self, filename: str, password: str | None = None) -> str:
162164

163165

164166
class ClassSerializer:
167+
@staticmethod
168+
def _is_saveable_class(obj_cls: KnownClass) -> TypeGuard[SaveableClass]:
169+
return issubclass(obj_cls, BaseSaveableClass)
170+
171+
@staticmethod
172+
def _is_enum_class(obj_cls: KnownClass) -> TypeGuard[EnumClass]:
173+
return issubclass(obj_cls, enum.Enum)
174+
175+
@staticmethod
176+
def _merge_class_kwargs(dct: dict[str, Any], cls_string: str, extra_kwargs: ClassArgs) -> dict[str, Any]:
177+
duplicate_keys = sorted(set(dct).intersection(extra_kwargs))
178+
if duplicate_keys:
179+
logger.error(
180+
"Duplicate deserialization keys for %s; keeping values from dct. "
181+
"duplicate_keys=%s dct_values=%s class_kwargs_values=%s",
182+
cls_string,
183+
duplicate_keys,
184+
{key: dct[key] for key in duplicate_keys},
185+
{key: extra_kwargs[key] for key in duplicate_keys},
186+
)
187+
188+
merged_dct = extra_kwargs.copy()
189+
merged_dct.update(dct)
190+
return merged_dct
191+
165192
@classmethod
166-
def general_deserializer(cls, known_classes, class_kwargs) -> Callable:
193+
def general_deserializer(
194+
cls, known_classes: KnownClasses, class_kwargs: ClassKwargs
195+
) -> Callable[[dict[str, Any]], Any]:
167196
"""General deserializer."""
168197

169-
def deserializer(dct: dict) -> dict:
198+
def deserializer(dct: dict[str, Any]) -> Any:
170199
"""Deserializer."""
171200
cls_string = dct.get("__class__") # e.g. KeyStore
172201
if cls_string:
173202
if cls_string in known_classes:
174203
obj_cls = known_classes.get(cls_string)
175-
if hasattr(obj_cls, "from_dump"): # is there KeyStore.from_dump ?
176-
if class_kwargs.get(cls_string): # apply additional arguments to the class from_dump
177-
dct.update(class_kwargs.get(cls_string))
178-
return obj_cls.from_dump(
179-
dct, class_kwargs=class_kwargs
180-
) # do: KeyStore.from_dump(**dct)
204+
if obj_cls and cls._is_saveable_class(obj_cls):
205+
if extra_class_kwargs := class_kwargs.get(
206+
cls_string
207+
): # apply additional arguments to the class from_dump
208+
dct = cls._merge_class_kwargs(dct, cls_string, extra_class_kwargs)
209+
return obj_cls.from_dump(dct, class_kwargs=class_kwargs)
181210
else:
182211
raise Exception(f"{obj_cls} doesnt have a from_dump classmethod.")
183212
else:
@@ -199,8 +228,8 @@ def deserializer(dct: dict) -> dict:
199228
)
200229
elif dct.get("__enum__"):
201230
obj_cls = known_classes.get(dct["name"])
202-
if obj_cls and hasattr(obj_cls, dct["value"]):
203-
return getattr(obj_cls, dct["value"])
231+
if obj_cls and cls._is_enum_class(obj_cls) and dct["value"] in obj_cls.__members__:
232+
return obj_cls[dct["value"]]
204233
else:
205234
logger.exception(f"Could not deserialize {obj_cls}({dct.get('value')}).")
206235

@@ -222,12 +251,12 @@ def general_serializer(cls, obj):
222251

223252

224253
class BaseSaveableClass:
225-
known_classes: dict[str, Any] = {"Network": bdk.Network}
254+
known_classes: KnownClasses = {"Network": bdk.Network}
226255
VERSION = "0.0.0"
227256
_version_from_dump: str | None = None
228257

229258
@staticmethod
230-
def cls_kwargs(*args, **kwargs):
259+
def cls_kwargs(*args, **kwargs) -> ClassArgs:
231260
return {}
232261

233262
@abstractmethod
@@ -254,7 +283,7 @@ def from_dump_downgrade_migration(cls, dct: dict[str, Any]):
254283
return dct
255284

256285
@classmethod
257-
def _from_dump(cls, dct: dict[str, Any], class_kwargs: dict | None = None):
286+
def _from_dump(cls, dct: dict[str, Any], class_kwargs: ClassArgs | None = None):
258287
"""From dump."""
259288
assert dct.get("__class__") == cls.__name__
260289
del dct["__class__"]
@@ -273,11 +302,11 @@ def _from_dump(cls, dct: dict[str, Any], class_kwargs: dict | None = None):
273302

274303
@classmethod
275304
@abstractmethod
276-
def from_dump(cls: type[SupportsInit], dct: dict[str, Any], class_kwargs: dict | None = None):
305+
def from_dump(cls, dct: dict[str, Any], class_kwargs: ClassKwargs | None = None):
277306
"""From dump."""
278307
raise NotImplementedError()
279308

280-
def clone(self, class_kwargs: dict | None = None) -> Self:
309+
def clone(self, class_kwargs: ClassKwargs | None = None) -> Self:
281310
"""Clone."""
282311
return self._from_dumps(self.dumps(), class_kwargs=class_kwargs)
283312

@@ -314,7 +343,7 @@ def dumps(self, indent=None) -> str:
314343
return self.dumps_object(self, indent=indent)
315344

316345
@staticmethod
317-
def _flatten_known_classes(known_classes: dict[str, Any]) -> dict[str, Any]:
346+
def _flatten_known_classes(known_classes: KnownClasses) -> KnownClasses:
318347
"Recursively extends the dict to includes all known_classes of known_classes"
319348
known_classes = known_classes.copy()
320349
for known_class in list(known_classes.values()):
@@ -323,13 +352,13 @@ def _flatten_known_classes(known_classes: dict[str, Any]) -> dict[str, Any]:
323352
return known_classes
324353

325354
@classmethod
326-
def get_known_classes(cls) -> dict[str, Any]:
355+
def get_known_classes(cls) -> KnownClasses:
327356
"Gets a flattened list of known classes that a json deserializer needs to interpet all objects"
328357
return BaseSaveableClass._flatten_known_classes({cls.__name__: cls})
329358

330359
@classmethod
331360
@time_logger
332-
def _from_dumps(cls, json_string: str, class_kwargs: dict | None = None):
361+
def _from_dumps(cls, json_string: str, class_kwargs: ClassKwargs | None = None):
333362
return json.loads(
334363
json_string,
335364
object_hook=ClassSerializer.general_deserializer(
@@ -339,7 +368,7 @@ def _from_dumps(cls, json_string: str, class_kwargs: dict | None = None):
339368

340369
@classmethod
341370
@time_logger
342-
def _from_file(cls, filename: str, password: str | None = None, class_kwargs: dict | None = None):
371+
def _from_file(cls, filename: str, password: str | None = None, class_kwargs: ClassKwargs | None = None):
343372
"""Loads the class from a file. This offers the option of add class_kwargs args.
344373
345374
Args:
@@ -371,7 +400,7 @@ def dump(self):
371400
return d
372401

373402
@classmethod
374-
def from_dump(cls, dct: dict, class_kwargs: dict | None = None):
403+
def from_dump(cls, dct: dict, class_kwargs: ClassKwargs | None = None):
375404
"""From dump."""
376405
super()._from_dump(dct, class_kwargs=class_kwargs)
377406
return cls(**filtered_for_init(dct, cls))

pyproject.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ line-length = 110
55

66
[tool.poetry]
77
name = "bitcoin-safe-lib"
8-
version = "2.1.0"
8+
version = "2.1.1"
99
authors = ["andreasgriffin <andreasgriffin@proton.me>"]
1010
license = "GPL-3.0"
1111
readme = "README.md"
@@ -53,4 +53,3 @@ known-first-party = ["bitcoin_safe_lib"]
5353
[tool.ruff.format]
5454
# Ruff formatter is Black-compatible; keep defaults or tweak here.
5555

56-

tests/test_storage.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import logging
2+
3+
from bitcoin_safe_lib.storage import SaveAllClass
4+
5+
6+
class ExampleSaveable(SaveAllClass):
7+
VERSION = "1.0.0"
8+
9+
def __init__(self, value: str, optional: str | None = None) -> None:
10+
self.value = value
11+
self.optional = optional
12+
13+
14+
def test_from_dumps_prefers_dct_values_over_class_kwargs(caplog) -> None:
15+
json_string = (
16+
'{"__class__":"ExampleSaveable","VERSION":"1.0.0","value":"from_dct","optional":"from_json"}'
17+
)
18+
19+
with caplog.at_level(logging.ERROR, logger="bitcoin_safe_lib.storage"):
20+
obj = ExampleSaveable._from_dumps(
21+
json_string,
22+
class_kwargs={"ExampleSaveable": {"value": "from_kwargs", "extra": "unused"}},
23+
)
24+
25+
assert obj.value == "from_dct"
26+
assert obj.optional == "from_json"
27+
assert "Duplicate deserialization keys for ExampleSaveable" in caplog.text
28+
assert "value" in caplog.text
29+
assert "from_dct" in caplog.text
30+
assert "from_kwargs" in caplog.text

0 commit comments

Comments
 (0)