diff --git a/packages/scratch-core/src/container_models/base.py b/packages/scratch-core/src/container_models/base.py index df2bb659..5d23c157 100644 --- a/packages/scratch-core/src/container_models/base.py +++ b/packages/scratch-core/src/container_models/base.py @@ -1,8 +1,8 @@ from collections.abc import Sequence from functools import partial from typing import Annotated, TypeAlias - -from numpy import array, bool_, floating, number, uint8 +from functools import cached_property +from numpy import array, bool_, floating, float64, number, uint8 from numpy.typing import DTypeLike, NDArray from pydantic import ( AfterValidator, @@ -51,7 +51,7 @@ def validate_shape(n_dims: int, value: NDArray) -> NDArray: ] FloatArray: TypeAlias = Annotated[ NDArray[floating], - BeforeValidator(partial(coerce_to_array, floating)), + BeforeValidator(partial(coerce_to_array, float64)), PlainSerializer(serialize_ndarray), ] BoolArray: TypeAlias = Annotated[ @@ -98,4 +98,22 @@ class ConfigBaseModel(BaseModel): extra="forbid", arbitrary_types_allowed=True, regex_engine="rust-regex", + revalidate_instances="always", ) + + def model_copy(self, *, update=None, deep=False): + copy = super().model_copy(update=update, deep=deep) + if update: + # Invalidate cached properties when any field changes + self._clear_cached_properties(copy) + # Validate model after updating + copy = self.model_validate(copy, by_alias=True, by_name=True) + return copy + + @staticmethod + def _clear_cached_properties(instance: BaseModel): + """Dynamically find and clear all cached_property values from instance.""" + for name in dir(type(instance)): + attr = getattr(type(instance), name, None) + if isinstance(attr, cached_property): + instance.__dict__.pop(name, None) diff --git a/packages/scratch-core/src/container_models/scan_image.py b/packages/scratch-core/src/container_models/scan_image.py index 5791f385..4b2b2fa1 100644 --- a/packages/scratch-core/src/container_models/scan_image.py +++ b/packages/scratch-core/src/container_models/scan_image.py @@ -41,14 +41,3 @@ def valid_data(self) -> FloatArray1D: valid_data = self.data[self.valid_mask] valid_data.setflags(write=False) return valid_data - - def model_copy(self, *, update=None, deep=False): - copy = super().model_copy(update=update, deep=deep) - # Invalidate cached properties when any field changes - if update: - # Dynamically find and clear all cached_property attributes - for name in dir(type(copy)): - attr = getattr(type(copy), name, None) - if isinstance(attr, cached_property): - copy.__dict__.pop(name, None) - return copy diff --git a/packages/scratch-core/src/conversion/preprocess_impression/utils.py b/packages/scratch-core/src/conversion/preprocess_impression/utils.py index 1330067b..85af69a7 100644 --- a/packages/scratch-core/src/conversion/preprocess_impression/utils.py +++ b/packages/scratch-core/src/conversion/preprocess_impression/utils.py @@ -29,7 +29,7 @@ def update_mark_scan_image( :param center: New center, or None to recompute from data. :return: New Mark instance with updated scan image. """ - return mark.model_copy(update={"scan_image": scan_image, "_center": center}) + return mark.model_copy(update={"scan_image": scan_image, "center_": center}) Point2D = tuple[float, float] diff --git a/packages/scratch-core/tests/container_models/test_validate_on_update.py b/packages/scratch-core/tests/container_models/test_validate_on_update.py new file mode 100644 index 00000000..1faed57d --- /dev/null +++ b/packages/scratch-core/tests/container_models/test_validate_on_update.py @@ -0,0 +1,57 @@ +import pytest +from pydantic import ValidationError +from container_models.base import ConfigBaseModel +from container_models.scan_image import ScanImage +import numpy as np + + +class TestModel(ConfigBaseModel): + string_field: str + float_field: float + + +@pytest.fixture(scope="module") +def test_model() -> TestModel: + return TestModel(string_field="some_string", float_field=100.0) + + +def test_model_copy_updates_correct_fields(test_model: TestModel): + updated = test_model.model_copy(update={"float_field": 150.0}) + assert updated.float_field == 150.0 + + +def test_model_copy_raises_on_incorrect_fields(test_model: TestModel): + with pytest.raises(ValidationError): + test_model.model_copy(update={"float_field": "invalid_string"}) + + +def test_model_copy_raises_on_extra_field(test_model: TestModel): + with pytest.raises(ValidationError): + test_model.model_copy(update={"extra_field": "some_value"}) + + +def test_scan_image_model_copy_converts_updated_fields( + scan_image_with_nans: ScanImage, +): + new_data = [[1, 2], [3, 4]] + updated = scan_image_with_nans.model_copy(update={"data": new_data}) + assert np.array_equal(updated.data, new_data) + + +def test_scan_image_model_copy_validates_updated_fields( + scan_image_with_nans: ScanImage, +): + with pytest.raises( + ValidationError, + match="Array shape mismatch, expected 2 dimension\\(s\\), but got 1", + ): + scan_image_with_nans.model_copy( + update={"data": scan_image_with_nans.data.flatten()} + ) + with pytest.raises( + ValidationError, + match="Array shape mismatch, expected 2 dimension\\(s\\), but got 0", + ): + scan_image_with_nans.model_copy(update={"data": "1"}) + with pytest.raises(ValidationError, match="Input should be an instance of ndarray"): + scan_image_with_nans.model_copy(update={"data": {0: 1}})