Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 93 additions & 87 deletions plugins/flytekit-spark/flytekitplugins/spark/schema.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib.util
import typing
from typing import Type

Expand Down Expand Up @@ -84,84 +85,6 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
return r.all()


classic_ps_dataframe = lazy_module("pyspark.sql.classic.dataframe")
ClassicDataFrame = classic_ps_dataframe.DataFrame


class ClassicSparkDataFrameSchemaReader(SchemaReader[ClassicDataFrame]):
"""
Implements how Classic SparkDataFrame should be read using the ``open`` method of FlyteSchema
"""

def __init__(self, from_path: str, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat):
super().__init__(from_path, cols, fmt)

def iter(self, **kwargs) -> typing.Generator[T, None, None]:
raise NotImplementedError("Classic Spark DataFrame reader cannot iterate over individual chunks")

def all(self, **kwargs) -> ClassicDataFrame:
if self._fmt == SchemaFormat.PARQUET:
ctx = FlyteContext.current_context().user_space_params
return ctx.spark_session.read.parquet(self.from_path)
raise AssertionError("Only Parquet type files are supported for classic spark dataframe currently")


class ClassicSparkDataFrameSchemaWriter(SchemaWriter[ClassicDataFrame]):
"""
Implements how Classic SparkDataFrame should be written using ``open`` method of FlyteSchema
"""

def __init__(self, to_path: str, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat):
super().__init__(to_path, cols, fmt)

def write(self, *dfs: ClassicDataFrame, **kwargs):
if dfs is None or len(dfs) == 0:
return
if len(dfs) > 1:
raise AssertionError("Only a single Classic Spark.DataFrame can be written per variable currently")
if self._fmt == SchemaFormat.PARQUET:
dfs[0].write.mode("overwrite").parquet(self.to_path)
return
raise AssertionError("Only Parquet type files are supported for classic spark dataframe currently")


class ClassicSparkDataFrameTransformer(TypeTransformer[ClassicDataFrame]):
"""
Transforms Classic Spark DataFrame's to and from a Schema (typed/untyped)
"""

def __init__(self):
super().__init__("classic-spark-df-transformer", t=ClassicDataFrame)

@staticmethod
def _get_schema_type() -> SchemaType:
return SchemaType(columns=[])

def get_literal_type(self, t: Type[ClassicDataFrame]) -> LiteralType:
return LiteralType(schema=self._get_schema_type())

def to_literal(
self,
ctx: FlyteContext,
python_val: ClassicDataFrame,
python_type: Type[ClassicDataFrame],
expected: LiteralType,
) -> Literal:
remote_path = ctx.file_access.join(
ctx.file_access.raw_output_prefix,
ctx.file_access.get_random_string(),
)
w = ClassicSparkDataFrameSchemaWriter(to_path=remote_path, cols=None, fmt=SchemaFormat.PARQUET)
w.write(python_val)
return Literal(scalar=Scalar(schema=Schema(remote_path, self._get_schema_type())))

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[ClassicDataFrame]) -> T:
if not (lv and lv.scalar and lv.scalar.schema):
return ClassicDataFrame()
r = ClassicSparkDataFrameSchemaReader(from_path=lv.scalar.schema.uri, cols=None, fmt=SchemaFormat.PARQUET)
return r.all()


# %%
# Registers a handle for Spark DataFrame + Flyte Schema type transition
# This allows open(pyspark.DataFrame) to be an acceptable type
Expand All @@ -175,15 +98,98 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
)
)

SchemaEngine.register_handler(
SchemaHandler(
"pyspark.sql.classic.DataFrame-Schema",
ClassicDataFrame,
ClassicSparkDataFrameSchemaReader,
ClassicSparkDataFrameSchemaWriter,
handles_remote_io=True,
)
)
# %%
# This makes pyspark.DataFrame as a supported output/input type with flytekit.
TypeEngine.register(SparkDataFrameTransformer())

# Only for classic pyspark which may not be available in Spark 4+
try:
spec = importlib.util.find_spec("pyspark.sql.classic.dataframe")
except Exception:
spec = None

if spec:
classic_ps_dataframe = lazy_module("pyspark.sql.classic.dataframe")
ClassicDataFrame = classic_ps_dataframe.DataFrame

class ClassicSparkDataFrameSchemaReader(SchemaReader[ClassicDataFrame]):
"""
Implements how Classic SparkDataFrame should be read using the ``open`` method of FlyteSchema
"""

