Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 61 additions & 10 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,13 +1145,54 @@ def assert_type(self, python_type: Type, python_val: T):
return base_transformer.assert_type(base_type, python_val)


def _resolve_json_schema_ref(property_val: dict, schema: typing.Optional[dict] = None) -> dict:
"""
Resolve a JSON schema $ref reference to its actual definition.

Args:
property_val: The property value that may contain a $ref
schema: The full schema containing the definitions

Returns:
The resolved property value (dereferenced if it was a $ref)

Raises:
ValueError: If the $ref path is invalid or cannot be resolved
"""
if "$ref" in property_val and schema is not None:
# Handle references like "#/$defs/ModelName" or "#/definitions/ModelName"
ref_path = property_val["$ref"]
if ref_path.startswith("#/"):
path_parts = ref_path[2:].split("/")
resolved = schema
try:
for part in path_parts:
resolved = resolved[part]
return resolved
except (KeyError, TypeError) as e:
raise ValueError(
f"Failed to resolve JSON schema reference '{ref_path}': {e}. "
f"Make sure the referenced definition exists in the schema."
) from e
return property_val


def _handle_json_schema_property(
property_key: str,
property_val: dict,
schema: typing.Optional[dict] = None,
) -> typing.Tuple[str, typing.Any]:
"""
A helper to handle the properties of a JSON schema and returns their equivalent Flyte attribute name and type.

Args:
property_key: The name of the property
property_val: The property schema definition
schema: The full schema containing definitions (needed for resolving $ref)
"""

# Resolve $ref before processing
property_val = _resolve_json_schema_ref(property_val, schema)

# Handle Optional[T] or Union[T1, T2, ...] at the top level for proper recursion
if property_val.get("anyOf"):
Expand All @@ -1166,7 +1207,7 @@ def _handle_json_schema_property(
)
attr_types = []
for item in property_val["anyOf"]:
_, attr_type = _handle_json_schema_property(property_key, item)
_, attr_type = _handle_json_schema_property(property_key, item, schema)
attr_types.append(attr_type)

