From ec6715e3bc089f998622eb77e009a3f1d55dae19 Mon Sep 17 00:00:00 2001 From: Christian Hettlage Date: Tue, 21 Oct 2025 09:30:53 +0200 Subject: [PATCH 1/5] Add an annotation for defining Pydantic types based on AstroPy Quantity --- src/aeonlib/types.py | 87 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 86 insertions(+), 1 deletion(-) diff --git a/src/aeonlib/types.py b/src/aeonlib/types.py index 5a649d3..aaf70e1 100644 --- a/src/aeonlib/types.py +++ b/src/aeonlib/types.py @@ -1,12 +1,13 @@ # pyright: reportUnknownVariableType=false # pyright: reportUnknownMemberType=false +import dataclasses import logging from datetime import datetime from typing import Annotated, Any, cast import astropy.coordinates import astropy.time -from astropy.units import Quantity +from astropy.units import Quantity, UnitBase from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler from pydantic.json_schema import JsonSchemaValue from pydantic_core import core_schema @@ -245,6 +246,90 @@ def __get_pydantic_json_schema__( } +@dataclasses.dataclass +class AstropyQuantityTypeAnnotation: + """ + Annotation for defining custom Pydantic types based on a `astropy.units.Quantity`. + + To define such a custom type, instantiate `AstropyQuantityTypeAnnotation` with + the default unit `d` and pass it as a type annotation. Pydantic fields with this + type can be instantiated wth a `float` or a `astropy.units.Quantity` with units + that are compatible with `d`. If a `float` is used, it is assumed to be given + with `d` as the unit. The field is stored as a `astropy.units.Quantity` with unit + `d`. + + For example, you can define a ProperMotion type as follows: + + ``` + from typing import Annotated, Union + from astropy import units as u + from astropy.units import Quantity + from aeonlib.salt.models.types import AstropyQuantityTypeAnnotation + + ProperMotion = Annotated[Union[Quantity, float], AstropyQuantityTypeAnnotation(u.arcsec / u.year)] + ``` + + This type can then be used in a Pydantic model: + + ``` + from pydantic import BaseModel + + class MovingObject(BaseModel): + proper_motion: ProperMotion + + # Create the same object in three different ways. + # Note: 1 year = 8766 hours + object1 = MovingObject(proper_motion=8766) # 3 arcsec per year + object2 = MovingObject(proper_motion=8766 * u.arcsec / u.year) + object3 = MovingObject(proper_motion=1 * u.arcsec / u.hour) + """ + + # Based on + # https://docs.pydantic.dev/latest/concepts/types/#handling-third-party-types + + default_unit: UnitBase + + def __get_pydantic_core_schema__( + self, + _source_type: Any, + _handler: GetCoreSchemaHandler, + ) -> core_schema.CoreSchema: + def validate_from_float(value: float) -> Quantity: + return Quantity(value, unit=self.default_unit) + + def validate_from_quantity(value: Quantity) -> Quantity: + return value.to(self.default_unit) + + from_float_schema = core_schema.chain_schema( + [ + core_schema.float_schema(), + core_schema.no_info_plain_validator_function(validate_from_float), + ] + ) + + from_quantity_schema = core_schema.chain_schema( + [ + core_schema.is_instance_schema(Quantity), + core_schema.no_info_plain_validator_function(validate_from_quantity), + ] + ) + + return core_schema.json_or_python_schema( + json_schema=from_float_schema, + python_schema=core_schema.union_schema( + [from_quantity_schema, from_float_schema] + ), + serialization=core_schema.plain_serializer_function_ser_schema( + lambda instance: instance.to(self.default_unit).value + ), + ) + + def __get_pydantic_json_schema( + self, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> JsonSchemaValue: + return handler(core_schema.float_schema()) + + Time = Annotated[astropy.time.Time | datetime, _AstropyTimeType] TimeMJD = Annotated[astropy.time.Time | datetime | float, _AstropyTimeMJDType] Angle = Annotated[ From d2f892ee828cd48fa0386576d1452332447846fa Mon Sep 17 00:00:00 2001 From: Christian Hettlage Date: Tue, 21 Oct 2025 09:52:08 +0200 Subject: [PATCH 2/5] Add tests for AstropyQuantityTypeAnnotation --- tests/module/test_types.py | 93 +++++++++++++++++++++++++++++++++++++- 1 file changed, 92 insertions(+), 1 deletion(-) diff --git a/tests/module/test_types.py b/tests/module/test_types.py index 93a125d..14b312d 100644 --- a/tests/module/test_types.py +++ b/tests/module/test_types.py @@ -1,12 +1,14 @@ import json import math from datetime import datetime +from typing import Annotated, Union import pytest from astropy.coordinates import Angle from astropy.time import Time from astropy import units as u -from pydantic import BaseModel +from astropy.units import Quantity +from pydantic import BaseModel, ValidationError import aeonlib.types @@ -21,6 +23,22 @@ class Window(BaseModel): end: aeonlib.types.Time +Wavelength = Annotated[ + Union[Quantity, float], aeonlib.types.AstropyQuantityTypeAnnotation(u.Angstrom) +] + + +ProperMotion = Annotated[ + Union[Quantity, float], + aeonlib.types.AstropyQuantityTypeAnnotation(u.arcsec / u.year), +] + + +class CelestialObject(BaseModel): + peak_wavelength: Wavelength + proper_motion: ProperMotion + + class TestAstropyTime: def test_with_astropy_time(self): """ @@ -208,3 +226,76 @@ def test_from_datetime(self): assert isinstance(obj.time, Time) assert obj.time.scale == "tt" assert obj.time.mjd == 60800.0 + + +class TestAstropyQuantityTypeAnnotation: + @pytest.mark.parametrize( + "peak_wavelength, proper_motion", + [ + # 1 year = 8766 hours + (4107 * u.Angstrom, 4383 * u.arcsec / u.year), + (410.7 * u.nm, 0.5 * u.arcsec / u.hour), + ], + ) + def test_from_quantity(self, peak_wavelength, proper_motion): + """ + Test objects constructed from astropy Quantity objects dump to json as floats + """ + asteroid = CelestialObject( + peak_wavelength=peak_wavelength, proper_motion=proper_motion + ) + dumped = asteroid.model_dump_json() + assert pytest.approx(json.loads(dumped)) == { + "peak_wavelength": 4107.0, + "proper_motion": 4383.0, + } + + def test_from_float(self): + """Test objects constructed from floats dump to json as floats""" + t = CelestialObject(peak_wavelength=7567.6, proper_motion=0.98) + dumped = t.model_dump_json() + assert pytest.approx(json.loads(dumped)) == { + "peak_wavelength": 7567.6, + "proper_motion": 0.98, + } + + @pytest.mark.parametrize( + "peak_wavelength, proper_motion", + [ + # 1 year = 8766 hours + (4107 * u.Angstrom, 4383 * u.arcsec / u.year), + (410.7 * u.nm, 0.5 * u.arcsec / u.hour), + ], + ) + def test_quantity_attributes(self, peak_wavelength, proper_motion): + """Test quantities are accessible on the model""" + asteroid = CelestialObject( + peak_wavelength=peak_wavelength, proper_motion=proper_motion + ) + assert isinstance(asteroid.peak_wavelength, Quantity) + assert pytest.approx(asteroid.peak_wavelength.value) == 4107 + assert asteroid.peak_wavelength.unit == u.Angstrom + assert pytest.approx(asteroid.proper_motion.value) == 4383.0 + assert asteroid.proper_motion.unit == u.arcsec / u.year + + def test_from_json(self): + """Test models can be constructed from json""" + target_json = json.dumps( + { + "peak_wavelength": "5516.89", + "proper_motion": "0.076", + } + ) + target = CelestialObject.model_validate_json(target_json) + assert isinstance(target.peak_wavelength, Quantity) + assert pytest.approx(target.peak_wavelength.value) == 5516.89 + assert target.peak_wavelength.unit == u.Angstrom + assert isinstance(target.proper_motion, Quantity) + assert pytest.approx(target.proper_motion.value) == 0.076 + assert target.proper_motion.unit == u.arcsec / u.year + + def test_rejects_incorrect_unit(self): + with pytest.raises(ValidationError, match="peak_wavelength"): + CelestialObject( + peak_wavelength=5000 * u.hour, proper_motion=1.0 * u.arcsec / u.year + ) From 5782abf7aac118cb027c80e8d65bc8049e8f9171 Mon Sep 17 00:00:00 2001 From: Christian Hettlage Date: Tue, 21 Oct 2025 12:12:16 +0200 Subject: [PATCH 3/5] Update the documentation on defining data models --- CONTRIBUTING.md | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index eba807d..77792e0 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -24,6 +24,44 @@ Facilities should attempt to avoid making excessive use of plain Python dictiona AEONlib also provides several [data types](https://github.com/AEONplus/AEONlib/blob/main/src/aeonlib/types.py) that improve Pydantic models: 1. `aeonlib.types.Time` using this type in a Pydantic model allows consumers to pass in `astropy.time.Time` objects as well as `datetime` objects. The facility can then decide how the time is serialized to match whatever specific format is required. 2. `aeonlib.types.Angle` similarly to the `Time` type, this allows consumers to pass in `astropy.coordinates.Angle` types as well as floats in decimal degrees, the facility can then decide how to serialize the type. +3. `aeonlib.types.AstropyQuantityTypeAnnotation` can be used to define Pydantic models based on `astropy.units.Quantity`, which allow consumers to pass in values with units or as floats. + +The following example illustrates how these types can be used to define a model. + +```python +from typing import Annotated, Union + +from astropy import coordinates +from astropy import time +from astropy import units as u +from astropy.units import Quantity +from pydantic import BaseModel + +from aeonlib.types import Angle, AstropyQuantityTypeAnnotation, Time + +Wavelength = Annotated[ + Union[Quantity, float], AstropyQuantityTypeAnnotation(u.Angstrom) +] + +class Observation(BaseModel): + start_time: Time + grating_angle: Angle + articulation_amgle: Angle + wavelength_range: tuple[Wavelength, Wavelength] + +observation = Observation( + start_time=time.Time(60775.0, scale="utc", format="mjd"), + grating_angle=22.5, + articulation_amgle=45 * u.deg, + wavelength_range=(5000, 600 * u.nm), # 5000 Å to 6000 Å +) + +assert type(observation.start_time) == time.Time +assert type(observation.grating_angle) == coordinates.Angle +assert type(observation.articulation_amgle) == coordinates.Angle +assert type(observation.wavelength_range[0]) == Quantity +assert type(observation.wavelength_range[1]) == Quantity +``` These types eliminate the need for the facility user to need to remember which exact format a facility requires (time in hms? Or ISO UTC?) and simply pass in higher level objects instead. From 4f2cbf081fcc68ad35e14ffbe594f523e96c5da2 Mon Sep 17 00:00:00 2001 From: Christian Hettlage Date: Wed, 22 Oct 2025 10:35:34 +0200 Subject: [PATCH 4/5] Add validators for checking greater/less than relations The validators are intended for cases where you cannot use the corresponding validators from the `annotated_types` library (for example, when dealing with AstroPy `Quantity` objects). --- src/aeonlib/validators.py | 164 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 164 insertions(+) create mode 100644 src/aeonlib/validators.py diff --git a/src/aeonlib/validators.py b/src/aeonlib/validators.py new file mode 100644 index 0000000..8ff08e8 --- /dev/null +++ b/src/aeonlib/validators.py @@ -0,0 +1,164 @@ +""" +This module defines some Pydantic validators. + +The validators are +""" + +from typing import Any + +import astropy.coordinates +from astropy import units as u +from pydantic import AfterValidator + + +def _check_gt(a: Any, b: Any) -> Any: + if a <= b: + raise ValueError(f"{a} is not greater than to {b}.") + return a + + +def _check_ge(a: Any, b: Any) -> Any: + if a < b: + raise ValueError(f"{a} is not greater than or equal to {b}.") + return a + + +def _check_lt(a: Any, b: Any) -> None: + if a >= b: + raise ValueError(f"{a} is not less than to {b}.") + return a + + +def _check_le(a: Any, b: Any) -> None: + if a > b: + raise ValueError(f"{a} is not less than or equal to {b}.") + return a + + +def Gt(value: Any): + """ + Return a Pydantic validator for checking a greater than relation. + + The returned validator can be used in a type annotation:: + + import pydantic + + class DummyModel(pydantic.BaseModel): + duration: Annotated[float, Gt(4)] + + Pydantic will first perform its own internal validation and then check whether + the field value is greater than the argument passed to `Gt` (4 in the example + above). + + It is up to the user to ensure that the field value and the argument of `Gt` can + be compared. + + Parameters + ---------- + value + Value against which to compare. + + Returns + ------- + A validator for checking a greater than relation. + """ + return AfterValidator(lambda v: _check_gt(v, value)) + + +def Ge(value: Any): + """ + Return a Pydantic validator for checking a greater than or equal to relation. + + The returned validator can be used in a type annotation:: + + import pydantic + + class DummyModel(pydantic.BaseModel): + duration: Annotated[float, Ge(4)] + + Pydantic will first perform its own internal validation and then check whether + the field value is greater than or equal to the argument passed to `Ge` (4 in the + example above). + + It is up to the user to ensure that the field value and the argument of `Ge` can + be compared. + + Parameters + ---------- + value + Value against which to compare. + + Returns + ------- + A validator for checking a greater than or equal to relation. + """ + return AfterValidator(lambda v: _check_ge(v, value)) + + +def Lt(value: Any): + """ + Return a Pydantic validator for checking a less than relation. + + The returned validator can be used in a type annotation:: + + import pydantic + + class DummyModel(pydantic.BaseModel): + height: Annotated[float, Lt(4)] + + Pydantic will first perform its own internal validation and then check whether + the field value is less than or equal to the argument passed to `Lt` (4 in the + example above). + + It is up to the user to ensure that the field value and the argument of `Lt` can + be compared. + + Parameters + ---------- + value + Value against which to compare. + + Returns + ------- + A validator for checking a less than relation. + """ + return AfterValidator(lambda v: _check_lt(v, value)) + + +def Le(value: Any): + """ + Return a Pydantic validator for checking a less than or equal to relation. + + The returned validator can be used in a type annotation:: + + import pydantic + + class DummyModel(pydantic.BaseModel): + height: Annotated[float, Le(4)] + + Pydantic will first perform its own internal validation and then check whether + the field value is less than or equal to the argument passed to `Le` (4 in the + example above). + + It is up to the user to ensure that the field value and the argument of `Le` can + be compared. + + Parameters + ---------- + value + Value against which to compare. + + Returns + ------- + A validator for checking a less than or equal to relation. + """ + return AfterValidator(lambda v: _check_le(v, value)) + + +def check_in_visibility_range( + dec: astropy.coordinates.Angle, +) -> astropy.coordinates.Angle: + if dec < -76 * u.deg or dec > 11 * u.deg: + raise ValueError("Not in SALT's visibility range (between -76 and 11 degrees).") + + return dec From 69b5132a09968dae01ff1df4d6883bbc88dd23a5 Mon Sep 17 00:00:00 2001 From: Christian Hettlage Date: Wed, 22 Oct 2025 11:09:53 +0200 Subject: [PATCH 5/5] Add validators for checking greater/less than relations The validators are intended for cases where you cannot use the corresponding validators from the `annotated_types` library (for example, when dealing with AstroPy `Quantity` objects). --- CONTRIBUTING.md | 2 + tests/module/test_validators.py | 85 +++++++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+) create mode 100644 tests/module/test_validators.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 77792e0..95f3a0e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -63,6 +63,8 @@ assert type(observation.wavelength_range[0]) == Quantity assert type(observation.wavelength_range[1]) == Quantity ``` +As you cannot use the validators of the `annotated_types` library with `Quantity` objects, the `aeonlib.validators` module provides some alternative validators, which you can use instead. + These types eliminate the need for the facility user to need to remember which exact format a facility requires (time in hms? Or ISO UTC?) and simply pass in higher level objects instead. Aeonlib also defines a few high level models: `aeonlib.models.Window` and several target types. If a facility can translate these models or even use them directly it should. This means a consumer of Aeonlib can define these high level models once, for example from a data alert stream, and use them with every facility for which they want to perform observations. diff --git a/tests/module/test_validators.py b/tests/module/test_validators.py new file mode 100644 index 0000000..ea05d73 --- /dev/null +++ b/tests/module/test_validators.py @@ -0,0 +1,85 @@ +from contextlib import nullcontext +from typing import Annotated + +import pytest +from pydantic import BaseModel, ValidationError + +from aeonlib.validators import Ge, Gt, Le, Lt + + +class GtModel(BaseModel): + a: Annotated[int, Gt(4)] + + +class GeModel(BaseModel): + a: Annotated[int, Ge(4)] + + +class LtModel(BaseModel): + a: Annotated[int, Lt(4)] + + +class LeModel(BaseModel): + a: Annotated[int, Le(4)] + + +class TestValidators: + @pytest.mark.parametrize( + "a, expectation", + [ + (3, pytest.raises(ValidationError)), + (4, pytest.raises(ValidationError)), + (5, nullcontext()), + ], + ) + def test_greater_than(self, a, expectation): + """Test that the Gt validator validates correctly.""" + with expectation: + GtModel(a=a) + + def test_greater_than_does_not_change_field_value(self): + """Test that the field value is not changed by the Gt validator.""" + assert GtModel(a=7).a == 7 + + @pytest.mark.parametrize( + "a, expectation", + [(3, pytest.raises(ValidationError)), (4, nullcontext()), (5, nullcontext())], + ) + def test_greater_equal(self, a, expectation): + """Test that the Ge validator validates correctly.""" + with expectation: + GeModel(a=a) + + def test_greater_equal_does_not_change_field_value(self): + """Test that the field value is not changed by the Ge validator.""" + assert GeModel(a=7).a == 7 + + @pytest.mark.parametrize( + "a, expectation", + [ + (3, nullcontext()), + (4, pytest.raises(ValidationError)), + (5, pytest.raises(ValidationError)), + ], + ) + def test_less_than(self, a, expectation): + """Test that the Lt validator validates correctly.""" + with expectation: + LtModel(a=a) + + def test_less_than_does_not_change_field_value(self): + """Test that the field value is not changed by the Lt validator.""" + assert LtModel(a=2).a == 2 + + @pytest.mark.parametrize( + "a, expectation", + [(3, nullcontext()), (4, nullcontext()), (5, pytest.raises(ValidationError))], + ) + def test_less_equal(self, a, expectation): + """Test that the Le validator validates correctly.""" + with expectation: + LeModel(a=a) + + def test_less_equal_does_not_change_field_value(self): + """Test that the field value is not changed by the Le validator.""" + assert LeModel(a=2).a == 2