From dbe47c39b4e749d137998e66706a9742349d52f2 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Thu, 29 May 2025 10:15:07 -0700 Subject: [PATCH] copy/paste for now Signed-off-by: Yee Hing Tong --- .../flytekitplugins/spark/sd_transformers.py | 51 +++++++++++++++++-- 1 file changed, 48 insertions(+), 3 deletions(-) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py b/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py index 2a0faa1b5d..ec681f3749 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py @@ -12,10 +12,12 @@ StructuredDatasetTransformerEngine, ) +import pyspark +from pyspark.sql import dataframe +from pyspark.sql.classic.dataframe import DataFrame as ClassicDataFrame + pd = lazy_module("pandas") -pyspark = lazy_module("pyspark") -ps_dataframe = lazy_module("pyspark.sql.dataframe") -DataFrame = ps_dataframe.DataFrame +DataFrame = dataframe.DataFrame class SparkDataFrameRenderer: @@ -52,6 +54,30 @@ def encode( return literals.StructuredDataset(uri=path, metadata=StructuredDatasetMetadata(structured_dataset_type)) +class ClassicSparkToParquetEncodingHandler(StructuredDatasetEncoder): + def __init__(self): + super().__init__(ClassicDataFrame, None, PARQUET) + + def encode( + self, + ctx: FlyteContext, + structured_dataset: StructuredDataset, + structured_dataset_type: StructuredDatasetType, + ) -> literals.StructuredDataset: + path = typing.cast(str, structured_dataset.uri) + if not path: + path = ctx.file_access.join( + ctx.file_access.raw_output_prefix, + ctx.file_access.get_random_string(), + ) + df = typing.cast(DataFrame, structured_dataset.dataframe) + ss = pyspark.sql.SparkSession.builder.getOrCreate() + # Avoid generating SUCCESS files + ss.conf.set("mapreduce.fileoutputcommitter.marksuccessfuljobs", "false") + df.write.mode("overwrite").parquet(path=path) + return literals.StructuredDataset(uri=path, metadata=StructuredDatasetMetadata(structured_dataset_type)) + + class ParquetToSparkDecodingHandler(StructuredDatasetDecoder): def __init__(self): super().__init__(DataFrame, None, PARQUET) @@ -69,6 +95,25 @@ def decode( return user_ctx.spark_session.read.parquet(flyte_value.uri) +class ClassicParquetToSparkDecodingHandler(StructuredDatasetDecoder): + def __init__(self): + super().__init__(ClassicDataFrame, None, PARQUET) + + def decode( + self, + ctx: FlyteContext, + flyte_value: literals.StructuredDataset, + current_task_metadata: StructuredDatasetMetadata, + ) -> DataFrame: + user_ctx = FlyteContext.current_context().user_space_params + if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: + columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] + return user_ctx.spark_session.read.parquet(flyte_value.uri).select(*columns) + return user_ctx.spark_session.read.parquet(flyte_value.uri) + + StructuredDatasetTransformerEngine.register(SparkToParquetEncodingHandler()) StructuredDatasetTransformerEngine.register(ParquetToSparkDecodingHandler()) +StructuredDatasetTransformerEngine.register(ClassicSparkToParquetEncodingHandler()) +StructuredDatasetTransformerEngine.register(ClassicParquetToSparkDecodingHandler()) StructuredDatasetTransformerEngine.register_renderer(DataFrame, SparkDataFrameRenderer())