diff --git a/flowfile_core/flowfile_core/flowfile/code_generator/__init__.py b/flowfile_core/flowfile_core/flowfile/code_generator/__init__.py index 3616bec84..0d4609087 100644 --- a/flowfile_core/flowfile_core/flowfile/code_generator/__init__.py +++ b/flowfile_core/flowfile_core/flowfile/code_generator/__init__.py @@ -3,9 +3,23 @@ UnsupportedNodeError, export_flow_to_polars, ) +from flowfile_core.flowfile.code_generator.python_script_rewriter import ( + FlowfileUsageAnalysis, + analyze_flowfile_usage, + build_function_code, + extract_imports, + get_required_packages, + rewrite_flowfile_calls, +) __all__ = [ "FlowGraphToPolarsConverter", "UnsupportedNodeError", "export_flow_to_polars", + "FlowfileUsageAnalysis", + "analyze_flowfile_usage", + "build_function_code", + "extract_imports", + "get_required_packages", + "rewrite_flowfile_calls", ] diff --git a/flowfile_core/flowfile_core/flowfile/code_generator/code_generator.py b/flowfile_core/flowfile_core/flowfile/code_generator/code_generator.py index fdbbb6843..c0b3d307c 100644 --- a/flowfile_core/flowfile_core/flowfile/code_generator/code_generator.py +++ b/flowfile_core/flowfile_core/flowfile/code_generator/code_generator.py @@ -52,6 +52,12 @@ def __init__(self, flow_graph: FlowGraph): self.last_node_var = None self.unsupported_nodes = [] self.custom_node_classes = {} + # Track which artifacts have been published: (kernel_id, artifact_name) → node_id + self._published_artifacts: dict[tuple[str, str], int] = {} + # Track if any python_script nodes exist (to emit _artifacts = {} once) + self._has_python_script_nodes: bool = False + # Track which kernel IDs are used (for initializing per-kernel sub-dicts) + self._kernel_ids_used: list[str] = [] def convert(self) -> str: """ @@ -1118,6 +1124,104 @@ def _handle_polars_code( self._add_code(f"{var_name} = _polars_code_{var_name.replace('df_', '')}({args})") self._add_code("") + def _handle_python_script( + self, settings: input_schema.NodePythonScript, var_name: str, input_vars: dict[str, str] + ) -> None: + """Handle python_script nodes by rewriting flowfile.* calls to plain Python.""" + from flowfile_core.flowfile.code_generator.python_script_rewriter import ( + analyze_flowfile_usage, + build_function_code, + extract_imports, + rewrite_flowfile_calls, + ) + + code = settings.python_script_input.code.strip() + kernel_id = settings.python_script_input.kernel_id + node_id = settings.node_id + + # Handle empty code — pass through input + if not code: + if input_vars: + self._add_code(f"{var_name} = {list(input_vars.values())[0]}") + else: + self._add_code(f"{var_name} = pl.LazyFrame()") + return + + # 1. Analyze flowfile usage + try: + analysis = analyze_flowfile_usage(code) + except SyntaxError as e: + self.unsupported_nodes.append(( + node_id, + "python_script", + f"Syntax error in python_script code: {e}" + )) + return + self._has_python_script_nodes = True + effective_kernel_id = kernel_id or "_default" + if effective_kernel_id not in self._kernel_ids_used: + self._kernel_ids_used.append(effective_kernel_id) + + # 2. Validate artifact dependencies are available (same kernel only) + for artifact_name in analysis.artifacts_consumed: + if (effective_kernel_id, artifact_name) not in self._published_artifacts: + self.unsupported_nodes.append(( + node_id, + "python_script", + f"Artifact '{artifact_name}' is consumed but not published by any " + f"upstream node on kernel '{effective_kernel_id}'" + )) + return + + # 4. Extract and register imports + user_imports = extract_imports(code) + for imp in user_imports: + self.imports.add(imp) + + # 5. Add kernel package requirements as comments + if kernel_id: + self._add_kernel_requirements(kernel_id, user_imports) + + # 6. Rewrite the code (kernel_id scopes artifact access) + rewritten, unsupported_markers = rewrite_flowfile_calls(code, analysis, kernel_id=kernel_id) + + # 7. Build and emit the function + func_def, call_code = build_function_code( + node_id, rewritten, analysis, input_vars, + kernel_id=kernel_id, unsupported_markers=unsupported_markers, + ) + + self._add_code(f"# --- Node {node_id}: python_script ---") + for line in func_def.split("\n"): + self._add_code(line) + self._add_code("") + self._add_code(call_code) + + # 8. Track published artifacts for validation of downstream nodes + for artifact_name, _ in analysis.artifacts_published: + self._published_artifacts[(effective_kernel_id, artifact_name)] = node_id + + self._add_code("") + + def _add_kernel_requirements(self, kernel_id: str, user_imports: list[str]) -> None: + """Add a comment block with required packages from kernel config.""" + try: + from flowfile_core.flowfile.code_generator.python_script_rewriter import get_required_packages + from flowfile_core.kernel.manager import get_kernel_manager + + manager = get_kernel_manager() + kernel = manager._kernels.get(kernel_id) + if not kernel or not kernel.packages: + return + + required = get_required_packages(user_imports, kernel.packages) + if required: + self._add_code(f"# Required packages: {', '.join(required)}") + self._add_code(f"# Install with: pip install {' '.join(required)}") + self._add_code("") + except Exception: + pass # Kernel manager not available; skip requirements comment + # Handlers for unsupported node types - these add nodes to the unsupported list def _handle_explore_data( @@ -1639,6 +1743,12 @@ def _build_final_code(self) -> str: lines.append(f" ETL Pipeline: {self.flow_graph.__name__}") lines.append(" Generated from Flowfile") lines.append(' """') + + # Artifact store — one sub-dict per kernel, matching runtime isolation + if self._has_python_script_nodes or self._published_artifacts: + kernel_init = ", ".join(f'"{kid}": {{}}' for kid in self._kernel_ids_used) + lines.append(f" _artifacts = {{{kernel_init}}} # Artifact store (per kernel)") + lines.append(" ") # Add the generated code diff --git a/flowfile_core/flowfile_core/flowfile/code_generator/python_script_rewriter.py b/flowfile_core/flowfile_core/flowfile/code_generator/python_script_rewriter.py new file mode 100644 index 000000000..ba587dd9f --- /dev/null +++ b/flowfile_core/flowfile_core/flowfile/code_generator/python_script_rewriter.py @@ -0,0 +1,638 @@ +""" +AST-based rewriter for python_script node code generation. + +Transforms flowfile.* API calls in user code into plain Python equivalents, +enabling code generation for python_script nodes that normally execute inside +Docker kernel containers. + +Artifacts are scoped per kernel — each kernel gets its own sub-dict inside +``_artifacts``, matching the runtime behaviour where every kernel container +has an independent artifact store. + +Mapping: + flowfile.read_input() → function parameter (input_df) + flowfile.read_inputs() → function parameter (inputs) + flowfile.publish_output(expr) → return statement + flowfile.publish_artifact("n", o) → _artifacts[""]["n"] = o + flowfile.read_artifact("n") → _artifacts[""]["n"] + flowfile.delete_artifact("n") → del _artifacts[""]["n"] + flowfile.list_artifacts() → _artifacts[""] + flowfile.log(msg, level) → print(f"[{level}] {msg}") +""" + +from __future__ import annotations + +import ast +import textwrap +from dataclasses import dataclass, field +from typing import Literal + +# Maps pip package names to their Python import module names +# when they differ from the package name. +PACKAGE_TO_IMPORT_MAP: dict[str, list[str]] = { + "scikit-learn": ["sklearn"], + "pillow": ["PIL"], + "opencv-python": ["cv2"], + "opencv-python-headless": ["cv2"], + "beautifulsoup4": ["bs4"], + "pyyaml": ["yaml"], + "pytorch": ["torch"], + "tensorflow-gpu": ["tensorflow"], +} + + +def get_import_names(package: str) -> list[str]: + """Return the import name(s) for a pip package.""" + return PACKAGE_TO_IMPORT_MAP.get(package, [package.replace("-", "_")]) + + +@dataclass +class FlowfileUsageAnalysis: + """Results of analyzing flowfile.* API usage in user code.""" + + input_mode: Literal["none", "single", "multi"] = "none" + has_read_input: bool = False + has_read_inputs: bool = False + has_output: bool = False + output_exprs: list[ast.expr] = field(default_factory=list) + passthrough_output: bool = False + + artifacts_published: list[tuple[str, ast.expr]] = field(default_factory=list) + artifacts_consumed: list[str] = field(default_factory=list) + artifacts_deleted: list[str] = field(default_factory=list) + + has_logging: bool = False + has_list_artifacts: bool = False + + # For error reporting + dynamic_artifact_names: list[ast.AST] = field(default_factory=list) + unsupported_calls: list[tuple[str, ast.AST]] = field(default_factory=list) + + +def _is_flowfile_call(node: ast.Call, method: str | None = None) -> bool: + """Check if an AST Call node is a flowfile.* method call.""" + if not isinstance(node, ast.Call): + return False + func = node.func + if isinstance(func, ast.Attribute): + if isinstance(func.value, ast.Name) and func.value.id == "flowfile": + if method is None: + return True + return func.attr == method + return False + + +def _is_passthrough_output(node: ast.Call) -> bool: + """Check if publish_output argument is flowfile.read_input().""" + if not node.args: + return False + arg = node.args[0] + return _is_flowfile_call(arg, "read_input") if isinstance(arg, ast.Call) else False + + +class _FlowfileUsageVisitor(ast.NodeVisitor): + """Walk the AST to collect information about flowfile.* API usage.""" + + def __init__(self) -> None: + self.analysis = FlowfileUsageAnalysis() + + def visit_Call(self, node: ast.Call) -> None: + if _is_flowfile_call(node): + method = node.func.attr + if method == "read_input": + self.analysis.has_read_input = True + if not self.analysis.has_read_inputs: + self.analysis.input_mode = "single" + elif method == "read_inputs": + self.analysis.has_read_inputs = True + self.analysis.input_mode = "multi" + elif method == "publish_output": + self.analysis.has_output = True + if node.args: + self.analysis.output_exprs.append(node.args[0]) + if _is_passthrough_output(node): + self.analysis.passthrough_output = True + elif method == "publish_artifact": + if len(node.args) >= 2: + name_node = node.args[0] + if isinstance(name_node, ast.Constant) and isinstance(name_node.value, str): + self.analysis.artifacts_published.append((name_node.value, node.args[1])) + else: + self.analysis.dynamic_artifact_names.append(node) + elif method == "read_artifact": + if node.args: + name_node = node.args[0] + if isinstance(name_node, ast.Constant) and isinstance(name_node.value, str): + self.analysis.artifacts_consumed.append(name_node.value) + else: + self.analysis.dynamic_artifact_names.append(node) + elif method == "delete_artifact": + if node.args: + name_node = node.args[0] + if isinstance(name_node, ast.Constant) and isinstance(name_node.value, str): + self.analysis.artifacts_deleted.append(name_node.value) + else: + self.analysis.dynamic_artifact_names.append(node) + elif method == "log" or method in ("log_info", "log_warning", "log_error"): + self.analysis.has_logging = True + elif method == "list_artifacts": + self.analysis.has_list_artifacts = True + elif method in ("display", "publish_global", "get_global", + "list_global_artifacts", "delete_global_artifact"): + self.analysis.unsupported_calls.append((method, node)) + self.generic_visit(node) + + +def analyze_flowfile_usage(code: str) -> FlowfileUsageAnalysis: + """Parse user code and analyze flowfile.* API usage. + + Args: + code: The raw Python source code from a python_script node. + + Returns: + FlowfileUsageAnalysis with details about how the flowfile API is used. + + Raises: + SyntaxError: If the code cannot be parsed. + """ + tree = ast.parse(code) + visitor = _FlowfileUsageVisitor() + visitor.visit(tree) + return visitor.analysis + + +class _FlowfileCallRewriter(ast.NodeTransformer): + """Rewrite flowfile.* API calls to plain Python equivalents. + + Artifact operations are scoped to a kernel-specific sub-dict so that + each kernel's artifacts stay isolated, matching runtime semantics. + """ + + def __init__(self, analysis: FlowfileUsageAnalysis, kernel_id: str | None = None) -> None: + self.analysis = analysis + self.kernel_id = kernel_id or "_default" + self.input_var = "input_df" if analysis.input_mode == "single" else "inputs" + self._last_output_expr: ast.expr | None = None + self._unsupported_markers: dict[str, str] = {} + # Track which publish_output call is the last one + if analysis.output_exprs: + self._last_output_expr = analysis.output_exprs[-1] + + # --- helpers for kernel-scoped artifact access --- + + def _kernel_artifacts_node(self, ctx: type[ast.expr_context] = ast.Load) -> ast.Subscript: + """Build ``_artifacts[""]`` AST node.""" + return ast.Subscript( + value=ast.Name(id="_artifacts", ctx=ast.Load()), + slice=ast.Constant(value=self.kernel_id), + ctx=ctx(), + ) + + def _artifact_subscript(self, name_node: ast.expr, ctx: type[ast.expr_context] = ast.Load) -> ast.Subscript: + """Build ``_artifacts[""][""]`` AST node.""" + return ast.Subscript( + value=self._kernel_artifacts_node(), + slice=name_node, + ctx=ctx(), + ) + + # --- visitors --- + + def visit_Call(self, node: ast.Call) -> ast.AST: + # First transform any nested calls + self.generic_visit(node) + + if not _is_flowfile_call(node): + return node + + method = node.func.attr + + if method == "read_input": + if self.analysis.input_mode == "multi": + # Both read_input and read_inputs used — read_input() → inputs["main"][0] + return ast.Subscript( + value=ast.Subscript( + value=ast.Name(id="inputs", ctx=ast.Load()), + slice=ast.Constant(value="main"), + ctx=ast.Load(), + ), + slice=ast.Constant(value=0), + ctx=ast.Load(), + ) + # flowfile.read_input() → input_df + return ast.Name(id=self.input_var, ctx=ast.Load()) + + if method == "read_inputs": + # flowfile.read_inputs() → inputs + return ast.Name(id=self.input_var, ctx=ast.Load()) + + if method == "read_artifact": + # flowfile.read_artifact("name") → _artifacts["kernel_id"]["name"] + return self._artifact_subscript(node.args[0]) + + if method == "list_artifacts": + # flowfile.list_artifacts() → _artifacts["kernel_id"] + return self._kernel_artifacts_node() + + if method == "log": + return self._make_log_print(node) + + if method == "log_info": + return self._make_log_print_with_level(node, "INFO") + + if method == "log_warning": + return self._make_log_print_with_level(node, "WARNING") + + if method == "log_error": + return self._make_log_print_with_level(node, "ERROR") + + return node + + def visit_Expr(self, node: ast.Expr) -> ast.AST | None: + # Transform nested calls first + self.generic_visit(node) + + if not isinstance(node.value, ast.Call): + return node + + call = node.value + if not _is_flowfile_call(call): + return node + + method = call.func.attr + + if method == "publish_output": + # Remove publish_output statements — we handle via return + return None + + if method == "publish_artifact": + # flowfile.publish_artifact("name", obj) → _artifacts["kernel_id"]["name"] = obj + if len(call.args) >= 2: + return ast.Assign( + targets=[self._artifact_subscript(call.args[0], ctx=ast.Store)], + value=call.args[1], + lineno=node.lineno, + col_offset=node.col_offset, + ) + return node + + if method == "delete_artifact": + # flowfile.delete_artifact("name") → del _artifacts["kernel_id"]["name"] + if call.args: + return ast.Delete( + targets=[self._artifact_subscript(call.args[0], ctx=ast.Del)], + lineno=node.lineno, + col_offset=node.col_offset, + ) + return node + + # Unsupported calls → replace with a marker that becomes a comment + if _is_flowfile_call(call): + source = ast.unparse(node.value) + marker = f"__FLOWFILE_UNSUPPORTED_{len(self._unsupported_markers)}__" + self._unsupported_markers[marker] = source + return ast.Expr(value=ast.Name(id=marker, ctx=ast.Load())) + + return node + + @staticmethod + def _make_log_print(node: ast.Call) -> ast.Call: + """Transform flowfile.log("msg", "LEVEL") → print("[LEVEL] msg").""" + msg_arg = node.args[0] if node.args else ast.Constant(value="") + + # Get level from second arg or keyword + level: ast.expr | None = None + if len(node.args) >= 2: + level = node.args[1] + else: + for kw in node.keywords: + if kw.arg == "level": + level = kw.value + break + + if level is None: + level = ast.Constant(value="INFO") + + # Build print(f"[{level}] {msg}") + format_str = ast.JoinedStr( + values=[ + ast.Constant(value="["), + ast.FormattedValue(value=level, conversion=-1), + ast.Constant(value="] "), + ast.FormattedValue(value=msg_arg, conversion=-1), + ] + ) + + return ast.Call( + func=ast.Name(id="print", ctx=ast.Load()), + args=[format_str], + keywords=[], + ) + + @staticmethod + def _make_log_print_with_level(node: ast.Call, level_str: str) -> ast.Call: + """Transform flowfile.log_info("msg") → print("[INFO] msg").""" + msg_arg = node.args[0] if node.args else ast.Constant(value="") + + format_str = ast.JoinedStr( + values=[ + ast.Constant(value=f"[{level_str}] "), + ast.FormattedValue(value=msg_arg, conversion=-1), + ] + ) + + return ast.Call( + func=ast.Name(id="print", ctx=ast.Load()), + args=[format_str], + keywords=[], + ) + + +def rewrite_flowfile_calls( + code: str, + analysis: FlowfileUsageAnalysis, + kernel_id: str | None = None, +) -> tuple[str, dict[str, str]]: + """Rewrite flowfile.* API calls in user code to plain Python. + + This removes/replaces flowfile API calls but does NOT add function + wrapping, return statements, or import stripping. Those are handled + by ``build_function_code``. + + Args: + code: The raw Python source from a python_script node. + analysis: Pre-computed analysis of flowfile usage. + kernel_id: The kernel ID for scoping artifact operations. + + Returns: + Tuple of (rewritten source code, unsupported call markers dict). + Markers are placeholder variable names mapped to the original source + of unsupported flowfile calls. Callers should replace them with + comments after all AST processing is complete. + """ + tree = ast.parse(code) + rewriter = _FlowfileCallRewriter(analysis, kernel_id=kernel_id) + new_tree = rewriter.visit(tree) + # Remove None nodes (deleted statements) + new_tree.body = [node for node in new_tree.body if node is not None] + ast.fix_missing_locations(new_tree) + result = ast.unparse(new_tree) + + return result, rewriter._unsupported_markers + + +def extract_imports(code: str) -> list[str]: + """Extract import statements from user code, excluding flowfile imports. + + Args: + code: The raw Python source code. + + Returns: + List of import statement strings (each is a full import line). + """ + tree = ast.parse(code) + imports: list[str] = [] + for node in ast.iter_child_nodes(tree): + if isinstance(node, ast.Import): + # Filter out "import flowfile" + non_flowfile_aliases = [alias for alias in node.names if alias.name != "flowfile"] + if non_flowfile_aliases: + # Reconstruct import with only non-flowfile names + filtered = ast.Import(names=non_flowfile_aliases) + ast.fix_missing_locations(filtered) + imports.append(ast.unparse(filtered)) + elif isinstance(node, ast.ImportFrom): + if node.module and "flowfile" not in node.module: + imports.append(ast.unparse(node)) + elif node.module is None: + imports.append(ast.unparse(node)) + return imports + + +def _strip_imports_and_flowfile(code: str) -> str: + """Remove import statements and flowfile import from code body. + + Returns the code with all top-level import/from-import statements removed. + """ + tree = ast.parse(code) + new_body = [] + for node in tree.body: + if isinstance(node, ast.Import): + # Keep non-flowfile imports? No — imports are extracted separately + continue + elif isinstance(node, ast.ImportFrom): + continue + else: + new_body.append(node) + tree.body = new_body + if not tree.body: + return "" + ast.fix_missing_locations(tree) + return ast.unparse(tree) + + +def build_function_code( + node_id: int, + rewritten_code: str, + analysis: FlowfileUsageAnalysis, + input_vars: dict[str, str], + kernel_id: str | None = None, + unsupported_markers: dict[str, str] | None = None, +) -> tuple[str, str]: + """Assemble rewritten code into a function definition and call. + + Args: + node_id: The node ID for naming. + rewritten_code: The AST-rewritten code (from rewrite_flowfile_calls, + with imports already stripped). + analysis: The flowfile usage analysis. + input_vars: Mapping of input names to variable names from upstream nodes. + kernel_id: The kernel ID (used to scope return expressions). + unsupported_markers: Marker→source mapping from rewrite_flowfile_calls. + These are replaced with comments in the final output. + + Returns: + Tuple of (function_definition, call_code). + E.g.: + ("def _node_5(input_df: pl.LazyFrame) -> pl.LazyFrame:\\n ...", + "df_5 = _node_5(df_3)") + """ + func_name = f"_node_{node_id}" + var_name = f"df_{node_id}" + + # Build parameter list and arguments + params: list[str] = [] + args: list[str] = [] + + if analysis.input_mode == "single": + params.append("input_df: pl.LazyFrame") + main_var = input_vars.get("main") + if main_var is None: + # Multiple main inputs — pick first + for k in sorted(input_vars.keys()): + if k.startswith("main"): + main_var = input_vars[k] + break + args.append(main_var or "pl.LazyFrame()") + elif analysis.input_mode == "multi": + # Runtime returns dict[str, list[pl.LazyFrame]] — each input name + # maps to a *list* of LazyFrames (multiple connections can share a name). + params.append("inputs: dict[str, list[pl.LazyFrame]]") + # Group input_vars by their base name (strip _0, _1 suffixes). + grouped = _group_input_vars(input_vars) + dict_entries = ", ".join( + f'"{k}": [{", ".join(vs)}]' for k, vs in sorted(grouped.items()) + ) + args.append("{" + dict_entries + "}") + + param_str = ", ".join(params) + return_type = "pl.LazyFrame" if params else "pl.LazyFrame | None" + + # Build function body + body_lines: list[str] = [] + + # Add warnings for unsupported calls / dynamic artifact names + if analysis.unsupported_calls: + methods = sorted({m for m, _ in analysis.unsupported_calls}) + body_lines.append(f"# WARNING: The following flowfile API calls are not supported in code generation") + body_lines.append(f"# and will not work outside the kernel runtime: {', '.join(methods)}") + if analysis.dynamic_artifact_names: + body_lines.append("# WARNING: Dynamic artifact names detected — these may not resolve correctly") + + # Strip imports from rewritten code (they go to top-level) + body_code = _strip_imports_and_flowfile(rewritten_code) + + if body_code: + for line in body_code.split("\n"): + body_lines.append(line) + + # Add return statement + if analysis.has_output and analysis.output_exprs: + last_expr = analysis.output_exprs[-1] + if analysis.passthrough_output and analysis.input_mode == "single": + # publish_output(read_input()) → return input_df + body_lines.append("return input_df") + else: + output_return = _build_return_for_output(last_expr, analysis, kernel_id=kernel_id) + body_lines.append(output_return) + elif analysis.input_mode == "single": + # No explicit output — pass through input + body_lines.append("return input_df") + elif analysis.input_mode == "multi": + # Pass through first input list + first_key = sorted(input_vars.keys())[0] if input_vars else "main" + # Strip _0 suffix to get the base name + base_key = _base_input_name(first_key) + body_lines.append(f'return inputs["{base_key}"][0]') + elif not params: + body_lines.append("return None") + + if not body_lines: + body_lines.append("pass") + + # Assemble function definition + indented_body = textwrap.indent("\n".join(body_lines), " ") + func_def = f"def {func_name}({param_str}) -> {return_type}:\n{indented_body}" + + # Replace unsupported-call markers with comments + if unsupported_markers: + for marker, source in unsupported_markers.items(): + func_def = func_def.replace(marker, f"# {source} # not supported outside kernel runtime") + + # Build call + arg_str = ", ".join(args) + call_code = f"{var_name} = {func_name}({arg_str})" + + return func_def, call_code + + +def _base_input_name(key: str) -> str: + """Strip numeric suffix from input var keys: 'main_0' → 'main'.""" + parts = key.rsplit("_", 1) + if len(parts) == 2 and parts[1].isdigit(): + return parts[0] + return key + + +def _group_input_vars(input_vars: dict[str, str]) -> dict[str, list[str]]: + """Group input variable names by their base name. + + E.g. {"main_0": "df_1", "main_1": "df_3"} → {"main": ["df_1", "df_3"]} + {"main": "df_1"} → {"main": ["df_1"]} + """ + grouped: dict[str, list[str]] = {} + for key, var in sorted(input_vars.items()): + base = _base_input_name(key) + grouped.setdefault(base, []).append(var) + return grouped + + +def _build_return_for_output( + output_expr: ast.expr, + analysis: FlowfileUsageAnalysis, + kernel_id: str | None = None, +) -> str: + """Build a return statement from a publish_output expression. + + The expression is the original AST node from publish_output(expr). + We need to rewrite any flowfile calls in it and then produce the return. + """ + # Create a temporary module to transform the expression + temp_code = ast.unparse(output_expr) + + # Check if it's just a variable name — common pattern like publish_output(result) + # In that case, ensure .lazy() is called for DataFrame returns + # We add .lazy() wrapper as a safety measure for DataFrames + rewriter = _FlowfileCallRewriter(analysis, kernel_id=kernel_id) + expr_tree = ast.parse(temp_code, mode="eval") + new_expr = rewriter.visit(expr_tree) + ast.fix_missing_locations(new_expr) + rewritten = ast.unparse(new_expr) + + # Check if the expression already has .lazy() call + if rewritten.endswith(".lazy()"): + return f"return {rewritten}" + + # If the expression is just a variable that likely holds a DataFrame, + # wrap with .lazy() to ensure LazyFrame return + return f"return {rewritten}.lazy()" + + +def get_required_packages( + user_imports: list[str], + kernel_packages: list[str], +) -> list[str]: + """Cross-reference user imports with kernel packages. + + Args: + user_imports: Import statement strings from user code. + kernel_packages: Package names from kernel configuration. + + Returns: + Sorted list of kernel packages that are actually used. + """ + # Build reverse map: import_name → package_name + import_to_package: dict[str, str] = {} + for pkg in kernel_packages: + for imp_name in get_import_names(pkg): + import_to_package[imp_name] = pkg + + # Parse user imports to get root module names + used_packages: set[str] = set() + for imp_str in user_imports: + try: + tree = ast.parse(imp_str) + except SyntaxError: + continue + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + root_module = alias.name.split(".")[0] + if root_module in import_to_package: + used_packages.add(import_to_package[root_module]) + elif isinstance(node, ast.ImportFrom) and node.module: + root_module = node.module.split(".")[0] + if root_module in import_to_package: + used_packages.add(import_to_package[root_module]) + + return sorted(used_packages) diff --git a/flowfile_core/tests/flowfile/test_code_generator_python_script.py b/flowfile_core/tests/flowfile/test_code_generator_python_script.py new file mode 100644 index 000000000..93cc84d92 --- /dev/null +++ b/flowfile_core/tests/flowfile/test_code_generator_python_script.py @@ -0,0 +1,725 @@ +""" +Integration tests for python_script node code generation. + +These tests verify that the FlowGraphToPolarsConverter correctly handles +python_script nodes by building FlowGraphs with python_script nodes +and checking the generated code. +""" + +import polars as pl +import pytest + +from flowfile_core.flowfile.code_generator.code_generator import ( + FlowGraphToPolarsConverter, + UnsupportedNodeError, + export_flow_to_polars, +) +from flowfile_core.flowfile.flow_graph import FlowGraph, add_connection +from flowfile_core.schemas import input_schema, schemas, transform_schema + + +def create_flow_settings(flow_id: int = 1) -> schemas.FlowSettings: + """Create basic flow settings for tests.""" + return schemas.FlowSettings( + flow_id=flow_id, + execution_mode="Performance", + execution_location="local", + path="/tmp/test_flow", + ) + + +def create_basic_flow(flow_id: int = 1, name: str = "test_flow") -> FlowGraph: + """Create a basic flow graph for testing.""" + return FlowGraph(flow_settings=create_flow_settings(flow_id), name=name) + + +def verify_code_executes(code: str): + """Execute generated code and verify no exceptions are raised.""" + exec_globals = {} + try: + exec(code, exec_globals) + _ = exec_globals["run_etl_pipeline"]() + except Exception as e: + raise AssertionError(f"Code execution failed:\n{e}\n\nGenerated code:\n{code}") + + +def get_result_from_generated_code(code: str): + """Execute generated code and return the result.""" + exec_globals = {} + exec(code, exec_globals) + return exec_globals["run_etl_pipeline"]() + + +def add_manual_input_node(flow: FlowGraph, node_id: int = 1) -> None: + """Add a manual input node with sample data.""" + raw_data = input_schema.RawData( + columns=[ + input_schema.MinimalFieldInfo(name="id", data_type="Int64"), + input_schema.MinimalFieldInfo(name="name", data_type="String"), + input_schema.MinimalFieldInfo(name="value", data_type="Float64"), + ], + data=[[1, 2, 3], ["Alice", "Bob", "Charlie"], [10.0, 20.0, 30.0]], + ) + settings = input_schema.NodeManualInput( + flow_id=1, + node_id=node_id, + raw_data_format=raw_data, + ) + flow.add_manual_input(settings) + + +def add_python_script_node( + flow: FlowGraph, + node_id: int, + code: str, + depending_on_ids: list[int] | None = None, + kernel_id: str | None = None, +) -> None: + """Add a python_script node to the flow.""" + settings = input_schema.NodePythonScript( + flow_id=1, + node_id=node_id, + depending_on_ids=depending_on_ids or [], + python_script_input=input_schema.PythonScriptInput( + code=code, + kernel_id=kernel_id, + ), + ) + flow.add_python_script(settings) + + +def connect_nodes(flow: FlowGraph, from_id: int, to_id: int, input_type: str = "main") -> None: + """Connect two nodes.""" + connection = input_schema.NodeConnection.create_from_simple_input(from_id, to_id, input_type) + add_connection(flow, connection) + + +# --------------------------------------------------------------------------- +# Basic python_script code generation tests +# --------------------------------------------------------------------------- + + +class TestSimplePythonScriptGeneration: + """Test basic python_script node code generation.""" + + def test_simple_passthrough(self): + """Python script that reads input and publishes it unchanged.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + code = ( + "df = flowfile.read_input()\n" + "flowfile.publish_output(df)\n" + ) + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1]) + connect_nodes(flow, 1, 2) + + generated = export_flow_to_polars(flow) + + assert "flowfile" not in generated + assert "def _node_2" in generated + assert "input_df" in generated + assert "_node_2(" in generated + + verify_code_executes(generated) + + def test_transform_with_collect(self): + """Python script that collects, transforms, and publishes output.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + code = ( + "import polars as pl\n" + "df = flowfile.read_input().collect()\n" + "result = df.with_columns(pl.col('value') * 2)\n" + "flowfile.publish_output(result)\n" + ) + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1]) + connect_nodes(flow, 1, 2) + + generated = export_flow_to_polars(flow) + + assert "flowfile" not in generated + assert "input_df.collect()" in generated + assert "def _node_2" in generated + + verify_code_executes(generated) + + def test_empty_code_passthrough(self): + """Empty python_script code should pass through.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + add_python_script_node(flow, node_id=2, code="", depending_on_ids=[1]) + connect_nodes(flow, 1, 2) + + generated = export_flow_to_polars(flow) + verify_code_executes(generated) + + def test_no_output_passthrough(self): + """Script without publish_output should pass through input.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + code = "df = flowfile.read_input().collect()\nx = len(df)\n" + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1]) + connect_nodes(flow, 1, 2) + + generated = export_flow_to_polars(flow) + + assert "return input_df" in generated + verify_code_executes(generated) + + def test_passthrough_output_pattern(self): + """publish_output(read_input()) should generate return input_df.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + code = ( + "df = flowfile.read_input().collect()\n" + "x = len(df)\n" + "flowfile.publish_output(flowfile.read_input())\n" + ) + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1]) + connect_nodes(flow, 1, 2) + + generated = export_flow_to_polars(flow) + + assert "return input_df" in generated + verify_code_executes(generated) + + +# --------------------------------------------------------------------------- +# Artifact tests +# --------------------------------------------------------------------------- + + +class TestArtifactCodeGeneration: + """Test artifact publish/consume code generation.""" + + def test_artifact_publish(self): + """publish_artifact becomes _artifacts[kernel_id] assignment.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + code = ( + "df = flowfile.read_input().collect()\n" + "model = {'trained': True}\n" + 'flowfile.publish_artifact("my_model", model)\n' + "flowfile.publish_output(flowfile.read_input())\n" + ) + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1], kernel_id="k1") + connect_nodes(flow, 1, 2) + + generated = export_flow_to_polars(flow) + + assert "_artifacts" in generated + assert "my_model" in generated + assert "k1" in generated + assert "flowfile" not in generated + + verify_code_executes(generated) + + def test_artifact_chain_same_kernel(self): + """Artifacts flow correctly between two python_script nodes on same kernel.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + # Producer node + producer_code = ( + "info = {'count': 42}\n" + 'flowfile.publish_artifact("info", info)\n' + "flowfile.publish_output(flowfile.read_input())\n" + ) + add_python_script_node(flow, node_id=2, code=producer_code, depending_on_ids=[1], kernel_id="k1") + connect_nodes(flow, 1, 2) + + # Consumer node — same kernel + consumer_code = ( + 'info = flowfile.read_artifact("info")\n' + "df = flowfile.read_input().collect()\n" + "flowfile.publish_output(df)\n" + ) + add_python_script_node(flow, node_id=3, code=consumer_code, depending_on_ids=[2], kernel_id="k1") + connect_nodes(flow, 2, 3) + + generated = export_flow_to_polars(flow) + + assert "_artifacts" in generated + assert "info" in generated + assert "flowfile" not in generated + + verify_code_executes(generated) + + def test_artifact_cross_kernel_error(self): + """Consuming an artifact from a different kernel should fail.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + # Producer on kernel k1 + producer_code = ( + 'flowfile.publish_artifact("model", {"x": 1})\n' + "flowfile.publish_output(flowfile.read_input())\n" + ) + add_python_script_node(flow, node_id=2, code=producer_code, depending_on_ids=[1], kernel_id="k1") + connect_nodes(flow, 1, 2) + + # Consumer on kernel k2 — different kernel, should fail + consumer_code = ( + 'model = flowfile.read_artifact("model")\n' + "flowfile.publish_output(flowfile.read_input())\n" + ) + add_python_script_node(flow, node_id=3, code=consumer_code, depending_on_ids=[2], kernel_id="k2") + connect_nodes(flow, 2, 3) + + with pytest.raises(UnsupportedNodeError, match="model"): + export_flow_to_polars(flow) + + def test_artifact_delete(self): + """delete_artifact becomes del _artifacts[kernel][...].""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + code = ( + "obj = {'x': 1}\n" + 'flowfile.publish_artifact("temp", obj)\n' + 'flowfile.delete_artifact("temp")\n' + "flowfile.publish_output(flowfile.read_input())\n" + ) + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1], kernel_id="k1") + connect_nodes(flow, 1, 2) + + generated = export_flow_to_polars(flow) + + assert "del _artifacts" in generated + assert "temp" in generated + assert "k1" in generated + verify_code_executes(generated) + + def test_unconsumed_artifact_error(self): + """Consuming an artifact not published upstream should fail.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + code = ( + 'model = flowfile.read_artifact("missing_model")\n' + "flowfile.publish_output(flowfile.read_input())\n" + ) + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1]) + connect_nodes(flow, 1, 2) + + with pytest.raises(UnsupportedNodeError, match="missing_model"): + export_flow_to_polars(flow) + + +# --------------------------------------------------------------------------- +# Logging tests +# --------------------------------------------------------------------------- + + +class TestLoggingCodeGeneration: + """Test that flowfile.log becomes print.""" + + def test_log_becomes_print(self): + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + code = ( + 'flowfile.log("processing data")\n' + "flowfile.publish_output(flowfile.read_input())\n" + ) + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1]) + connect_nodes(flow, 1, 2) + + generated = export_flow_to_polars(flow) + + assert "print" in generated + assert "flowfile" not in generated + verify_code_executes(generated) + + +# --------------------------------------------------------------------------- +# Error handling tests +# --------------------------------------------------------------------------- + + +class TestErrorHandling: + """Test error cases in python_script code generation.""" + + def test_dynamic_artifact_name(self): + """Dynamic artifact names should produce a warning comment, not an error.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + code = ( + "name = 'model'\n" + "flowfile.read_artifact(name)\n" + "flowfile.publish_output(flowfile.read_input())\n" + ) + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1]) + connect_nodes(flow, 1, 2) + + result = export_flow_to_polars(flow) + assert "WARNING" in result + assert "Dynamic artifact names" in result + + def test_syntax_error_in_code(self): + """Syntax errors should produce UnsupportedNodeError.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + code = "def foo(:\n" + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1]) + connect_nodes(flow, 1, 2) + + with pytest.raises(UnsupportedNodeError, match="Syntax error"): + export_flow_to_polars(flow) + + def test_unsupported_display_call(self): + """flowfile.display should produce a warning comment, not an error.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + code = ( + "flowfile.display('hello')\n" + "flowfile.publish_output(flowfile.read_input())\n" + ) + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1]) + connect_nodes(flow, 1, 2) + + result = export_flow_to_polars(flow) + assert "WARNING" in result + assert "not supported in code generation" in result + assert "display" in result + + +# --------------------------------------------------------------------------- +# Import handling tests +# --------------------------------------------------------------------------- + + +class TestImportHandling: + """Test that imports from python_script nodes are handled correctly.""" + + def test_user_imports_added(self): + """User imports should appear in generated code.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + code = ( + "import json\n" + "df = flowfile.read_input().collect()\n" + "data = json.dumps({'count': len(df)})\n" + "flowfile.publish_output(flowfile.read_input())\n" + ) + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1]) + connect_nodes(flow, 1, 2) + + generated = export_flow_to_polars(flow) + + assert "import json" in generated + assert "flowfile" not in generated + + def test_flowfile_import_excluded(self): + """import flowfile should not appear in generated code.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + code = ( + "import flowfile\n" + "import json\n" + "flowfile.publish_output(flowfile.read_input())\n" + ) + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1]) + connect_nodes(flow, 1, 2) + + generated = export_flow_to_polars(flow) + + # "import flowfile" should not be present; "import json" should be + lines = generated.split("\n") + assert not any(line.strip() == "import flowfile" for line in lines) + assert "import json" in generated + + +# --------------------------------------------------------------------------- +# Mixed node type tests +# --------------------------------------------------------------------------- + + +class TestMixedNodeTypes: + """Test flows mixing python_script with other node types.""" + + def test_manual_input_then_python_script(self): + """Manual input → python_script → output.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + code = ( + "import polars as pl\n" + "df = flowfile.read_input().collect()\n" + "result = df.with_columns(pl.lit('new').alias('new_col'))\n" + "flowfile.publish_output(result)\n" + ) + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1]) + connect_nodes(flow, 1, 2) + + generated = export_flow_to_polars(flow) + verify_code_executes(generated) + + result = get_result_from_generated_code(generated) + if hasattr(result, "collect"): + result = result.collect() + assert "new_col" in result.columns + + def test_python_script_then_filter(self): + """python_script → filter node.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + code = ( + "flowfile.publish_output(flowfile.read_input())\n" + ) + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1]) + connect_nodes(flow, 1, 2) + + # Add filter node + filter_settings = input_schema.NodeFilter( + flow_id=1, + node_id=3, + depending_on_id=2, + filter_input=transform_schema.FilterInput( + mode="basic", + basic_filter=transform_schema.BasicFilter( + field="value", + operator=transform_schema.FilterOperator.GREATER_THAN, + value="15", + ), + ), + ) + flow.add_filter(filter_settings) + connect_nodes(flow, 2, 3) + + generated = export_flow_to_polars(flow) + verify_code_executes(generated) + + def test_multiple_python_script_nodes(self): + """Chain of multiple python_script nodes.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + code1 = ( + "import polars as pl\n" + "df = flowfile.read_input().collect()\n" + "result = df.with_columns(pl.col('value').alias('doubled'))\n" + "flowfile.publish_output(result)\n" + ) + add_python_script_node(flow, node_id=2, code=code1, depending_on_ids=[1]) + connect_nodes(flow, 1, 2) + + code2 = ( + "df = flowfile.read_input().collect()\n" + "flowfile.publish_output(df)\n" + ) + add_python_script_node(flow, node_id=3, code=code2, depending_on_ids=[2]) + connect_nodes(flow, 2, 3) + + generated = export_flow_to_polars(flow) + + assert "def _node_2" in generated + assert "def _node_3" in generated + verify_code_executes(generated) + + +# --------------------------------------------------------------------------- +# Artifacts store initialization test +# --------------------------------------------------------------------------- + + +class TestArtifactStoreInitialization: + """Test that _artifacts is emitted properly with per-kernel sub-dicts.""" + + def test_artifacts_dict_emitted_for_python_script(self): + """_artifacts should appear with kernel sub-dict when python_script nodes exist.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + code = "flowfile.publish_output(flowfile.read_input())\n" + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1], kernel_id="k1") + connect_nodes(flow, 1, 2) + + generated = export_flow_to_polars(flow) + assert '_artifacts = {"k1": {}}' in generated + + def test_multiple_kernels_initialized(self): + """Multiple kernels should each get their own sub-dict.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + code = "flowfile.publish_output(flowfile.read_input())\n" + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1], kernel_id="k1") + connect_nodes(flow, 1, 2) + + add_python_script_node(flow, node_id=3, code=code, depending_on_ids=[2], kernel_id="k2") + connect_nodes(flow, 2, 3) + + generated = export_flow_to_polars(flow) + assert "k1" in generated + assert "k2" in generated + # Both should be initialized as empty dicts + assert "_artifacts" in generated + verify_code_executes(generated) + + def test_no_artifacts_dict_without_python_script(self): + """_artifacts should NOT appear when no python_script nodes exist.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + # Add a filter node (not python_script) + filter_settings = input_schema.NodeFilter( + flow_id=1, + node_id=2, + depending_on_id=1, + filter_input=transform_schema.FilterInput( + mode="basic", + basic_filter=transform_schema.BasicFilter( + field="value", + operator=transform_schema.FilterOperator.GREATER_THAN, + value="15", + ), + ), + ) + flow.add_filter(filter_settings) + connect_nodes(flow, 1, 2) + + generated = export_flow_to_polars(flow) + assert "_artifacts" not in generated + + +# --------------------------------------------------------------------------- +# Full pipeline test matching the spec's appendix example +# --------------------------------------------------------------------------- + + +class TestFullPipelineExample: + """Test the complete example from the specification.""" + + def test_train_predict_pipeline(self): + """Simulate train → predict pipeline with artifacts on same kernel.""" + flow = create_basic_flow() + + # Node 1: Manual input with training data + raw_data = input_schema.RawData( + columns=[ + input_schema.MinimalFieldInfo(name="f1", data_type="Float64"), + input_schema.MinimalFieldInfo(name="f2", data_type="Float64"), + input_schema.MinimalFieldInfo(name="target", data_type="Int64"), + ], + data=[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [0, 1, 0, 1]], + ) + settings = input_schema.NodeManualInput( + flow_id=1, node_id=1, raw_data_format=raw_data + ) + flow.add_manual_input(settings) + + # Node 2: Train model + train_code = ( + "import polars as pl\n" + "df = flowfile.read_input().collect()\n" + "model = {'trained': True, 'n_features': 2}\n" + 'flowfile.publish_artifact("model", model)\n' + "flowfile.publish_output(flowfile.read_input())\n" + ) + add_python_script_node(flow, node_id=2, code=train_code, depending_on_ids=[1], kernel_id="ml") + connect_nodes(flow, 1, 2) + + # Node 3: Use model (consume artifact — same kernel) + predict_code = ( + "import polars as pl\n" + 'model = flowfile.read_artifact("model")\n' + "df = flowfile.read_input().collect()\n" + "result = df.with_columns(pl.lit(model['n_features']).alias('n_features'))\n" + "flowfile.publish_output(result)\n" + ) + add_python_script_node(flow, node_id=3, code=predict_code, depending_on_ids=[2], kernel_id="ml") + connect_nodes(flow, 2, 3) + + generated = export_flow_to_polars(flow) + + # Verify structure — kernel-scoped artifacts + assert "def _node_2" in generated + assert "def _node_3" in generated + assert "_artifacts" in generated + assert "ml" in generated + assert "model" in generated + assert "flowfile" not in generated + + # Verify execution + verify_code_executes(generated) + + result = get_result_from_generated_code(generated) + if hasattr(result, "collect"): + result = result.collect() + assert "n_features" in result.columns + + def test_list_artifacts_usage(self): + """Test that list_artifacts becomes _artifacts[kernel_id] reference.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + code = ( + "obj = {'x': 1}\n" + 'flowfile.publish_artifact("item", obj)\n' + "arts = flowfile.list_artifacts()\n" + "flowfile.publish_output(flowfile.read_input())\n" + ) + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1], kernel_id="k1") + connect_nodes(flow, 1, 2) + + generated = export_flow_to_polars(flow) + + assert "flowfile" not in generated + assert "k1" in generated + verify_code_executes(generated) + + def test_node_comment_header(self): + """Generated code should include node header comments.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + code = "flowfile.publish_output(flowfile.read_input())\n" + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1]) + connect_nodes(flow, 1, 2) + + generated = export_flow_to_polars(flow) + + assert "# --- Node 2: python_script ---" in generated + + +# --------------------------------------------------------------------------- +# Multi-input (read_inputs) tests +# --------------------------------------------------------------------------- + + +class TestMultiInputCodeGeneration: + """Test python_script nodes with multiple inputs (read_inputs).""" + + def test_read_inputs_dict_structure(self): + """read_inputs should produce dict[str, list[pl.LazyFrame]] matching runtime.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + add_manual_input_node(flow, node_id=3) + + code = ( + "dfs = flowfile.read_inputs()\n" + "df = dfs['main'][0].collect()\n" + "flowfile.publish_output(df)\n" + ) + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1, 3]) + connect_nodes(flow, 1, 2) + connect_nodes(flow, 3, 2) + + generated = export_flow_to_polars(flow) + + assert "dict[str, list[pl.LazyFrame]]" in generated + assert "flowfile" not in generated + verify_code_executes(generated) diff --git a/flowfile_core/tests/flowfile/test_python_script_rewriter.py b/flowfile_core/tests/flowfile/test_python_script_rewriter.py new file mode 100644 index 000000000..d1a03146c --- /dev/null +++ b/flowfile_core/tests/flowfile/test_python_script_rewriter.py @@ -0,0 +1,459 @@ +""" +Unit tests for python_script_rewriter.py — the AST rewriting engine that +transforms flowfile.* API calls into plain Python equivalents. +""" + +import ast + +import pytest + +from flowfile_core.flowfile.code_generator.python_script_rewriter import ( + FlowfileUsageAnalysis, + analyze_flowfile_usage, + build_function_code, + extract_imports, + get_import_names, + get_required_packages, + rewrite_flowfile_calls, +) + + +# --------------------------------------------------------------------------- +# Tests for analyze_flowfile_usage +# --------------------------------------------------------------------------- + + +class TestAnalyzeFlowfileUsage: + def test_single_input(self): + code = "df = flowfile.read_input()" + analysis = analyze_flowfile_usage(code) + assert analysis.input_mode == "single" + + def test_multi_input(self): + code = "inputs = flowfile.read_inputs()" + analysis = analyze_flowfile_usage(code) + assert analysis.input_mode == "multi" + + def test_no_input(self): + code = "x = 1 + 2" + analysis = analyze_flowfile_usage(code) + assert analysis.input_mode == "none" + + def test_publish_output(self): + code = "flowfile.publish_output(df)" + analysis = analyze_flowfile_usage(code) + assert analysis.has_output is True + assert len(analysis.output_exprs) == 1 + + def test_no_output(self): + code = "df = flowfile.read_input()" + analysis = analyze_flowfile_usage(code) + assert analysis.has_output is False + + def test_passthrough_output(self): + code = "flowfile.publish_output(flowfile.read_input())" + analysis = analyze_flowfile_usage(code) + assert analysis.passthrough_output is True + + def test_non_passthrough_output(self): + code = "flowfile.publish_output(result)" + analysis = analyze_flowfile_usage(code) + assert analysis.passthrough_output is False + + def test_artifact_publish(self): + code = 'flowfile.publish_artifact("model", clf)' + analysis = analyze_flowfile_usage(code) + assert len(analysis.artifacts_published) == 1 + assert analysis.artifacts_published[0][0] == "model" + + def test_artifact_consume(self): + code = 'model = flowfile.read_artifact("model")' + analysis = analyze_flowfile_usage(code) + assert analysis.artifacts_consumed == ["model"] + + def test_artifact_delete(self): + code = 'flowfile.delete_artifact("model")' + analysis = analyze_flowfile_usage(code) + assert analysis.artifacts_deleted == ["model"] + + def test_dynamic_artifact_name_detected(self): + code = "flowfile.read_artifact(name_var)" + analysis = analyze_flowfile_usage(code) + assert len(analysis.dynamic_artifact_names) == 1 + + def test_dynamic_publish_artifact_name_detected(self): + code = "flowfile.publish_artifact(name_var, obj)" + analysis = analyze_flowfile_usage(code) + assert len(analysis.dynamic_artifact_names) == 1 + + def test_dynamic_delete_artifact_name_detected(self): + code = "flowfile.delete_artifact(name_var)" + analysis = analyze_flowfile_usage(code) + assert len(analysis.dynamic_artifact_names) == 1 + + def test_logging(self): + code = 'flowfile.log("hello")' + analysis = analyze_flowfile_usage(code) + assert analysis.has_logging is True + + def test_log_with_level(self): + code = 'flowfile.log("hello", "ERROR")' + analysis = analyze_flowfile_usage(code) + assert analysis.has_logging is True + + def test_log_info(self): + code = 'flowfile.log_info("hello")' + analysis = analyze_flowfile_usage(code) + assert analysis.has_logging is True + + def test_list_artifacts(self): + code = "arts = flowfile.list_artifacts()" + analysis = analyze_flowfile_usage(code) + assert analysis.has_list_artifacts is True + + def test_unsupported_display_call(self): + code = "flowfile.display(fig)" + analysis = analyze_flowfile_usage(code) + assert len(analysis.unsupported_calls) == 1 + assert analysis.unsupported_calls[0][0] == "display" + + def test_multiple_artifacts(self): + code = ( + 'flowfile.publish_artifact("model", clf)\n' + 'flowfile.publish_artifact("scaler", sc)\n' + 'x = flowfile.read_artifact("model")\n' + ) + analysis = analyze_flowfile_usage(code) + assert len(analysis.artifacts_published) == 2 + assert analysis.artifacts_consumed == ["model"] + + def test_syntax_error_raises(self): + code = "def foo(:" + with pytest.raises(SyntaxError): + analyze_flowfile_usage(code) + + def test_complete_script(self): + code = ( + "import polars as pl\n" + "from sklearn.ensemble import RandomForestClassifier\n" + "\n" + "df = flowfile.read_input().collect()\n" + "X = df.select(['f1', 'f2']).to_numpy()\n" + "y = df.get_column('target').to_numpy()\n" + "\n" + "model = RandomForestClassifier()\n" + "model.fit(X, y)\n" + "\n" + 'flowfile.publish_artifact("rf_model", model)\n' + "flowfile.publish_output(flowfile.read_input())\n" + ) + analysis = analyze_flowfile_usage(code) + assert analysis.input_mode == "single" + assert analysis.has_output is True + assert analysis.passthrough_output is True + assert len(analysis.artifacts_published) == 1 + assert analysis.artifacts_published[0][0] == "rf_model" + + +# --------------------------------------------------------------------------- +# Tests for rewrite_flowfile_calls +# --------------------------------------------------------------------------- + + +class TestRewriteFlowfileCalls: + def test_read_input_replaced(self): + code = "df = flowfile.read_input()" + analysis = analyze_flowfile_usage(code) + result, _ = rewrite_flowfile_calls(code, analysis) + assert "flowfile" not in result + assert "input_df" in result + + def test_read_input_with_collect(self): + code = "df = flowfile.read_input().collect()" + analysis = analyze_flowfile_usage(code) + result, _ = rewrite_flowfile_calls(code, analysis) + assert "flowfile" not in result + assert "input_df.collect()" in result + + def test_read_inputs_replaced(self): + code = 'dfs = flowfile.read_inputs()\ndf = dfs["main"]' + analysis = analyze_flowfile_usage(code) + result, _ = rewrite_flowfile_calls(code, analysis) + assert "flowfile" not in result + assert "inputs" in result + + def test_publish_output_removed(self): + code = "x = 1\nflowfile.publish_output(df)\ny = 2" + analysis = analyze_flowfile_usage(code) + result, _ = rewrite_flowfile_calls(code, analysis) + assert "publish_output" not in result + assert "x = 1" in result + assert "y = 2" in result + + def test_publish_artifact_becomes_assignment(self): + code = 'flowfile.publish_artifact("model", clf)' + analysis = analyze_flowfile_usage(code) + result, _ = rewrite_flowfile_calls(code, analysis, kernel_id="k1") + assert "flowfile" not in result + assert "_artifacts" in result + assert "k1" in result + assert "model" in result + + def test_read_artifact_becomes_subscript(self): + code = 'model = flowfile.read_artifact("model")' + analysis = analyze_flowfile_usage(code) + result, _ = rewrite_flowfile_calls(code, analysis, kernel_id="k1") + assert "flowfile" not in result + assert "_artifacts" in result + assert "k1" in result + assert "model" in result + + def test_delete_artifact_becomes_del(self): + code = 'flowfile.delete_artifact("model")' + analysis = analyze_flowfile_usage(code) + result, _ = rewrite_flowfile_calls(code, analysis, kernel_id="k1") + assert "flowfile" not in result + assert "del _artifacts" in result + assert "k1" in result + + def test_list_artifacts_becomes_kernel_dict(self): + code = "arts = flowfile.list_artifacts()" + analysis = analyze_flowfile_usage(code) + result, _ = rewrite_flowfile_calls(code, analysis, kernel_id="k1") + assert "flowfile" not in result + assert "_artifacts" in result + assert "k1" in result + + def test_default_kernel_id_when_none(self): + code = 'flowfile.publish_artifact("model", clf)' + analysis = analyze_flowfile_usage(code) + result, _ = rewrite_flowfile_calls(code, analysis, kernel_id=None) + assert "_default" in result + + def test_artifacts_scoped_to_kernel(self): + """Verify artifact access includes the kernel_id key.""" + code = 'model = flowfile.read_artifact("model")' + analysis = analyze_flowfile_usage(code) + result, _ = rewrite_flowfile_calls(code, analysis, kernel_id="my_kernel") + assert "my_kernel" in result + assert "model" in result + + def test_log_becomes_print(self): + code = 'flowfile.log("hello")' + analysis = analyze_flowfile_usage(code) + result, _ = rewrite_flowfile_calls(code, analysis) + assert "print" in result + assert "flowfile" not in result + + def test_log_with_level_becomes_print(self): + code = 'flowfile.log("hello", "ERROR")' + analysis = analyze_flowfile_usage(code) + result, _ = rewrite_flowfile_calls(code, analysis) + assert "print" in result + + def test_log_info_becomes_print(self): + code = 'flowfile.log_info("processing")' + analysis = analyze_flowfile_usage(code) + result, _ = rewrite_flowfile_calls(code, analysis) + assert "print" in result + assert "INFO" in result + + def test_non_flowfile_code_unchanged(self): + code = "x = 1 + 2\ny = x * 3" + analysis = analyze_flowfile_usage(code) + result, _ = rewrite_flowfile_calls(code, analysis) + assert "x = 1 + 2" in result + assert "y = x * 3" in result + + def test_chained_collect(self): + code = "df = flowfile.read_input().collect()\nresult = df.select(['a'])" + analysis = analyze_flowfile_usage(code) + result, _ = rewrite_flowfile_calls(code, analysis) + assert "input_df.collect()" in result + assert "result = df.select" in result + + def test_unsupported_call_returns_marker(self): + """Unsupported calls should produce markers for comment replacement.""" + code = "flowfile.publish_global('model', obj)\nx = 1" + analysis = analyze_flowfile_usage(code) + result, markers = rewrite_flowfile_calls(code, analysis) + assert len(markers) == 1 + marker = list(markers.keys())[0] + assert marker in result + source = list(markers.values())[0] + assert "publish_global" in source + + +# --------------------------------------------------------------------------- +# Tests for extract_imports +# --------------------------------------------------------------------------- + + +class TestExtractImports: + def test_standard_import(self): + code = "import numpy as np\nx = 1" + result = extract_imports(code) + assert "import numpy as np" in result + + def test_from_import(self): + code = "from sklearn.ensemble import RandomForestClassifier" + result = extract_imports(code) + assert len(result) == 1 + assert "RandomForestClassifier" in result[0] + + def test_flowfile_import_excluded(self): + code = "import flowfile\nimport numpy as np" + result = extract_imports(code) + assert len(result) == 1 + assert "numpy" in result[0] + + def test_polars_import_included(self): + code = "import polars as pl" + result = extract_imports(code) + assert "import polars as pl" in result + + def test_no_imports(self): + code = "x = 1 + 2" + result = extract_imports(code) + assert result == [] + + def test_multiple_imports(self): + code = "import numpy\nimport pandas\nfrom os import path" + result = extract_imports(code) + assert len(result) == 3 + + +# --------------------------------------------------------------------------- +# Tests for build_function_code +# --------------------------------------------------------------------------- + + +class TestBuildFunctionCode: + def test_simple_single_input(self): + code = "df = input_df.collect()\nresult = df.select(['a'])" + analysis = FlowfileUsageAnalysis(input_mode="single", has_output=True) + # Add a mock output expr + analysis.output_exprs = [ast.parse("result", mode="eval").body] + func_def, call_code = build_function_code( + node_id=5, + rewritten_code=code, + analysis=analysis, + input_vars={"main": "df_3"}, + ) + assert "def _node_5(input_df: pl.LazyFrame)" in func_def + assert "return" in func_def + assert "df_5 = _node_5(df_3)" == call_code + + def test_no_input(self): + code = "x = 42" + analysis = FlowfileUsageAnalysis(input_mode="none") + func_def, call_code = build_function_code( + node_id=1, + rewritten_code=code, + analysis=analysis, + input_vars={}, + ) + assert "def _node_1()" in func_def + assert "df_1 = _node_1()" == call_code + + def test_multi_input(self): + code = 'df = inputs["main"]' + analysis = FlowfileUsageAnalysis(input_mode="multi") + func_def, call_code = build_function_code( + node_id=2, + rewritten_code=code, + analysis=analysis, + input_vars={"main": "df_1", "right": "df_0"}, + ) + # Runtime uses dict[str, list[pl.LazyFrame]] + assert "inputs: dict[str, list[pl.LazyFrame]]" in func_def + assert "df_2 = _node_2(" in call_code + + def test_multi_input_grouped(self): + """Multiple main inputs (main_0, main_1) should be grouped into a list.""" + code = 'dfs = inputs["main"]' + analysis = FlowfileUsageAnalysis(input_mode="multi") + func_def, call_code = build_function_code( + node_id=2, + rewritten_code=code, + analysis=analysis, + input_vars={"main_0": "df_1", "main_1": "df_3"}, + ) + assert "inputs: dict[str, list[pl.LazyFrame]]" in func_def + # The call should group main_0 and main_1 under "main" + assert '"main": [df_1, df_3]' in call_code + + def test_passthrough_return(self): + code = "x = 1" + analysis = FlowfileUsageAnalysis(input_mode="single", has_output=True, passthrough_output=True) + analysis.output_exprs = [ast.parse("flowfile.read_input()", mode="eval").body] + func_def, _ = build_function_code( + node_id=3, + rewritten_code=code, + analysis=analysis, + input_vars={"main": "df_1"}, + ) + assert "return input_df" in func_def + + def test_implicit_passthrough_no_output(self): + """If no publish_output is called with single input, pass through.""" + code = "_artifacts['model'] = clf" + analysis = FlowfileUsageAnalysis(input_mode="single") + func_def, _ = build_function_code( + node_id=4, + rewritten_code=code, + analysis=analysis, + input_vars={"main": "df_2"}, + ) + assert "return input_df" in func_def + + +# --------------------------------------------------------------------------- +# Tests for get_import_names / get_required_packages +# --------------------------------------------------------------------------- + + +class TestPackageMapping: + def test_scikit_learn(self): + assert get_import_names("scikit-learn") == ["sklearn"] + + def test_pillow(self): + assert get_import_names("pillow") == ["PIL"] + + def test_standard_package(self): + assert get_import_names("numpy") == ["numpy"] + + def test_dash_to_underscore(self): + assert get_import_names("my-package") == ["my_package"] + + +class TestGetRequiredPackages: + def test_basic_match(self): + user_imports = ["from sklearn.ensemble import RandomForestClassifier"] + kernel_packages = ["scikit-learn", "numpy", "polars"] + result = get_required_packages(user_imports, kernel_packages) + assert result == ["scikit-learn"] + + def test_multiple_matches(self): + user_imports = [ + "import numpy as np", + "from sklearn.ensemble import RandomForestClassifier", + ] + kernel_packages = ["scikit-learn", "numpy", "polars"] + result = get_required_packages(user_imports, kernel_packages) + assert result == ["numpy", "scikit-learn"] + + def test_no_match(self): + user_imports = ["import polars as pl"] + kernel_packages = ["scikit-learn"] + result = get_required_packages(user_imports, kernel_packages) + assert result == [] + + def test_empty_inputs(self): + assert get_required_packages([], []) == [] + + def test_pillow_mapping(self): + user_imports = ["from PIL import Image"] + kernel_packages = ["pillow"] + result = get_required_packages(user_imports, kernel_packages) + assert result == ["pillow"]