From ec3763e801873f5321653650a7d06269bcfbb75b Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 6 Feb 2026 15:49:29 +0000 Subject: [PATCH 1/4] Add code generation support for python_script nodes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implement AST-based rewriting of flowfile.* API calls in python_script nodes to plain Python equivalents, enabling code export for flows that use kernel-based Python execution. New module python_script_rewriter.py handles: - flowfile.read_input() → function parameter (input_df) - flowfile.read_inputs() → function parameter (inputs dict) - flowfile.publish_output(expr) → return statement - flowfile.publish_artifact/read_artifact/delete_artifact → _artifacts dict - flowfile.log() → print() - Package dependency detection from kernel config Integration with FlowGraphToPolarsConverter: - New _handle_python_script handler method - Artifact tracking across nodes for validation - _artifacts = {} emitted in generated code when needed - Graceful error handling for unsupported patterns https://claude.ai/code/session_01Cn56TDT4iPpFpgFL8Fp1pn --- .../flowfile/code_generator/__init__.py | 14 + .../flowfile/code_generator/code_generator.py | 121 ++++ .../code_generator/python_script_rewriter.py | 545 +++++++++++++++ .../test_code_generator_python_script.py | 648 ++++++++++++++++++ .../flowfile/test_python_script_rewriter.py | 415 +++++++++++ 5 files changed, 1743 insertions(+) create mode 100644 flowfile_core/flowfile_core/flowfile/code_generator/python_script_rewriter.py create mode 100644 flowfile_core/tests/flowfile/test_code_generator_python_script.py create mode 100644 flowfile_core/tests/flowfile/test_python_script_rewriter.py diff --git a/flowfile_core/flowfile_core/flowfile/code_generator/__init__.py b/flowfile_core/flowfile_core/flowfile/code_generator/__init__.py index 3616bec84..0d4609087 100644 --- a/flowfile_core/flowfile_core/flowfile/code_generator/__init__.py +++ b/flowfile_core/flowfile_core/flowfile/code_generator/__init__.py @@ -3,9 +3,23 @@ UnsupportedNodeError, export_flow_to_polars, ) +from flowfile_core.flowfile.code_generator.python_script_rewriter import ( + FlowfileUsageAnalysis, + analyze_flowfile_usage, + build_function_code, + extract_imports, + get_required_packages, + rewrite_flowfile_calls, +) __all__ = [ "FlowGraphToPolarsConverter", "UnsupportedNodeError", "export_flow_to_polars", + "FlowfileUsageAnalysis", + "analyze_flowfile_usage", + "build_function_code", + "extract_imports", + "get_required_packages", + "rewrite_flowfile_calls", ] diff --git a/flowfile_core/flowfile_core/flowfile/code_generator/code_generator.py b/flowfile_core/flowfile_core/flowfile/code_generator/code_generator.py index fdbbb6843..c919bd8f4 100644 --- a/flowfile_core/flowfile_core/flowfile/code_generator/code_generator.py +++ b/flowfile_core/flowfile_core/flowfile/code_generator/code_generator.py @@ -52,6 +52,10 @@ def __init__(self, flow_graph: FlowGraph): self.last_node_var = None self.unsupported_nodes = [] self.custom_node_classes = {} + # Track which artifacts have been published and by which node (for validation) + self._published_artifacts: dict[str, int] = {} # artifact_name → node_id + # Track if any python_script nodes exist (to emit _artifacts = {} once) + self._has_python_script_nodes: bool = False def convert(self) -> str: """ @@ -1118,6 +1122,118 @@ def _handle_polars_code( self._add_code(f"{var_name} = _polars_code_{var_name.replace('df_', '')}({args})") self._add_code("") + def _handle_python_script( + self, settings: input_schema.NodePythonScript, var_name: str, input_vars: dict[str, str] + ) -> None: + """Handle python_script nodes by rewriting flowfile.* calls to plain Python.""" + from flowfile_core.flowfile.code_generator.python_script_rewriter import ( + analyze_flowfile_usage, + build_function_code, + extract_imports, + rewrite_flowfile_calls, + ) + + code = settings.python_script_input.code.strip() + kernel_id = settings.python_script_input.kernel_id + node_id = settings.node_id + + # Handle empty code — pass through input + if not code: + if input_vars: + self._add_code(f"{var_name} = {list(input_vars.values())[0]}") + else: + self._add_code(f"{var_name} = pl.LazyFrame()") + return + + # 1. Analyze flowfile usage + try: + analysis = analyze_flowfile_usage(code) + except SyntaxError as e: + self.unsupported_nodes.append(( + node_id, + "python_script", + f"Syntax error in python_script code: {e}" + )) + return + self._has_python_script_nodes = True + + # 2. Check for unsupported patterns + if analysis.dynamic_artifact_names: + self.unsupported_nodes.append(( + node_id, + "python_script", + "Artifact names must be string literals for code generation. " + f"Found dynamic names at lines: {[getattr(n, 'lineno', '?') for n in analysis.dynamic_artifact_names]}" + )) + return + + if analysis.unsupported_calls: + methods = [m for m, _ in analysis.unsupported_calls] + self.unsupported_nodes.append(( + node_id, + "python_script", + f"Unsupported flowfile API calls for code generation: {', '.join(methods)}" + )) + return + + # 3. Validate artifact dependencies are available + for artifact_name in analysis.artifacts_consumed: + if artifact_name not in self._published_artifacts: + self.unsupported_nodes.append(( + node_id, + "python_script", + f"Artifact '{artifact_name}' is consumed but not published by any upstream node" + )) + return + + # 4. Extract and register imports + user_imports = extract_imports(code) + for imp in user_imports: + self.imports.add(imp) + + # 5. Add kernel package requirements as comments + if kernel_id: + self._add_kernel_requirements(kernel_id, user_imports) + + # 6. Rewrite the code + rewritten = rewrite_flowfile_calls(code, analysis) + + # 7. Build and emit the function + func_def, call_code = build_function_code( + node_id, rewritten, analysis, input_vars + ) + + self._add_code(f"# --- Node {node_id}: python_script ---") + for line in func_def.split("\n"): + self._add_code(line) + self._add_code("") + self._add_code(call_code) + + # 8. Track published artifacts for validation of downstream nodes + for artifact_name, _ in analysis.artifacts_published: + self._published_artifacts[artifact_name] = node_id + + self._add_code("") + + def _add_kernel_requirements(self, kernel_id: str, user_imports: list[str]) -> None: + """Add a comment block with required packages from kernel config.""" + try: + from flowfile_core.flowfile.code_generator.python_script_rewriter import get_required_packages + from flowfile_core.kernel.manager import get_kernel_manager + + manager = get_kernel_manager() + kernel = manager._kernels.get(kernel_id) + if not kernel or not kernel.packages: + return + + required = get_required_packages(user_imports, kernel.packages) + if required: + self._add_code(f"# Required packages: {', '.join(required)}") + self._add_code(f"# Install with: pip install {' '.join(required)}") + self._add_code("") + except Exception: + pass # Kernel manager not available; skip requirements comment + # Handlers for unsupported node types - these add nodes to the unsupported list def _handle_explore_data( @@ -1639,6 +1755,11 @@ def _build_final_code(self) -> str: lines.append(f" ETL Pipeline: {self.flow_graph.__name__}") lines.append(" Generated from Flowfile") lines.append(' """') + + # Artifact store (only if python_script nodes exist) + if self._has_python_script_nodes or self._published_artifacts: + lines.append(" _artifacts = {} # Shared artifact store") + lines.append(" ") # Add the generated code diff --git a/flowfile_core/flowfile_core/flowfile/code_generator/python_script_rewriter.py b/flowfile_core/flowfile_core/flowfile/code_generator/python_script_rewriter.py new file mode 100644 index 000000000..d7a04c4ff --- /dev/null +++ b/flowfile_core/flowfile_core/flowfile/code_generator/python_script_rewriter.py @@ -0,0 +1,545 @@ +""" +AST-based rewriter for python_script node code generation. + +Transforms flowfile.* API calls in user code into plain Python equivalents, +enabling code generation for python_script nodes that normally execute inside +Docker kernel containers. + +Mapping: + flowfile.read_input() → function parameter (input_df) + flowfile.read_inputs() → function parameter (inputs) + flowfile.publish_output(expr) → return statement + flowfile.publish_artifact("n", o) → _artifacts["n"] = o + flowfile.read_artifact("n") → _artifacts["n"] + flowfile.delete_artifact("n") → del _artifacts["n"] + flowfile.list_artifacts() → _artifacts + flowfile.log(msg, level) → print(f"[{level}] {msg}") +""" + +from __future__ import annotations + +import ast +import textwrap +from dataclasses import dataclass, field +from typing import Literal + +# Maps pip package names to their Python import module names +# when they differ from the package name. +PACKAGE_TO_IMPORT_MAP: dict[str, list[str]] = { + "scikit-learn": ["sklearn"], + "pillow": ["PIL"], + "opencv-python": ["cv2"], + "opencv-python-headless": ["cv2"], + "beautifulsoup4": ["bs4"], + "pyyaml": ["yaml"], + "pytorch": ["torch"], + "tensorflow-gpu": ["tensorflow"], +} + + +def get_import_names(package: str) -> list[str]: + """Return the import name(s) for a pip package.""" + return PACKAGE_TO_IMPORT_MAP.get(package, [package.replace("-", "_")]) + + +@dataclass +class FlowfileUsageAnalysis: + """Results of analyzing flowfile.* API usage in user code.""" + + input_mode: Literal["none", "single", "multi"] = "none" + has_output: bool = False + output_exprs: list[ast.expr] = field(default_factory=list) + passthrough_output: bool = False + + artifacts_published: list[tuple[str, ast.expr]] = field(default_factory=list) + artifacts_consumed: list[str] = field(default_factory=list) + artifacts_deleted: list[str] = field(default_factory=list) + + has_logging: bool = False + has_list_artifacts: bool = False + + # For error reporting + dynamic_artifact_names: list[ast.AST] = field(default_factory=list) + unsupported_calls: list[tuple[str, ast.AST]] = field(default_factory=list) + + +def _is_flowfile_call(node: ast.Call, method: str | None = None) -> bool: + """Check if an AST Call node is a flowfile.* method call.""" + if not isinstance(node, ast.Call): + return False + func = node.func + if isinstance(func, ast.Attribute): + if isinstance(func.value, ast.Name) and func.value.id == "flowfile": + if method is None: + return True + return func.attr == method + return False + + +def _is_passthrough_output(node: ast.Call) -> bool: + """Check if publish_output argument is flowfile.read_input().""" + if not node.args: + return False + arg = node.args[0] + return _is_flowfile_call(arg, "read_input") if isinstance(arg, ast.Call) else False + + +class _FlowfileUsageVisitor(ast.NodeVisitor): + """Walk the AST to collect information about flowfile.* API usage.""" + + def __init__(self) -> None: + self.analysis = FlowfileUsageAnalysis() + + def visit_Call(self, node: ast.Call) -> None: + if _is_flowfile_call(node): + method = node.func.attr + if method == "read_input": + self.analysis.input_mode = "single" + elif method == "read_inputs": + self.analysis.input_mode = "multi" + elif method == "publish_output": + self.analysis.has_output = True + if node.args: + self.analysis.output_exprs.append(node.args[0]) + if _is_passthrough_output(node): + self.analysis.passthrough_output = True + elif method == "publish_artifact": + if len(node.args) >= 2: + name_node = node.args[0] + if isinstance(name_node, ast.Constant) and isinstance(name_node.value, str): + self.analysis.artifacts_published.append((name_node.value, node.args[1])) + else: + self.analysis.dynamic_artifact_names.append(node) + elif method == "read_artifact": + if node.args: + name_node = node.args[0] + if isinstance(name_node, ast.Constant) and isinstance(name_node.value, str): + self.analysis.artifacts_consumed.append(name_node.value) + else: + self.analysis.dynamic_artifact_names.append(node) + elif method == "delete_artifact": + if node.args: + name_node = node.args[0] + if isinstance(name_node, ast.Constant) and isinstance(name_node.value, str): + self.analysis.artifacts_deleted.append(name_node.value) + else: + self.analysis.dynamic_artifact_names.append(node) + elif method == "log" or method in ("log_info", "log_warning", "log_error"): + self.analysis.has_logging = True + elif method == "list_artifacts": + self.analysis.has_list_artifacts = True + elif method in ("display", "publish_global", "get_global", + "list_global_artifacts", "delete_global_artifact"): + self.analysis.unsupported_calls.append((method, node)) + self.generic_visit(node) + + +def analyze_flowfile_usage(code: str) -> FlowfileUsageAnalysis: + """Parse user code and analyze flowfile.* API usage. + + Args: + code: The raw Python source code from a python_script node. + + Returns: + FlowfileUsageAnalysis with details about how the flowfile API is used. + + Raises: + SyntaxError: If the code cannot be parsed. + """ + tree = ast.parse(code) + visitor = _FlowfileUsageVisitor() + visitor.visit(tree) + return visitor.analysis + + +class _FlowfileCallRewriter(ast.NodeTransformer): + """Rewrite flowfile.* API calls to plain Python equivalents.""" + + def __init__(self, analysis: FlowfileUsageAnalysis) -> None: + self.analysis = analysis + self.input_var = "input_df" if analysis.input_mode == "single" else "inputs" + self._last_output_expr: ast.expr | None = None + # Track which publish_output call is the last one + if analysis.output_exprs: + self._last_output_expr = analysis.output_exprs[-1] + + def visit_Call(self, node: ast.Call) -> ast.AST: + # First transform any nested calls + self.generic_visit(node) + + if not _is_flowfile_call(node): + return node + + method = node.func.attr + + if method == "read_input": + # flowfile.read_input() → input_df + return ast.Name(id=self.input_var, ctx=ast.Load()) + + if method == "read_inputs": + # flowfile.read_inputs() → inputs + return ast.Name(id=self.input_var, ctx=ast.Load()) + + if method == "read_artifact": + # flowfile.read_artifact("name") → _artifacts["name"] + return ast.Subscript( + value=ast.Name(id="_artifacts", ctx=ast.Load()), + slice=node.args[0], + ctx=ast.Load(), + ) + + if method == "list_artifacts": + # flowfile.list_artifacts() → _artifacts + return ast.Name(id="_artifacts", ctx=ast.Load()) + + if method == "log": + return self._make_log_print(node) + + if method == "log_info": + return self._make_log_print_with_level(node, "INFO") + + if method == "log_warning": + return self._make_log_print_with_level(node, "WARNING") + + if method == "log_error": + return self._make_log_print_with_level(node, "ERROR") + + return node + + def visit_Expr(self, node: ast.Expr) -> ast.AST | None: + # Transform nested calls first + self.generic_visit(node) + + if not isinstance(node.value, ast.Call): + return node + + call = node.value + if not _is_flowfile_call(call): + return node + + method = call.func.attr + + if method == "publish_output": + # Remove publish_output statements — we handle via return + return None + + if method == "publish_artifact": + # flowfile.publish_artifact("name", obj) → _artifacts["name"] = obj + if len(call.args) >= 2: + return ast.Assign( + targets=[ + ast.Subscript( + value=ast.Name(id="_artifacts", ctx=ast.Load()), + slice=call.args[0], + ctx=ast.Store(), + ) + ], + value=call.args[1], + lineno=node.lineno, + col_offset=node.col_offset, + ) + return node + + if method == "delete_artifact": + # flowfile.delete_artifact("name") → del _artifacts["name"] + if call.args: + return ast.Delete( + targets=[ + ast.Subscript( + value=ast.Name(id="_artifacts", ctx=ast.Load()), + slice=call.args[0], + ctx=ast.Del(), + ) + ], + lineno=node.lineno, + col_offset=node.col_offset, + ) + return node + + return node + + @staticmethod + def _make_log_print(node: ast.Call) -> ast.Call: + """Transform flowfile.log("msg", "LEVEL") → print("[LEVEL] msg").""" + msg_arg = node.args[0] if node.args else ast.Constant(value="") + + # Get level from second arg or keyword + level: ast.expr | None = None + if len(node.args) >= 2: + level = node.args[1] + else: + for kw in node.keywords: + if kw.arg == "level": + level = kw.value + break + + if level is None: + level = ast.Constant(value="INFO") + + # Build print(f"[{level}] {msg}") + format_str = ast.JoinedStr( + values=[ + ast.Constant(value="["), + ast.FormattedValue(value=level, conversion=-1), + ast.Constant(value="] "), + ast.FormattedValue(value=msg_arg, conversion=-1), + ] + ) + + return ast.Call( + func=ast.Name(id="print", ctx=ast.Load()), + args=[format_str], + keywords=[], + ) + + @staticmethod + def _make_log_print_with_level(node: ast.Call, level_str: str) -> ast.Call: + """Transform flowfile.log_info("msg") → print("[INFO] msg").""" + msg_arg = node.args[0] if node.args else ast.Constant(value="") + + format_str = ast.JoinedStr( + values=[ + ast.Constant(value=f"[{level_str}] "), + ast.FormattedValue(value=msg_arg, conversion=-1), + ] + ) + + return ast.Call( + func=ast.Name(id="print", ctx=ast.Load()), + args=[format_str], + keywords=[], + ) + + +def rewrite_flowfile_calls(code: str, analysis: FlowfileUsageAnalysis) -> str: + """Rewrite flowfile.* API calls in user code to plain Python. + + This removes/replaces flowfile API calls but does NOT add function + wrapping, return statements, or import stripping. Those are handled + by ``build_function_code``. + + Args: + code: The raw Python source from a python_script node. + analysis: Pre-computed analysis of flowfile usage. + + Returns: + The rewritten source code with flowfile calls replaced. + """ + tree = ast.parse(code) + rewriter = _FlowfileCallRewriter(analysis) + new_tree = rewriter.visit(tree) + # Remove None nodes (deleted statements) + new_tree.body = [node for node in new_tree.body if node is not None] + ast.fix_missing_locations(new_tree) + return ast.unparse(new_tree) + + +def extract_imports(code: str) -> list[str]: + """Extract import statements from user code, excluding flowfile imports. + + Args: + code: The raw Python source code. + + Returns: + List of import statement strings (each is a full import line). + """ + tree = ast.parse(code) + imports: list[str] = [] + for node in ast.iter_child_nodes(tree): + if isinstance(node, ast.Import): + # Filter out "import flowfile" + non_flowfile_aliases = [alias for alias in node.names if alias.name != "flowfile"] + if non_flowfile_aliases: + # Reconstruct import with only non-flowfile names + filtered = ast.Import(names=non_flowfile_aliases) + ast.fix_missing_locations(filtered) + imports.append(ast.unparse(filtered)) + elif isinstance(node, ast.ImportFrom): + if node.module and "flowfile" not in node.module: + imports.append(ast.unparse(node)) + elif node.module is None: + imports.append(ast.unparse(node)) + return imports + + +def _strip_imports_and_flowfile(code: str) -> str: + """Remove import statements and flowfile import from code body. + + Returns the code with all top-level import/from-import statements removed. + """ + tree = ast.parse(code) + new_body = [] + for node in tree.body: + if isinstance(node, ast.Import): + # Keep non-flowfile imports? No — imports are extracted separately + continue + elif isinstance(node, ast.ImportFrom): + continue + else: + new_body.append(node) + tree.body = new_body + if not tree.body: + return "" + ast.fix_missing_locations(tree) + return ast.unparse(tree) + + +def build_function_code( + node_id: int, + rewritten_code: str, + analysis: FlowfileUsageAnalysis, + input_vars: dict[str, str], +) -> tuple[str, str]: + """Assemble rewritten code into a function definition and call. + + Args: + node_id: The node ID for naming. + rewritten_code: The AST-rewritten code (from rewrite_flowfile_calls, + with imports already stripped). + analysis: The flowfile usage analysis. + input_vars: Mapping of input names to variable names from upstream nodes. + + Returns: + Tuple of (function_definition, call_code). + E.g.: + ("def _node_5(input_df: pl.LazyFrame) -> pl.LazyFrame:\\n ...", + "df_5 = _node_5(df_3)") + """ + func_name = f"_node_{node_id}" + var_name = f"df_{node_id}" + + # Build parameter list and arguments + params: list[str] = [] + args: list[str] = [] + + if analysis.input_mode == "single": + params.append("input_df: pl.LazyFrame") + main_var = input_vars.get("main") + if main_var is None: + # Multiple main inputs — pick first + for k in sorted(input_vars.keys()): + if k.startswith("main"): + main_var = input_vars[k] + break + args.append(main_var or "pl.LazyFrame()") + elif analysis.input_mode == "multi": + params.append("inputs: dict[str, pl.LazyFrame]") + dict_entries = ", ".join(f'"{k}": {v}' for k, v in sorted(input_vars.items())) + args.append("{" + dict_entries + "}") + + param_str = ", ".join(params) + return_type = "pl.LazyFrame" if params else "pl.LazyFrame | None" + + # Build function body + body_lines: list[str] = [] + + # Strip imports from rewritten code (they go to top-level) + body_code = _strip_imports_and_flowfile(rewritten_code) + + if body_code: + for line in body_code.split("\n"): + body_lines.append(line) + + # Add return statement + if analysis.has_output and analysis.output_exprs: + last_expr = analysis.output_exprs[-1] + if analysis.passthrough_output and analysis.input_mode == "single": + # publish_output(read_input()) → return input_df + body_lines.append("return input_df") + else: + # The output expr was rewritten by the transformer — + # we need to figure out what it became. + # The rewriter removed the publish_output Expr statement. + # We need to produce a return for the last output expression. + # Approach: re-parse and transform just the output expression + output_return = _build_return_for_output(last_expr, analysis) + body_lines.append(output_return) + elif analysis.input_mode == "single": + # No explicit output — pass through input + body_lines.append("return input_df") + elif analysis.input_mode == "multi": + # Pass through first input + first_key = sorted(input_vars.keys())[0] if input_vars else "main" + body_lines.append(f'return inputs["{first_key}"]') + elif not params: + body_lines.append("return None") + + if not body_lines: + body_lines.append("pass") + + # Assemble function definition + indented_body = textwrap.indent("\n".join(body_lines), " ") + func_def = f"def {func_name}({param_str}) -> {return_type}:\n{indented_body}" + + # Build call + arg_str = ", ".join(args) + call_code = f"{var_name} = {func_name}({arg_str})" + + return func_def, call_code + + +def _build_return_for_output(output_expr: ast.expr, analysis: FlowfileUsageAnalysis) -> str: + """Build a return statement from a publish_output expression. + + The expression is the original AST node from publish_output(expr). + We need to rewrite any flowfile calls in it and then produce the return. + """ + # Create a temporary module to transform the expression + temp_code = ast.unparse(output_expr) + + # Check if it's just a variable name — common pattern like publish_output(result) + # In that case, ensure .lazy() is called for DataFrame returns + # We add .lazy() wrapper as a safety measure for DataFrames + rewriter = _FlowfileCallRewriter(analysis) + expr_tree = ast.parse(temp_code, mode="eval") + new_expr = rewriter.visit(expr_tree) + ast.fix_missing_locations(new_expr) + rewritten = ast.unparse(new_expr) + + # Check if the expression already has .lazy() call + if rewritten.endswith(".lazy()"): + return f"return {rewritten}" + + # If the expression is just a variable that likely holds a DataFrame, + # wrap with .lazy() to ensure LazyFrame return + return f"return {rewritten}.lazy()" + + +def get_required_packages( + user_imports: list[str], + kernel_packages: list[str], +) -> list[str]: + """Cross-reference user imports with kernel packages. + + Args: + user_imports: Import statement strings from user code. + kernel_packages: Package names from kernel configuration. + + Returns: + Sorted list of kernel packages that are actually used. + """ + # Build reverse map: import_name → package_name + import_to_package: dict[str, str] = {} + for pkg in kernel_packages: + for imp_name in get_import_names(pkg): + import_to_package[imp_name] = pkg + + # Parse user imports to get root module names + used_packages: set[str] = set() + for imp_str in user_imports: + try: + tree = ast.parse(imp_str) + except SyntaxError: + continue + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + root_module = alias.name.split(".")[0] + if root_module in import_to_package: + used_packages.add(import_to_package[root_module]) + elif isinstance(node, ast.ImportFrom) and node.module: + root_module = node.module.split(".")[0] + if root_module in import_to_package: + used_packages.add(import_to_package[root_module]) + + return sorted(used_packages) diff --git a/flowfile_core/tests/flowfile/test_code_generator_python_script.py b/flowfile_core/tests/flowfile/test_code_generator_python_script.py new file mode 100644 index 000000000..200daa8fd --- /dev/null +++ b/flowfile_core/tests/flowfile/test_code_generator_python_script.py @@ -0,0 +1,648 @@ +""" +Integration tests for python_script node code generation. + +These tests verify that the FlowGraphToPolarsConverter correctly handles +python_script nodes by building FlowGraphs with python_script nodes +and checking the generated code. +""" + +import polars as pl +import pytest + +from flowfile_core.flowfile.code_generator.code_generator import ( + FlowGraphToPolarsConverter, + UnsupportedNodeError, + export_flow_to_polars, +) +from flowfile_core.flowfile.flow_graph import FlowGraph, add_connection +from flowfile_core.schemas import input_schema, schemas, transform_schema + + +def create_flow_settings(flow_id: int = 1) -> schemas.FlowSettings: + """Create basic flow settings for tests.""" + return schemas.FlowSettings( + flow_id=flow_id, + execution_mode="Performance", + execution_location="local", + path="/tmp/test_flow", + ) + + +def create_basic_flow(flow_id: int = 1, name: str = "test_flow") -> FlowGraph: + """Create a basic flow graph for testing.""" + return FlowGraph(flow_settings=create_flow_settings(flow_id), name=name) + + +def verify_code_executes(code: str): + """Execute generated code and verify no exceptions are raised.""" + exec_globals = {} + try: + exec(code, exec_globals) + _ = exec_globals["run_etl_pipeline"]() + except Exception as e: + raise AssertionError(f"Code execution failed:\n{e}\n\nGenerated code:\n{code}") + + +def get_result_from_generated_code(code: str): + """Execute generated code and return the result.""" + exec_globals = {} + exec(code, exec_globals) + return exec_globals["run_etl_pipeline"]() + + +def add_manual_input_node(flow: FlowGraph, node_id: int = 1) -> None: + """Add a manual input node with sample data.""" + raw_data = input_schema.RawData( + columns=[ + input_schema.MinimalFieldInfo(name="id", data_type="Int64"), + input_schema.MinimalFieldInfo(name="name", data_type="String"), + input_schema.MinimalFieldInfo(name="value", data_type="Float64"), + ], + data=[[1, 2, 3], ["Alice", "Bob", "Charlie"], [10.0, 20.0, 30.0]], + ) + settings = input_schema.NodeManualInput( + flow_id=1, + node_id=node_id, + raw_data_format=raw_data, + ) + flow.add_manual_input(settings) + + +def add_python_script_node( + flow: FlowGraph, + node_id: int, + code: str, + depending_on_ids: list[int] | None = None, + kernel_id: str | None = None, +) -> None: + """Add a python_script node to the flow.""" + settings = input_schema.NodePythonScript( + flow_id=1, + node_id=node_id, + depending_on_ids=depending_on_ids or [], + python_script_input=input_schema.PythonScriptInput( + code=code, + kernel_id=kernel_id, + ), + ) + flow.add_python_script(settings) + + +def connect_nodes(flow: FlowGraph, from_id: int, to_id: int, input_type: str = "main") -> None: + """Connect two nodes.""" + connection = input_schema.NodeConnection.create_from_simple_input(from_id, to_id, input_type) + add_connection(flow, connection) + + +# --------------------------------------------------------------------------- +# Basic python_script code generation tests +# --------------------------------------------------------------------------- + + +class TestSimplePythonScriptGeneration: + """Test basic python_script node code generation.""" + + def test_simple_passthrough(self): + """Python script that reads input and publishes it unchanged.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + code = ( + "df = flowfile.read_input()\n" + "flowfile.publish_output(df)\n" + ) + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1]) + connect_nodes(flow, 1, 2) + + generated = export_flow_to_polars(flow) + + assert "flowfile" not in generated + assert "def _node_2" in generated + assert "input_df" in generated + assert "_node_2(" in generated + + verify_code_executes(generated) + + def test_transform_with_collect(self): + """Python script that collects, transforms, and publishes output.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + code = ( + "import polars as pl\n" + "df = flowfile.read_input().collect()\n" + "result = df.with_columns(pl.col('value') * 2)\n" + "flowfile.publish_output(result)\n" + ) + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1]) + connect_nodes(flow, 1, 2) + + generated = export_flow_to_polars(flow) + + assert "flowfile" not in generated + assert "input_df.collect()" in generated + assert "def _node_2" in generated + + verify_code_executes(generated) + + def test_empty_code_passthrough(self): + """Empty python_script code should pass through.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + add_python_script_node(flow, node_id=2, code="", depending_on_ids=[1]) + connect_nodes(flow, 1, 2) + + generated = export_flow_to_polars(flow) + verify_code_executes(generated) + + def test_no_output_passthrough(self): + """Script without publish_output should pass through input.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + code = "df = flowfile.read_input().collect()\nx = len(df)\n" + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1]) + connect_nodes(flow, 1, 2) + + generated = export_flow_to_polars(flow) + + assert "return input_df" in generated + verify_code_executes(generated) + + def test_passthrough_output_pattern(self): + """publish_output(read_input()) should generate return input_df.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + code = ( + "df = flowfile.read_input().collect()\n" + "x = len(df)\n" + "flowfile.publish_output(flowfile.read_input())\n" + ) + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1]) + connect_nodes(flow, 1, 2) + + generated = export_flow_to_polars(flow) + + assert "return input_df" in generated + verify_code_executes(generated) + + +# --------------------------------------------------------------------------- +# Artifact tests +# --------------------------------------------------------------------------- + + +class TestArtifactCodeGeneration: + """Test artifact publish/consume code generation.""" + + def test_artifact_publish(self): + """publish_artifact becomes _artifacts assignment.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + code = ( + "df = flowfile.read_input().collect()\n" + "model = {'trained': True}\n" + 'flowfile.publish_artifact("my_model", model)\n' + "flowfile.publish_output(flowfile.read_input())\n" + ) + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1]) + connect_nodes(flow, 1, 2) + + generated = export_flow_to_polars(flow) + + assert "_artifacts" in generated + assert "my_model" in generated + assert "flowfile" not in generated + # Should have _artifacts = {} at top level + assert "_artifacts = {}" in generated + + verify_code_executes(generated) + + def test_artifact_chain(self): + """Artifacts flow correctly between two python_script nodes.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + # Producer node + producer_code = ( + "info = {'count': 42}\n" + 'flowfile.publish_artifact("info", info)\n' + "flowfile.publish_output(flowfile.read_input())\n" + ) + add_python_script_node(flow, node_id=2, code=producer_code, depending_on_ids=[1]) + connect_nodes(flow, 1, 2) + + # Consumer node + consumer_code = ( + 'info = flowfile.read_artifact("info")\n' + "df = flowfile.read_input().collect()\n" + "flowfile.publish_output(df)\n" + ) + add_python_script_node(flow, node_id=3, code=consumer_code, depending_on_ids=[2]) + connect_nodes(flow, 2, 3) + + generated = export_flow_to_polars(flow) + + assert "_artifacts" in generated + assert "info" in generated + assert "flowfile" not in generated + + verify_code_executes(generated) + + def test_artifact_delete(self): + """delete_artifact becomes del _artifacts[...].""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + code = ( + "obj = {'x': 1}\n" + 'flowfile.publish_artifact("temp", obj)\n' + 'flowfile.delete_artifact("temp")\n' + "flowfile.publish_output(flowfile.read_input())\n" + ) + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1]) + connect_nodes(flow, 1, 2) + + generated = export_flow_to_polars(flow) + + assert "del _artifacts" in generated + assert "temp" in generated + verify_code_executes(generated) + + def test_unconsumed_artifact_error(self): + """Consuming an artifact not published upstream should fail.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + code = ( + 'model = flowfile.read_artifact("missing_model")\n' + "flowfile.publish_output(flowfile.read_input())\n" + ) + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1]) + connect_nodes(flow, 1, 2) + + with pytest.raises(UnsupportedNodeError, match="missing_model"): + export_flow_to_polars(flow) + + +# --------------------------------------------------------------------------- +# Logging tests +# --------------------------------------------------------------------------- + + +class TestLoggingCodeGeneration: + """Test that flowfile.log becomes print.""" + + def test_log_becomes_print(self): + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + code = ( + 'flowfile.log("processing data")\n' + "flowfile.publish_output(flowfile.read_input())\n" + ) + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1]) + connect_nodes(flow, 1, 2) + + generated = export_flow_to_polars(flow) + + assert "print" in generated + assert "flowfile" not in generated + verify_code_executes(generated) + + +# --------------------------------------------------------------------------- +# Error handling tests +# --------------------------------------------------------------------------- + + +class TestErrorHandling: + """Test error cases in python_script code generation.""" + + def test_dynamic_artifact_name(self): + """Dynamic artifact names should produce UnsupportedNodeError.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + code = ( + "name = 'model'\n" + "flowfile.read_artifact(name)\n" + "flowfile.publish_output(flowfile.read_input())\n" + ) + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1]) + connect_nodes(flow, 1, 2) + + with pytest.raises(UnsupportedNodeError, match="string literals"): + export_flow_to_polars(flow) + + def test_syntax_error_in_code(self): + """Syntax errors should produce UnsupportedNodeError.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + code = "def foo(:\n" + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1]) + connect_nodes(flow, 1, 2) + + with pytest.raises(UnsupportedNodeError, match="Syntax error"): + export_flow_to_polars(flow) + + def test_unsupported_display_call(self): + """flowfile.display should produce UnsupportedNodeError.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + code = ( + "flowfile.display('hello')\n" + "flowfile.publish_output(flowfile.read_input())\n" + ) + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1]) + connect_nodes(flow, 1, 2) + + with pytest.raises(UnsupportedNodeError, match="Unsupported flowfile API"): + export_flow_to_polars(flow) + + +# --------------------------------------------------------------------------- +# Import handling tests +# --------------------------------------------------------------------------- + + +class TestImportHandling: + """Test that imports from python_script nodes are handled correctly.""" + + def test_user_imports_added(self): + """User imports should appear in generated code.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + code = ( + "import json\n" + "df = flowfile.read_input().collect()\n" + "data = json.dumps({'count': len(df)})\n" + "flowfile.publish_output(flowfile.read_input())\n" + ) + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1]) + connect_nodes(flow, 1, 2) + + generated = export_flow_to_polars(flow) + + assert "import json" in generated + assert "flowfile" not in generated + + def test_flowfile_import_excluded(self): + """import flowfile should not appear in generated code.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + code = ( + "import flowfile\n" + "import json\n" + "flowfile.publish_output(flowfile.read_input())\n" + ) + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1]) + connect_nodes(flow, 1, 2) + + generated = export_flow_to_polars(flow) + + # "import flowfile" should not be present; "import json" should be + lines = generated.split("\n") + assert not any(line.strip() == "import flowfile" for line in lines) + assert "import json" in generated + + +# --------------------------------------------------------------------------- +# Mixed node type tests +# --------------------------------------------------------------------------- + + +class TestMixedNodeTypes: + """Test flows mixing python_script with other node types.""" + + def test_manual_input_then_python_script(self): + """Manual input → python_script → output.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + code = ( + "import polars as pl\n" + "df = flowfile.read_input().collect()\n" + "result = df.with_columns(pl.lit('new').alias('new_col'))\n" + "flowfile.publish_output(result)\n" + ) + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1]) + connect_nodes(flow, 1, 2) + + generated = export_flow_to_polars(flow) + verify_code_executes(generated) + + result = get_result_from_generated_code(generated) + if hasattr(result, "collect"): + result = result.collect() + assert "new_col" in result.columns + + def test_python_script_then_filter(self): + """python_script → filter node.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + code = ( + "flowfile.publish_output(flowfile.read_input())\n" + ) + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1]) + connect_nodes(flow, 1, 2) + + # Add filter node + filter_settings = input_schema.NodeFilter( + flow_id=1, + node_id=3, + depending_on_id=2, + filter_input=transform_schema.FilterInput( + mode="basic", + basic_filter=transform_schema.BasicFilter( + field="value", + operator=transform_schema.FilterOperator.GREATER_THAN, + value="15", + ), + ), + ) + flow.add_filter(filter_settings) + connect_nodes(flow, 2, 3) + + generated = export_flow_to_polars(flow) + verify_code_executes(generated) + + def test_multiple_python_script_nodes(self): + """Chain of multiple python_script nodes.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + code1 = ( + "import polars as pl\n" + "df = flowfile.read_input().collect()\n" + "result = df.with_columns(pl.col('value').alias('doubled'))\n" + "flowfile.publish_output(result)\n" + ) + add_python_script_node(flow, node_id=2, code=code1, depending_on_ids=[1]) + connect_nodes(flow, 1, 2) + + code2 = ( + "df = flowfile.read_input().collect()\n" + "flowfile.publish_output(df)\n" + ) + add_python_script_node(flow, node_id=3, code=code2, depending_on_ids=[2]) + connect_nodes(flow, 2, 3) + + generated = export_flow_to_polars(flow) + + assert "def _node_2" in generated + assert "def _node_3" in generated + verify_code_executes(generated) + + +# --------------------------------------------------------------------------- +# Artifacts store initialization test +# --------------------------------------------------------------------------- + + +class TestArtifactStoreInitialization: + """Test that _artifacts = {} is emitted properly.""" + + def test_artifacts_dict_emitted_for_python_script(self): + """_artifacts = {} should appear when python_script nodes exist.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + code = "flowfile.publish_output(flowfile.read_input())\n" + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1]) + connect_nodes(flow, 1, 2) + + generated = export_flow_to_polars(flow) + assert "_artifacts = {}" in generated + + def test_no_artifacts_dict_without_python_script(self): + """_artifacts should NOT appear when no python_script nodes exist.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + # Add a filter node (not python_script) + filter_settings = input_schema.NodeFilter( + flow_id=1, + node_id=2, + depending_on_id=1, + filter_input=transform_schema.FilterInput( + mode="basic", + basic_filter=transform_schema.BasicFilter( + field="value", + operator=transform_schema.FilterOperator.GREATER_THAN, + value="15", + ), + ), + ) + flow.add_filter(filter_settings) + connect_nodes(flow, 1, 2) + + generated = export_flow_to_polars(flow) + assert "_artifacts" not in generated + + +# --------------------------------------------------------------------------- +# Full pipeline test matching the spec's appendix example +# --------------------------------------------------------------------------- + + +class TestFullPipelineExample: + """Test the complete example from the specification.""" + + def test_train_predict_pipeline(self): + """Simulate train → predict pipeline with artifacts.""" + flow = create_basic_flow() + + # Node 1: Manual input with training data + raw_data = input_schema.RawData( + columns=[ + input_schema.MinimalFieldInfo(name="f1", data_type="Float64"), + input_schema.MinimalFieldInfo(name="f2", data_type="Float64"), + input_schema.MinimalFieldInfo(name="target", data_type="Int64"), + ], + data=[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [0, 1, 0, 1]], + ) + settings = input_schema.NodeManualInput( + flow_id=1, node_id=1, raw_data_format=raw_data + ) + flow.add_manual_input(settings) + + # Node 2: Train model + train_code = ( + "import polars as pl\n" + "df = flowfile.read_input().collect()\n" + "model = {'trained': True, 'n_features': 2}\n" + 'flowfile.publish_artifact("model", model)\n' + "flowfile.publish_output(flowfile.read_input())\n" + ) + add_python_script_node(flow, node_id=2, code=train_code, depending_on_ids=[1]) + connect_nodes(flow, 1, 2) + + # Node 3: Use model (consume artifact) + predict_code = ( + "import polars as pl\n" + 'model = flowfile.read_artifact("model")\n' + "df = flowfile.read_input().collect()\n" + "result = df.with_columns(pl.lit(model['n_features']).alias('n_features'))\n" + "flowfile.publish_output(result)\n" + ) + add_python_script_node(flow, node_id=3, code=predict_code, depending_on_ids=[2]) + connect_nodes(flow, 2, 3) + + generated = export_flow_to_polars(flow) + + # Verify structure + assert "_artifacts = {}" in generated + assert "def _node_2" in generated + assert "def _node_3" in generated + assert "_artifacts" in generated + assert "model" in generated + assert "flowfile" not in generated + + # Verify execution + verify_code_executes(generated) + + result = get_result_from_generated_code(generated) + if hasattr(result, "collect"): + result = result.collect() + assert "n_features" in result.columns + + def test_list_artifacts_usage(self): + """Test that list_artifacts becomes _artifacts reference.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + code = ( + "obj = {'x': 1}\n" + 'flowfile.publish_artifact("item", obj)\n' + "arts = flowfile.list_artifacts()\n" + "flowfile.publish_output(flowfile.read_input())\n" + ) + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1]) + connect_nodes(flow, 1, 2) + + generated = export_flow_to_polars(flow) + + assert "flowfile" not in generated + verify_code_executes(generated) + + def test_node_comment_header(self): + """Generated code should include node header comments.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + code = "flowfile.publish_output(flowfile.read_input())\n" + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1]) + connect_nodes(flow, 1, 2) + + generated = export_flow_to_polars(flow) + + assert "# --- Node 2: python_script ---" in generated diff --git a/flowfile_core/tests/flowfile/test_python_script_rewriter.py b/flowfile_core/tests/flowfile/test_python_script_rewriter.py new file mode 100644 index 000000000..0b90fea5c --- /dev/null +++ b/flowfile_core/tests/flowfile/test_python_script_rewriter.py @@ -0,0 +1,415 @@ +""" +Unit tests for python_script_rewriter.py — the AST rewriting engine that +transforms flowfile.* API calls into plain Python equivalents. +""" + +import ast + +import pytest + +from flowfile_core.flowfile.code_generator.python_script_rewriter import ( + FlowfileUsageAnalysis, + analyze_flowfile_usage, + build_function_code, + extract_imports, + get_import_names, + get_required_packages, + rewrite_flowfile_calls, +) + + +# --------------------------------------------------------------------------- +# Tests for analyze_flowfile_usage +# --------------------------------------------------------------------------- + + +class TestAnalyzeFlowfileUsage: + def test_single_input(self): + code = "df = flowfile.read_input()" + analysis = analyze_flowfile_usage(code) + assert analysis.input_mode == "single" + + def test_multi_input(self): + code = "inputs = flowfile.read_inputs()" + analysis = analyze_flowfile_usage(code) + assert analysis.input_mode == "multi" + + def test_no_input(self): + code = "x = 1 + 2" + analysis = analyze_flowfile_usage(code) + assert analysis.input_mode == "none" + + def test_publish_output(self): + code = "flowfile.publish_output(df)" + analysis = analyze_flowfile_usage(code) + assert analysis.has_output is True + assert len(analysis.output_exprs) == 1 + + def test_no_output(self): + code = "df = flowfile.read_input()" + analysis = analyze_flowfile_usage(code) + assert analysis.has_output is False + + def test_passthrough_output(self): + code = "flowfile.publish_output(flowfile.read_input())" + analysis = analyze_flowfile_usage(code) + assert analysis.passthrough_output is True + + def test_non_passthrough_output(self): + code = "flowfile.publish_output(result)" + analysis = analyze_flowfile_usage(code) + assert analysis.passthrough_output is False + + def test_artifact_publish(self): + code = 'flowfile.publish_artifact("model", clf)' + analysis = analyze_flowfile_usage(code) + assert len(analysis.artifacts_published) == 1 + assert analysis.artifacts_published[0][0] == "model" + + def test_artifact_consume(self): + code = 'model = flowfile.read_artifact("model")' + analysis = analyze_flowfile_usage(code) + assert analysis.artifacts_consumed == ["model"] + + def test_artifact_delete(self): + code = 'flowfile.delete_artifact("model")' + analysis = analyze_flowfile_usage(code) + assert analysis.artifacts_deleted == ["model"] + + def test_dynamic_artifact_name_detected(self): + code = "flowfile.read_artifact(name_var)" + analysis = analyze_flowfile_usage(code) + assert len(analysis.dynamic_artifact_names) == 1 + + def test_dynamic_publish_artifact_name_detected(self): + code = "flowfile.publish_artifact(name_var, obj)" + analysis = analyze_flowfile_usage(code) + assert len(analysis.dynamic_artifact_names) == 1 + + def test_dynamic_delete_artifact_name_detected(self): + code = "flowfile.delete_artifact(name_var)" + analysis = analyze_flowfile_usage(code) + assert len(analysis.dynamic_artifact_names) == 1 + + def test_logging(self): + code = 'flowfile.log("hello")' + analysis = analyze_flowfile_usage(code) + assert analysis.has_logging is True + + def test_log_with_level(self): + code = 'flowfile.log("hello", "ERROR")' + analysis = analyze_flowfile_usage(code) + assert analysis.has_logging is True + + def test_log_info(self): + code = 'flowfile.log_info("hello")' + analysis = analyze_flowfile_usage(code) + assert analysis.has_logging is True + + def test_list_artifacts(self): + code = "arts = flowfile.list_artifacts()" + analysis = analyze_flowfile_usage(code) + assert analysis.has_list_artifacts is True + + def test_unsupported_display_call(self): + code = "flowfile.display(fig)" + analysis = analyze_flowfile_usage(code) + assert len(analysis.unsupported_calls) == 1 + assert analysis.unsupported_calls[0][0] == "display" + + def test_multiple_artifacts(self): + code = ( + 'flowfile.publish_artifact("model", clf)\n' + 'flowfile.publish_artifact("scaler", sc)\n' + 'x = flowfile.read_artifact("model")\n' + ) + analysis = analyze_flowfile_usage(code) + assert len(analysis.artifacts_published) == 2 + assert analysis.artifacts_consumed == ["model"] + + def test_syntax_error_raises(self): + code = "def foo(:" + with pytest.raises(SyntaxError): + analyze_flowfile_usage(code) + + def test_complete_script(self): + code = ( + "import polars as pl\n" + "from sklearn.ensemble import RandomForestClassifier\n" + "\n" + "df = flowfile.read_input().collect()\n" + "X = df.select(['f1', 'f2']).to_numpy()\n" + "y = df.get_column('target').to_numpy()\n" + "\n" + "model = RandomForestClassifier()\n" + "model.fit(X, y)\n" + "\n" + 'flowfile.publish_artifact("rf_model", model)\n' + "flowfile.publish_output(flowfile.read_input())\n" + ) + analysis = analyze_flowfile_usage(code) + assert analysis.input_mode == "single" + assert analysis.has_output is True + assert analysis.passthrough_output is True + assert len(analysis.artifacts_published) == 1 + assert analysis.artifacts_published[0][0] == "rf_model" + + +# --------------------------------------------------------------------------- +# Tests for rewrite_flowfile_calls +# --------------------------------------------------------------------------- + + +class TestRewriteFlowfileCalls: + def test_read_input_replaced(self): + code = "df = flowfile.read_input()" + analysis = analyze_flowfile_usage(code) + result = rewrite_flowfile_calls(code, analysis) + assert "flowfile" not in result + assert "input_df" in result + + def test_read_input_with_collect(self): + code = "df = flowfile.read_input().collect()" + analysis = analyze_flowfile_usage(code) + result = rewrite_flowfile_calls(code, analysis) + assert "flowfile" not in result + assert "input_df.collect()" in result + + def test_read_inputs_replaced(self): + code = 'dfs = flowfile.read_inputs()\ndf = dfs["main"]' + analysis = analyze_flowfile_usage(code) + result = rewrite_flowfile_calls(code, analysis) + assert "flowfile" not in result + assert "inputs" in result + + def test_publish_output_removed(self): + code = "x = 1\nflowfile.publish_output(df)\ny = 2" + analysis = analyze_flowfile_usage(code) + result = rewrite_flowfile_calls(code, analysis) + assert "publish_output" not in result + assert "x = 1" in result + assert "y = 2" in result + + def test_publish_artifact_becomes_assignment(self): + code = 'flowfile.publish_artifact("model", clf)' + analysis = analyze_flowfile_usage(code) + result = rewrite_flowfile_calls(code, analysis) + assert "flowfile" not in result + assert "_artifacts" in result + assert "model" in result + + def test_read_artifact_becomes_subscript(self): + code = 'model = flowfile.read_artifact("model")' + analysis = analyze_flowfile_usage(code) + result = rewrite_flowfile_calls(code, analysis) + assert "flowfile" not in result + assert "_artifacts" in result + assert "model" in result + + def test_delete_artifact_becomes_del(self): + code = 'flowfile.delete_artifact("model")' + analysis = analyze_flowfile_usage(code) + result = rewrite_flowfile_calls(code, analysis) + assert "flowfile" not in result + assert "del _artifacts" in result + + def test_list_artifacts_becomes_dict(self): + code = "arts = flowfile.list_artifacts()" + analysis = analyze_flowfile_usage(code) + result = rewrite_flowfile_calls(code, analysis) + assert "flowfile" not in result + assert "_artifacts" in result + + def test_log_becomes_print(self): + code = 'flowfile.log("hello")' + analysis = analyze_flowfile_usage(code) + result = rewrite_flowfile_calls(code, analysis) + assert "print" in result + assert "flowfile" not in result + + def test_log_with_level_becomes_print(self): + code = 'flowfile.log("hello", "ERROR")' + analysis = analyze_flowfile_usage(code) + result = rewrite_flowfile_calls(code, analysis) + assert "print" in result + + def test_log_info_becomes_print(self): + code = 'flowfile.log_info("processing")' + analysis = analyze_flowfile_usage(code) + result = rewrite_flowfile_calls(code, analysis) + assert "print" in result + assert "INFO" in result + + def test_non_flowfile_code_unchanged(self): + code = "x = 1 + 2\ny = x * 3" + analysis = analyze_flowfile_usage(code) + result = rewrite_flowfile_calls(code, analysis) + assert "x = 1 + 2" in result + assert "y = x * 3" in result + + def test_chained_collect(self): + code = "df = flowfile.read_input().collect()\nresult = df.select(['a'])" + analysis = analyze_flowfile_usage(code) + result = rewrite_flowfile_calls(code, analysis) + assert "input_df.collect()" in result + assert "result = df.select" in result + + +# --------------------------------------------------------------------------- +# Tests for extract_imports +# --------------------------------------------------------------------------- + + +class TestExtractImports: + def test_standard_import(self): + code = "import numpy as np\nx = 1" + result = extract_imports(code) + assert "import numpy as np" in result + + def test_from_import(self): + code = "from sklearn.ensemble import RandomForestClassifier" + result = extract_imports(code) + assert len(result) == 1 + assert "RandomForestClassifier" in result[0] + + def test_flowfile_import_excluded(self): + code = "import flowfile\nimport numpy as np" + result = extract_imports(code) + assert len(result) == 1 + assert "numpy" in result[0] + + def test_polars_import_included(self): + code = "import polars as pl" + result = extract_imports(code) + assert "import polars as pl" in result + + def test_no_imports(self): + code = "x = 1 + 2" + result = extract_imports(code) + assert result == [] + + def test_multiple_imports(self): + code = "import numpy\nimport pandas\nfrom os import path" + result = extract_imports(code) + assert len(result) == 3 + + +# --------------------------------------------------------------------------- +# Tests for build_function_code +# --------------------------------------------------------------------------- + + +class TestBuildFunctionCode: + def test_simple_single_input(self): + code = "df = input_df.collect()\nresult = df.select(['a'])" + analysis = FlowfileUsageAnalysis(input_mode="single", has_output=True) + # Add a mock output expr + analysis.output_exprs = [ast.parse("result", mode="eval").body] + func_def, call_code = build_function_code( + node_id=5, + rewritten_code=code, + analysis=analysis, + input_vars={"main": "df_3"}, + ) + assert "def _node_5(input_df: pl.LazyFrame)" in func_def + assert "return" in func_def + assert "df_5 = _node_5(df_3)" == call_code + + def test_no_input(self): + code = "x = 42" + analysis = FlowfileUsageAnalysis(input_mode="none") + func_def, call_code = build_function_code( + node_id=1, + rewritten_code=code, + analysis=analysis, + input_vars={}, + ) + assert "def _node_1()" in func_def + assert "df_1 = _node_1()" == call_code + + def test_multi_input(self): + code = 'df = inputs["main"]' + analysis = FlowfileUsageAnalysis(input_mode="multi") + func_def, call_code = build_function_code( + node_id=2, + rewritten_code=code, + analysis=analysis, + input_vars={"main": "df_1", "right": "df_0"}, + ) + assert "inputs: dict[str, pl.LazyFrame]" in func_def + assert "df_2 = _node_2(" in call_code + + def test_passthrough_return(self): + code = "x = 1" + analysis = FlowfileUsageAnalysis(input_mode="single", has_output=True, passthrough_output=True) + analysis.output_exprs = [ast.parse("flowfile.read_input()", mode="eval").body] + func_def, _ = build_function_code( + node_id=3, + rewritten_code=code, + analysis=analysis, + input_vars={"main": "df_1"}, + ) + assert "return input_df" in func_def + + def test_implicit_passthrough_no_output(self): + """If no publish_output is called with single input, pass through.""" + code = "_artifacts['model'] = clf" + analysis = FlowfileUsageAnalysis(input_mode="single") + func_def, _ = build_function_code( + node_id=4, + rewritten_code=code, + analysis=analysis, + input_vars={"main": "df_2"}, + ) + assert "return input_df" in func_def + + +# --------------------------------------------------------------------------- +# Tests for get_import_names / get_required_packages +# --------------------------------------------------------------------------- + + +class TestPackageMapping: + def test_scikit_learn(self): + assert get_import_names("scikit-learn") == ["sklearn"] + + def test_pillow(self): + assert get_import_names("pillow") == ["PIL"] + + def test_standard_package(self): + assert get_import_names("numpy") == ["numpy"] + + def test_dash_to_underscore(self): + assert get_import_names("my-package") == ["my_package"] + + +class TestGetRequiredPackages: + def test_basic_match(self): + user_imports = ["from sklearn.ensemble import RandomForestClassifier"] + kernel_packages = ["scikit-learn", "numpy", "polars"] + result = get_required_packages(user_imports, kernel_packages) + assert result == ["scikit-learn"] + + def test_multiple_matches(self): + user_imports = [ + "import numpy as np", + "from sklearn.ensemble import RandomForestClassifier", + ] + kernel_packages = ["scikit-learn", "numpy", "polars"] + result = get_required_packages(user_imports, kernel_packages) + assert result == ["numpy", "scikit-learn"] + + def test_no_match(self): + user_imports = ["import polars as pl"] + kernel_packages = ["scikit-learn"] + result = get_required_packages(user_imports, kernel_packages) + assert result == [] + + def test_empty_inputs(self): + assert get_required_packages([], []) == [] + + def test_pillow_mapping(self): + user_imports = ["from PIL import Image"] + kernel_packages = ["pillow"] + result = get_required_packages(user_imports, kernel_packages) + assert result == ["pillow"] From 7d3289bdb06bf02ea4d438b9ca69c00d8ba721ab Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 6 Feb 2026 16:05:14 +0000 Subject: [PATCH 2/4] Scope artifacts by kernel_id and fix read_inputs structure Artifacts are now keyed per kernel in _artifacts, matching the runtime where each kernel container has its own independent artifact store. Cross-kernel artifact access is validated and rejected at code gen time. Also fixes read_inputs() to produce dict[str, list[pl.LazyFrame]] matching the runtime API where each input name maps to a list of LazyFrames (multiple connections can share a name). Input vars with suffixed names (main_0, main_1) are grouped under their base name. https://claude.ai/code/session_01Cn56TDT4iPpFpgFL8Fp1pn --- .../flowfile/code_generator/code_generator.py | 29 ++-- .../code_generator/python_script_rewriter.py | 136 ++++++++++++------ .../test_code_generator_python_script.py | 120 +++++++++++++--- .../flowfile/test_python_script_rewriter.py | 45 +++++- 4 files changed, 246 insertions(+), 84 deletions(-) diff --git a/flowfile_core/flowfile_core/flowfile/code_generator/code_generator.py b/flowfile_core/flowfile_core/flowfile/code_generator/code_generator.py index c919bd8f4..97c6bd335 100644 --- a/flowfile_core/flowfile_core/flowfile/code_generator/code_generator.py +++ b/flowfile_core/flowfile_core/flowfile/code_generator/code_generator.py @@ -52,10 +52,12 @@ def __init__(self, flow_graph: FlowGraph): self.last_node_var = None self.unsupported_nodes = [] self.custom_node_classes = {} - # Track which artifacts have been published and by which node (for validation) - self._published_artifacts: dict[str, int] = {} # artifact_name → node_id + # Track which artifacts have been published: (kernel_id, artifact_name) → node_id + self._published_artifacts: dict[tuple[str, str], int] = {} # Track if any python_script nodes exist (to emit _artifacts = {} once) self._has_python_script_nodes: bool = False + # Track which kernel IDs are used (for initializing per-kernel sub-dicts) + self._kernel_ids_used: list[str] = [] def convert(self) -> str: """ @@ -1156,6 +1158,9 @@ def _handle_python_script( )) return self._has_python_script_nodes = True + effective_kernel_id = kernel_id or "_default" + if effective_kernel_id not in self._kernel_ids_used: + self._kernel_ids_used.append(effective_kernel_id) # 2. Check for unsupported patterns if analysis.dynamic_artifact_names: @@ -1176,13 +1181,14 @@ def _handle_python_script( )) return - # 3. Validate artifact dependencies are available + # 3. Validate artifact dependencies are available (same kernel only) for artifact_name in analysis.artifacts_consumed: - if artifact_name not in self._published_artifacts: + if (effective_kernel_id, artifact_name) not in self._published_artifacts: self.unsupported_nodes.append(( node_id, "python_script", - f"Artifact '{artifact_name}' is consumed but not published by any upstream node" + f"Artifact '{artifact_name}' is consumed but not published by any " + f"upstream node on kernel '{effective_kernel_id}'" )) return @@ -1195,12 +1201,12 @@ def _handle_python_script( if kernel_id: self._add_kernel_requirements(kernel_id, user_imports) - # 6. Rewrite the code - rewritten = rewrite_flowfile_calls(code, analysis) + # 6. Rewrite the code (kernel_id scopes artifact access) + rewritten = rewrite_flowfile_calls(code, analysis, kernel_id=kernel_id) # 7. Build and emit the function func_def, call_code = build_function_code( - node_id, rewritten, analysis, input_vars + node_id, rewritten, analysis, input_vars, kernel_id=kernel_id ) self._add_code(f"# --- Node {node_id}: python_script ---") @@ -1211,7 +1217,7 @@ def _handle_python_script( # 8. Track published artifacts for validation of downstream nodes for artifact_name, _ in analysis.artifacts_published: - self._published_artifacts[artifact_name] = node_id + self._published_artifacts[(effective_kernel_id, artifact_name)] = node_id self._add_code("") @@ -1756,9 +1762,10 @@ def _build_final_code(self) -> str: lines.append(" Generated from Flowfile") lines.append(' """') - # Artifact store (only if python_script nodes exist) + # Artifact store — one sub-dict per kernel, matching runtime isolation if self._has_python_script_nodes or self._published_artifacts: - lines.append(" _artifacts = {} # Shared artifact store") + kernel_init = ", ".join(f'"{kid}": {{}}' for kid in self._kernel_ids_used) + lines.append(f" _artifacts = {{{kernel_init}}} # Artifact store (per kernel)") lines.append(" ") diff --git a/flowfile_core/flowfile_core/flowfile/code_generator/python_script_rewriter.py b/flowfile_core/flowfile_core/flowfile/code_generator/python_script_rewriter.py index d7a04c4ff..6fca8e107 100644 --- a/flowfile_core/flowfile_core/flowfile/code_generator/python_script_rewriter.py +++ b/flowfile_core/flowfile_core/flowfile/code_generator/python_script_rewriter.py @@ -5,14 +5,18 @@ enabling code generation for python_script nodes that normally execute inside Docker kernel containers. +Artifacts are scoped per kernel — each kernel gets its own sub-dict inside +``_artifacts``, matching the runtime behaviour where every kernel container +has an independent artifact store. + Mapping: flowfile.read_input() → function parameter (input_df) flowfile.read_inputs() → function parameter (inputs) flowfile.publish_output(expr) → return statement - flowfile.publish_artifact("n", o) → _artifacts["n"] = o - flowfile.read_artifact("n") → _artifacts["n"] - flowfile.delete_artifact("n") → del _artifacts["n"] - flowfile.list_artifacts() → _artifacts + flowfile.publish_artifact("n", o) → _artifacts[""]["n"] = o + flowfile.read_artifact("n") → _artifacts[""]["n"] + flowfile.delete_artifact("n") → del _artifacts[""]["n"] + flowfile.list_artifacts() → _artifacts[""] flowfile.log(msg, level) → print(f"[{level}] {msg}") """ @@ -153,16 +157,41 @@ def analyze_flowfile_usage(code: str) -> FlowfileUsageAnalysis: class _FlowfileCallRewriter(ast.NodeTransformer): - """Rewrite flowfile.* API calls to plain Python equivalents.""" + """Rewrite flowfile.* API calls to plain Python equivalents. + + Artifact operations are scoped to a kernel-specific sub-dict so that + each kernel's artifacts stay isolated, matching runtime semantics. + """ - def __init__(self, analysis: FlowfileUsageAnalysis) -> None: + def __init__(self, analysis: FlowfileUsageAnalysis, kernel_id: str | None = None) -> None: self.analysis = analysis + self.kernel_id = kernel_id or "_default" self.input_var = "input_df" if analysis.input_mode == "single" else "inputs" self._last_output_expr: ast.expr | None = None # Track which publish_output call is the last one if analysis.output_exprs: self._last_output_expr = analysis.output_exprs[-1] + # --- helpers for kernel-scoped artifact access --- + + def _kernel_artifacts_node(self, ctx: type[ast.expr_context] = ast.Load) -> ast.Subscript: + """Build ``_artifacts[""]`` AST node.""" + return ast.Subscript( + value=ast.Name(id="_artifacts", ctx=ast.Load()), + slice=ast.Constant(value=self.kernel_id), + ctx=ctx(), + ) + + def _artifact_subscript(self, name_node: ast.expr, ctx: type[ast.expr_context] = ast.Load) -> ast.Subscript: + """Build ``_artifacts[""][""]`` AST node.""" + return ast.Subscript( + value=self._kernel_artifacts_node(), + slice=name_node, + ctx=ctx(), + ) + + # --- visitors --- + def visit_Call(self, node: ast.Call) -> ast.AST: # First transform any nested calls self.generic_visit(node) @@ -181,16 +210,12 @@ def visit_Call(self, node: ast.Call) -> ast.AST: return ast.Name(id=self.input_var, ctx=ast.Load()) if method == "read_artifact": - # flowfile.read_artifact("name") → _artifacts["name"] - return ast.Subscript( - value=ast.Name(id="_artifacts", ctx=ast.Load()), - slice=node.args[0], - ctx=ast.Load(), - ) + # flowfile.read_artifact("name") → _artifacts["kernel_id"]["name"] + return self._artifact_subscript(node.args[0]) if method == "list_artifacts": - # flowfile.list_artifacts() → _artifacts - return ast.Name(id="_artifacts", ctx=ast.Load()) + # flowfile.list_artifacts() → _artifacts["kernel_id"] + return self._kernel_artifacts_node() if method == "log": return self._make_log_print(node) @@ -224,16 +249,10 @@ def visit_Expr(self, node: ast.Expr) -> ast.AST | None: return None if method == "publish_artifact": - # flowfile.publish_artifact("name", obj) → _artifacts["name"] = obj + # flowfile.publish_artifact("name", obj) → _artifacts["kernel_id"]["name"] = obj if len(call.args) >= 2: return ast.Assign( - targets=[ - ast.Subscript( - value=ast.Name(id="_artifacts", ctx=ast.Load()), - slice=call.args[0], - ctx=ast.Store(), - ) - ], + targets=[self._artifact_subscript(call.args[0], ctx=ast.Store)], value=call.args[1], lineno=node.lineno, col_offset=node.col_offset, @@ -241,16 +260,10 @@ def visit_Expr(self, node: ast.Expr) -> ast.AST | None: return node if method == "delete_artifact": - # flowfile.delete_artifact("name") → del _artifacts["name"] + # flowfile.delete_artifact("name") → del _artifacts["kernel_id"]["name"] if call.args: return ast.Delete( - targets=[ - ast.Subscript( - value=ast.Name(id="_artifacts", ctx=ast.Load()), - slice=call.args[0], - ctx=ast.Del(), - ) - ], + targets=[self._artifact_subscript(call.args[0], ctx=ast.Del)], lineno=node.lineno, col_offset=node.col_offset, ) @@ -311,7 +324,11 @@ def _make_log_print_with_level(node: ast.Call, level_str: str) -> ast.Call: ) -def rewrite_flowfile_calls(code: str, analysis: FlowfileUsageAnalysis) -> str: +def rewrite_flowfile_calls( + code: str, + analysis: FlowfileUsageAnalysis, + kernel_id: str | None = None, +) -> str: """Rewrite flowfile.* API calls in user code to plain Python. This removes/replaces flowfile API calls but does NOT add function @@ -321,12 +338,13 @@ def rewrite_flowfile_calls(code: str, analysis: FlowfileUsageAnalysis) -> str: Args: code: The raw Python source from a python_script node. analysis: Pre-computed analysis of flowfile usage. + kernel_id: The kernel ID for scoping artifact operations. Returns: The rewritten source code with flowfile calls replaced. """ tree = ast.parse(code) - rewriter = _FlowfileCallRewriter(analysis) + rewriter = _FlowfileCallRewriter(analysis, kernel_id=kernel_id) new_tree = rewriter.visit(tree) # Remove None nodes (deleted statements) new_tree.body = [node for node in new_tree.body if node is not None] @@ -389,6 +407,7 @@ def build_function_code( rewritten_code: str, analysis: FlowfileUsageAnalysis, input_vars: dict[str, str], + kernel_id: str | None = None, ) -> tuple[str, str]: """Assemble rewritten code into a function definition and call. @@ -398,6 +417,7 @@ def build_function_code( with imports already stripped). analysis: The flowfile usage analysis. input_vars: Mapping of input names to variable names from upstream nodes. + kernel_id: The kernel ID (used to scope return expressions). Returns: Tuple of (function_definition, call_code). @@ -423,8 +443,14 @@ def build_function_code( break args.append(main_var or "pl.LazyFrame()") elif analysis.input_mode == "multi": - params.append("inputs: dict[str, pl.LazyFrame]") - dict_entries = ", ".join(f'"{k}": {v}' for k, v in sorted(input_vars.items())) + # Runtime returns dict[str, list[pl.LazyFrame]] — each input name + # maps to a *list* of LazyFrames (multiple connections can share a name). + params.append("inputs: dict[str, list[pl.LazyFrame]]") + # Group input_vars by their base name (strip _0, _1 suffixes). + grouped = _group_input_vars(input_vars) + dict_entries = ", ".join( + f'"{k}": [{", ".join(vs)}]' for k, vs in sorted(grouped.items()) + ) args.append("{" + dict_entries + "}") param_str = ", ".join(params) @@ -447,20 +473,17 @@ def build_function_code( # publish_output(read_input()) → return input_df body_lines.append("return input_df") else: - # The output expr was rewritten by the transformer — - # we need to figure out what it became. - # The rewriter removed the publish_output Expr statement. - # We need to produce a return for the last output expression. - # Approach: re-parse and transform just the output expression - output_return = _build_return_for_output(last_expr, analysis) + output_return = _build_return_for_output(last_expr, analysis, kernel_id=kernel_id) body_lines.append(output_return) elif analysis.input_mode == "single": # No explicit output — pass through input body_lines.append("return input_df") elif analysis.input_mode == "multi": - # Pass through first input + # Pass through first input list first_key = sorted(input_vars.keys())[0] if input_vars else "main" - body_lines.append(f'return inputs["{first_key}"]') + # Strip _0 suffix to get the base name + base_key = _base_input_name(first_key) + body_lines.append(f'return inputs["{base_key}"][0]') elif not params: body_lines.append("return None") @@ -478,7 +501,32 @@ def build_function_code( return func_def, call_code -def _build_return_for_output(output_expr: ast.expr, analysis: FlowfileUsageAnalysis) -> str: +def _base_input_name(key: str) -> str: + """Strip numeric suffix from input var keys: 'main_0' → 'main'.""" + parts = key.rsplit("_", 1) + if len(parts) == 2 and parts[1].isdigit(): + return parts[0] + return key + + +def _group_input_vars(input_vars: dict[str, str]) -> dict[str, list[str]]: + """Group input variable names by their base name. + + E.g. {"main_0": "df_1", "main_1": "df_3"} → {"main": ["df_1", "df_3"]} + {"main": "df_1"} → {"main": ["df_1"]} + """ + grouped: dict[str, list[str]] = {} + for key, var in sorted(input_vars.items()): + base = _base_input_name(key) + grouped.setdefault(base, []).append(var) + return grouped + + +def _build_return_for_output( + output_expr: ast.expr, + analysis: FlowfileUsageAnalysis, + kernel_id: str | None = None, +) -> str: """Build a return statement from a publish_output expression. The expression is the original AST node from publish_output(expr). @@ -490,7 +538,7 @@ def _build_return_for_output(output_expr: ast.expr, analysis: FlowfileUsageAnaly # Check if it's just a variable name — common pattern like publish_output(result) # In that case, ensure .lazy() is called for DataFrame returns # We add .lazy() wrapper as a safety measure for DataFrames - rewriter = _FlowfileCallRewriter(analysis) + rewriter = _FlowfileCallRewriter(analysis, kernel_id=kernel_id) expr_tree = ast.parse(temp_code, mode="eval") new_expr = rewriter.visit(expr_tree) ast.fix_missing_locations(new_expr) diff --git a/flowfile_core/tests/flowfile/test_code_generator_python_script.py b/flowfile_core/tests/flowfile/test_code_generator_python_script.py index 200daa8fd..2d1209160 100644 --- a/flowfile_core/tests/flowfile/test_code_generator_python_script.py +++ b/flowfile_core/tests/flowfile/test_code_generator_python_script.py @@ -198,7 +198,7 @@ class TestArtifactCodeGeneration: """Test artifact publish/consume code generation.""" def test_artifact_publish(self): - """publish_artifact becomes _artifacts assignment.""" + """publish_artifact becomes _artifacts[kernel_id] assignment.""" flow = create_basic_flow() add_manual_input_node(flow, node_id=1) @@ -208,21 +208,20 @@ def test_artifact_publish(self): 'flowfile.publish_artifact("my_model", model)\n' "flowfile.publish_output(flowfile.read_input())\n" ) - add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1]) + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1], kernel_id="k1") connect_nodes(flow, 1, 2) generated = export_flow_to_polars(flow) assert "_artifacts" in generated assert "my_model" in generated + assert "k1" in generated assert "flowfile" not in generated - # Should have _artifacts = {} at top level - assert "_artifacts = {}" in generated verify_code_executes(generated) - def test_artifact_chain(self): - """Artifacts flow correctly between two python_script nodes.""" + def test_artifact_chain_same_kernel(self): + """Artifacts flow correctly between two python_script nodes on same kernel.""" flow = create_basic_flow() add_manual_input_node(flow, node_id=1) @@ -232,16 +231,16 @@ def test_artifact_chain(self): 'flowfile.publish_artifact("info", info)\n' "flowfile.publish_output(flowfile.read_input())\n" ) - add_python_script_node(flow, node_id=2, code=producer_code, depending_on_ids=[1]) + add_python_script_node(flow, node_id=2, code=producer_code, depending_on_ids=[1], kernel_id="k1") connect_nodes(flow, 1, 2) - # Consumer node + # Consumer node — same kernel consumer_code = ( 'info = flowfile.read_artifact("info")\n' "df = flowfile.read_input().collect()\n" "flowfile.publish_output(df)\n" ) - add_python_script_node(flow, node_id=3, code=consumer_code, depending_on_ids=[2]) + add_python_script_node(flow, node_id=3, code=consumer_code, depending_on_ids=[2], kernel_id="k1") connect_nodes(flow, 2, 3) generated = export_flow_to_polars(flow) @@ -252,8 +251,32 @@ def test_artifact_chain(self): verify_code_executes(generated) + def test_artifact_cross_kernel_error(self): + """Consuming an artifact from a different kernel should fail.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + # Producer on kernel k1 + producer_code = ( + 'flowfile.publish_artifact("model", {"x": 1})\n' + "flowfile.publish_output(flowfile.read_input())\n" + ) + add_python_script_node(flow, node_id=2, code=producer_code, depending_on_ids=[1], kernel_id="k1") + connect_nodes(flow, 1, 2) + + # Consumer on kernel k2 — different kernel, should fail + consumer_code = ( + 'model = flowfile.read_artifact("model")\n' + "flowfile.publish_output(flowfile.read_input())\n" + ) + add_python_script_node(flow, node_id=3, code=consumer_code, depending_on_ids=[2], kernel_id="k2") + connect_nodes(flow, 2, 3) + + with pytest.raises(UnsupportedNodeError, match="model"): + export_flow_to_polars(flow) + def test_artifact_delete(self): - """delete_artifact becomes del _artifacts[...].""" + """delete_artifact becomes del _artifacts[kernel][...].""" flow = create_basic_flow() add_manual_input_node(flow, node_id=1) @@ -263,13 +286,14 @@ def test_artifact_delete(self): 'flowfile.delete_artifact("temp")\n' "flowfile.publish_output(flowfile.read_input())\n" ) - add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1]) + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1], kernel_id="k1") connect_nodes(flow, 1, 2) generated = export_flow_to_polars(flow) assert "del _artifacts" in generated assert "temp" in generated + assert "k1" in generated verify_code_executes(generated) def test_unconsumed_artifact_error(self): @@ -509,19 +533,38 @@ def test_multiple_python_script_nodes(self): class TestArtifactStoreInitialization: - """Test that _artifacts = {} is emitted properly.""" + """Test that _artifacts is emitted properly with per-kernel sub-dicts.""" def test_artifacts_dict_emitted_for_python_script(self): - """_artifacts = {} should appear when python_script nodes exist.""" + """_artifacts should appear with kernel sub-dict when python_script nodes exist.""" flow = create_basic_flow() add_manual_input_node(flow, node_id=1) code = "flowfile.publish_output(flowfile.read_input())\n" - add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1]) + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1], kernel_id="k1") connect_nodes(flow, 1, 2) generated = export_flow_to_polars(flow) - assert "_artifacts = {}" in generated + assert '_artifacts = {"k1": {}}' in generated + + def test_multiple_kernels_initialized(self): + """Multiple kernels should each get their own sub-dict.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + + code = "flowfile.publish_output(flowfile.read_input())\n" + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1], kernel_id="k1") + connect_nodes(flow, 1, 2) + + add_python_script_node(flow, node_id=3, code=code, depending_on_ids=[2], kernel_id="k2") + connect_nodes(flow, 2, 3) + + generated = export_flow_to_polars(flow) + assert "k1" in generated + assert "k2" in generated + # Both should be initialized as empty dicts + assert "_artifacts" in generated + verify_code_executes(generated) def test_no_artifacts_dict_without_python_script(self): """_artifacts should NOT appear when no python_script nodes exist.""" @@ -558,7 +601,7 @@ class TestFullPipelineExample: """Test the complete example from the specification.""" def test_train_predict_pipeline(self): - """Simulate train → predict pipeline with artifacts.""" + """Simulate train → predict pipeline with artifacts on same kernel.""" flow = create_basic_flow() # Node 1: Manual input with training data @@ -583,10 +626,10 @@ def test_train_predict_pipeline(self): 'flowfile.publish_artifact("model", model)\n' "flowfile.publish_output(flowfile.read_input())\n" ) - add_python_script_node(flow, node_id=2, code=train_code, depending_on_ids=[1]) + add_python_script_node(flow, node_id=2, code=train_code, depending_on_ids=[1], kernel_id="ml") connect_nodes(flow, 1, 2) - # Node 3: Use model (consume artifact) + # Node 3: Use model (consume artifact — same kernel) predict_code = ( "import polars as pl\n" 'model = flowfile.read_artifact("model")\n' @@ -594,16 +637,16 @@ def test_train_predict_pipeline(self): "result = df.with_columns(pl.lit(model['n_features']).alias('n_features'))\n" "flowfile.publish_output(result)\n" ) - add_python_script_node(flow, node_id=3, code=predict_code, depending_on_ids=[2]) + add_python_script_node(flow, node_id=3, code=predict_code, depending_on_ids=[2], kernel_id="ml") connect_nodes(flow, 2, 3) generated = export_flow_to_polars(flow) - # Verify structure - assert "_artifacts = {}" in generated + # Verify structure — kernel-scoped artifacts assert "def _node_2" in generated assert "def _node_3" in generated assert "_artifacts" in generated + assert "ml" in generated assert "model" in generated assert "flowfile" not in generated @@ -616,7 +659,7 @@ def test_train_predict_pipeline(self): assert "n_features" in result.columns def test_list_artifacts_usage(self): - """Test that list_artifacts becomes _artifacts reference.""" + """Test that list_artifacts becomes _artifacts[kernel_id] reference.""" flow = create_basic_flow() add_manual_input_node(flow, node_id=1) @@ -626,12 +669,13 @@ def test_list_artifacts_usage(self): "arts = flowfile.list_artifacts()\n" "flowfile.publish_output(flowfile.read_input())\n" ) - add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1]) + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1], kernel_id="k1") connect_nodes(flow, 1, 2) generated = export_flow_to_polars(flow) assert "flowfile" not in generated + assert "k1" in generated verify_code_executes(generated) def test_node_comment_header(self): @@ -646,3 +690,33 @@ def test_node_comment_header(self): generated = export_flow_to_polars(flow) assert "# --- Node 2: python_script ---" in generated + + +# --------------------------------------------------------------------------- +# Multi-input (read_inputs) tests +# --------------------------------------------------------------------------- + + +class TestMultiInputCodeGeneration: + """Test python_script nodes with multiple inputs (read_inputs).""" + + def test_read_inputs_dict_structure(self): + """read_inputs should produce dict[str, list[pl.LazyFrame]] matching runtime.""" + flow = create_basic_flow() + add_manual_input_node(flow, node_id=1) + add_manual_input_node(flow, node_id=3) + + code = ( + "dfs = flowfile.read_inputs()\n" + "df = dfs['main'][0].collect()\n" + "flowfile.publish_output(df)\n" + ) + add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1, 3]) + connect_nodes(flow, 1, 2) + connect_nodes(flow, 3, 2) + + generated = export_flow_to_polars(flow) + + assert "dict[str, list[pl.LazyFrame]]" in generated + assert "flowfile" not in generated + verify_code_executes(generated) diff --git a/flowfile_core/tests/flowfile/test_python_script_rewriter.py b/flowfile_core/tests/flowfile/test_python_script_rewriter.py index 0b90fea5c..bb372a020 100644 --- a/flowfile_core/tests/flowfile/test_python_script_rewriter.py +++ b/flowfile_core/tests/flowfile/test_python_script_rewriter.py @@ -193,32 +193,50 @@ def test_publish_output_removed(self): def test_publish_artifact_becomes_assignment(self): code = 'flowfile.publish_artifact("model", clf)' analysis = analyze_flowfile_usage(code) - result = rewrite_flowfile_calls(code, analysis) + result = rewrite_flowfile_calls(code, analysis, kernel_id="k1") assert "flowfile" not in result assert "_artifacts" in result + assert "k1" in result assert "model" in result def test_read_artifact_becomes_subscript(self): code = 'model = flowfile.read_artifact("model")' analysis = analyze_flowfile_usage(code) - result = rewrite_flowfile_calls(code, analysis) + result = rewrite_flowfile_calls(code, analysis, kernel_id="k1") assert "flowfile" not in result assert "_artifacts" in result + assert "k1" in result assert "model" in result def test_delete_artifact_becomes_del(self): code = 'flowfile.delete_artifact("model")' analysis = analyze_flowfile_usage(code) - result = rewrite_flowfile_calls(code, analysis) + result = rewrite_flowfile_calls(code, analysis, kernel_id="k1") assert "flowfile" not in result assert "del _artifacts" in result + assert "k1" in result - def test_list_artifacts_becomes_dict(self): + def test_list_artifacts_becomes_kernel_dict(self): code = "arts = flowfile.list_artifacts()" analysis = analyze_flowfile_usage(code) - result = rewrite_flowfile_calls(code, analysis) + result = rewrite_flowfile_calls(code, analysis, kernel_id="k1") assert "flowfile" not in result assert "_artifacts" in result + assert "k1" in result + + def test_default_kernel_id_when_none(self): + code = 'flowfile.publish_artifact("model", clf)' + analysis = analyze_flowfile_usage(code) + result = rewrite_flowfile_calls(code, analysis, kernel_id=None) + assert "_default" in result + + def test_artifacts_scoped_to_kernel(self): + """Verify artifact access includes the kernel_id key.""" + code = 'model = flowfile.read_artifact("model")' + analysis = analyze_flowfile_usage(code) + result = rewrite_flowfile_calls(code, analysis, kernel_id="my_kernel") + assert "my_kernel" in result + assert "model" in result def test_log_becomes_print(self): code = 'flowfile.log("hello")' @@ -336,9 +354,24 @@ def test_multi_input(self): analysis=analysis, input_vars={"main": "df_1", "right": "df_0"}, ) - assert "inputs: dict[str, pl.LazyFrame]" in func_def + # Runtime uses dict[str, list[pl.LazyFrame]] + assert "inputs: dict[str, list[pl.LazyFrame]]" in func_def assert "df_2 = _node_2(" in call_code + def test_multi_input_grouped(self): + """Multiple main inputs (main_0, main_1) should be grouped into a list.""" + code = 'dfs = inputs["main"]' + analysis = FlowfileUsageAnalysis(input_mode="multi") + func_def, call_code = build_function_code( + node_id=2, + rewritten_code=code, + analysis=analysis, + input_vars={"main_0": "df_1", "main_1": "df_3"}, + ) + assert "inputs: dict[str, list[pl.LazyFrame]]" in func_def + # The call should group main_0 and main_1 under "main" + assert '"main": [df_1, df_3]' in call_code + def test_passthrough_return(self): code = "x = 1" analysis = FlowfileUsageAnalysis(input_mode="single", has_output=True, passthrough_output=True) From 828a274661b60eb218c9cf4848a4b5a77bc1076b Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 7 Feb 2026 15:44:23 +0000 Subject: [PATCH 3/4] Emit warnings instead of errors for unsupported flowfile API calls Unsupported calls (publish_global, display, etc.) and dynamic artifact names no longer block code generation. Instead, the generated function includes WARNING comments so users can see what won't work outside the kernel runtime. Also fixes mixed read_input/read_inputs usage: when both are present, read_input() is rewritten to inputs["main"][0] so it stays valid in the multi-input function signature. https://claude.ai/code/session_01Cn56TDT4iPpFpgFL8Fp1pn --- .../flowfile/code_generator/code_generator.py | 21 +-------------- .../code_generator/python_script_rewriter.py | 26 ++++++++++++++++++- .../test_code_generator_python_script.py | 15 ++++++----- 3 files changed, 35 insertions(+), 27 deletions(-) diff --git a/flowfile_core/flowfile_core/flowfile/code_generator/code_generator.py b/flowfile_core/flowfile_core/flowfile/code_generator/code_generator.py index 97c6bd335..415e22435 100644 --- a/flowfile_core/flowfile_core/flowfile/code_generator/code_generator.py +++ b/flowfile_core/flowfile_core/flowfile/code_generator/code_generator.py @@ -1162,26 +1162,7 @@ def _handle_python_script( if effective_kernel_id not in self._kernel_ids_used: self._kernel_ids_used.append(effective_kernel_id) - # 2. Check for unsupported patterns - if analysis.dynamic_artifact_names: - self.unsupported_nodes.append(( - node_id, - "python_script", - "Artifact names must be string literals for code generation. " - f"Found dynamic names at lines: {[getattr(n, 'lineno', '?') for n in analysis.dynamic_artifact_names]}" - )) - return - - if analysis.unsupported_calls: - methods = [m for m, _ in analysis.unsupported_calls] - self.unsupported_nodes.append(( - node_id, - "python_script", - f"Unsupported flowfile API calls for code generation: {', '.join(methods)}" - )) - return - - # 3. Validate artifact dependencies are available (same kernel only) + # 2. Validate artifact dependencies are available (same kernel only) for artifact_name in analysis.artifacts_consumed: if (effective_kernel_id, artifact_name) not in self._published_artifacts: self.unsupported_nodes.append(( diff --git a/flowfile_core/flowfile_core/flowfile/code_generator/python_script_rewriter.py b/flowfile_core/flowfile_core/flowfile/code_generator/python_script_rewriter.py index 6fca8e107..2f300176f 100644 --- a/flowfile_core/flowfile_core/flowfile/code_generator/python_script_rewriter.py +++ b/flowfile_core/flowfile_core/flowfile/code_generator/python_script_rewriter.py @@ -51,6 +51,8 @@ class FlowfileUsageAnalysis: """Results of analyzing flowfile.* API usage in user code.""" input_mode: Literal["none", "single", "multi"] = "none" + has_read_input: bool = False + has_read_inputs: bool = False has_output: bool = False output_exprs: list[ast.expr] = field(default_factory=list) passthrough_output: bool = False @@ -98,8 +100,11 @@ def visit_Call(self, node: ast.Call) -> None: if _is_flowfile_call(node): method = node.func.attr if method == "read_input": - self.analysis.input_mode = "single" + self.analysis.has_read_input = True + if not self.analysis.has_read_inputs: + self.analysis.input_mode = "single" elif method == "read_inputs": + self.analysis.has_read_inputs = True self.analysis.input_mode = "multi" elif method == "publish_output": self.analysis.has_output = True @@ -202,6 +207,17 @@ def visit_Call(self, node: ast.Call) -> ast.AST: method = node.func.attr if method == "read_input": + if self.analysis.input_mode == "multi": + # Both read_input and read_inputs used — read_input() → inputs["main"][0] + return ast.Subscript( + value=ast.Subscript( + value=ast.Name(id="inputs", ctx=ast.Load()), + slice=ast.Constant(value="main"), + ctx=ast.Load(), + ), + slice=ast.Constant(value=0), + ctx=ast.Load(), + ) # flowfile.read_input() → input_df return ast.Name(id=self.input_var, ctx=ast.Load()) @@ -459,6 +475,14 @@ def build_function_code( # Build function body body_lines: list[str] = [] + # Add warnings for unsupported calls / dynamic artifact names + if analysis.unsupported_calls: + methods = sorted({m for m, _ in analysis.unsupported_calls}) + body_lines.append(f"# WARNING: The following flowfile API calls are not supported in code generation") + body_lines.append(f"# and will not work outside the kernel runtime: {', '.join(methods)}") + if analysis.dynamic_artifact_names: + body_lines.append("# WARNING: Dynamic artifact names detected — these may not resolve correctly") + # Strip imports from rewritten code (they go to top-level) body_code = _strip_imports_and_flowfile(rewritten_code) diff --git a/flowfile_core/tests/flowfile/test_code_generator_python_script.py b/flowfile_core/tests/flowfile/test_code_generator_python_script.py index 2d1209160..93cc84d92 100644 --- a/flowfile_core/tests/flowfile/test_code_generator_python_script.py +++ b/flowfile_core/tests/flowfile/test_code_generator_python_script.py @@ -347,7 +347,7 @@ class TestErrorHandling: """Test error cases in python_script code generation.""" def test_dynamic_artifact_name(self): - """Dynamic artifact names should produce UnsupportedNodeError.""" + """Dynamic artifact names should produce a warning comment, not an error.""" flow = create_basic_flow() add_manual_input_node(flow, node_id=1) @@ -359,8 +359,9 @@ def test_dynamic_artifact_name(self): add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1]) connect_nodes(flow, 1, 2) - with pytest.raises(UnsupportedNodeError, match="string literals"): - export_flow_to_polars(flow) + result = export_flow_to_polars(flow) + assert "WARNING" in result + assert "Dynamic artifact names" in result def test_syntax_error_in_code(self): """Syntax errors should produce UnsupportedNodeError.""" @@ -375,7 +376,7 @@ def test_syntax_error_in_code(self): export_flow_to_polars(flow) def test_unsupported_display_call(self): - """flowfile.display should produce UnsupportedNodeError.""" + """flowfile.display should produce a warning comment, not an error.""" flow = create_basic_flow() add_manual_input_node(flow, node_id=1) @@ -386,8 +387,10 @@ def test_unsupported_display_call(self): add_python_script_node(flow, node_id=2, code=code, depending_on_ids=[1]) connect_nodes(flow, 1, 2) - with pytest.raises(UnsupportedNodeError, match="Unsupported flowfile API"): - export_flow_to_polars(flow) + result = export_flow_to_polars(flow) + assert "WARNING" in result + assert "not supported in code generation" in result + assert "display" in result # --------------------------------------------------------------------------- From 22b4e8176e0d017b0545c9e8892b666634ac06a1 Mon Sep 17 00:00:00 2001 From: Claude Date: Sun, 8 Feb 2026 20:22:59 +0000 Subject: [PATCH 4/4] Comment out unsupported flowfile calls instead of leaving them as code Unsupported calls like flowfile.publish_global() are now replaced with inline comments showing the original call, so users can see what was skipped and why. Uses a marker-based approach to survive AST round-trips. https://claude.ai/code/session_01Cn56TDT4iPpFpgFL8Fp1pn --- .../flowfile/code_generator/code_generator.py | 5 ++- .../code_generator/python_script_rewriter.py | 27 ++++++++++-- .../flowfile/test_python_script_rewriter.py | 41 ++++++++++++------- 3 files changed, 53 insertions(+), 20 deletions(-) diff --git a/flowfile_core/flowfile_core/flowfile/code_generator/code_generator.py b/flowfile_core/flowfile_core/flowfile/code_generator/code_generator.py index 415e22435..c0b3d307c 100644 --- a/flowfile_core/flowfile_core/flowfile/code_generator/code_generator.py +++ b/flowfile_core/flowfile_core/flowfile/code_generator/code_generator.py @@ -1183,11 +1183,12 @@ def _handle_python_script( self._add_kernel_requirements(kernel_id, user_imports) # 6. Rewrite the code (kernel_id scopes artifact access) - rewritten = rewrite_flowfile_calls(code, analysis, kernel_id=kernel_id) + rewritten, unsupported_markers = rewrite_flowfile_calls(code, analysis, kernel_id=kernel_id) # 7. Build and emit the function func_def, call_code = build_function_code( - node_id, rewritten, analysis, input_vars, kernel_id=kernel_id + node_id, rewritten, analysis, input_vars, + kernel_id=kernel_id, unsupported_markers=unsupported_markers, ) self._add_code(f"# --- Node {node_id}: python_script ---") diff --git a/flowfile_core/flowfile_core/flowfile/code_generator/python_script_rewriter.py b/flowfile_core/flowfile_core/flowfile/code_generator/python_script_rewriter.py index 2f300176f..ba587dd9f 100644 --- a/flowfile_core/flowfile_core/flowfile/code_generator/python_script_rewriter.py +++ b/flowfile_core/flowfile_core/flowfile/code_generator/python_script_rewriter.py @@ -173,6 +173,7 @@ def __init__(self, analysis: FlowfileUsageAnalysis, kernel_id: str | None = None self.kernel_id = kernel_id or "_default" self.input_var = "input_df" if analysis.input_mode == "single" else "inputs" self._last_output_expr: ast.expr | None = None + self._unsupported_markers: dict[str, str] = {} # Track which publish_output call is the last one if analysis.output_exprs: self._last_output_expr = analysis.output_exprs[-1] @@ -285,6 +286,13 @@ def visit_Expr(self, node: ast.Expr) -> ast.AST | None: ) return node + # Unsupported calls → replace with a marker that becomes a comment + if _is_flowfile_call(call): + source = ast.unparse(node.value) + marker = f"__FLOWFILE_UNSUPPORTED_{len(self._unsupported_markers)}__" + self._unsupported_markers[marker] = source + return ast.Expr(value=ast.Name(id=marker, ctx=ast.Load())) + return node @staticmethod @@ -344,7 +352,7 @@ def rewrite_flowfile_calls( code: str, analysis: FlowfileUsageAnalysis, kernel_id: str | None = None, -) -> str: +) -> tuple[str, dict[str, str]]: """Rewrite flowfile.* API calls in user code to plain Python. This removes/replaces flowfile API calls but does NOT add function @@ -357,7 +365,10 @@ def rewrite_flowfile_calls( kernel_id: The kernel ID for scoping artifact operations. Returns: - The rewritten source code with flowfile calls replaced. + Tuple of (rewritten source code, unsupported call markers dict). + Markers are placeholder variable names mapped to the original source + of unsupported flowfile calls. Callers should replace them with + comments after all AST processing is complete. """ tree = ast.parse(code) rewriter = _FlowfileCallRewriter(analysis, kernel_id=kernel_id) @@ -365,7 +376,9 @@ def rewrite_flowfile_calls( # Remove None nodes (deleted statements) new_tree.body = [node for node in new_tree.body if node is not None] ast.fix_missing_locations(new_tree) - return ast.unparse(new_tree) + result = ast.unparse(new_tree) + + return result, rewriter._unsupported_markers def extract_imports(code: str) -> list[str]: @@ -424,6 +437,7 @@ def build_function_code( analysis: FlowfileUsageAnalysis, input_vars: dict[str, str], kernel_id: str | None = None, + unsupported_markers: dict[str, str] | None = None, ) -> tuple[str, str]: """Assemble rewritten code into a function definition and call. @@ -434,6 +448,8 @@ def build_function_code( analysis: The flowfile usage analysis. input_vars: Mapping of input names to variable names from upstream nodes. kernel_id: The kernel ID (used to scope return expressions). + unsupported_markers: Marker→source mapping from rewrite_flowfile_calls. + These are replaced with comments in the final output. Returns: Tuple of (function_definition, call_code). @@ -518,6 +534,11 @@ def build_function_code( indented_body = textwrap.indent("\n".join(body_lines), " ") func_def = f"def {func_name}({param_str}) -> {return_type}:\n{indented_body}" + # Replace unsupported-call markers with comments + if unsupported_markers: + for marker, source in unsupported_markers.items(): + func_def = func_def.replace(marker, f"# {source} # not supported outside kernel runtime") + # Build call arg_str = ", ".join(args) call_code = f"{var_name} = {func_name}({arg_str})" diff --git a/flowfile_core/tests/flowfile/test_python_script_rewriter.py b/flowfile_core/tests/flowfile/test_python_script_rewriter.py index bb372a020..d1a03146c 100644 --- a/flowfile_core/tests/flowfile/test_python_script_rewriter.py +++ b/flowfile_core/tests/flowfile/test_python_script_rewriter.py @@ -164,28 +164,28 @@ class TestRewriteFlowfileCalls: def test_read_input_replaced(self): code = "df = flowfile.read_input()" analysis = analyze_flowfile_usage(code) - result = rewrite_flowfile_calls(code, analysis) + result, _ = rewrite_flowfile_calls(code, analysis) assert "flowfile" not in result assert "input_df" in result def test_read_input_with_collect(self): code = "df = flowfile.read_input().collect()" analysis = analyze_flowfile_usage(code) - result = rewrite_flowfile_calls(code, analysis) + result, _ = rewrite_flowfile_calls(code, analysis) assert "flowfile" not in result assert "input_df.collect()" in result def test_read_inputs_replaced(self): code = 'dfs = flowfile.read_inputs()\ndf = dfs["main"]' analysis = analyze_flowfile_usage(code) - result = rewrite_flowfile_calls(code, analysis) + result, _ = rewrite_flowfile_calls(code, analysis) assert "flowfile" not in result assert "inputs" in result def test_publish_output_removed(self): code = "x = 1\nflowfile.publish_output(df)\ny = 2" analysis = analyze_flowfile_usage(code) - result = rewrite_flowfile_calls(code, analysis) + result, _ = rewrite_flowfile_calls(code, analysis) assert "publish_output" not in result assert "x = 1" in result assert "y = 2" in result @@ -193,7 +193,7 @@ def test_publish_output_removed(self): def test_publish_artifact_becomes_assignment(self): code = 'flowfile.publish_artifact("model", clf)' analysis = analyze_flowfile_usage(code) - result = rewrite_flowfile_calls(code, analysis, kernel_id="k1") + result, _ = rewrite_flowfile_calls(code, analysis, kernel_id="k1") assert "flowfile" not in result assert "_artifacts" in result assert "k1" in result @@ -202,7 +202,7 @@ def test_publish_artifact_becomes_assignment(self): def test_read_artifact_becomes_subscript(self): code = 'model = flowfile.read_artifact("model")' analysis = analyze_flowfile_usage(code) - result = rewrite_flowfile_calls(code, analysis, kernel_id="k1") + result, _ = rewrite_flowfile_calls(code, analysis, kernel_id="k1") assert "flowfile" not in result assert "_artifacts" in result assert "k1" in result @@ -211,7 +211,7 @@ def test_read_artifact_becomes_subscript(self): def test_delete_artifact_becomes_del(self): code = 'flowfile.delete_artifact("model")' analysis = analyze_flowfile_usage(code) - result = rewrite_flowfile_calls(code, analysis, kernel_id="k1") + result, _ = rewrite_flowfile_calls(code, analysis, kernel_id="k1") assert "flowfile" not in result assert "del _artifacts" in result assert "k1" in result @@ -219,7 +219,7 @@ def test_delete_artifact_becomes_del(self): def test_list_artifacts_becomes_kernel_dict(self): code = "arts = flowfile.list_artifacts()" analysis = analyze_flowfile_usage(code) - result = rewrite_flowfile_calls(code, analysis, kernel_id="k1") + result, _ = rewrite_flowfile_calls(code, analysis, kernel_id="k1") assert "flowfile" not in result assert "_artifacts" in result assert "k1" in result @@ -227,51 +227,62 @@ def test_list_artifacts_becomes_kernel_dict(self): def test_default_kernel_id_when_none(self): code = 'flowfile.publish_artifact("model", clf)' analysis = analyze_flowfile_usage(code) - result = rewrite_flowfile_calls(code, analysis, kernel_id=None) + result, _ = rewrite_flowfile_calls(code, analysis, kernel_id=None) assert "_default" in result def test_artifacts_scoped_to_kernel(self): """Verify artifact access includes the kernel_id key.""" code = 'model = flowfile.read_artifact("model")' analysis = analyze_flowfile_usage(code) - result = rewrite_flowfile_calls(code, analysis, kernel_id="my_kernel") + result, _ = rewrite_flowfile_calls(code, analysis, kernel_id="my_kernel") assert "my_kernel" in result assert "model" in result def test_log_becomes_print(self): code = 'flowfile.log("hello")' analysis = analyze_flowfile_usage(code) - result = rewrite_flowfile_calls(code, analysis) + result, _ = rewrite_flowfile_calls(code, analysis) assert "print" in result assert "flowfile" not in result def test_log_with_level_becomes_print(self): code = 'flowfile.log("hello", "ERROR")' analysis = analyze_flowfile_usage(code) - result = rewrite_flowfile_calls(code, analysis) + result, _ = rewrite_flowfile_calls(code, analysis) assert "print" in result def test_log_info_becomes_print(self): code = 'flowfile.log_info("processing")' analysis = analyze_flowfile_usage(code) - result = rewrite_flowfile_calls(code, analysis) + result, _ = rewrite_flowfile_calls(code, analysis) assert "print" in result assert "INFO" in result def test_non_flowfile_code_unchanged(self): code = "x = 1 + 2\ny = x * 3" analysis = analyze_flowfile_usage(code) - result = rewrite_flowfile_calls(code, analysis) + result, _ = rewrite_flowfile_calls(code, analysis) assert "x = 1 + 2" in result assert "y = x * 3" in result def test_chained_collect(self): code = "df = flowfile.read_input().collect()\nresult = df.select(['a'])" analysis = analyze_flowfile_usage(code) - result = rewrite_flowfile_calls(code, analysis) + result, _ = rewrite_flowfile_calls(code, analysis) assert "input_df.collect()" in result assert "result = df.select" in result + def test_unsupported_call_returns_marker(self): + """Unsupported calls should produce markers for comment replacement.""" + code = "flowfile.publish_global('model', obj)\nx = 1" + analysis = analyze_flowfile_usage(code) + result, markers = rewrite_flowfile_calls(code, analysis) + assert len(markers) == 1 + marker = list(markers.keys())[0] + assert marker in result + source = list(markers.values())[0] + assert "publish_global" in source + # --------------------------------------------------------------------------- # Tests for extract_imports