Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion marble_api/versions/v1/data_request/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
from collections.abc import Sized
from datetime import timezone
from typing import Required, TypedDict
from typing import Required, Self, TypedDict

from bson import ObjectId
from pydantic import (
Expand All @@ -14,6 +15,7 @@
ValidationInfo,
field_serializer,
field_validator,
model_validator,
)
from pydantic.functional_validators import BeforeValidator
from pydantic.json_schema import SkipJsonSchema
Expand Down Expand Up @@ -55,6 +57,7 @@ class DataRequest(BaseModel):
authors: list[Author]
geometry: GeoJSON | None
temporal: Temporal
tz_offset: SkipJsonSchema[list[float] | None] = Field(default=None, exclude=True)
links: Links
path: str
contact: EmailStr
Expand All @@ -78,6 +81,21 @@ def validate_geometries(cls, value: GeoJSON | None) -> dict | None:
validate_collapsible(value)
return value

@model_validator(mode="after")
def get_tz_offset(self) -> Self:
"""Store the timezone offset for the temporal data."""
if self.temporal is not None:
self.tz_offset = [datetime.datetime.utcoffset(t).total_seconds() for t in self.temporal]
return self

@field_serializer("temporal")
def convert_from_utc(self, value: Temporal, info: FieldSerializationInfo) -> list[str]:
"""Apply the timezone offset to convert this from UTC to a date in the correct timezone."""
return [
t.astimezone(datetime.timezone(datetime.timedelta(seconds=self.tz_offset[i]))).isoformat()
for i, t in enumerate(value)
]

@field_serializer("user")
def require_user_set(self, value: str, info: FieldSerializationInfo) -> str:
"""Require that the user_name is set when the model is serialized."""
Expand Down
1 change: 0 additions & 1 deletion marble_api/versions/v1/data_request/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ async def patch_data_request(
data_request.user = user
selector = {"_id": _data_request_id(request_id)}
if updated_fields:
updated_fields.update(data_request.model_dump(include="stac_item"))
result = await client.db["data-request"].find_one_and_update(
selector, {"$set": updated_fields}, return_document=ReturnDocument.AFTER
)
Expand Down
14 changes: 6 additions & 8 deletions test/faker_providers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import datetime

import bson
import pytest
from faker import Faker
Expand Down Expand Up @@ -192,22 +190,22 @@ def author(self):
author_["email"] = self.generator.email()
return author_

def utc_date_time_seconds_precision(self):
return self.generator.date_time(tzinfo=datetime.timezone.utc).replace(microsecond=0)
def tz_aware_date_time_seconds_precision(self):
return self.generator.date_time(tzinfo=self.generator.pytimezone()).replace(microsecond=0)

def temporal(self):
opt = self.generator.random.random()
if opt < 1 / 3:
return sorted(
[
self.utc_date_time_seconds_precision(),
self.utc_date_time_seconds_precision(),
self.tz_aware_date_time_seconds_precision(),
self.tz_aware_date_time_seconds_precision(),
]
)
elif opt < 2 / 3:
return [self.utc_date_time_seconds_precision()] * 2
return [self.tz_aware_date_time_seconds_precision()] * 2
else:
return [self.utc_date_time_seconds_precision()]
return [self.tz_aware_date_time_seconds_precision()]

def link(self):
return {"href": self.generator.uri(), "rel": self.generator.word(), "type": self.generator.mime_type()}
Expand Down
9 changes: 4 additions & 5 deletions test/integration/versions/v1/data_request/test_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from stac_pydantic import Item

from marble_api.database import client
from marble_api.versions.v1.data_request.models import DataRequest, DataRequestPublic
from marble_api.versions.v1.data_request.models import DataRequestPublic
from marble_api.versions.v1.data_request.routes import get_data_requests

pytestmark = pytest.mark.anyio
Expand Down Expand Up @@ -226,11 +226,10 @@ async def test_valid(self, fake, async_client, collection_route, data_requests):
data = fake.data_request().model_dump_json(exclude=["user"])
response = await async_client.post(collection_route, json=json.loads(data))
assert response.status_code == 200
assert (id_ := response.json().get("id"))
response_data = response.json()
assert (id_ := response_data.pop("id", None))
bson.ObjectId(id_) # check that the id is a valid object id
assert {"user": data_requests[0]["user"], **json.loads(data)} == json.loads(
DataRequest(**response.json()).model_dump_json()
)
assert {"user": data_requests[0]["user"], **json.loads(data)} == response_data

async def test_invalid_authors(self, fake, async_client, collection_route):
data = json.loads(fake.data_request().model_dump_json())
Expand Down
2 changes: 1 addition & 1 deletion test/unit/versions/v1/data_request/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def test_temporal_to_utc(self, fake_class):
offset = datetime.timezone(datetime.timedelta(hours=3))
temporal = [now, now + datetime.timedelta(hours=1)]
temporal_offset = [t.astimezone(offset) for t in temporal]
request = fake_class(temporal=temporal_offset)
request = fake_class(temporal=[t.isoformat() for t in temporal_offset])
assert (
request.stac_item["properties"]["datetime"]
== temporal[0].isoformat()
Expand Down