From cabeb41584b280646d2b9bdbf85193cdc8a13d18 Mon Sep 17 00:00:00 2001 From: Sean Brar Date: Wed, 4 Mar 2026 02:23:18 -0800 Subject: [PATCH 01/14] feat: implement explicit create_cache API This commit refactors the persistent context caching mechanism in Pollux, replacing the implicit `enable_caching=True` flag in `Config` with an explicit `create_cache` API. This decouples cache upload/warm-up from text generation, allowing for stricter validation (e.g., rejecting `system_instruction` or `tools` usage alongside a `CacheHandle` to match Gemini API constraints) and more predictable caching behavior. --- .github/ISSUE_TEMPLATE/bug.md | 2 +- README.md | 1 - .../optimization/cache-warming-and-ttl.py | 29 +- cookbook/utils/runtime.py | 5 +- docs/caching.md | 57 +-- docs/configuration.md | 12 +- docs/portable-code.md | 29 +- docs/reference/api.md | 4 + docs/reference/provider-capabilities.md | 23 +- docs/source-patterns.md | 2 +- src/pollux/__init__.py | 139 +++++++- src/pollux/cache.py | 16 +- src/pollux/config.py | 13 - src/pollux/execute.py | 92 ++--- src/pollux/options.py | 11 + src/pollux/plan.py | 31 +- src/pollux/providers/anthropic.py | 12 +- src/pollux/providers/base.py | 12 +- src/pollux/providers/gemini.py | 12 +- src/pollux/providers/mock.py | 12 +- src/pollux/providers/openai.py | 12 +- tests/conftest.py | 10 +- tests/helpers.py | 2 +- tests/test_config.py | 25 -- tests/test_pipeline.py | 336 ++++++++++++++++-- 25 files changed, 606 insertions(+), 293 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/bug.md b/.github/ISSUE_TEMPLATE/bug.md index 2ea6707..14e7918 100644 --- a/.github/ISSUE_TEMPLATE/bug.md +++ b/.github/ISSUE_TEMPLATE/bug.md @@ -40,7 +40,7 @@ labels: [bug] diff --git a/README.md b/README.md index 76addcc..30957e7 100644 --- a/README.md +++ b/README.md @@ -138,7 +138,6 @@ from pollux import Config config = Config( provider="gemini", model="gemini-2.5-flash-lite", - enable_caching=True, # Gemini-only in v1.0 ) ``` diff --git a/cookbook/optimization/cache-warming-and-ttl.py b/cookbook/optimization/cache-warming-and-ttl.py index 174ad1a..eac8957 100644 --- a/cookbook/optimization/cache-warming-and-ttl.py +++ b/cookbook/optimization/cache-warming-and-ttl.py @@ -7,7 +7,7 @@ Pattern: - Keep prompts and sources fixed. - - Enable caching with a meaningful TTL. + - Create a persistent cache via ``create_cache()``. - Run once to warm and once to reuse (back-to-back). - Compare tokens and cache signal. """ @@ -16,7 +16,6 @@ import argparse import asyncio -from dataclasses import replace from pathlib import Path from typing import TYPE_CHECKING @@ -28,15 +27,15 @@ print_section, ) from cookbook.utils.runtime import add_runtime_args, build_config_or_exit, usage_tokens -from pollux import Config, Source, run_many +from pollux import Config, Options, Source, create_cache, run_many if TYPE_CHECKING: from pollux.result import ResultEnvelope -PROMPTS = [ +PROMPTS = ( "List 5 key concepts with one-sentence explanations.", "Extract three actionable recommendations.", -] +) def describe(run_name: str, envelope: ResultEnvelope) -> None: @@ -52,15 +51,21 @@ def describe(run_name: str, envelope: ResultEnvelope) -> None: ) -async def main_async(directory: Path, *, limit: int, config: Config) -> None: +async def main_async(directory: Path, *, limit: int, config: Config, ttl: int) -> None: files = sorted(path for path in directory.rglob("*") if path.is_file())[:limit] if not files: raise SystemExit(f"No files found under: {directory}") sources = [Source.from_file(path) for path in files] - warm = await run_many(PROMPTS, sources=sources, config=config) - reuse = await run_many(PROMPTS, sources=sources, config=config) + handle = await create_cache(sources, config=config, ttl_seconds=ttl) + + warm = await run_many( + PROMPTS, sources=sources, config=config, options=Options(cache=handle) + ) + reuse = await run_many( + PROMPTS, sources=sources, config=config, options=Options(cache=handle) + ) warm_tokens = usage_tokens(warm) reuse_tokens = usage_tokens(reuse) saved = None @@ -107,16 +112,14 @@ def main() -> None: hint="No input directory found. Run `just demo-data` or pass --input /path/to/dir.", ) config = build_config_or_exit(args) - cached_config = replace( - config, enable_caching=True, ttl_seconds=max(1, int(args.ttl)) - ) - print_header("Cache warming and TTL", config=cached_config) + print_header("Cache warming and TTL", config=config) asyncio.run( main_async( directory, limit=max(1, int(args.limit)), - config=cached_config, + config=config, + ttl=max(1, int(args.ttl)), ) ) diff --git a/cookbook/utils/runtime.py b/cookbook/utils/runtime.py index 1bbd46b..9308cb7 100644 --- a/cookbook/utils/runtime.py +++ b/cookbook/utils/runtime.py @@ -60,14 +60,11 @@ def build_config_or_exit(args: argparse.Namespace) -> Config: def print_run_mode(config: Config) -> None: """Print a compact runtime mode line for recipe users.""" mode = "mock" if config.use_mock else "real-api" - caching = f"on(ttl={config.ttl_seconds}s)" if config.enable_caching else "off" extra = "" # Keep the mode line compact; only call out non-default concurrency. if getattr(config, "request_concurrency", 6) != 6: extra = f" | request_concurrency={config.request_concurrency}" - print( - f"Mode: {mode} | provider={config.provider} | model={config.model} | caching={caching}{extra}" - ) + print(f"Mode: {mode} | provider={config.provider} | model={config.model}{extra}") def usage_tokens(envelope: ResultEnvelope) -> int | None: diff --git a/docs/caching.md b/docs/caching.md index 3758c4b..5dfb8c9 100644 --- a/docs/caching.md +++ b/docs/caching.md @@ -1,5 +1,5 @@ @@ -73,29 +73,30 @@ compare_efficiency(946_800, 10) More questions on the same content = greater savings. -## Enabling Caching +## Creating a Cache -Let's see this in practice. Two flags in `Config` control caching: +Use `create_cache()` to upload content to the provider once, then pass the +returned handle to `run()` or `run_many()` via `Options(cache=handle)`: ```python import asyncio -from pollux import Config, Source, run_many +from pollux import Config, Options, Source, create_cache, run_many async def main() -> None: config = Config( provider="gemini", model="gemini-2.5-flash-lite", - enable_caching=True, - ttl_seconds=3600, ) - prompts = ["Summarize in one sentence.", "List 3 keywords."] sources = [Source.from_text( "ACME Corp Q3 2025 earnings: revenue $4.2B (+12% YoY), " "operating margin 18.5%, guidance raised for Q4." )] - first = await run_many(prompts=prompts, sources=sources, config=config) - second = await run_many(prompts=prompts, sources=sources, config=config) + handle = await create_cache(sources, config=config, ttl_seconds=3600) + + prompts = ["Summarize in one sentence.", "List 3 keywords."] + first = await run_many(prompts=prompts, sources=sources, config=config, options=Options(cache=handle)) + second = await run_many(prompts=prompts, sources=sources, config=config, options=Options(cache=handle)) print("first:", first["status"]) print("second:", second["status"]) @@ -106,23 +107,37 @@ asyncio.run(main()) ### Step-by-Step Walkthrough -1. **Set `enable_caching=True`.** This tells Pollux to upload content to the - provider's cache on the first call, rather than sending it inline. +1. **Call `create_cache()`.** Pass your sources, config, and a TTL. Pollux + uploads the content to the provider and returns a `CacheHandle`. 2. **Set `ttl_seconds`.** The TTL controls how long the cached content lives on the provider. Match it to your reuse window. 3600s (1 hour) is a reasonable default for interactive sessions. -3. **Run the same sources with different prompts.** The first `run_many()` call - uploads the content and creates a cache entry. The second call detects the - same content hash and reuses the cached reference. +3. **Pass the handle via `Options(cache=handle)`.** Each `run()` or `run_many()` + call that uses this handle references the cached content instead of + re-uploading it. 4. **Verify with `metrics.cache_used`.** Check `result["metrics"]["cache_used"]` on subsequent calls. `True` confirms the provider served content from cache rather than re-uploading. -Pollux computes cache identity from model + source content hash. The second -call reuses the cached context automatically. +Pollux computes cache identity from model + source content hash. Calls with +the same handle reuse the cached context automatically. + +!!! warning "Options restricted when using a cache handle" + When `Options(cache=handle)` is set, the following fields **cannot** be + passed alongside it: + + - `system_instruction` — bake it into `create_cache(system_instruction=...)` + instead. + - `tools` / `tool_choice` — bake them into `create_cache()` instead (when + supported). + + Pollux raises `ConfigurationError` immediately if it detects these + conflicts. This mirrors a hard constraint in the Gemini API, where + `cached_content` cannot coexist with `system_instruction`, `tools`, or + `tool_config` in the same `GenerateContent` request. ## Cache Identity @@ -156,8 +171,8 @@ behavior. Usage counters are provider-dependent. ## Tuning TTL -The default TTL is 3600 seconds (1 hour). Tune `ttl_seconds` to match your -expected reuse window: +Pass `ttl_seconds` to `create_cache()` to control the cache lifetime. The +default is 3600 seconds (1 hour). Tune it to match your expected reuse window: - **Too short:** the cache expires before you reuse it, wasting the warm-up cost. @@ -181,9 +196,9 @@ caching and enable it when you see repeated context in your workload. ## Provider Dependency -Context caching is **Gemini-only**. Enabling it with OpenAI raises -an actionable error. See -[Provider Capabilities](reference/provider-capabilities.md) for the full +Persistent context caching is **Gemini-only**. Calling `create_cache()` with +a provider that lacks `persistent_cache` support raises an actionable error. +See [Provider Capabilities](reference/provider-capabilities.md) for the full matrix. --- diff --git a/docs/configuration.md b/docs/configuration.md index 23b24ddd..e22bea3 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -35,8 +35,6 @@ All fields and their defaults: | `model` | `str` | *(required)* | Model identifier | | `api_key` | `str \| None` | `None` | Explicit key; auto-resolved from env if omitted | | `use_mock` | `bool` | `False` | Use mock provider (no network calls) | -| `enable_caching` | `bool` | `False` | Enable provider-side context caching | -| `ttl_seconds` | `int` | `3600` | Cache time-to-live in seconds | | `request_concurrency` | `int` | `6` | Max concurrent API calls in multi-prompt execution | | `retry` | `RetryPolicy` | `RetryPolicy()` | Retry configuration | @@ -78,8 +76,6 @@ pipeline logic, testing integrations, and CI. config = Config( provider="gemini", model="gemini-2.5-flash-lite", - enable_caching=True, # Reuse uploaded context (Gemini-only) - ttl_seconds=3600, # Cache lifetime request_concurrency=6, # Concurrent API calls ) ``` @@ -87,7 +83,7 @@ config = Config( | Need | Direction | |---|---| | Fast iteration without API calls | `use_mock=True` | -| Reduce token spend on repeated context | `enable_caching=True`. See [Reducing Costs with Context Caching](caching.md) | +| Reduce token spend on repeated context | Use `create_cache()`. See [Reducing Costs with Context Caching](caching.md) | | Higher throughput for many prompts/sources | Increase `request_concurrency` | | Better resilience to transient failures | Customize `retry=RetryPolicy(...)` | @@ -151,6 +147,7 @@ options = Options( | `delivery_mode` | `str` | `"realtime"` | Only `"realtime"` is supported; `"deferred"` raises an error | | `history` | `list[dict] \| None` | `None` | Conversation history. See [Continuing Conversations Across Turns](conversations-and-agents.md) | | `continue_from` | `ResultEnvelope \| None` | `None` | Resume from a prior result. See [Continuing Conversations Across Turns](conversations-and-agents.md) | +| `cache` | `CacheHandle \| None` | `None` | Persistent context cache. See [Reducing Costs with Context Caching](caching.md) | !!! note OpenAI GPT-5 family models (`gpt-5`, `gpt-5-mini`, `gpt-5-nano`) reject @@ -159,6 +156,11 @@ options = Options( See [Writing Portable Code Across Providers](portable-code.md#model-specific-constraints) for the full constraints mapping. +!!! warning "Cache handle restrictions" + When `cache` is set, `system_instruction` and `tools` **must not** be + passed in the same `Options`. Bake them into `create_cache()` instead. + See [Reducing Costs with Context Caching](caching.md) for details. + ## Safety Notes - `Config` is immutable (`frozen=True`). Create a new instance to change values. diff --git a/docs/portable-code.md b/docs/portable-code.md index 7fb3b5a..878199b 100644 --- a/docs/portable-code.md +++ b/docs/portable-code.md @@ -35,8 +35,7 @@ varying parts in config; keep the stable parts in functions. ## Complete Example -A document analysis function that works on any provider. Caching is used -when available, skipped otherwise. +A document analysis function that works on any provider. ```python import asyncio @@ -55,27 +54,22 @@ class DocumentSummary(BaseModel): @dataclass class ProviderConfig: - """Maps a provider to a model and capability flags.""" + """Maps a provider to a model.""" provider: str model: str - supports_caching: bool = False # Provider-specific details live here, not in your pipeline logic PROVIDERS = { - "gemini": ProviderConfig("gemini", "gemini-2.5-flash-lite", supports_caching=True), + "gemini": ProviderConfig("gemini", "gemini-2.5-flash-lite"), "openai": ProviderConfig("openai", "gpt-5-nano"), } -def make_config(provider_name: str, *, enable_caching: bool = False) -> Config: +def make_config(provider_name: str) -> Config: """Build a Config for the given provider with safe defaults.""" pc = PROVIDERS[provider_name] - return Config( - provider=pc.provider, - model=pc.model, - enable_caching=enable_caching and pc.supports_caching, - ) + return Config(provider=pc.provider, model=pc.model) async def analyze_document( @@ -83,10 +77,9 @@ async def analyze_document( prompt: str, *, provider_name: str = "gemini", - enable_caching: bool = False, ) -> DocumentSummary: """Analyze a document — works with any supported provider.""" - config = make_config(provider_name, enable_caching=enable_caching) + config = make_config(provider_name) options = Options(response_schema=DocumentSummary) result = await run( @@ -118,12 +111,12 @@ asyncio.run(main()) ### Step-by-Step Walkthrough 1. **Centralize provider details.** `ProviderConfig` maps each provider to - its model and capability flags. Your analysis functions never reference - provider names or models directly. + its model. Your analysis functions never reference provider names or + models directly. -2. **Guard capability-specific features.** `make_config` only enables caching - when both the caller requests it *and* the provider supports it. This - avoids `ConfigurationError` at runtime. +2. **Use `create_cache()` for persistent caching.** Caching is now + opt-in via `create_cache()` and `Options(cache=handle)`. Only call + it when the provider supports `persistent_cache` (e.g. Gemini). 3. **Write provider-agnostic functions.** `analyze_document` accepts a provider name and builds the config internally. The prompt, source, and diff --git a/docs/reference/api.md b/docs/reference/api.md index eccca32..1324bda 100644 --- a/docs/reference/api.md +++ b/docs/reference/api.md @@ -14,10 +14,14 @@ The primary execution functions are exported from `pollux`: ::: pollux.continue_tool +::: pollux.create_cache + ## Core Types ::: pollux.Source +::: pollux.CacheHandle + ::: pollux.Config ::: pollux.Options diff --git a/docs/reference/provider-capabilities.md b/docs/reference/provider-capabilities.md index 2f3e3a8..a6638f8 100644 --- a/docs/reference/provider-capabilities.md +++ b/docs/reference/provider-capabilities.md @@ -84,20 +84,19 @@ Pollux is **capability-transparent**, not capability-equalizing: providers are a When a requested feature is unsupported for the selected provider or release scope, Pollux raises `ConfigurationError` or `APIError` with a concrete hint, instead of degrading silently. -For example, enabling caching with OpenAI: +For example, creating a persistent cache with OpenAI: ```python -from pollux import Config - -config = Config( - provider="openai", - model="gpt-5-nano", - enable_caching=True, # not supported for OpenAI +from pollux import Config, Source, create_cache + +config = Config(provider="openai", model="gpt-5-nano") +# This raises immediately: +# ConfigurationError: Provider 'openai' does not support persistent caching +# hint: "Use a provider that supports persistent_cache (e.g. Gemini)." +handle = await create_cache( + [Source.from_text("hello")], config=config ) -# At execution time, Pollux raises: -# ConfigurationError: Provider does not support caching -# hint: "Disable caching or choose a provider with caching support." ``` -The error is raised at execution time (not at `Config` creation) because -caching support is a provider capability checked during plan execution. +The error is raised at `create_cache()` call time because persistent caching +is a provider capability checked before the upload attempt. diff --git a/docs/source-patterns.md b/docs/source-patterns.md index 01105ce..5ec596f 100644 --- a/docs/source-patterns.md +++ b/docs/source-patterns.md @@ -232,7 +232,7 @@ async def process_to_jsonl(directory: str, output: str) -> None: - **Memory with large collections.** Each `Source.from_file()` reads the file for hashing. For very large collections, process in batches rather than loading all sources at once. -- **Caching helps fan-out, not iteration.** `enable_caching=True` saves +- **Caching helps fan-out, not iteration.** `create_cache()` saves tokens when the *same source* gets reused across multiple prompts. It does not help when each file is different. See [Reducing Costs with Context Caching](caching.md). diff --git a/src/pollux/__init__.py b/src/pollux/__init__.py index bf76d12..d601ad5 100644 --- a/src/pollux/__init__.py +++ b/src/pollux/__init__.py @@ -3,6 +3,7 @@ Public API: - run(): Single prompt execution - run_many(): Multi-prompt source-pattern execution + - create_cache(): Create a persistent context cache - Source: Explicit input types - Config: Configuration dataclass """ @@ -11,9 +12,10 @@ import asyncio import logging +import time from typing import TYPE_CHECKING, Any -from pollux.cache import CacheRegistry +from pollux.cache import CacheHandle, CacheRegistry from pollux.config import Config from pollux.errors import ( APIError, @@ -113,17 +115,9 @@ async def run_many( provider = _get_provider(request.config) try: - trace = await execute_plan(plan, provider, _registry) + trace = await execute_plan(plan, provider) finally: - aclose = getattr(provider, "aclose", None) - if callable(aclose): - try: - await aclose() - except asyncio.CancelledError: - raise - except Exception as exc: - # Cleanup should never mask the primary failure. - logger.warning("Provider cleanup failed: %s", exc) + await _close_provider(provider) return build_result(plan, trace) @@ -169,6 +163,127 @@ async def continue_tool( return await run(prompt=None, config=config, options=merged_options) +async def create_cache( + sources: tuple[Source, ...] | list[Source], + *, + config: Config, + system_instruction: str | None = None, + ttl_seconds: int = 3600, +) -> CacheHandle: + """Create a persistent context cache for use with ``run()`` / ``run_many()``. + + Args: + sources: Content to cache (files, text, URLs). + config: Configuration specifying provider and model. + system_instruction: Optional system-level instruction baked into the cache. + ttl_seconds: Time-to-live in seconds (must be ≥ 1). + + Returns: + A ``CacheHandle`` that can be passed via ``Options(cache=handle)``. + + Raises: + ConfigurationError: If the provider does not support persistent caching + or *ttl_seconds* is invalid. + + Example: + handle = await create_cache( + [Source.from_file("book.pdf")], + config=Config(provider="gemini", model="gemini-2.5-flash"), + ttl_seconds=3600, + ) + result = await run("Summarize.", config=config, options=Options(cache=handle)) + """ + from pollux.cache import get_or_create_cache + from pollux.execute import _substitute_upload_parts + from pollux.plan import build_shared_parts + + if not isinstance(ttl_seconds, int) or ttl_seconds < 1: + raise ConfigurationError( + f"ttl_seconds must be an integer ≥ 1, got {ttl_seconds!r}", + hint="Pass a positive integer for the cache TTL.", + ) + + provider = _get_provider(config) + try: + if not provider.capabilities.persistent_cache: + raise ConfigurationError( + f"Provider {config.provider!r} does not support persistent caching", + hint="Use a provider that supports persistent_cache (e.g. Gemini).", + ) + + src_tuple = tuple(sources) if not isinstance(sources, tuple) else sources + + # Validate sources + for s in src_tuple: + if not isinstance(s, Source): + raise ConfigurationError( + f"Expected Source, got {type(s).__name__}", + hint="Use Source.from_file(), Source.from_text(), etc.", + ) + + parts = build_shared_parts(src_tuple) + + # Resolve file uploads. + upload_cache: dict[tuple[str, str], Any] = {} + upload_inflight: dict[tuple[str, str], asyncio.Future[Any]] = {} + upload_lock = asyncio.Lock() + retry_policy = config.retry + + parts = await _substitute_upload_parts( + parts, + provider=provider, + call_idx=None, + upload_cache=upload_cache, + upload_inflight=upload_inflight, + upload_lock=upload_lock, + retry_policy=retry_policy, + ) + + from pollux.cache import compute_cache_key + + key = compute_cache_key( + config.model, src_tuple, system_instruction=system_instruction + ) + + cache_name = await get_or_create_cache( + provider, + _registry, + key=key, + model=config.model, + parts=parts, + system_instruction=system_instruction, + ttl_seconds=ttl_seconds, + retry_policy=retry_policy, + ) + + if cache_name is None: + raise InternalError( + "Cache creation returned None unexpectedly", + hint="This is a Pollux internal error. Please report it.", + ) + + return CacheHandle( + name=cache_name, + model=config.model, + provider=config.provider, + expires_at=time.time() + ttl_seconds, + ) + finally: + await _close_provider(provider) + + +async def _close_provider(provider: Provider) -> None: + """Close provider resources without masking primary errors.""" + aclose = getattr(provider, "aclose", None) + if callable(aclose): + try: + await aclose() + except asyncio.CancelledError: + raise + except Exception as exc: + logger.warning("Provider cleanup failed: %s", exc) + + def _get_provider(config: Config) -> Provider: """Get the appropriate provider based on configuration.""" if config.use_mock: @@ -210,6 +325,7 @@ def _get_provider(config: Config) -> Provider: __all__ = [ "APIError", "CacheError", + "CacheHandle", "Config", "ConfigurationError", "InternalError", @@ -222,6 +338,7 @@ def _get_provider(config: Config) -> Provider: "Source", "SourceError", "continue_tool", + "create_cache", "run", "run_many", ] diff --git a/src/pollux/cache.py b/src/pollux/cache.py index 9b4a9e2..00b79c3 100644 --- a/src/pollux/cache.py +++ b/src/pollux/cache.py @@ -19,6 +19,20 @@ logger = logging.getLogger(__name__) +@dataclass(frozen=True) +class CacheHandle: + """Opaque handle returned by ``create_cache()``. + + Pass instances via ``Options(cache=handle)`` to reuse a persistent + context cache across ``run()`` / ``run_many()`` calls. + """ + + name: str + model: str + provider: str + expires_at: float + + @dataclass class CacheRegistry: """Registry tracking cache entries with expiration.""" @@ -82,7 +96,7 @@ async def get_or_create_cache( Single-flight: concurrent requests for the same key share one creation call. """ - if not provider.supports_caching: + if not provider.capabilities.persistent_cache: return None async def _work() -> str: diff --git a/src/pollux/config.py b/src/pollux/config.py index d4d815f..9f72445 100644 --- a/src/pollux/config.py +++ b/src/pollux/config.py @@ -38,9 +38,6 @@ class Config: #: Auto-resolved from ``GEMINI_API_KEY`` or ``OPENAI_API_KEY`` when *None*. api_key: str | None = None use_mock: bool = False - #: Gemini-only in v1.0; silently ignored for other providers. - enable_caching: bool = False - ttl_seconds: int = 3600 request_concurrency: int = 6 retry: RetryPolicy = field(default_factory=RetryPolicy) @@ -59,21 +56,11 @@ def __post_init__(self) -> None: f"request_concurrency must be an integer, got {type(self.request_concurrency).__name__}", hint="Pass a whole number ≥ 1 for request_concurrency.", ) - if not isinstance(self.ttl_seconds, int): - raise ConfigurationError( - f"ttl_seconds must be an integer, got {type(self.ttl_seconds).__name__}", - hint="Pass a whole number ≥ 0 for ttl_seconds.", - ) if self.request_concurrency < 1: raise ConfigurationError( f"request_concurrency must be ≥ 1, got {self.request_concurrency}", hint="This controls how many API calls run in parallel.", ) - if self.ttl_seconds < 0: - raise ConfigurationError( - f"ttl_seconds must be ≥ 0, got {self.ttl_seconds}", - hint="This controls the cache time-to-live in seconds (0 disables caching TTL).", - ) # Auto-resolve API key from environment if not provided if self.api_key is None and not self.use_mock: diff --git a/src/pollux/execute.py b/src/pollux/execute.py index 2c7b3ca..4dc6d82 100644 --- a/src/pollux/execute.py +++ b/src/pollux/execute.py @@ -11,7 +11,6 @@ from typing import TYPE_CHECKING, Any from pollux._singleflight import singleflight_cached -from pollux.cache import CacheRegistry, get_or_create_cache from pollux.errors import APIError, ConfigurationError, InternalError, PolluxError from pollux.providers.models import ( Message, @@ -68,9 +67,7 @@ class ExecutionTrace: conversation_state: dict[str, Any] | None = None -async def execute_plan( - plan: Plan, provider: Provider, registry: CacheRegistry -) -> ExecutionTrace: +async def execute_plan(plan: Plan, provider: Provider) -> ExecutionTrace: """Execute the plan with the given provider. Handles: @@ -116,8 +113,54 @@ async def execute_plan( "Conversation continuity currently supports exactly one prompt per call", hint="Use run() or run_many() with a single prompt when passing history/continue_from.", ) + if options.cache is not None: + cache_handle = options.cache + if not caps.persistent_cache: + raise ConfigurationError( + "Provider does not support persistent caching", + hint=( + "Remove options.cache or choose a provider with " + "persistent_cache support." + ), + ) + if cache_handle.provider != config.provider: + raise ConfigurationError( + "cache handle provider does not match config provider", + hint=( + f"Create the cache with provider={config.provider!r} and " + "reuse it with the same provider." + ), + ) + if cache_handle.model != model: + raise ConfigurationError( + "cache handle model does not match config model", + hint=( + f"Create the cache with model={model!r} and reuse it " + "with the same model." + ), + ) + # Gemini (and potentially other providers) reject requests that pass + # system_instruction or tools alongside cached_content. Catch the + # conflict early so users get a clear Pollux error instead of a + # provider 400. + if options.system_instruction is not None: + raise ConfigurationError( + "system_instruction cannot be used with a cache handle", + hint=( + "Bake the system instruction into create_cache() instead, " + "or remove the cache handle." + ), + ) + if options.tools is not None: + raise ConfigurationError( + "tools cannot be used with a cache handle", + hint=( + "Bake tools into create_cache() instead, " + "or remove the cache handle." + ), + ) - if (not provider.supports_uploads) and any( + if (not provider.capabilities.uploads) and any( isinstance(p, dict) and isinstance(p.get("file_path"), str) and isinstance(p.get("mime_type"), str) @@ -169,48 +212,13 @@ async def execute_plan( upload_inflight: dict[tuple[str, str], asyncio.Future[ProviderFileAsset]] = {} upload_lock = asyncio.Lock() retry_policy = config.retry - cache_name: str | None = None responses: list[dict[str, Any]] = [] total_usage: dict[str, int] = {} conversation_state: dict[str, Any] | None = None try: - # Handle caching - if plan.use_cache: - # Resolve uploads for shared parts first, as cache creation requires URIs - shared_parts = list(plan.shared_parts) - if shared_parts: - shared_parts = await _substitute_upload_parts( - shared_parts, - provider=provider, - call_idx=None, - upload_cache=upload_cache, - upload_inflight=upload_inflight, - upload_lock=upload_lock, - retry_policy=retry_policy, - ) - - if plan.cache_key: - try: - cache_name = await get_or_create_cache( - provider, - registry, - key=plan.cache_key, - model=config.model, - parts=shared_parts, # Use resolved parts with URIs - system_instruction=options.system_instruction, - ttl_seconds=config.ttl_seconds, - retry_policy=retry_policy, - ) - except asyncio.CancelledError: - raise - except PolluxError: - raise - except Exception as e: - raise InternalError( - f"Cache creation failed: {type(e).__name__}: {e}", - hint="This is a Pollux internal error. Please report it.", - ) from e + # Use pre-created cache name from Options.cache (via plan). + cache_name = options.cache.name if options.cache is not None else None # Execute calls with concurrency control concurrency = config.request_concurrency diff --git a/src/pollux/options.py b/src/pollux/options.py index 778e2bd..0ea6ee2 100644 --- a/src/pollux/options.py +++ b/src/pollux/options.py @@ -10,6 +10,7 @@ from pollux.errors import ConfigurationError if TYPE_CHECKING: + from pollux.cache import CacheHandle from pollux.result import ResultEnvelope ReasoningEffort = str @@ -44,6 +45,8 @@ class Options: continue_from: ResultEnvelope | None = None #: Hard limit on the model's total output tokens. Provider-specific semantics. max_tokens: int | None = None + #: Persistent context cache obtained from ``create_cache()``. + cache: CacheHandle | None = None def __post_init__(self) -> None: """Validate option shapes early for clear errors.""" @@ -101,6 +104,14 @@ def __post_init__(self) -> None: "continue_from must be a prior Pollux result envelope", hint="Pass the dict returned by run() or run_many().", ) + if self.cache is not None: + from pollux.cache import CacheHandle + + if not isinstance(self.cache, CacheHandle): + raise ConfigurationError( + "cache must be a CacheHandle from create_cache()", + hint="Call create_cache() first, then pass Options(cache=handle).", + ) def response_schema_json(self) -> dict[str, Any] | None: """Return JSON Schema for provider APIs.""" diff --git a/src/pollux/plan.py b/src/pollux/plan.py index 59663df..433a340 100644 --- a/src/pollux/plan.py +++ b/src/pollux/plan.py @@ -12,12 +12,11 @@ @dataclass(frozen=True) class Plan: - """Execution plan with shared context and cache identity.""" + """Execution plan with shared context and optional cache reference.""" request: Request shared_parts: tuple[Any, ...] = () - use_cache: bool = False - cache_key: str | None = None + cache_name: str | None = None @property def n_calls(self) -> int: @@ -30,34 +29,22 @@ def build_plan(request: Request) -> Plan: Handles both single-prompt and vectorized (multi-prompt) scenarios. """ - config = request.config sources = request.sources + shared_parts = build_shared_parts(sources) - # Build shared parts from sources - shared_parts = _build_shared_parts(sources) - - # Determine if caching should be used - use_cache = config.enable_caching and len(shared_parts) > 0 - cache_key = None - - if use_cache: - from pollux.cache import compute_cache_key - - cache_key = compute_cache_key( - config.model, - sources, - system_instruction=request.options.system_instruction, - ) + # Resolve cache_name from Options.cache if provided. + cache_name: str | None = None + if request.options.cache is not None: + cache_name = request.options.cache.name return Plan( request=request, shared_parts=tuple(shared_parts), - use_cache=use_cache, - cache_key=cache_key, + cache_name=cache_name, ) -def _build_shared_parts(sources: tuple[Source, ...]) -> list[Any]: +def build_shared_parts(sources: tuple[Source, ...]) -> list[Any]: """Convert sources to API parts.""" parts: list[Any] = [] diff --git a/src/pollux/providers/anthropic.py b/src/pollux/providers/anthropic.py index 9437382..ce835a9 100644 --- a/src/pollux/providers/anthropic.py +++ b/src/pollux/providers/anthropic.py @@ -57,21 +57,11 @@ def _get_client(self) -> Any: self._client = AsyncAnthropic(api_key=self.api_key) return self._client - @property - def supports_caching(self) -> bool: - """Whether this provider supports context caching.""" - return self.capabilities.caching - - @property - def supports_uploads(self) -> bool: - """Whether this provider supports file uploads.""" - return self.capabilities.uploads - @property def capabilities(self) -> ProviderCapabilities: """Return supported feature flags.""" return ProviderCapabilities( - caching=False, + persistent_cache=False, uploads=True, structured_outputs=True, reasoning=True, diff --git a/src/pollux/providers/base.py b/src/pollux/providers/base.py index 757f79b..e832327 100644 --- a/src/pollux/providers/base.py +++ b/src/pollux/providers/base.py @@ -19,7 +19,7 @@ class ProviderCapabilities: """Feature flags exposed by providers.""" - caching: bool + persistent_cache: bool uploads: bool structured_outputs: bool = False reasoning: bool = False @@ -53,16 +53,6 @@ async def create_cache( """Create a cache and return its name.""" ... - @property - def supports_caching(self) -> bool: - """Whether this provider supports caching.""" - ... - - @property - def supports_uploads(self) -> bool: - """Whether this provider supports file uploads.""" - ... - @property def capabilities(self) -> ProviderCapabilities: """Feature capabilities for strict option validation.""" diff --git a/src/pollux/providers/gemini.py b/src/pollux/providers/gemini.py index d63b17e..e009291 100644 --- a/src/pollux/providers/gemini.py +++ b/src/pollux/providers/gemini.py @@ -48,21 +48,11 @@ def _get_client(self) -> Any: self._client = genai.Client(api_key=self.api_key) return self._client - @property - def supports_caching(self) -> bool: - """Whether this provider supports context caching.""" - return self.capabilities.caching - - @property - def supports_uploads(self) -> bool: - """Whether this provider supports file uploads.""" - return self.capabilities.uploads - @property def capabilities(self) -> ProviderCapabilities: """Return supported feature flags.""" return ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, structured_outputs=True, reasoning=True, diff --git a/src/pollux/providers/mock.py b/src/pollux/providers/mock.py index 6301ef5..023a1c9 100644 --- a/src/pollux/providers/mock.py +++ b/src/pollux/providers/mock.py @@ -17,21 +17,11 @@ class MockProvider: Supports caching and uploads but returns synthetic responses. """ - @property - def supports_caching(self) -> bool: - """Whether this provider supports caching.""" - return self.capabilities.caching - - @property - def supports_uploads(self) -> bool: - """Whether this provider supports file uploads.""" - return self.capabilities.uploads - @property def capabilities(self) -> ProviderCapabilities: """Return supported feature flags.""" return ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, structured_outputs=False, reasoning=False, diff --git a/src/pollux/providers/openai.py b/src/pollux/providers/openai.py index 39f7e23..ea9fd45 100644 --- a/src/pollux/providers/openai.py +++ b/src/pollux/providers/openai.py @@ -44,21 +44,11 @@ def _get_client(self) -> Any: self._client = AsyncOpenAI(api_key=self.api_key) return self._client - @property - def supports_caching(self) -> bool: - """Whether this provider supports context caching.""" - return self.capabilities.caching - - @property - def supports_uploads(self) -> bool: - """Whether this provider supports file uploads.""" - return self.capabilities.uploads - @property def capabilities(self) -> ProviderCapabilities: """Return supported feature flags.""" return ProviderCapabilities( - caching=False, + persistent_cache=False, uploads=True, structured_outputs=True, reasoning=True, diff --git a/tests/conftest.py b/tests/conftest.py index 0536563..53a06b9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -32,7 +32,7 @@ class FakeProvider: last_generate_kwargs: dict[str, Any] | None = None _capabilities: ProviderCapabilities = field( default_factory=lambda: ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, structured_outputs=False, reasoning=False, @@ -41,14 +41,6 @@ class FakeProvider: ) ) - @property - def supports_caching(self) -> bool: - return self.capabilities.caching - - @property - def supports_uploads(self) -> bool: - return self.capabilities.uploads - @property def capabilities(self) -> ProviderCapabilities: return self._capabilities diff --git a/tests/helpers.py b/tests/helpers.py index ae9d354..3503f63 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -72,7 +72,7 @@ def __post_init__(self) -> None: self, "_capabilities", ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, structured_outputs=True, reasoning=True, diff --git a/tests/test_config.py b/tests/test_config.py index 5c143f2..1cbe33b 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -184,31 +184,6 @@ def test_non_integer_request_concurrency_raises_clear_error(gemini_model: str) - assert exc.value.hint is not None -def test_negative_ttl_seconds_raises_clear_error(gemini_model: str) -> None: - """Negative TTL is nonsensical; must fail at construction.""" - with pytest.raises(ConfigurationError, match="ttl_seconds must be") as exc: - Config(provider="gemini", model=gemini_model, use_mock=True, ttl_seconds=-1) - assert exc.value.hint is not None - - -def test_non_integer_ttl_seconds_raises_clear_error(gemini_model: str) -> None: - """Non-integer TTL should raise ConfigurationError, not TypeError.""" - with pytest.raises(ConfigurationError, match="must be an integer") as exc: - Config( - provider="gemini", - model=gemini_model, - use_mock=True, - ttl_seconds="3600", # type: ignore[arg-type] - ) - assert exc.value.hint is not None - - -def test_zero_ttl_seconds_is_allowed(gemini_model: str) -> None: - """ttl_seconds=0 is valid (disables caching TTL).""" - cfg = Config(provider="gemini", model=gemini_model, use_mock=True, ttl_seconds=0) - assert cfg.ttl_seconds == 0 - - def test_config_str_and_repr_redact_api_key(gemini_model: str) -> None: """String representations must not leak secrets.""" secret = "top-secret-key" diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index fea7aef..ba195da 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -10,7 +10,7 @@ import pytest import pollux -from pollux.cache import CacheRegistry, compute_cache_key +from pollux.cache import CacheHandle, CacheRegistry, compute_cache_key from pollux.config import Config from pollux.errors import APIError, ConfigurationError, PlanningError, SourceError from pollux.options import Options @@ -202,7 +202,7 @@ async def upload_file(self, path: Any, mime_type: str) -> ProviderFileAsset: # async def test_cache_error_attributes_provider_without_call_index( monkeypatch: pytest.MonkeyPatch, ) -> None: - """Cache failures should carry provider and phase but no call index.""" + """Cache failures from create_cache() should carry provider and phase.""" @dataclass class _Provider(FakeProvider): @@ -222,21 +222,18 @@ async def create_cache(self, **kwargs: Any) -> str: provider="gemini", model=CACHE_MODEL, use_mock=True, - enable_caching=True, retry=RetryPolicy(max_attempts=1), ) with pytest.raises(APIError) as exc: - await pollux.run_many( - ("Q",), - sources=(Source.from_text("cache me"),), + await pollux.create_cache( + (Source.from_text("cache me"),), config=cfg, ) err = exc.value assert err.provider == "gemini" assert err.phase == "cache" - assert err.call_idx is None # ============================================================================= @@ -285,6 +282,53 @@ async def aclose(self) -> None: assert fake.closed == 1, name +@pytest.mark.asyncio +async def test_create_cache_closes_provider(monkeypatch: pytest.MonkeyPatch) -> None: + """create_cache should close provider resources on success and failure.""" + + @dataclass + class _Provider(FakeProvider): + closed: int = 0 + fail_cache: bool = False + + async def create_cache( + self, + *, + model: str, + parts: list[Any], + system_instruction: str | None = None, + ttl_seconds: int = 3600, + ) -> str: + if self.fail_cache: + raise APIError("cache failed", provider="gemini", phase="cache") + return await super().create_cache( + model=model, + parts=parts, + system_instruction=system_instruction, + ttl_seconds=ttl_seconds, + ) + + async def aclose(self) -> None: + self.closed += 1 + + cfg = Config(provider="gemini", model=CACHE_MODEL, use_mock=True) + for fail_cache in (False, True): + fake = _Provider(fail_cache=fail_cache) + monkeypatch.setattr(pollux, "_get_provider", lambda _config, _fake=fake: _fake) + monkeypatch.setattr(pollux, "_registry", CacheRegistry()) + + if fail_cache: + with pytest.raises(APIError, match="cache failed"): + await pollux.create_cache((Source.from_text("cache me"),), config=cfg) + else: + handle = await pollux.create_cache( + (Source.from_text("cache me"),), config=cfg + ) + assert isinstance(handle, CacheHandle) + + assert fake.closed == 1 + + # ============================================================================= # Retry Behavior (Boundary) # ============================================================================= @@ -401,7 +445,7 @@ async def test_source_from_json_is_sent_as_inline_content() -> None: async def test_cache_single_flight_propagates_failure_and_clears_inflight( monkeypatch: pytest.MonkeyPatch, ) -> None: - """If cache creation fails, all waiters should see the error and future calls can recover.""" + """If cache creation fails, concurrent callers see the error; future calls recover.""" fake = GateProvider(kind="cache") monkeypatch.setattr(pollux, "_get_provider", lambda _config: fake) monkeypatch.setattr(pollux, "_registry", CacheRegistry()) @@ -410,17 +454,16 @@ async def test_cache_single_flight_propagates_failure_and_clears_inflight( provider="gemini", model=CACHE_MODEL, use_mock=True, - enable_caching=True, retry=RetryPolicy(max_attempts=1), ) source = Source.from_text("cache me", identifier="same-id") t1 = asyncio.create_task( - pollux.run_many(("A",), sources=(source,), config=cfg), + pollux.create_cache((source,), config=cfg), ) await fake.started.wait() t2 = asyncio.create_task( - pollux.run_many(("B",), sources=(source,), config=cfg), + pollux.create_cache((source,), config=cfg), ) fake.release.set() @@ -430,13 +473,13 @@ async def test_cache_single_flight_propagates_failure_and_clears_inflight( assert fake.cache_calls == 1 # After the failure, the registry should not be stuck; it should be able to create a cache. - result = await pollux.run_many(("C",), sources=(source,), config=cfg) - assert result["status"] == "ok" + handle = await pollux.create_cache((source,), config=cfg) + assert isinstance(handle, CacheHandle) assert fake.cache_calls == 2 # And after a successful cache, additional calls should not recreate it. - result2 = await pollux.run_many(("D",), sources=(source,), config=cfg) - assert result2["status"] == "ok" + handle2 = await pollux.create_cache((source,), config=cfg) + assert isinstance(handle2, CacheHandle) assert fake.cache_calls == 2 @@ -735,7 +778,8 @@ async def delete_file(self, file_id: str) -> None: # noqa: ARG002 async def test_cached_context_is_not_resent_on_each_call( monkeypatch: pytest.MonkeyPatch, ) -> None: - """When cache is active, call payloads should include only prompt-specific parts.""" + """When cache is active via Options(cache=handle), payloads include only prompts.""" + import time @dataclass class PartsCaptureProvider(FakeProvider): @@ -755,25 +799,179 @@ async def generate(self, request: ProviderRequest) -> ProviderResponse: fake = PartsCaptureProvider() monkeypatch.setattr(pollux, "_get_provider", lambda _config: fake) - monkeypatch.setattr(pollux, "_registry", CacheRegistry()) - cfg = Config( provider="gemini", model=CACHE_MODEL, use_mock=True, - enable_caching=True, ) + + handle = CacheHandle( + name="cachedContents/test", + model=CACHE_MODEL, + provider="gemini", + expires_at=time.time() + 3600, + ) + await pollux.run_many( prompts=("A", "B"), sources=(Source.from_text("shared context"),), config=cfg, + options=Options(cache=handle), ) - assert fake.cache_calls == 1 assert fake.cache_names == ["cachedContents/test", "cachedContents/test"] assert fake.received_parts == [["A"], ["B"]] +@pytest.mark.asyncio +async def test_options_cache_requires_persistent_cache_capability( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Passing Options(cache=...) should fail on providers without persistent caching.""" + import time + + fake = FakeProvider( + _capabilities=ProviderCapabilities( + persistent_cache=False, + uploads=True, + structured_outputs=False, + reasoning=False, + deferred_delivery=False, + conversation=False, + ) + ) + monkeypatch.setattr(pollux, "_get_provider", lambda _config: fake) + + cfg = Config(provider="openai", model=OPENAI_MODEL, use_mock=True) + handle = CacheHandle( + name="cachedContents/test", + model=OPENAI_MODEL, + provider="openai", + expires_at=time.time() + 3600, + ) + + with pytest.raises(ConfigurationError, match="persistent caching"): + await pollux.run_many( + prompts=("Q",), + sources=(Source.from_text("shared context"),), + config=cfg, + options=Options(cache=handle), + ) + + assert fake.last_parts is None + + +@pytest.mark.asyncio +async def test_options_cache_rejects_provider_and_model_mismatch( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Cache handles must match the active provider and model.""" + import time + + fake = FakeProvider() + monkeypatch.setattr(pollux, "_get_provider", lambda _config: fake) + + cfg = Config(provider="gemini", model=GEMINI_MODEL, use_mock=True) + bad_provider = CacheHandle( + name="cachedContents/test", + model=GEMINI_MODEL, + provider="openai", + expires_at=time.time() + 3600, + ) + bad_model = CacheHandle( + name="cachedContents/test", + model=OPENAI_MODEL, + provider="gemini", + expires_at=time.time() + 3600, + ) + + with pytest.raises(ConfigurationError, match="provider does not match"): + await pollux.run_many( + prompts=("Q",), + sources=(Source.from_text("shared context"),), + config=cfg, + options=Options(cache=bad_provider), + ) + with pytest.raises(ConfigurationError, match="model does not match"): + await pollux.run_many( + prompts=("Q",), + sources=(Source.from_text("shared context"),), + config=cfg, + options=Options(cache=bad_model), + ) + + assert fake.last_parts is None + + +@pytest.mark.asyncio +async def test_options_cache_rejects_system_instruction( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """system_instruction cannot coexist with a cache handle.""" + import time + + fake = FakeProvider() + monkeypatch.setattr(pollux, "_get_provider", lambda _config: fake) + + cfg = Config(provider="gemini", model=GEMINI_MODEL, use_mock=True) + handle = CacheHandle( + name="cachedContents/test", + model=GEMINI_MODEL, + provider="gemini", + expires_at=time.time() + 3600, + ) + + with pytest.raises(ConfigurationError, match="system_instruction cannot be used"): + await pollux.run_many( + prompts=("Q",), + sources=(Source.from_text("shared context"),), + config=cfg, + options=Options(cache=handle, system_instruction="Be concise."), + ) + + assert fake.last_parts is None + + +@pytest.mark.asyncio +async def test_options_cache_rejects_tools( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """tools cannot coexist with a cache handle.""" + import time + + fake = FakeProvider() + monkeypatch.setattr(pollux, "_get_provider", lambda _config: fake) + + cfg = Config(provider="gemini", model=GEMINI_MODEL, use_mock=True) + handle = CacheHandle( + name="cachedContents/test", + model=GEMINI_MODEL, + provider="gemini", + expires_at=time.time() + 3600, + ) + + tools = [ + { + "name": "get_weather", + "description": "Get weather", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + }, + } + ] + + with pytest.raises(ConfigurationError, match="tools cannot be used"): + await pollux.run_many( + prompts=("Q",), + sources=(Source.from_text("shared context"),), + config=cfg, + options=Options(cache=handle, tools=tools), + ) + + assert fake.last_parts is None + + def test_cache_identity_uses_content_digest_not_identifier_only() -> None: """Regression: cache identity keys must not collide across distinct sources.""" model = GEMINI_MODEL @@ -818,6 +1016,68 @@ def test_cache_identity_includes_system_instruction() -> None: assert concise != verbose +@pytest.mark.asyncio +async def test_create_cache_returns_handle(monkeypatch: pytest.MonkeyPatch) -> None: + """create_cache() should return a CacheHandle with the expected fields.""" + fake = FakeProvider() + monkeypatch.setattr(pollux, "_get_provider", lambda _config: fake) + monkeypatch.setattr(pollux, "_registry", CacheRegistry()) + + cfg = Config(provider="gemini", model=CACHE_MODEL, use_mock=True) + handle = await pollux.create_cache( + (Source.from_text("hello"),), + config=cfg, + ttl_seconds=600, + ) + + assert isinstance(handle, CacheHandle) + assert handle.name == "cachedContents/test" + assert handle.model == CACHE_MODEL + assert handle.provider == "gemini" + assert handle.expires_at > 0 + assert fake.cache_calls == 1 + + +@pytest.mark.asyncio +async def test_create_cache_rejects_unsupported_provider( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """create_cache() should raise ConfigurationError for providers without persistent_cache.""" + fake = FakeProvider( + _capabilities=ProviderCapabilities( + persistent_cache=False, + uploads=True, + structured_outputs=False, + reasoning=False, + ) + ) + monkeypatch.setattr(pollux, "_get_provider", lambda _config: fake) + + cfg = Config(provider="gemini", model=GEMINI_MODEL, use_mock=True) + with pytest.raises(ConfigurationError, match="persistent caching"): + await pollux.create_cache( + (Source.from_text("hello"),), + config=cfg, + ) + + +def test_create_cache_validates_ttl() -> None: + """create_cache() should reject invalid ttl_seconds synchronously (via coroutine).""" + cfg = Config(provider="gemini", model=GEMINI_MODEL, use_mock=True) + + async def _check() -> None: + with pytest.raises(ConfigurationError, match="ttl_seconds"): + await pollux.create_cache( + (Source.from_text("hello"),), config=cfg, ttl_seconds=0 + ) + with pytest.raises(ConfigurationError, match="ttl_seconds"): + await pollux.create_cache( + (Source.from_text("hello"),), config=cfg, ttl_seconds=-1 + ) + + asyncio.get_event_loop().run_until_complete(_check()) + + @pytest.mark.asyncio async def test_options_response_schema_requires_provider_capability() -> None: """Strict capability checks reject unsupported structured outputs.""" @@ -859,7 +1119,7 @@ class ExampleSchema(BaseModel): fake = FakeProvider( _capabilities=ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, structured_outputs=True, reasoning=True, @@ -898,7 +1158,7 @@ async def test_delivery_mode_deferred_is_explicitly_not_implemented( """Deferred delivery should fail clearly until backend support lands.""" fake = FakeProvider( _capabilities=ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, structured_outputs=True, reasoning=True, @@ -931,7 +1191,7 @@ class Paper(BaseModel): class _StructuredProvider(FakeProvider): _capabilities: ProviderCapabilities = field( default_factory=lambda: ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, structured_outputs=True, reasoning=False, @@ -986,7 +1246,7 @@ async def test_conversation_options_are_forwarded_when_provider_supports_them( """Conversation options should pass through when provider supports the feature.""" fake = FakeProvider( _capabilities=ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, structured_outputs=True, reasoning=True, @@ -1015,7 +1275,7 @@ async def test_conversation_requires_single_prompt_per_call( """Conversation continuity is single-turn per API call in v1.1.""" fake = FakeProvider( _capabilities=ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, structured_outputs=True, reasoning=False, @@ -1041,7 +1301,7 @@ async def test_continue_from_requires_conversation_state( """continue_from must include _conversation_state; valid state is forwarded.""" fake = KwargsCaptureProvider( _capabilities=ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, structured_outputs=True, reasoning=False, @@ -1090,7 +1350,7 @@ async def test_conversation_result_includes_conversation_state( class _ConversationProvider(FakeProvider): _capabilities: ProviderCapabilities = field( default_factory=lambda: ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, structured_outputs=False, reasoning=False, @@ -1137,7 +1397,7 @@ async def test_conversation_state_preserves_provider_state_from_response( class _ProviderStateConversationProvider(FakeProvider): _capabilities: ProviderCapabilities = field( default_factory=lambda: ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, structured_outputs=False, reasoning=True, @@ -1288,7 +1548,7 @@ class Paper(BaseModel): fake = ScriptedProvider( _capabilities=ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, structured_outputs=True, reasoning=False, @@ -1350,7 +1610,7 @@ async def test_tool_call_response_populates_conversation_state_without_history( class _ToolCallProvider(FakeProvider): _capabilities: ProviderCapabilities = field( default_factory=lambda: ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, structured_outputs=False, reasoning=False, @@ -1410,7 +1670,7 @@ async def test_tool_calls_preserved_in_conversation_state( class _ToolCallProvider(FakeProvider): _capabilities: ProviderCapabilities = field( default_factory=lambda: ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, structured_outputs=False, reasoning=False, @@ -1452,7 +1712,7 @@ async def test_continue_from_preserves_tool_history_items( """continue_from with tool messages in history passes them to provider.""" fake = KwargsCaptureProvider( _capabilities=ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, structured_outputs=False, reasoning=False, @@ -1501,7 +1761,7 @@ async def test_continue_from_forwards_provider_state_with_history_items( """History item provider_state should be forwarded in request.provider_state.""" fake = KwargsCaptureProvider( _capabilities=ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, structured_outputs=False, reasoning=True, @@ -1561,7 +1821,7 @@ async def test_continue_tool_mechanics(monkeypatch: pytest.MonkeyPatch) -> None: """continue_tool should neatly append tool results and allow None prompt.""" fake = KwargsCaptureProvider( _capabilities=ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, structured_outputs=False, reasoning=False, @@ -1621,7 +1881,7 @@ async def test_reasoning_surfaced_in_result_envelope( """Provider reasoning text should appear in ResultEnvelope.reasoning.""" fake = ScriptedProvider( _capabilities=ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, reasoning=True, ), @@ -1653,7 +1913,7 @@ async def test_reasoning_omitted_when_absent( """ResultEnvelope should not include reasoning key when provider omits it.""" fake = ScriptedProvider( _capabilities=ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, reasoning=True, ), @@ -1676,7 +1936,7 @@ async def test_reasoning_mixed_across_multi_prompt( """Multi-prompt: reasoning=None for calls without thinking content.""" fake = ScriptedProvider( _capabilities=ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, reasoning=True, ), @@ -1705,7 +1965,7 @@ async def test_reasoning_tokens_aggregate_in_result_usage( """Pipeline should preserve and sum reasoning_tokens across calls.""" fake = ScriptedProvider( _capabilities=ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, reasoning=True, ), From f1004a5bc3000ac7d661ee5c35b7f18b9141103a Mon Sep 17 00:00:00 2001 From: Sean Brar Date: Wed, 4 Mar 2026 02:34:04 -0800 Subject: [PATCH 02/14] test: Convert test_create_cache_validates_ttl to async --- tests/test_pipeline.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index ba195da..c50c4b0 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1061,21 +1061,19 @@ async def test_create_cache_rejects_unsupported_provider( ) -def test_create_cache_validates_ttl() -> None: +@pytest.mark.asyncio +async def test_create_cache_validates_ttl() -> None: """create_cache() should reject invalid ttl_seconds synchronously (via coroutine).""" cfg = Config(provider="gemini", model=GEMINI_MODEL, use_mock=True) - async def _check() -> None: - with pytest.raises(ConfigurationError, match="ttl_seconds"): - await pollux.create_cache( - (Source.from_text("hello"),), config=cfg, ttl_seconds=0 - ) - with pytest.raises(ConfigurationError, match="ttl_seconds"): - await pollux.create_cache( - (Source.from_text("hello"),), config=cfg, ttl_seconds=-1 - ) - - asyncio.get_event_loop().run_until_complete(_check()) + with pytest.raises(ConfigurationError, match="ttl_seconds"): + await pollux.create_cache( + (Source.from_text("hello"),), config=cfg, ttl_seconds=0 + ) + with pytest.raises(ConfigurationError, match="ttl_seconds"): + await pollux.create_cache( + (Source.from_text("hello"),), config=cfg, ttl_seconds=-1 + ) @pytest.mark.asyncio From bc62c025957149c3f8f6e3814d90f1d4495f90b5 Mon Sep 17 00:00:00 2001 From: Sean Brar Date: Wed, 4 Mar 2026 02:51:57 -0800 Subject: [PATCH 03/14] fix(cache): plumb expires_at and tools through create_cache --- src/pollux/__init__.py | 14 ++++++++---- src/pollux/cache.py | 38 ++++++++++++++++++++----------- src/pollux/providers/anthropic.py | 3 ++- src/pollux/providers/base.py | 1 + src/pollux/providers/gemini.py | 17 ++++++++++---- src/pollux/providers/mock.py | 1 + src/pollux/providers/openai.py | 3 ++- tests/conftest.py | 3 ++- tests/test_pipeline.py | 2 ++ 9 files changed, 56 insertions(+), 26 deletions(-) diff --git a/src/pollux/__init__.py b/src/pollux/__init__.py index d601ad5..dc897ca 100644 --- a/src/pollux/__init__.py +++ b/src/pollux/__init__.py @@ -12,7 +12,6 @@ import asyncio import logging -import time from typing import TYPE_CHECKING, Any from pollux.cache import CacheHandle, CacheRegistry @@ -168,6 +167,7 @@ async def create_cache( *, config: Config, system_instruction: str | None = None, + tools: list[dict[str, Any]] | list[Any] | None = None, ttl_seconds: int = 3600, ) -> CacheHandle: """Create a persistent context cache for use with ``run()`` / ``run_many()``. @@ -176,6 +176,7 @@ async def create_cache( sources: Content to cache (files, text, URLs). config: Configuration specifying provider and model. system_instruction: Optional system-level instruction baked into the cache. + tools: Optional tools to bake into the cache. ttl_seconds: Time-to-live in seconds (must be ≥ 1). Returns: @@ -242,31 +243,34 @@ async def create_cache( from pollux.cache import compute_cache_key key = compute_cache_key( - config.model, src_tuple, system_instruction=system_instruction + config.model, src_tuple, system_instruction=system_instruction, tools=tools ) - cache_name = await get_or_create_cache( + result = await get_or_create_cache( provider, _registry, key=key, model=config.model, parts=parts, system_instruction=system_instruction, + tools=tools, ttl_seconds=ttl_seconds, retry_policy=retry_policy, ) - if cache_name is None: + if result is None: raise InternalError( "Cache creation returned None unexpectedly", hint="This is a Pollux internal error. Please report it.", ) + cache_name, expires_at = result + return CacheHandle( name=cache_name, model=config.model, provider=config.provider, - expires_at=time.time() + ttl_seconds, + expires_at=expires_at, ) finally: await _close_provider(provider) diff --git a/src/pollux/cache.py b/src/pollux/cache.py index 00b79c3..7640001 100644 --- a/src/pollux/cache.py +++ b/src/pollux/cache.py @@ -38,31 +38,33 @@ class CacheRegistry: """Registry tracking cache entries with expiration.""" _entries: dict[str, tuple[str, float]] = field(default_factory=dict) - _inflight: dict[str, asyncio.Future[str]] = field(default_factory=dict) + _inflight: dict[str, asyncio.Future[tuple[str, float]]] = field( + default_factory=dict + ) _lock: asyncio.Lock = field(default_factory=asyncio.Lock) - def get(self, key: str) -> str | None: - """Get cache name if exists and not expired.""" + def get(self, key: str) -> tuple[str, float] | None: + """Get cache entry if exists and not expired.""" entry = self._entries.get(key) if entry is None: return None - name, expires_at = entry + _, expires_at = entry if time.time() >= expires_at: del self._entries[key] logger.debug("Cache expired key=%s…", key[:8]) return None - return name + return entry - def set(self, key: str, name: str, ttl_seconds: int) -> None: + def set(self, key: str, value: tuple[str, float]) -> None: """Store cache entry with expiration time.""" - expires_at = time.time() + max(0, ttl_seconds) - self._entries[key] = (name, expires_at) + self._entries[key] = value def compute_cache_key( model: str, sources: tuple[Source, ...], system_instruction: str | None = None, + tools: list[dict[str, Any]] | list[Any] | None = None, ) -> str: """Compute deterministic cache key using content hashes. @@ -72,6 +74,11 @@ def compute_cache_key( parts = [model] if system_instruction: parts.append(system_instruction) + if tools: + import json + + # sort_keys to ensure deterministic JSON representation + parts.append(json.dumps(tools, sort_keys=True)) for source in sources: # Use content hash, not identifier @@ -89,9 +96,10 @@ async def get_or_create_cache( model: str, parts: list[Any], system_instruction: str | None, + tools: list[dict[str, Any]] | list[Any] | None = None, ttl_seconds: int, retry_policy: RetryPolicy | None = None, -) -> str | None: +) -> tuple[str, float] | None: """Get existing cache or create new one with single-flight protection. Single-flight: concurrent requests for the same key share one creation call. @@ -99,32 +107,36 @@ async def get_or_create_cache( if not provider.capabilities.persistent_cache: return None - async def _work() -> str: + async def _work() -> tuple[str, float]: logger.debug("Creating cache key=%s…", key[:8]) if retry_policy is None or retry_policy.max_attempts <= 1: - return await provider.create_cache( + name = await provider.create_cache( model=model, parts=parts, system_instruction=system_instruction, + tools=tools, ttl_seconds=ttl_seconds, ) + return name, time.time() + max(0, ttl_seconds) - return await retry_async( + name = await retry_async( lambda: provider.create_cache( model=model, parts=parts, system_instruction=system_instruction, + tools=tools, ttl_seconds=ttl_seconds, ), policy=retry_policy, should_retry=should_retry_side_effect, ) + return name, time.time() + max(0, ttl_seconds) return await singleflight_cached( key, lock=registry._lock, inflight=registry._inflight, cache_get=registry.get, - cache_set=lambda k, v: registry.set(k, v, ttl_seconds), + cache_set=registry.set, work=_work, ) diff --git a/src/pollux/providers/anthropic.py b/src/pollux/providers/anthropic.py index ce835a9..2c90e71 100644 --- a/src/pollux/providers/anthropic.py +++ b/src/pollux/providers/anthropic.py @@ -341,10 +341,11 @@ async def create_cache( model: str, parts: list[Any], system_instruction: str | None = None, + tools: list[dict[str, Any]] | list[Any] | None = None, ttl_seconds: int = 3600, ) -> str: """Raise because Anthropic caching is deferred.""" - _ = model, parts, system_instruction, ttl_seconds + _ = model, parts, system_instruction, tools, ttl_seconds raise APIError("Anthropic provider does not support context caching") async def aclose(self) -> None: diff --git a/src/pollux/providers/base.py b/src/pollux/providers/base.py index e832327..ee9bebe 100644 --- a/src/pollux/providers/base.py +++ b/src/pollux/providers/base.py @@ -48,6 +48,7 @@ async def create_cache( model: str, parts: list[Any], system_instruction: str | None = None, + tools: list[dict[str, Any]] | list[Any] | None = None, ttl_seconds: int = 3600, ) -> str: """Create a cache and return its name.""" diff --git a/src/pollux/providers/gemini.py b/src/pollux/providers/gemini.py index e009291..f671d8e 100644 --- a/src/pollux/providers/gemini.py +++ b/src/pollux/providers/gemini.py @@ -415,20 +415,27 @@ async def create_cache( model: str, parts: list[Any], system_instruction: str | None = None, + tools: list[dict[str, Any]] | list[Any] | None = None, ttl_seconds: int = 3600, ) -> str: """Create a cached content entry.""" client = self._get_client() from google.genai import types + config_kwargs: dict[str, Any] = { + "contents": self._convert_parts(parts), + "system_instruction": system_instruction, + "ttl": f"{ttl_seconds}s", + } + if tools is not None: + tool_objs = self._normalize_tools(tools) + if tool_objs: + config_kwargs["tools"] = tool_objs + try: result = await client.aio.caches.create( model=model, - config=types.CreateCachedContentConfig( - contents=self._convert_parts(parts), - system_instruction=system_instruction, - ttl=f"{ttl_seconds}s", - ), + config=types.CreateCachedContentConfig(**config_kwargs), ) return str(result.name) except asyncio.CancelledError: diff --git a/src/pollux/providers/mock.py b/src/pollux/providers/mock.py index 023a1c9..c7dc2c5 100644 --- a/src/pollux/providers/mock.py +++ b/src/pollux/providers/mock.py @@ -60,6 +60,7 @@ async def create_cache( model: str, parts: list[Any], # noqa: ARG002 system_instruction: str | None = None, # noqa: ARG002 + tools: list[dict[str, Any]] | list[Any] | None = None, # noqa: ARG002 ttl_seconds: int = 3600, # noqa: ARG002 ) -> str: """Return a mock cache name.""" diff --git a/src/pollux/providers/openai.py b/src/pollux/providers/openai.py index ea9fd45..a9efd01 100644 --- a/src/pollux/providers/openai.py +++ b/src/pollux/providers/openai.py @@ -326,10 +326,11 @@ async def create_cache( model: str, parts: list[Any], system_instruction: str | None = None, + tools: list[dict[str, Any]] | list[Any] | None = None, ttl_seconds: int = 3600, ) -> str: """Raise because OpenAI context caching is not supported.""" - _ = model, parts, system_instruction, ttl_seconds + _ = model, parts, system_instruction, tools, ttl_seconds raise APIError("OpenAI provider does not support context caching") async def delete_file(self, file_id: str) -> None: diff --git a/tests/conftest.py b/tests/conftest.py index 53a06b9..74d68c2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -78,9 +78,10 @@ async def create_cache( model: str, parts: list[Any], system_instruction: str | None = None, + tools: list[dict[str, Any]] | list[Any] | None = None, ttl_seconds: int = 3600, ) -> str: - del model, parts, system_instruction, ttl_seconds + del model, parts, system_instruction, tools, ttl_seconds self.cache_calls += 1 return "cachedContents/test" diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index c50c4b0..5e26daf 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -297,6 +297,7 @@ async def create_cache( model: str, parts: list[Any], system_instruction: str | None = None, + tools: list[dict[str, Any]] | list[Any] | None = None, ttl_seconds: int = 3600, ) -> str: if self.fail_cache: @@ -305,6 +306,7 @@ async def create_cache( model=model, parts=parts, system_instruction=system_instruction, + tools=tools, ttl_seconds=ttl_seconds, ) From 9b01360f6a8fc34e62e980ed824e22ce225b2104 Mon Sep 17 00:00:00 2001 From: Sean Brar Date: Wed, 4 Mar 2026 03:01:34 -0800 Subject: [PATCH 04/14] fix(cache): raise ConfigurationError on unserializable tools in create_cache --- src/pollux/cache.py | 11 +++++++++-- tests/test_pipeline.py | 24 ++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/src/pollux/cache.py b/src/pollux/cache.py index 7640001..82a0f9a 100644 --- a/src/pollux/cache.py +++ b/src/pollux/cache.py @@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any from pollux._singleflight import singleflight_cached +from pollux.errors import ConfigurationError from pollux.retry import RetryPolicy, retry_async, should_retry_side_effect if TYPE_CHECKING: @@ -77,8 +78,14 @@ def compute_cache_key( if tools: import json - # sort_keys to ensure deterministic JSON representation - parts.append(json.dumps(tools, sort_keys=True)) + try: + # sort_keys to ensure deterministic JSON representation + parts.append(json.dumps(tools, sort_keys=True)) + except TypeError as e: + raise ConfigurationError( + "Tools provided to create_cache() must be JSON serializable.", + hint="If using custom objects or Pydantic models for tools, convert them to dicts first.", + ) from e for source in sources: # Use content hash, not identifier diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 5e26daf..9da02ac 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1040,6 +1040,30 @@ async def test_create_cache_returns_handle(monkeypatch: pytest.MonkeyPatch) -> N assert fake.cache_calls == 1 +@pytest.mark.asyncio +async def test_create_cache_rejects_unserializable_tools( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """create_cache() should raise ConfigurationError with a hint for non-JSON tools.""" + fake = FakeProvider() + monkeypatch.setattr(pollux, "_get_provider", lambda _config: fake) + monkeypatch.setattr(pollux, "_registry", CacheRegistry()) + + cfg = Config(provider="gemini", model=CACHE_MODEL, use_mock=True) + + class CustomTool: + pass + + with pytest.raises(ConfigurationError, match="JSON serializable") as exc: + await pollux.create_cache( + (Source.from_text("hello"),), + config=cfg, + tools=[CustomTool()], + ) + + assert "convert them to dicts" in str(exc.value.hint) + + @pytest.mark.asyncio async def test_create_cache_rejects_unsupported_provider( monkeypatch: pytest.MonkeyPatch, From e187537c15714db1e141df1a836afd3063d867f7 Mon Sep 17 00:00:00 2001 From: Sean Brar Date: Wed, 4 Mar 2026 03:15:15 -0800 Subject: [PATCH 05/14] fix: validate cache sources and tool formats Addresses P1 and P2 findings from code review: - Raises ConfigurationError if options.cache is used alongside sources. - Validates tools as dictionaries before creation in Gemini provider. --- .../optimization/cache-warming-and-ttl.py | 8 +--- src/pollux/execute.py | 8 ++++ src/pollux/providers/gemini.py | 16 +++++--- tests/test_pipeline.py | 37 ++++++------------- 4 files changed, 32 insertions(+), 37 deletions(-) diff --git a/cookbook/optimization/cache-warming-and-ttl.py b/cookbook/optimization/cache-warming-and-ttl.py index eac8957..9a66227 100644 --- a/cookbook/optimization/cache-warming-and-ttl.py +++ b/cookbook/optimization/cache-warming-and-ttl.py @@ -60,12 +60,8 @@ async def main_async(directory: Path, *, limit: int, config: Config, ttl: int) - handle = await create_cache(sources, config=config, ttl_seconds=ttl) - warm = await run_many( - PROMPTS, sources=sources, config=config, options=Options(cache=handle) - ) - reuse = await run_many( - PROMPTS, sources=sources, config=config, options=Options(cache=handle) - ) + warm = await run_many(PROMPTS, config=config, options=Options(cache=handle)) + reuse = await run_many(PROMPTS, config=config, options=Options(cache=handle)) warm_tokens = usage_tokens(warm) reuse_tokens = usage_tokens(reuse) saved = None diff --git a/src/pollux/execute.py b/src/pollux/execute.py index 4dc6d82..a4393bc 100644 --- a/src/pollux/execute.py +++ b/src/pollux/execute.py @@ -159,6 +159,14 @@ async def execute_plan(plan: Plan, provider: Provider) -> ExecutionTrace: "or remove the cache handle." ), ) + if plan.shared_parts: + raise ConfigurationError( + "sources cannot be used with a cache handle", + hint=( + "Bake sources into create_cache() instead, " + "or remove the cache handle." + ), + ) if (not provider.capabilities.uploads) and any( isinstance(p, dict) diff --git a/src/pollux/providers/gemini.py b/src/pollux/providers/gemini.py index f671d8e..13a62ff 100644 --- a/src/pollux/providers/gemini.py +++ b/src/pollux/providers/gemini.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any import uuid -from pollux.errors import APIError +from pollux.errors import APIError, ConfigurationError from pollux.providers._errors import wrap_provider_error from pollux.providers.base import ProviderCapabilities from pollux.providers.models import ( @@ -104,6 +104,11 @@ def _normalize_tools(self, tools: list[dict[str, Any]]) -> list[Any]: tool_objs: list[Any] = [] for t in tools: + if not isinstance(t, dict): + raise ConfigurationError( + f"Tool must be a dictionary, got {type(t).__name__}", + hint="Ensure all items in the tools list are dictionaries.", + ) if "name" in t: raw_params = t.get("parameters") params = ( @@ -427,12 +432,13 @@ async def create_cache( "system_instruction": system_instruction, "ttl": f"{ttl_seconds}s", } - if tools is not None: - tool_objs = self._normalize_tools(tools) - if tool_objs: - config_kwargs["tools"] = tool_objs try: + if tools is not None: + tool_objs = self._normalize_tools(tools) + if tool_objs: + config_kwargs["tools"] = tool_objs + result = await client.aio.caches.create( model=model, config=types.CreateCachedContentConfig(**config_kwargs), diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 9da02ac..70203e8 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -777,28 +777,13 @@ async def delete_file(self, file_id: str) -> None: # noqa: ARG002 @pytest.mark.asyncio -async def test_cached_context_is_not_resent_on_each_call( +async def test_cached_context_rejects_sources( monkeypatch: pytest.MonkeyPatch, ) -> None: - """When cache is active via Options(cache=handle), payloads include only prompts.""" + """When cache is active via Options(cache=handle), passing sources raises ConfigurationError.""" import time - @dataclass - class PartsCaptureProvider(FakeProvider): - received_parts: list[list[Any]] = field(default_factory=list) - cache_names: list[str | None] = field(default_factory=list) - - async def generate(self, request: ProviderRequest) -> ProviderResponse: - self.received_parts.append(request.parts) - self.cache_names.append(request.cache_name) - prompt = ( - request.parts[-1] - if request.parts and isinstance(request.parts[-1], str) - else "" - ) - return ProviderResponse(text=f"ok:{prompt}", usage={"total_tokens": 1}) - - fake = PartsCaptureProvider() + fake = FakeProvider() monkeypatch.setattr(pollux, "_get_provider", lambda _config: fake) cfg = Config( @@ -814,15 +799,15 @@ async def generate(self, request: ProviderRequest) -> ProviderResponse: expires_at=time.time() + 3600, ) - await pollux.run_many( - prompts=("A", "B"), - sources=(Source.from_text("shared context"),), - config=cfg, - options=Options(cache=handle), - ) + with pytest.raises(ConfigurationError, match="sources cannot be used"): + await pollux.run_many( + prompts=("A", "B"), + sources=(Source.from_text("shared context"),), + config=cfg, + options=Options(cache=handle), + ) - assert fake.cache_names == ["cachedContents/test", "cachedContents/test"] - assert fake.received_parts == [["A"], ["B"]] + assert fake.last_parts is None @pytest.mark.asyncio From d5eb8dd23b61308da6f62cc143eab61b2b5074c7 Mon Sep 17 00:00:00 2001 From: Sean Brar Date: Wed, 4 Mar 2026 03:18:26 -0800 Subject: [PATCH 06/14] docs: remove sources from caching example Updates caching.md to not pass sources alongside a cache handle, which is now explicitly rejected by ConfigurationError. --- docs/caching.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/caching.md b/docs/caching.md index 5dfb8c9..8936e26 100644 --- a/docs/caching.md +++ b/docs/caching.md @@ -95,8 +95,8 @@ async def main() -> None: handle = await create_cache(sources, config=config, ttl_seconds=3600) prompts = ["Summarize in one sentence.", "List 3 keywords."] - first = await run_many(prompts=prompts, sources=sources, config=config, options=Options(cache=handle)) - second = await run_many(prompts=prompts, sources=sources, config=config, options=Options(cache=handle)) + first = await run_many(prompts=prompts, config=config, options=Options(cache=handle)) + second = await run_many(prompts=prompts, config=config, options=Options(cache=handle)) print("first:", first["status"]) print("second:", second["status"]) @@ -133,6 +133,7 @@ the same handle reuse the cached context automatically. instead. - `tools` / `tool_choice` — bake them into `create_cache()` instead (when supported). + - `sources` — bake them into `create_cache()` instead. Pollux raises `ConfigurationError` immediately if it detects these conflicts. This mirrors a hard constraint in the Gemini API, where From 19b4bb3cd6f8cda0c9b504b03a4aa63a98a7a4f1 Mon Sep 17 00:00:00 2001 From: Sean Brar Date: Wed, 4 Mar 2026 03:43:26 -0800 Subject: [PATCH 07/14] fix(cache): avoid eager uploads and reject tool_choice with cache handles --- docs/configuration.md | 7 +++--- src/pollux/__init__.py | 22 ++++++++++++------ src/pollux/execute.py | 8 +++++++ tests/test_pipeline.py | 53 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 80 insertions(+), 10 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index e22bea3..2a18152 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -157,9 +157,10 @@ options = Options( for the full constraints mapping. !!! warning "Cache handle restrictions" - When `cache` is set, `system_instruction` and `tools` **must not** be - passed in the same `Options`. Bake them into `create_cache()` instead. - See [Reducing Costs with Context Caching](caching.md) for details. + When `cache` is set, `system_instruction`, `tools`, and `tool_choice` + **must not** be passed in the same `Options`. Bake them into + `create_cache()` instead. See + [Reducing Costs with Context Caching](caching.md) for details. ## Safety Notes diff --git a/src/pollux/__init__.py b/src/pollux/__init__.py index dc897ca..277a755 100644 --- a/src/pollux/__init__.py +++ b/src/pollux/__init__.py @@ -194,7 +194,7 @@ async def create_cache( ) result = await run("Summarize.", config=config, options=Options(cache=handle)) """ - from pollux.cache import get_or_create_cache + from pollux.cache import compute_cache_key, get_or_create_cache from pollux.execute import _substitute_upload_parts from pollux.plan import build_shared_parts @@ -222,6 +222,20 @@ async def create_cache( hint="Use Source.from_file(), Source.from_text(), etc.", ) + key = compute_cache_key( + config.model, src_tuple, system_instruction=system_instruction, tools=tools + ) + + cached = _registry.get(key) + if cached is not None: + cache_name, expires_at = cached + return CacheHandle( + name=cache_name, + model=config.model, + provider=config.provider, + expires_at=expires_at, + ) + parts = build_shared_parts(src_tuple) # Resolve file uploads. @@ -240,12 +254,6 @@ async def create_cache( retry_policy=retry_policy, ) - from pollux.cache import compute_cache_key - - key = compute_cache_key( - config.model, src_tuple, system_instruction=system_instruction, tools=tools - ) - result = await get_or_create_cache( provider, _registry, diff --git a/src/pollux/execute.py b/src/pollux/execute.py index a4393bc..7a9a5a9 100644 --- a/src/pollux/execute.py +++ b/src/pollux/execute.py @@ -159,6 +159,14 @@ async def execute_plan(plan: Plan, provider: Provider) -> ExecutionTrace: "or remove the cache handle." ), ) + if options.tool_choice is not None: + raise ConfigurationError( + "tool_choice cannot be used with a cache handle", + hint=( + "Bake tools/tool_choice into create_cache() instead, " + "or remove the cache handle." + ), + ) if plan.shared_parts: raise ConfigurationError( "sources cannot be used with a cache handle", diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 70203e8..0adc9f9 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -959,6 +959,35 @@ async def test_options_cache_rejects_tools( assert fake.last_parts is None +@pytest.mark.asyncio +async def test_options_cache_rejects_tool_choice( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """tool_choice cannot coexist with a cache handle.""" + import time + + fake = FakeProvider() + monkeypatch.setattr(pollux, "_get_provider", lambda _config: fake) + + cfg = Config(provider="gemini", model=GEMINI_MODEL, use_mock=True) + handle = CacheHandle( + name="cachedContents/test", + model=GEMINI_MODEL, + provider="gemini", + expires_at=time.time() + 3600, + ) + + with pytest.raises(ConfigurationError, match="tool_choice cannot be used"): + await pollux.run_many( + prompts=("Q",), + sources=(Source.from_text("shared context"),), + config=cfg, + options=Options(cache=handle, tool_choice="required"), + ) + + assert fake.last_parts is None + + def test_cache_identity_uses_content_digest_not_identifier_only() -> None: """Regression: cache identity keys must not collide across distinct sources.""" model = GEMINI_MODEL @@ -1025,6 +1054,28 @@ async def test_create_cache_returns_handle(monkeypatch: pytest.MonkeyPatch) -> N assert fake.cache_calls == 1 +@pytest.mark.asyncio +async def test_create_cache_cache_hit_skips_uploads( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Any, +) -> None: + """Repeated create_cache() calls for the same key should not re-upload files.""" + fake = FakeProvider() + monkeypatch.setattr(pollux, "_get_provider", lambda _config: fake) + monkeypatch.setattr(pollux, "_registry", CacheRegistry()) + + file_path = tmp_path / "cache-me.txt" + file_path.write_text("hello cache", encoding="utf-8") + + cfg = Config(provider="gemini", model=CACHE_MODEL, use_mock=True) + first = await pollux.create_cache((Source.from_file(file_path),), config=cfg) + second = await pollux.create_cache((Source.from_file(file_path),), config=cfg) + + assert first.name == second.name + assert fake.cache_calls == 1 + assert fake.upload_calls == 1 + + @pytest.mark.asyncio async def test_create_cache_rejects_unserializable_tools( monkeypatch: pytest.MonkeyPatch, @@ -1047,6 +1098,8 @@ class CustomTool: ) assert "convert them to dicts" in str(exc.value.hint) + assert fake.upload_calls == 0 + assert fake.cache_calls == 0 @pytest.mark.asyncio From cdd2af67139ac86f03a31a8f2af93321b9ed04a7 Mon Sep 17 00:00:00 2001 From: Sean Brar Date: Wed, 4 Mar 2026 18:54:38 -0800 Subject: [PATCH 08/14] fix(cache): scope cache identity by provider --- docs/caching.md | 5 +++-- docs/configuration.md | 5 +++-- src/pollux/__init__.py | 6 +++++- src/pollux/cache.py | 5 ++++- src/pollux/execute.py | 2 +- tests/test_pipeline.py | 13 ++++++++++++- 6 files changed, 28 insertions(+), 8 deletions(-) diff --git a/docs/caching.md b/docs/caching.md index 8936e26..3169379 100644 --- a/docs/caching.md +++ b/docs/caching.md @@ -131,8 +131,9 @@ the same handle reuse the cached context automatically. - `system_instruction` — bake it into `create_cache(system_instruction=...)` instead. - - `tools` / `tool_choice` — bake them into `create_cache()` instead (when + - `tools` — bake them into `create_cache(tools=...)` instead (when supported). + - `tool_choice` — remove it when using cached content. - `sources` — bake them into `create_cache()` instead. Pollux raises `ConfigurationError` immediately if it detects these @@ -142,7 +143,7 @@ the same handle reuse the cached context automatically. ## Cache Identity -Cache keys are deterministic: `hash(model + content hashes of sources)`. +Cache keys are deterministic: `hash(model + provider + content hashes of sources)`. This means: diff --git a/docs/configuration.md b/docs/configuration.md index 2a18152..c1387f4 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -158,8 +158,9 @@ options = Options( !!! warning "Cache handle restrictions" When `cache` is set, `system_instruction`, `tools`, and `tool_choice` - **must not** be passed in the same `Options`. Bake them into - `create_cache()` instead. See + **must not** be passed in the same `Options`. `system_instruction` and + `tools` can be baked into `create_cache()`, while `tool_choice` must be + set only on uncached calls. See [Reducing Costs with Context Caching](caching.md) for details. ## Safety Notes diff --git a/src/pollux/__init__.py b/src/pollux/__init__.py index 277a755..b1bb354 100644 --- a/src/pollux/__init__.py +++ b/src/pollux/__init__.py @@ -223,7 +223,11 @@ async def create_cache( ) key = compute_cache_key( - config.model, src_tuple, system_instruction=system_instruction, tools=tools + config.model, + src_tuple, + provider=config.provider, + system_instruction=system_instruction, + tools=tools, ) cached = _registry.get(key) diff --git a/src/pollux/cache.py b/src/pollux/cache.py index 82a0f9a..e64382f 100644 --- a/src/pollux/cache.py +++ b/src/pollux/cache.py @@ -64,15 +64,18 @@ def set(self, key: str, value: tuple[str, float]) -> None: def compute_cache_key( model: str, sources: tuple[Source, ...], + provider: str | None = None, system_instruction: str | None = None, tools: list[dict[str, Any]] | list[Any] | None = None, ) -> str: """Compute deterministic cache key using content hashes. - Key = hash(model + system + content digests of sources) + Key = hash(model + provider + system + content digests of sources) This fixes the cache identity collision bug where identifier+size was used. """ parts = [model] + if provider: + parts.append(provider) if system_instruction: parts.append(system_instruction) if tools: diff --git a/src/pollux/execute.py b/src/pollux/execute.py index 7a9a5a9..db5436c 100644 --- a/src/pollux/execute.py +++ b/src/pollux/execute.py @@ -163,7 +163,7 @@ async def execute_plan(plan: Plan, provider: Provider) -> ExecutionTrace: raise ConfigurationError( "tool_choice cannot be used with a cache handle", hint=( - "Bake tools/tool_choice into create_cache() instead, " + "Remove tool_choice when using a cache handle, " "or remove the cache handle." ), ) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 0adc9f9..2fa071e 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -977,7 +977,7 @@ async def test_options_cache_rejects_tool_choice( expires_at=time.time() + 3600, ) - with pytest.raises(ConfigurationError, match="tool_choice cannot be used"): + with pytest.raises(ConfigurationError, match="tool_choice cannot be used") as exc: await pollux.run_many( prompts=("Q",), sources=(Source.from_text("shared context"),), @@ -985,6 +985,8 @@ async def test_options_cache_rejects_tool_choice( options=Options(cache=handle, tool_choice="required"), ) + assert exc.value.hint is not None + assert "Remove tool_choice" in exc.value.hint assert fake.last_parts is None @@ -1032,6 +1034,15 @@ def test_cache_identity_includes_system_instruction() -> None: assert concise != verbose +def test_cache_identity_includes_provider() -> None: + """Identical model/content across providers must not share cache keys.""" + source = Source.from_text("shared context") + gemini = compute_cache_key(GEMINI_MODEL, (source,), provider="gemini") + openai = compute_cache_key(GEMINI_MODEL, (source,), provider="openai") + + assert gemini != openai + + @pytest.mark.asyncio async def test_create_cache_returns_handle(monkeypatch: pytest.MonkeyPatch) -> None: """create_cache() should return a CacheHandle with the expected fields.""" From 3d494830297974fecb110dcf444ff332b79ab2e3 Mon Sep 17 00:00:00 2001 From: Sean Brar Date: Wed, 4 Mar 2026 19:42:49 -0800 Subject: [PATCH 09/14] refactor(cache): consolidate caching internals into cache.py Move create_cache implementation from __init__.py into cache.py (create_cache_impl, _resolve_file_parts, module-level _registry), leaving __init__.create_cache as a thin provider-lifecycle wrapper. Shift cache-handle conflict validation (provider/model mismatch, system_instruction/tools/tool_choice/sources conflicts) from execute_plan() into build_plan() so errors surface at planning time before any network I/O. Retain a single runtime persistent_cache capability check in execute_plan() as a safety net for hand-built handles. The new _resolve_file_parts memoizes uploads by (file_path, mime_type) to avoid duplicate uploads for repeated file sources within a single create_cache() call. --- src/pollux/__init__.py | 91 ++------------------------- src/pollux/cache.py | 138 ++++++++++++++++++++++++++++++++++++++++- src/pollux/execute.py | 76 ++++------------------- src/pollux/plan.py | 56 ++++++++++++++++- tests/test_pipeline.py | 35 ++++++++--- 5 files changed, 235 insertions(+), 161 deletions(-) diff --git a/src/pollux/__init__.py b/src/pollux/__init__.py index b1bb354..9fb52f8 100644 --- a/src/pollux/__init__.py +++ b/src/pollux/__init__.py @@ -14,7 +14,7 @@ import logging from typing import TYPE_CHECKING, Any -from pollux.cache import CacheHandle, CacheRegistry +from pollux.cache import CacheHandle from pollux.config import Config from pollux.errors import ( APIError, @@ -49,9 +49,6 @@ logger = logging.getLogger(__name__) -# Module-level cache registry for reuse across calls -_registry = CacheRegistry() - async def run( prompt: str | None = None, @@ -194,95 +191,17 @@ async def create_cache( ) result = await run("Summarize.", config=config, options=Options(cache=handle)) """ - from pollux.cache import compute_cache_key, get_or_create_cache - from pollux.execute import _substitute_upload_parts - from pollux.plan import build_shared_parts - - if not isinstance(ttl_seconds, int) or ttl_seconds < 1: - raise ConfigurationError( - f"ttl_seconds must be an integer ≥ 1, got {ttl_seconds!r}", - hint="Pass a positive integer for the cache TTL.", - ) + from pollux.cache import create_cache_impl provider = _get_provider(config) try: - if not provider.capabilities.persistent_cache: - raise ConfigurationError( - f"Provider {config.provider!r} does not support persistent caching", - hint="Use a provider that supports persistent_cache (e.g. Gemini).", - ) - - src_tuple = tuple(sources) if not isinstance(sources, tuple) else sources - - # Validate sources - for s in src_tuple: - if not isinstance(s, Source): - raise ConfigurationError( - f"Expected Source, got {type(s).__name__}", - hint="Use Source.from_file(), Source.from_text(), etc.", - ) - - key = compute_cache_key( - config.model, - src_tuple, - provider=config.provider, - system_instruction=system_instruction, - tools=tools, - ) - - cached = _registry.get(key) - if cached is not None: - cache_name, expires_at = cached - return CacheHandle( - name=cache_name, - model=config.model, - provider=config.provider, - expires_at=expires_at, - ) - - parts = build_shared_parts(src_tuple) - - # Resolve file uploads. - upload_cache: dict[tuple[str, str], Any] = {} - upload_inflight: dict[tuple[str, str], asyncio.Future[Any]] = {} - upload_lock = asyncio.Lock() - retry_policy = config.retry - - parts = await _substitute_upload_parts( - parts, + return await create_cache_impl( + sources, provider=provider, - call_idx=None, - upload_cache=upload_cache, - upload_inflight=upload_inflight, - upload_lock=upload_lock, - retry_policy=retry_policy, - ) - - result = await get_or_create_cache( - provider, - _registry, - key=key, - model=config.model, - parts=parts, + config=config, system_instruction=system_instruction, tools=tools, ttl_seconds=ttl_seconds, - retry_policy=retry_policy, - ) - - if result is None: - raise InternalError( - "Cache creation returned None unexpectedly", - hint="This is a Pollux internal error. Please report it.", - ) - - cache_name, expires_at = result - - return CacheHandle( - name=cache_name, - model=config.model, - provider=config.provider, - expires_at=expires_at, ) finally: await _close_provider(provider) diff --git a/src/pollux/cache.py b/src/pollux/cache.py index e64382f..26a1086 100644 --- a/src/pollux/cache.py +++ b/src/pollux/cache.py @@ -6,14 +6,16 @@ from dataclasses import dataclass, field import hashlib import logging +from pathlib import Path import time from typing import TYPE_CHECKING, Any from pollux._singleflight import singleflight_cached -from pollux.errors import ConfigurationError +from pollux.errors import ConfigurationError, InternalError from pollux.retry import RetryPolicy, retry_async, should_retry_side_effect if TYPE_CHECKING: + from pollux.config import Config from pollux.providers.base import Provider from pollux.source import Source @@ -150,3 +152,137 @@ async def _work() -> tuple[str, float]: cache_set=registry.set, work=_work, ) + + +# Module-level registry shared across create_cache calls. +_registry = CacheRegistry() + + +async def _resolve_file_parts( + parts: list[Any], + provider: Provider, + retry_policy: RetryPolicy, +) -> list[Any]: + """Replace file placeholders with uploaded assets. + + Memoizes by ``(file_path, mime_type)`` so duplicate file sources in a + single ``create_cache()`` call share one upload. No singleflight + needed because ``create_cache`` is sequential, not concurrent fan-out. + """ + resolved: list[Any] = [] + seen: dict[tuple[str, str], Any] = {} + for part in parts: + if ( + isinstance(part, dict) + and isinstance(part.get("file_path"), str) + and isinstance(part.get("mime_type"), str) + ): + fp, mt = part["file_path"], part["mime_type"] + key = (fp, mt) + if key in seen: + resolved.append(seen[key]) + continue + if retry_policy.max_attempts <= 1: + asset = await provider.upload_file(Path(fp), mt) + else: + + async def _upload(_fp: str = fp, _mt: str = mt) -> Any: + return await provider.upload_file(Path(_fp), _mt) + + asset = await retry_async( + _upload, + policy=retry_policy, + should_retry=should_retry_side_effect, + ) + seen[key] = asset + resolved.append(asset) + else: + resolved.append(part) + return resolved + + +async def create_cache_impl( + sources: tuple[Source, ...] | list[Source], + *, + provider: Provider, + config: Config, + system_instruction: str | None = None, + tools: list[dict[str, Any]] | list[Any] | None = None, + ttl_seconds: int = 3600, +) -> CacheHandle: + """Core implementation of ``create_cache()``. + + Receives an already-initialized provider; the caller manages its lifecycle. + """ + from pollux.plan import build_shared_parts + from pollux.source import Source as SourceCls + + if not isinstance(ttl_seconds, int) or ttl_seconds < 1: + raise ConfigurationError( + f"ttl_seconds must be an integer ≥ 1, got {ttl_seconds!r}", + hint="Pass a positive integer for the cache TTL.", + ) + + if not provider.capabilities.persistent_cache: + raise ConfigurationError( + f"Provider {config.provider!r} does not support persistent caching", + hint="Use a provider that supports persistent_cache (e.g. Gemini).", + ) + + src_tuple = tuple(sources) if not isinstance(sources, tuple) else sources + + for s in src_tuple: + if not isinstance(s, SourceCls): + raise ConfigurationError( + f"Expected Source, got {type(s).__name__}", + hint="Use Source.from_file(), Source.from_text(), etc.", + ) + + key = compute_cache_key( + config.model, + src_tuple, + provider=config.provider, + system_instruction=system_instruction, + tools=tools, + ) + + cached = _registry.get(key) + if cached is not None: + cache_name, expires_at = cached + return CacheHandle( + name=cache_name, + model=config.model, + provider=config.provider, + expires_at=expires_at, + ) + + parts = build_shared_parts(src_tuple) + retry_policy = config.retry + parts = await _resolve_file_parts(parts, provider, retry_policy) + + result = await get_or_create_cache( + provider, + _registry, + key=key, + model=config.model, + parts=parts, + system_instruction=system_instruction, + tools=tools, + ttl_seconds=ttl_seconds, + retry_policy=retry_policy, + ) + + if result is None: + raise InternalError( + "Cache creation returned None unexpectedly", + hint="This is a Pollux internal error. Please report it.", + ) + + cache_name, expires_at = result + + return CacheHandle( + name=cache_name, + model=config.model, + provider=config.provider, + expires_at=expires_at, + ) diff --git a/src/pollux/execute.py b/src/pollux/execute.py index db5436c..b3c2842 100644 --- a/src/pollux/execute.py +++ b/src/pollux/execute.py @@ -113,68 +113,17 @@ async def execute_plan(plan: Plan, provider: Provider) -> ExecutionTrace: "Conversation continuity currently supports exactly one prompt per call", hint="Use run() or run_many() with a single prompt when passing history/continue_from.", ) - if options.cache is not None: - cache_handle = options.cache - if not caps.persistent_cache: - raise ConfigurationError( - "Provider does not support persistent caching", - hint=( - "Remove options.cache or choose a provider with " - "persistent_cache support." - ), - ) - if cache_handle.provider != config.provider: - raise ConfigurationError( - "cache handle provider does not match config provider", - hint=( - f"Create the cache with provider={config.provider!r} and " - "reuse it with the same provider." - ), - ) - if cache_handle.model != model: - raise ConfigurationError( - "cache handle model does not match config model", - hint=( - f"Create the cache with model={model!r} and reuse it " - "with the same model." - ), - ) - # Gemini (and potentially other providers) reject requests that pass - # system_instruction or tools alongside cached_content. Catch the - # conflict early so users get a clear Pollux error instead of a - # provider 400. - if options.system_instruction is not None: - raise ConfigurationError( - "system_instruction cannot be used with a cache handle", - hint=( - "Bake the system instruction into create_cache() instead, " - "or remove the cache handle." - ), - ) - if options.tools is not None: - raise ConfigurationError( - "tools cannot be used with a cache handle", - hint=( - "Bake tools into create_cache() instead, " - "or remove the cache handle." - ), - ) - if options.tool_choice is not None: - raise ConfigurationError( - "tool_choice cannot be used with a cache handle", - hint=( - "Remove tool_choice when using a cache handle, " - "or remove the cache handle." - ), - ) - if plan.shared_parts: - raise ConfigurationError( - "sources cannot be used with a cache handle", - hint=( - "Bake sources into create_cache() instead, " - "or remove the cache handle." - ), - ) + # Runtime safety net: reject hand-built handles targeting providers + # that lack persistent caching (the planner already validates other + # cache conflicts). + if plan.cache_name is not None and not caps.persistent_cache: + raise ConfigurationError( + "Provider does not support persistent caching", + hint=( + "Remove options.cache or choose a provider with " + "persistent_cache support." + ), + ) if (not provider.capabilities.uploads) and any( isinstance(p, dict) @@ -233,8 +182,7 @@ async def execute_plan(plan: Plan, provider: Provider) -> ExecutionTrace: conversation_state: dict[str, Any] | None = None try: - # Use pre-created cache name from Options.cache (via plan). - cache_name = options.cache.name if options.cache is not None else None + cache_name = plan.cache_name # Execute calls with concurrency control concurrency = config.request_concurrency diff --git a/src/pollux/plan.py b/src/pollux/plan.py index 433a340..da6ab40 100644 --- a/src/pollux/plan.py +++ b/src/pollux/plan.py @@ -5,6 +5,8 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any +from pollux.errors import ConfigurationError + if TYPE_CHECKING: from pollux.request import Request from pollux.source import Source @@ -28,14 +30,64 @@ def build_plan(request: Request) -> Plan: """Build execution plan from normalized request. Handles both single-prompt and vectorized (multi-prompt) scenarios. + Validates cache handle conflicts eagerly so callers get clear errors + before any network I/O. """ sources = request.sources shared_parts = build_shared_parts(sources) - # Resolve cache_name from Options.cache if provided. cache_name: str | None = None if request.options.cache is not None: - cache_name = request.options.cache.name + cache = request.options.cache + if cache.provider != request.config.provider: + raise ConfigurationError( + "cache handle provider does not match config provider", + hint=( + f"Create the cache with provider={request.config.provider!r} and " + "reuse it with the same provider." + ), + ) + if cache.model != request.config.model: + raise ConfigurationError( + "cache handle model does not match config model", + hint=( + f"Create the cache with model={request.config.model!r} and reuse it " + "with the same model." + ), + ) + if request.options.system_instruction is not None: + raise ConfigurationError( + "system_instruction cannot be used with a cache handle", + hint=( + "Bake the system instruction into create_cache() instead, " + "or remove the cache handle." + ), + ) + if request.options.tools is not None: + raise ConfigurationError( + "tools cannot be used with a cache handle", + hint=( + "Bake tools into create_cache() instead, " + "or remove the cache handle." + ), + ) + if request.options.tool_choice is not None: + raise ConfigurationError( + "tool_choice cannot be used with a cache handle", + hint=( + "Remove tool_choice when using a cache handle, " + "or remove the cache handle." + ), + ) + if shared_parts: + raise ConfigurationError( + "sources cannot be used with a cache handle", + hint=( + "Bake sources into create_cache() instead, " + "or remove the cache handle." + ), + ) + cache_name = cache.name return Plan( request=request, diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 2fa071e..80f0c15 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -10,6 +10,7 @@ import pytest import pollux +import pollux.cache from pollux.cache import CacheHandle, CacheRegistry, compute_cache_key from pollux.config import Config from pollux.errors import APIError, ConfigurationError, PlanningError, SourceError @@ -317,7 +318,7 @@ async def aclose(self) -> None: for fail_cache in (False, True): fake = _Provider(fail_cache=fail_cache) monkeypatch.setattr(pollux, "_get_provider", lambda _config, _fake=fake: _fake) - monkeypatch.setattr(pollux, "_registry", CacheRegistry()) + monkeypatch.setattr(pollux.cache, "_registry", CacheRegistry()) if fail_cache: with pytest.raises(APIError, match="cache failed"): @@ -450,7 +451,7 @@ async def test_cache_single_flight_propagates_failure_and_clears_inflight( """If cache creation fails, concurrent callers see the error; future calls recover.""" fake = GateProvider(kind="cache") monkeypatch.setattr(pollux, "_get_provider", lambda _config: fake) - monkeypatch.setattr(pollux, "_registry", CacheRegistry()) + monkeypatch.setattr(pollux.cache, "_registry", CacheRegistry()) cfg = Config( provider="gemini", @@ -840,13 +841,10 @@ async def test_options_cache_requires_persistent_cache_capability( with pytest.raises(ConfigurationError, match="persistent caching"): await pollux.run_many( prompts=("Q",), - sources=(Source.from_text("shared context"),), config=cfg, options=Options(cache=handle), ) - assert fake.last_parts is None - @pytest.mark.asyncio async def test_options_cache_rejects_provider_and_model_mismatch( @@ -1048,7 +1046,7 @@ async def test_create_cache_returns_handle(monkeypatch: pytest.MonkeyPatch) -> N """create_cache() should return a CacheHandle with the expected fields.""" fake = FakeProvider() monkeypatch.setattr(pollux, "_get_provider", lambda _config: fake) - monkeypatch.setattr(pollux, "_registry", CacheRegistry()) + monkeypatch.setattr(pollux.cache, "_registry", CacheRegistry()) cfg = Config(provider="gemini", model=CACHE_MODEL, use_mock=True) handle = await pollux.create_cache( @@ -1073,7 +1071,7 @@ async def test_create_cache_cache_hit_skips_uploads( """Repeated create_cache() calls for the same key should not re-upload files.""" fake = FakeProvider() monkeypatch.setattr(pollux, "_get_provider", lambda _config: fake) - monkeypatch.setattr(pollux, "_registry", CacheRegistry()) + monkeypatch.setattr(pollux.cache, "_registry", CacheRegistry()) file_path = tmp_path / "cache-me.txt" file_path.write_text("hello cache", encoding="utf-8") @@ -1087,6 +1085,27 @@ async def test_create_cache_cache_hit_skips_uploads( assert fake.upload_calls == 1 +@pytest.mark.asyncio +async def test_create_cache_deduplicates_file_uploads_within_call( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Any, +) -> None: + """Duplicate file sources in a single create_cache() should upload only once.""" + fake = FakeProvider() + monkeypatch.setattr(pollux, "_get_provider", lambda _config: fake) + monkeypatch.setattr(pollux.cache, "_registry", CacheRegistry()) + + file_path = tmp_path / "dup.txt" + file_path.write_text("same content", encoding="utf-8") + + cfg = Config(provider="gemini", model=CACHE_MODEL, use_mock=True) + src = Source.from_file(file_path) + handle = await pollux.create_cache((src, src), config=cfg) + + assert isinstance(handle, CacheHandle) + assert fake.upload_calls == 1 + + @pytest.mark.asyncio async def test_create_cache_rejects_unserializable_tools( monkeypatch: pytest.MonkeyPatch, @@ -1094,7 +1113,7 @@ async def test_create_cache_rejects_unserializable_tools( """create_cache() should raise ConfigurationError with a hint for non-JSON tools.""" fake = FakeProvider() monkeypatch.setattr(pollux, "_get_provider", lambda _config: fake) - monkeypatch.setattr(pollux, "_registry", CacheRegistry()) + monkeypatch.setattr(pollux.cache, "_registry", CacheRegistry()) cfg = Config(provider="gemini", model=CACHE_MODEL, use_mock=True) From efd19f79fd2dd91f6efec9393be25d91c7797406 Mon Sep 17 00:00:00 2001 From: Sean Brar Date: Wed, 4 Mar 2026 19:53:24 -0800 Subject: [PATCH 10/14] fix(cache): deduplicate file uploads across concurrent create_cache calls Move _resolve_file_parts() into the single-flight work function inside get_or_create_cache() so concurrent callers for the same cache key share both uploads and cache creation. Previously uploads ran before the single-flight boundary, causing duplicate uploads when two coroutines raced past the registry miss. get_or_create_cache() now accepts raw_parts (unresolved placeholders) and resolves them inside _work(). Add test_cache_single_flight_deduplicates_file_uploads to verify upload_calls==1 and cache_calls==1 under concurrency. --- src/pollux/cache.py | 19 +++++++++--------- tests/test_pipeline.py | 45 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 9 deletions(-) diff --git a/src/pollux/cache.py b/src/pollux/cache.py index 26a1086..6da4e3a 100644 --- a/src/pollux/cache.py +++ b/src/pollux/cache.py @@ -106,7 +106,7 @@ async def get_or_create_cache( *, key: str, model: str, - parts: list[Any], + raw_parts: list[Any], system_instruction: str | None, tools: list[dict[str, Any]] | list[Any] | None = None, ttl_seconds: int, @@ -114,14 +114,17 @@ async def get_or_create_cache( ) -> tuple[str, float] | None: """Get existing cache or create new one with single-flight protection. - Single-flight: concurrent requests for the same key share one creation call. + File placeholders in *raw_parts* are resolved inside the single-flight + work function so concurrent callers share both uploads and cache creation. """ if not provider.capabilities.persistent_cache: return None async def _work() -> tuple[str, float]: logger.debug("Creating cache key=%s…", key[:8]) - if retry_policy is None or retry_policy.max_attempts <= 1: + policy = retry_policy or RetryPolicy(max_attempts=1) + parts = await _resolve_file_parts(raw_parts, provider, policy) + if policy.max_attempts <= 1: name = await provider.create_cache( model=model, parts=parts, @@ -139,7 +142,7 @@ async def _work() -> tuple[str, float]: tools=tools, ttl_seconds=ttl_seconds, ), - policy=retry_policy, + policy=policy, should_retry=should_retry_side_effect, ) return name, time.time() + max(0, ttl_seconds) @@ -256,20 +259,18 @@ async def create_cache_impl( expires_at=expires_at, ) - parts = build_shared_parts(src_tuple) - retry_policy = config.retry - parts = await _resolve_file_parts(parts, provider, retry_policy) + raw_parts = build_shared_parts(src_tuple) result = await get_or_create_cache( provider, _registry, key=key, model=config.model, - parts=parts, + raw_parts=raw_parts, system_instruction=system_instruction, tools=tools, ttl_seconds=ttl_seconds, - retry_policy=retry_policy, + retry_policy=config.retry, ) if result is None: diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 80f0c15..a540e29 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -444,6 +444,51 @@ async def test_source_from_json_is_sent_as_inline_content() -> None: # ============================================================================= +@pytest.mark.asyncio +async def test_cache_single_flight_deduplicates_file_uploads( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Any, +) -> None: + """Concurrent create_cache() calls should share uploads via single-flight.""" + gate = asyncio.Event() + entered = asyncio.Event() + + @dataclass + class _SlowCacheProvider(FakeProvider): + async def create_cache(self, **kwargs: Any) -> str: # noqa: ARG002 + self.cache_calls += 1 + entered.set() + await gate.wait() + return "cachedContents/test" + + fake = _SlowCacheProvider() + monkeypatch.setattr(pollux, "_get_provider", lambda _config: fake) + monkeypatch.setattr(pollux.cache, "_registry", CacheRegistry()) + + file_path = tmp_path / "shared.txt" + file_path.write_text("shared content", encoding="utf-8") + + cfg = Config( + provider="gemini", + model=CACHE_MODEL, + use_mock=True, + retry=RetryPolicy(max_attempts=1), + ) + source = Source.from_file(file_path) + + t1 = asyncio.create_task(pollux.create_cache((source,), config=cfg)) + await entered.wait() + t2 = asyncio.create_task(pollux.create_cache((source,), config=cfg)) + # Small yield to let t2 join the singleflight waiters. + await asyncio.sleep(0) + gate.set() + + results = await asyncio.gather(t1, t2, return_exceptions=True) + assert all(isinstance(r, CacheHandle) for r in results) + assert fake.upload_calls == 1, "concurrent calls should share uploads" + assert fake.cache_calls == 1, "concurrent calls should share cache creation" + + @pytest.mark.asyncio async def test_cache_single_flight_propagates_failure_and_clears_inflight( monkeypatch: pytest.MonkeyPatch, From 5fa9c290e10e9db00ddce1f2c31599406ea0a21a Mon Sep 17 00:00:00 2001 From: Sean Brar Date: Wed, 4 Mar 2026 20:09:25 -0800 Subject: [PATCH 11/14] fix(cache): scope cache identity by api_key and fix docstring example Include api_key in compute_cache_key() so different credentials for the same provider/model produce distinct cache entries. Prevents silent cross-account handle reuse in multi-tenant or multi-key scenarios. Also fix the create_cache() docstring example which referenced an undefined `config` variable (now uses `cfg` consistently). --- src/pollux/__init__.py | 5 +++-- src/pollux/cache.py | 9 +++++++-- tests/test_pipeline.py | 13 +++++++++++++ 3 files changed, 23 insertions(+), 4 deletions(-) diff --git a/src/pollux/__init__.py b/src/pollux/__init__.py index 9fb52f8..e07db9e 100644 --- a/src/pollux/__init__.py +++ b/src/pollux/__init__.py @@ -184,12 +184,13 @@ async def create_cache( or *ttl_seconds* is invalid. Example: + cfg = Config(provider="gemini", model="gemini-2.5-flash") handle = await create_cache( [Source.from_file("book.pdf")], - config=Config(provider="gemini", model="gemini-2.5-flash"), + config=cfg, ttl_seconds=3600, ) - result = await run("Summarize.", config=config, options=Options(cache=handle)) + result = await run("Summarize.", config=cfg, options=Options(cache=handle)) """ from pollux.cache import create_cache_impl diff --git a/src/pollux/cache.py b/src/pollux/cache.py index 6da4e3a..83d0849 100644 --- a/src/pollux/cache.py +++ b/src/pollux/cache.py @@ -67,17 +67,21 @@ def compute_cache_key( model: str, sources: tuple[Source, ...], provider: str | None = None, + api_key: str | None = None, system_instruction: str | None = None, tools: list[dict[str, Any]] | list[Any] | None = None, ) -> str: """Compute deterministic cache key using content hashes. - Key = hash(model + provider + system + content digests of sources) - This fixes the cache identity collision bug where identifier+size was used. + Key = hash(model + provider + api_key + system + content digests of sources). + Including ``api_key`` prevents cross-account handle reuse when multiple + keys for the same provider/model coexist in one process. """ parts = [model] if provider: parts.append(provider) + if api_key: + parts.append(api_key) if system_instruction: parts.append(system_instruction) if tools: @@ -245,6 +249,7 @@ async def create_cache_impl( config.model, src_tuple, provider=config.provider, + api_key=config.api_key, system_instruction=system_instruction, tools=tools, ) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index a540e29..b28ce37 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1086,6 +1086,19 @@ def test_cache_identity_includes_provider() -> None: assert gemini != openai +def test_cache_identity_includes_api_key() -> None: + """Different API keys for the same provider/model must not share cache keys.""" + source = Source.from_text("shared context") + key_a = compute_cache_key( + GEMINI_MODEL, (source,), provider="gemini", api_key="key-aaa" + ) + key_b = compute_cache_key( + GEMINI_MODEL, (source,), provider="gemini", api_key="key-bbb" + ) + + assert key_a != key_b + + @pytest.mark.asyncio async def test_create_cache_returns_handle(monkeypatch: pytest.MonkeyPatch) -> None: """create_cache() should return a CacheHandle with the expected fields.""" From bb741327c5b20d722228e33c0aa8d5d8837896e3 Mon Sep 17 00:00:00 2001 From: Sean Brar Date: Wed, 4 Mar 2026 20:38:19 -0800 Subject: [PATCH 12/14] fix(cache): validate tools and system_instruction before file uploads - Validate tool items are dicts in create_cache_impl before uploads, preventing wasted file uploads on invalid input - Validate system_instruction type at the API boundary, converting a raw TypeError into a ConfigurationError with hint - Pass through ConfigurationError in wrap_provider_error instead of re-wrapping as CacheError --- src/pollux/cache.py | 14 ++++++++++++++ src/pollux/providers/_errors.py | 6 ++++++ tests/test_pipeline.py | 6 +++--- 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/src/pollux/cache.py b/src/pollux/cache.py index 83d0849..d01296b 100644 --- a/src/pollux/cache.py +++ b/src/pollux/cache.py @@ -230,6 +230,12 @@ async def create_cache_impl( hint="Pass a positive integer for the cache TTL.", ) + if system_instruction is not None and not isinstance(system_instruction, str): + raise ConfigurationError( + f"system_instruction must be a string, got {type(system_instruction).__name__}", + hint="Pass a string for the system instruction.", + ) + if not provider.capabilities.persistent_cache: raise ConfigurationError( f"Provider {config.provider!r} does not support persistent caching", @@ -245,6 +251,14 @@ async def create_cache_impl( hint="Use Source.from_file(), Source.from_text(), etc.", ) + if tools is not None: + for i, t in enumerate(tools): + if not isinstance(t, dict): + raise ConfigurationError( + f"Tool at index {i} must be a dictionary, got {type(t).__name__}", + hint="Ensure all items in the tools list are dictionaries.", + ) + key = compute_cache_key( config.model, src_tuple, diff --git a/src/pollux/providers/_errors.py b/src/pollux/providers/_errors.py index 99d513a..074cd5d 100644 --- a/src/pollux/providers/_errors.py +++ b/src/pollux/providers/_errors.py @@ -16,6 +16,7 @@ from pollux.errors import ( APIError, CacheError, + ConfigurationError, RateLimitError, _walk_exception_chain, ) @@ -138,6 +139,11 @@ def wrap_provider_error( if isinstance(exc, asyncio.CancelledError): raise exc + # ConfigurationError should propagate as-is, not be re-wrapped as + # CacheError/APIError — it signals a caller mistake, not a provider failure. + if isinstance(exc, ConfigurationError): + raise exc + # Already wrapped — fill in missing context only. if isinstance(exc, APIError): if exc.provider is None: diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index b28ce37..73f4e21 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1168,7 +1168,7 @@ async def test_create_cache_deduplicates_file_uploads_within_call( async def test_create_cache_rejects_unserializable_tools( monkeypatch: pytest.MonkeyPatch, ) -> None: - """create_cache() should raise ConfigurationError with a hint for non-JSON tools.""" + """create_cache() should raise ConfigurationError for non-dict tool items.""" fake = FakeProvider() monkeypatch.setattr(pollux, "_get_provider", lambda _config: fake) monkeypatch.setattr(pollux.cache, "_registry", CacheRegistry()) @@ -1178,14 +1178,14 @@ async def test_create_cache_rejects_unserializable_tools( class CustomTool: pass - with pytest.raises(ConfigurationError, match="JSON serializable") as exc: + with pytest.raises(ConfigurationError, match="must be a dictionary") as exc: await pollux.create_cache( (Source.from_text("hello"),), config=cfg, tools=[CustomTool()], ) - assert "convert them to dicts" in str(exc.value.hint) + assert exc.value.hint is not None assert fake.upload_calls == 0 assert fake.cache_calls == 0 From 865186e754579d6ca885c3e072a7ca6fd6a38359 Mon Sep 17 00:00:00 2001 From: Sean Brar Date: Wed, 4 Mar 2026 20:42:27 -0800 Subject: [PATCH 13/14] fix(cache): reject expired cache handles in build_plan An expired CacheHandle passed via Options(cache=handle) was silently accepted, leading to a cryptic provider error. Now caught eagerly with a clear ConfigurationError before any network I/O. --- src/pollux/plan.py | 6 ++++++ tests/test_pipeline.py | 24 ++++++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/src/pollux/plan.py b/src/pollux/plan.py index da6ab40..3dbe740 100644 --- a/src/pollux/plan.py +++ b/src/pollux/plan.py @@ -3,6 +3,7 @@ from __future__ import annotations from dataclasses import dataclass +import time from typing import TYPE_CHECKING, Any from pollux.errors import ConfigurationError @@ -39,6 +40,11 @@ def build_plan(request: Request) -> Plan: cache_name: str | None = None if request.options.cache is not None: cache = request.options.cache + if time.time() >= cache.expires_at: + raise ConfigurationError( + "cache handle has expired", + hint="Create a new cache with create_cache().", + ) if cache.provider != request.config.provider: raise ConfigurationError( "cache handle provider does not match config provider", diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 73f4e21..a7a078b 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -891,6 +891,30 @@ async def test_options_cache_requires_persistent_cache_capability( ) +@pytest.mark.asyncio +async def test_options_cache_rejects_expired_handle( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Expired cache handles must be rejected before any network I/O.""" + import time + + fake = FakeProvider() + monkeypatch.setattr(pollux, "_get_provider", lambda _config: fake) + + cfg = Config(provider="gemini", model=GEMINI_MODEL, use_mock=True) + expired = CacheHandle( + name="cachedContents/test", + model=GEMINI_MODEL, + provider="gemini", + expires_at=time.time() - 1, + ) + + with pytest.raises(ConfigurationError, match="expired"): + await pollux.run("Q", config=cfg, options=Options(cache=expired)) + + assert fake.last_parts is None + + @pytest.mark.asyncio async def test_options_cache_rejects_provider_and_model_mismatch( monkeypatch: pytest.MonkeyPatch, From 3df3d276f0a8cfdfaec99b3f80dedd70ecd219fb Mon Sep 17 00:00:00 2001 From: Sean Brar Date: Wed, 4 Mar 2026 21:23:09 -0800 Subject: [PATCH 14/14] docs(cache): note validation-before-I/O design and scaling threshold --- src/pollux/cache.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/pollux/cache.py b/src/pollux/cache.py index d01296b..fe06eed 100644 --- a/src/pollux/cache.py +++ b/src/pollux/cache.py @@ -220,6 +220,11 @@ async def create_cache_impl( """Core implementation of ``create_cache()``. Receives an already-initialized provider; the caller manages its lifecycle. + + All input validation is intentionally front-loaded before any I/O + (uploads, API calls). If the parameter surface grows beyond the + current five axes, consider a validated ``CacheSpec`` dataclass to + keep this boundary manageable. """ from pollux.plan import build_shared_parts from pollux.source import Source as SourceCls