diff --git a/wake_ai/cli.py b/wake_ai/cli.py index 3596de2b..92e17768 100644 --- a/wake_ai/cli.py +++ b/wake_ai/cli.py @@ -301,6 +301,7 @@ def main(ctx: click.Context, working_dir: str | None, model: str, resume: bool, ctx.obj["model"] = model ctx.obj["working_dir"] = working_dir ctx.obj["execution_dir"] = execution_dir + ctx.obj["resume"] = resume ctx.obj["cleanup_working_dir"] = not no_cleanup ctx.obj["show_progress"] = not no_progress ctx.obj["console"] = console # Pass console for coordinated output diff --git a/wake_ai/core/claude.py b/wake_ai/core/claude.py index 00d1e087..c58ebc80 100644 --- a/wake_ai/core/claude.py +++ b/wake_ai/core/claude.py @@ -520,6 +520,7 @@ def query( prompt: str, max_turns: Optional[int] = None, continue_session: bool = False, + resume_session: Optional[str] = None, ) -> ClaudeCodeResponse: """Execute a query with Claude Code (synchronous wrapper). @@ -534,15 +535,23 @@ def query( # Note: Session resumption logic is handled in the async version + if resume_session and continue_session: + raise ValueError( + "resume_session and continue_session cannot be used together" + ) + if continue_session: logger.debug(f"Continuing session: {continue_session}") # Execute async version using asyncio event loop response = asyncio.run( self.query_async( - prompt=prompt, max_turns=max_turns, continue_session=continue_session + prompt=prompt, max_turns=max_turns, continue_session=continue_session, resume_session=resume_session ) ) + # update only it could reach here. + # it can update last session, however other handiling needs to be done. + self.last_session_id = response.session_id return response diff --git a/wake_ai/core/flow.py b/wake_ai/core/flow.py index d845c8b9..5777ce0f 100644 --- a/wake_ai/core/flow.py +++ b/wake_ai/core/flow.py @@ -148,9 +148,42 @@ class StepExecutionInfo: duration: float # in seconds retries: int status: str # "completed", "skipped", "failed", "running" + session_id: str start_time: Optional[datetime] = None # Track when step started + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + return { + "name": self.name, + "turns": self.turns, + "cost": self.cost, + "duration": self.duration, + "retries": self.retries, + "status": self.status, + "start_time": self.start_time.isoformat() if self.start_time else None, + "session_id": self.session_id + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "StepExecutionInfo": + """Create StepExecutionInfo from dictionary.""" + start_time = None + if data["start_time"]: + start_time = datetime.fromisoformat(data["start_time"]) + + return cls( + name=data["name"], + turns=data["turns"], + cost=data["cost"], + duration=data["duration"], + retries=data["retries"], + status=data["status"], + start_time=start_time, + session_id=data["session_id"] + ) + + @dataclass class WorkflowState: """State tracking for workflow execution.""" @@ -184,6 +217,47 @@ class AIWorkflow(ABC): _init_called: bool _console: Console + @staticmethod + def _find_latest_session_dir(workflow_name: Optional[str] = None) -> Optional[Path]: + """Find the latest session directory for resuming. + + Args: + workflow_name: Optional workflow name to filter by state files + + Returns: + Path to the latest session directory, or None if not found + """ + wake_ai_dir = Path.cwd() / ".wake" / "ai" + if not wake_ai_dir.exists(): + return None + + # Find all session directories matching the timestamp pattern + session_dirs = [] + for path in wake_ai_dir.iterdir(): + if path.is_dir() and re.match(r'^\d{8}_\d{6}_[a-z0-9]{6}$', path.name): + session_dirs.append(path) + + if not session_dirs: + return None + + # Sort by directory name (which includes timestamp) to get latest + session_dirs.sort(key=lambda p: p.name, reverse=True) + + # If workflow name specified, find the latest session with a matching state file + if workflow_name: + for session_dir in session_dirs: + state_file = session_dir / f"{workflow_name}_state.json" + if state_file.exists(): + logger.debug(f"Found latest session for '{workflow_name}': {session_dir}") + return session_dir + logger.warning(f"No previous session found for workflow '{workflow_name}'") + return None + else: + # Return the latest session directory regardless of workflow + latest_dir = session_dirs[0] + logger.debug(f"Found latest session directory: {latest_dir}") + return latest_dir + def __init__( self, name: Optional[str] = None, @@ -196,7 +270,8 @@ def __init__( disallowed_tools: Optional[List[str]] = None, cleanup_working_dir: Optional[bool] = None, show_progress: Optional[bool] = None, - console: Optional[Console] = None + console: Optional[Console] = None, + resume: Optional[bool] = None ): """Initialize workflow. @@ -213,6 +288,7 @@ def __init__( cleanup_working_dir: Whether to remove working_dir after completion (default: True) show_progress: Whether to show progress bar during execution (default: True) console: Rich Console instance for coordinated output (optional) + resume: Whether to resume from latest session (default: False) """ ctx = click.get_current_context(silent=True) if ctx is None: @@ -233,6 +309,8 @@ def __init__( working_dir = cli.get("working_dir", None) if execution_dir is None: execution_dir = cli.get("execution_dir", None) + if resume is None: + resume = cli.get("resume", False) # Set cleanup behavior (use instance value if provided, else class default) self.cleanup_working_dir = cleanup_working_dir if cleanup_working_dir is not None else cli.get("cleanup_working_dir", True) @@ -254,6 +332,22 @@ def __init__( # Set up working directory if working_dir is not None: self.working_dir = Path(working_dir).resolve() + elif resume: + # Try to find latest session directory for resuming + latest_session = self._find_latest_session_dir(self.name) + if latest_session is not None: + self.working_dir = latest_session + logger.info(f"Resuming from session: {self.working_dir}") + else: + # No previous session found, create new one but warn user + logger.warning("No previous session found to resume from, creating new session") + import random + import string + from datetime import datetime + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + suffix = ''.join(random.choices(string.ascii_lowercase + string.digits, k=6)) + session_id = f"{timestamp}_{suffix}" + self.working_dir = Path.cwd() / ".wake" / "ai" / session_id else: # Generate session ID for working directory import random @@ -328,9 +422,9 @@ def __init__( console=self._console ) - self.steps: List[WorkflowStep] = [] + self.steps = [] self.state = WorkflowState() - self._dynamic_generators: Dict[str, Callable[[ClaudeCodeResponse, Dict[str, Any]], List[WorkflowStep]]] = {} + self._dynamic_generators = {} # Progress tracking self._status_context = None # console.status context manager @@ -479,10 +573,7 @@ def execute(self, context: Optional[Dict[str, Any]] = None, resume: bool = False # Use status display context manager for the entire execution with self._status_display(): # Initialize progress tracking - try: - self.update_progress("Initializing workflow...") - except Exception as e: - logger.debug(f"Failed to update progress: {e}") + self.update_progress("Initializing workflow...") if resume and (self.working_dir / f"{self.name}_state.json").exists(): logger.info(f"Resuming workflow from saved state in: {self.working_dir / f'{self.name}_state.json'}") @@ -515,7 +606,8 @@ def execute(self, context: Optional[Dict[str, Any]] = None, resume: bool = False turns=0, duration=0.0, retries=0, - status="skipped" + status="skipped", + session_id="" # TODO ) # Update status display with skipped step self._update_status_display() @@ -526,46 +618,46 @@ def execute(self, context: Optional[Dict[str, Any]] = None, resume: bool = False logger.info(f"Executing step {self.state.current_step + 1}/{len(self.steps)}: '{step.name}'") # Update progress message at step start (percentage based on completed steps) + step_msg = f"Starting '{step.name}' ({self.state.current_step + 1}/{len(self.steps)})" + self.update_progress(step_msg) + + # Save original tools and model + original_allowed = self.session.allowed_tools + original_disallowed = self.session.disallowed_tools + original_model = getattr(self.session, 'model', None) + + # Track step execution start time + step_start_time = datetime.now() + + # Mark step as running and update display + self.state.step_info[self.state.current_step] = StepExecutionInfo( + name=step.name, + cost=0.0, + turns=0, + duration=0.0, + retries=0, + status="running", + start_time=step_start_time, + session_id="" # TODO + ) + self._update_status_display() + + # Execute step with retry logic + retry_count = 0 + validation_errors = [] + response = None + + # Change model for this step if specified + if step.model is not None and step.model != original_model: + logger.debug(f"Switching from model '{original_model}' to '{step.model}' for step '{step.name}'") + self.session.model = step.model + + # catching error during query and step operation. try: - step_msg = f"Starting '{step.name}' ({self.state.current_step + 1}/{len(self.steps)})" - self.update_progress(step_msg) - except Exception as e: - logger.debug(f"Failed to update progress: {e}") + while retry_count <= step.max_retries: - try: - # Track step execution start time - step_start_time = datetime.now() - - # Mark step as running and update display - self.state.step_info[self.state.current_step] = StepExecutionInfo( - name=step.name, - cost=0.0, - turns=0, - duration=0.0, - retries=0, - status="running", - start_time=step_start_time - ) - self._update_status_display() - - # Execute step with retry logic - retry_count = 0 - validation_errors = [] - response = None - step_total_cost = 0.0 - step_total_turns = 0 - - # Save original tools and model - original_allowed = self.session.allowed_tools - original_disallowed = self.session.disallowed_tools - original_model = getattr(self.session, 'model', None) - - # Change model for this step if specified - if step.model is not None and step.model != original_model: - logger.debug(f"Switching from model '{original_model}' to '{step.model}' for step '{step.name}'") - self.session.model = step.model + current_step_info = self.state.step_info[self.state.current_step] - while retry_count <= step.max_retries: # Set tools if specified (step overrides workflow defaults) if step.allowed_tools is not None: self.session.allowed_tools = step.allowed_tools @@ -584,14 +676,20 @@ def execute(self, context: Optional[Dict[str, Any]] = None, resume: bool = False prompt = step.format_prompt(self.state.context) # Continue session only if step explicitly requests it + previous_session_id = None + should_continue = step.continue_session + if should_continue and self.state.current_step == 0: + raise ValueError("Cannot continue session for step 0") + if should_continue and self.state.current_step > 0: + previous_session_id = self.state.step_info[self.state.current_step-1].session_id if step.max_cost: logger.debug(f"Querying with cost limit ${step.max_cost} for step '{step.name}' (continue_session={should_continue}, model={getattr(self.session, 'model', 'default')})") - response = self.query_with_cost(prompt, step.max_cost, continue_session=should_continue, step_info=self.state.step_info[self.state.current_step]) + response = self.query_with_cost(prompt, step.max_cost, resume_session=previous_session_id, step_info=current_step_info) else: logger.debug(f"Querying step '{step.name}' (continue_session={should_continue}, model={getattr(self.session, 'model', 'default')})") - response = self.session.query(prompt, continue_session=should_continue) + response = self.session.query(prompt, resume_session=previous_session_id) else: # Retry attempt - add error correction prompt error_prompt = "The following errors occurred, please fix them:\n" @@ -601,16 +699,13 @@ def execute(self, context: Optional[Dict[str, Any]] = None, resume: bool = False logger.info(f"Retrying step '{step.name}' (attempt {retry_count}/{step.max_retries}) - previous attempt failed validation") # Update progress message for retry (don't change percentage) - try: - retry_msg = f"Retrying '{step.name}' (attempt {retry_count}/{step.max_retries})" - self.update_progress_message(retry_msg) - except Exception as e: - logger.debug(f"Failed to update progress message: {e}") + retry_msg = f"Retrying '{step.name}' (attempt {retry_count}/{step.max_retries})" + self.update_progress_message(retry_msg) # Always continue session for retries if step.max_retry_cost: logger.debug(f"Querying retry with cost limit ${step.max_retry_cost} for step '{step.name}' (model={getattr(self.session, 'model', 'default')})") - response = self.query_with_cost(prompt, step.max_retry_cost, continue_session=True, step_info=self.state.step_info[self.state.current_step]) + response = self.query_with_cost(prompt, step.max_retry_cost, continue_session=True, step_info=current_step_info) else: logger.debug(f"Querying retry for step '{step.name}' (model={getattr(self.session, 'model', 'default')})") response = self.session.query(prompt, continue_session=True) @@ -620,19 +715,17 @@ def execute(self, context: Optional[Dict[str, Any]] = None, resume: bool = False logger.debug(f"Claude session ID: {response.session_id}") # Update progress message for validation (don't change percentage) - try: - if retry_count == 0: - validation_msg = f"Validating '{step.name}' output" - else: - validation_msg = f"Validating retry of '{step.name}' (attempt {retry_count}/{step.max_retries})" - self.update_progress_message(validation_msg) - except Exception as e: - logger.debug(f"Failed to update progress message: {e}") + if retry_count == 0: + validation_msg = f"Validating '{step.name}' output" + else: + validation_msg = f"Validating retry of '{step.name}' (attempt {retry_count}/{step.max_retries})" + self.update_progress_message(validation_msg) + # Update cumulative cost and step totals self.state.cumulative_cost += response.cost - step_total_cost += response.cost - step_total_turns += response.num_turns + current_step_info.cost += response.cost + current_step_info.turns += response.num_turns # Validate response success, validation_errors = step.validate_response(response) @@ -640,21 +733,16 @@ def execute(self, context: Optional[Dict[str, Any]] = None, resume: bool = False if success: # Validation passed - log successful completion with total cost/turns retry_msg = f" after {retry_count} retries" if retry_count > 0 else "" - logger.info(f"Step '{step.name}' completed{retry_msg} - cost: ${step_total_cost:.4f}, turns: {step_total_turns}") + logger.info(f"Step '{step.name}' completed{retry_msg} - cost: ${current_step_info.cost:.4f}, turns: {current_step_info.turns}") logger.debug(f"Response: {response.content}") # Calculate step duration - step_duration = (datetime.now() - step_start_time).total_seconds() - - # Record step execution info - self.state.step_info[self.state.current_step] = StepExecutionInfo( - name=step.name, - cost=step_total_cost, - turns=step_total_turns, - duration=step_duration, - retries=retry_count, - status="completed" - ) + step_duration = float((datetime.now() - step_start_time).total_seconds()) + current_step_info.duration += step_duration + current_step_info.retries += retry_count + current_step_info.status = "completed" + current_step_info.session_id = response.session_id + # Update live display with completed step self._update_status_display() @@ -676,11 +764,8 @@ def execute(self, context: Optional[Dict[str, Any]] = None, resume: bool = False self._save_state() # Update progress after step completion - try: - step_msg = f"Completed step '{step.name}' ({len(self.state.completed_steps)}/{len(self.steps)})" - self.update_progress(step_msg) - except Exception as e: - logger.debug(f"Failed to update progress: {e}") + step_msg = f"Completed step '{step.name}' ({len(self.state.completed_steps)}/{len(self.steps)})" + self.update_progress(step_msg) break else: @@ -726,10 +811,7 @@ def execute(self, context: Optional[Dict[str, Any]] = None, resume: bool = False logger.info(f"Workflow '{self.name}' completed successfully (total cost: ${self.state.cumulative_cost:.4f})") # Complete progress tracking - try: - self.update_progress("Workflow completed!", force_percentage=1.0) - except Exception as e: - logger.debug(f"Failed to update final progress: {e}") + self.update_progress("Workflow completed!", force_percentage=1.0) # Format results before cleanup formatted_results = self.format_results(results) @@ -748,7 +830,7 @@ def execute(self, context: Optional[Dict[str, Any]] = None, resume: bool = False return results, formatted_results - def query_with_cost(self, prompt: str, cost_limit: float, turn_step: int = 50, continue_session: bool = False, step_info: Optional[StepExecutionInfo] = None) -> ClaudeCodeResponse: + def query_with_cost(self, prompt: str, cost_limit: float, turn_step: int = 50, continue_session: bool = False, resume_session: Optional[str] = None, step_info: Optional[StepExecutionInfo] = None) -> ClaudeCodeResponse: """Execute queries with cost monitoring and automatic completion. Args: @@ -773,13 +855,17 @@ def query_with_cost(self, prompt: str, cost_limit: float, turn_step: int = 50, c last_response = None iteration = 0 + if resume_session and continue_session: + raise ValueError("resume_session and continue_session cannot be used together") + # First query with the initial prompt logger.debug(f"Iteration {iteration}: Initial query") response = self.session.query( prompt=prompt, max_turns=turn_step, - continue_session=continue_session + continue_session=continue_session, + resume_session=resume_session ) if not response.success: @@ -1009,7 +1095,8 @@ def _save_state(self): "context": self.state.context, "errors": self.state.errors, "cumulative_cost": self.state.cumulative_cost, - "progress_percentage": self.state.progress_percentage + "progress_percentage": self.state.progress_percentage, + "step_info": {k: v.to_dict() for k, v in self.state.step_info.items()} } state_file = self.working_dir / f"{self.name}_state.json" state_file.write_text(json.dumps(state_data, indent=2)) @@ -1020,6 +1107,7 @@ def _load_state(self): state_file = self.working_dir / f"{self.name}_state.json" logger.debug(f"Loading workflow state from {state_file}") data = json.loads(state_file.read_text()) + self.state.step_info = {int(k): StepExecutionInfo.from_dict(v) for k, v in data["step_info"].items()} self.state.current_step = data["current_step"] self.state.completed_steps = data["completed_steps"] self.state.skipped_steps = data["skipped_steps"]