def __init__(self, from_path: str, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat):
super().__init__(from_path, cols, fmt)

def iter(self, **kwargs) -> typing.Generator[T, None, None]:
raise NotImplementedError("Classic Spark DataFrame reader cannot iterate over individual chunks")

def all(self, **kwargs) -> ClassicDataFrame:
if self._fmt == SchemaFormat.PARQUET:
ctx = FlyteContext.current_context().user_space_params
return ctx.spark_session.read.parquet(self.from_path)
raise AssertionError("Only Parquet type files are supported for classic spark dataframe currently")

class ClassicSparkDataFrameSchemaWriter(SchemaWriter[ClassicDataFrame]):
"""
Implements how Classic SparkDataFrame should be written using ``open`` method of FlyteSchema
"""

def __init__(self, to_path: str, cols: typing.Optional[typing.Dict[str, type]], fmt: SchemaFormat):
super().__init__(to_path, cols, fmt)

def write(self, *dfs: ClassicDataFrame, **kwargs):
if dfs is None or len(dfs) == 0:
return
if len(dfs) > 1:
raise AssertionError("Only a single Classic Spark.DataFrame can be written per variable currently")
if self._fmt == SchemaFormat.PARQUET:
dfs[0].write.mode("overwrite").parquet(self.to_path)
return
raise AssertionError("Only Parquet type files are supported for classic spark dataframe currently")

class ClassicSparkDataFrameTransformer(TypeTransformer[ClassicDataFrame]):
"""
Transforms Classic Spark DataFrame's to and from a Schema (typed/untyped)
"""

def __init__(self):
super().__init__("classic-spark-df-transformer", t=ClassicDataFrame)

@staticmethod
def _get_schema_type() -> SchemaType:
return SchemaType(columns=[])

def get_literal_type(self, t: Type[ClassicDataFrame]) -> LiteralType:
return LiteralType(schema=self._get_schema_type())

def to_literal(
self,
ctx: FlyteContext,
python_val: ClassicDataFrame,
python_type: Type[ClassicDataFrame],
expected: LiteralType,
) -> Literal:
remote_path = ctx.file_access.join(
ctx.file_access.raw_output_prefix,
ctx.file_access.get_random_string(),
)
w = ClassicSparkDataFrameSchemaWriter(to_path=remote_path, cols=None, fmt=SchemaFormat.PARQUET)
w.write(python_val)
return Literal(scalar=Scalar(schema=Schema(remote_path, self._get_schema_type())))

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[ClassicDataFrame]) -> T:
if not (lv and lv.scalar and lv.scalar.schema):
return ClassicDataFrame()
r = ClassicSparkDataFrameSchemaReader(from_path=lv.scalar.schema.uri, cols=None, fmt=SchemaFormat.PARQUET)
return r.all()

SchemaEngine.register_handler(
SchemaHandler(
"pyspark.sql.classic.DataFrame-Schema",
ClassicDataFrame,
ClassicSparkDataFrameSchemaReader,
ClassicSparkDataFrameSchemaWriter,
handles_remote_io=True,
)
)
TypeEngine.register(ClassicSparkDataFrameTransformer())
104 changes: 57 additions & 47 deletions plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib.util
import typing

from flytekit import FlyteContext, lazy_module
Expand All @@ -14,6 +15,8 @@

pd = lazy_module("pandas")
pyspark = lazy_module("pyspark")

# Base Spark DataFrame (Spark 3.x or Spark 4 parent)
ps_dataframe = lazy_module("pyspark.sql.dataframe")
DataFrame = ps_dataframe.DataFrame

