From 0a352ac9ded77535da6ac8b176ebcdbbee1108c7 Mon Sep 17 00:00:00 2001 From: Stefano Fioravanzo Date: Mon, 2 Feb 2026 14:26:23 +0100 Subject: [PATCH 1/5] feat: Add AST-based import parsing module MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add new imports.py module with centralized package name resolution: - STDLIB_MODULES: Comprehensive set of Python 3.10+ stdlib modules - PACKAGE_NAME_MAP: Centralized registry mapping import names to PyPI package names (sklearn→scikit-learn, cv2→opencv-python, etc.) - ImportInfo dataclass: Structured representation of parsed imports - parse_imports_ast(): Parse all import forms using Python AST - get_packages_to_install(): Extract pip package names from code - is_stdlib_module(): Check if a module is part of stdlib This replaces the brittle string-based import parsing in the compiler with proper AST analysis that handles all Python import forms including multi-line imports, parenthesized imports, and aliases. Signed-off-by: Stefano Fioravanzo --- backend/kale/common/imports.py | 251 +++++++++++++++++++++++++++++++++ 1 file changed, 251 insertions(+) create mode 100644 backend/kale/common/imports.py diff --git a/backend/kale/common/imports.py b/backend/kale/common/imports.py new file mode 100644 index 00000000..e17e38b2 --- /dev/null +++ b/backend/kale/common/imports.py @@ -0,0 +1,251 @@ +# Copyright 2026 The Kubeflow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""AST-based import parsing and package name resolution. + +This module provides utilities for parsing Python import statements using AST +and resolving import names to their corresponding PyPI package names. + +The key components are: +- STDLIB_MODULES: Set of Python standard library module names +- PACKAGE_NAME_MAP: Mapping from import names to PyPI package names +- parse_imports_ast(): Parse import statements from code using AST +- get_packages_to_install(): Get pip-installable package names from code +""" + +import ast +import sys +from dataclasses import dataclass +from typing import Dict, List, Optional, Set + + +# Python standard library modules (should not be pip installed) +# This is a comprehensive list for Python 3.10+ +STDLIB_MODULES: Set[str] = { + # Built-in modules + "abc", "aifc", "argparse", "array", "ast", "asynchat", "asyncio", + "asyncore", "atexit", "audioop", "base64", "bdb", "binascii", + "binhex", "bisect", "builtins", "bz2", + "calendar", "cgi", "cgitb", "chunk", "cmath", "cmd", "code", + "codecs", "codeop", "collections", "colorsys", "compileall", + "concurrent", "configparser", "contextlib", "contextvars", "copy", + "copyreg", "cProfile", "crypt", "csv", "ctypes", "curses", + "dataclasses", "datetime", "dbm", "decimal", "difflib", "dis", + "distutils", "doctest", + "email", "encodings", "enum", "errno", + "faulthandler", "fcntl", "filecmp", "fileinput", "fnmatch", + "fractions", "ftplib", "functools", + "gc", "getopt", "getpass", "gettext", "glob", "graphlib", "grp", "gzip", + "hashlib", "heapq", "hmac", "html", "http", + "idlelib", "imaplib", "imghdr", "imp", "importlib", "inspect", "io", + "ipaddress", "itertools", + "json", + "keyword", + "lib2to3", "linecache", "locale", "logging", "lzma", + "mailbox", "mailcap", "marshal", "math", "mimetypes", "mmap", + "modulefinder", "multiprocessing", + "netrc", "nis", "nntplib", "numbers", + "operator", "optparse", "os", "ossaudiodev", + "pathlib", "pdb", "pickle", "pickletools", "pipes", "pkgutil", + "platform", "plistlib", "poplib", "posix", "posixpath", "pprint", + "profile", "pstats", "pty", "pwd", "py_compile", "pyclbr", "pydoc", + "queue", "quopri", + "random", "re", "readline", "reprlib", "resource", "rlcompleter", + "runpy", + "sched", "secrets", "select", "selectors", "shelve", "shlex", + "shutil", "signal", "site", "smtpd", "smtplib", "sndhdr", "socket", + "socketserver", "spwd", "sqlite3", "ssl", "stat", "statistics", + "string", "stringprep", "struct", "subprocess", "sunau", "symtable", + "sys", "sysconfig", "syslog", + "tabnanny", "tarfile", "telnetlib", "tempfile", "termios", "test", + "textwrap", "threading", "time", "timeit", "tkinter", "token", + "tokenize", "trace", "traceback", "tracemalloc", "tty", "turtle", + "turtledemo", "types", "typing", + "unicodedata", "unittest", "urllib", "uu", "uuid", + "venv", + "warnings", "wave", "weakref", "webbrowser", "winreg", "winsound", + "wsgiref", + "xdrlib", "xml", "xmlrpc", + "zipapp", "zipfile", "zipimport", "zlib", "zoneinfo", + # Common submodules that might be imported directly + "os.path", "urllib.parse", "urllib.request", "collections.abc", + "typing_extensions", # Not stdlib but often bundled +} + + +# Mapping from Python import names to PyPI package names. +# - If the import name matches the PyPI name, it doesn't need to be here +# - If the value is None, the package should be skipped (e.g., stdlib) +# - Add new mappings here as needed +PACKAGE_NAME_MAP: Dict[str, Optional[str]] = { + # Common packages where import name differs from PyPI name + "sklearn": "scikit-learn", + "cv2": "opencv-python", + "PIL": "pillow", + "yaml": "pyyaml", + "bs4": "beautifulsoup4", + "skimage": "scikit-image", + "dateutil": "python-dateutil", + "dotenv": "python-dotenv", + "jwt": "pyjwt", + "magic": "python-magic", + "serial": "pyserial", + "usb": "pyusb", + "git": "gitpython", + "Bio": "biopython", + "OpenSSL": "pyopenssl", + "Crypto": "pycryptodome", + "google.protobuf": "protobuf", + "google.cloud": "google-cloud-core", +} + + +@dataclass +class ImportInfo: + """Structured import information extracted from AST parsing. + + Attributes: + module: The full module path (e.g., "sklearn.ensemble") + names: List of imported names (e.g., ["RandomForestClassifier"]) + alias: The alias if used (e.g., "rf" from "import x as rf") + is_from: True if this is a "from X import Y" statement + line_number: Source line number for error reporting + """ + module: str + names: List[str] + alias: Optional[str] + is_from: bool + line_number: int + + @property + def top_level_package(self) -> str: + """Get the top-level package name (before the first dot).""" + return self.module.split('.')[0] + + def get_pypi_package(self) -> Optional[str]: + """Get the PyPI package name for this import. + + Returns: + The PyPI package name, or None if this is a stdlib module. + """ + top_level = self.top_level_package + + # Check if it's a stdlib module + if top_level in STDLIB_MODULES or self.module in STDLIB_MODULES: + return None + + # Check for explicit mapping (try full module path first, then top-level) + if self.module in PACKAGE_NAME_MAP: + return PACKAGE_NAME_MAP[self.module] + if top_level in PACKAGE_NAME_MAP: + return PACKAGE_NAME_MAP[top_level] + + # Default: assume import name matches PyPI package name + return top_level + + +def parse_imports_ast(code: str) -> List[ImportInfo]: + """Parse all import statements from Python code using AST. + + This function properly handles all Python import forms: + - import foo + - import foo.bar + - import foo as f + - from foo import bar + - from foo.bar import baz + - from foo import bar, baz + - from foo import (bar, baz) + - from foo import bar as b + + Args: + code: Python source code as a string + + Returns: + List of ImportInfo objects representing each import statement + + Raises: + SyntaxError: If the code cannot be parsed + """ + imports: List[ImportInfo] = [] + + try: + tree = ast.parse(code) + except SyntaxError: + # If we can't parse the code, return empty list + # The caller can decide how to handle this + return imports + + for node in ast.walk(tree): + if isinstance(node, ast.Import): + # Handle: import foo, import foo.bar, import foo as f + for alias in node.names: + imports.append(ImportInfo( + module=alias.name, + names=[alias.name.split('.')[-1]], + alias=alias.asname, + is_from=False, + line_number=node.lineno + )) + elif isinstance(node, ast.ImportFrom): + # Handle: from foo import bar, from foo import bar as b + if node.module is not None: + # Regular from import + imports.append(ImportInfo( + module=node.module, + names=[a.name for a in node.names], + alias=node.names[0].asname if len(node.names) == 1 else None, + is_from=True, + line_number=node.lineno + )) + # else: relative import like "from . import x" - skip these + # as they won't need to be pip installed + + return imports + + +def get_packages_to_install(code: str) -> Set[str]: + """Extract pip-installable package names from Python code. + + This function parses the import statements in the code and returns + the set of PyPI package names that would need to be installed. + + Standard library modules are automatically filtered out. + + Args: + code: Python source code as a string + + Returns: + Set of PyPI package names to install + """ + packages: Set[str] = set() + + for imp in parse_imports_ast(code): + pkg = imp.get_pypi_package() + if pkg is not None: + packages.add(pkg) + + return packages + + +def is_stdlib_module(module_name: str) -> bool: + """Check if a module name is part of the Python standard library. + + Args: + module_name: The module name to check (can be dotted like "os.path") + + Returns: + True if the module is part of stdlib, False otherwise + """ + top_level = module_name.split('.')[0] + return top_level in STDLIB_MODULES or module_name in STDLIB_MODULES From ac0265425c968544d6cde97870eeb5f7e3c66280 Mon Sep 17 00:00:00 2001 From: Stefano Fioravanzo Date: Mon, 2 Feb 2026 14:33:33 +0100 Subject: [PATCH 2/5] test: Add comprehensive unit tests for imports module Add test_imports.py with 70 test cases covering: - TestStdlibModules: Verify stdlib module detection - TestPackageNameMap: Verify import-to-PyPI mappings - TestImportInfo: Test the ImportInfo dataclass methods - TestParseImportsAst: Test AST parsing of various import forms - Simple imports, aliases, from imports - Multiple names, nested modules, parenthesized imports - Dotted imports, mixed code, line number tracking - TestGetPackagesToInstall: Test package extraction - stdlib filtering, package name mapping, deduplication - Real-world data science import patterns - TestIsStdlibModule: Test stdlib detection helper Signed-off-by: Stefano Fioravanzo --- backend/kale/tests/unit_tests/test_imports.py | 370 ++++++++++++++++++ 1 file changed, 370 insertions(+) create mode 100644 backend/kale/tests/unit_tests/test_imports.py diff --git a/backend/kale/tests/unit_tests/test_imports.py b/backend/kale/tests/unit_tests/test_imports.py new file mode 100644 index 00000000..553710cf --- /dev/null +++ b/backend/kale/tests/unit_tests/test_imports.py @@ -0,0 +1,370 @@ +# Copyright 2026 The Kubeflow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the imports module.""" + +import pytest + +from kale.common.imports import ( + ImportInfo, + STDLIB_MODULES, + PACKAGE_NAME_MAP, + parse_imports_ast, + get_packages_to_install, + is_stdlib_module, +) + + +class TestStdlibModules: + """Tests for STDLIB_MODULES set.""" + + @pytest.mark.parametrize("module", [ + "os", "sys", "re", "json", "random", "collections", + "typing", "pathlib", "datetime", "functools", "itertools", + "math", "subprocess", "threading", "multiprocessing", + ]) + def test_common_stdlib_modules_included(self, module): + """Verify common stdlib modules are in the set.""" + assert module in STDLIB_MODULES + + @pytest.mark.parametrize("module", [ + "numpy", "pandas", "sklearn", "tensorflow", "torch", + "requests", "flask", "django", "pytest", + ]) + def test_third_party_modules_not_included(self, module): + """Verify third-party modules are not in stdlib set.""" + assert module not in STDLIB_MODULES + + +class TestPackageNameMap: + """Tests for PACKAGE_NAME_MAP dictionary.""" + + @pytest.mark.parametrize("import_name,pypi_name", [ + ("sklearn", "scikit-learn"), + ("cv2", "opencv-python"), + ("PIL", "pillow"), + ("yaml", "pyyaml"), + ("bs4", "beautifulsoup4"), + ("skimage", "scikit-image"), + ]) + def test_common_mappings_exist(self, import_name, pypi_name): + """Verify common import-to-PyPI mappings are correct.""" + assert PACKAGE_NAME_MAP.get(import_name) == pypi_name + + +class TestImportInfo: + """Tests for ImportInfo dataclass.""" + + def test_top_level_package_simple(self): + """Test top_level_package with simple module.""" + info = ImportInfo( + module="numpy", + names=["array"], + alias=None, + is_from=True, + line_number=1 + ) + assert info.top_level_package == "numpy" + + def test_top_level_package_nested(self): + """Test top_level_package with nested module.""" + info = ImportInfo( + module="sklearn.ensemble", + names=["RandomForestClassifier"], + alias=None, + is_from=True, + line_number=1 + ) + assert info.top_level_package == "sklearn" + + def test_get_pypi_package_direct_mapping(self): + """Test get_pypi_package with direct mapping.""" + info = ImportInfo( + module="sklearn", + names=["sklearn"], + alias=None, + is_from=False, + line_number=1 + ) + assert info.get_pypi_package() == "scikit-learn" + + def test_get_pypi_package_nested_mapping(self): + """Test get_pypi_package with nested module mapping.""" + info = ImportInfo( + module="sklearn.ensemble", + names=["RandomForestClassifier"], + alias=None, + is_from=True, + line_number=1 + ) + assert info.get_pypi_package() == "scikit-learn" + + def test_get_pypi_package_stdlib_returns_none(self): + """Test get_pypi_package returns None for stdlib.""" + info = ImportInfo( + module="os", + names=["path"], + alias=None, + is_from=True, + line_number=1 + ) + assert info.get_pypi_package() is None + + def test_get_pypi_package_no_mapping(self): + """Test get_pypi_package with no mapping (uses import name).""" + info = ImportInfo( + module="numpy", + names=["array"], + alias=None, + is_from=True, + line_number=1 + ) + assert info.get_pypi_package() == "numpy" + + +class TestParseImportsAst: + """Tests for parse_imports_ast function.""" + + def test_empty_code(self): + """Test parsing empty code.""" + result = parse_imports_ast("") + assert result == [] + + def test_simple_import(self): + """Test parsing simple import statement.""" + code = "import numpy" + result = parse_imports_ast(code) + assert len(result) == 1 + assert result[0].module == "numpy" + assert result[0].is_from is False + + def test_import_with_alias(self): + """Test parsing import with alias.""" + code = "import numpy as np" + result = parse_imports_ast(code) + assert len(result) == 1 + assert result[0].module == "numpy" + assert result[0].alias == "np" + assert result[0].is_from is False + + def test_from_import(self): + """Test parsing from import statement.""" + code = "from sklearn import ensemble" + result = parse_imports_ast(code) + assert len(result) == 1 + assert result[0].module == "sklearn" + assert result[0].names == ["ensemble"] + assert result[0].is_from is True + + def test_from_import_multiple_names(self): + """Test parsing from import with multiple names.""" + code = "from os import path, getcwd, listdir" + result = parse_imports_ast(code) + assert len(result) == 1 + assert result[0].module == "os" + assert sorted(result[0].names) == sorted(["path", "getcwd", "listdir"]) + assert result[0].is_from is True + + def test_from_import_nested_module(self): + """Test parsing from import with nested module.""" + code = "from sklearn.ensemble import RandomForestClassifier" + result = parse_imports_ast(code) + assert len(result) == 1 + assert result[0].module == "sklearn.ensemble" + assert result[0].names == ["RandomForestClassifier"] + assert result[0].is_from is True + + def test_multiple_imports(self): + """Test parsing multiple import statements.""" + code = """ +import os +import numpy as np +from pandas import DataFrame, Series +from sklearn.ensemble import RandomForestClassifier +""" + result = parse_imports_ast(code) + assert len(result) == 4 + + modules = [r.module for r in result] + assert "os" in modules + assert "numpy" in modules + assert "pandas" in modules + assert "sklearn.ensemble" in modules + + def test_import_with_parentheses(self): + """Test parsing from import with parentheses.""" + code = """from collections import ( + OrderedDict, + defaultdict, + Counter +)""" + result = parse_imports_ast(code) + assert len(result) == 1 + assert result[0].module == "collections" + assert sorted(result[0].names) == sorted( + ["OrderedDict", "defaultdict", "Counter"]) + + def test_dotted_import(self): + """Test parsing dotted import.""" + code = "import os.path" + result = parse_imports_ast(code) + assert len(result) == 1 + assert result[0].module == "os.path" + assert result[0].is_from is False + + def test_invalid_code_returns_empty(self): + """Test that invalid code returns empty list.""" + code = "def foo(\n pass" # Invalid syntax + result = parse_imports_ast(code) + assert result == [] + + def test_code_with_imports_and_other_statements(self): + """Test parsing code with mixed statements.""" + code = """ +import numpy as np + +x = 5 +y = np.array([1, 2, 3]) + +from pandas import DataFrame + +df = DataFrame({'a': [1, 2, 3]}) +""" + result = parse_imports_ast(code) + assert len(result) == 2 + modules = [r.module for r in result] + assert "numpy" in modules + assert "pandas" in modules + + def test_line_numbers(self): + """Test that line numbers are captured.""" + code = """import os +import sys +from json import loads""" + result = parse_imports_ast(code) + assert result[0].line_number == 1 + assert result[1].line_number == 2 + assert result[2].line_number == 3 + + +class TestGetPackagesToInstall: + """Tests for get_packages_to_install function.""" + + def test_empty_code(self): + """Test with empty code.""" + result = get_packages_to_install("") + assert result == set() + + def test_stdlib_only(self): + """Test code with only stdlib imports.""" + code = """ +import os +import sys +import json +from collections import defaultdict +""" + result = get_packages_to_install(code) + assert result == set() + + def test_single_package(self): + """Test code with single third-party package.""" + code = "import numpy" + result = get_packages_to_install(code) + assert result == {"numpy"} + + def test_multiple_packages(self): + """Test code with multiple third-party packages.""" + code = """ +import numpy +import pandas +from sklearn.ensemble import RandomForestClassifier +""" + result = get_packages_to_install(code) + assert result == {"numpy", "pandas", "scikit-learn"} + + def test_mixed_stdlib_and_third_party(self): + """Test code with mixed stdlib and third-party imports.""" + code = """ +import os +import sys +import numpy as np +from json import loads +from pandas import DataFrame +""" + result = get_packages_to_install(code) + assert result == {"numpy", "pandas"} + + def test_package_name_mapping(self): + """Test that package name mapping is applied.""" + code = """ +import sklearn +from cv2 import imread +import PIL +import yaml +""" + result = get_packages_to_install(code) + expected = {"scikit-learn", "opencv-python", "pillow", "pyyaml"} + assert result == expected + + def test_deduplication(self): + """Test that duplicate packages are deduplicated.""" + code = """ +import numpy +import numpy as np +from numpy import array +from numpy.random import rand +""" + result = get_packages_to_install(code) + assert result == {"numpy"} + + def test_real_world_data_science_imports(self): + """Test with realistic data science imports.""" + code = """ +import os +import sys +import json +import numpy as np +import pandas as pd +from sklearn.model_selection import train_test_split +from sklearn.ensemble import RandomForestClassifier +from sklearn.metrics import accuracy_score +import matplotlib.pyplot as plt +import seaborn as sns +""" + result = get_packages_to_install(code) + expected = { + "numpy", "pandas", "scikit-learn", + "matplotlib", "seaborn" + } + assert result == expected + + +class TestIsStdlibModule: + """Tests for is_stdlib_module function.""" + + @pytest.mark.parametrize("module", [ + "os", "sys", "re", "json", "random", + "os.path", "collections.abc", "urllib.parse", + ]) + def test_stdlib_modules(self, module): + """Test that stdlib modules are correctly identified.""" + assert is_stdlib_module(module) is True + + @pytest.mark.parametrize("module", [ + "numpy", "pandas", "sklearn", "tensorflow", + "numpy.random", "pandas.core", + ]) + def test_non_stdlib_modules(self, module): + """Test that non-stdlib modules are correctly identified.""" + assert is_stdlib_module(module) is False From e907eeec733fe8f32255e76661a3ec53cc250ac3 Mon Sep 17 00:00:00 2001 From: Stefano Fioravanzo Date: Mon, 2 Feb 2026 14:34:49 +0100 Subject: [PATCH 3/5] refactor: Use AST-based import parsing in compiler Replace the brittle string-based import parsing in compiler.py with the new AST-based approach from the imports module. Before: - String splitting on 'import ' and 'from ' prefixes - Hardcoded if/elif chains for package name mapping - Only handled 'random' and 'sklearn' special cases - Could not handle multi-line imports or parenthesized imports - Added stdlib modules to packages_to_install After: - Proper AST parsing via get_packages_to_install() - Centralized PACKAGE_NAME_MAP with 12+ common mappings - Handles all Python import forms correctly - Filters out stdlib modules automatically - Extensible: just add to PACKAGE_NAME_MAP for new cases The _get_package_list_from_imports method now delegates to the imports module, reducing it from 25 lines to 10 lines while significantly improving correctness and maintainability. Signed-off-by: Stefano Fioravanzo --- backend/kale/compiler.py | 39 +++++++++++++-------------------------- 1 file changed, 13 insertions(+), 26 deletions(-) diff --git a/backend/kale/compiler.py b/backend/kale/compiler.py index a965feb1..5b443d46 100644 --- a/backend/kale/compiler.py +++ b/backend/kale/compiler.py @@ -23,6 +23,7 @@ from kale import __version__ as KALE_VERSION from kale.common import graphutils, kfputils, utils +from kale.common.imports import get_packages_to_install from kale.pipeline import Pipeline, PipelineParam, Step log = logging.getLogger(__name__) @@ -260,42 +261,28 @@ def generate_pipeline(self, lightweight_components): return autopep8.fix_code(pipeline_code) def _get_package_list_from_imports(self): - """Extracts unique package names from the tagged imports cell. + """Extract pip-installable package names from imports using AST. - Args: - imports_str: A string containing Python import statements. + Uses the imports module to parse Python import statements via AST + and resolve them to their corresponding PyPI package names. This + properly handles all import forms and filters out stdlib modules. Returns: - A list of unique top-level package names. + A sorted list of unique PyPI package names to install. """ package_names = set() + + # Always include kale and kfp as dependencies if KALE_VERSION != "0+unknown": package_names.add(f"kubeflow-kale=={KALE_VERSION}") else: package_names.add("kubeflow-kale") package_names.add("kfp>=2.0.0") - lines = self.imports_and_functions.strip().split("\n") - - for line in lines: - line = line.strip() - if line.startswith("import "): - # For 'import package' or 'import package as alias' - parts = line.split(" ") - if len(parts) > 1: - package_name = parts[1].split(".")[0] - if package_name == "random": - package_name = "random2" - if package_name == "sklearn": - package_name = "scikit-learn" - package_names.add(package_name) - elif line.startswith("from "): - parts = line.split(" ") - if len(parts) > 1: - package_name = parts[1].split(".")[0] - if package_name == "sklearn": - package_name = "scikit-learn" - package_names.add(package_name) - return sorted(package_names) + + # Parse imports using AST and resolve to PyPI package names + package_names.update(get_packages_to_install(self.imports_and_functions)) + + return sorted(list(package_names)) def _get_templating_env(self, templates_path=None): if self.templating_env: From 74926a43894b610efd51cdc4303f3b317d6728e2 Mon Sep 17 00:00:00 2001 From: Stefano Fioravanzo Date: Wed, 4 Feb 2026 21:26:23 +0100 Subject: [PATCH 4/5] style: Fix import sorting in test_imports.py Sort imports alphabetically to satisfy ruff linter. Signed-off-by: Stefano Fioravanzo --- backend/kale/tests/unit_tests/test_imports.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/backend/kale/tests/unit_tests/test_imports.py b/backend/kale/tests/unit_tests/test_imports.py index 553710cf..d947b660 100644 --- a/backend/kale/tests/unit_tests/test_imports.py +++ b/backend/kale/tests/unit_tests/test_imports.py @@ -17,12 +17,12 @@ import pytest from kale.common.imports import ( - ImportInfo, - STDLIB_MODULES, PACKAGE_NAME_MAP, - parse_imports_ast, + STDLIB_MODULES, + ImportInfo, get_packages_to_install, is_stdlib_module, + parse_imports_ast, ) From 965e4f4817fbe26277c9e4037b1846df8ef04e3a Mon Sep 17 00:00:00 2001 From: Stefano Fioravanzo Date: Thu, 5 Feb 2026 22:21:03 +0100 Subject: [PATCH 5/5] style: Apply ruff lint fixes for modern Python type annotations Signed-off-by: Stefano Fioravanzo --- backend/kale/common/astutils.py | 29 +- backend/kale/common/imports.py | 299 ++++++++++++++---- backend/kale/common/kfputils.py | 2 +- backend/kale/compiler.py | 2 +- backend/kale/processors/pyprocessor.py | 2 +- backend/kale/tests/unit_tests/test_imports.py | 136 ++++---- 6 files changed, 320 insertions(+), 150 deletions(-) diff --git a/backend/kale/common/astutils.py b/backend/kale/common/astutils.py index f77f2386..26a50f96 100644 --- a/backend/kale/common/astutils.py +++ b/backend/kale/common/astutils.py @@ -59,13 +59,13 @@ def get_list_tuple_names(node): Returns: a list of all names of the tuple """ - assert isinstance(node, (ast.Tuple, ast.List)) + assert isinstance(node, ast.Tuple | ast.List) names = [] for _n in node.elts: - if isinstance(_n, (ast.Tuple, ast.List)): + if isinstance(_n, ast.Tuple | ast.List): # recursive call names.extend(get_list_tuple_names(_n)) - elif isinstance(_n, (ast.Name,)): + elif isinstance(_n, ast.Name): names.append(_n.id) return names @@ -135,21 +135,18 @@ def get_marshal_candidates(code): for node in walk(block, stop_at=contexts): if isinstance(node, contexts): names.add(node.name) - if isinstance(node, (ast.Name,)): + if isinstance(node, ast.Name): names.add(node.id) if isinstance( node, - ( - ast.Import, - ast.ImportFrom, - ), + ast.Import | ast.ImportFrom, ): for _n in node.names: if _n.asname is None: names.add(_n.name) else: names.add(_n.asname) - if isinstance(node, (ast.Tuple, ast.List)): + if isinstance(node, ast.Tuple | ast.List): names.update(get_list_tuple_names(node)) return names @@ -171,7 +168,7 @@ def parse_functions(code): tree = ast.parse(code) for block in tree.body: for node in walk(block, stop_at=(ast.FunctionDef,), ignore=(ast.ClassDef,)): - if isinstance(node, (ast.FunctionDef,)): + if isinstance(node, ast.FunctionDef): fn_name = node.name fns[fn_name] = astor.to_source(node) return fns @@ -208,7 +205,7 @@ def get_function_calls(code): # a function call. We check the attribute func to be ast.Name # because it could also be a ast.Attribute node, in case of # function calls like obj.foo() - if isinstance(node, (ast.Call,)) and isinstance(node.func, (ast.Name,)): + if isinstance(node, ast.Call) and isinstance(node.func, ast.Name): fns.add(node.func.id) return fns @@ -230,10 +227,7 @@ def get_function_and_class_names(code): for node in walk(block): if isinstance( node, - ( - ast.FunctionDef, - ast.ClassDef, - ), + ast.FunctionDef | ast.ClassDef, ): names.add(node.name) return names @@ -256,10 +250,7 @@ def parse_assignments_expressions(code): if ( isinstance( targets[0], - ( - ast.Tuple, - ast.List, - ), + ast.Tuple | ast.List, ) or len(targets) > 1 ): diff --git a/backend/kale/common/imports.py b/backend/kale/common/imports.py index e17e38b2..8a48fc23 100644 --- a/backend/kale/common/imports.py +++ b/backend/kale/common/imports.py @@ -25,61 +25,220 @@ """ import ast -import sys from dataclasses import dataclass -from typing import Dict, List, Optional, Set - # Python standard library modules (should not be pip installed) # This is a comprehensive list for Python 3.10+ -STDLIB_MODULES: Set[str] = { +STDLIB_MODULES: set[str] = { # Built-in modules - "abc", "aifc", "argparse", "array", "ast", "asynchat", "asyncio", - "asyncore", "atexit", "audioop", "base64", "bdb", "binascii", - "binhex", "bisect", "builtins", "bz2", - "calendar", "cgi", "cgitb", "chunk", "cmath", "cmd", "code", - "codecs", "codeop", "collections", "colorsys", "compileall", - "concurrent", "configparser", "contextlib", "contextvars", "copy", - "copyreg", "cProfile", "crypt", "csv", "ctypes", "curses", - "dataclasses", "datetime", "dbm", "decimal", "difflib", "dis", - "distutils", "doctest", - "email", "encodings", "enum", "errno", - "faulthandler", "fcntl", "filecmp", "fileinput", "fnmatch", - "fractions", "ftplib", "functools", - "gc", "getopt", "getpass", "gettext", "glob", "graphlib", "grp", "gzip", - "hashlib", "heapq", "hmac", "html", "http", - "idlelib", "imaplib", "imghdr", "imp", "importlib", "inspect", "io", - "ipaddress", "itertools", + "abc", + "aifc", + "argparse", + "array", + "ast", + "asynchat", + "asyncio", + "asyncore", + "atexit", + "audioop", + "base64", + "bdb", + "binascii", + "binhex", + "bisect", + "builtins", + "bz2", + "calendar", + "cgi", + "cgitb", + "chunk", + "cmath", + "cmd", + "code", + "codecs", + "codeop", + "collections", + "colorsys", + "compileall", + "concurrent", + "configparser", + "contextlib", + "contextvars", + "copy", + "copyreg", + "cProfile", + "crypt", + "csv", + "ctypes", + "curses", + "dataclasses", + "datetime", + "dbm", + "decimal", + "difflib", + "dis", + "distutils", + "doctest", + "email", + "encodings", + "enum", + "errno", + "faulthandler", + "fcntl", + "filecmp", + "fileinput", + "fnmatch", + "fractions", + "ftplib", + "functools", + "gc", + "getopt", + "getpass", + "gettext", + "glob", + "graphlib", + "grp", + "gzip", + "hashlib", + "heapq", + "hmac", + "html", + "http", + "idlelib", + "imaplib", + "imghdr", + "imp", + "importlib", + "inspect", + "io", + "ipaddress", + "itertools", "json", "keyword", - "lib2to3", "linecache", "locale", "logging", "lzma", - "mailbox", "mailcap", "marshal", "math", "mimetypes", "mmap", - "modulefinder", "multiprocessing", - "netrc", "nis", "nntplib", "numbers", - "operator", "optparse", "os", "ossaudiodev", - "pathlib", "pdb", "pickle", "pickletools", "pipes", "pkgutil", - "platform", "plistlib", "poplib", "posix", "posixpath", "pprint", - "profile", "pstats", "pty", "pwd", "py_compile", "pyclbr", "pydoc", - "queue", "quopri", - "random", "re", "readline", "reprlib", "resource", "rlcompleter", + "lib2to3", + "linecache", + "locale", + "logging", + "lzma", + "mailbox", + "mailcap", + "marshal", + "math", + "mimetypes", + "mmap", + "modulefinder", + "multiprocessing", + "netrc", + "nis", + "nntplib", + "numbers", + "operator", + "optparse", + "os", + "ossaudiodev", + "pathlib", + "pdb", + "pickle", + "pickletools", + "pipes", + "pkgutil", + "platform", + "plistlib", + "poplib", + "posix", + "posixpath", + "pprint", + "profile", + "pstats", + "pty", + "pwd", + "py_compile", + "pyclbr", + "pydoc", + "queue", + "quopri", + "random", + "re", + "readline", + "reprlib", + "resource", + "rlcompleter", "runpy", - "sched", "secrets", "select", "selectors", "shelve", "shlex", - "shutil", "signal", "site", "smtpd", "smtplib", "sndhdr", "socket", - "socketserver", "spwd", "sqlite3", "ssl", "stat", "statistics", - "string", "stringprep", "struct", "subprocess", "sunau", "symtable", - "sys", "sysconfig", "syslog", - "tabnanny", "tarfile", "telnetlib", "tempfile", "termios", "test", - "textwrap", "threading", "time", "timeit", "tkinter", "token", - "tokenize", "trace", "traceback", "tracemalloc", "tty", "turtle", - "turtledemo", "types", "typing", - "unicodedata", "unittest", "urllib", "uu", "uuid", + "sched", + "secrets", + "select", + "selectors", + "shelve", + "shlex", + "shutil", + "signal", + "site", + "smtpd", + "smtplib", + "sndhdr", + "socket", + "socketserver", + "spwd", + "sqlite3", + "ssl", + "stat", + "statistics", + "string", + "stringprep", + "struct", + "subprocess", + "sunau", + "symtable", + "sys", + "sysconfig", + "syslog", + "tabnanny", + "tarfile", + "telnetlib", + "tempfile", + "termios", + "test", + "textwrap", + "threading", + "time", + "timeit", + "tkinter", + "token", + "tokenize", + "trace", + "traceback", + "tracemalloc", + "tty", + "turtle", + "turtledemo", + "types", + "typing", + "unicodedata", + "unittest", + "urllib", + "uu", + "uuid", "venv", - "warnings", "wave", "weakref", "webbrowser", "winreg", "winsound", + "warnings", + "wave", + "weakref", + "webbrowser", + "winreg", + "winsound", "wsgiref", - "xdrlib", "xml", "xmlrpc", - "zipapp", "zipfile", "zipimport", "zlib", "zoneinfo", + "xdrlib", + "xml", + "xmlrpc", + "zipapp", + "zipfile", + "zipimport", + "zlib", + "zoneinfo", # Common submodules that might be imported directly - "os.path", "urllib.parse", "urllib.request", "collections.abc", + "os.path", + "urllib.parse", + "urllib.request", + "collections.abc", "typing_extensions", # Not stdlib but often bundled } @@ -88,7 +247,7 @@ # - If the import name matches the PyPI name, it doesn't need to be here # - If the value is None, the package should be skipped (e.g., stdlib) # - Add new mappings here as needed -PACKAGE_NAME_MAP: Dict[str, Optional[str]] = { +PACKAGE_NAME_MAP: dict[str, str | None] = { # Common packages where import name differs from PyPI name "sklearn": "scikit-learn", "cv2": "opencv-python", @@ -122,18 +281,19 @@ class ImportInfo: is_from: True if this is a "from X import Y" statement line_number: Source line number for error reporting """ + module: str - names: List[str] - alias: Optional[str] + names: list[str] + alias: str | None is_from: bool line_number: int @property def top_level_package(self) -> str: """Get the top-level package name (before the first dot).""" - return self.module.split('.')[0] + return self.module.split(".")[0] - def get_pypi_package(self) -> Optional[str]: + def get_pypi_package(self) -> str | None: """Get the PyPI package name for this import. Returns: @@ -155,7 +315,7 @@ def get_pypi_package(self) -> Optional[str]: return top_level -def parse_imports_ast(code: str) -> List[ImportInfo]: +def parse_imports_ast(code: str) -> list[ImportInfo]: """Parse all import statements from Python code using AST. This function properly handles all Python import forms: @@ -177,7 +337,7 @@ def parse_imports_ast(code: str) -> List[ImportInfo]: Raises: SyntaxError: If the code cannot be parsed """ - imports: List[ImportInfo] = [] + imports: list[ImportInfo] = [] try: tree = ast.parse(code) @@ -190,31 +350,32 @@ def parse_imports_ast(code: str) -> List[ImportInfo]: if isinstance(node, ast.Import): # Handle: import foo, import foo.bar, import foo as f for alias in node.names: - imports.append(ImportInfo( - module=alias.name, - names=[alias.name.split('.')[-1]], - alias=alias.asname, - is_from=False, - line_number=node.lineno - )) - elif isinstance(node, ast.ImportFrom): + imports.append( + ImportInfo( + module=alias.name, + names=[alias.name.split(".")[-1]], + alias=alias.asname, + is_from=False, + line_number=node.lineno, + ) + ) + elif isinstance(node, ast.ImportFrom) and node.module is not None: # Handle: from foo import bar, from foo import bar as b - if node.module is not None: - # Regular from import - imports.append(ImportInfo( + # Skip relative imports like "from . import x" as they won't need pip install + imports.append( + ImportInfo( module=node.module, names=[a.name for a in node.names], alias=node.names[0].asname if len(node.names) == 1 else None, is_from=True, - line_number=node.lineno - )) - # else: relative import like "from . import x" - skip these - # as they won't need to be pip installed + line_number=node.lineno, + ) + ) return imports -def get_packages_to_install(code: str) -> Set[str]: +def get_packages_to_install(code: str) -> set[str]: """Extract pip-installable package names from Python code. This function parses the import statements in the code and returns @@ -228,7 +389,7 @@ def get_packages_to_install(code: str) -> Set[str]: Returns: Set of PyPI package names to install """ - packages: Set[str] = set() + packages: set[str] = set() for imp in parse_imports_ast(code): pkg = imp.get_pypi_package() @@ -247,5 +408,5 @@ def is_stdlib_module(module_name: str) -> bool: Returns: True if the module is part of stdlib, False otherwise """ - top_level = module_name.split('.')[0] + top_level = module_name.split(".")[0] return top_level in STDLIB_MODULES or module_name in STDLIB_MODULES diff --git a/backend/kale/common/kfputils.py b/backend/kale/common/kfputils.py index 811c21ac..3b45c569 100644 --- a/backend/kale/common/kfputils.py +++ b/backend/kale/common/kfputils.py @@ -298,7 +298,7 @@ def generate_mlpipeline_metrics(metrics): """ metadata = [] for name, value in metrics.items(): - if not isinstance(value, (int, float)): + if not isinstance(value, int | float): try: value = float(value) except ValueError: diff --git a/backend/kale/compiler.py b/backend/kale/compiler.py index e8c59463..52b59b9a 100644 --- a/backend/kale/compiler.py +++ b/backend/kale/compiler.py @@ -282,7 +282,7 @@ def _get_package_list_from_imports(self): # Parse imports using AST and resolve to PyPI package names package_names.update(get_packages_to_install(self.imports_and_functions)) - return sorted(list(package_names)) + return sorted(package_names) def _get_templating_env(self, templates_path=None): if self.templating_env: diff --git a/backend/kale/processors/pyprocessor.py b/backend/kale/processors/pyprocessor.py index da4a59da..36392a3b 100644 --- a/backend/kale/processors/pyprocessor.py +++ b/backend/kale/processors/pyprocessor.py @@ -155,7 +155,7 @@ def _fn_args_ensure_supported_types(self): # parameters are in _ALLOWED_ARG_KINDS and they have defaults # FIXME: Ensure we support all the KFP-supported types # https://github.com/kubeflow/pipelines/blob/9af3e79c10b9bb1ac1adc7bf8c1354a16fa7b461/sdk/python/kfp/components/_data_passing.py#L107-L116 - if not isinstance(param.default, (int, float, str, bool)): + if not isinstance(param.default, int | float | str | bool): raise RuntimeError( "Pipeline parameters must be of primitive" " types: int, float, str, or bool. Pipeline" diff --git a/backend/kale/tests/unit_tests/test_imports.py b/backend/kale/tests/unit_tests/test_imports.py index d947b660..b23402ab 100644 --- a/backend/kale/tests/unit_tests/test_imports.py +++ b/backend/kale/tests/unit_tests/test_imports.py @@ -29,19 +29,44 @@ class TestStdlibModules: """Tests for STDLIB_MODULES set.""" - @pytest.mark.parametrize("module", [ - "os", "sys", "re", "json", "random", "collections", - "typing", "pathlib", "datetime", "functools", "itertools", - "math", "subprocess", "threading", "multiprocessing", - ]) + @pytest.mark.parametrize( + "module", + [ + "os", + "sys", + "re", + "json", + "random", + "collections", + "typing", + "pathlib", + "datetime", + "functools", + "itertools", + "math", + "subprocess", + "threading", + "multiprocessing", + ], + ) def test_common_stdlib_modules_included(self, module): """Verify common stdlib modules are in the set.""" assert module in STDLIB_MODULES - @pytest.mark.parametrize("module", [ - "numpy", "pandas", "sklearn", "tensorflow", "torch", - "requests", "flask", "django", "pytest", - ]) + @pytest.mark.parametrize( + "module", + [ + "numpy", + "pandas", + "sklearn", + "tensorflow", + "torch", + "requests", + "flask", + "django", + "pytest", + ], + ) def test_third_party_modules_not_included(self, module): """Verify third-party modules are not in stdlib set.""" assert module not in STDLIB_MODULES @@ -50,14 +75,17 @@ def test_third_party_modules_not_included(self, module): class TestPackageNameMap: """Tests for PACKAGE_NAME_MAP dictionary.""" - @pytest.mark.parametrize("import_name,pypi_name", [ - ("sklearn", "scikit-learn"), - ("cv2", "opencv-python"), - ("PIL", "pillow"), - ("yaml", "pyyaml"), - ("bs4", "beautifulsoup4"), - ("skimage", "scikit-image"), - ]) + @pytest.mark.parametrize( + "import_name,pypi_name", + [ + ("sklearn", "scikit-learn"), + ("cv2", "opencv-python"), + ("PIL", "pillow"), + ("yaml", "pyyaml"), + ("bs4", "beautifulsoup4"), + ("skimage", "scikit-image"), + ], + ) def test_common_mappings_exist(self, import_name, pypi_name): """Verify common import-to-PyPI mappings are correct.""" assert PACKAGE_NAME_MAP.get(import_name) == pypi_name @@ -68,13 +96,7 @@ class TestImportInfo: def test_top_level_package_simple(self): """Test top_level_package with simple module.""" - info = ImportInfo( - module="numpy", - names=["array"], - alias=None, - is_from=True, - line_number=1 - ) + info = ImportInfo(module="numpy", names=["array"], alias=None, is_from=True, line_number=1) assert info.top_level_package == "numpy" def test_top_level_package_nested(self): @@ -84,18 +106,14 @@ def test_top_level_package_nested(self): names=["RandomForestClassifier"], alias=None, is_from=True, - line_number=1 + line_number=1, ) assert info.top_level_package == "sklearn" def test_get_pypi_package_direct_mapping(self): """Test get_pypi_package with direct mapping.""" info = ImportInfo( - module="sklearn", - names=["sklearn"], - alias=None, - is_from=False, - line_number=1 + module="sklearn", names=["sklearn"], alias=None, is_from=False, line_number=1 ) assert info.get_pypi_package() == "scikit-learn" @@ -106,30 +124,18 @@ def test_get_pypi_package_nested_mapping(self): names=["RandomForestClassifier"], alias=None, is_from=True, - line_number=1 + line_number=1, ) assert info.get_pypi_package() == "scikit-learn" def test_get_pypi_package_stdlib_returns_none(self): """Test get_pypi_package returns None for stdlib.""" - info = ImportInfo( - module="os", - names=["path"], - alias=None, - is_from=True, - line_number=1 - ) + info = ImportInfo(module="os", names=["path"], alias=None, is_from=True, line_number=1) assert info.get_pypi_package() is None def test_get_pypi_package_no_mapping(self): """Test get_pypi_package with no mapping (uses import name).""" - info = ImportInfo( - module="numpy", - names=["array"], - alias=None, - is_from=True, - line_number=1 - ) + info = ImportInfo(module="numpy", names=["array"], alias=None, is_from=True, line_number=1) assert info.get_pypi_package() == "numpy" @@ -212,8 +218,7 @@ def test_import_with_parentheses(self): result = parse_imports_ast(code) assert len(result) == 1 assert result[0].module == "collections" - assert sorted(result[0].names) == sorted( - ["OrderedDict", "defaultdict", "Counter"]) + assert sorted(result[0].names) == sorted(["OrderedDict", "defaultdict", "Counter"]) def test_dotted_import(self): """Test parsing dotted import.""" @@ -343,28 +348,41 @@ def test_real_world_data_science_imports(self): import seaborn as sns """ result = get_packages_to_install(code) - expected = { - "numpy", "pandas", "scikit-learn", - "matplotlib", "seaborn" - } + expected = {"numpy", "pandas", "scikit-learn", "matplotlib", "seaborn"} assert result == expected class TestIsStdlibModule: """Tests for is_stdlib_module function.""" - @pytest.mark.parametrize("module", [ - "os", "sys", "re", "json", "random", - "os.path", "collections.abc", "urllib.parse", - ]) + @pytest.mark.parametrize( + "module", + [ + "os", + "sys", + "re", + "json", + "random", + "os.path", + "collections.abc", + "urllib.parse", + ], + ) def test_stdlib_modules(self, module): """Test that stdlib modules are correctly identified.""" assert is_stdlib_module(module) is True - @pytest.mark.parametrize("module", [ - "numpy", "pandas", "sklearn", "tensorflow", - "numpy.random", "pandas.core", - ]) + @pytest.mark.parametrize( + "module", + [ + "numpy", + "pandas", + "sklearn", + "tensorflow", + "numpy.random", + "pandas.core", + ], + ) def test_non_stdlib_modules(self, module): """Test that non-stdlib modules are correctly identified.""" assert is_stdlib_module(module) is False