Skip to content
17 changes: 16 additions & 1 deletion flytekit/types/directory/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,25 @@ def serialize_flyte_dir(self) -> Dict[str, str]:
)
return {"path": lv.scalar.blob.uri}

def _was_initialized_via_init(self) -> bool:
"""Check if object was initialized via __init__ or deserialized by Pydantic.

When Pydantic deserializes a FlyteDirectory, it bypasses __init__ and directly
sets the 'path' attribute. This means the _downloader and _remote_source attributes
won't exist. We use this check to determine if we need to go through the transformer
to properly set up these internal attributes.

Returns:
bool: True if __init__ was called (normal initialization), False if created via
Pydantic deserialization and needs to pass through transformer.
"""
return hasattr(self, "_downloader") and hasattr(self, "_remote_source")

@model_validator(mode="after")
def deserialize_flyte_dir(self, info) -> FlyteDirectory:
if info.context is None or info.context.get("deserialize") is not True:
return self
if self._was_initialized_via_init():
return self

pv = FlyteDirToMultipartBlobTransformer().to_python_value(
FlyteContextManager.current_context(),
Expand Down
17 changes: 16 additions & 1 deletion flytekit/types/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,25 @@ def serialize_flyte_file(self) -> Dict[str, typing.Any]:
out["metadata"] = lv.metadata
return out

def _was_initialized_via_init(self) -> bool:
"""Check if object was initialized via __init__ or deserialized by Pydantic.

When Pydantic deserializes a FlyteFile, it bypasses __init__ and directly
sets the 'path' attribute. This means the _downloader and _remote_source attributes
won't exist. We use this check to determine if we need to go through the transformer
to properly set up these internal attributes.

Returns:
bool: True if __init__ was called (normal initialization), False if created via
Pydantic deserialization and needs to pass through transformer.
"""
return hasattr(self, "_downloader") and hasattr(self, "_remote_source")

@model_validator(mode="after")
def deserialize_flyte_file(self, info) -> "FlyteFile":
if info.context is None or info.context.get("deserialize") is not True:
return self
if self._was_initialized_via_init():
return self

pv = FlyteFilePathTransformer().to_python_value(
FlyteContextManager.current_context(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1022,3 +1022,109 @@ def mock_resolve_remote_path(flyte_uri: str):

bm_revived = TypeEngine.to_python_value(ctx, lit, BM)
assert bm_revived.s.literal.uri == "/my/replaced/val"


def test_flytefile_pydantic_model_dump_validate_cycle():
class BM(BaseModel):
ff: FlyteFile

bm = BM(ff=FlyteFile.from_source("s3://my-bucket/file.txt"))

assert bm.ff.remote_source == "s3://my-bucket/file.txt"

bm_dict = bm.model_dump()
bm2 = BM.model_validate(bm_dict)

assert isinstance(bm2.ff, FlyteFile)
assert bm2.ff.remote_source == "s3://my-bucket/file.txt"

bm2_dict = bm2.model_dump()
assert bm_dict == bm2_dict


def test_flytefile_pydantic_with_local_file(local_dummy_file):
class BM(BaseModel):
ff: FlyteFile

bm = BM(ff=FlyteFile(path=local_dummy_file))

bm_dict = bm.model_dump()
bm2 = BM.model_validate(bm_dict)

assert isinstance(bm2.ff, FlyteFile)
assert hasattr(bm2.ff, "_downloader")
assert hasattr(bm2.ff, "_remote_source")

bm2.model_dump()


def test_flytefile_pydantic_with_metadata(local_dummy_file):
class BM(BaseModel):
ff: FlyteFile

bm = BM(ff=FlyteFile(path=local_dummy_file, metadata={"key": "value"}))

bm_dict = bm.model_dump()
bm2 = BM.model_validate(bm_dict)

assert isinstance(bm2.ff, FlyteFile)
assert hasattr(bm2.ff, "_downloader")
assert hasattr(bm2.ff, "_remote_source")
assert bm2.ff.metadata == {"key": "value"}

bm2.model_dump()


def test_flytefile_pydantic_direct_dict_validate(local_dummy_file):
class BM(BaseModel):
ff: FlyteFile

bm = BM.model_validate({"ff": {"path": local_dummy_file}})

assert isinstance(bm.ff, FlyteFile)
assert hasattr(bm.ff, "_downloader")
assert hasattr(bm.ff, "_remote_source")


def test_flytedirectory_pydantic_direct_dict_validate(local_dummy_directory):
class BM(BaseModel):
fd: FlyteDirectory

bm = BM.model_validate({"fd": {"path": local_dummy_directory}})

assert isinstance(bm.fd, FlyteDirectory)
assert hasattr(bm.fd, "_downloader")
assert hasattr(bm.fd, "_remote_source")


def test_flytedirectory_pydantic_model_dump_validate_cycle():
class BM(BaseModel):
fd: FlyteDirectory

bm = BM(fd=FlyteDirectory.from_source("s3://my-bucket/my-dir"))

assert bm.fd.remote_source == "s3://my-bucket/my-dir"

bm_dict = bm.model_dump()
bm2 = BM.model_validate(bm_dict)

assert isinstance(bm2.fd, FlyteDirectory)
assert bm2.fd.remote_source == "s3://my-bucket/my-dir"

bm2.model_dump()


def test_flytedirectory_pydantic_with_local_directory(local_dummy_directory):
class BM(BaseModel):
fd: FlyteDirectory

bm = BM(fd=FlyteDirectory(path=local_dummy_directory))

bm_dict = bm.model_dump()
bm2 = BM.model_validate(bm_dict)

assert isinstance(bm2.fd, FlyteDirectory)
assert hasattr(bm2.fd, "_downloader")
assert hasattr(bm2.fd, "_remote_source")

bm2.model_dump()
Loading