diff --git a/src/codeweaver/core/di/container.py b/src/codeweaver/core/di/container.py index 4da279b6..6689ddc8 100644 --- a/src/codeweaver/core/di/container.py +++ b/src/codeweaver/core/di/container.py @@ -7,6 +7,7 @@ from __future__ import annotations +import ast import asyncio import inspect import logging @@ -81,7 +82,81 @@ def __init__(self) -> None: self._shutdown_hooks: list[Callable[..., Any]] = [] self._cleanup_stack: AsyncExitStack | None = None self._request_cache: dict[Any, Any] = {} # Keys can be types or callables - self._providers_loaded: bool = False # Track if auto-discovery has run # Track if auto-discovery has run # Track if auto-discovery has run + self._providers_loaded: bool = False # Track if auto-discovery has run + + def _safe_eval_type(self, type_str: str, globalns: dict[str, Any]) -> Any: + """Safely evaluate a type string using AST validation. + + Args: + type_str: The string representation of a type. + globalns: The global namespace for evaluation. + + Returns: + The evaluated type object. + + Raises: + ValueError: If the type string contains forbidden constructs. + """ + try: + tree = ast.parse(type_str, mode="eval") + except SyntaxError: + return None + + class TypeValidator(ast.NodeVisitor): + def generic_visit(self, node: ast.AST) -> None: + # Allowed nodes for type annotations, including support for: + # - Generics: List[int], dict[str, Any] (Subscript, Name, Attribute) + # - Unions: int | str (BinOp, BitOr) + # - Annotated: Annotated[int, Depends(...)] (Call, keyword, Tuple, List) + # - Literals: Literal["foo"] (Constant) + if not isinstance( + node, + ( + ast.Expression, + ast.Name, + ast.Attribute, + ast.Subscript, + ast.Constant, + ast.BinOp, + ast.BitOr, + ast.Load, + ast.Tuple, + ast.List, + ast.Call, + ast.keyword, + ), + ): + raise ValueError(f"Forbidden AST node in type string: {type(node).__name__}") + + # Block dunder access to prevent escaping the restricted environment + if isinstance(node, ast.Name) and node.id.startswith("__"): + raise ValueError(f"Forbidden dunder name: {node.id}") + if isinstance(node, ast.Attribute) and node.attr.startswith("__"): + raise ValueError(f"Forbidden dunder attribute: {node.attr}") + + super().generic_visit(node) + + TypeValidator().visit(tree) + + # Restricted eval: only allow basic builtin types to be resolved + # even if they are not in the module's globals. + safe_builtins = { + "int": int, + "float": float, + "str": str, + "bool": bool, + "list": list, + "tuple": tuple, + "dict": dict, + "set": set, + "frozenset": frozenset, + "type": type, + "object": object, + "bytes": bytes, + } + + code = compile(tree, "", "eval") + return eval(code, {"__builtins__": safe_builtins}, globalns) @staticmethod def _unwrap_annotated(annotation: Any) -> Any: @@ -135,9 +210,8 @@ def _resolve_string_type( return None # First, try to evaluate the string as a type reference - # ruff: noqa: S307 - eval is necessary for type resolution, not literal evaluation with suppress(Exception): - return eval(type_str, globalns) + return self._safe_eval_type(type_str, globalns) # If direct eval failed, check if it's an Annotated pattern like "Annotated[SomeType, ...]" # In this case, try to resolve the base type from registered factories if type_str.startswith("Annotated["): @@ -164,7 +238,7 @@ def _resolve_string_type( enhanced_globalns[base_type_str] = factory_type with suppress(Exception): - return eval(type_str, enhanced_globalns) + return self._safe_eval_type(type_str, enhanced_globalns) # Fallback: try to find a factory by matching type name return next( ( diff --git a/tests/di/test_container_security.py b/tests/di/test_container_security.py new file mode 100644 index 00000000..26a05668 --- /dev/null +++ b/tests/di/test_container_security.py @@ -0,0 +1,81 @@ + +# SPDX-FileCopyrightText: 2025 Knitli Inc. +# SPDX-FileContributor: Adam Poulemanos +# +# SPDX-License-Identifier: MIT OR Apache-2.0 + +"""Security tests for the dependency injection container. + +This module verifies that the DI container safely resolves string type +annotations, preventing arbitrary code execution while supporting +complex Python type hints including generics, unions, and Annotated types. +""" + +from typing import Annotated, List, Optional, Union, get_args + +import pytest + +from codeweaver.core.di.container import Container +from codeweaver.core.di.dependency import Depends + +def test_safe_type_resolution(): + container = Container() + globalns = { + "List": List, + "Optional": Optional, + "Union": Union, + "Annotated": Annotated, + "Depends": Depends, + "int": int, + "str": str, + } + + # Valid type strings + assert container._resolve_string_type("int", globalns) is int + assert container._resolve_string_type("List[int]", globalns) == List[int] + assert container._resolve_string_type("Optional[str]", globalns) == Optional[str] + assert container._resolve_string_type("int | str", globalns) == (int | str) + + # Annotated with Depends + resolved_annotated = container._resolve_string_type("Annotated[int, Depends()]", globalns) + + # Check that it's an Annotated type in a cross-version compatible way. + # get_origin(Annotated[int, ...]) should be Annotated, but some environments + # might unwrap it or return a different origin. We check for __metadata__ + # which is specific to Annotated types. + assert hasattr(resolved_annotated, "__metadata__"), f"Expected Annotated type, got {type(resolved_annotated)}" + assert get_args(resolved_annotated)[0] is int + assert any(isinstance(m, Depends) for m in resolved_annotated.__metadata__) + +def test_malicious_type_resolution(): + container = Container() + globalns = {"__name__": "__main__"} + + # Malicious strings that should be blocked + malicious_strings = [ + "__import__('os').system('echo VULNERABLE')", + "eval('1+1')", + "getattr(int, '__name__')", + "int.__class__", + "(lambda x: x)(1)", + ] + + for s in malicious_strings: + result = container._resolve_string_type(s, globalns) + assert result is None, f"String '{s}' should have been blocked" + +def test_dunder_blocking(): + container = Container() + globalns = {"int": int} + + # Dunder name blocking + assert container._resolve_string_type("__name__", {"__name__": "foo"}) is None + + # Dunder attribute blocking + assert container._resolve_string_type("int.__name__", globalns) is None + +def test_safe_builtins_resolution(): + container = Container() + # No globals provided for basic types + assert container._resolve_string_type("int", {"__name__": "foo"}) is int + assert container._resolve_string_type("list[str]", {"__name__": "foo"}) == list[str]