4141from base64 import urlsafe_b64encode as b64e
4242from collections .abc import Callable , Iterable
4343from pathlib import Path
44- from typing import Any , Protocol , Self , TypeVar
44+ from typing import Any , Self , TypeAlias , TypeGuard , TypeVar
4545
4646import bdkpython as bdk
4747from cryptography .fernet import Fernet
5454from .util import fast_version
5555
5656T = TypeVar ("T" )
57+ ClassArgs : TypeAlias = dict [str , Any ] # noqa: UP040
58+ ClassKwargs : TypeAlias = dict [str , ClassArgs ] # noqa: UP040
59+ SaveableClass : TypeAlias = type ["BaseSaveableClass" ] # noqa: UP040
60+ EnumClass : TypeAlias = type [enum .Enum ] # noqa: UP040
61+ KnownClass : TypeAlias = SaveableClass | EnumClass # noqa: UP040
62+ KnownClasses : TypeAlias = dict [str , KnownClass ] # noqa: UP040
5763
5864logger = logging .getLogger (__name__ )
5965
@@ -68,11 +74,7 @@ def filtered_dict(d: dict, allowed_keys: Iterable[str]) -> dict:
6874 return {k : v for k , v in d .items () if k in allowed_keys }
6975
7076
71- class SupportsInit (Protocol ):
72- def __init__ (self , * args , ** kwargs : Any ) -> None : ...
73-
74-
75- def filtered_for_init (d : dict , cls : type [SupportsInit ]) -> dict :
77+ def filtered_for_init (d : dict , cls : type [Any ]) -> dict :
7678 """Filtered for init."""
7779 return filtered_dict (d , varnames (cls .__init__ ))
7880
@@ -162,22 +164,49 @@ def load(self, filename: str, password: str | None = None) -> str:
162164
163165
164166class ClassSerializer :
167+ @staticmethod
168+ def _is_saveable_class (obj_cls : KnownClass ) -> TypeGuard [SaveableClass ]:
169+ return issubclass (obj_cls , BaseSaveableClass )
170+
171+ @staticmethod
172+ def _is_enum_class (obj_cls : KnownClass ) -> TypeGuard [EnumClass ]:
173+ return issubclass (obj_cls , enum .Enum )
174+
175+ @staticmethod
176+ def _merge_class_kwargs (dct : dict [str , Any ], cls_string : str , extra_kwargs : ClassArgs ) -> dict [str , Any ]:
177+ duplicate_keys = sorted (set (dct ).intersection (extra_kwargs ))
178+ if duplicate_keys :
179+ logger .error (
180+ "Duplicate deserialization keys for %s; keeping values from dct. "
181+ "duplicate_keys=%s dct_values=%s class_kwargs_values=%s" ,
182+ cls_string ,
183+ duplicate_keys ,
184+ {key : dct [key ] for key in duplicate_keys },
185+ {key : extra_kwargs [key ] for key in duplicate_keys },
186+ )
187+
188+ merged_dct = extra_kwargs .copy ()
189+ merged_dct .update (dct )
190+ return merged_dct
191+
165192 @classmethod
166- def general_deserializer (cls , known_classes , class_kwargs ) -> Callable :
193+ def general_deserializer (
194+ cls , known_classes : KnownClasses , class_kwargs : ClassKwargs
195+ ) -> Callable [[dict [str , Any ]], Any ]:
167196 """General deserializer."""
168197
169- def deserializer (dct : dict ) -> dict :
198+ def deserializer (dct : dict [ str , Any ] ) -> Any :
170199 """Deserializer."""
171200 cls_string = dct .get ("__class__" ) # e.g. KeyStore
172201 if cls_string :
173202 if cls_string in known_classes :
174203 obj_cls = known_classes .get (cls_string )
175- if hasattr ( obj_cls , "from_dump" ): # is there KeyStore.from_dump ?
176- if class_kwargs .get (cls_string ): # apply additional arguments to the class from_dump
177- dct . update ( class_kwargs . get ( cls_string ))
178- return obj_cls . from_dump (
179- dct , class_kwargs = class_kwargs
180- ) # do: KeyStore .from_dump(** dct)
204+ if obj_cls and cls . _is_saveable_class ( obj_cls ):
205+ if extra_class_kwargs := class_kwargs .get (
206+ cls_string
207+ ): # apply additional arguments to the class from_dump
208+ dct = cls . _merge_class_kwargs ( dct , cls_string , extra_class_kwargs )
209+ return obj_cls .from_dump (dct , class_kwargs = class_kwargs )
181210 else :
182211 raise Exception (f"{ obj_cls } doesnt have a from_dump classmethod." )
183212 else :
@@ -199,8 +228,8 @@ def deserializer(dct: dict) -> dict:
199228 )
200229 elif dct .get ("__enum__" ):
201230 obj_cls = known_classes .get (dct ["name" ])
202- if obj_cls and hasattr (obj_cls , dct ["value" ]) :
203- return getattr ( obj_cls , dct ["value" ])
231+ if obj_cls and cls . _is_enum_class (obj_cls ) and dct ["value" ] in obj_cls . __members__ :
232+ return obj_cls [ dct ["value" ]]
204233 else :
205234 logger .exception (f"Could not deserialize { obj_cls } ({ dct .get ('value' )} )." )
206235
@@ -222,12 +251,12 @@ def general_serializer(cls, obj):
222251
223252
224253class BaseSaveableClass :
225- known_classes : dict [ str , Any ] = {"Network" : bdk .Network }
254+ known_classes : KnownClasses = {"Network" : bdk .Network }
226255 VERSION = "0.0.0"
227256 _version_from_dump : str | None = None
228257
229258 @staticmethod
230- def cls_kwargs (* args , ** kwargs ):
259+ def cls_kwargs (* args , ** kwargs ) -> ClassArgs :
231260 return {}
232261
233262 @abstractmethod
@@ -254,7 +283,7 @@ def from_dump_downgrade_migration(cls, dct: dict[str, Any]):
254283 return dct
255284
256285 @classmethod
257- def _from_dump (cls , dct : dict [str , Any ], class_kwargs : dict | None = None ):
286+ def _from_dump (cls , dct : dict [str , Any ], class_kwargs : ClassArgs | None = None ):
258287 """From dump."""
259288 assert dct .get ("__class__" ) == cls .__name__
260289 del dct ["__class__" ]
@@ -273,11 +302,11 @@ def _from_dump(cls, dct: dict[str, Any], class_kwargs: dict | None = None):
273302
274303 @classmethod
275304 @abstractmethod
276- def from_dump (cls : type [ SupportsInit ] , dct : dict [str , Any ], class_kwargs : dict | None = None ):
305+ def from_dump (cls , dct : dict [str , Any ], class_kwargs : ClassKwargs | None = None ):
277306 """From dump."""
278307 raise NotImplementedError ()
279308
280- def clone (self , class_kwargs : dict | None = None ) -> Self :
309+ def clone (self , class_kwargs : ClassKwargs | None = None ) -> Self :
281310 """Clone."""
282311 return self ._from_dumps (self .dumps (), class_kwargs = class_kwargs )
283312
@@ -314,7 +343,7 @@ def dumps(self, indent=None) -> str:
314343 return self .dumps_object (self , indent = indent )
315344
316345 @staticmethod
317- def _flatten_known_classes (known_classes : dict [ str , Any ] ) -> dict [ str , Any ] :
346+ def _flatten_known_classes (known_classes : KnownClasses ) -> KnownClasses :
318347 "Recursively extends the dict to includes all known_classes of known_classes"
319348 known_classes = known_classes .copy ()
320349 for known_class in list (known_classes .values ()):
@@ -323,13 +352,13 @@ def _flatten_known_classes(known_classes: dict[str, Any]) -> dict[str, Any]:
323352 return known_classes
324353
325354 @classmethod
326- def get_known_classes (cls ) -> dict [ str , Any ] :
355+ def get_known_classes (cls ) -> KnownClasses :
327356 "Gets a flattened list of known classes that a json deserializer needs to interpet all objects"
328357 return BaseSaveableClass ._flatten_known_classes ({cls .__name__ : cls })
329358
330359 @classmethod
331360 @time_logger
332- def _from_dumps (cls , json_string : str , class_kwargs : dict | None = None ):
361+ def _from_dumps (cls , json_string : str , class_kwargs : ClassKwargs | None = None ):
333362 return json .loads (
334363 json_string ,
335364 object_hook = ClassSerializer .general_deserializer (
@@ -339,7 +368,7 @@ def _from_dumps(cls, json_string: str, class_kwargs: dict | None = None):
339368
340369 @classmethod
341370 @time_logger
342- def _from_file (cls , filename : str , password : str | None = None , class_kwargs : dict | None = None ):
371+ def _from_file (cls , filename : str , password : str | None = None , class_kwargs : ClassKwargs | None = None ):
343372 """Loads the class from a file. This offers the option of add class_kwargs args.
344373
345374 Args:
@@ -371,7 +400,7 @@ def dump(self):
371400 return d
372401
373402 @classmethod
374- def from_dump (cls , dct : dict , class_kwargs : dict | None = None ):
403+ def from_dump (cls , dct : dict , class_kwargs : ClassKwargs | None = None ):
375404 """From dump."""
376405 super ()._from_dump (dct , class_kwargs = class_kwargs )
377406 return cls (** filtered_for_init (dct , cls ))
0 commit comments