diff --git a/flytekit/extras/tensorflow/model.py b/flytekit/extras/tensorflow/model.py index b9fbf24d4b..ee0c663f00 100644 --- a/flytekit/extras/tensorflow/model.py +++ b/flytekit/extras/tensorflow/model.py @@ -20,7 +20,7 @@ def get_literal_type(self, t: Type[tf.keras.Model]) -> LiteralType: return LiteralType( blob=_core_types.BlobType( format=self.TENSORFLOW_FORMAT, - dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, + dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, ) ) @@ -34,14 +34,14 @@ async def async_to_literal( meta = BlobMetadata( type=_core_types.BlobType( format=self.TENSORFLOW_FORMAT, - dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, + dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, ) ) - local_path = ctx.file_access.get_random_local_path() + local_path = ctx.file_access.get_random_local_path() + ".keras" pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True) - # save model in SavedModel format + # save model in Keras format tf.keras.models.save_model(python_val, local_path) remote_path = await ctx.file_access.async_put_raw_data(local_path) @@ -55,8 +55,8 @@ async def async_to_python_value( except AttributeError: TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") - local_path = ctx.file_access.get_random_local_path() - await ctx.file_access.async_get_data(uri, local_path, is_multipart=True) + local_path = ctx.file_access.get_random_local_path() + ".keras" + await ctx.file_access.async_get_data(uri, local_path, is_multipart=False) # load model return tf.keras.models.load_model(local_path) @@ -64,7 +64,7 @@ async def async_to_python_value( def guess_python_type(self, literal_type: LiteralType) -> Type[tf.keras.Model]: if ( literal_type.blob is not None - and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.MULTIPART + and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE and literal_type.blob.format == self.TENSORFLOW_FORMAT ): return tf.keras.Model diff --git a/tests/flytekit/unit/extras/tensorflow/model/test_transformations.py b/tests/flytekit/unit/extras/tensorflow/model/test_transformations.py index 9d7aa12737..7bc487b0a5 100644 --- a/tests/flytekit/unit/extras/tensorflow/model/test_transformations.py +++ b/tests/flytekit/unit/extras/tensorflow/model/test_transformations.py @@ -39,7 +39,7 @@ def get_tf_model() -> tf.keras.Model: ) def test_get_literal_type(transformer, python_type, format): lt = transformer.get_literal_type(python_type) - assert lt == LiteralType(blob=BlobType(format=format, dimensionality=BlobType.BlobDimensionality.MULTIPART)) + assert lt == LiteralType(blob=BlobType(format=format, dimensionality=BlobType.BlobDimensionality.SINGLE)) @pytest.mark.parametrize( @@ -58,7 +58,7 @@ def test_to_python_value_and_literal(transformer, python_type, format, python_va assert lv.scalar.blob.metadata == BlobMetadata( type=BlobType( format=format, - dimensionality=BlobType.BlobDimensionality.MULTIPART, + dimensionality=BlobType.BlobDimensionality.SINGLE, ) ) assert lv.scalar.blob.uri is not None