From 5846270d84d6f5bd0eda3ed56a6fa6bfb8bbb7ad Mon Sep 17 00:00:00 2001 From: Adam Ashenfelter Date: Mon, 29 Dec 2025 13:14:53 -0800 Subject: [PATCH] fix: modal support for local python modules --- agex/agent/registration.py | 159 ++++++++++++++++++++++++++++++++ agex/host/dependencies.py | 9 +- agex/host/modal.py | 7 ++ tests/agex/test_dependencies.py | 39 ++++++++ 4 files changed, 212 insertions(+), 2 deletions(-) diff --git a/agex/agent/registration.py b/agex/agent/registration.py index b521e0b..dfd2c57 100644 --- a/agex/agent/registration.py +++ b/agex/agent/registration.py @@ -21,6 +21,102 @@ F = TypeVar("F", bound=Callable[..., Any]) +def _is_local_module(module_name: str) -> bool: + """ + Determine whether a module should be treated as local to the project/workspace. + + In this context, "local" means code that lives in the current repository/workspace, + rather than third-party packages installed into site-packages or the standard library. + + A module is considered "local" if: + - It has no distribution metadata (i.e., is not a packaged dependency) + - Or it is installed in editable mode from a location outside site-packages + + Local modules need to be added to Modal images via add_local_python_source. + """ + import os + import site + import sys + from importlib import metadata + + if not module_name: + return False + + top_level = module_name.split(".")[0] + + # Skip stdlib and builtins + if top_level in sys.stdlib_module_names or top_level in sys.builtin_module_names: + return False + + # Skip agex itself - it gets installed on Modal + if top_level == "agex": + return False + + # Check if module has PyPI distribution metadata + try: + metadata.version(top_level) + # Has version = pip installed, but might be editable + # Check if it's in site-packages + try: + mod = sys.modules.get(top_level) + if mod is None: + import importlib + + mod = importlib.import_module(top_level) + + if hasattr(mod, "__file__") and mod.__file__: + module_path = os.path.abspath(mod.__file__) + site_paths = site.getsitepackages() + user_site = site.getusersitepackages() + if user_site: + site_paths = site_paths + [user_site] + + # If module is NOT in site-packages, it's a local/editable install + in_site = any( + module_path.startswith(os.path.abspath(sp)) + for sp in site_paths + if sp + ) + return not in_site + except Exception: + # Broad catch needed: dynamic import/path operations can fail in many ways + # (ImportError, AttributeError, OSError, etc.) depending on module state + pass + return False # Has metadata, assume it's from PyPI + except metadata.PackageNotFoundError: + # No distribution metadata - could be local OR an internal module (like _pytest) + # Check if the module is in site-packages + try: + mod = sys.modules.get(top_level) + if mod is None: + import importlib + + mod = importlib.import_module(top_level) + + if hasattr(mod, "__file__") and mod.__file__: + module_path = os.path.abspath(mod.__file__) + site_paths = site.getsitepackages() + user_site = site.getusersitepackages() + if user_site: + site_paths = site_paths + [user_site] + + # If module IS in site-packages, it's an internal PyPI package module + in_site = any( + module_path.startswith(os.path.abspath(sp)) + for sp in site_paths + if sp + ) + if in_site: + return False # In site-packages = not local + except Exception: + # Broad catch needed: dynamic import/path operations can fail in many ways + # (ImportError, AttributeError, OSError, etc.) depending on module state + pass + + # No metadata and not in site-packages = definitely local + return True + + class RegistrationMixin(BaseAgent): @overload def fn( @@ -467,6 +563,50 @@ def _track_module(self, module_name: str | None) -> None: if module_name: self._tracked_modules.add(module_name) + def _get_installed_optional_deps(self, package_name: str) -> set[str]: + """ + Find optional dependencies of a package that are installed locally. + + For packages with extras (e.g., `calgebra[google]`), this detects which + optional dependencies are actually installed in the environment so they + can be included in remote execution images. + + Args: + package_name: Distribution name (e.g., "calgebra", not import name) + + Returns: + Set of installed optional deps as "package==version" strings + """ + import re + from importlib import metadata + + installed_optionals: set[str] = set() + + try: + reqs = metadata.requires(package_name) or [] + except metadata.PackageNotFoundError: + return installed_optionals + + for req in reqs: + # Look for extras markers like: 'gcsa; extra == "google"' + if "extra ==" in req or "extra==" in req: + # Extract the package name (before the semicolon) + dep_spec = req.split(";")[0].strip() + # Remove version specifiers: "gcsa>=1.0" -> "gcsa" + dep_name = re.split(r"[<>=!~\[]", dep_spec)[0].strip() + + if not dep_name: + continue + + # Check if it's actually installed + try: + version = metadata.version(dep_name) + installed_optionals.add(f"{dep_name}=={version}") + except metadata.PackageNotFoundError: + pass # Not installed, skip + + return installed_optionals + @property def dependencies(self) -> "Dependencies": """ @@ -489,6 +629,9 @@ def dependencies(self) -> "Dependencies": # Build packages list from tracked modules (expensive, but only once) packages: set[str] = set() + local_packages: set[str] = set() # Local packages for add_local_python_source + # Track distribution names for optional dep lookup + distribution_names: set[str] = set() # Cache packages_distributions() result for this computation try: @@ -507,6 +650,11 @@ def dependencies(self) -> "Dependencies": # Get top-level package name top_level = module_name.split(".")[0] + # Check if this is a local package first + if _is_local_module(module_name): + local_packages.add(top_level) + continue + try: # Map import name to distribution name (e.g. sklearn -> scikit-learn) dist_names = pkg_map.get(top_level) @@ -515,6 +663,7 @@ def dependencies(self) -> "Dependencies": try: version = metadata.version(pkg) packages.add(f"{pkg}=={version}") + distribution_names.add(pkg) except metadata.PackageNotFoundError: pass else: @@ -523,11 +672,19 @@ def dependencies(self) -> "Dependencies": try: version = metadata.version(top_level) packages.add(f"{top_level}=={version}") + distribution_names.add(top_level) except metadata.PackageNotFoundError: pass except Exception: pass + # Detect installed optional dependencies for each tracked package + # This ensures extras like `calgebra[google]` have their optional deps + # (e.g., gcsa) included if they're installed locally + for dist_name in distribution_names: + optional_deps = self._get_installed_optional_deps(dist_name) + packages.update(optional_deps) + # Collect dependencies from sub-agents (hierarchical agents) # When an agent uses sub-agents via @agent.fn @sub_agent.task, # the sub-agent gets serialized in the closure and needs its deps too @@ -546,11 +703,13 @@ def dependencies(self) -> "Dependencies": if hasattr(sub_agent, "dependencies"): sub_deps = sub_agent.dependencies packages.update(sub_deps.packages) + local_packages.update(sub_deps.local_packages) deps = Dependencies( python_version=f"{sys.version_info.major}.{sys.version_info.minor}", agex_version=metadata.version("agex"), packages=sorted(list(packages)), + local_packages=sorted(list(local_packages)), ) self._cached_dependencies = deps return deps diff --git a/agex/host/dependencies.py b/agex/host/dependencies.py index 5014630..26206e6 100644 --- a/agex/host/dependencies.py +++ b/agex/host/dependencies.py @@ -8,6 +8,7 @@ class Dependencies: python_version: str agex_version: str packages: list[str] = field(default_factory=list) + local_packages: list[str] = field(default_factory=list) @property def id(self) -> str: @@ -15,7 +16,11 @@ def id(self) -> str: import hashlib # Sort packages to ensure stable ID - payload = f"py{self.python_version}-agex{self.agex_version}-" + "-".join( - sorted(self.packages) + # Include both PyPI packages and local packages in hash + payload = ( + f"py{self.python_version}-agex{self.agex_version}-" + + "-".join(sorted(self.packages)) + + "-local-" + + "-".join(sorted(self.local_packages)) ) return hashlib.sha256(payload.encode()).hexdigest()[:12] diff --git a/agex/host/modal.py b/agex/host/modal.py index a918069..e829254 100644 --- a/agex/host/modal.py +++ b/agex/host/modal.py @@ -526,6 +526,13 @@ def _build_function( if extra_packages: image = image.pip_install(*extra_packages) + # Add local packages to the image via add_local_python_source + # IMPORTANT: This must come AFTER all pip_install calls + # Modal mounts these at startup rather than embedding in image layer + if deps.local_packages: + for pkg_name in deps.local_packages: + image = image.add_local_python_source(pkg_name) + # Configure secrets modal_secrets = [modal.Secret.from_name(s) for s in self.secrets] diff --git a/tests/agex/test_dependencies.py b/tests/agex/test_dependencies.py index 612f026..8b1cf6d 100644 --- a/tests/agex/test_dependencies.py +++ b/tests/agex/test_dependencies.py @@ -157,3 +157,42 @@ def test_warmup_passes_dependencies(self): assert deps.python_version assert deps.agex_version assert any(p.startswith("pytest==") for p in deps.packages) + + +class TestOptionalDependencies: + """Test detection of installed optional dependencies.""" + + def test_get_installed_optional_deps_unknown_package(self): + """Unknown package should return empty set.""" + agent = Agent() + result = agent._get_installed_optional_deps("nonexistent-fake-package-xyz") + assert result == set() + + def test_get_installed_optional_deps_no_extras(self): + """Package with no optional deps should return empty set.""" + agent = Agent() + # pytest doesn't have optional deps (just required ones) + result = agent._get_installed_optional_deps("pytest") + # Should not error, may return empty or some deps + assert isinstance(result, set) + + def test_optional_deps_helper_detects_installed(self): + """The helper should detect installed optional deps of a package.""" + from importlib import metadata + + agent = Agent() + + # Find a real package in the environment that has optional deps + # We'll use agex itself which has optional deps like fastapi + try: + metadata.requires("agex") + except metadata.PackageNotFoundError: + # agex not installed as package in test env + return + + # Check if any optional deps exist and are installed + optional_deps = agent._get_installed_optional_deps("agex") + + # The result should be a set of "package==version" strings + for dep in optional_deps: + assert "==" in dep, f"Expected 'package==version' format, got: {dep}"