# Gather all the types and return a Union[T1, T2, ...]
Expand All @@ -1181,7 +1222,7 @@ def _handle_json_schema_property(

# Handle list
if property_type == "array":
return (property_key, typing.List[_get_element_type(property_val["items"])]) # type: ignore
return (property_key, typing.List[_get_element_type(property_val["items"], schema)]) # type: ignore
# Handle null types (i.e. None)
elif property_type == "null":
return (property_key, type(None)) # type: ignore
Expand All @@ -1191,7 +1232,7 @@ def _handle_json_schema_property(
# those are handled in the top level of the function with recursion.
if property_val.get("additionalProperties"):
# For typing.Dict type
elem_type = _get_element_type(property_val["additionalProperties"])
elem_type = _get_element_type(property_val["additionalProperties"], schema)
return (property_key, typing.Dict[str, elem_type]) # type: ignore
elif property_val.get("title"):
# For nested dataclass
Expand All @@ -1207,7 +1248,7 @@ def _handle_json_schema_property(
return (property_key, str) # type: ignore
# Handle None, int, float, bool or str
else:
return (property_key, _get_element_type(property_val)) # type: ignore
return (property_key, _get_element_type(property_val, schema)) # type: ignore


def generate_attribute_list_from_dataclass_json_mixin(
Expand All @@ -1216,7 +1257,7 @@ def generate_attribute_list_from_dataclass_json_mixin(
):
attribute_list: typing.List[typing.Tuple[Any, Any]] = []
for property_key, property_val in schema["properties"].items():
attribute_list.append(_handle_json_schema_property(property_key, property_val))
attribute_list.append(_handle_json_schema_property(property_key, property_val, schema))
return attribute_list


Expand Down Expand Up @@ -2452,7 +2493,7 @@ def generate_attribute_list_from_dataclass_json(schema: dict, schema_name: typin
property_type = property_val["type"]
# Handle list
if property_val["type"] == "array":
attribute_list.append((property_key, List[_get_element_type(property_val["items"])])) # type: ignore[misc,index]
attribute_list.append((property_key, List[_get_element_type(property_val["items"], None)])) # type: ignore[misc,index]
# Handle dataclass and dict
elif property_type == "object":
# For nested dataclass
Expand All @@ -2472,11 +2513,11 @@ def generate_attribute_list_from_dataclass_json(schema: dict, schema_name: typin
else:
# typed dict
attribute_list.append(
(property_key, Dict[str, _get_element_type(property_val["additionalProperties"])]) # type: ignore[misc,index]
(property_key, Dict[str, _get_element_type(property_val["additionalProperties"], None)]) # type: ignore[misc,index]
)
# Handle primitive types like int, float, bool or str
else:
attribute_list.append([property_key, _get_element_type(property_val)]) # type: ignore
attribute_list.append([property_key, _get_element_type(property_val, None)]) # type: ignore
return attribute_list


Expand All @@ -2502,7 +2543,10 @@ def convert_mashumaro_json_schema_to_python_class(schema: dict, schema_name: typ
return dataclass_json(dataclasses.make_dataclass(schema_name, attribute_list))


def _get_element_type(element_property: typing.Dict[str, str]) -> Type:
def _get_element_type(element_property: typing.Dict[str, typing.Any], schema: typing.Optional[dict] = None) -> Type:
# Resolve $ref before processing
element_property = _resolve_json_schema_ref(element_property, schema)

element_type = (
[e_property["type"] for e_property in element_property["anyOf"]] # type: ignore
if element_property.get("anyOf")
Expand All @@ -2512,7 +2556,7 @@ def _get_element_type(element_property: typing.Dict[str, str]) -> Type:

if isinstance(element_type, list):
# Element type of Optional[int] is [integer, None]
return typing.Optional[_get_element_type({"type": element_type[0]})] # type: ignore
return typing.Optional[_get_element_type({"type": element_type[0]}, schema)] # type: ignore

if element_type == "string":
return str
Expand All @@ -2525,6 +2569,13 @@ def _get_element_type(element_property: typing.Dict[str, str]) -> Type:
return int
else:
return float
elif element_type == "object":
# Handle nested dataclass objects
if element_property.get("title"):
sub_schema_name = element_property["title"]
return typing.cast(Type, convert_mashumaro_json_schema_to_python_class(element_property, sub_schema_name))
# For untyped dict or other object types
return dict # type: ignore
return str


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1022,3 +1022,69 @@ def mock_resolve_remote_path(flyte_uri: str):

bm_revived = TypeEngine.to_python_value(ctx, lit, BM)
assert bm_revived.s.literal.uri == "/my/replaced/val"


def test_nested_pydantic_model_with_list():
"""
Test that nested Pydantic models in lists are handled correctly.
This tests the fix for JSON schema $ref references in nested models.
"""
class NestedModel(BaseModel):
name: str
value: int

class ParentModel(BaseModel):
id: str
nested: NestedModel # Direct nested model (generates $ref)
nested_list: List[NestedModel] # List of nested models (generates array with $ref in items)
optional_nested_list: Optional[List[NestedModel]] = None # Optional list with $ref

# Test TypeEngine can handle the model
ctx = FlyteContextManager.current_context()

# Create test data
parent = ParentModel(
id="test-id",
nested=NestedModel(name="direct", value=42),
nested_list=[
NestedModel(name="item1", value=1),
NestedModel(name="item2", value=2),
],
optional_nested_list=[
NestedModel(name="opt1", value=10),
NestedModel(name="opt2", value=20),
]
)

# Test to_literal_type
lt = TypeEngine.to_literal_type(ParentModel)
assert lt is not None

# Test to_literal
literal = TypeEngine.to_literal(ctx, parent, ParentModel, lt)
assert literal is not None

# Test to_python_value
restored = TypeEngine.to_python_value(ctx, literal, ParentModel)
assert restored.id == "test-id"
assert restored.nested.name == "direct"
assert restored.nested.value == 42
assert len(restored.nested_list) == 2
assert restored.nested_list[0].name == "item1"
assert restored.nested_list[0].value == 1
assert restored.nested_list[1].name == "item2"
assert restored.nested_list[1].value == 2
assert restored.optional_nested_list is not None
assert len(restored.optional_nested_list) == 2
assert restored.optional_nested_list[0].name == "opt1"
assert restored.optional_nested_list[0].value == 10

# Test with task
@task
def process_parent(data: ParentModel) -> ParentModel:
return data

result = process_parent(data=parent)
assert result.id == "test-id"
assert result.nested.name == "direct"
assert len(result.nested_list) == 2