Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion datacommons_client/endpoints/payloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,4 +133,7 @@ class ResolveRequestPayload(BaseDCModel):
"""

node_dcids: ListOrStr = Field(..., serialization_alias="nodes")
expression: str | list[str] = Field(..., serialization_alias="property")
expression: str | list[str] | None = Field(default=None,
serialization_alias="property")
resolver: str | None = Field(default=None, serialization_alias="resolver")
target: str | None = Field(default=None, serialization_alias="target")
Comment on lines +136 to +139
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

To ensure that ResolveRequestPayload is always used correctly, it's a good practice to add validation directly to the Pydantic model. This will prevent invalid combinations of expression, resolver, and target regardless of how the payload is created.

You can add a model_validator to enforce that:

  1. expression and resolver are mutually exclusive.
  2. target is only provided when resolver is also present.

Here is a suggested implementation to add to the ResolveRequestPayload class:

  @model_validator(mode="after")
  def _validate_resolver_args(self):
    if self.expression is not None and self.resolver is not None:
      raise ValueError(
          "`expression` and `resolver` are mutually exclusive and cannot be used together."
      )
    if self.target is not None and self.resolver is None:
      raise ValueError("`target` can only be used with `resolver`.")
    return self

28 changes: 25 additions & 3 deletions datacommons_client/endpoints/resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,20 @@ def __init__(self, api: API):
"""Initializes the ResolveEndpoint instance."""
super().__init__(endpoint="resolve", api=api)

def fetch(self, node_ids: str | list[str],
expression: str | list[str]) -> ResolveResponse:
def fetch(self,
node_ids: str | list[str],
expression: str | list[str] | None = None,
resolver: str | None = None,
target: str | None = None) -> ResolveResponse:
"""
Fetches resolved data for the given nodes and expressions, identified by name,
coordinates, or wiki ID.

Args:
node_ids (str | list[str]): One or more node IDs to resolve.
expression (str): The relation expression to query.
resolver (str | None): The resolver type to use (e.g., "indicator").
target (str | None): The resolution target (e.g., "custom_only").

Returns:
ResolveResponse: The response object containing the resolved data.
Expand All @@ -56,11 +61,28 @@ def fetch(self, node_ids: str | list[str],

# Construct the payload
payload = ResolveRequestPayload(node_dcids=node_ids,
expression=expression).to_dict()
expression=expression,
resolver=resolver,
target=target).to_dict()

# Send the request and return the response
return ResolveResponse.model_validate(self.post(payload))

def fetch_indicators(self,
queries: str | list[str],
target: str | None = None) -> ResolveResponse:
"""
Fetches resolved indicators (StatisticalVariables or Topics) for the given queries.

Args:
queries (str | list[str]): One or more queries (e.g. "population", "gdp").
target (str | None): Optional target for resolution (e.g., "base_only", "custom_only", "base_and_custom").

Returns:
ResolveResponse: The response object containing the resolved indicators.
"""
return self.fetch(node_ids=queries, resolver="indicator", target=target)

def fetch_dcids_by_name(self,
names: str | list[str],
entity_type: Optional[str] = None) -> ResolveResponse:
Expand Down
2 changes: 2 additions & 0 deletions datacommons_client/models/resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ class Candidate(BaseDCModel):

dcid: NodeDCID = Field(default_factory=str)
dominantType: Optional[DominantType] = None
metadata: dict[str, str] | None = None
typeOf: list[str] | None = None


class Entity(BaseDCModel):
Expand Down
70 changes: 70 additions & 0 deletions datacommons_client/tests/endpoints/test_resolve_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,73 @@ def test_flatten_resolve_response():

# Assertions
assert result == expected


def test_fetch_indicators_calls_endpoints_correctly():
"""Tests the fetch_indicators method."""
api_mock = MagicMock()
# Mock response data structure
mock_response_data = {
"entities": [{
"node":
"population",
"candidates": [{
"dcid": "Count_Person",
"dominantType": "StatisticalVariable",
"metadata": {
"score": "0.9",
"sentence": "population count"
},
"typeOf": ["StatisticalVariable"]
}]
}]
}
api_mock.post = MagicMock(return_value=mock_response_data)
endpoint = ResolveEndpoint(api=api_mock)

# Call the method
response = endpoint.fetch_indicators(queries=["population"],
target="custom_only")

# Verify post was called with correct payload
api_mock.post.assert_called_once_with(payload={
"nodes": ["population"],
"resolver": "indicator",
"target": "custom_only"
},
endpoint="resolve",
all_pages=True,
next_token=None)

# Verify response parsing
expected = ResolveResponse(entities=[
Entity(node="population",
candidates=[
Candidate(dcid="Count_Person",
dominantType="StatisticalVariable",
metadata={
"score": "0.9",
"sentence": "population count"
},
typeOf=["StatisticalVariable"])
])
])
assert response == expected


def test_fetch_still_works_with_expression():
"""Tests that fetch still works with expression (regression test)."""
api_mock = MagicMock()
mock_response_data = {"entities": []}
api_mock.post = MagicMock(return_value=mock_response_data)
endpoint = ResolveEndpoint(api=api_mock)

endpoint.fetch(node_ids=["geoId/06"], expression="<-containedInPlace")

api_mock.post.assert_called_once_with(payload={
"nodes": ["geoId/06"],
"property": "<-containedInPlace"
},
endpoint="resolve",
all_pages=True,
next_token=None)
30 changes: 22 additions & 8 deletions datacommons_client/tests/endpoints/test_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,19 +856,26 @@ def test_resolve_response_dict():
"candidates": [
{
"dcid": "dcid1",
"dominantType": "Type1"
"dominantType": "Type1",
"metadata": None,
"typeOf": None,
},
{
"dcid": "dcid2",
"dominantType": None
"dominantType": None,
"metadata": None,
"typeOf": None,
},
],
},
{
"node": "entity2",
"node":
"entity2",
"candidates": [{
"dcid": "dcid3",
"dominantType": "Type2"
"dominantType": "Type2",
"metadata": None,
"typeOf": None,
},],
},
]
Expand Down Expand Up @@ -968,19 +975,26 @@ def test_resolve_response_json_string_exclude_none():
"candidates": [
{
"dcid": "dcid1",
"dominantType": "Type1"
"dominantType": "Type1",
"metadata": None,
"typeOf": None,
},
{
"dcid": "dcid2",
"dominantType": None
"dominantType": None,
"metadata": None,
"typeOf": None,
},
],
},
{
"node": "entity2",
"node":
"entity2",
"candidates": [{
"dcid": "dcid3",
"dominantType": "Type2"
"dominantType": "Type2",
"metadata": None,
"typeOf": None,
},],
},
{
Expand Down