diff --git a/datacommons_client/endpoints/payloads.py b/datacommons_client/endpoints/payloads.py index 75807455..af66334b 100644 --- a/datacommons_client/endpoints/payloads.py +++ b/datacommons_client/endpoints/payloads.py @@ -133,4 +133,7 @@ class ResolveRequestPayload(BaseDCModel): """ node_dcids: ListOrStr = Field(..., serialization_alias="nodes") - expression: str | list[str] = Field(..., serialization_alias="property") + expression: str | list[str] | None = Field(default=None, + serialization_alias="property") + resolver: str | None = Field(default=None, serialization_alias="resolver") + target: str | None = Field(default=None, serialization_alias="target") diff --git a/datacommons_client/endpoints/resolve.py b/datacommons_client/endpoints/resolve.py index 0b91633b..ca7c87c4 100644 --- a/datacommons_client/endpoints/resolve.py +++ b/datacommons_client/endpoints/resolve.py @@ -37,8 +37,11 @@ def __init__(self, api: API): """Initializes the ResolveEndpoint instance.""" super().__init__(endpoint="resolve", api=api) - def fetch(self, node_ids: str | list[str], - expression: str | list[str]) -> ResolveResponse: + def fetch(self, + node_ids: str | list[str], + expression: str | list[str] | None = None, + resolver: str | None = None, + target: str | None = None) -> ResolveResponse: """ Fetches resolved data for the given nodes and expressions, identified by name, coordinates, or wiki ID. @@ -46,6 +49,8 @@ def fetch(self, node_ids: str | list[str], Args: node_ids (str | list[str]): One or more node IDs to resolve. expression (str): The relation expression to query. + resolver (str | None): The resolver type to use (e.g., "indicator"). + target (str | None): The resolution target (e.g., "custom_only"). Returns: ResolveResponse: The response object containing the resolved data. @@ -56,11 +61,28 @@ def fetch(self, node_ids: str | list[str], # Construct the payload payload = ResolveRequestPayload(node_dcids=node_ids, - expression=expression).to_dict() + expression=expression, + resolver=resolver, + target=target).to_dict() # Send the request and return the response return ResolveResponse.model_validate(self.post(payload)) + def fetch_indicators(self, + queries: str | list[str], + target: str | None = None) -> ResolveResponse: + """ + Fetches resolved indicators (StatisticalVariables or Topics) for the given queries. + + Args: + queries (str | list[str]): One or more queries (e.g. "population", "gdp"). + target (str | None): Optional target for resolution (e.g., "base_only", "custom_only", "base_and_custom"). + + Returns: + ResolveResponse: The response object containing the resolved indicators. + """ + return self.fetch(node_ids=queries, resolver="indicator", target=target) + def fetch_dcids_by_name(self, names: str | list[str], entity_type: Optional[str] = None) -> ResolveResponse: diff --git a/datacommons_client/models/resolve.py b/datacommons_client/models/resolve.py index ad69c1e2..919a9e36 100644 --- a/datacommons_client/models/resolve.py +++ b/datacommons_client/models/resolve.py @@ -20,6 +20,8 @@ class Candidate(BaseDCModel): dcid: NodeDCID = Field(default_factory=str) dominantType: Optional[DominantType] = None + metadata: dict[str, str] | None = None + typeOf: list[str] | None = None class Entity(BaseDCModel): diff --git a/datacommons_client/tests/endpoints/test_resolve_endpoint.py b/datacommons_client/tests/endpoints/test_resolve_endpoint.py index 74c06efe..ea337bbb 100644 --- a/datacommons_client/tests/endpoints/test_resolve_endpoint.py +++ b/datacommons_client/tests/endpoints/test_resolve_endpoint.py @@ -155,3 +155,73 @@ def test_flatten_resolve_response(): # Assertions assert result == expected + + +def test_fetch_indicators_calls_endpoints_correctly(): + """Tests the fetch_indicators method.""" + api_mock = MagicMock() + # Mock response data structure + mock_response_data = { + "entities": [{ + "node": + "population", + "candidates": [{ + "dcid": "Count_Person", + "dominantType": "StatisticalVariable", + "metadata": { + "score": "0.9", + "sentence": "population count" + }, + "typeOf": ["StatisticalVariable"] + }] + }] + } + api_mock.post = MagicMock(return_value=mock_response_data) + endpoint = ResolveEndpoint(api=api_mock) + + # Call the method + response = endpoint.fetch_indicators(queries=["population"], + target="custom_only") + + # Verify post was called with correct payload + api_mock.post.assert_called_once_with(payload={ + "nodes": ["population"], + "resolver": "indicator", + "target": "custom_only" + }, + endpoint="resolve", + all_pages=True, + next_token=None) + + # Verify response parsing + expected = ResolveResponse(entities=[ + Entity(node="population", + candidates=[ + Candidate(dcid="Count_Person", + dominantType="StatisticalVariable", + metadata={ + "score": "0.9", + "sentence": "population count" + }, + typeOf=["StatisticalVariable"]) + ]) + ]) + assert response == expected + + +def test_fetch_still_works_with_expression(): + """Tests that fetch still works with expression (regression test).""" + api_mock = MagicMock() + mock_response_data = {"entities": []} + api_mock.post = MagicMock(return_value=mock_response_data) + endpoint = ResolveEndpoint(api=api_mock) + + endpoint.fetch(node_ids=["geoId/06"], expression="<-containedInPlace") + + api_mock.post.assert_called_once_with(payload={ + "nodes": ["geoId/06"], + "property": "<-containedInPlace" + }, + endpoint="resolve", + all_pages=True, + next_token=None) diff --git a/datacommons_client/tests/endpoints/test_response.py b/datacommons_client/tests/endpoints/test_response.py index a8bfd6eb..9faa1d88 100644 --- a/datacommons_client/tests/endpoints/test_response.py +++ b/datacommons_client/tests/endpoints/test_response.py @@ -856,19 +856,26 @@ def test_resolve_response_dict(): "candidates": [ { "dcid": "dcid1", - "dominantType": "Type1" + "dominantType": "Type1", + "metadata": None, + "typeOf": None, }, { "dcid": "dcid2", - "dominantType": None + "dominantType": None, + "metadata": None, + "typeOf": None, }, ], }, { - "node": "entity2", + "node": + "entity2", "candidates": [{ "dcid": "dcid3", - "dominantType": "Type2" + "dominantType": "Type2", + "metadata": None, + "typeOf": None, },], }, ] @@ -968,19 +975,26 @@ def test_resolve_response_json_string_exclude_none(): "candidates": [ { "dcid": "dcid1", - "dominantType": "Type1" + "dominantType": "Type1", + "metadata": None, + "typeOf": None, }, { "dcid": "dcid2", - "dominantType": None + "dominantType": None, + "metadata": None, + "typeOf": None, }, ], }, { - "node": "entity2", + "node": + "entity2", "candidates": [{ "dcid": "dcid3", - "dominantType": "Type2" + "dominantType": "Type2", + "metadata": None, + "typeOf": None, },], }, {