Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 159 additions & 0 deletions agex/agent/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,102 @@
F = TypeVar("F", bound=Callable[..., Any])


def _is_local_module(module_name: str) -> bool:
"""
Determine whether a module should be treated as local to the project/workspace.

In this context, "local" means code that lives in the current repository/workspace,
rather than third-party packages installed into site-packages or the standard library.

A module is considered "local" if:
- It has no distribution metadata (i.e., is not a packaged dependency)
- Or it is installed in editable mode from a location outside site-packages

Local modules need to be added to Modal images via add_local_python_source.
"""
import os
import site
import sys
from importlib import metadata

if not module_name:
return False

top_level = module_name.split(".")[0]

# Skip stdlib and builtins
if top_level in sys.stdlib_module_names or top_level in sys.builtin_module_names:
return False

# Skip agex itself - it gets installed on Modal
if top_level == "agex":
return False

# Check if module has PyPI distribution metadata
try:
metadata.version(top_level)
# Has version = pip installed, but might be editable
# Check if it's in site-packages
try:
mod = sys.modules.get(top_level)
if mod is None:
import importlib

mod = importlib.import_module(top_level)

if hasattr(mod, "__file__") and mod.__file__:
module_path = os.path.abspath(mod.__file__)
site_paths = site.getsitepackages()
user_site = site.getusersitepackages()
if user_site:
site_paths = site_paths + [user_site]

# If module is NOT in site-packages, it's a local/editable install
in_site = any(
module_path.startswith(os.path.abspath(sp))
for sp in site_paths
if sp
)
return not in_site
except Exception:
# Broad catch needed: dynamic import/path operations can fail in many ways
# (ImportError, AttributeError, OSError, etc.) depending on module state
pass
return False # Has metadata, assume it's from PyPI
except metadata.PackageNotFoundError:
# No distribution metadata - could be local OR an internal module (like _pytest)
# Check if the module is in site-packages
try:
mod = sys.modules.get(top_level)
if mod is None:
import importlib

mod = importlib.import_module(top_level)

if hasattr(mod, "__file__") and mod.__file__:
module_path = os.path.abspath(mod.__file__)
site_paths = site.getsitepackages()
user_site = site.getusersitepackages()
if user_site:
site_paths = site_paths + [user_site]

# If module IS in site-packages, it's an internal PyPI package module
in_site = any(
module_path.startswith(os.path.abspath(sp))
for sp in site_paths
if sp
)
if in_site:
return False # In site-packages = not local
except Exception:
# Broad catch needed: dynamic import/path operations can fail in many ways
# (ImportError, AttributeError, OSError, etc.) depending on module state
pass

# No metadata and not in site-packages = definitely local
return True


