From 10aa9a05329b985bd7345e80482c54fa600a8783 Mon Sep 17 00:00:00 2001 From: Kevin Liao Date: Thu, 13 Nov 2025 06:08:01 +0800 Subject: [PATCH] fix: update Spark plugin for compatibility with Spark 3.x and 4.x Signed-off-by: Kevin Liao --- .../flytekitplugins/spark/schema.py | 180 +++++++++--------- .../flytekitplugins/spark/sd_transformers.py | 104 +++++----- 2 files changed, 150 insertions(+), 134 deletions(-) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/schema.py b/plugins/flytekit-spark/flytekitplugins/spark/schema.py index 4c423e6894..4a174619ab 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/schema.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/schema.py @@ -1,3 +1,4 @@ +import importlib.util import typing from typing import Type @@ -84,84 +85,6 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: return r.all() -classic_ps_dataframe = lazy_module("pyspark.sql.classic.dataframe") -ClassicDataFrame = classic_ps_dataframe.DataFrame - - -class ClassicSparkDataFrameSchemaReader(SchemaReader[ClassicDataFrame]): - """ - Implements how Classic SparkDataFrame should be read using the ``open`` method of FlyteSchema - """ - - def __init__(self, from_path: str, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): - super().__init__(from_path, cols, fmt) - - def iter(self, **kwargs) -> typing.Generator[T, None, None]: - raise NotImplementedError("Classic Spark DataFrame reader cannot iterate over individual chunks") - - def all(self, **kwargs) -> ClassicDataFrame: - if self._fmt == SchemaFormat.PARQUET: - ctx = FlyteContext.current_context().user_space_params - return ctx.spark_session.read.parquet(self.from_path) - raise AssertionError("Only Parquet type files are supported for classic spark dataframe currently") - - -class ClassicSparkDataFrameSchemaWriter(SchemaWriter[ClassicDataFrame]): - """ - Implements how Classic SparkDataFrame should be written using ``open`` method of FlyteSchema - """ - - def __init__(self, to_path: str, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): - super().__init__(to_path, cols, fmt) - - def write(self, *dfs: ClassicDataFrame, **kwargs): - if dfs is None or len(dfs) == 0: - return - if len(dfs) > 1: - raise AssertionError("Only a single Classic Spark.DataFrame can be written per variable currently") - if self._fmt == SchemaFormat.PARQUET: - dfs[0].write.mode("overwrite").parquet(self.to_path) - return - raise AssertionError("Only Parquet type files are supported for classic spark dataframe currently") - - -class ClassicSparkDataFrameTransformer(TypeTransformer[ClassicDataFrame]): - """ - Transforms Classic Spark DataFrame's to and from a Schema (typed/untyped) - """ - - def __init__(self): - super().__init__("classic-spark-df-transformer", t=ClassicDataFrame) - - @staticmethod - def _get_schema_type() -> SchemaType: - return SchemaType(columns=[]) - - def get_literal_type(self, t: Type[ClassicDataFrame]) -> LiteralType: - return LiteralType(schema=self._get_schema_type()) - - def to_literal( - self, - ctx: FlyteContext, - python_val: ClassicDataFrame, - python_type: Type[ClassicDataFrame], - expected: LiteralType, - ) -> Literal: - remote_path = ctx.file_access.join( - ctx.file_access.raw_output_prefix, - ctx.file_access.get_random_string(), - ) - w = ClassicSparkDataFrameSchemaWriter(to_path=remote_path, cols=None, fmt=SchemaFormat.PARQUET) - w.write(python_val) - return Literal(scalar=Scalar(schema=Schema(remote_path, self._get_schema_type()))) - - def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[ClassicDataFrame]) -> T: - if not (lv and lv.scalar and lv.scalar.schema): - return ClassicDataFrame() - r = ClassicSparkDataFrameSchemaReader(from_path=lv.scalar.schema.uri, cols=None, fmt=SchemaFormat.PARQUET) - return r.all() - - # %% # Registers a handle for Spark DataFrame + Flyte Schema type transition # This allows open(pyspark.DataFrame) to be an acceptable type @@ -175,15 +98,98 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: ) ) -SchemaEngine.register_handler( - SchemaHandler( - "pyspark.sql.classic.DataFrame-Schema", - ClassicDataFrame, - ClassicSparkDataFrameSchemaReader, - ClassicSparkDataFrameSchemaWriter, - handles_remote_io=True, - ) -) # %% # This makes pyspark.DataFrame as a supported output/input type with flytekit. TypeEngine.register(SparkDataFrameTransformer()) + +# Only for classic pyspark which may not be available in Spark 4+ +try: + spec = importlib.util.find_spec("pyspark.sql.classic.dataframe") +except Exception: + spec = None + +if spec: + classic_ps_dataframe = lazy_module("pyspark.sql.classic.dataframe") + ClassicDataFrame = classic_ps_dataframe.DataFrame + + class ClassicSparkDataFrameSchemaReader(SchemaReader[ClassicDataFrame]): + """ + Implements how Classic SparkDataFrame should be read using the ``open`` method of FlyteSchema + """ + + def __init__(self, from_path: str, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): + super().__init__(from_path, cols, fmt) + + def iter(self, **kwargs) -> typing.Generator[T, None, None]: + raise NotImplementedError("Classic Spark DataFrame reader cannot iterate over individual chunks") + + def all(self, **kwargs) -> ClassicDataFrame: + if self._fmt == SchemaFormat.PARQUET: + ctx = FlyteContext.current_context().user_space_params + return ctx.spark_session.read.parquet(self.from_path) + raise AssertionError("Only Parquet type files are supported for classic spark dataframe currently") + + class ClassicSparkDataFrameSchemaWriter(SchemaWriter[ClassicDataFrame]): + """ + Implements how Classic SparkDataFrame should be written using ``open`` method of FlyteSchema + """ + + def __init__(self, to_path: str, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat): + super().__init__(to_path, cols, fmt) + + def write(self, *dfs: ClassicDataFrame, **kwargs): + if dfs is None or len(dfs) == 0: + return + if len(dfs) > 1: + raise AssertionError("Only a single Classic Spark.DataFrame can be written per variable currently") + if self._fmt == SchemaFormat.PARQUET: + dfs[0].write.mode("overwrite").parquet(self.to_path) + return + raise AssertionError("Only Parquet type files are supported for classic spark dataframe currently") + + class ClassicSparkDataFrameTransformer(TypeTransformer[ClassicDataFrame]): + """ + Transforms Classic Spark DataFrame's to and from a Schema (typed/untyped) + """ + + def __init__(self): + super().__init__("classic-spark-df-transformer", t=ClassicDataFrame) + + @staticmethod + def _get_schema_type() -> SchemaType: + return SchemaType(columns=[]) + + def get_literal_type(self, t: Type[ClassicDataFrame]) -> LiteralType: + return LiteralType(schema=self._get_schema_type()) + + def to_literal( + self, + ctx: FlyteContext, + python_val: ClassicDataFrame, + python_type: Type[ClassicDataFrame], + expected: LiteralType, + ) -> Literal: + remote_path = ctx.file_access.join( + ctx.file_access.raw_output_prefix, + ctx.file_access.get_random_string(), + ) + w = ClassicSparkDataFrameSchemaWriter(to_path=remote_path, cols=None, fmt=SchemaFormat.PARQUET) + w.write(python_val) + return Literal(scalar=Scalar(schema=Schema(remote_path, self._get_schema_type()))) + + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[ClassicDataFrame]) -> T: + if not (lv and lv.scalar and lv.scalar.schema): + return ClassicDataFrame() + r = ClassicSparkDataFrameSchemaReader(from_path=lv.scalar.schema.uri, cols=None, fmt=SchemaFormat.PARQUET) + return r.all() + + SchemaEngine.register_handler( + SchemaHandler( + "pyspark.sql.classic.DataFrame-Schema", + ClassicDataFrame, + ClassicSparkDataFrameSchemaReader, + ClassicSparkDataFrameSchemaWriter, + handles_remote_io=True, + ) + ) + TypeEngine.register(ClassicSparkDataFrameTransformer()) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py b/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py index a849b711dc..1c3524ff73 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py @@ -1,3 +1,4 @@ +import importlib.util import typing from flytekit import FlyteContext, lazy_module @@ -14,6 +15,8 @@ pd = lazy_module("pandas") pyspark = lazy_module("pyspark") + +# Base Spark DataFrame (Spark 3.x or Spark 4 parent) ps_dataframe = lazy_module("pyspark.sql.dataframe") DataFrame = ps_dataframe.DataFrame @@ -47,7 +50,9 @@ def encode( df = typing.cast(DataFrame, structured_dataset.dataframe) ss = pyspark.sql.SparkSession.builder.getOrCreate() # Avoid generating SUCCESS files + ss.conf.set("mapreduce.fileoutputcommitter.marksuccessfuljobs", "false") + df.write.mode("overwrite").parquet(path=path) return literals.StructuredDataset(uri=path, metadata=StructuredDatasetMetadata(structured_dataset_type)) @@ -73,51 +78,56 @@ def decode( StructuredDatasetTransformerEngine.register(ParquetToSparkDecodingHandler()) StructuredDatasetTransformerEngine.register_renderer(DataFrame, SparkDataFrameRenderer()) -classic_ps_dataframe = lazy_module("pyspark.sql.classic.dataframe") -ClassicDataFrame = classic_ps_dataframe.DataFrame - - -class ClassicSparkToParquetEncodingHandler(StructuredDatasetEncoder): - def __init__(self): - super().__init__(ClassicDataFrame, None, PARQUET) - - def encode( - self, - ctx: FlyteContext, - structured_dataset: StructuredDataset, - structured_dataset_type: StructuredDatasetType, - ) -> literals.StructuredDataset: - path = typing.cast(str, structured_dataset.uri) - if not path: - path = ctx.file_access.join( - ctx.file_access.raw_output_prefix, - ctx.file_access.get_random_string(), - ) - df = typing.cast(ClassicDataFrame, structured_dataset.dataframe) - ss = pyspark.sql.SparkSession.builder.getOrCreate() - ss.conf.set("mapreduce.fileoutputcommitter.marksuccessfuljobs", "false") - df.write.mode("overwrite").parquet(path=path) - return literals.StructuredDataset(uri=path, metadata=StructuredDatasetMetadata(structured_dataset_type)) - - -class ParquetToClassicSparkDecodingHandler(StructuredDatasetDecoder): - def __init__(self): - super().__init__(ClassicDataFrame, None, PARQUET) - - def decode( - self, - ctx: FlyteContext, - flyte_value: literals.StructuredDataset, - current_task_metadata: StructuredDatasetMetadata, - ) -> ClassicDataFrame: - user_ctx = FlyteContext.current_context().user_space_params - if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: - columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] - return user_ctx.spark_session.read.parquet(flyte_value.uri).select(*columns) - return user_ctx.spark_session.read.parquet(flyte_value.uri) - -# Register the handlers -StructuredDatasetTransformerEngine.register(ClassicSparkToParquetEncodingHandler()) -StructuredDatasetTransformerEngine.register(ParquetToClassicSparkDecodingHandler()) -StructuredDatasetTransformerEngine.register_renderer(ClassicDataFrame, SparkDataFrameRenderer()) +# Only for classic pyspark which may not be available in Spark 4+ +try: + spec = importlib.util.find_spec("pyspark.sql.classic.dataframe") +except Exception: + spec = None +if spec: + # Spark 4 "classic" concrete DataFrame, if available + classic_ps_dataframe = lazy_module("pyspark.sql.classic.dataframe") + ClassicDataFrame = classic_ps_dataframe.DataFrame + + class ClassicSparkToParquetEncodingHandler(StructuredDatasetEncoder): + def __init__(self): + super().__init__(ClassicDataFrame, None, PARQUET) + + def encode( + self, + ctx: FlyteContext, + structured_dataset: StructuredDataset, + structured_dataset_type: StructuredDatasetType, + ) -> literals.StructuredDataset: + path = typing.cast(str, structured_dataset.uri) + if not path: + path = ctx.file_access.join( + ctx.file_access.raw_output_prefix, + ctx.file_access.get_random_string(), + ) + df = typing.cast(ClassicDataFrame, structured_dataset.dataframe) + ss = pyspark.sql.SparkSession.builder.getOrCreate() + ss.conf.set("mapreduce.fileoutputcommitter.marksuccessfuljobs", "false") + df.write.mode("overwrite").parquet(path=path) + return literals.StructuredDataset(uri=path, metadata=StructuredDatasetMetadata(structured_dataset_type)) + + class ParquetToClassicSparkDecodingHandler(StructuredDatasetDecoder): + def __init__(self): + super().__init__(ClassicDataFrame, None, PARQUET) + + def decode( + self, + ctx: FlyteContext, + flyte_value: literals.StructuredDataset, + current_task_metadata: StructuredDatasetMetadata, + ) -> ClassicDataFrame: + user_ctx = FlyteContext.current_context().user_space_params + if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: + columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] + return user_ctx.spark_session.read.parquet(flyte_value.uri).select(*columns) + return user_ctx.spark_session.read.parquet(flyte_value.uri) + + # Register the handlers + StructuredDatasetTransformerEngine.register(ClassicSparkToParquetEncodingHandler()) + StructuredDatasetTransformerEngine.register(ParquetToClassicSparkDecodingHandler()) + StructuredDatasetTransformerEngine.register_renderer(ClassicDataFrame, SparkDataFrameRenderer())