From 7a8bdc4089417404c9528b4c584c831a5317473d Mon Sep 17 00:00:00 2001 From: Adam Ashenfelter Date: Mon, 29 Dec 2025 19:02:19 -0800 Subject: [PATCH 1/2] feat: commit updates on task result events --- agex/agent/loop/async_loop.py | 38 +++++++++++++++++++++++++++++++++-- agex/agent/loop/common.py | 2 ++ agex/agent/loop/sync_loop.py | 38 +++++++++++++++++++++++++++++++++-- agex/state/__init__.py | 3 ++- agex/state/versioned.py | 12 +++++++++-- docs/api/events.md | 4 +++- 6 files changed, 89 insertions(+), 8 deletions(-) diff --git a/agex/agent/loop/async_loop.py b/agex/agent/loop/async_loop.py index b507363..121dff6 100644 --- a/agex/agent/loop/async_loop.py +++ b/agex/agent/loop/async_loop.py @@ -49,8 +49,9 @@ create_task_start_event, create_unsaved_warning, events, - get_events_from_log, # State helpers + get_commit_hash, + get_events_from_log, initialize_exec_state, is_live_root, yield_new_events, @@ -183,12 +184,17 @@ def setup_on_event(event): for iteration in range(self.max_iterations): # Check for cancellation at the start of each iteration if check_cancellation(task_name, versioned_state, exec_state): + # Pre-generate commit hash so the terminal event can reference + # the commit that will include it + next_commit = get_commit_hash() if versioned_state else None + # Record CancelledEvent in the log FIRST cancelled_event = CancelledEvent( agent_name=self.name, task_name=task_name, iterations_completed=iteration, ) + cancelled_event.commit_hash = next_commit add_event_to_log(exec_state, cancelled_event, on_event=None) if on_event: res = call_sync_or_async(on_event, cancelled_event) @@ -198,7 +204,7 @@ def setup_on_event(event): # Snapshot AFTER adding the event so it's included if versioned_state is not None: - versioned_state.snapshot() + versioned_state.snapshot(commit_hash=next_commit) raise TaskCancelled( message=f"Task '{task_name}' was cancelled", @@ -256,13 +262,23 @@ def setup_on_event(event): yield event events_yielded = len(events(exec_state)) + # Pre-generate commit hash so the terminal event can reference + # the commit that will include it + next_commit = get_commit_hash() if versioned_state else None + success_event = create_success_event(self.name, task_signal.result) + success_event.commit_hash = next_commit add_event_to_log(exec_state, success_event, on_event=None) if on_event: res = call_sync_or_async(on_event, success_event) if inspect.isawaitable(res): await res yield success_event + + # Snapshot with the pre-generated hash so event.commit_hash matches + if versioned_state is not None: + versioned_state.snapshot(commit_hash=next_commit) + return except TaskContinue: @@ -276,7 +292,12 @@ def setup_on_event(event): yield event events_yielded = len(events(exec_state)) + # Pre-generate commit hash so the terminal event can reference + # the commit that will include it + next_commit = get_commit_hash() if versioned_state else None + clarify_event = create_clarify_event(self.name, task_clarify.message) + clarify_event.commit_hash = next_commit add_event_to_log(exec_state, clarify_event, on_event=None) if on_event: res = call_sync_or_async(on_event, clarify_event) @@ -284,6 +305,10 @@ def setup_on_event(event): await res yield clarify_event + # Snapshot with the pre-generated hash so event.commit_hash matches + if versioned_state is not None: + versioned_state.snapshot(commit_hash=next_commit) + if isinstance(state, Namespaced): raise EvalError( f"Sub-agent needs clarification: {task_clarify.message}", None @@ -296,7 +321,12 @@ def setup_on_event(event): yield event events_yielded = len(events(exec_state)) + # Pre-generate commit hash so the terminal event can reference + # the commit that will include it + next_commit = get_commit_hash() if versioned_state else None + fail_event = create_fail_event(self.name, task_fail.message) + fail_event.commit_hash = next_commit add_event_to_log(exec_state, fail_event, on_event=None) if on_event: res = call_sync_or_async(on_event, fail_event) @@ -304,6 +334,10 @@ def setup_on_event(event): await res yield fail_event + # Snapshot with the pre-generated hash so event.commit_hash matches + if versioned_state is not None: + versioned_state.snapshot(commit_hash=next_commit) + if isinstance(state, Namespaced): raise EvalError(f"Sub-agent failed: {task_fail.message}", None) else: diff --git a/agex/agent/loop/common.py b/agex/agent/loop/common.py index 75a1cbd..3b782c2 100644 --- a/agex/agent/loop/common.py +++ b/agex/agent/loop/common.py @@ -41,6 +41,7 @@ Namespaced, Versioned, events, + get_commit_hash, is_live_root, ) from agex.state.log import add_event_to_log, get_events_from_log @@ -60,6 +61,7 @@ "create_guidance_output", "create_unsaved_warning", # State helpers + "get_commit_hash", "initialize_exec_state", "check_for_task_call", "strip_namespace_prefix", diff --git a/agex/agent/loop/sync_loop.py b/agex/agent/loop/sync_loop.py index f7043c1..d1b25ef 100644 --- a/agex/agent/loop/sync_loop.py +++ b/agex/agent/loop/sync_loop.py @@ -43,8 +43,9 @@ create_task_start_event, create_unsaved_warning, events, - get_events_from_log, # State helpers + get_commit_hash, + get_events_from_log, initialize_exec_state, is_live_root, yield_new_events, @@ -143,18 +144,23 @@ def setup_on_event(event): for iteration in range(self.max_iterations): # Check for cancellation at the start of each iteration if check_cancellation(task_name, versioned_state, exec_state): + # Pre-generate commit hash so the terminal event can reference + # the commit that will include it + next_commit = get_commit_hash() if versioned_state else None + # Record CancelledEvent in the log FIRST cancelled_event = CancelledEvent( agent_name=self.name, task_name=task_name, iterations_completed=iteration, ) + cancelled_event.commit_hash = next_commit add_event_to_log(exec_state, cancelled_event, on_event=on_event) yield cancelled_event # Snapshot AFTER adding the event so it's included if versioned_state is not None: - versioned_state.snapshot() + versioned_state.snapshot(commit_hash=next_commit) raise TaskCancelled( message=f"Task '{task_name}' was cancelled", @@ -201,9 +207,19 @@ def setup_on_event(event): yield event events_yielded = len(events(exec_state)) + # Pre-generate commit hash so the terminal event can reference + # the commit that will include it + next_commit = get_commit_hash() if versioned_state else None + success_event = create_success_event(self.name, task_signal.result) + success_event.commit_hash = next_commit add_event_to_log(exec_state, success_event, on_event=on_event) yield success_event + + # Snapshot with the pre-generated hash so event.commit_hash matches + if versioned_state is not None: + versioned_state.snapshot(commit_hash=next_commit) + return task_signal.result except TaskContinue: @@ -217,10 +233,19 @@ def setup_on_event(event): yield event events_yielded = len(events(exec_state)) + # Pre-generate commit hash so the terminal event can reference + # the commit that will include it + next_commit = get_commit_hash() if versioned_state else None + clarify_event = create_clarify_event(self.name, task_clarify.message) + clarify_event.commit_hash = next_commit add_event_to_log(exec_state, clarify_event, on_event=on_event) yield clarify_event + # Snapshot with the pre-generated hash so event.commit_hash matches + if versioned_state is not None: + versioned_state.snapshot(commit_hash=next_commit) + if isinstance(state, Namespaced): raise EvalError( f"Sub-agent needs clarification: {task_clarify.message}", None @@ -233,10 +258,19 @@ def setup_on_event(event): yield event events_yielded = len(events(exec_state)) + # Pre-generate commit hash so the terminal event can reference + # the commit that will include it + next_commit = get_commit_hash() if versioned_state else None + fail_event = create_fail_event(self.name, task_fail.message) + fail_event.commit_hash = next_commit add_event_to_log(exec_state, fail_event, on_event=on_event) yield fail_event + # Snapshot with the pre-generated hash so event.commit_hash matches + if versioned_state is not None: + versioned_state.snapshot(commit_hash=next_commit) + if isinstance(state, Namespaced): raise EvalError(f"Sub-agent failed: {task_fail.message}", None) else: diff --git a/agex/state/__init__.py b/agex/state/__init__.py index b483081..5149aab 100644 --- a/agex/state/__init__.py +++ b/agex/state/__init__.py @@ -10,12 +10,13 @@ from .live import Live from .namespaced import Namespaced from .scoped import Scoped -from .versioned import ConcurrencyError, Versioned +from .versioned import ConcurrencyError, Versioned, get_commit_hash __all__ = [ "State", "StateConfig", "is_live_root", + "get_commit_hash", "Live", "KVStore", "Namespaced", diff --git a/agex/state/versioned.py b/agex/state/versioned.py index 32d2420..2d7b2d7 100644 --- a/agex/state/versioned.py +++ b/agex/state/versioned.py @@ -354,7 +354,15 @@ def _detect_mutations(self) -> tuple[dict[str, bytes], list[str]]: return mutations, unsavable_keys - def snapshot(self) -> SnapshotResult: + def snapshot(self, commit_hash: str | None = None) -> SnapshotResult: + """Create a new commit with the current changes. + + Args: + commit_hash: Optional pre-generated commit hash. If provided, this hash + will be used for the new commit. This is useful when the commit hash + needs to be known before the snapshot is taken (e.g., for stamping + terminal events with their post-snapshot commit). + """ # First, detect any mutations in accessed objects mutations, unsavable_keys = self._detect_mutations() unsaved_keys = list(unsavable_keys) @@ -364,7 +372,7 @@ def snapshot(self) -> SnapshotResult: self.accessed_objects.clear() # Clear tracking return SnapshotResult(self.current_commit, unsaved_keys) - new_hash = _get_commit_hash() + new_hash = commit_hash or _get_commit_hash() diffs = {} new_commit_keys = {} new_meta: dict[str, MetaEntry] = {} diff --git a/docs/api/events.md b/docs/api/events.md index d6517f7..eec2522 100644 --- a/docs/api/events.md +++ b/docs/api/events.md @@ -161,7 +161,9 @@ All events share these common properties from `BaseEvent`: - **`timestamp`**: `datetime` - UTC timestamp when the event occurred. - **`agent_name`**: `str` - Name of the agent that generated the event. - **`full_namespace`**: `str` - The agent's namespace path. Equals `agent_name` for the agent that owns the state. -- **`commit_hash`**: `str | None` - If using `Versioned` state, the commit hash of the agent's state before this event occurred. See [Inspecting Historical State](state.md#inspecting-historical-state) for how to use this for debugging. +- **`commit_hash`**: `str | None` - The commit hash linking this event to `Versioned` state. Only populated when using `Versioned` state (see [State Management](state.md)); `None` for `Live` or ephemeral state. + - **For action events** (`ActionEvent`, `OutputEvent`, etc.): The commit hash *before* the action—useful for inspecting what the agent saw when it made a decision. + - **For task result events** (`SuccessEvent`, `FailEvent`, `ClarifyEvent`, `CancelledEvent`): The commit hash *after* the result is recorded—enabling "revert to this outcome" workflows via `state.checkout(event.commit_hash)`. - **`source`**: `Literal["setup", "main"]` - The execution phase that generated the event. Defaults to `"main"`. Events generated by the `setup` parameter of `@agent.task` are tagged with `"setup"`. - **`full_detail_tokens`**: `int` - Cached token estimate for full-detail rendering. Computed automatically at event creation. - **`low_detail_tokens`**: `int` - Cached token estimate for low-detail rendering (used when event age triggers compression). Typically 25-50% of `full_detail_tokens`. Computed automatically for `TaskStartEvent`, `OutputEvent`, and `SuccessEvent`; equals `full_detail_tokens` for other event types. From 3f959456f5da1f088fe1f483d047aedc2cf43307 Mon Sep 17 00:00:00 2001 From: Adam Ashenfelter Date: Tue, 30 Dec 2025 11:01:14 -0800 Subject: [PATCH 2/2] feat: Versioned.initial_state & reset_to --- agex/state/versioned.py | 67 +++++++++++++- docs/api/events.md | 2 +- docs/api/state.md | 21 ++++- tests/agex/state/test_versioned.py | 137 +++++++++++++++++++++++++++++ 4 files changed, 222 insertions(+), 5 deletions(-) diff --git a/agex/state/versioned.py b/agex/state/versioned.py index 2d7b2d7..8fd75b1 100644 --- a/agex/state/versioned.py +++ b/agex/state/versioned.py @@ -308,6 +308,15 @@ def history(self, commit_hash: str | None = None) -> Iterable[str]: else: current_hash = None + @property + def initial_commit(self) -> str: + """ + Return the hash of the initial (root) commit. + Useful for reverting state completely to the beginning. + """ + # History yields newest-first, so the last item is the initial commit + return list(self.history())[-1] + def _detect_mutations(self) -> tuple[dict[str, bytes], list[str]]: """Detect mutations in accessed objects and auto-save them. @@ -545,12 +554,66 @@ def checkout(self, commit_hash: str) -> "Versioned | None": Args: commit_hash: The commit to checkout """ - # First, validate that the commit is in our history. - if commit_hash not in list(self.history()): + # Validate that the commit exists + if self.long_term.get(COMMIT_KEYSET % commit_hash) is None: return None return Versioned(self.long_term, commit_hash=commit_hash) + def revert_to(self, commit_hash: str) -> bool: + """ + Reset HEAD to a previous commit, discarding later history. + + This moves HEAD backward to the specified commit, making all + commits after it orphaned. Orphaned commits can be cleaned up + by GCVersioned. + + Args: + commit_hash: The commit to revert to (must be in history) + + Returns: + True if revert succeeded, False if commit not in history + """ + # Validate that the commit exists + if self.long_term.get(COMMIT_KEYSET % commit_hash) is None: + return False + + # Update HEAD to point to this commit + self.long_term.set(HEAD_COMMIT, pickle.dumps(commit_hash)) + + # Reset local state to match (same logic as reset()) + self.current_commit = commit_hash + self.base_commit = commit_hash + + # Reload commit keys + commit_keyset_bytes = self.long_term.get(COMMIT_KEYSET % commit_hash) + if commit_keyset_bytes is not None: + self.commit_keys = pickle.loads(commit_keyset_bytes) + else: + self.commit_keys = {} + + # Reload metadata + meta_bytes = self.long_term.get(META_KEY % commit_hash) + if meta_bytes is not None: + try: + self.meta = pickle.loads(meta_bytes) + except Exception: + self.meta = {} + else: + self.meta = {} + + # Reset working state + self.live = Live() + self.removed = set() + self.accessed_objects.clear() + self._touch_counter = ( + max((entry.last_touch for entry in self.meta.values()), default=0) + if self.meta + else 0 + ) + + return True + def diffs(self, commit_hash: str | None = None) -> dict[str, Any]: """ Returns the state changes for a given commit. diff --git a/docs/api/events.md b/docs/api/events.md index eec2522..799522a 100644 --- a/docs/api/events.md +++ b/docs/api/events.md @@ -163,7 +163,7 @@ All events share these common properties from `BaseEvent`: - **`full_namespace`**: `str` - The agent's namespace path. Equals `agent_name` for the agent that owns the state. - **`commit_hash`**: `str | None` - The commit hash linking this event to `Versioned` state. Only populated when using `Versioned` state (see [State Management](state.md)); `None` for `Live` or ephemeral state. - **For action events** (`ActionEvent`, `OutputEvent`, etc.): The commit hash *before* the action—useful for inspecting what the agent saw when it made a decision. - - **For task result events** (`SuccessEvent`, `FailEvent`, `ClarifyEvent`, `CancelledEvent`): The commit hash *after* the result is recorded—enabling "revert to this outcome" workflows via `state.checkout(event.commit_hash)`. + - **For task result events** (`SuccessEvent`, `FailEvent`, `ClarifyEvent`, `CancelledEvent`): The commit hash *after* the result is recorded—enabling "revert to this outcome" workflows via `state.revert_to(event.commit_hash)`. - **`source`**: `Literal["setup", "main"]` - The execution phase that generated the event. Defaults to `"main"`. Events generated by the `setup` parameter of `@agent.task` are tagged with `"setup"`. - **`full_detail_tokens`**: `int` - Cached token estimate for full-detail rendering. Computed automatically at event creation. - **`low_detail_tokens`**: `int` - Cached token estimate for low-detail rendering (used when event age triggers compression). Typically 25-50% of `full_detail_tokens`. Computed automatically for `TaskStartEvent`, `OutputEvent`, and `SuccessEvent`; equals `full_detail_tokens` for other event types. diff --git a/docs/api/state.md b/docs/api/state.md index f95744a..66eada9 100644 --- a/docs/api/state.md +++ b/docs/api/state.md @@ -142,7 +142,10 @@ state = connect_state(type="versioned", storage="disk", path="/var/agex/state") ### Automatic Checkpointing -Every agent iteration creates a snapshot. You can inspect or rollback to any point: +Every agent iteration creates a snapshot. You can inspect historical states or revert the agent to a previous point in time. + +**Inspecting History (Read-Only)** +Use `checkout()` to get a read-only view of the state at a specific commit: ```python from agex import events, view @@ -150,12 +153,26 @@ from agex import events, view # Get events after a task run all_events = events(resolved_state) -# Each event has a commit_hash for time-travel debugging +# Inspect state as it was when an action occurred action = all_events[0] historical = resolved_state.checkout(action.commit_hash) print(view(historical, focus="full")) ``` +**Reverting State (Destructive)** +Use `revert_to()` to move the agent's HEAD back to a previous commit. This orphans all subsequent commits (which can be cleaned up by GC). + +```python +# Revert to the state after a specific successful task +success_event = all_events[-1] +resolved_state.revert_to(success_event.commit_hash) + +# The agent continues from this point as if later actions never happened +``` + +> [!TIP] +> Use `state.initial_commit` to get the hash of the very first commit (the empty root state). This is useful for resetting the agent completely. + ### Concurrent Task Handling Versioned state handles concurrent execution safely via the `on_conflict` parameter on tasks: diff --git a/tests/agex/state/test_versioned.py b/tests/agex/state/test_versioned.py index 623050b..34f813e 100644 --- a/tests/agex/state/test_versioned.py +++ b/tests/agex/state/test_versioned.py @@ -364,3 +364,140 @@ def test_reset_reloads_from_head(): assert state2.get("b") is None assert state2.current_commit == state1.current_commit assert state2.base_commit == state1.current_commit + + +def test_revert_to_moves_head_to_earlier_commit(): + """Test that revert_to moves HEAD to an earlier commit.""" + import pickle + + from agex.state.versioned import HEAD_COMMIT + + store = kv.Memory() + state = Versioned(store) + + # Create some history + state.set("a", 1) + state.snapshot() + state.merge() + commit1 = state.current_commit + + state.set("a", 2) + state.snapshot() + state.merge() + + state.set("a", 3) + state.snapshot() + state.merge() + commit3 = state.current_commit + + # Verify we're at commit3 + assert state.get("a") == 3 + assert state.current_commit == commit3 + + # Revert to commit1 + result = state.revert_to(commit1) + assert result is True + + # Verify state is now at commit1 + assert state.get("a") == 1 + assert state.current_commit == commit1 + assert state.base_commit == commit1 + + # HEAD should be updated in the store + head = pickle.loads(store.get(HEAD_COMMIT)) + assert head == commit1 + + +def test_revert_to_orphans_later_commits(): + """Test that revert_to leaves later commits as orphans.""" + store = kv.Memory() + state = Versioned(store) + + # Create some history + state.set("a", 1) + state.snapshot() + state.merge() + commit1 = state.current_commit + + state.set("a", 2) + state.snapshot() + state.merge() + commit2 = state.current_commit + + # Revert to commit1 + state.revert_to(commit1) + + # History should only go back to commit1, not to commit2 + history = list(state.history()) + assert commit1 in history + assert commit2 not in history # commit2 is now orphaned + + +def test_revert_to_returns_false_for_invalid_commit(): + """Test that revert_to returns False for commits not in history.""" + store = kv.Memory() + state = Versioned(store) + + state.set("a", 1) + state.snapshot() + state.merge() + + # Try to revert to a non-existent commit + result = state.revert_to("nonexistent_hash") + assert result is False + + # State should be unchanged + assert state.get("a") == 1 + + +def test_revert_to_clears_local_changes(): + """Test that revert_to clears any uncommitted local changes.""" + store = kv.Memory() + state = Versioned(store) + + state.set("a", 1) + state.snapshot() + state.merge() + commit1 = state.current_commit + + # Make local changes without committing + state.set("b", 2) + + # Revert to commit1 + state.revert_to(commit1) + + # Local changes should be gone + assert state.get("a") == 1 + assert state.get("b") is None + assert "b" not in list(state.keys()) + + +def test_versioned_initial_commit(): + """Test that initial_commit returns the root commit hash.""" + store = kv.Memory() + state = Versioned(store) + + # Initial state has one commit (the root) + root = state.current_commit + assert root is not None + assert state.initial_commit == root + + # Add more commits + state.set("a", 1) + state.snapshot() + commit1 = state.current_commit + + state.set("b", 2) + state.snapshot() + commit2 = state.current_commit + + # initial_commit should remain the same + assert state.initial_commit == root + assert state.initial_commit != commit1 + assert state.initial_commit != commit2 + + # Revert to initial works + state.revert_to(state.initial_commit) + assert state.current_commit == root + assert state.get("a") is None + assert state.get("b") is None