Expand Down Expand Up @@ -47,7 +50,9 @@ def encode(
df = typing.cast(DataFrame, structured_dataset.dataframe)
ss = pyspark.sql.SparkSession.builder.getOrCreate()
# Avoid generating SUCCESS files

ss.conf.set("mapreduce.fileoutputcommitter.marksuccessfuljobs", "false")

df.write.mode("overwrite").parquet(path=path)
return literals.StructuredDataset(uri=path, metadata=StructuredDatasetMetadata(structured_dataset_type))

Expand All @@ -73,51 +78,56 @@ def decode(
StructuredDatasetTransformerEngine.register(ParquetToSparkDecodingHandler())
StructuredDatasetTransformerEngine.register_renderer(DataFrame, SparkDataFrameRenderer())

classic_ps_dataframe = lazy_module("pyspark.sql.classic.dataframe")
ClassicDataFrame = classic_ps_dataframe.DataFrame


class ClassicSparkToParquetEncodingHandler(StructuredDatasetEncoder):
def __init__(self):
super().__init__(ClassicDataFrame, None, PARQUET)

def encode(
self,
ctx: FlyteContext,
structured_dataset: StructuredDataset,
structured_dataset_type: StructuredDatasetType,
) -> literals.StructuredDataset:
path = typing.cast(str, structured_dataset.uri)
if not path:
path = ctx.file_access.join(
ctx.file_access.raw_output_prefix,
ctx.file_access.get_random_string(),
)
df = typing.cast(ClassicDataFrame, structured_dataset.dataframe)
ss = pyspark.sql.SparkSession.builder.getOrCreate()
ss.conf.set("mapreduce.fileoutputcommitter.marksuccessfuljobs", "false")
df.write.mode("overwrite").parquet(path=path)
return literals.StructuredDataset(uri=path, metadata=StructuredDatasetMetadata(structured_dataset_type))


class ParquetToClassicSparkDecodingHandler(StructuredDatasetDecoder):
def __init__(self):
super().__init__(ClassicDataFrame, None, PARQUET)

def decode(
self,
ctx: FlyteContext,
flyte_value: literals.StructuredDataset,
current_task_metadata: StructuredDatasetMetadata,
) -> ClassicDataFrame:
user_ctx = FlyteContext.current_context().user_space_params
if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns:
columns = [c.name for c in current_task_metadata.structured_dataset_type.columns]
return user_ctx.spark_session.read.parquet(flyte_value.uri).select(*columns)
return user_ctx.spark_session.read.parquet(flyte_value.uri)


# Register the handlers
StructuredDatasetTransformerEngine.register(ClassicSparkToParquetEncodingHandler())
StructuredDatasetTransformerEngine.register(ParquetToClassicSparkDecodingHandler())
StructuredDatasetTransformerEngine.register_renderer(ClassicDataFrame, SparkDataFrameRenderer())
# Only for classic pyspark which may not be available in Spark 4+
try:
spec = importlib.util.find_spec("pyspark.sql.classic.dataframe")
except Exception:
spec = None
if spec:
# Spark 4 "classic" concrete DataFrame, if available
classic_ps_dataframe = lazy_module("pyspark.sql.classic.dataframe")
ClassicDataFrame = classic_ps_dataframe.DataFrame

class ClassicSparkToParquetEncodingHandler(StructuredDatasetEncoder):
def __init__(self):
super().__init__(ClassicDataFrame, None, PARQUET)

def encode(
self,
ctx: FlyteContext,
structured_dataset: StructuredDataset,
structured_dataset_type: StructuredDatasetType,
) -> literals.StructuredDataset:
path = typing.cast(str, structured_dataset.uri)
if not path:
path = ctx.file_access.join(
ctx.file_access.raw_output_prefix,
ctx.file_access.get_random_string(),
)
df = typing.cast(ClassicDataFrame, structured_dataset.dataframe)
ss = pyspark.sql.SparkSession.builder.getOrCreate()
ss.conf.set("mapreduce.fileoutputcommitter.marksuccessfuljobs", "false")
df.write.mode("overwrite").parquet(path=path)
return literals.StructuredDataset(uri=path, metadata=StructuredDatasetMetadata(structured_dataset_type))

class ParquetToClassicSparkDecodingHandler(StructuredDatasetDecoder):
def __init__(self):
super().__init__(ClassicDataFrame, None, PARQUET)

def decode(
self,
ctx: FlyteContext,
flyte_value: literals.StructuredDataset,
current_task_metadata: StructuredDatasetMetadata,
) -> ClassicDataFrame:
user_ctx = FlyteContext.current_context().user_space_params
if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns:
columns = [c.name for c in current_task_metadata.structured_dataset_type.columns]
return user_ctx.spark_session.read.parquet(flyte_value.uri).select(*columns)
return user_ctx.spark_session.read.parquet(flyte_value.uri)

# Register the handlers
StructuredDatasetTransformerEngine.register(ClassicSparkToParquetEncodingHandler())
StructuredDatasetTransformerEngine.register(ParquetToClassicSparkDecodingHandler())
StructuredDatasetTransformerEngine.register_renderer(ClassicDataFrame, SparkDataFrameRenderer())