diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 24a78f184b..551258824b 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1145,13 +1145,54 @@ def assert_type(self, python_type: Type, python_val: T): return base_transformer.assert_type(base_type, python_val) +def _resolve_json_schema_ref(property_val: dict, schema: typing.Optional[dict] = None) -> dict: + """ + Resolve a JSON schema $ref reference to its actual definition. + + Args: + property_val: The property value that may contain a $ref + schema: The full schema containing the definitions + + Returns: + The resolved property value (dereferenced if it was a $ref) + + Raises: + ValueError: If the $ref path is invalid or cannot be resolved + """ + if "$ref" in property_val and schema is not None: + # Handle references like "#/$defs/ModelName" or "#/definitions/ModelName" + ref_path = property_val["$ref"] + if ref_path.startswith("#/"): + path_parts = ref_path[2:].split("/") + resolved = schema + try: + for part in path_parts: + resolved = resolved[part] + return resolved + except (KeyError, TypeError) as e: + raise ValueError( + f"Failed to resolve JSON schema reference '{ref_path}': {e}. " + f"Make sure the referenced definition exists in the schema." + ) from e + return property_val + + def _handle_json_schema_property( property_key: str, property_val: dict, + schema: typing.Optional[dict] = None, ) -> typing.Tuple[str, typing.Any]: """ A helper to handle the properties of a JSON schema and returns their equivalent Flyte attribute name and type. + + Args: + property_key: The name of the property + property_val: The property schema definition + schema: The full schema containing definitions (needed for resolving $ref) """ + + # Resolve $ref before processing + property_val = _resolve_json_schema_ref(property_val, schema) # Handle Optional[T] or Union[T1, T2, ...] at the top level for proper recursion if property_val.get("anyOf"): @@ -1166,7 +1207,7 @@ def _handle_json_schema_property( ) attr_types = [] for item in property_val["anyOf"]: - _, attr_type = _handle_json_schema_property(property_key, item) + _, attr_type = _handle_json_schema_property(property_key, item, schema) attr_types.append(attr_type) # Gather all the types and return a Union[T1, T2, ...] @@ -1181,7 +1222,7 @@ def _handle_json_schema_property( # Handle list if property_type == "array": - return (property_key, typing.List[_get_element_type(property_val["items"])]) # type: ignore + return (property_key, typing.List[_get_element_type(property_val["items"], schema)]) # type: ignore # Handle null types (i.e. None) elif property_type == "null": return (property_key, type(None)) # type: ignore @@ -1191,7 +1232,7 @@ def _handle_json_schema_property( # those are handled in the top level of the function with recursion. if property_val.get("additionalProperties"): # For typing.Dict type - elem_type = _get_element_type(property_val["additionalProperties"]) + elem_type = _get_element_type(property_val["additionalProperties"], schema) return (property_key, typing.Dict[str, elem_type]) # type: ignore elif property_val.get("title"): # For nested dataclass @@ -1207,7 +1248,7 @@ def _handle_json_schema_property( return (property_key, str) # type: ignore # Handle None, int, float, bool or str else: - return (property_key, _get_element_type(property_val)) # type: ignore + return (property_key, _get_element_type(property_val, schema)) # type: ignore def generate_attribute_list_from_dataclass_json_mixin( @@ -1216,7 +1257,7 @@ def generate_attribute_list_from_dataclass_json_mixin( ): attribute_list: typing.List[typing.Tuple[Any, Any]] = [] for property_key, property_val in schema["properties"].items(): - attribute_list.append(_handle_json_schema_property(property_key, property_val)) + attribute_list.append(_handle_json_schema_property(property_key, property_val, schema)) return attribute_list @@ -2452,7 +2493,7 @@ def generate_attribute_list_from_dataclass_json(schema: dict, schema_name: typin property_type = property_val["type"] # Handle list if property_val["type"] == "array": - attribute_list.append((property_key, List[_get_element_type(property_val["items"])])) # type: ignore[misc,index] + attribute_list.append((property_key, List[_get_element_type(property_val["items"], None)])) # type: ignore[misc,index] # Handle dataclass and dict elif property_type == "object": # For nested dataclass @@ -2472,11 +2513,11 @@ def generate_attribute_list_from_dataclass_json(schema: dict, schema_name: typin else: # typed dict attribute_list.append( - (property_key, Dict[str, _get_element_type(property_val["additionalProperties"])]) # type: ignore[misc,index] + (property_key, Dict[str, _get_element_type(property_val["additionalProperties"], None)]) # type: ignore[misc,index] ) # Handle primitive types like int, float, bool or str else: - attribute_list.append([property_key, _get_element_type(property_val)]) # type: ignore + attribute_list.append([property_key, _get_element_type(property_val, None)]) # type: ignore return attribute_list @@ -2502,7 +2543,10 @@ def convert_mashumaro_json_schema_to_python_class(schema: dict, schema_name: typ return dataclass_json(dataclasses.make_dataclass(schema_name, attribute_list)) -def _get_element_type(element_property: typing.Dict[str, str]) -> Type: +def _get_element_type(element_property: typing.Dict[str, typing.Any], schema: typing.Optional[dict] = None) -> Type: + # Resolve $ref before processing + element_property = _resolve_json_schema_ref(element_property, schema) + element_type = ( [e_property["type"] for e_property in element_property["anyOf"]] # type: ignore if element_property.get("anyOf") @@ -2512,7 +2556,7 @@ def _get_element_type(element_property: typing.Dict[str, str]) -> Type: if isinstance(element_type, list): # Element type of Optional[int] is [integer, None] - return typing.Optional[_get_element_type({"type": element_type[0]})] # type: ignore + return typing.Optional[_get_element_type({"type": element_type[0]}, schema)] # type: ignore if element_type == "string": return str @@ -2525,6 +2569,13 @@ def _get_element_type(element_property: typing.Dict[str, str]) -> Type: return int else: return float + elif element_type == "object": + # Handle nested dataclass objects + if element_property.get("title"): + sub_schema_name = element_property["title"] + return typing.cast(Type, convert_mashumaro_json_schema_to_python_class(element_property, sub_schema_name)) + # For untyped dict or other object types + return dict # type: ignore return str diff --git a/tests/flytekit/unit/extras/pydantic_transformer/test_pydantic_basemodel_transformer.py b/tests/flytekit/unit/extras/pydantic_transformer/test_pydantic_basemodel_transformer.py index 63127eaee2..66cf818d01 100644 --- a/tests/flytekit/unit/extras/pydantic_transformer/test_pydantic_basemodel_transformer.py +++ b/tests/flytekit/unit/extras/pydantic_transformer/test_pydantic_basemodel_transformer.py @@ -1022,3 +1022,69 @@ def mock_resolve_remote_path(flyte_uri: str): bm_revived = TypeEngine.to_python_value(ctx, lit, BM) assert bm_revived.s.literal.uri == "/my/replaced/val" + + +def test_nested_pydantic_model_with_list(): + """ + Test that nested Pydantic models in lists are handled correctly. + This tests the fix for JSON schema $ref references in nested models. + """ + class NestedModel(BaseModel): + name: str + value: int + + class ParentModel(BaseModel): + id: str + nested: NestedModel # Direct nested model (generates $ref) + nested_list: List[NestedModel] # List of nested models (generates array with $ref in items) + optional_nested_list: Optional[List[NestedModel]] = None # Optional list with $ref + + # Test TypeEngine can handle the model + ctx = FlyteContextManager.current_context() + + # Create test data + parent = ParentModel( + id="test-id", + nested=NestedModel(name="direct", value=42), + nested_list=[ + NestedModel(name="item1", value=1), + NestedModel(name="item2", value=2), + ], + optional_nested_list=[ + NestedModel(name="opt1", value=10), + NestedModel(name="opt2", value=20), + ] + ) + + # Test to_literal_type + lt = TypeEngine.to_literal_type(ParentModel) + assert lt is not None + + # Test to_literal + literal = TypeEngine.to_literal(ctx, parent, ParentModel, lt) + assert literal is not None + + # Test to_python_value + restored = TypeEngine.to_python_value(ctx, literal, ParentModel) + assert restored.id == "test-id" + assert restored.nested.name == "direct" + assert restored.nested.value == 42 + assert len(restored.nested_list) == 2 + assert restored.nested_list[0].name == "item1" + assert restored.nested_list[0].value == 1 + assert restored.nested_list[1].name == "item2" + assert restored.nested_list[1].value == 2 + assert restored.optional_nested_list is not None + assert len(restored.optional_nested_list) == 2 + assert restored.optional_nested_list[0].name == "opt1" + assert restored.optional_nested_list[0].value == 10 + + # Test with task + @task + def process_parent(data: ParentModel) -> ParentModel: + return data + + result = process_parent(data=parent) + assert result.id == "test-id" + assert result.nested.name == "direct" + assert len(result.nested_list) == 2