diff --git a/flowfile_core/flowfile_core/flowfile/code_generator/code_generator.py b/flowfile_core/flowfile_core/flowfile/code_generator/code_generator.py index e9cd6847b..a18616c59 100644 --- a/flowfile_core/flowfile_core/flowfile/code_generator/code_generator.py +++ b/flowfile_core/flowfile_core/flowfile/code_generator/code_generator.py @@ -2,7 +2,7 @@ import polars as pl from pl_fuzzy_frame_match.models import FuzzyMapping - +from io import StringIO from flowfile_core.flowfile.flow_graph import FlowGraph from flowfile_core.flowfile.flow_data_engine.flow_file_column.main import FlowfileColumn, convert_pl_type_to_string from flowfile_core.flowfile.flow_data_engine.flow_file_column.utils import cast_str_to_polars_type @@ -10,7 +10,7 @@ from flowfile_core.flowfile.util.execution_orderer import determine_execution_order from flowfile_core.schemas import input_schema, transform_schema from flowfile_core.configs import logger - +from functools import lru_cache class FlowGraphToPolarsConverter: """ @@ -22,7 +22,6 @@ class FlowGraphToPolarsConverter: flow_graph: FlowGraph node_var_mapping: Dict[int, str] imports: Set[str] - code_lines: List[str] output_nodes: List[Tuple[int, str]] = [] last_node_var: Optional[str] = None @@ -30,9 +29,9 @@ def __init__(self, flow_graph: FlowGraph): self.flow_graph = flow_graph self.node_var_mapping: Dict[int, str] = {} # Maps node_id to variable name self.imports: Set[str] = {"import polars as pl"} - self.code_lines: List[str] = [] self.output_nodes = [] self.last_node_var = None + self.code_buffer = StringIO() def convert(self) -> str: """ @@ -510,7 +509,6 @@ def _handle_left_inner_join_keys(self, settings: input_schema.NodeJoin, right_df - reverse_action: Mapping to rename __DROP__ columns after join - after_join_drop_cols: Left join keys marked for dropping """ - left_join_keys_to_keep = [jk.new_name for jk in settings.join_input.left_select.join_key_selects if jk.keep] join_key_duplication_command = [ f'pl.col("{rjk.old_name}").alias("__DROP__{rjk.new_name}__DROP__")' @@ -682,7 +680,7 @@ def _execute_join_with_post_processing(self, settings: input_schema.NodeJoin, va # Convert back to lazy for right joins if settings.join_input.how == 'right': - self._add_code(f".lazy()") + self._add_code(".lazy()") self._add_code(")") @@ -734,7 +732,7 @@ def _handle_pivot_no_index(self, settings: input_schema.NodePivot, var_name: str self._add_code(' .with_columns(pl.lit(1).alias("__temp_index__"))') self._add_code(' .pivot(') self._add_code(f' values="{pivot_input.value_col}",') - self._add_code(f' index=["__temp_index__"],') + self._add_code( ' index=["__temp_index__"],') self._add_code(f' columns="{pivot_input.pivot_column}",') self._add_code(f' aggregate_function="{agg_func}"') self._add_code(" )") @@ -900,7 +898,7 @@ def _handle_record_id(self, settings: input_schema.NodeRecordId, var_name: str, # Row number within groups self._add_code(f"{var_name} = ({input_df}") self._add_code(f" .with_columns(pl.lit(1).alias('{record_input.output_column_name}'))") - self._add_code(f" .with_columns([") + self._add_code( " .with_columns([") self._add_code(f" (pl.cum_count('{record_input.output_column_name}').over({record_input.group_by_columns}) + {record_input.offset} - 1)") self._add_code(f" .alias('{record_input.output_column_name}')") self._add_code("])") @@ -928,24 +926,24 @@ def _handle_cloud_storage_writer(self, settings: input_schema.NodeCloudStorageWr self.imports.add("import flowfile as ff") self._add_code(f"(ff.FlowFrame({input_df})") if output_settings.file_format == "csv": - self._add_code(f' .write_csv_to_cloud_storage(') + self._add_code( ' .write_csv_to_cloud_storage(') self._add_code(f' path="{output_settings.resource_path}",') self._add_code(f' connection_name="{output_settings.connection_name}",') self._add_code(f' delimiter="{output_settings.csv_delimiter}",') self._add_code(f' encoding="{output_settings.csv_encoding}",') self._add_code(f' description="{settings.description}"') elif output_settings.file_format == "parquet": - self._add_code(f' .write_parquet_to_cloud_storage(') + self._add_code( ' .write_parquet_to_cloud_storage(') self._add_code(f' path="{output_settings.resource_path}",') self._add_code(f' connection_name="{output_settings.connection_name}",') self._add_code(f' description="{settings.description}"') elif output_settings.file_format == "json": - self._add_code(f' .write_json_to_cloud_storage(') + self._add_code( ' .write_json_to_cloud_storage(') self._add_code(f' path="{output_settings.resource_path}",') self._add_code(f' connection_name="{output_settings.connection_name}",') self._add_code(f' description="{settings.description}"') elif output_settings.file_format == "delta": - self._add_code(f' .write_delta(') + self._add_code( ' .write_delta(') self._add_code(f' path="{output_settings.resource_path}",') self._add_code(f' write_mode="{output_settings.write_mode}",') self._add_code(f' connection_name="{output_settings.connection_name}",') @@ -1003,7 +1001,7 @@ def _handle_polars_code(self, settings: input_schema.NodePolarsCode, var_name: s is_expression = "output_df" not in code # Wrap the code in a function - self._add_code(f"# Custom Polars code") + self._add_code("# Custom Polars code") self._add_code(f"def _polars_code_{var_name.replace('df_', '')}({params}):") # Handle the code based on its structure @@ -1036,11 +1034,11 @@ def _handle_polars_code(self, settings: input_schema.NodePolarsCode, var_name: s def _add_code(self, line: str) -> None: """Add a line of code.""" - self.code_lines.append(line) + self.code_buffer.write(line + '\n') def _add_comment(self, comment: str) -> None: """Add a comment line.""" - self.code_lines.append(comment) + self.code_buffer.write(comment + '\n') def _parse_filter_expression(self, expr: str) -> str: """Parse Flowfile filter expression to Polars expression.""" @@ -1093,6 +1091,7 @@ def _create_basic_filter_expr(self, basic: transform_schema.BasicFilter) -> str: return f"pl.col('{col}').is_in({values})" return col + @lru_cache(maxsize=1024) def _get_polars_dtype(self, dtype_str: str) -> str: """Convert Flowfile dtype string to Polars dtype.""" dtype_map = { @@ -1110,6 +1109,7 @@ def _get_polars_dtype(self, dtype_str: str) -> str: } return dtype_map.get(dtype_str, 'pl.Utf8') + @lru_cache(maxsize=1024) def _get_agg_function(self, agg: str) -> str: """Get Polars aggregation function name.""" agg_map = { @@ -1181,15 +1181,7 @@ def _build_final_code(self) -> str: lines.append(' Generated from Flowfile') lines.append(' """') lines.append(" ") - - # Add the generated code - for line in self.code_lines: - if line: - lines.append(f" {line}") - else: - lines.append("") - # Add main block - lines.append("") + lines.append(self.code_buffer.getvalue()) self.add_return_code(lines) lines.append("") lines.append("")