Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
import polars as pl

from pl_fuzzy_frame_match.models import FuzzyMapping

from io import StringIO
from flowfile_core.flowfile.flow_graph import FlowGraph
from flowfile_core.flowfile.flow_data_engine.flow_file_column.main import FlowfileColumn, convert_pl_type_to_string
from flowfile_core.flowfile.flow_data_engine.flow_file_column.utils import cast_str_to_polars_type
from flowfile_core.flowfile.flow_node.flow_node import FlowNode
from flowfile_core.flowfile.util.execution_orderer import determine_execution_order
from flowfile_core.schemas import input_schema, transform_schema
from flowfile_core.configs import logger

from functools import lru_cache

class FlowGraphToPolarsConverter:
"""
Expand All @@ -22,17 +22,16 @@ class FlowGraphToPolarsConverter:
flow_graph: FlowGraph
node_var_mapping: Dict[int, str]
imports: Set[str]
code_lines: List[str]
output_nodes: List[Tuple[int, str]] = []
last_node_var: Optional[str] = None

def __init__(self, flow_graph: FlowGraph):
self.flow_graph = flow_graph
self.node_var_mapping: Dict[int, str] = {} # Maps node_id to variable name
self.imports: Set[str] = {"import polars as pl"}
self.code_lines: List[str] = []
self.output_nodes = []
self.last_node_var = None
self.code_buffer = StringIO()

def convert(self) -> str:
"""
Expand Down Expand Up @@ -510,7 +509,6 @@ def _handle_left_inner_join_keys(self, settings: input_schema.NodeJoin, right_df
- reverse_action: Mapping to rename __DROP__ columns after join
- after_join_drop_cols: Left join keys marked for dropping
"""
left_join_keys_to_keep = [jk.new_name for jk in settings.join_input.left_select.join_key_selects if jk.keep]

join_key_duplication_command = [
f'pl.col("{rjk.old_name}").alias("__DROP__{rjk.new_name}__DROP__")'
Expand Down Expand Up @@ -682,7 +680,7 @@ def _execute_join_with_post_processing(self, settings: input_schema.NodeJoin, va

# Convert back to lazy for right joins
if settings.join_input.how == 'right':
self._add_code(f".lazy()")
self._add_code(".lazy()")

self._add_code(")")

Expand Down Expand Up @@ -734,7 +732,7 @@ def _handle_pivot_no_index(self, settings: input_schema.NodePivot, var_name: str
self._add_code(' .with_columns(pl.lit(1).alias("__temp_index__"))')
self._add_code(' .pivot(')
self._add_code(f' values="{pivot_input.value_col}",')
self._add_code(f' index=["__temp_index__"],')
self._add_code( ' index=["__temp_index__"],')
self._add_code(f' columns="{pivot_input.pivot_column}",')
self._add_code(f' aggregate_function="{agg_func}"')
self._add_code(" )")
Expand Down Expand Up @@ -900,7 +898,7 @@ def _handle_record_id(self, settings: input_schema.NodeRecordId, var_name: str,
# Row number within groups
self._add_code(f"{var_name} = ({input_df}")
self._add_code(f" .with_columns(pl.lit(1).alias('{record_input.output_column_name}'))")
self._add_code(f" .with_columns([")
self._add_code( " .with_columns([")
self._add_code(f" (pl.cum_count('{record_input.output_column_name}').over({record_input.group_by_columns}) + {record_input.offset} - 1)")
self._add_code(f" .alias('{record_input.output_column_name}')")
self._add_code("])")
Expand Down Expand Up @@ -928,24 +926,24 @@ def _handle_cloud_storage_writer(self, settings: input_schema.NodeCloudStorageWr
self.imports.add("import flowfile as ff")
self._add_code(f"(ff.FlowFrame({input_df})")
if output_settings.file_format == "csv":
self._add_code(f' .write_csv_to_cloud_storage(')
self._add_code( ' .write_csv_to_cloud_storage(')
self._add_code(f' path="{output_settings.resource_path}",')
self._add_code(f' connection_name="{output_settings.connection_name}",')
self._add_code(f' delimiter="{output_settings.csv_delimiter}",')
self._add_code(f' encoding="{output_settings.csv_encoding}",')
self._add_code(f' description="{settings.description}"')
elif output_settings.file_format == "parquet":
self._add_code(f' .write_parquet_to_cloud_storage(')
self._add_code( ' .write_parquet_to_cloud_storage(')
self._add_code(f' path="{output_settings.resource_path}",')
self._add_code(f' connection_name="{output_settings.connection_name}",')
self._add_code(f' description="{settings.description}"')
elif output_settings.file_format == "json":
self._add_code(f' .write_json_to_cloud_storage(')
self._add_code( ' .write_json_to_cloud_storage(')
self._add_code(f' path="{output_settings.resource_path}",')
self._add_code(f' connection_name="{output_settings.connection_name}",')
self._add_code(f' description="{settings.description}"')
elif output_settings.file_format == "delta":
self._add_code(f' .write_delta(')
self._add_code( ' .write_delta(')
self._add_code(f' path="{output_settings.resource_path}",')
self._add_code(f' write_mode="{output_settings.write_mode}",')
self._add_code(f' connection_name="{output_settings.connection_name}",')
Expand Down Expand Up @@ -1003,7 +1001,7 @@ def _handle_polars_code(self, settings: input_schema.NodePolarsCode, var_name: s
is_expression = "output_df" not in code

# Wrap the code in a function
self._add_code(f"# Custom Polars code")
self._add_code("# Custom Polars code")
self._add_code(f"def _polars_code_{var_name.replace('df_', '')}({params}):")

# Handle the code based on its structure
Expand Down Expand Up @@ -1036,11 +1034,11 @@ def _handle_polars_code(self, settings: input_schema.NodePolarsCode, var_name: s

def _add_code(self, line: str) -> None:
"""Add a line of code."""
self.code_lines.append(line)
self.code_buffer.write(line + '\n')

def _add_comment(self, comment: str) -> None:
"""Add a comment line."""
self.code_lines.append(comment)
self.code_buffer.write(comment + '\n')

def _parse_filter_expression(self, expr: str) -> str:
"""Parse Flowfile filter expression to Polars expression."""
Expand Down Expand Up @@ -1093,6 +1091,7 @@ def _create_basic_filter_expr(self, basic: transform_schema.BasicFilter) -> str:
return f"pl.col('{col}').is_in({values})"
return col

@lru_cache(maxsize=1024)
def _get_polars_dtype(self, dtype_str: str) -> str:
"""Convert Flowfile dtype string to Polars dtype."""
dtype_map = {
Expand All @@ -1110,6 +1109,7 @@ def _get_polars_dtype(self, dtype_str: str) -> str:
}
return dtype_map.get(dtype_str, 'pl.Utf8')

@lru_cache(maxsize=1024)
def _get_agg_function(self, agg: str) -> str:
"""Get Polars aggregation function name."""
agg_map = {
Expand Down Expand Up @@ -1181,15 +1181,7 @@ def _build_final_code(self) -> str:
lines.append(' Generated from Flowfile')
lines.append(' """')
lines.append(" ")

# Add the generated code
for line in self.code_lines:
if line:
lines.append(f" {line}")
else:
lines.append("")
# Add main block
lines.append("")
lines.append(self.code_buffer.getvalue())
self.add_return_code(lines)
lines.append("")
lines.append("")
Expand Down
Loading