From 264274ace9e577020353e7b0baa89e96e487cfca Mon Sep 17 00:00:00 2001 From: mishaschwartz <4380924+mishaschwartz@users.noreply.github.com> Date: Fri, 21 Nov 2025 10:00:38 -0500 Subject: [PATCH] handle multiple timezones --- marble_api/versions/v1/data_request/models.py | 20 ++++++++++++++++++- marble_api/versions/v1/data_request/routes.py | 1 - test/faker_providers.py | 14 ++++++------- .../versions/v1/data_request/test_routes.py | 9 ++++----- .../versions/v1/data_request/test_models.py | 2 +- 5 files changed, 30 insertions(+), 16 deletions(-) diff --git a/marble_api/versions/v1/data_request/models.py b/marble_api/versions/v1/data_request/models.py index 38e8c82..872a4c4 100644 --- a/marble_api/versions/v1/data_request/models.py +++ b/marble_api/versions/v1/data_request/models.py @@ -1,6 +1,7 @@ +import datetime from collections.abc import Sized from datetime import timezone -from typing import Required, TypedDict +from typing import Required, Self, TypedDict from bson import ObjectId from pydantic import ( @@ -14,6 +15,7 @@ ValidationInfo, field_serializer, field_validator, + model_validator, ) from pydantic.functional_validators import BeforeValidator from pydantic.json_schema import SkipJsonSchema @@ -55,6 +57,7 @@ class DataRequest(BaseModel): authors: list[Author] geometry: GeoJSON | None temporal: Temporal + tz_offset: SkipJsonSchema[list[float] | None] = Field(default=None, exclude=True) links: Links path: str contact: EmailStr @@ -78,6 +81,21 @@ def validate_geometries(cls, value: GeoJSON | None) -> dict | None: validate_collapsible(value) return value + @model_validator(mode="after") + def get_tz_offset(self) -> Self: + """Store the timezone offset for the temporal data.""" + if self.temporal is not None: + self.tz_offset = [datetime.datetime.utcoffset(t).total_seconds() for t in self.temporal] + return self + + @field_serializer("temporal") + def convert_from_utc(self, value: Temporal, info: FieldSerializationInfo) -> list[str]: + """Apply the timezone offset to convert this from UTC to a date in the correct timezone.""" + return [ + t.astimezone(datetime.timezone(datetime.timedelta(seconds=self.tz_offset[i]))).isoformat() + for i, t in enumerate(value) + ] + @field_serializer("user") def require_user_set(self, value: str, info: FieldSerializationInfo) -> str: """Require that the user_name is set when the model is serialized.""" diff --git a/marble_api/versions/v1/data_request/routes.py b/marble_api/versions/v1/data_request/routes.py index 72f60b8..e7f6c31 100644 --- a/marble_api/versions/v1/data_request/routes.py +++ b/marble_api/versions/v1/data_request/routes.py @@ -64,7 +64,6 @@ async def patch_data_request( data_request.user = user selector = {"_id": _data_request_id(request_id)} if updated_fields: - updated_fields.update(data_request.model_dump(include="stac_item")) result = await client.db["data-request"].find_one_and_update( selector, {"$set": updated_fields}, return_document=ReturnDocument.AFTER ) diff --git a/test/faker_providers.py b/test/faker_providers.py index 8396a33..69e2646 100644 --- a/test/faker_providers.py +++ b/test/faker_providers.py @@ -1,5 +1,3 @@ -import datetime - import bson import pytest from faker import Faker @@ -192,22 +190,22 @@ def author(self): author_["email"] = self.generator.email() return author_ - def utc_date_time_seconds_precision(self): - return self.generator.date_time(tzinfo=datetime.timezone.utc).replace(microsecond=0) + def tz_aware_date_time_seconds_precision(self): + return self.generator.date_time(tzinfo=self.generator.pytimezone()).replace(microsecond=0) def temporal(self): opt = self.generator.random.random() if opt < 1 / 3: return sorted( [ - self.utc_date_time_seconds_precision(), - self.utc_date_time_seconds_precision(), + self.tz_aware_date_time_seconds_precision(), + self.tz_aware_date_time_seconds_precision(), ] ) elif opt < 2 / 3: - return [self.utc_date_time_seconds_precision()] * 2 + return [self.tz_aware_date_time_seconds_precision()] * 2 else: - return [self.utc_date_time_seconds_precision()] + return [self.tz_aware_date_time_seconds_precision()] def link(self): return {"href": self.generator.uri(), "rel": self.generator.word(), "type": self.generator.mime_type()} diff --git a/test/integration/versions/v1/data_request/test_routes.py b/test/integration/versions/v1/data_request/test_routes.py index d11b212..93301ff 100644 --- a/test/integration/versions/v1/data_request/test_routes.py +++ b/test/integration/versions/v1/data_request/test_routes.py @@ -7,7 +7,7 @@ from stac_pydantic import Item from marble_api.database import client -from marble_api.versions.v1.data_request.models import DataRequest, DataRequestPublic +from marble_api.versions.v1.data_request.models import DataRequestPublic from marble_api.versions.v1.data_request.routes import get_data_requests pytestmark = pytest.mark.anyio @@ -226,11 +226,10 @@ async def test_valid(self, fake, async_client, collection_route, data_requests): data = fake.data_request().model_dump_json(exclude=["user"]) response = await async_client.post(collection_route, json=json.loads(data)) assert response.status_code == 200 - assert (id_ := response.json().get("id")) + response_data = response.json() + assert (id_ := response_data.pop("id", None)) bson.ObjectId(id_) # check that the id is a valid object id - assert {"user": data_requests[0]["user"], **json.loads(data)} == json.loads( - DataRequest(**response.json()).model_dump_json() - ) + assert {"user": data_requests[0]["user"], **json.loads(data)} == response_data async def test_invalid_authors(self, fake, async_client, collection_route): data = json.loads(fake.data_request().model_dump_json()) diff --git a/test/unit/versions/v1/data_request/test_models.py b/test/unit/versions/v1/data_request/test_models.py index 3bec7db..191a33a 100644 --- a/test/unit/versions/v1/data_request/test_models.py +++ b/test/unit/versions/v1/data_request/test_models.py @@ -137,7 +137,7 @@ def test_temporal_to_utc(self, fake_class): offset = datetime.timezone(datetime.timedelta(hours=3)) temporal = [now, now + datetime.timedelta(hours=1)] temporal_offset = [t.astimezone(offset) for t in temporal] - request = fake_class(temporal=temporal_offset) + request = fake_class(temporal=[t.isoformat() for t in temporal_offset]) assert ( request.stac_item["properties"]["datetime"] == temporal[0].isoformat()