class RegistrationMixin(BaseAgent):
@overload
def fn(
Expand Down Expand Up @@ -467,6 +563,50 @@ def _track_module(self, module_name: str | None) -> None:
if module_name:
self._tracked_modules.add(module_name)

def _get_installed_optional_deps(self, package_name: str) -> set[str]:
"""
Find optional dependencies of a package that are installed locally.

For packages with extras (e.g., `calgebra[google]`), this detects which
optional dependencies are actually installed in the environment so they
can be included in remote execution images.

Args:
package_name: Distribution name (e.g., "calgebra", not import name)

Returns:
Set of installed optional deps as "package==version" strings
"""
import re
from importlib import metadata

installed_optionals: set[str] = set()

try:
reqs = metadata.requires(package_name) or []
except metadata.PackageNotFoundError:
return installed_optionals

for req in reqs:
# Look for extras markers like: 'gcsa; extra == "google"'
if "extra ==" in req or "extra==" in req:
# Extract the package name (before the semicolon)
dep_spec = req.split(";")[0].strip()
# Remove version specifiers: "gcsa>=1.0" -> "gcsa"
dep_name = re.split(r"[<>=!~\[]", dep_spec)[0].strip()

if not dep_name:
continue

# Check if it's actually installed
try:
version = metadata.version(dep_name)
installed_optionals.add(f"{dep_name}=={version}")
except metadata.PackageNotFoundError:
pass # Not installed, skip

return installed_optionals

@property
def dependencies(self) -> "Dependencies":
"""
Expand All @@ -489,6 +629,9 @@ def dependencies(self) -> "Dependencies":

# Build packages list from tracked modules (expensive, but only once)
packages: set[str] = set()
local_packages: set[str] = set() # Local packages for add_local_python_source
# Track distribution names for optional dep lookup
distribution_names: set[str] = set()

# Cache packages_distributions() result for this computation
try:
Expand All @@ -507,6 +650,11 @@ def dependencies(self) -> "Dependencies":
# Get top-level package name
top_level = module_name.split(".")[0]

# Check if this is a local package first
if _is_local_module(module_name):
local_packages.add(top_level)
continue

try:
# Map import name to distribution name (e.g. sklearn -> scikit-learn)
dist_names = pkg_map.get(top_level)
Expand All @@ -515,6 +663,7 @@ def dependencies(self) -> "Dependencies":
try:
version = metadata.version(pkg)
packages.add(f"{pkg}=={version}")
distribution_names.add(pkg)
except metadata.PackageNotFoundError:
pass
else:
Expand All @@ -523,11 +672,19 @@ def dependencies(self) -> "Dependencies":
try:
version = metadata.version(top_level)
packages.add(f"{top_level}=={version}")
distribution_names.add(top_level)
except metadata.PackageNotFoundError:
pass
except Exception:
pass

# Detect installed optional dependencies for each tracked package
# This ensures extras like `calgebra[google]` have their optional deps
# (e.g., gcsa) included if they're installed locally
for dist_name in distribution_names:
optional_deps = self._get_installed_optional_deps(dist_name)
packages.update(optional_deps)

# Collect dependencies from sub-agents (hierarchical agents)
# When an agent uses sub-agents via @agent.fn @sub_agent.task,
# the sub-agent gets serialized in the closure and needs its deps too
Expand All @@ -546,11 +703,13 @@ def dependencies(self) -> "Dependencies":
if hasattr(sub_agent, "dependencies"):
sub_deps = sub_agent.dependencies
packages.update(sub_deps.packages)
local_packages.update(sub_deps.local_packages)

deps = Dependencies(
python_version=f"{sys.version_info.major}.{sys.version_info.minor}",
agex_version=metadata.version("agex"),
packages=sorted(list(packages)),
local_packages=sorted(list(local_packages)),
)
self._cached_dependencies = deps
return deps
9 changes: 7 additions & 2 deletions agex/host/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,19 @@ class Dependencies:
python_version: str
agex_version: str
packages: list[str] = field(default_factory=list)
local_packages: list[str] = field(default_factory=list)

@property
def id(self) -> str:
"""Unique identifier for this set of dependencies (for image caching)."""
import hashlib

# Sort packages to ensure stable ID
payload = f"py{self.python_version}-agex{self.agex_version}-" + "-".join(
sorted(self.packages)
# Include both PyPI packages and local packages in hash
payload = (
f"py{self.python_version}-agex{self.agex_version}-"
+ "-".join(sorted(self.packages))
+ "-local-"
+ "-".join(sorted(self.local_packages))
)
return hashlib.sha256(payload.encode()).hexdigest()[:12]
7 changes: 7 additions & 0 deletions agex/host/modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,13 @@ def _build_function(
if extra_packages:
image = image.pip_install(*extra_packages)

# Add local packages to the image via add_local_python_source
# IMPORTANT: This must come AFTER all pip_install calls
# Modal mounts these at startup rather than embedding in image layer
if deps.local_packages:
for pkg_name in deps.local_packages:
image = image.add_local_python_source(pkg_name)

# Configure secrets
modal_secrets = [modal.Secret.from_name(s) for s in self.secrets]

Expand Down
39 changes: 39 additions & 0 deletions tests/agex/test_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,42 @@ def test_warmup_passes_dependencies(self):
assert deps.python_version
assert deps.agex_version
assert any(p.startswith("pytest==") for p in deps.packages)


class TestOptionalDependencies:
"""Test detection of installed optional dependencies."""

def test_get_installed_optional_deps_unknown_package(self):
"""Unknown package should return empty set."""
agent = Agent()
result = agent._get_installed_optional_deps("nonexistent-fake-package-xyz")
assert result == set()

def test_get_installed_optional_deps_no_extras(self):
"""Package with no optional deps should return empty set."""
agent = Agent()
# pytest doesn't have optional deps (just required ones)
result = agent._get_installed_optional_deps("pytest")
# Should not error, may return empty or some deps
assert isinstance(result, set)

def test_optional_deps_helper_detects_installed(self):
"""The helper should detect installed optional deps of a package."""
from importlib import metadata

agent = Agent()

# Find a real package in the environment that has optional deps
# We'll use agex itself which has optional deps like fastapi
try:
metadata.requires("agex")
except metadata.PackageNotFoundError:
# agex not installed as package in test env
return

# Check if any optional deps exist and are installed
optional_deps = agent._get_installed_optional_deps("agex")

# The result should be a set of "package==version" strings
for dep in optional_deps:
assert "==" in dep, f"Expected 'package==version' format, got: {dep}"