Skip to content

Commit a66e9d7

Browse files
committed
fix(protocol): add ToolState transition validation
1 parent c398212 commit a66e9d7

File tree

3 files changed

+51
-0
lines changed

3 files changed

+51
-0
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/cortex-protocol/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ uuid = { workspace = true, features = ["serde", "v4"] }
2020
chrono = { workspace = true }
2121
strum_macros = "0.27"
2222
base64 = { workspace = true }
23+
tracing = { workspace = true }
2324

2425
[dev-dependencies]
2526
pretty_assertions = { workspace = true }

src/cortex-protocol/src/protocol/message_parts.rs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,45 @@ pub enum ToolState {
182182
},
183183
}
184184

185+
186+
impl ToolState {
187+
/// Check if transitioning to the given state is valid.
188+
///
189+
/// Valid transitions:
190+
/// - Pending -> Running, Completed, Error
191+
/// - Running -> Completed, Error
192+
/// - Completed -> (terminal, no transitions)
193+
/// - Error -> (terminal, no transitions)
194+
///
195+
/// State machine:
196+
/// ```text
197+
/// Pending -> Running -> Completed
198+
/// | |
199+
/// | +-> Error
200+
/// +-> Completed
201+
/// +-> Error
202+
/// ```
203+
pub fn can_transition_to(&self, target: &ToolState) -> bool {
204+
match (self, target) {
205+
// From Pending, can go to any non-Pending state
206+
(ToolState::Pending { .. }, ToolState::Running { .. }) => true,
207+
(ToolState::Pending { .. }, ToolState::Completed { .. }) => true,
208+
(ToolState::Pending { .. }, ToolState::Error { .. }) => true,
209+
210+
// From Running, can go to Completed or Error
211+
(ToolState::Running { .. }, ToolState::Completed { .. }) => true,
212+
(ToolState::Running { .. }, ToolState::Error { .. }) => true,
213+
214+
// Terminal states cannot transition
215+
(ToolState::Completed { .. }, _) => false,
216+
(ToolState::Error { .. }, _) => false,
217+
218+
// Any other transition is invalid
219+
_ => false,
220+
}
221+
}
222+
}
223+
185224
/// Subtask execution status.
186225
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, JsonSchema)]
187226
#[serde(rename_all = "snake_case")]
@@ -552,6 +591,8 @@ impl MessageWithParts {
552591
}
553592

554593
/// Update a tool state by call ID.
594+
///
595+
/// Logs a warning if the state transition is invalid (e.g., from a terminal state).
555596
pub fn update_tool_state(&mut self, call_id: &str, new_state: ToolState) -> bool {
556597
for part in &mut self.parts {
557598
if let MessagePart::Tool {
@@ -561,6 +602,14 @@ impl MessageWithParts {
561602
} = &mut part.part
562603
{
563604
if cid == call_id {
605+
if !state.can_transition_to(&new_state) {
606+
tracing::warn!(
607+
"Invalid ToolState transition from {:?} to {:?} for call_id {}",
608+
state,
609+
new_state,
610+
call_id
611+
);
612+
}
564613
*state = new_state;
565614
return true;
566615
}

0 commit comments

Comments
 (0)