Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions packages/scratch-core/src/container_models/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from collections.abc import Sequence
from functools import partial
from typing import Annotated, TypeAlias

from numpy import array, bool_, floating, number, uint8
from functools import cached_property
from numpy import array, bool_, floating, float64, number, uint8
from numpy.typing import DTypeLike, NDArray
from pydantic import (
AfterValidator,
Expand Down Expand Up @@ -51,7 +51,7 @@ def validate_shape(n_dims: int, value: NDArray) -> NDArray:
]
FloatArray: TypeAlias = Annotated[
NDArray[floating],
BeforeValidator(partial(coerce_to_array, floating)),
BeforeValidator(partial(coerce_to_array, float64)),
PlainSerializer(serialize_ndarray),
]
BoolArray: TypeAlias = Annotated[
Expand Down Expand Up @@ -98,4 +98,22 @@ class ConfigBaseModel(BaseModel):
extra="forbid",
arbitrary_types_allowed=True,
regex_engine="rust-regex",
revalidate_instances="always",
)

def model_copy(self, *, update=None, deep=False):
copy = super().model_copy(update=update, deep=deep)
if update:
# Invalidate cached properties when any field changes
self._clear_cached_properties(copy)
# Validate model after updating
copy = self.model_validate(copy, by_alias=True, by_name=True)
return copy

@staticmethod
def _clear_cached_properties(instance: BaseModel):
"""Dynamically find and clear all cached_property values from instance."""
for name in dir(type(instance)):
attr = getattr(type(instance), name, None)
if isinstance(attr, cached_property):
instance.__dict__.pop(name, None)
11 changes: 0 additions & 11 deletions packages/scratch-core/src/container_models/scan_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,3 @@ def valid_data(self) -> FloatArray1D:
valid_data = self.data[self.valid_mask]
valid_data.setflags(write=False)
return valid_data

def model_copy(self, *, update=None, deep=False):
copy = super().model_copy(update=update, deep=deep)
# Invalidate cached properties when any field changes
if update:
# Dynamically find and clear all cached_property attributes
for name in dir(type(copy)):
attr = getattr(type(copy), name, None)
if isinstance(attr, cached_property):
copy.__dict__.pop(name, None)
return copy
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def update_mark_scan_image(
:param center: New center, or None to recompute from data.
:return: New Mark instance with updated scan image.
"""
return mark.model_copy(update={"scan_image": scan_image, "_center": center})
return mark.model_copy(update={"scan_image": scan_image, "center_": center})


Point2D = tuple[float, float]
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import pytest
from pydantic import ValidationError
from container_models.base import ConfigBaseModel
from container_models.scan_image import ScanImage
import numpy as np


class TestModel(ConfigBaseModel):
string_field: str
float_field: float


@pytest.fixture(scope="module")
def test_model() -> TestModel:
return TestModel(string_field="some_string", float_field=100.0)


def test_model_copy_updates_correct_fields(test_model: TestModel):
updated = test_model.model_copy(update={"float_field": 150.0})
assert updated.float_field == 150.0


def test_model_copy_raises_on_incorrect_fields(test_model: TestModel):
with pytest.raises(ValidationError):
test_model.model_copy(update={"float_field": "invalid_string"})


def test_model_copy_raises_on_extra_field(test_model: TestModel):
with pytest.raises(ValidationError):
test_model.model_copy(update={"extra_field": "some_value"})


def test_scan_image_model_copy_converts_updated_fields(
scan_image_with_nans: ScanImage,
):
new_data = [[1, 2], [3, 4]]
updated = scan_image_with_nans.model_copy(update={"data": new_data})
assert np.array_equal(updated.data, new_data)


def test_scan_image_model_copy_validates_updated_fields(
scan_image_with_nans: ScanImage,
):
with pytest.raises(
ValidationError,
match="Array shape mismatch, expected 2 dimension\\(s\\), but got 1",
):
scan_image_with_nans.model_copy(
update={"data": scan_image_with_nans.data.flatten()}
)
with pytest.raises(
ValidationError,
match="Array shape mismatch, expected 2 dimension\\(s\\), but got 0",
):
scan_image_with_nans.model_copy(update={"data": "1"})
with pytest.raises(ValidationError, match="Input should be an instance of ndarray"):
scan_image_with_nans.model_copy(update={"data": {0: 1}})