diff --git a/flytekit/types/directory/types.py b/flytekit/types/directory/types.py index 286cfafbf6..ec081aca94 100644 --- a/flytekit/types/directory/types.py +++ b/flytekit/types/directory/types.py @@ -176,10 +176,25 @@ def serialize_flyte_dir(self) -> Dict[str, str]: ) return {"path": lv.scalar.blob.uri} + def _was_initialized_via_init(self) -> bool: + """Check if object was initialized via __init__ or deserialized by Pydantic. + + When Pydantic deserializes a FlyteDirectory, it bypasses __init__ and directly + sets the 'path' attribute. This means the _downloader and _remote_source attributes + won't exist. We use this check to determine if we need to go through the transformer + to properly set up these internal attributes. + + Returns: + bool: True if __init__ was called (normal initialization), False if created via + Pydantic deserialization and needs to pass through transformer. + """ + return hasattr(self, "_downloader") and hasattr(self, "_remote_source") + @model_validator(mode="after") def deserialize_flyte_dir(self, info) -> FlyteDirectory: if info.context is None or info.context.get("deserialize") is not True: - return self + if self._was_initialized_via_init(): + return self pv = FlyteDirToMultipartBlobTransformer().to_python_value( FlyteContextManager.current_context(), diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index 780188f9e5..47915add8e 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -193,10 +193,25 @@ def serialize_flyte_file(self) -> Dict[str, typing.Any]: out["metadata"] = lv.metadata return out + def _was_initialized_via_init(self) -> bool: + """Check if object was initialized via __init__ or deserialized by Pydantic. + + When Pydantic deserializes a FlyteFile, it bypasses __init__ and directly + sets the 'path' attribute. This means the _downloader and _remote_source attributes + won't exist. We use this check to determine if we need to go through the transformer + to properly set up these internal attributes. + + Returns: + bool: True if __init__ was called (normal initialization), False if created via + Pydantic deserialization and needs to pass through transformer. + """ + return hasattr(self, "_downloader") and hasattr(self, "_remote_source") + @model_validator(mode="after") def deserialize_flyte_file(self, info) -> "FlyteFile": if info.context is None or info.context.get("deserialize") is not True: - return self + if self._was_initialized_via_init(): + return self pv = FlyteFilePathTransformer().to_python_value( FlyteContextManager.current_context(), diff --git a/tests/flytekit/unit/extras/pydantic_transformer/test_pydantic_basemodel_transformer.py b/tests/flytekit/unit/extras/pydantic_transformer/test_pydantic_basemodel_transformer.py index 63127eaee2..6206817bfb 100644 --- a/tests/flytekit/unit/extras/pydantic_transformer/test_pydantic_basemodel_transformer.py +++ b/tests/flytekit/unit/extras/pydantic_transformer/test_pydantic_basemodel_transformer.py @@ -1022,3 +1022,109 @@ def mock_resolve_remote_path(flyte_uri: str): bm_revived = TypeEngine.to_python_value(ctx, lit, BM) assert bm_revived.s.literal.uri == "/my/replaced/val" + + +def test_flytefile_pydantic_model_dump_validate_cycle(): + class BM(BaseModel): + ff: FlyteFile + + bm = BM(ff=FlyteFile.from_source("s3://my-bucket/file.txt")) + + assert bm.ff.remote_source == "s3://my-bucket/file.txt" + + bm_dict = bm.model_dump() + bm2 = BM.model_validate(bm_dict) + + assert isinstance(bm2.ff, FlyteFile) + assert bm2.ff.remote_source == "s3://my-bucket/file.txt" + + bm2_dict = bm2.model_dump() + assert bm_dict == bm2_dict + + +def test_flytefile_pydantic_with_local_file(local_dummy_file): + class BM(BaseModel): + ff: FlyteFile + + bm = BM(ff=FlyteFile(path=local_dummy_file)) + + bm_dict = bm.model_dump() + bm2 = BM.model_validate(bm_dict) + + assert isinstance(bm2.ff, FlyteFile) + assert hasattr(bm2.ff, "_downloader") + assert hasattr(bm2.ff, "_remote_source") + + bm2.model_dump() + + +def test_flytefile_pydantic_with_metadata(local_dummy_file): + class BM(BaseModel): + ff: FlyteFile + + bm = BM(ff=FlyteFile(path=local_dummy_file, metadata={"key": "value"})) + + bm_dict = bm.model_dump() + bm2 = BM.model_validate(bm_dict) + + assert isinstance(bm2.ff, FlyteFile) + assert hasattr(bm2.ff, "_downloader") + assert hasattr(bm2.ff, "_remote_source") + assert bm2.ff.metadata == {"key": "value"} + + bm2.model_dump() + + +def test_flytefile_pydantic_direct_dict_validate(local_dummy_file): + class BM(BaseModel): + ff: FlyteFile + + bm = BM.model_validate({"ff": {"path": local_dummy_file}}) + + assert isinstance(bm.ff, FlyteFile) + assert hasattr(bm.ff, "_downloader") + assert hasattr(bm.ff, "_remote_source") + + +def test_flytedirectory_pydantic_direct_dict_validate(local_dummy_directory): + class BM(BaseModel): + fd: FlyteDirectory + + bm = BM.model_validate({"fd": {"path": local_dummy_directory}}) + + assert isinstance(bm.fd, FlyteDirectory) + assert hasattr(bm.fd, "_downloader") + assert hasattr(bm.fd, "_remote_source") + + +def test_flytedirectory_pydantic_model_dump_validate_cycle(): + class BM(BaseModel): + fd: FlyteDirectory + + bm = BM(fd=FlyteDirectory.from_source("s3://my-bucket/my-dir")) + + assert bm.fd.remote_source == "s3://my-bucket/my-dir" + + bm_dict = bm.model_dump() + bm2 = BM.model_validate(bm_dict) + + assert isinstance(bm2.fd, FlyteDirectory) + assert bm2.fd.remote_source == "s3://my-bucket/my-dir" + + bm2.model_dump() + + +def test_flytedirectory_pydantic_with_local_directory(local_dummy_directory): + class BM(BaseModel): + fd: FlyteDirectory + + bm = BM(fd=FlyteDirectory(path=local_dummy_directory)) + + bm_dict = bm.model_dump() + bm2 = BM.model_validate(bm_dict) + + assert isinstance(bm2.fd, FlyteDirectory) + assert hasattr(bm2.fd, "_downloader") + assert hasattr(bm2.fd, "_remote_source") + + bm2.model_dump()