@@ -182,6 +182,45 @@ pub enum ToolState {
182182 } ,
183183}
184184
185+
186+ impl ToolState {
187+ /// Check if transitioning to the given state is valid.
188+ ///
189+ /// Valid transitions:
190+ /// - Pending -> Running, Completed, Error
191+ /// - Running -> Completed, Error
192+ /// - Completed -> (terminal, no transitions)
193+ /// - Error -> (terminal, no transitions)
194+ ///
195+ /// State machine:
196+ /// ```text
197+ /// Pending -> Running -> Completed
198+ /// | |
199+ /// | +-> Error
200+ /// +-> Completed
201+ /// +-> Error
202+ /// ```
203+ pub fn can_transition_to ( & self , target : & ToolState ) -> bool {
204+ match ( self , target) {
205+ // From Pending, can go to any non-Pending state
206+ ( ToolState :: Pending { .. } , ToolState :: Running { .. } ) => true ,
207+ ( ToolState :: Pending { .. } , ToolState :: Completed { .. } ) => true ,
208+ ( ToolState :: Pending { .. } , ToolState :: Error { .. } ) => true ,
209+
210+ // From Running, can go to Completed or Error
211+ ( ToolState :: Running { .. } , ToolState :: Completed { .. } ) => true ,
212+ ( ToolState :: Running { .. } , ToolState :: Error { .. } ) => true ,
213+
214+ // Terminal states cannot transition
215+ ( ToolState :: Completed { .. } , _) => false ,
216+ ( ToolState :: Error { .. } , _) => false ,
217+
218+ // Any other transition is invalid
219+ _ => false ,
220+ }
221+ }
222+ }
223+
185224/// Subtask execution status.
186225#[ derive( Debug , Clone , Copy , Serialize , Deserialize , PartialEq , Eq , JsonSchema ) ]
187226#[ serde( rename_all = "snake_case" ) ]
@@ -552,6 +591,8 @@ impl MessageWithParts {
552591 }
553592
554593 /// Update a tool state by call ID.
594+ ///
595+ /// Logs a warning if the state transition is invalid (e.g., from a terminal state).
555596 pub fn update_tool_state ( & mut self , call_id : & str , new_state : ToolState ) -> bool {
556597 for part in & mut self . parts {
557598 if let MessagePart :: Tool {
@@ -561,6 +602,14 @@ impl MessageWithParts {
561602 } = & mut part. part
562603 {
563604 if cid == call_id {
605+ if !state. can_transition_to ( & new_state) {
606+ tracing:: warn!(
607+ "Invalid ToolState transition from {:?} to {:?} for call_id {}" ,
608+ state,
609+ new_state,
610+ call_id
611+ ) ;
612+ }
564613 * state = new_state;
565614 return true ;
566615 }
0 commit comments