diff --git a/pyproject.toml b/pyproject.toml index 1b4ed1d..9086d0b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dataglasses" -version = "0.6.0" +version = "0.7.0" dependencies = [] requires-python = ">=3.10" authors = [{ name = "Uri Granta", email = "uri.granta+python@gmail.com" }] diff --git a/src/dataglasses/core.py b/src/dataglasses/core.py index 5ff573a..780b569 100644 --- a/src/dataglasses/core.py +++ b/src/dataglasses/core.py @@ -105,7 +105,11 @@ def _from_dict( datacls, ) - origin = get_origin(cls) + origin = cast(type, get_origin(cls)) + + if origin in transform and not transformed: + input_type, fn = transform[origin] + return fn(_from_dict(input_type, value, datacls, transformed=True)) if origin in (collections.abc.Sequence, list): if not isinstance(value, Sequence): @@ -274,7 +278,11 @@ def _json_schema( evaluated_type = cast(type, ref._evaluate(_globals, _locals, frozenset())) return _json_schema(evaluated_type, datacls) - origin = get_origin(cls) + origin = cast(type, get_origin(cls)) + + if origin in transform and not transformed: + input_type, _ = transform[origin] + return _json_schema(input_type, datacls, transformed=True) if origin in (collections.abc.Sequence, list): sequence_type = get_args(cls)[0] diff --git a/tests/test_dataklasses.py b/tests/test_dataklasses.py index 2be1f51..30b16c0 100644 --- a/tests/test_dataklasses.py +++ b/tests/test_dataklasses.py @@ -446,6 +446,25 @@ def test_transform(transform: TransformRules, output: DataclassTransform) -> Non validate(value, schema) +@dataclass +class DataclassTransformGeneric: + a: set[str] + b: set[int] + c: set[float] + + +def test_transform_generic() -> None: + value = {"a": ["a", "b"], "b": [1, 2], "c": [0.5, 0.7]} + transform: TransformRules = { + set: (list[str | int], set), + set[float]: (list[float], lambda lst: set(lst) | {0.0}), + } + data = from_dict(DataclassTransformGeneric, value, transform=transform) + assert data == DataclassTransformGeneric({"a", "b"}, {1, 2}, {0.0, 0.5, 0.7}) + schema = to_json_schema(DataclassTransformGeneric, transform=transform) + validate(value, schema) + + # ================= # UNSUPPORTED TYPES # =================