From d5043994a12ff8833b57201e667b2330ca5ae044 Mon Sep 17 00:00:00 2001 From: mishaschwartz <4380924+mishaschwartz@users.noreply.github.com> Date: Mon, 10 Nov 2025 20:50:40 -0500 Subject: [PATCH] introduce users --- README.md | 16 + marble_api/versions/v1/app.py | 6 +- marble_api/versions/v1/data_request/models.py | 13 +- marble_api/versions/v1/data_request/routes.py | 91 ++++-- .../versions/v1/data_request/test_routes.py | 279 +++++++++++++----- .../versions/v1/data_request/test_models.py | 11 +- 6 files changed, 318 insertions(+), 98 deletions(-) diff --git a/README.md b/README.md index 969e914..95b31b6 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,22 @@ An API for the Marble platform. - MongoDB server +## Authentication and Authorization + +Marble API does not do any authentication or authorization (authn/z). That is left to other +applications (such as [Magpie](https://github.com/ouranosinc/magpie)). + +Marble API assumes that only users with administrator access should be able to access all routes +prefixed with `/vX/admin/` (where `X` is a version number). + +Marble API also assumes that only users with a given user name or id `Y` should be able to access +all routes prefixed with `/vX/users/Y/` (where `X` is a version number). + +When integrating Marble API with the [birdhouse](https://github.com/bird-house/birdhouse-deploy/) platform we +recommend enabling it with the +[Marble API component](https://github.com/DACCS-Climate/marble-config/tree/main/components/marble-api). +This enables the basic authn/z rules described above through [Magpie](https://github.com/ouranosinc/magpie). + ## Developing To start a development server: diff --git a/marble_api/versions/v1/app.py b/marble_api/versions/v1/app.py index 2d44d95..f973eb8 100644 --- a/marble_api/versions/v1/app.py +++ b/marble_api/versions/v1/app.py @@ -1,7 +1,9 @@ from fastapi import FastAPI -from marble_api.versions.v1.data_request.routes import router as data_request_router +from marble_api.versions.v1.data_request.routes import admin_router as data_request_admin_router +from marble_api.versions.v1.data_request.routes import user_router as data_request_user_router app = FastAPI(version="1") -app.include_router(data_request_router) +app.include_router(data_request_user_router) +app.include_router(data_request_admin_router) diff --git a/marble_api/versions/v1/data_request/models.py b/marble_api/versions/v1/data_request/models.py index 262fbf7..e39bfaf 100644 --- a/marble_api/versions/v1/data_request/models.py +++ b/marble_api/versions/v1/data_request/models.py @@ -9,7 +9,9 @@ ConfigDict, EmailStr, Field, + FieldSerializationInfo, ValidationInfo, + field_serializer, field_validator, ) from pydantic.functional_validators import BeforeValidator @@ -46,7 +48,7 @@ class DataRequest(BaseModel): """ id: SkipJsonSchema[PyObjectId | None] = Field(default=None, validation_alias="_id", exclude=True) - user: str + user: SkipJsonSchema[str | None] = None # user is set by the route after the model is first validated title: str description: str | None = None authors: list[Author] @@ -60,7 +62,7 @@ class DataRequest(BaseModel): extra_properties: dict[str, str] = {} model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True) - @field_validator("user", "title", "description", "authors", "path", "contact") + @field_validator("title", "description", "authors", "path", "contact") @classmethod def min_length_if_set(cls, value: Sized | None, info: ValidationInfo) -> Sized | None: """Raise an error if the value is not None and is empty.""" @@ -75,6 +77,12 @@ def validate_geometries(cls, value: GeoJSON | None) -> dict | None: validate_collapsible(value) return value + @field_serializer("user") + def require_user_set(self, value: str, info: FieldSerializationInfo) -> str: + """Require that the user_name is set when the model is serialized.""" + assert value, f"{info.field_name} must be set and non-empty" + return value + @partial_model class DataRequestUpdate(DataRequest): @@ -96,6 +104,7 @@ class DataRequestPublic(DataRequest): """ id: Annotated[str, BeforeValidator(str)] = Field(..., validation_alias="_id") + user: str # user is required to be set in the database model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True, extra="allow") @property diff --git a/marble_api/versions/v1/data_request/routes.py b/marble_api/versions/v1/data_request/routes.py index efccf6c..72f60b8 100644 --- a/marble_api/versions/v1/data_request/routes.py +++ b/marble_api/versions/v1/data_request/routes.py @@ -1,8 +1,10 @@ +from collections.abc import AsyncGenerator from typing import Annotated import pymongo from bson import ObjectId -from fastapi import APIRouter, HTTPException, Query, Request, Response, status +from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response, status +from pydantic_core import PydanticSerializationError from pymongo import ReturnDocument from marble_api.database import client @@ -14,26 +16,52 @@ DataRequestUpdate, ) -router = APIRouter(prefix="/data-requests") + +async def _handle_serialization_error() -> AsyncGenerator[None]: + try: + yield + except PydanticSerializationError as e: + raise HTTPException(status_code=422, detail=str(e)) from e + + +user_router = APIRouter(prefix="/users/{user}/data-requests", tags=["User"]) +admin_router = APIRouter( + prefix="/admin/data-requests", tags=["Admin"], dependencies=[Depends(_handle_serialization_error)] +) def _data_request_id(id_: str) -> ObjectId: return object_id(id_, HTTPException(status_code=404, detail=f"data publish request with id={id_} not found")) -@router.post("/") -async def post_data_request(data_request: DataRequest) -> DataRequestPublic: +def _is_router_scope(request: Request, router: APIRouter) -> bool: + return request.scope.get("route").path.startswith(f"{router.prefix}/") + + +@user_router.post("/") +@admin_router.post("/") +async def post_data_request_user(user: str, data_request: DataRequest) -> DataRequestPublic: """Create a new data request and return the newly created data request.""" + data_request.user = user new_data_request = data_request.model_dump(by_alias=True) result = await client.db["data-request"].insert_one(new_data_request) new_data_request["id"] = str(result.inserted_id) return new_data_request -@router.patch("/{request_id}") -async def patch_data_request(request_id: str, data_request: DataRequestUpdate) -> DataRequestPublic: +@user_router.patch("/{request_id}") +@admin_router.patch("/{request_id}") +async def patch_data_request( + request_id: str, data_request: DataRequestUpdate, request: Request, user: str | None = None +) -> DataRequestPublic: """Update fields of data request and return the updated data request.""" updated_fields = data_request.model_dump(exclude_unset=True, by_alias=True) + updated_user = updated_fields.get("user") + if updated_user and _is_router_scope(request, user_router) and user != updated_user: + # Users cannot change the data request so that it belongs to a different user + raise HTTPException(status_code=403, detail="Forbidden") + if user: + data_request.user = user selector = {"_id": _data_request_id(request_id)} if updated_fields: updated_fields.update(data_request.model_dump(include="stac_item")) @@ -49,10 +77,16 @@ async def patch_data_request(request_id: str, data_request: DataRequestUpdate) - raise HTTPException(status_code=404, detail="data publish request not found") -@router.get("/{request_id}", response_model_by_alias=False) -async def get_data_request(request_id: str, stac: bool = False) -> DataRequestPublic: +@user_router.get("/{request_id}", response_model_by_alias=False) +@admin_router.get("/{request_id}", response_model_by_alias=False) +async def get_data_request( + request_id: str, request: Request, stac: bool = False, user: str | None = None +) -> DataRequestPublic: """Get a data request with the given request_id.""" - if (result := await client.db["data-request"].find_one({"_id": _data_request_id(request_id)})) is not None: + selector = {"_id": _data_request_id(request_id)} + if _is_router_scope(request, user_router): + selector["user"] = user + if (result := await client.db["data-request"].find_one(selector)) is not None: if stac: try: result["stac_item"] = DataRequestPublic(**result).stac_item @@ -63,20 +97,26 @@ async def get_data_request(request_id: str, stac: bool = False) -> DataRequestPu raise HTTPException(status_code=404, detail="data publish request not found") -@router.delete("/{request_id}") -async def delete_data_request(request_id: str) -> Response: +@user_router.delete("/{request_id}") +@admin_router.delete("/{request_id}") +async def delete_data_request(request_id: str, request: Request, user: str | None = None) -> Response: """Delete a data request with the given request_id.""" - result = await client.db["data-request"].delete_one({"_id": _data_request_id(request_id)}) + selector = {"_id": _data_request_id(request_id)} + if _is_router_scope(request, user_router): + selector["user"] = user + result = await client.db["data-request"].delete_one(selector) if result.deleted_count == 1: return Response(status_code=status.HTTP_204_NO_CONTENT) raise HTTPException(status_code=404, detail="data publish request not found") -@router.get("/") +@user_router.get("/") +@admin_router.get("/") async def get_data_requests( request: Request, + user: str | None = None, after: str | None = None, before: str | None = None, limit: Annotated[int, Query(le=100, gt=0)] = 10, @@ -89,21 +129,27 @@ async def get_data_requests( Use the offset and limit parameters to select specific ranges of data requests. """ reverse_it = False + selector = {} + if _is_router_scope(request, user_router): + selector["user"] = user if after: db_request = ( - client.db["data-request"].find({"_id": {"$gte": _data_request_id(after)}}).sort("_id", pymongo.ASCENDING) + client.db["data-request"] + .find({**selector, "_id": {"$gt": _data_request_id(after)}}) + .sort("_id", pymongo.ASCENDING) ) elif before: db_request = ( - client.db["data-request"].find({"_id": {"$lte": _data_request_id(before)}}).sort("_id", pymongo.DESCENDING) + client.db["data-request"] + .find({**selector, "_id": {"$lt": _data_request_id(before)}}) + .sort("_id", pymongo.DESCENDING) ) reverse_it = True # put the eventual result back in ascending order for consistency else: - db_request = client.db["data-request"].find({}).sort("_id", pymongo.ASCENDING) - + db_request = client.db["data-request"].find(selector).sort("_id", pymongo.ASCENDING) data_requests = await db_request.limit(limit + 1).to_list() if reverse_it: - data_requests = reversed(data_requests) + data_requests = list(reversed(data_requests)) query_params = {} @@ -112,14 +158,17 @@ async def get_data_requests( if data_requests: if after: if over_limit: + data_requests.pop() query_params["after"] = data_requests[-1]["_id"] - query_params["before"] = data_requests.pop(0)["_id"] + query_params["before"] = data_requests[0]["_id"] elif before: if over_limit: + data_requests.pop(0) query_params["before"] = data_requests[0]["_id"] - query_params["after"] = data_requests.pop()["_id"] + query_params["after"] = data_requests[-1]["_id"] elif over_limit: - query_params["after"] = data_requests.pop()["_id"] + data_requests.pop() + query_params["after"] = data_requests[-1]["_id"] links = [] diff --git a/test/integration/versions/v1/data_request/test_routes.py b/test/integration/versions/v1/data_request/test_routes.py index cf15ec9..d11b212 100644 --- a/test/integration/versions/v1/data_request/test_routes.py +++ b/test/integration/versions/v1/data_request/test_routes.py @@ -13,13 +13,35 @@ pytestmark = pytest.mark.anyio +class _TestUser: + @pytest.fixture + def member_route(self, data_requests): + return f"/v1/users/{data_requests[0]['user']}/data-requests/{data_requests[0]['_id']}" + + @pytest.fixture + def collection_route(self, data_requests): + return f"/v1/users/{data_requests[0]['user']}/data-requests/" + + +class _TestAdmin: + @pytest.fixture + def member_route(self, data_requests): + return f"/v1/admin/data-requests/{data_requests[0]['_id']}" + + @pytest.fixture + def collection_route(self): + return "/v1/admin/data-requests/" + + class _TestGet: n_data_requests = 2 @pytest.fixture(scope="class", autouse=True) @classmethod async def load_data(cls, fake): - data = [fake.data_request().model_dump() for _ in range(cls.n_data_requests)] + data = [fake.data_request(user="user1").model_dump() for _ in range(cls.n_data_requests // 2)] + [ + fake.data_request(user="user2").model_dump() for _ in range(cls.n_data_requests - cls.n_data_requests // 2) + ] await client.db.get_collection("data-request").insert_many(data) @pytest.fixture(scope="class", autouse=True) @@ -36,63 +58,75 @@ async def data_requests(cls): yield await client.db.get_collection("data-request").find({}).to_list() -@pytest.mark.no_db_cleanup -class TestGetOne(_TestGet): - async def test_get(self, async_client, data_requests): - resp = await async_client.get(f"/v1/data-requests/{data_requests[0]['_id']}") +class _TestGetOne(_TestGet): + async def test_get(self, async_client, data_requests, member_route): + resp = await async_client.get(member_route) assert resp.status_code == 200 assert DataRequestPublic(**data_requests[0]) == DataRequestPublic(**resp.json()) - async def test_get_stac(self, async_client, data_requests): - resp = await async_client.get(f"/v1/data-requests/{data_requests[0]['_id']}?stac=true") + async def test_get_stac(self, async_client, member_route): + resp = await async_client.get(f"{member_route}?stac=true") assert resp.status_code == 200 assert (item := resp.json().get("stac_item")) Item(**item) - async def test_bad_id(self, async_client): - resp = await async_client.get("/v1/data-requests/id-does-not-exist") + async def test_bad_id(self, async_client, member_route): + invalid_route = "/".join(member_route.split("/")[:-1] + ["some-bad-id"]) + resp = await async_client.get(invalid_route) + assert resp.status_code == 404 + + +@pytest.mark.no_db_cleanup +class TestGetOneUser(_TestGetOne, _TestUser): + async def test_bad_user(self, async_client, data_requests): + invalid_route = f"/v1/users/{data_requests[0]['user'] + '-bad-user'}/data-requests/{data_requests[0]['_id']}" + resp = await async_client.get(invalid_route) assert resp.status_code == 404 @pytest.mark.no_db_cleanup -class TestGetMany(_TestGet): +class TestGetOneAdmin(_TestGetOne, _TestAdmin): ... + + +class _TestGetMany(_TestGet): default_link_limit = inspect.signature(get_data_requests).parameters["limit"].default - n_data_requests = default_link_limit + 2 + n_data_requests = default_link_limit * 2 + 2 + n_data_requests_return_count: int - async def test_get(self, async_client, data_requests): - response = await async_client.get("/v1/data-requests/") + async def test_get(self, async_client, data_requests, collection_route): + response = await async_client.get(collection_route) models = {str(req["_id"]): DataRequestPublic(**req) for req in data_requests} for req in response.json()["data_requests"]: assert DataRequestPublic(**req) == models[req["id"]] - async def test_get_stac(self, async_client): - resp = await async_client.get("/v1/data-requests/?stac=true") + async def test_get_stac(self, async_client, collection_route): + resp = await async_client.get(f"{collection_route}?stac=true") for req in resp.json()["data_requests"]: assert (item := req.get("stac_item")) Item(**item) - async def test_get_limit_default(self, async_client): - response = await async_client.get("/v1/data-requests/") + async def test_get_limit_default(self, async_client, collection_route): + response = await async_client.get(collection_route) assert len(response.json()["data_requests"]) == self.default_link_limit - async def test_get_limit_non_default(self, async_client): - response = await async_client.get("/v1/data-requests/?limit=5") + async def test_get_limit_non_default(self, async_client, collection_route): + response = await async_client.get(f"{collection_route}?limit=5") assert len(response.json()["data_requests"]) == 5 - async def test_get_limit_more(self, async_client): - response = await async_client.get(f"/v1/data-requests/?limit={self.n_data_requests + 1}") - assert len(response.json()["data_requests"]) == self.n_data_requests + async def test_get_limit_more(self, async_client, collection_route): + response = await async_client.get(f"{collection_route}?limit={self.n_data_requests + 1}") + assert len(response.json()["data_requests"]) == self.n_data_requests_return_count - async def test_get_limit_none(self, async_client): - response = await async_client.get("/v1/data-requests/?limit=0") + async def test_get_limit_none(self, async_client, collection_route): + response = await async_client.get(f"{collection_route}?limit=0") assert response.status_code == 422 - async def test_get_limit_over_max(self, async_client): - response = await async_client.get("/v1/data-requests/?limit=200") + async def test_get_limit_over_max(self, async_client, collection_route): + response = await async_client.get(f"{collection_route}?limit=200") assert response.status_code == 422 - async def test_get_first_page_links(self, async_client): - response = await async_client.get("/v1/data-requests/") + async def test_get_first_page_links(self, async_client, collection_route): + response = await async_client.get(collection_route) links = response.json()["links"] assert len(links) == 1 link = links[0] @@ -102,8 +136,8 @@ async def test_get_first_page_links(self, async_client): assert (after_id := parse_qs(urlparse(link["href"]).query).get("after")) assert after_id not in [r["id"] for r in response.json()["data_requests"]] - async def test_get_last_page_links(self, async_client): - response = await async_client.get("/v1/data-requests/") + async def test_get_last_page_links(self, async_client, collection_route): + response = await async_client.get(f"{collection_route}?limit={self.n_data_requests_return_count - 3}") next_link = next(link for link in response.json()["links"] if link["rel"] == "next") response2 = await async_client.get(next_link["href"]) links = response2.json()["links"] @@ -115,8 +149,8 @@ async def test_get_last_page_links(self, async_client): assert (before_id := parse_qs(urlparse(link["href"]).query).get("before")) assert before_id not in [r["id"] for r in response.json()["data_requests"]] - async def test_get_mid_page_links(self, async_client): - response = await async_client.get("/v1/data-requests/?limit=4") + async def test_get_mid_page_links(self, async_client, collection_route): + response = await async_client.get(f"{collection_route}?limit=4") next_link = next(link for link in response.json()["links"] if link["rel"] == "next") response2 = await async_client.get(next_link["href"]) links = response2.json()["links"] @@ -132,31 +166,96 @@ async def test_get_mid_page_links(self, async_client): assert (id_ := parse_qs(urlparse(link["href"]).query).get("after")) assert id_ not in [r["id"] for r in response.json()["data_requests"]] + async def test_next_prev_is_consistent(self, async_client, collection_route): + response = await async_client.get(f"{collection_route}?limit=4") + # page0 -> page1 + next_link = next(link for link in response.json()["links"] if link["rel"] == "next") + next_response = await async_client.get(next_link["href"]) + # page0 -> page1 -> page2 + next_next_link = next(link for link in next_response.json()["links"] if link["rel"] == "next") + next_next_response = await async_client.get(next_next_link["href"]) + # page0 -> page1 -> page0 + next_prev_link = next(link for link in next_response.json()["links"] if link["rel"] == "prev") + next_prev_response = await async_client.get(next_prev_link["href"]) + # page0 -> page1 -> page2 -> page1 + next_next_prev_link = next(link for link in next_next_response.json()["links"] if link["rel"] == "prev") + next_next_prev_response = await async_client.get(next_next_prev_link["href"]) + # page0 -> page1 -> page2 -> page1 -> page0 + next_next_prev_prev_link = next( + link for link in next_next_prev_response.json()["links"] if link["rel"] == "prev" + ) + next_next_prev_prev_response = await async_client.get(next_next_prev_prev_link["href"]) + assert response.json() == next_prev_response.json() == next_next_prev_prev_response.json() + assert next_response.json() == next_next_prev_response.json() + + async def test_get_all_same_as_paging_next(self, async_client, collection_route): + all_response = await async_client.get(f"{collection_route}?limit={self.n_data_requests}") + next_link = [f"{collection_route}?limit=4"] + data_requests = [] + while next_link: + response = await async_client.get(next_link[0]) + data_requests.extend(response.json()["data_requests"]) + next_link = [link["href"] for link in response.json()["links"] if link["rel"] == "next"] + assert all_response.json()["data_requests"] == data_requests + + +@pytest.mark.no_db_cleanup +class TestGetManyUser(_TestGetMany, _TestUser): + n_data_requests_return_count = _TestGetMany.n_data_requests // 2 + + async def test_get_all(self, async_client, collection_route): + response = await async_client.get(f"{collection_route}?limit={self.n_data_requests}") + assert len(response.json()["data_requests"]) == self.n_data_requests // 2 + + +@pytest.mark.no_db_cleanup +class TestGetManyAdmin(_TestGetMany, _TestAdmin): + n_data_requests_return_count = _TestGetMany.n_data_requests + + async def test_get_all(self, async_client, collection_route): + response = await async_client.get(f"{collection_route}?limit={self.n_data_requests}") + assert len(response.json()["data_requests"]) == self.n_data_requests + + +class _TestPost: + @pytest.fixture + def data_requests(self): + return [{"user": "user1"}] -class TestPost: - async def test_valid(self, fake, async_client): - data = fake.data_request().model_dump_json() - response = await async_client.post("/v1/data-requests/", json=json.loads(data)) + async def test_valid(self, fake, async_client, collection_route, data_requests): + data = fake.data_request().model_dump_json(exclude=["user"]) + response = await async_client.post(collection_route, json=json.loads(data)) assert response.status_code == 200 assert (id_ := response.json().get("id")) bson.ObjectId(id_) # check that the id is a valid object id - assert json.loads(data) == json.loads(DataRequest(**response.json()).model_dump_json()) + assert {"user": data_requests[0]["user"], **json.loads(data)} == json.loads( + DataRequest(**response.json()).model_dump_json() + ) - async def test_invalid_authors(self, fake, async_client): + async def test_invalid_authors(self, fake, async_client, collection_route): data = json.loads(fake.data_request().model_dump_json()) data["authors"] = [] - response = await async_client.post("/v1/data-requests/", json=data) + response = await async_client.post(collection_route, json=data) assert response.status_code == 422 - async def test_invalid_uncollapsible_geometry(self, fake, async_client): + async def test_invalid_uncollapsible_geometry(self, fake, async_client, collection_route): data = { **json.loads(fake.data_request().model_dump_json()), "geometry": json.loads(fake.uncollapsible_geojson().model_dump_json()), } - response = await async_client.post("/v1/data-requests/", json=data) + response = await async_client.post(collection_route, json=data) assert response.status_code == 422 +class TestPostUser(_TestPost, _TestUser): ... + + +class TestPostAdmin(_TestPost, _TestAdmin): + @pytest.fixture + def collection_route(self, data_requests): + return f"/v1/admin/data-requests/?user={data_requests[0]['user']}" + + class _TestUpdate: @pytest.fixture async def loaded_data(self, fake): @@ -166,72 +265,110 @@ async def loaded_data(self, fake): model["id"] = str(resp.inserted_id) return model + @pytest.fixture + async def data_requests(self, loaded_data): + return [{"_id": loaded_data["id"], **loaded_data}] + -class TestPatch(_TestUpdate): - async def test_valid(self, loaded_data, async_client, fake): +class _TestPatch(_TestUpdate): + async def test_valid(self, loaded_data, async_client, fake, member_route): title = fake.sentence() update = {"title": title} - response = await async_client.patch(f"/v1/data-requests/{loaded_data['id']}", json=update) + response = await async_client.patch(member_route, json=update) assert response.status_code == 200 loaded_data.update(update) assert loaded_data == response.json() - async def test_valid_multiple(self, loaded_data, async_client, fake): + async def test_valid_multiple(self, loaded_data, async_client, fake, member_route): title = fake.sentence() authors = [fake.author(), fake.author()] update = {"title": title, "authors": authors} - response = await async_client.patch(f"/v1/data-requests/{loaded_data['id']}", json=update) + response = await async_client.patch(member_route, json=update) assert response.status_code == 200 loaded_data.update(update) assert loaded_data == response.json() - async def test_update_nothing(self, loaded_data, async_client): - response = await async_client.patch(f"/v1/data-requests/{loaded_data['id']}", json={}) + async def test_update_nothing(self, loaded_data, async_client, member_route): + response = await async_client.patch(member_route, json={}) assert response.status_code == 200 assert loaded_data == response.json() - async def test_update_everything(self, loaded_data, async_client, fake): - update = json.loads(fake.data_request().model_dump_json()) - response = await async_client.patch(f"/v1/data-requests/{loaded_data['id']}", json=update) - assert response.status_code == 200 - update["id"] = loaded_data["id"] - assert update == response.json() - - async def test_no_id_update(self, loaded_data, async_client): + async def test_no_id_update(self, loaded_data, async_client, member_route): update = {"id": str(bson.ObjectId())} - response = await async_client.patch(f"/v1/data-requests/{loaded_data['id']}", json=update) + response = await async_client.patch(member_route, json=update) assert response.status_code == 200 assert response.json()["id"] == loaded_data["id"] assert response.json()["id"] != update["id"] assert loaded_data == response.json() - async def test_invalid_unset_value(self, loaded_data, async_client): - response = await async_client.patch(f"/v1/data-requests/{loaded_data['id']}", json={"title": None}) + async def test_invalid_unset_value(self, async_client, member_route): + response = await async_client.patch(member_route, json={"title": None}) assert response.status_code == 422 - async def test_invalid_bad_type(self, loaded_data, async_client): - response = await async_client.patch(f"/v1/data-requests/{loaded_data['id']}", json={"title": 10}) + async def test_invalid_bad_type(self, async_client, member_route): + response = await async_client.patch(member_route, json={"title": 10}) assert response.status_code == 422 - async def test_invalid_uncollapsible_geometry(self, fake, loaded_data, async_client): + async def test_invalid_uncollapsible_geometry(self, fake, async_client, member_route): response = await async_client.patch( - f"/v1/data-requests/{loaded_data['id']}", + member_route, json={"geometry": json.loads(fake.uncollapsible_geojson().model_dump_json())}, ) assert response.status_code == 422 - async def test_bad_id(self, async_client): - resp = await async_client.patch("/v1/data-requests/id-does-not-exist", json={}) + async def test_bad_id(self, async_client, collection_route): + resp = await async_client.patch(f"{collection_route}/id-does-not-exist", json={}) assert resp.status_code == 404, resp.json() -class TestDelete(_TestUpdate): - async def test_exists(self, loaded_data, async_client): - response = await async_client.delete(f"/v1/data-requests/{loaded_data['id']}") +class TestPatchUser(_TestPatch, _TestUser): + async def test_update_everything(self, loaded_data, async_client, fake, member_route): + update = json.loads(fake.data_request().model_dump_json(exclude=["user"])) + response = await async_client.patch(member_route, json=update) + assert response.status_code == 200 + update["id"] = loaded_data["id"] + update["user"] = loaded_data["user"] + assert update == response.json() + + async def test_no_update_user(self, loaded_data, async_client, member_route): + new_user = loaded_data["user"] + "suffix" + response = await async_client.patch(member_route, json={"user": new_user}) + assert response.status_code == 403 + + +class TestPatchAdmin(_TestPatch, _TestAdmin): + async def test_update_everything(self, loaded_data, async_client, fake, member_route): + update = json.loads(fake.data_request().model_dump_json()) + response = await async_client.patch(member_route, json=update) + assert response.status_code == 200 + update["id"] = loaded_data["id"] + assert update == response.json() + + async def test_update_user(self, loaded_data, async_client, member_route): + new_user = loaded_data["user"] + "suffix" + response = await async_client.patch(member_route, json={"user": new_user}) + assert response.status_code == 200 + assert response.json()["user"] == new_user + + +class _TestDelete(_TestUpdate): + async def test_exists(self, loaded_data, async_client, member_route): + response = await async_client.delete(member_route) assert response.status_code == 204 resp = await client.db.get_collection("data-request").find_one({"_id": bson.ObjectId(loaded_data["id"])}) assert resp is None - async def test_bad_id(self, async_client): - resp = await async_client.delete("/v1/data-requests/id-does-not-exist") - assert resp.status_code == 404, resp.json() + async def test_bad_id(self, async_client, member_route): + route = "/" + "/".join(member_route.strip("/").split("/")) + "bad-id-suffix" + resp = await async_client.delete(route) + assert resp.status_code == 404 + + +class TestDeleteUser(_TestDelete, _TestUser): + async def test_bad_user(self, loaded_data, async_client, member_route): + route = f"/v1/users/someotheruser/data-requests/{loaded_data['id']}" + response = await async_client.delete(route) + assert response.status_code == 404 + + +class TestDeleteAdmin(_TestDelete, _TestAdmin): ... diff --git a/test/unit/versions/v1/data_request/test_models.py b/test/unit/versions/v1/data_request/test_models.py index 20fd6ed..065fc05 100644 --- a/test/unit/versions/v1/data_request/test_models.py +++ b/test/unit/versions/v1/data_request/test_models.py @@ -2,6 +2,7 @@ import pytest from pydantic import TypeAdapter, ValidationError +from pydantic_core import PydanticSerializationError from pystac import Item from marble_api.utils.geojson import collapse_geometries @@ -33,7 +34,7 @@ def fake_class(self, fake): def test_id_dumped(self, fake_class): assert "id" not in fake_class().model_dump() - @pytest.mark.parametrize("field", ["user", "title", "description", "authors", "path", "contact"]) + @pytest.mark.parametrize("field", ["title", "description", "authors", "path", "contact"]) def test_text_fields_not_empty(self, fake_class, field): with pytest.raises(ValidationError): fake_class(**{field: ""}) @@ -41,7 +42,6 @@ def test_text_fields_not_empty(self, fake_class, field): @pytest.mark.parametrize( "field", [ - "user", "title", "authors", "temporal", @@ -57,6 +57,13 @@ def test_fields_not_nullable(self, fake_class, field): with pytest.raises(ValidationError): fake_class(**{field: None}) + @pytest.mark.parametrize("value", [None, ""]) + def test_user_field_present_when_serialized(self, fake_class, value): + model = fake_class() + model.user = value + with pytest.raises(PydanticSerializationError): + model.model_dump() + @pytest.mark.parametrize( "field", [