Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions agex/host/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,31 @@
from agex.host.dependencies import Dependencies
from agex.state import State
from agex.state.config import StateConfig
from agex.state.kv import KVStore


def apply_init_if_fresh(
state: "State",
kv: "KVStore",
init: "Callable[[], dict[str, Any]] | dict[str, Any] | None",
) -> None:
"""Apply init vars if state is fresh (no sentinel).
Args:
state: The state object to initialize variables on
kv: The underlying KVStore to check for sentinel
init: Callable or dict of init variables (or None to skip)
"""
if init is not None and "__agex_init__" not in kv:
init_vars = init() if callable(init) else init
for key, value in init_vars.items():
state.set(key, value)
state.set("__agex_init__", True)
# Commit snapshot for versioned state
from agex.state import Versioned

if isinstance(state, Versioned):
state.snapshot()


class Host(ABC):
Expand Down
12 changes: 9 additions & 3 deletions agex/host/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ def _create_state(self, config: "StateConfig", kv: Any) -> "State":
from agex.state import Live, Versioned
from agex.state.gc import GCVersioned

from .base import apply_init_if_fresh

if config.type == "versioned":
state: "State" = Versioned(store=kv)
# Wrap with GC if high_water_bytes is set
Expand All @@ -94,11 +96,15 @@ def _create_state(self, config: "StateConfig", kv: Any) -> "State":
high_water_bytes=config.high_water_bytes,
low_water_bytes=config.low_water_bytes,
)
return state
elif config.type == "live":
return Live()
state = Live()
else: # ephemeral
return Live()
state = Live()

# Apply init if provided and state is fresh
apply_init_if_fresh(state, kv, config.init)

return state

