Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions flytekit/extras/tensorflow/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def get_literal_type(self, t: Type[tf.keras.Model]) -> LiteralType:
return LiteralType(
blob=_core_types.BlobType(
format=self.TENSORFLOW_FORMAT,
dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART,
dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE,
)
)

Expand All @@ -34,14 +34,14 @@ async def async_to_literal(
meta = BlobMetadata(
type=_core_types.BlobType(
format=self.TENSORFLOW_FORMAT,
dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART,
dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE,
)
)

local_path = ctx.file_access.get_random_local_path()
local_path = ctx.file_access.get_random_local_path() + ".keras"
pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True)

# save model in SavedModel format
# save model in Keras format
tf.keras.models.save_model(python_val, local_path)

remote_path = await ctx.file_access.async_put_raw_data(local_path)
Expand All @@ -55,16 +55,16 @@ async def async_to_python_value(
except AttributeError:
TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}")

local_path = ctx.file_access.get_random_local_path()
await ctx.file_access.async_get_data(uri, local_path, is_multipart=True)
local_path = ctx.file_access.get_random_local_path() + ".keras"
await ctx.file_access.async_get_data(uri, local_path, is_multipart=False)

# load model
return tf.keras.models.load_model(local_path)

def guess_python_type(self, literal_type: LiteralType) -> Type[tf.keras.Model]:
if (
literal_type.blob is not None
and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.MULTIPART
and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE
and literal_type.blob.format == self.TENSORFLOW_FORMAT
):
return tf.keras.Model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_tf_model() -> tf.keras.Model:
)
def test_get_literal_type(transformer, python_type, format):
lt = transformer.get_literal_type(python_type)
assert lt == LiteralType(blob=BlobType(format=format, dimensionality=BlobType.BlobDimensionality.MULTIPART))
assert lt == LiteralType(blob=BlobType(format=format, dimensionality=BlobType.BlobDimensionality.SINGLE))


@pytest.mark.parametrize(
Expand All @@ -58,7 +58,7 @@ def test_to_python_value_and_literal(transformer, python_type, format, python_va
assert lv.scalar.blob.metadata == BlobMetadata(
type=BlobType(
format=format,
dimensionality=BlobType.BlobDimensionality.MULTIPART,
dimensionality=BlobType.BlobDimensionality.SINGLE,
)
)
assert lv.scalar.blob.uri is not None
Expand Down
Loading