diff --git a/.github/ISSUE_TEMPLATE/bug.md b/.github/ISSUE_TEMPLATE/bug.md index 2ea6707..14e7918 100644 --- a/.github/ISSUE_TEMPLATE/bug.md +++ b/.github/ISSUE_TEMPLATE/bug.md @@ -40,7 +40,7 @@ labels: [bug] diff --git a/README.md b/README.md index 76addcc..30957e7 100644 --- a/README.md +++ b/README.md @@ -138,7 +138,6 @@ from pollux import Config config = Config( provider="gemini", model="gemini-2.5-flash-lite", - enable_caching=True, # Gemini-only in v1.0 ) ``` diff --git a/cookbook/optimization/cache-warming-and-ttl.py b/cookbook/optimization/cache-warming-and-ttl.py index 174ad1a..9a66227 100644 --- a/cookbook/optimization/cache-warming-and-ttl.py +++ b/cookbook/optimization/cache-warming-and-ttl.py @@ -7,7 +7,7 @@ Pattern: - Keep prompts and sources fixed. - - Enable caching with a meaningful TTL. + - Create a persistent cache via ``create_cache()``. - Run once to warm and once to reuse (back-to-back). - Compare tokens and cache signal. """ @@ -16,7 +16,6 @@ import argparse import asyncio -from dataclasses import replace from pathlib import Path from typing import TYPE_CHECKING @@ -28,15 +27,15 @@ print_section, ) from cookbook.utils.runtime import add_runtime_args, build_config_or_exit, usage_tokens -from pollux import Config, Source, run_many +from pollux import Config, Options, Source, create_cache, run_many if TYPE_CHECKING: from pollux.result import ResultEnvelope -PROMPTS = [ +PROMPTS = ( "List 5 key concepts with one-sentence explanations.", "Extract three actionable recommendations.", -] +) def describe(run_name: str, envelope: ResultEnvelope) -> None: @@ -52,15 +51,17 @@ def describe(run_name: str, envelope: ResultEnvelope) -> None: ) -async def main_async(directory: Path, *, limit: int, config: Config) -> None: +async def main_async(directory: Path, *, limit: int, config: Config, ttl: int) -> None: files = sorted(path for path in directory.rglob("*") if path.is_file())[:limit] if not files: raise SystemExit(f"No files found under: {directory}") sources = [Source.from_file(path) for path in files] - warm = await run_many(PROMPTS, sources=sources, config=config) - reuse = await run_many(PROMPTS, sources=sources, config=config) + handle = await create_cache(sources, config=config, ttl_seconds=ttl) + + warm = await run_many(PROMPTS, config=config, options=Options(cache=handle)) + reuse = await run_many(PROMPTS, config=config, options=Options(cache=handle)) warm_tokens = usage_tokens(warm) reuse_tokens = usage_tokens(reuse) saved = None @@ -107,16 +108,14 @@ def main() -> None: hint="No input directory found. Run `just demo-data` or pass --input /path/to/dir.", ) config = build_config_or_exit(args) - cached_config = replace( - config, enable_caching=True, ttl_seconds=max(1, int(args.ttl)) - ) - print_header("Cache warming and TTL", config=cached_config) + print_header("Cache warming and TTL", config=config) asyncio.run( main_async( directory, limit=max(1, int(args.limit)), - config=cached_config, + config=config, + ttl=max(1, int(args.ttl)), ) ) diff --git a/cookbook/utils/runtime.py b/cookbook/utils/runtime.py index 1bbd46b..9308cb7 100644 --- a/cookbook/utils/runtime.py +++ b/cookbook/utils/runtime.py @@ -60,14 +60,11 @@ def build_config_or_exit(args: argparse.Namespace) -> Config: def print_run_mode(config: Config) -> None: """Print a compact runtime mode line for recipe users.""" mode = "mock" if config.use_mock else "real-api" - caching = f"on(ttl={config.ttl_seconds}s)" if config.enable_caching else "off" extra = "" # Keep the mode line compact; only call out non-default concurrency. if getattr(config, "request_concurrency", 6) != 6: extra = f" | request_concurrency={config.request_concurrency}" - print( - f"Mode: {mode} | provider={config.provider} | model={config.model} | caching={caching}{extra}" - ) + print(f"Mode: {mode} | provider={config.provider} | model={config.model}{extra}") def usage_tokens(envelope: ResultEnvelope) -> int | None: diff --git a/docs/caching.md b/docs/caching.md index 3758c4b..3169379 100644 --- a/docs/caching.md +++ b/docs/caching.md @@ -1,5 +1,5 @@ @@ -73,29 +73,30 @@ compare_efficiency(946_800, 10) More questions on the same content = greater savings. -## Enabling Caching +## Creating a Cache -Let's see this in practice. Two flags in `Config` control caching: +Use `create_cache()` to upload content to the provider once, then pass the +returned handle to `run()` or `run_many()` via `Options(cache=handle)`: ```python import asyncio -from pollux import Config, Source, run_many +from pollux import Config, Options, Source, create_cache, run_many async def main() -> None: config = Config( provider="gemini", model="gemini-2.5-flash-lite", - enable_caching=True, - ttl_seconds=3600, ) - prompts = ["Summarize in one sentence.", "List 3 keywords."] sources = [Source.from_text( "ACME Corp Q3 2025 earnings: revenue $4.2B (+12% YoY), " "operating margin 18.5%, guidance raised for Q4." )] - first = await run_many(prompts=prompts, sources=sources, config=config) - second = await run_many(prompts=prompts, sources=sources, config=config) + handle = await create_cache(sources, config=config, ttl_seconds=3600) + + prompts = ["Summarize in one sentence.", "List 3 keywords."] + first = await run_many(prompts=prompts, config=config, options=Options(cache=handle)) + second = await run_many(prompts=prompts, config=config, options=Options(cache=handle)) print("first:", first["status"]) print("second:", second["status"]) @@ -106,27 +107,43 @@ asyncio.run(main()) ### Step-by-Step Walkthrough -1. **Set `enable_caching=True`.** This tells Pollux to upload content to the - provider's cache on the first call, rather than sending it inline. +1. **Call `create_cache()`.** Pass your sources, config, and a TTL. Pollux + uploads the content to the provider and returns a `CacheHandle`. 2. **Set `ttl_seconds`.** The TTL controls how long the cached content lives on the provider. Match it to your reuse window. 3600s (1 hour) is a reasonable default for interactive sessions. -3. **Run the same sources with different prompts.** The first `run_many()` call - uploads the content and creates a cache entry. The second call detects the - same content hash and reuses the cached reference. +3. **Pass the handle via `Options(cache=handle)`.** Each `run()` or `run_many()` + call that uses this handle references the cached content instead of + re-uploading it. 4. **Verify with `metrics.cache_used`.** Check `result["metrics"]["cache_used"]` on subsequent calls. `True` confirms the provider served content from cache rather than re-uploading. -Pollux computes cache identity from model + source content hash. The second -call reuses the cached context automatically. +Pollux computes cache identity from model + source content hash. Calls with +the same handle reuse the cached context automatically. + +!!! warning "Options restricted when using a cache handle" + When `Options(cache=handle)` is set, the following fields **cannot** be + passed alongside it: + + - `system_instruction` — bake it into `create_cache(system_instruction=...)` + instead. + - `tools` — bake them into `create_cache(tools=...)` instead (when + supported). + - `tool_choice` — remove it when using cached content. + - `sources` — bake them into `create_cache()` instead. + + Pollux raises `ConfigurationError` immediately if it detects these + conflicts. This mirrors a hard constraint in the Gemini API, where + `cached_content` cannot coexist with `system_instruction`, `tools`, or + `tool_config` in the same `GenerateContent` request. ## Cache Identity -Cache keys are deterministic: `hash(model + content hashes of sources)`. +Cache keys are deterministic: `hash(model + provider + content hashes of sources)`. This means: @@ -156,8 +173,8 @@ behavior. Usage counters are provider-dependent. ## Tuning TTL -The default TTL is 3600 seconds (1 hour). Tune `ttl_seconds` to match your -expected reuse window: +Pass `ttl_seconds` to `create_cache()` to control the cache lifetime. The +default is 3600 seconds (1 hour). Tune it to match your expected reuse window: - **Too short:** the cache expires before you reuse it, wasting the warm-up cost. @@ -181,9 +198,9 @@ caching and enable it when you see repeated context in your workload. ## Provider Dependency -Context caching is **Gemini-only**. Enabling it with OpenAI raises -an actionable error. See -[Provider Capabilities](reference/provider-capabilities.md) for the full +Persistent context caching is **Gemini-only**. Calling `create_cache()` with +a provider that lacks `persistent_cache` support raises an actionable error. +See [Provider Capabilities](reference/provider-capabilities.md) for the full matrix. --- diff --git a/docs/configuration.md b/docs/configuration.md index 23b24ddd..c1387f4 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -35,8 +35,6 @@ All fields and their defaults: | `model` | `str` | *(required)* | Model identifier | | `api_key` | `str \| None` | `None` | Explicit key; auto-resolved from env if omitted | | `use_mock` | `bool` | `False` | Use mock provider (no network calls) | -| `enable_caching` | `bool` | `False` | Enable provider-side context caching | -| `ttl_seconds` | `int` | `3600` | Cache time-to-live in seconds | | `request_concurrency` | `int` | `6` | Max concurrent API calls in multi-prompt execution | | `retry` | `RetryPolicy` | `RetryPolicy()` | Retry configuration | @@ -78,8 +76,6 @@ pipeline logic, testing integrations, and CI. config = Config( provider="gemini", model="gemini-2.5-flash-lite", - enable_caching=True, # Reuse uploaded context (Gemini-only) - ttl_seconds=3600, # Cache lifetime request_concurrency=6, # Concurrent API calls ) ``` @@ -87,7 +83,7 @@ config = Config( | Need | Direction | |---|---| | Fast iteration without API calls | `use_mock=True` | -| Reduce token spend on repeated context | `enable_caching=True`. See [Reducing Costs with Context Caching](caching.md) | +| Reduce token spend on repeated context | Use `create_cache()`. See [Reducing Costs with Context Caching](caching.md) | | Higher throughput for many prompts/sources | Increase `request_concurrency` | | Better resilience to transient failures | Customize `retry=RetryPolicy(...)` | @@ -151,6 +147,7 @@ options = Options( | `delivery_mode` | `str` | `"realtime"` | Only `"realtime"` is supported; `"deferred"` raises an error | | `history` | `list[dict] \| None` | `None` | Conversation history. See [Continuing Conversations Across Turns](conversations-and-agents.md) | | `continue_from` | `ResultEnvelope \| None` | `None` | Resume from a prior result. See [Continuing Conversations Across Turns](conversations-and-agents.md) | +| `cache` | `CacheHandle \| None` | `None` | Persistent context cache. See [Reducing Costs with Context Caching](caching.md) | !!! note OpenAI GPT-5 family models (`gpt-5`, `gpt-5-mini`, `gpt-5-nano`) reject @@ -159,6 +156,13 @@ options = Options( See [Writing Portable Code Across Providers](portable-code.md#model-specific-constraints) for the full constraints mapping. +!!! warning "Cache handle restrictions" + When `cache` is set, `system_instruction`, `tools`, and `tool_choice` + **must not** be passed in the same `Options`. `system_instruction` and + `tools` can be baked into `create_cache()`, while `tool_choice` must be + set only on uncached calls. See + [Reducing Costs with Context Caching](caching.md) for details. + ## Safety Notes - `Config` is immutable (`frozen=True`). Create a new instance to change values. diff --git a/docs/portable-code.md b/docs/portable-code.md index 7fb3b5a..878199b 100644 --- a/docs/portable-code.md +++ b/docs/portable-code.md @@ -35,8 +35,7 @@ varying parts in config; keep the stable parts in functions. ## Complete Example -A document analysis function that works on any provider. Caching is used -when available, skipped otherwise. +A document analysis function that works on any provider. ```python import asyncio @@ -55,27 +54,22 @@ class DocumentSummary(BaseModel): @dataclass class ProviderConfig: - """Maps a provider to a model and capability flags.""" + """Maps a provider to a model.""" provider: str model: str - supports_caching: bool = False # Provider-specific details live here, not in your pipeline logic PROVIDERS = { - "gemini": ProviderConfig("gemini", "gemini-2.5-flash-lite", supports_caching=True), + "gemini": ProviderConfig("gemini", "gemini-2.5-flash-lite"), "openai": ProviderConfig("openai", "gpt-5-nano"), } -def make_config(provider_name: str, *, enable_caching: bool = False) -> Config: +def make_config(provider_name: str) -> Config: """Build a Config for the given provider with safe defaults.""" pc = PROVIDERS[provider_name] - return Config( - provider=pc.provider, - model=pc.model, - enable_caching=enable_caching and pc.supports_caching, - ) + return Config(provider=pc.provider, model=pc.model) async def analyze_document( @@ -83,10 +77,9 @@ async def analyze_document( prompt: str, *, provider_name: str = "gemini", - enable_caching: bool = False, ) -> DocumentSummary: """Analyze a document — works with any supported provider.""" - config = make_config(provider_name, enable_caching=enable_caching) + config = make_config(provider_name) options = Options(response_schema=DocumentSummary) result = await run( @@ -118,12 +111,12 @@ asyncio.run(main()) ### Step-by-Step Walkthrough 1. **Centralize provider details.** `ProviderConfig` maps each provider to - its model and capability flags. Your analysis functions never reference - provider names or models directly. + its model. Your analysis functions never reference provider names or + models directly. -2. **Guard capability-specific features.** `make_config` only enables caching - when both the caller requests it *and* the provider supports it. This - avoids `ConfigurationError` at runtime. +2. **Use `create_cache()` for persistent caching.** Caching is now + opt-in via `create_cache()` and `Options(cache=handle)`. Only call + it when the provider supports `persistent_cache` (e.g. Gemini). 3. **Write provider-agnostic functions.** `analyze_document` accepts a provider name and builds the config internally. The prompt, source, and diff --git a/docs/reference/api.md b/docs/reference/api.md index eccca32..1324bda 100644 --- a/docs/reference/api.md +++ b/docs/reference/api.md @@ -14,10 +14,14 @@ The primary execution functions are exported from `pollux`: ::: pollux.continue_tool +::: pollux.create_cache + ## Core Types ::: pollux.Source +::: pollux.CacheHandle + ::: pollux.Config ::: pollux.Options diff --git a/docs/reference/provider-capabilities.md b/docs/reference/provider-capabilities.md index 2f3e3a8..a6638f8 100644 --- a/docs/reference/provider-capabilities.md +++ b/docs/reference/provider-capabilities.md @@ -84,20 +84,19 @@ Pollux is **capability-transparent**, not capability-equalizing: providers are a When a requested feature is unsupported for the selected provider or release scope, Pollux raises `ConfigurationError` or `APIError` with a concrete hint, instead of degrading silently. -For example, enabling caching with OpenAI: +For example, creating a persistent cache with OpenAI: ```python -from pollux import Config - -config = Config( - provider="openai", - model="gpt-5-nano", - enable_caching=True, # not supported for OpenAI +from pollux import Config, Source, create_cache + +config = Config(provider="openai", model="gpt-5-nano") +# This raises immediately: +# ConfigurationError: Provider 'openai' does not support persistent caching +# hint: "Use a provider that supports persistent_cache (e.g. Gemini)." +handle = await create_cache( + [Source.from_text("hello")], config=config ) -# At execution time, Pollux raises: -# ConfigurationError: Provider does not support caching -# hint: "Disable caching or choose a provider with caching support." ``` -The error is raised at execution time (not at `Config` creation) because -caching support is a provider capability checked during plan execution. +The error is raised at `create_cache()` call time because persistent caching +is a provider capability checked before the upload attempt. diff --git a/docs/source-patterns.md b/docs/source-patterns.md index 01105ce..5ec596f 100644 --- a/docs/source-patterns.md +++ b/docs/source-patterns.md @@ -232,7 +232,7 @@ async def process_to_jsonl(directory: str, output: str) -> None: - **Memory with large collections.** Each `Source.from_file()` reads the file for hashing. For very large collections, process in batches rather than loading all sources at once. -- **Caching helps fan-out, not iteration.** `enable_caching=True` saves +- **Caching helps fan-out, not iteration.** `create_cache()` saves tokens when the *same source* gets reused across multiple prompts. It does not help when each file is different. See [Reducing Costs with Context Caching](caching.md). diff --git a/src/pollux/__init__.py b/src/pollux/__init__.py index bf76d12..e07db9e 100644 --- a/src/pollux/__init__.py +++ b/src/pollux/__init__.py @@ -3,6 +3,7 @@ Public API: - run(): Single prompt execution - run_many(): Multi-prompt source-pattern execution + - create_cache(): Create a persistent context cache - Source: Explicit input types - Config: Configuration dataclass """ @@ -13,7 +14,7 @@ import logging from typing import TYPE_CHECKING, Any -from pollux.cache import CacheRegistry +from pollux.cache import CacheHandle from pollux.config import Config from pollux.errors import ( APIError, @@ -48,9 +49,6 @@ logger = logging.getLogger(__name__) -# Module-level cache registry for reuse across calls -_registry = CacheRegistry() - async def run( prompt: str | None = None, @@ -113,17 +111,9 @@ async def run_many( provider = _get_provider(request.config) try: - trace = await execute_plan(plan, provider, _registry) + trace = await execute_plan(plan, provider) finally: - aclose = getattr(provider, "aclose", None) - if callable(aclose): - try: - await aclose() - except asyncio.CancelledError: - raise - except Exception as exc: - # Cleanup should never mask the primary failure. - logger.warning("Provider cleanup failed: %s", exc) + await _close_provider(provider) return build_result(plan, trace) @@ -169,6 +159,67 @@ async def continue_tool( return await run(prompt=None, config=config, options=merged_options) +async def create_cache( + sources: tuple[Source, ...] | list[Source], + *, + config: Config, + system_instruction: str | None = None, + tools: list[dict[str, Any]] | list[Any] | None = None, + ttl_seconds: int = 3600, +) -> CacheHandle: + """Create a persistent context cache for use with ``run()`` / ``run_many()``. + + Args: + sources: Content to cache (files, text, URLs). + config: Configuration specifying provider and model. + system_instruction: Optional system-level instruction baked into the cache. + tools: Optional tools to bake into the cache. + ttl_seconds: Time-to-live in seconds (must be ≥ 1). + + Returns: + A ``CacheHandle`` that can be passed via ``Options(cache=handle)``. + + Raises: + ConfigurationError: If the provider does not support persistent caching + or *ttl_seconds* is invalid. + + Example: + cfg = Config(provider="gemini", model="gemini-2.5-flash") + handle = await create_cache( + [Source.from_file("book.pdf")], + config=cfg, + ttl_seconds=3600, + ) + result = await run("Summarize.", config=cfg, options=Options(cache=handle)) + """ + from pollux.cache import create_cache_impl + + provider = _get_provider(config) + try: + return await create_cache_impl( + sources, + provider=provider, + config=config, + system_instruction=system_instruction, + tools=tools, + ttl_seconds=ttl_seconds, + ) + finally: + await _close_provider(provider) + + +async def _close_provider(provider: Provider) -> None: + """Close provider resources without masking primary errors.""" + aclose = getattr(provider, "aclose", None) + if callable(aclose): + try: + await aclose() + except asyncio.CancelledError: + raise + except Exception as exc: + logger.warning("Provider cleanup failed: %s", exc) + + def _get_provider(config: Config) -> Provider: """Get the appropriate provider based on configuration.""" if config.use_mock: @@ -210,6 +261,7 @@ def _get_provider(config: Config) -> Provider: __all__ = [ "APIError", "CacheError", + "CacheHandle", "Config", "ConfigurationError", "InternalError", @@ -222,6 +274,7 @@ def _get_provider(config: Config) -> Provider: "Source", "SourceError", "continue_tool", + "create_cache", "run", "run_many", ] diff --git a/src/pollux/cache.py b/src/pollux/cache.py index 9b4a9e2..fe06eed 100644 --- a/src/pollux/cache.py +++ b/src/pollux/cache.py @@ -6,58 +6,95 @@ from dataclasses import dataclass, field import hashlib import logging +from pathlib import Path import time from typing import TYPE_CHECKING, Any from pollux._singleflight import singleflight_cached +from pollux.errors import ConfigurationError, InternalError from pollux.retry import RetryPolicy, retry_async, should_retry_side_effect if TYPE_CHECKING: + from pollux.config import Config from pollux.providers.base import Provider from pollux.source import Source logger = logging.getLogger(__name__) +@dataclass(frozen=True) +class CacheHandle: + """Opaque handle returned by ``create_cache()``. + + Pass instances via ``Options(cache=handle)`` to reuse a persistent + context cache across ``run()`` / ``run_many()`` calls. + """ + + name: str + model: str + provider: str + expires_at: float + + @dataclass class CacheRegistry: """Registry tracking cache entries with expiration.""" _entries: dict[str, tuple[str, float]] = field(default_factory=dict) - _inflight: dict[str, asyncio.Future[str]] = field(default_factory=dict) + _inflight: dict[str, asyncio.Future[tuple[str, float]]] = field( + default_factory=dict + ) _lock: asyncio.Lock = field(default_factory=asyncio.Lock) - def get(self, key: str) -> str | None: - """Get cache name if exists and not expired.""" + def get(self, key: str) -> tuple[str, float] | None: + """Get cache entry if exists and not expired.""" entry = self._entries.get(key) if entry is None: return None - name, expires_at = entry + _, expires_at = entry if time.time() >= expires_at: del self._entries[key] logger.debug("Cache expired key=%s…", key[:8]) return None - return name + return entry - def set(self, key: str, name: str, ttl_seconds: int) -> None: + def set(self, key: str, value: tuple[str, float]) -> None: """Store cache entry with expiration time.""" - expires_at = time.time() + max(0, ttl_seconds) - self._entries[key] = (name, expires_at) + self._entries[key] = value def compute_cache_key( model: str, sources: tuple[Source, ...], + provider: str | None = None, + api_key: str | None = None, system_instruction: str | None = None, + tools: list[dict[str, Any]] | list[Any] | None = None, ) -> str: """Compute deterministic cache key using content hashes. - Key = hash(model + system + content digests of sources) - This fixes the cache identity collision bug where identifier+size was used. + Key = hash(model + provider + api_key + system + content digests of sources). + Including ``api_key`` prevents cross-account handle reuse when multiple + keys for the same provider/model coexist in one process. """ parts = [model] + if provider: + parts.append(provider) + if api_key: + parts.append(api_key) if system_instruction: parts.append(system_instruction) + if tools: + import json + + try: + # sort_keys to ensure deterministic JSON representation + parts.append(json.dumps(tools, sort_keys=True)) + except TypeError as e: + raise ConfigurationError( + "Tools provided to create_cache() must be JSON serializable.", + hint="If using custom objects or Pydantic models for tools, convert them to dicts first.", + ) from e for source in sources: # Use content hash, not identifier @@ -73,44 +110,204 @@ async def get_or_create_cache( *, key: str, model: str, - parts: list[Any], + raw_parts: list[Any], system_instruction: str | None, + tools: list[dict[str, Any]] | list[Any] | None = None, ttl_seconds: int, retry_policy: RetryPolicy | None = None, -) -> str | None: +) -> tuple[str, float] | None: """Get existing cache or create new one with single-flight protection. - Single-flight: concurrent requests for the same key share one creation call. + File placeholders in *raw_parts* are resolved inside the single-flight + work function so concurrent callers share both uploads and cache creation. """ - if not provider.supports_caching: + if not provider.capabilities.persistent_cache: return None - async def _work() -> str: + async def _work() -> tuple[str, float]: logger.debug("Creating cache key=%s…", key[:8]) - if retry_policy is None or retry_policy.max_attempts <= 1: - return await provider.create_cache( + policy = retry_policy or RetryPolicy(max_attempts=1) + parts = await _resolve_file_parts(raw_parts, provider, policy) + if policy.max_attempts <= 1: + name = await provider.create_cache( model=model, parts=parts, system_instruction=system_instruction, + tools=tools, ttl_seconds=ttl_seconds, ) + return name, time.time() + max(0, ttl_seconds) - return await retry_async( + name = await retry_async( lambda: provider.create_cache( model=model, parts=parts, system_instruction=system_instruction, + tools=tools, ttl_seconds=ttl_seconds, ), - policy=retry_policy, + policy=policy, should_retry=should_retry_side_effect, ) + return name, time.time() + max(0, ttl_seconds) return await singleflight_cached( key, lock=registry._lock, inflight=registry._inflight, cache_get=registry.get, - cache_set=lambda k, v: registry.set(k, v, ttl_seconds), + cache_set=registry.set, work=_work, ) + + +# Module-level registry shared across create_cache calls. +_registry = CacheRegistry() + + +async def _resolve_file_parts( + parts: list[Any], + provider: Provider, + retry_policy: RetryPolicy, +) -> list[Any]: + """Replace file placeholders with uploaded assets. + + Memoizes by ``(file_path, mime_type)`` so duplicate file sources in a + single ``create_cache()`` call share one upload. No singleflight + needed because ``create_cache`` is sequential, not concurrent fan-out. + """ + resolved: list[Any] = [] + seen: dict[tuple[str, str], Any] = {} + for part in parts: + if ( + isinstance(part, dict) + and isinstance(part.get("file_path"), str) + and isinstance(part.get("mime_type"), str) + ): + fp, mt = part["file_path"], part["mime_type"] + key = (fp, mt) + if key in seen: + resolved.append(seen[key]) + continue + if retry_policy.max_attempts <= 1: + asset = await provider.upload_file(Path(fp), mt) + else: + + async def _upload(_fp: str = fp, _mt: str = mt) -> Any: + return await provider.upload_file(Path(_fp), _mt) + + asset = await retry_async( + _upload, + policy=retry_policy, + should_retry=should_retry_side_effect, + ) + seen[key] = asset + resolved.append(asset) + else: + resolved.append(part) + return resolved + + +async def create_cache_impl( + sources: tuple[Source, ...] | list[Source], + *, + provider: Provider, + config: Config, + system_instruction: str | None = None, + tools: list[dict[str, Any]] | list[Any] | None = None, + ttl_seconds: int = 3600, +) -> CacheHandle: + """Core implementation of ``create_cache()``. + + Receives an already-initialized provider; the caller manages its lifecycle. + + All input validation is intentionally front-loaded before any I/O + (uploads, API calls). If the parameter surface grows beyond the + current five axes, consider a validated ``CacheSpec`` dataclass to + keep this boundary manageable. + """ + from pollux.plan import build_shared_parts + from pollux.source import Source as SourceCls + + if not isinstance(ttl_seconds, int) or ttl_seconds < 1: + raise ConfigurationError( + f"ttl_seconds must be an integer ≥ 1, got {ttl_seconds!r}", + hint="Pass a positive integer for the cache TTL.", + ) + + if system_instruction is not None and not isinstance(system_instruction, str): + raise ConfigurationError( + f"system_instruction must be a string, got {type(system_instruction).__name__}", + hint="Pass a string for the system instruction.", + ) + + if not provider.capabilities.persistent_cache: + raise ConfigurationError( + f"Provider {config.provider!r} does not support persistent caching", + hint="Use a provider that supports persistent_cache (e.g. Gemini).", + ) + + src_tuple = tuple(sources) if not isinstance(sources, tuple) else sources + + for s in src_tuple: + if not isinstance(s, SourceCls): + raise ConfigurationError( + f"Expected Source, got {type(s).__name__}", + hint="Use Source.from_file(), Source.from_text(), etc.", + ) + + if tools is not None: + for i, t in enumerate(tools): + if not isinstance(t, dict): + raise ConfigurationError( + f"Tool at index {i} must be a dictionary, got {type(t).__name__}", + hint="Ensure all items in the tools list are dictionaries.", + ) + + key = compute_cache_key( + config.model, + src_tuple, + provider=config.provider, + api_key=config.api_key, + system_instruction=system_instruction, + tools=tools, + ) + + cached = _registry.get(key) + if cached is not None: + cache_name, expires_at = cached + return CacheHandle( + name=cache_name, + model=config.model, + provider=config.provider, + expires_at=expires_at, + ) + + raw_parts = build_shared_parts(src_tuple) + + result = await get_or_create_cache( + provider, + _registry, + key=key, + model=config.model, + raw_parts=raw_parts, + system_instruction=system_instruction, + tools=tools, + ttl_seconds=ttl_seconds, + retry_policy=config.retry, + ) + + if result is None: + raise InternalError( + "Cache creation returned None unexpectedly", + hint="This is a Pollux internal error. Please report it.", + ) + + cache_name, expires_at = result + + return CacheHandle( + name=cache_name, + model=config.model, + provider=config.provider, + expires_at=expires_at, + ) diff --git a/src/pollux/config.py b/src/pollux/config.py index d4d815f..9f72445 100644 --- a/src/pollux/config.py +++ b/src/pollux/config.py @@ -38,9 +38,6 @@ class Config: #: Auto-resolved from ``GEMINI_API_KEY`` or ``OPENAI_API_KEY`` when *None*. api_key: str | None = None use_mock: bool = False - #: Gemini-only in v1.0; silently ignored for other providers. - enable_caching: bool = False - ttl_seconds: int = 3600 request_concurrency: int = 6 retry: RetryPolicy = field(default_factory=RetryPolicy) @@ -59,21 +56,11 @@ def __post_init__(self) -> None: f"request_concurrency must be an integer, got {type(self.request_concurrency).__name__}", hint="Pass a whole number ≥ 1 for request_concurrency.", ) - if not isinstance(self.ttl_seconds, int): - raise ConfigurationError( - f"ttl_seconds must be an integer, got {type(self.ttl_seconds).__name__}", - hint="Pass a whole number ≥ 0 for ttl_seconds.", - ) if self.request_concurrency < 1: raise ConfigurationError( f"request_concurrency must be ≥ 1, got {self.request_concurrency}", hint="This controls how many API calls run in parallel.", ) - if self.ttl_seconds < 0: - raise ConfigurationError( - f"ttl_seconds must be ≥ 0, got {self.ttl_seconds}", - hint="This controls the cache time-to-live in seconds (0 disables caching TTL).", - ) # Auto-resolve API key from environment if not provided if self.api_key is None and not self.use_mock: diff --git a/src/pollux/execute.py b/src/pollux/execute.py index 2c7b3ca..b3c2842 100644 --- a/src/pollux/execute.py +++ b/src/pollux/execute.py @@ -11,7 +11,6 @@ from typing import TYPE_CHECKING, Any from pollux._singleflight import singleflight_cached -from pollux.cache import CacheRegistry, get_or_create_cache from pollux.errors import APIError, ConfigurationError, InternalError, PolluxError from pollux.providers.models import ( Message, @@ -68,9 +67,7 @@ class ExecutionTrace: conversation_state: dict[str, Any] | None = None -async def execute_plan( - plan: Plan, provider: Provider, registry: CacheRegistry -) -> ExecutionTrace: +async def execute_plan(plan: Plan, provider: Provider) -> ExecutionTrace: """Execute the plan with the given provider. Handles: @@ -116,8 +113,19 @@ async def execute_plan( "Conversation continuity currently supports exactly one prompt per call", hint="Use run() or run_many() with a single prompt when passing history/continue_from.", ) + # Runtime safety net: reject hand-built handles targeting providers + # that lack persistent caching (the planner already validates other + # cache conflicts). + if plan.cache_name is not None and not caps.persistent_cache: + raise ConfigurationError( + "Provider does not support persistent caching", + hint=( + "Remove options.cache or choose a provider with " + "persistent_cache support." + ), + ) - if (not provider.supports_uploads) and any( + if (not provider.capabilities.uploads) and any( isinstance(p, dict) and isinstance(p.get("file_path"), str) and isinstance(p.get("mime_type"), str) @@ -169,48 +177,12 @@ async def execute_plan( upload_inflight: dict[tuple[str, str], asyncio.Future[ProviderFileAsset]] = {} upload_lock = asyncio.Lock() retry_policy = config.retry - cache_name: str | None = None responses: list[dict[str, Any]] = [] total_usage: dict[str, int] = {} conversation_state: dict[str, Any] | None = None try: - # Handle caching - if plan.use_cache: - # Resolve uploads for shared parts first, as cache creation requires URIs - shared_parts = list(plan.shared_parts) - if shared_parts: - shared_parts = await _substitute_upload_parts( - shared_parts, - provider=provider, - call_idx=None, - upload_cache=upload_cache, - upload_inflight=upload_inflight, - upload_lock=upload_lock, - retry_policy=retry_policy, - ) - - if plan.cache_key: - try: - cache_name = await get_or_create_cache( - provider, - registry, - key=plan.cache_key, - model=config.model, - parts=shared_parts, # Use resolved parts with URIs - system_instruction=options.system_instruction, - ttl_seconds=config.ttl_seconds, - retry_policy=retry_policy, - ) - except asyncio.CancelledError: - raise - except PolluxError: - raise - except Exception as e: - raise InternalError( - f"Cache creation failed: {type(e).__name__}: {e}", - hint="This is a Pollux internal error. Please report it.", - ) from e + cache_name = plan.cache_name # Execute calls with concurrency control concurrency = config.request_concurrency diff --git a/src/pollux/options.py b/src/pollux/options.py index 778e2bd..0ea6ee2 100644 --- a/src/pollux/options.py +++ b/src/pollux/options.py @@ -10,6 +10,7 @@ from pollux.errors import ConfigurationError if TYPE_CHECKING: + from pollux.cache import CacheHandle from pollux.result import ResultEnvelope ReasoningEffort = str @@ -44,6 +45,8 @@ class Options: continue_from: ResultEnvelope | None = None #: Hard limit on the model's total output tokens. Provider-specific semantics. max_tokens: int | None = None + #: Persistent context cache obtained from ``create_cache()``. + cache: CacheHandle | None = None def __post_init__(self) -> None: """Validate option shapes early for clear errors.""" @@ -101,6 +104,14 @@ def __post_init__(self) -> None: "continue_from must be a prior Pollux result envelope", hint="Pass the dict returned by run() or run_many().", ) + if self.cache is not None: + from pollux.cache import CacheHandle + + if not isinstance(self.cache, CacheHandle): + raise ConfigurationError( + "cache must be a CacheHandle from create_cache()", + hint="Call create_cache() first, then pass Options(cache=handle).", + ) def response_schema_json(self) -> dict[str, Any] | None: """Return JSON Schema for provider APIs.""" diff --git a/src/pollux/plan.py b/src/pollux/plan.py index 59663df..3dbe740 100644 --- a/src/pollux/plan.py +++ b/src/pollux/plan.py @@ -3,8 +3,11 @@ from __future__ import annotations from dataclasses import dataclass +import time from typing import TYPE_CHECKING, Any +from pollux.errors import ConfigurationError + if TYPE_CHECKING: from pollux.request import Request from pollux.source import Source @@ -12,12 +15,11 @@ @dataclass(frozen=True) class Plan: - """Execution plan with shared context and cache identity.""" + """Execution plan with shared context and optional cache reference.""" request: Request shared_parts: tuple[Any, ...] = () - use_cache: bool = False - cache_key: str | None = None + cache_name: str | None = None @property def n_calls(self) -> int: @@ -29,35 +31,78 @@ def build_plan(request: Request) -> Plan: """Build execution plan from normalized request. Handles both single-prompt and vectorized (multi-prompt) scenarios. + Validates cache handle conflicts eagerly so callers get clear errors + before any network I/O. """ - config = request.config sources = request.sources - - # Build shared parts from sources - shared_parts = _build_shared_parts(sources) - - # Determine if caching should be used - use_cache = config.enable_caching and len(shared_parts) > 0 - cache_key = None - - if use_cache: - from pollux.cache import compute_cache_key - - cache_key = compute_cache_key( - config.model, - sources, - system_instruction=request.options.system_instruction, - ) + shared_parts = build_shared_parts(sources) + + cache_name: str | None = None + if request.options.cache is not None: + cache = request.options.cache + if time.time() >= cache.expires_at: + raise ConfigurationError( + "cache handle has expired", + hint="Create a new cache with create_cache().", + ) + if cache.provider != request.config.provider: + raise ConfigurationError( + "cache handle provider does not match config provider", + hint=( + f"Create the cache with provider={request.config.provider!r} and " + "reuse it with the same provider." + ), + ) + if cache.model != request.config.model: + raise ConfigurationError( + "cache handle model does not match config model", + hint=( + f"Create the cache with model={request.config.model!r} and reuse it " + "with the same model." + ), + ) + if request.options.system_instruction is not None: + raise ConfigurationError( + "system_instruction cannot be used with a cache handle", + hint=( + "Bake the system instruction into create_cache() instead, " + "or remove the cache handle." + ), + ) + if request.options.tools is not None: + raise ConfigurationError( + "tools cannot be used with a cache handle", + hint=( + "Bake tools into create_cache() instead, " + "or remove the cache handle." + ), + ) + if request.options.tool_choice is not None: + raise ConfigurationError( + "tool_choice cannot be used with a cache handle", + hint=( + "Remove tool_choice when using a cache handle, " + "or remove the cache handle." + ), + ) + if shared_parts: + raise ConfigurationError( + "sources cannot be used with a cache handle", + hint=( + "Bake sources into create_cache() instead, " + "or remove the cache handle." + ), + ) + cache_name = cache.name return Plan( request=request, shared_parts=tuple(shared_parts), - use_cache=use_cache, - cache_key=cache_key, + cache_name=cache_name, ) -def _build_shared_parts(sources: tuple[Source, ...]) -> list[Any]: +def build_shared_parts(sources: tuple[Source, ...]) -> list[Any]: """Convert sources to API parts.""" parts: list[Any] = [] diff --git a/src/pollux/providers/_errors.py b/src/pollux/providers/_errors.py index 99d513a..074cd5d 100644 --- a/src/pollux/providers/_errors.py +++ b/src/pollux/providers/_errors.py @@ -16,6 +16,7 @@ from pollux.errors import ( APIError, CacheError, + ConfigurationError, RateLimitError, _walk_exception_chain, ) @@ -138,6 +139,11 @@ def wrap_provider_error( if isinstance(exc, asyncio.CancelledError): raise exc + # ConfigurationError should propagate as-is, not be re-wrapped as + # CacheError/APIError — it signals a caller mistake, not a provider failure. + if isinstance(exc, ConfigurationError): + raise exc + # Already wrapped — fill in missing context only. if isinstance(exc, APIError): if exc.provider is None: diff --git a/src/pollux/providers/anthropic.py b/src/pollux/providers/anthropic.py index 9437382..2c90e71 100644 --- a/src/pollux/providers/anthropic.py +++ b/src/pollux/providers/anthropic.py @@ -57,21 +57,11 @@ def _get_client(self) -> Any: self._client = AsyncAnthropic(api_key=self.api_key) return self._client - @property - def supports_caching(self) -> bool: - """Whether this provider supports context caching.""" - return self.capabilities.caching - - @property - def supports_uploads(self) -> bool: - """Whether this provider supports file uploads.""" - return self.capabilities.uploads - @property def capabilities(self) -> ProviderCapabilities: """Return supported feature flags.""" return ProviderCapabilities( - caching=False, + persistent_cache=False, uploads=True, structured_outputs=True, reasoning=True, @@ -351,10 +341,11 @@ async def create_cache( model: str, parts: list[Any], system_instruction: str | None = None, + tools: list[dict[str, Any]] | list[Any] | None = None, ttl_seconds: int = 3600, ) -> str: """Raise because Anthropic caching is deferred.""" - _ = model, parts, system_instruction, ttl_seconds + _ = model, parts, system_instruction, tools, ttl_seconds raise APIError("Anthropic provider does not support context caching") async def aclose(self) -> None: diff --git a/src/pollux/providers/base.py b/src/pollux/providers/base.py index 757f79b..ee9bebe 100644 --- a/src/pollux/providers/base.py +++ b/src/pollux/providers/base.py @@ -19,7 +19,7 @@ class ProviderCapabilities: """Feature flags exposed by providers.""" - caching: bool + persistent_cache: bool uploads: bool structured_outputs: bool = False reasoning: bool = False @@ -48,21 +48,12 @@ async def create_cache( model: str, parts: list[Any], system_instruction: str | None = None, + tools: list[dict[str, Any]] | list[Any] | None = None, ttl_seconds: int = 3600, ) -> str: """Create a cache and return its name.""" ... - @property - def supports_caching(self) -> bool: - """Whether this provider supports caching.""" - ... - - @property - def supports_uploads(self) -> bool: - """Whether this provider supports file uploads.""" - ... - @property def capabilities(self) -> ProviderCapabilities: """Feature capabilities for strict option validation.""" diff --git a/src/pollux/providers/gemini.py b/src/pollux/providers/gemini.py index d63b17e..13a62ff 100644 --- a/src/pollux/providers/gemini.py +++ b/src/pollux/providers/gemini.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any import uuid -from pollux.errors import APIError +from pollux.errors import APIError, ConfigurationError from pollux.providers._errors import wrap_provider_error from pollux.providers.base import ProviderCapabilities from pollux.providers.models import ( @@ -48,21 +48,11 @@ def _get_client(self) -> Any: self._client = genai.Client(api_key=self.api_key) return self._client - @property - def supports_caching(self) -> bool: - """Whether this provider supports context caching.""" - return self.capabilities.caching - - @property - def supports_uploads(self) -> bool: - """Whether this provider supports file uploads.""" - return self.capabilities.uploads - @property def capabilities(self) -> ProviderCapabilities: """Return supported feature flags.""" return ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, structured_outputs=True, reasoning=True, @@ -114,6 +104,11 @@ def _normalize_tools(self, tools: list[dict[str, Any]]) -> list[Any]: tool_objs: list[Any] = [] for t in tools: + if not isinstance(t, dict): + raise ConfigurationError( + f"Tool must be a dictionary, got {type(t).__name__}", + hint="Ensure all items in the tools list are dictionaries.", + ) if "name" in t: raw_params = t.get("parameters") params = ( @@ -425,20 +420,28 @@ async def create_cache( model: str, parts: list[Any], system_instruction: str | None = None, + tools: list[dict[str, Any]] | list[Any] | None = None, ttl_seconds: int = 3600, ) -> str: """Create a cached content entry.""" client = self._get_client() from google.genai import types + config_kwargs: dict[str, Any] = { + "contents": self._convert_parts(parts), + "system_instruction": system_instruction, + "ttl": f"{ttl_seconds}s", + } + try: + if tools is not None: + tool_objs = self._normalize_tools(tools) + if tool_objs: + config_kwargs["tools"] = tool_objs + result = await client.aio.caches.create( model=model, - config=types.CreateCachedContentConfig( - contents=self._convert_parts(parts), - system_instruction=system_instruction, - ttl=f"{ttl_seconds}s", - ), + config=types.CreateCachedContentConfig(**config_kwargs), ) return str(result.name) except asyncio.CancelledError: diff --git a/src/pollux/providers/mock.py b/src/pollux/providers/mock.py index 6301ef5..c7dc2c5 100644 --- a/src/pollux/providers/mock.py +++ b/src/pollux/providers/mock.py @@ -17,21 +17,11 @@ class MockProvider: Supports caching and uploads but returns synthetic responses. """ - @property - def supports_caching(self) -> bool: - """Whether this provider supports caching.""" - return self.capabilities.caching - - @property - def supports_uploads(self) -> bool: - """Whether this provider supports file uploads.""" - return self.capabilities.uploads - @property def capabilities(self) -> ProviderCapabilities: """Return supported feature flags.""" return ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, structured_outputs=False, reasoning=False, @@ -70,6 +60,7 @@ async def create_cache( model: str, parts: list[Any], # noqa: ARG002 system_instruction: str | None = None, # noqa: ARG002 + tools: list[dict[str, Any]] | list[Any] | None = None, # noqa: ARG002 ttl_seconds: int = 3600, # noqa: ARG002 ) -> str: """Return a mock cache name.""" diff --git a/src/pollux/providers/openai.py b/src/pollux/providers/openai.py index 39f7e23..a9efd01 100644 --- a/src/pollux/providers/openai.py +++ b/src/pollux/providers/openai.py @@ -44,21 +44,11 @@ def _get_client(self) -> Any: self._client = AsyncOpenAI(api_key=self.api_key) return self._client - @property - def supports_caching(self) -> bool: - """Whether this provider supports context caching.""" - return self.capabilities.caching - - @property - def supports_uploads(self) -> bool: - """Whether this provider supports file uploads.""" - return self.capabilities.uploads - @property def capabilities(self) -> ProviderCapabilities: """Return supported feature flags.""" return ProviderCapabilities( - caching=False, + persistent_cache=False, uploads=True, structured_outputs=True, reasoning=True, @@ -336,10 +326,11 @@ async def create_cache( model: str, parts: list[Any], system_instruction: str | None = None, + tools: list[dict[str, Any]] | list[Any] | None = None, ttl_seconds: int = 3600, ) -> str: """Raise because OpenAI context caching is not supported.""" - _ = model, parts, system_instruction, ttl_seconds + _ = model, parts, system_instruction, tools, ttl_seconds raise APIError("OpenAI provider does not support context caching") async def delete_file(self, file_id: str) -> None: diff --git a/tests/conftest.py b/tests/conftest.py index 0536563..74d68c2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -32,7 +32,7 @@ class FakeProvider: last_generate_kwargs: dict[str, Any] | None = None _capabilities: ProviderCapabilities = field( default_factory=lambda: ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, structured_outputs=False, reasoning=False, @@ -41,14 +41,6 @@ class FakeProvider: ) ) - @property - def supports_caching(self) -> bool: - return self.capabilities.caching - - @property - def supports_uploads(self) -> bool: - return self.capabilities.uploads - @property def capabilities(self) -> ProviderCapabilities: return self._capabilities @@ -86,9 +78,10 @@ async def create_cache( model: str, parts: list[Any], system_instruction: str | None = None, + tools: list[dict[str, Any]] | list[Any] | None = None, ttl_seconds: int = 3600, ) -> str: - del model, parts, system_instruction, ttl_seconds + del model, parts, system_instruction, tools, ttl_seconds self.cache_calls += 1 return "cachedContents/test" diff --git a/tests/helpers.py b/tests/helpers.py index ae9d354..3503f63 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -72,7 +72,7 @@ def __post_init__(self) -> None: self, "_capabilities", ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, structured_outputs=True, reasoning=True, diff --git a/tests/test_config.py b/tests/test_config.py index 5c143f2..1cbe33b 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -184,31 +184,6 @@ def test_non_integer_request_concurrency_raises_clear_error(gemini_model: str) - assert exc.value.hint is not None -def test_negative_ttl_seconds_raises_clear_error(gemini_model: str) -> None: - """Negative TTL is nonsensical; must fail at construction.""" - with pytest.raises(ConfigurationError, match="ttl_seconds must be") as exc: - Config(provider="gemini", model=gemini_model, use_mock=True, ttl_seconds=-1) - assert exc.value.hint is not None - - -def test_non_integer_ttl_seconds_raises_clear_error(gemini_model: str) -> None: - """Non-integer TTL should raise ConfigurationError, not TypeError.""" - with pytest.raises(ConfigurationError, match="must be an integer") as exc: - Config( - provider="gemini", - model=gemini_model, - use_mock=True, - ttl_seconds="3600", # type: ignore[arg-type] - ) - assert exc.value.hint is not None - - -def test_zero_ttl_seconds_is_allowed(gemini_model: str) -> None: - """ttl_seconds=0 is valid (disables caching TTL).""" - cfg = Config(provider="gemini", model=gemini_model, use_mock=True, ttl_seconds=0) - assert cfg.ttl_seconds == 0 - - def test_config_str_and_repr_redact_api_key(gemini_model: str) -> None: """String representations must not leak secrets.""" secret = "top-secret-key" diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index fea7aef..a7a078b 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -10,7 +10,8 @@ import pytest import pollux -from pollux.cache import CacheRegistry, compute_cache_key +import pollux.cache +from pollux.cache import CacheHandle, CacheRegistry, compute_cache_key from pollux.config import Config from pollux.errors import APIError, ConfigurationError, PlanningError, SourceError from pollux.options import Options @@ -202,7 +203,7 @@ async def upload_file(self, path: Any, mime_type: str) -> ProviderFileAsset: # async def test_cache_error_attributes_provider_without_call_index( monkeypatch: pytest.MonkeyPatch, ) -> None: - """Cache failures should carry provider and phase but no call index.""" + """Cache failures from create_cache() should carry provider and phase.""" @dataclass class _Provider(FakeProvider): @@ -222,21 +223,18 @@ async def create_cache(self, **kwargs: Any) -> str: provider="gemini", model=CACHE_MODEL, use_mock=True, - enable_caching=True, retry=RetryPolicy(max_attempts=1), ) with pytest.raises(APIError) as exc: - await pollux.run_many( - ("Q",), - sources=(Source.from_text("cache me"),), + await pollux.create_cache( + (Source.from_text("cache me"),), config=cfg, ) err = exc.value assert err.provider == "gemini" assert err.phase == "cache" - assert err.call_idx is None # ============================================================================= @@ -285,6 +283,55 @@ async def aclose(self) -> None: assert fake.closed == 1, name +@pytest.mark.asyncio +async def test_create_cache_closes_provider(monkeypatch: pytest.MonkeyPatch) -> None: + """create_cache should close provider resources on success and failure.""" + + @dataclass + class _Provider(FakeProvider): + closed: int = 0 + fail_cache: bool = False + + async def create_cache( + self, + *, + model: str, + parts: list[Any], + system_instruction: str | None = None, + tools: list[dict[str, Any]] | list[Any] | None = None, + ttl_seconds: int = 3600, + ) -> str: + if self.fail_cache: + raise APIError("cache failed", provider="gemini", phase="cache") + return await super().create_cache( + model=model, + parts=parts, + system_instruction=system_instruction, + tools=tools, + ttl_seconds=ttl_seconds, + ) + + async def aclose(self) -> None: + self.closed += 1 + + cfg = Config(provider="gemini", model=CACHE_MODEL, use_mock=True) + for fail_cache in (False, True): + fake = _Provider(fail_cache=fail_cache) + monkeypatch.setattr(pollux, "_get_provider", lambda _config, _fake=fake: _fake) + monkeypatch.setattr(pollux.cache, "_registry", CacheRegistry()) + + if fail_cache: + with pytest.raises(APIError, match="cache failed"): + await pollux.create_cache((Source.from_text("cache me"),), config=cfg) + else: + handle = await pollux.create_cache( + (Source.from_text("cache me"),), config=cfg + ) + assert isinstance(handle, CacheHandle) + + assert fake.closed == 1 + + # ============================================================================= # Retry Behavior (Boundary) # ============================================================================= @@ -397,30 +444,74 @@ async def test_source_from_json_is_sent_as_inline_content() -> None: # ============================================================================= +@pytest.mark.asyncio +async def test_cache_single_flight_deduplicates_file_uploads( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Any, +) -> None: + """Concurrent create_cache() calls should share uploads via single-flight.""" + gate = asyncio.Event() + entered = asyncio.Event() + + @dataclass + class _SlowCacheProvider(FakeProvider): + async def create_cache(self, **kwargs: Any) -> str: # noqa: ARG002 + self.cache_calls += 1 + entered.set() + await gate.wait() + return "cachedContents/test" + + fake = _SlowCacheProvider() + monkeypatch.setattr(pollux, "_get_provider", lambda _config: fake) + monkeypatch.setattr(pollux.cache, "_registry", CacheRegistry()) + + file_path = tmp_path / "shared.txt" + file_path.write_text("shared content", encoding="utf-8") + + cfg = Config( + provider="gemini", + model=CACHE_MODEL, + use_mock=True, + retry=RetryPolicy(max_attempts=1), + ) + source = Source.from_file(file_path) + + t1 = asyncio.create_task(pollux.create_cache((source,), config=cfg)) + await entered.wait() + t2 = asyncio.create_task(pollux.create_cache((source,), config=cfg)) + # Small yield to let t2 join the singleflight waiters. + await asyncio.sleep(0) + gate.set() + + results = await asyncio.gather(t1, t2, return_exceptions=True) + assert all(isinstance(r, CacheHandle) for r in results) + assert fake.upload_calls == 1, "concurrent calls should share uploads" + assert fake.cache_calls == 1, "concurrent calls should share cache creation" + + @pytest.mark.asyncio async def test_cache_single_flight_propagates_failure_and_clears_inflight( monkeypatch: pytest.MonkeyPatch, ) -> None: - """If cache creation fails, all waiters should see the error and future calls can recover.""" + """If cache creation fails, concurrent callers see the error; future calls recover.""" fake = GateProvider(kind="cache") monkeypatch.setattr(pollux, "_get_provider", lambda _config: fake) - monkeypatch.setattr(pollux, "_registry", CacheRegistry()) + monkeypatch.setattr(pollux.cache, "_registry", CacheRegistry()) cfg = Config( provider="gemini", model=CACHE_MODEL, use_mock=True, - enable_caching=True, retry=RetryPolicy(max_attempts=1), ) source = Source.from_text("cache me", identifier="same-id") t1 = asyncio.create_task( - pollux.run_many(("A",), sources=(source,), config=cfg), + pollux.create_cache((source,), config=cfg), ) await fake.started.wait() t2 = asyncio.create_task( - pollux.run_many(("B",), sources=(source,), config=cfg), + pollux.create_cache((source,), config=cfg), ) fake.release.set() @@ -430,13 +521,13 @@ async def test_cache_single_flight_propagates_failure_and_clears_inflight( assert fake.cache_calls == 1 # After the failure, the registry should not be stuck; it should be able to create a cache. - result = await pollux.run_many(("C",), sources=(source,), config=cfg) - assert result["status"] == "ok" + handle = await pollux.create_cache((source,), config=cfg) + assert isinstance(handle, CacheHandle) assert fake.cache_calls == 2 # And after a successful cache, additional calls should not recreate it. - result2 = await pollux.run_many(("D",), sources=(source,), config=cfg) - assert result2["status"] == "ok" + handle2 = await pollux.create_cache((source,), config=cfg) + assert isinstance(handle2, CacheHandle) assert fake.cache_calls == 2 @@ -732,46 +823,238 @@ async def delete_file(self, file_id: str) -> None: # noqa: ARG002 @pytest.mark.asyncio -async def test_cached_context_is_not_resent_on_each_call( +async def test_cached_context_rejects_sources( monkeypatch: pytest.MonkeyPatch, ) -> None: - """When cache is active, call payloads should include only prompt-specific parts.""" - - @dataclass - class PartsCaptureProvider(FakeProvider): - received_parts: list[list[Any]] = field(default_factory=list) - cache_names: list[str | None] = field(default_factory=list) + """When cache is active via Options(cache=handle), passing sources raises ConfigurationError.""" + import time - async def generate(self, request: ProviderRequest) -> ProviderResponse: - self.received_parts.append(request.parts) - self.cache_names.append(request.cache_name) - prompt = ( - request.parts[-1] - if request.parts and isinstance(request.parts[-1], str) - else "" - ) - return ProviderResponse(text=f"ok:{prompt}", usage={"total_tokens": 1}) - - fake = PartsCaptureProvider() + fake = FakeProvider() monkeypatch.setattr(pollux, "_get_provider", lambda _config: fake) - monkeypatch.setattr(pollux, "_registry", CacheRegistry()) - cfg = Config( provider="gemini", model=CACHE_MODEL, use_mock=True, - enable_caching=True, ) - await pollux.run_many( - prompts=("A", "B"), - sources=(Source.from_text("shared context"),), - config=cfg, + + handle = CacheHandle( + name="cachedContents/test", + model=CACHE_MODEL, + provider="gemini", + expires_at=time.time() + 3600, ) - assert fake.cache_calls == 1 - assert fake.cache_names == ["cachedContents/test", "cachedContents/test"] - assert fake.received_parts == [["A"], ["B"]] + with pytest.raises(ConfigurationError, match="sources cannot be used"): + await pollux.run_many( + prompts=("A", "B"), + sources=(Source.from_text("shared context"),), + config=cfg, + options=Options(cache=handle), + ) + + assert fake.last_parts is None + + +@pytest.mark.asyncio +async def test_options_cache_requires_persistent_cache_capability( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Passing Options(cache=...) should fail on providers without persistent caching.""" + import time + + fake = FakeProvider( + _capabilities=ProviderCapabilities( + persistent_cache=False, + uploads=True, + structured_outputs=False, + reasoning=False, + deferred_delivery=False, + conversation=False, + ) + ) + monkeypatch.setattr(pollux, "_get_provider", lambda _config: fake) + + cfg = Config(provider="openai", model=OPENAI_MODEL, use_mock=True) + handle = CacheHandle( + name="cachedContents/test", + model=OPENAI_MODEL, + provider="openai", + expires_at=time.time() + 3600, + ) + + with pytest.raises(ConfigurationError, match="persistent caching"): + await pollux.run_many( + prompts=("Q",), + config=cfg, + options=Options(cache=handle), + ) + + +@pytest.mark.asyncio +async def test_options_cache_rejects_expired_handle( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Expired cache handles must be rejected before any network I/O.""" + import time + + fake = FakeProvider() + monkeypatch.setattr(pollux, "_get_provider", lambda _config: fake) + + cfg = Config(provider="gemini", model=GEMINI_MODEL, use_mock=True) + expired = CacheHandle( + name="cachedContents/test", + model=GEMINI_MODEL, + provider="gemini", + expires_at=time.time() - 1, + ) + + with pytest.raises(ConfigurationError, match="expired"): + await pollux.run("Q", config=cfg, options=Options(cache=expired)) + + assert fake.last_parts is None + + +@pytest.mark.asyncio +async def test_options_cache_rejects_provider_and_model_mismatch( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Cache handles must match the active provider and model.""" + import time + + fake = FakeProvider() + monkeypatch.setattr(pollux, "_get_provider", lambda _config: fake) + + cfg = Config(provider="gemini", model=GEMINI_MODEL, use_mock=True) + bad_provider = CacheHandle( + name="cachedContents/test", + model=GEMINI_MODEL, + provider="openai", + expires_at=time.time() + 3600, + ) + bad_model = CacheHandle( + name="cachedContents/test", + model=OPENAI_MODEL, + provider="gemini", + expires_at=time.time() + 3600, + ) + + with pytest.raises(ConfigurationError, match="provider does not match"): + await pollux.run_many( + prompts=("Q",), + sources=(Source.from_text("shared context"),), + config=cfg, + options=Options(cache=bad_provider), + ) + with pytest.raises(ConfigurationError, match="model does not match"): + await pollux.run_many( + prompts=("Q",), + sources=(Source.from_text("shared context"),), + config=cfg, + options=Options(cache=bad_model), + ) + + assert fake.last_parts is None + + +@pytest.mark.asyncio +async def test_options_cache_rejects_system_instruction( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """system_instruction cannot coexist with a cache handle.""" + import time + + fake = FakeProvider() + monkeypatch.setattr(pollux, "_get_provider", lambda _config: fake) + + cfg = Config(provider="gemini", model=GEMINI_MODEL, use_mock=True) + handle = CacheHandle( + name="cachedContents/test", + model=GEMINI_MODEL, + provider="gemini", + expires_at=time.time() + 3600, + ) + + with pytest.raises(ConfigurationError, match="system_instruction cannot be used"): + await pollux.run_many( + prompts=("Q",), + sources=(Source.from_text("shared context"),), + config=cfg, + options=Options(cache=handle, system_instruction="Be concise."), + ) + + assert fake.last_parts is None + + +@pytest.mark.asyncio +async def test_options_cache_rejects_tools( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """tools cannot coexist with a cache handle.""" + import time + + fake = FakeProvider() + monkeypatch.setattr(pollux, "_get_provider", lambda _config: fake) + + cfg = Config(provider="gemini", model=GEMINI_MODEL, use_mock=True) + handle = CacheHandle( + name="cachedContents/test", + model=GEMINI_MODEL, + provider="gemini", + expires_at=time.time() + 3600, + ) + + tools = [ + { + "name": "get_weather", + "description": "Get weather", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + }, + } + ] + + with pytest.raises(ConfigurationError, match="tools cannot be used"): + await pollux.run_many( + prompts=("Q",), + sources=(Source.from_text("shared context"),), + config=cfg, + options=Options(cache=handle, tools=tools), + ) + + assert fake.last_parts is None + + +@pytest.mark.asyncio +async def test_options_cache_rejects_tool_choice( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """tool_choice cannot coexist with a cache handle.""" + import time + + fake = FakeProvider() + monkeypatch.setattr(pollux, "_get_provider", lambda _config: fake) + + cfg = Config(provider="gemini", model=GEMINI_MODEL, use_mock=True) + handle = CacheHandle( + name="cachedContents/test", + model=GEMINI_MODEL, + provider="gemini", + expires_at=time.time() + 3600, + ) + + with pytest.raises(ConfigurationError, match="tool_choice cannot be used") as exc: + await pollux.run_many( + prompts=("Q",), + sources=(Source.from_text("shared context"),), + config=cfg, + options=Options(cache=handle, tool_choice="required"), + ) + + assert exc.value.hint is not None + assert "Remove tool_choice" in exc.value.hint + assert fake.last_parts is None def test_cache_identity_uses_content_digest_not_identifier_only() -> None: @@ -818,6 +1101,157 @@ def test_cache_identity_includes_system_instruction() -> None: assert concise != verbose +def test_cache_identity_includes_provider() -> None: + """Identical model/content across providers must not share cache keys.""" + source = Source.from_text("shared context") + gemini = compute_cache_key(GEMINI_MODEL, (source,), provider="gemini") + openai = compute_cache_key(GEMINI_MODEL, (source,), provider="openai") + + assert gemini != openai + + +def test_cache_identity_includes_api_key() -> None: + """Different API keys for the same provider/model must not share cache keys.""" + source = Source.from_text("shared context") + key_a = compute_cache_key( + GEMINI_MODEL, (source,), provider="gemini", api_key="key-aaa" + ) + key_b = compute_cache_key( + GEMINI_MODEL, (source,), provider="gemini", api_key="key-bbb" + ) + + assert key_a != key_b + + +@pytest.mark.asyncio +async def test_create_cache_returns_handle(monkeypatch: pytest.MonkeyPatch) -> None: + """create_cache() should return a CacheHandle with the expected fields.""" + fake = FakeProvider() + monkeypatch.setattr(pollux, "_get_provider", lambda _config: fake) + monkeypatch.setattr(pollux.cache, "_registry", CacheRegistry()) + + cfg = Config(provider="gemini", model=CACHE_MODEL, use_mock=True) + handle = await pollux.create_cache( + (Source.from_text("hello"),), + config=cfg, + ttl_seconds=600, + ) + + assert isinstance(handle, CacheHandle) + assert handle.name == "cachedContents/test" + assert handle.model == CACHE_MODEL + assert handle.provider == "gemini" + assert handle.expires_at > 0 + assert fake.cache_calls == 1 + + +@pytest.mark.asyncio +async def test_create_cache_cache_hit_skips_uploads( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Any, +) -> None: + """Repeated create_cache() calls for the same key should not re-upload files.""" + fake = FakeProvider() + monkeypatch.setattr(pollux, "_get_provider", lambda _config: fake) + monkeypatch.setattr(pollux.cache, "_registry", CacheRegistry()) + + file_path = tmp_path / "cache-me.txt" + file_path.write_text("hello cache", encoding="utf-8") + + cfg = Config(provider="gemini", model=CACHE_MODEL, use_mock=True) + first = await pollux.create_cache((Source.from_file(file_path),), config=cfg) + second = await pollux.create_cache((Source.from_file(file_path),), config=cfg) + + assert first.name == second.name + assert fake.cache_calls == 1 + assert fake.upload_calls == 1 + + +@pytest.mark.asyncio +async def test_create_cache_deduplicates_file_uploads_within_call( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Any, +) -> None: + """Duplicate file sources in a single create_cache() should upload only once.""" + fake = FakeProvider() + monkeypatch.setattr(pollux, "_get_provider", lambda _config: fake) + monkeypatch.setattr(pollux.cache, "_registry", CacheRegistry()) + + file_path = tmp_path / "dup.txt" + file_path.write_text("same content", encoding="utf-8") + + cfg = Config(provider="gemini", model=CACHE_MODEL, use_mock=True) + src = Source.from_file(file_path) + handle = await pollux.create_cache((src, src), config=cfg) + + assert isinstance(handle, CacheHandle) + assert fake.upload_calls == 1 + + +@pytest.mark.asyncio +async def test_create_cache_rejects_unserializable_tools( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """create_cache() should raise ConfigurationError for non-dict tool items.""" + fake = FakeProvider() + monkeypatch.setattr(pollux, "_get_provider", lambda _config: fake) + monkeypatch.setattr(pollux.cache, "_registry", CacheRegistry()) + + cfg = Config(provider="gemini", model=CACHE_MODEL, use_mock=True) + + class CustomTool: + pass + + with pytest.raises(ConfigurationError, match="must be a dictionary") as exc: + await pollux.create_cache( + (Source.from_text("hello"),), + config=cfg, + tools=[CustomTool()], + ) + + assert exc.value.hint is not None + assert fake.upload_calls == 0 + assert fake.cache_calls == 0 + + +@pytest.mark.asyncio +async def test_create_cache_rejects_unsupported_provider( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """create_cache() should raise ConfigurationError for providers without persistent_cache.""" + fake = FakeProvider( + _capabilities=ProviderCapabilities( + persistent_cache=False, + uploads=True, + structured_outputs=False, + reasoning=False, + ) + ) + monkeypatch.setattr(pollux, "_get_provider", lambda _config: fake) + + cfg = Config(provider="gemini", model=GEMINI_MODEL, use_mock=True) + with pytest.raises(ConfigurationError, match="persistent caching"): + await pollux.create_cache( + (Source.from_text("hello"),), + config=cfg, + ) + + +@pytest.mark.asyncio +async def test_create_cache_validates_ttl() -> None: + """create_cache() should reject invalid ttl_seconds synchronously (via coroutine).""" + cfg = Config(provider="gemini", model=GEMINI_MODEL, use_mock=True) + + with pytest.raises(ConfigurationError, match="ttl_seconds"): + await pollux.create_cache( + (Source.from_text("hello"),), config=cfg, ttl_seconds=0 + ) + with pytest.raises(ConfigurationError, match="ttl_seconds"): + await pollux.create_cache( + (Source.from_text("hello"),), config=cfg, ttl_seconds=-1 + ) + + @pytest.mark.asyncio async def test_options_response_schema_requires_provider_capability() -> None: """Strict capability checks reject unsupported structured outputs.""" @@ -859,7 +1293,7 @@ class ExampleSchema(BaseModel): fake = FakeProvider( _capabilities=ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, structured_outputs=True, reasoning=True, @@ -898,7 +1332,7 @@ async def test_delivery_mode_deferred_is_explicitly_not_implemented( """Deferred delivery should fail clearly until backend support lands.""" fake = FakeProvider( _capabilities=ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, structured_outputs=True, reasoning=True, @@ -931,7 +1365,7 @@ class Paper(BaseModel): class _StructuredProvider(FakeProvider): _capabilities: ProviderCapabilities = field( default_factory=lambda: ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, structured_outputs=True, reasoning=False, @@ -986,7 +1420,7 @@ async def test_conversation_options_are_forwarded_when_provider_supports_them( """Conversation options should pass through when provider supports the feature.""" fake = FakeProvider( _capabilities=ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, structured_outputs=True, reasoning=True, @@ -1015,7 +1449,7 @@ async def test_conversation_requires_single_prompt_per_call( """Conversation continuity is single-turn per API call in v1.1.""" fake = FakeProvider( _capabilities=ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, structured_outputs=True, reasoning=False, @@ -1041,7 +1475,7 @@ async def test_continue_from_requires_conversation_state( """continue_from must include _conversation_state; valid state is forwarded.""" fake = KwargsCaptureProvider( _capabilities=ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, structured_outputs=True, reasoning=False, @@ -1090,7 +1524,7 @@ async def test_conversation_result_includes_conversation_state( class _ConversationProvider(FakeProvider): _capabilities: ProviderCapabilities = field( default_factory=lambda: ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, structured_outputs=False, reasoning=False, @@ -1137,7 +1571,7 @@ async def test_conversation_state_preserves_provider_state_from_response( class _ProviderStateConversationProvider(FakeProvider): _capabilities: ProviderCapabilities = field( default_factory=lambda: ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, structured_outputs=False, reasoning=True, @@ -1288,7 +1722,7 @@ class Paper(BaseModel): fake = ScriptedProvider( _capabilities=ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, structured_outputs=True, reasoning=False, @@ -1350,7 +1784,7 @@ async def test_tool_call_response_populates_conversation_state_without_history( class _ToolCallProvider(FakeProvider): _capabilities: ProviderCapabilities = field( default_factory=lambda: ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, structured_outputs=False, reasoning=False, @@ -1410,7 +1844,7 @@ async def test_tool_calls_preserved_in_conversation_state( class _ToolCallProvider(FakeProvider): _capabilities: ProviderCapabilities = field( default_factory=lambda: ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, structured_outputs=False, reasoning=False, @@ -1452,7 +1886,7 @@ async def test_continue_from_preserves_tool_history_items( """continue_from with tool messages in history passes them to provider.""" fake = KwargsCaptureProvider( _capabilities=ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, structured_outputs=False, reasoning=False, @@ -1501,7 +1935,7 @@ async def test_continue_from_forwards_provider_state_with_history_items( """History item provider_state should be forwarded in request.provider_state.""" fake = KwargsCaptureProvider( _capabilities=ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, structured_outputs=False, reasoning=True, @@ -1561,7 +1995,7 @@ async def test_continue_tool_mechanics(monkeypatch: pytest.MonkeyPatch) -> None: """continue_tool should neatly append tool results and allow None prompt.""" fake = KwargsCaptureProvider( _capabilities=ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, structured_outputs=False, reasoning=False, @@ -1621,7 +2055,7 @@ async def test_reasoning_surfaced_in_result_envelope( """Provider reasoning text should appear in ResultEnvelope.reasoning.""" fake = ScriptedProvider( _capabilities=ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, reasoning=True, ), @@ -1653,7 +2087,7 @@ async def test_reasoning_omitted_when_absent( """ResultEnvelope should not include reasoning key when provider omits it.""" fake = ScriptedProvider( _capabilities=ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, reasoning=True, ), @@ -1676,7 +2110,7 @@ async def test_reasoning_mixed_across_multi_prompt( """Multi-prompt: reasoning=None for calls without thinking content.""" fake = ScriptedProvider( _capabilities=ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, reasoning=True, ), @@ -1705,7 +2139,7 @@ async def test_reasoning_tokens_aggregate_in_result_usage( """Pipeline should preserve and sum reasoning_tokens across calls.""" fake = ScriptedProvider( _capabilities=ProviderCapabilities( - caching=True, + persistent_cache=True, uploads=True, reasoning=True, ),