def execute(
self,
Expand Down
10 changes: 7 additions & 3 deletions agex/host/modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable

from agex.host.base import Host
from agex.host.base import Host, apply_init_if_fresh
from agex.host.local import Local
from agex.state import Live, Versioned
from agex.state.kv.modal_dict import ModalDict
Expand Down Expand Up @@ -297,7 +297,9 @@ def sanitize_name(name: str) -> str:

cache = Disk(str(cache_dir))
kv = Composite([cache, source, volume])
return Versioned(store=kv)
state = Versioned(store=kv)
apply_init_if_fresh(state, kv, config.init)
return state

else:
# Two-tier: Disk → ModalDict (memory storage)
Expand All @@ -314,7 +316,9 @@ def sanitize_name(name: str) -> str:

cache = Disk(str(cache_dir))
kv = Composite([cache, source])
return Versioned(store=kv)
state = Versioned(store=kv)
apply_init_if_fresh(state, kv, config.init)
return state

def execute(
self,
Expand Down
15 changes: 14 additions & 1 deletion agex/state/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""A state management system for tic agents."""

from typing import Literal, cast
from typing import Any, Callable, Literal, cast

from ..agent.events import Event
from .config import StateConfig
Expand Down Expand Up @@ -32,6 +32,7 @@
def connect_state(
type: Literal["ephemeral", "versioned", "live"],
storage: str | None = None,
init: "Callable[[], dict[str, Any]] | dict[str, Any] | None" = None,
**kwargs,
) -> StateConfig:
"""
Expand All @@ -40,6 +41,9 @@ def connect_state(
Args:
type: State semantics ("ephemeral", "versioned", or "live")
storage: Storage backend ("memory" or "disk"). Not required for ephemeral.
init: Callable or dict to initialize state variables on first session creation.
If a callable, it will be invoked when the session is first created.
The returned dict keys become variable names in the agent's namespace.
**kwargs: Type and storage-specific arguments

Storage-specific kwargs:
Expand Down Expand Up @@ -67,6 +71,14 @@ def connect_state(
path="~/.agex/state",
high_water_bytes=100_000_000,
)

# Versioned with initial variables
connect_state(
type="versioned",
storage="disk",
path="/tmp/agex/tmnt",
init=lambda: {"leo": load_cal("leo.ics"), ...},
)
"""
# Validate storage requirements
if type != "ephemeral" and storage is None:
Expand Down Expand Up @@ -97,6 +109,7 @@ def connect_state(
high_water_bytes=kwargs.get("high_water_bytes"),
low_water_bytes=kwargs.get("low_water_bytes"),
options=options if options else None,
init=init,
)


Expand Down
7 changes: 6 additions & 1 deletion agex/state/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""State configuration for agent state management."""

from dataclasses import dataclass
from typing import Any, Literal
from typing import Any, Callable, Literal

# Type alias for init parameter
InitVars = Callable[[], dict[str, Any]] | dict[str, Any] | None


@dataclass
Expand All @@ -18,6 +21,7 @@ class StateConfig:
path: Directory path for disk storage
high_water_bytes: Trigger GC when total size exceeds this (versioned only)
low_water_bytes: Target size after GC (versioned only, default: 80% of high_water)
init: Callable or dict to initialize state variables on first session creation
"""

type: Literal["ephemeral", "versioned", "live"]
Expand All @@ -26,6 +30,7 @@ class StateConfig:
high_water_bytes: int | None = None
low_water_bytes: int | None = None
options: dict[str, Any] | None = None
init: InitVars = None

def dump_config(self) -> dict[str, Any]:
"""Serialize for remote reconstruction."""
Expand Down
37 changes: 37 additions & 0 deletions docs/api/state.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ state_config = connect_state(
type: Literal["versioned", "live"] = "versioned",
storage: Literal["memory", "disk"] = "memory",
path: str | None = None, # Required for disk storage
init: Callable[[], dict] | dict | None = None, # Initialize state vars
)
```

Expand All @@ -21,6 +22,7 @@ state_config = connect_state(
| `type` | `str` | `"versioned"` | State type: `"versioned"` (with checkpointing) or `"live"` (in-memory only) |
| `storage` | `str` | `"memory"` | Storage backend: `"memory"` or `"disk"` |
| `path` | `str \| None` | `None` | Path for disk storage (required when `storage="disk"`) |
| `init` | `Callable \| dict \| None` | `None` | Initialize state variables on first session creation |

## State Types

Expand Down Expand Up @@ -138,6 +140,41 @@ state = connect_state(type="versioned", storage="disk", path="/var/agex/state")

**Use for:** Production, remote execution, long-running workflows.

## State Initialization

The `init` parameter lets you populate state variables when a session is first created. This is useful for loading data that should be mutable within state (e.g., calendars, datasets, config objects).

```python
from agex import Agent, connect_state

def load_initial_data():
"""Called once per new session."""
return {
"calendar": load_calendar("events.ics"),
"config": {"theme": "dark", "locale": "en"},
}

agent = Agent(
primer="You manage my calendar.",
state=connect_state(
type="versioned",
storage="disk",
path="/tmp/agex/calendar",
init=load_initial_data, # Callable - deferred until first session
),
)
```

**How it works:**
1. On first access to a session, `init()` is called (or dict is used directly)
2. Each key-value pair is set in state
3. A sentinel (`__agex_init__`) marks the session as initialized
4. For versioned state, a snapshot is committed
5. Subsequent calls skip init (sentinel detected)

> [!TIP]
> Use a callable for lazy initialization - it defers loading until the session is actually created, avoiding work at agent definition time.

## Features of Versioned State

### Automatic Checkpointing
Expand Down
133 changes: 133 additions & 0 deletions tests/agex/state/test_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
"""Tests for connect_state(init=...) parameter."""

from agex import connect_state
from agex.host.local import Local


class TestStateInit:
"""Tests for state initialization via connect_state(init=...)."""

def test_init_callable_runs_on_fresh_state(self):
"""Init callable is invoked when state is fresh."""
call_count = 0

def my_init():
nonlocal call_count
call_count += 1
return {"x": 42, "y": "hello"}

config = connect_state(type="versioned", storage="memory", init=my_init)
host = Local()
state = host.resolve_state(config, "test_session")

assert call_count == 1
assert state.get("x") == 42
assert state.get("y") == "hello"

def test_init_dict_applied_on_fresh_state(self):
"""Init dict is applied when state is fresh."""
config = connect_state(
type="versioned", storage="memory", init={"a": 1, "b": [1, 2, 3]}
)
host = Local()
state = host.resolve_state(config, "test_session")

assert state.get("a") == 1
assert state.get("b") == [1, 2, 3]

def test_init_not_called_on_existing_state(self):
"""Init is not called when state already has sentinel."""
call_count = 0

def my_init():
nonlocal call_count
call_count += 1
return {"x": call_count}

config = connect_state(type="versioned", storage="memory", init=my_init)
host = Local()

# First call - should run init
state1 = host.resolve_state(config, "test_session")
assert call_count == 1
assert state1.get("x") == 1

# Second call - should NOT run init (reuses cached state with sentinel)
state2 = host.resolve_state(config, "test_session")
assert call_count == 1 # Still 1, not called again
assert state2.get("x") == 1

def test_init_sets_sentinel(self):
"""Init sets the __agex_init__ sentinel."""
config = connect_state(type="versioned", storage="memory", init={"x": 1})
host = Local()
state = host.resolve_state(config, "test_session")

assert state.get("__agex_init__") is True

def test_init_none_does_nothing(self):
"""No init when init=None (default)."""
config = connect_state(type="versioned", storage="memory")
host = Local()
state = host.resolve_state(config, "test_session")

assert state.get("__agex_init__") is None
assert "__agex_init__" not in state

def test_init_commits_snapshot(self):
"""Init vars are persisted in a snapshot."""
config = connect_state(type="versioned", storage="memory", init={"x": 42})
host = Local()
state = host.resolve_state(config, "test_session")

# Check that a snapshot was taken (history has commits)
history = list(state.history())
assert len(history) >= 1 # At least the init snapshot
# Initial commit should contain x
assert state.get("x") == 42

def test_init_different_sessions_independent(self):
"""Different sessions each get their own init."""
call_count = 0

def my_init():
nonlocal call_count
call_count += 1
return {"session_num": call_count}

config = connect_state(type="versioned", storage="memory", init=my_init)
host = Local()

state1 = host.resolve_state(config, "session_a")
state2 = host.resolve_state(config, "session_b")

assert call_count == 2 # Called once per session
assert state1.get("session_num") == 1
assert state2.get("session_num") == 2

def test_init_with_disk_storage(self, tmp_path):
"""Init works with disk storage."""
call_count = 0

def my_init():
nonlocal call_count
call_count += 1
return {"x": 42}

config = connect_state(
type="versioned", storage="disk", path=str(tmp_path), init=my_init
)
host = Local()

# First call - init runs
state1 = host.resolve_state(config, "test_session")
assert call_count == 1
assert state1.get("x") == 42

# Create new host (simulates restart) - init should NOT run
host2 = Local()
state2 = host2.resolve_state(config, "test_session")
# Note: call_count check depends on whether we're reloading from disk
# The sentinel should prevent re-init
assert state2.get("x") == 42
assert state2.get("__agex_init__") is True