diff --git a/agex/host/base.py b/agex/host/base.py index 1f36ef1..c8207f4 100644 --- a/agex/host/base.py +++ b/agex/host/base.py @@ -12,6 +12,31 @@ from agex.host.dependencies import Dependencies from agex.state import State from agex.state.config import StateConfig + from agex.state.kv import KVStore + + +def apply_init_if_fresh( + state: "State", + kv: "KVStore", + init: "Callable[[], dict[str, Any]] | dict[str, Any] | None", +) -> None: + """Apply init vars if state is fresh (no sentinel). + + Args: + state: The state object to initialize variables on + kv: The underlying KVStore to check for sentinel + init: Callable or dict of init variables (or None to skip) + """ + if init is not None and "__agex_init__" not in kv: + init_vars = init() if callable(init) else init + for key, value in init_vars.items(): + state.set(key, value) + state.set("__agex_init__", True) + # Commit snapshot for versioned state + from agex.state import Versioned + + if isinstance(state, Versioned): + state.snapshot() class Host(ABC): diff --git a/agex/host/local.py b/agex/host/local.py index 2cb329c..559f324 100644 --- a/agex/host/local.py +++ b/agex/host/local.py @@ -85,6 +85,8 @@ def _create_state(self, config: "StateConfig", kv: Any) -> "State": from agex.state import Live, Versioned from agex.state.gc import GCVersioned + from .base import apply_init_if_fresh + if config.type == "versioned": state: "State" = Versioned(store=kv) # Wrap with GC if high_water_bytes is set @@ -94,11 +96,15 @@ def _create_state(self, config: "StateConfig", kv: Any) -> "State": high_water_bytes=config.high_water_bytes, low_water_bytes=config.low_water_bytes, ) - return state elif config.type == "live": - return Live() + state = Live() else: # ephemeral - return Live() + state = Live() + + # Apply init if provided and state is fresh + apply_init_if_fresh(state, kv, config.init) + + return state def execute( self, diff --git a/agex/host/modal.py b/agex/host/modal.py index e829254..71afeea 100644 --- a/agex/host/modal.py +++ b/agex/host/modal.py @@ -10,7 +10,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Callable -from agex.host.base import Host +from agex.host.base import Host, apply_init_if_fresh from agex.host.local import Local from agex.state import Live, Versioned from agex.state.kv.modal_dict import ModalDict @@ -297,7 +297,9 @@ def sanitize_name(name: str) -> str: cache = Disk(str(cache_dir)) kv = Composite([cache, source, volume]) - return Versioned(store=kv) + state = Versioned(store=kv) + apply_init_if_fresh(state, kv, config.init) + return state else: # Two-tier: Disk → ModalDict (memory storage) @@ -314,7 +316,9 @@ def sanitize_name(name: str) -> str: cache = Disk(str(cache_dir)) kv = Composite([cache, source]) - return Versioned(store=kv) + state = Versioned(store=kv) + apply_init_if_fresh(state, kv, config.init) + return state def execute( self, diff --git a/agex/state/__init__.py b/agex/state/__init__.py index 5149aab..9b8e4c9 100644 --- a/agex/state/__init__.py +++ b/agex/state/__init__.py @@ -1,6 +1,6 @@ """A state management system for tic agents.""" -from typing import Literal, cast +from typing import Any, Callable, Literal, cast from ..agent.events import Event from .config import StateConfig @@ -32,6 +32,7 @@ def connect_state( type: Literal["ephemeral", "versioned", "live"], storage: str | None = None, + init: "Callable[[], dict[str, Any]] | dict[str, Any] | None" = None, **kwargs, ) -> StateConfig: """ @@ -40,6 +41,9 @@ def connect_state( Args: type: State semantics ("ephemeral", "versioned", or "live") storage: Storage backend ("memory" or "disk"). Not required for ephemeral. + init: Callable or dict to initialize state variables on first session creation. + If a callable, it will be invoked when the session is first created. + The returned dict keys become variable names in the agent's namespace. **kwargs: Type and storage-specific arguments Storage-specific kwargs: @@ -67,6 +71,14 @@ def connect_state( path="~/.agex/state", high_water_bytes=100_000_000, ) + + # Versioned with initial variables + connect_state( + type="versioned", + storage="disk", + path="/tmp/agex/tmnt", + init=lambda: {"leo": load_cal("leo.ics"), ...}, + ) """ # Validate storage requirements if type != "ephemeral" and storage is None: @@ -97,6 +109,7 @@ def connect_state( high_water_bytes=kwargs.get("high_water_bytes"), low_water_bytes=kwargs.get("low_water_bytes"), options=options if options else None, + init=init, ) diff --git a/agex/state/config.py b/agex/state/config.py index a9d9720..ad776dd 100644 --- a/agex/state/config.py +++ b/agex/state/config.py @@ -1,7 +1,10 @@ """State configuration for agent state management.""" from dataclasses import dataclass -from typing import Any, Literal +from typing import Any, Callable, Literal + +# Type alias for init parameter +InitVars = Callable[[], dict[str, Any]] | dict[str, Any] | None @dataclass @@ -18,6 +21,7 @@ class StateConfig: path: Directory path for disk storage high_water_bytes: Trigger GC when total size exceeds this (versioned only) low_water_bytes: Target size after GC (versioned only, default: 80% of high_water) + init: Callable or dict to initialize state variables on first session creation """ type: Literal["ephemeral", "versioned", "live"] @@ -26,6 +30,7 @@ class StateConfig: high_water_bytes: int | None = None low_water_bytes: int | None = None options: dict[str, Any] | None = None + init: InitVars = None def dump_config(self) -> dict[str, Any]: """Serialize for remote reconstruction.""" diff --git a/docs/api/state.md b/docs/api/state.md index 66eada9..322e3ad 100644 --- a/docs/api/state.md +++ b/docs/api/state.md @@ -11,6 +11,7 @@ state_config = connect_state( type: Literal["versioned", "live"] = "versioned", storage: Literal["memory", "disk"] = "memory", path: str | None = None, # Required for disk storage + init: Callable[[], dict] | dict | None = None, # Initialize state vars ) ``` @@ -21,6 +22,7 @@ state_config = connect_state( | `type` | `str` | `"versioned"` | State type: `"versioned"` (with checkpointing) or `"live"` (in-memory only) | | `storage` | `str` | `"memory"` | Storage backend: `"memory"` or `"disk"` | | `path` | `str \| None` | `None` | Path for disk storage (required when `storage="disk"`) | +| `init` | `Callable \| dict \| None` | `None` | Initialize state variables on first session creation | ## State Types @@ -138,6 +140,41 @@ state = connect_state(type="versioned", storage="disk", path="/var/agex/state") **Use for:** Production, remote execution, long-running workflows. +## State Initialization + +The `init` parameter lets you populate state variables when a session is first created. This is useful for loading data that should be mutable within state (e.g., calendars, datasets, config objects). + +```python +from agex import Agent, connect_state + +def load_initial_data(): + """Called once per new session.""" + return { + "calendar": load_calendar("events.ics"), + "config": {"theme": "dark", "locale": "en"}, + } + +agent = Agent( + primer="You manage my calendar.", + state=connect_state( + type="versioned", + storage="disk", + path="/tmp/agex/calendar", + init=load_initial_data, # Callable - deferred until first session + ), +) +``` + +**How it works:** +1. On first access to a session, `init()` is called (or dict is used directly) +2. Each key-value pair is set in state +3. A sentinel (`__agex_init__`) marks the session as initialized +4. For versioned state, a snapshot is committed +5. Subsequent calls skip init (sentinel detected) + +> [!TIP] +> Use a callable for lazy initialization - it defers loading until the session is actually created, avoiding work at agent definition time. + ## Features of Versioned State ### Automatic Checkpointing diff --git a/tests/agex/state/test_init.py b/tests/agex/state/test_init.py new file mode 100644 index 0000000..699e3c4 --- /dev/null +++ b/tests/agex/state/test_init.py @@ -0,0 +1,133 @@ +"""Tests for connect_state(init=...) parameter.""" + +from agex import connect_state +from agex.host.local import Local + + +class TestStateInit: + """Tests for state initialization via connect_state(init=...).""" + + def test_init_callable_runs_on_fresh_state(self): + """Init callable is invoked when state is fresh.""" + call_count = 0 + + def my_init(): + nonlocal call_count + call_count += 1 + return {"x": 42, "y": "hello"} + + config = connect_state(type="versioned", storage="memory", init=my_init) + host = Local() + state = host.resolve_state(config, "test_session") + + assert call_count == 1 + assert state.get("x") == 42 + assert state.get("y") == "hello" + + def test_init_dict_applied_on_fresh_state(self): + """Init dict is applied when state is fresh.""" + config = connect_state( + type="versioned", storage="memory", init={"a": 1, "b": [1, 2, 3]} + ) + host = Local() + state = host.resolve_state(config, "test_session") + + assert state.get("a") == 1 + assert state.get("b") == [1, 2, 3] + + def test_init_not_called_on_existing_state(self): + """Init is not called when state already has sentinel.""" + call_count = 0 + + def my_init(): + nonlocal call_count + call_count += 1 + return {"x": call_count} + + config = connect_state(type="versioned", storage="memory", init=my_init) + host = Local() + + # First call - should run init + state1 = host.resolve_state(config, "test_session") + assert call_count == 1 + assert state1.get("x") == 1 + + # Second call - should NOT run init (reuses cached state with sentinel) + state2 = host.resolve_state(config, "test_session") + assert call_count == 1 # Still 1, not called again + assert state2.get("x") == 1 + + def test_init_sets_sentinel(self): + """Init sets the __agex_init__ sentinel.""" + config = connect_state(type="versioned", storage="memory", init={"x": 1}) + host = Local() + state = host.resolve_state(config, "test_session") + + assert state.get("__agex_init__") is True + + def test_init_none_does_nothing(self): + """No init when init=None (default).""" + config = connect_state(type="versioned", storage="memory") + host = Local() + state = host.resolve_state(config, "test_session") + + assert state.get("__agex_init__") is None + assert "__agex_init__" not in state + + def test_init_commits_snapshot(self): + """Init vars are persisted in a snapshot.""" + config = connect_state(type="versioned", storage="memory", init={"x": 42}) + host = Local() + state = host.resolve_state(config, "test_session") + + # Check that a snapshot was taken (history has commits) + history = list(state.history()) + assert len(history) >= 1 # At least the init snapshot + # Initial commit should contain x + assert state.get("x") == 42 + + def test_init_different_sessions_independent(self): + """Different sessions each get their own init.""" + call_count = 0 + + def my_init(): + nonlocal call_count + call_count += 1 + return {"session_num": call_count} + + config = connect_state(type="versioned", storage="memory", init=my_init) + host = Local() + + state1 = host.resolve_state(config, "session_a") + state2 = host.resolve_state(config, "session_b") + + assert call_count == 2 # Called once per session + assert state1.get("session_num") == 1 + assert state2.get("session_num") == 2 + + def test_init_with_disk_storage(self, tmp_path): + """Init works with disk storage.""" + call_count = 0 + + def my_init(): + nonlocal call_count + call_count += 1 + return {"x": 42} + + config = connect_state( + type="versioned", storage="disk", path=str(tmp_path), init=my_init + ) + host = Local() + + # First call - init runs + state1 = host.resolve_state(config, "test_session") + assert call_count == 1 + assert state1.get("x") == 42 + + # Create new host (simulates restart) - init should NOT run + host2 = Local() + state2 = host2.resolve_state(config, "test_session") + # Note: call_count check depends on whether we're reloading from disk + # The sentinel should prevent re-init + assert state2.get("x") == 42 + assert state2.get("__agex_init__") is True