diff --git a/.env.example b/.env.example index 2c139c3a..03cf7873 100644 --- a/.env.example +++ b/.env.example @@ -34,6 +34,22 @@ # ANTHROPIC_DEFAULT_SONNET_MODEL=claude-sonnet-4-5@20250929 # ANTHROPIC_DEFAULT_HAIKU_MODEL=claude-3-5-haiku@20241022 +# =================== +# Alternative API Providers +# =================== +# NOTE: These env vars are the legacy way to configure providers. +# The recommended way is to use the Settings UI (API Provider section). +# UI settings take precedence when api_provider != "claude". + +# Kimi K2.5 (Moonshot) Configuration (Optional) +# Get an API key at: https://kimi.com +# +# ANTHROPIC_BASE_URL=https://api.kimi.com/coding/ +# ANTHROPIC_API_KEY=your-kimi-api-key +# ANTHROPIC_DEFAULT_SONNET_MODEL=kimi-k2.5 +# ANTHROPIC_DEFAULT_OPUS_MODEL=kimi-k2.5 +# ANTHROPIC_DEFAULT_HAIKU_MODEL=kimi-k2.5 + # GLM/Alternative API Configuration (Optional) # To use Zhipu AI's GLM models instead of Claude, uncomment and set these variables. # This only affects AutoForge - your global Claude Code settings remain unchanged. diff --git a/bin/autoforge.js b/bin/autoforge.js old mode 100644 new mode 100755 diff --git a/env_constants.py b/env_constants.py index e284c542..45737c41 100644 --- a/env_constants.py +++ b/env_constants.py @@ -15,6 +15,7 @@ # Core API configuration "ANTHROPIC_BASE_URL", # Custom API endpoint (e.g., https://api.z.ai/api/anthropic) "ANTHROPIC_AUTH_TOKEN", # API authentication token + "ANTHROPIC_API_KEY", # API key (used by Kimi and other providers) "API_TIMEOUT_MS", # Request timeout in milliseconds # Model tier overrides "ANTHROPIC_DEFAULT_SONNET_MODEL", # Model override for Sonnet diff --git a/registry.py b/registry.py index 4668b153..6770bd5c 100644 --- a/registry.py +++ b/registry.py @@ -612,3 +612,120 @@ def get_all_settings() -> dict[str, str]: except Exception as e: logger.warning("Failed to read settings: %s", e) return {} + + +# ============================================================================= +# API Provider Definitions +# ============================================================================= + +API_PROVIDERS: dict[str, dict[str, Any]] = { + "claude": { + "name": "Claude (Anthropic)", + "base_url": None, + "requires_auth": False, + "models": [ + {"id": "claude-opus-4-5-20251101", "name": "Claude Opus 4.5"}, + {"id": "claude-sonnet-4-5-20250929", "name": "Claude Sonnet 4.5"}, + ], + "default_model": "claude-opus-4-5-20251101", + }, + "kimi": { + "name": "Kimi K2.5 (Moonshot)", + "base_url": "https://api.kimi.com/coding/", + "requires_auth": True, + "auth_env_var": "ANTHROPIC_API_KEY", + "models": [{"id": "kimi-k2.5", "name": "Kimi K2.5"}], + "default_model": "kimi-k2.5", + }, + "glm": { + "name": "GLM (Zhipu AI)", + "base_url": "https://api.z.ai/api/anthropic", + "requires_auth": True, + "auth_env_var": "ANTHROPIC_AUTH_TOKEN", + "models": [ + {"id": "glm-4.7", "name": "GLM 4.7"}, + {"id": "glm-4.5-air", "name": "GLM 4.5 Air"}, + ], + "default_model": "glm-4.7", + }, + "ollama": { + "name": "Ollama (Local)", + "base_url": "http://localhost:11434", + "requires_auth": False, + "models": [ + {"id": "qwen3-coder", "name": "Qwen3 Coder"}, + {"id": "deepseek-coder-v2", "name": "DeepSeek Coder V2"}, + ], + "default_model": "qwen3-coder", + }, + "custom": { + "name": "Custom Provider", + "base_url": "", + "requires_auth": True, + "auth_env_var": "ANTHROPIC_AUTH_TOKEN", + "models": [], + "default_model": "", + }, +} + + +def get_effective_sdk_env() -> dict[str, str]: + """Build environment variable dict for Claude SDK based on current API provider settings. + + When api_provider is "claude" (or unset), falls back to existing env vars (current behavior). + For other providers, builds env dict from stored settings (api_base_url, api_auth_token, api_model). + + Returns: + Dict ready to merge into subprocess env or pass to SDK. + """ + all_settings = get_all_settings() + provider_id = all_settings.get("api_provider", "claude") + + if provider_id == "claude": + # Default behavior: forward existing env vars + from env_constants import API_ENV_VARS + sdk_env: dict[str, str] = {} + for var in API_ENV_VARS: + value = os.getenv(var) + if value: + sdk_env[var] = value + return sdk_env + + # Alternative provider: build env from settings + provider = API_PROVIDERS.get(provider_id) + if not provider: + logger.warning("Unknown API provider '%s', falling back to claude", provider_id) + from env_constants import API_ENV_VARS + sdk_env = {} + for var in API_ENV_VARS: + value = os.getenv(var) + if value: + sdk_env[var] = value + return sdk_env + + sdk_env = {} + + # Base URL + base_url = all_settings.get("api_base_url") or provider.get("base_url") + if base_url: + sdk_env["ANTHROPIC_BASE_URL"] = base_url + + # Auth token + auth_token = all_settings.get("api_auth_token") + if auth_token: + auth_env_var = provider.get("auth_env_var", "ANTHROPIC_AUTH_TOKEN") + sdk_env[auth_env_var] = auth_token + + # Model - set all three tier overrides to the same model + model = all_settings.get("api_model") or provider.get("default_model") + if model: + sdk_env["ANTHROPIC_DEFAULT_OPUS_MODEL"] = model + sdk_env["ANTHROPIC_DEFAULT_SONNET_MODEL"] = model + sdk_env["ANTHROPIC_DEFAULT_HAIKU_MODEL"] = model + + # Timeout + timeout = all_settings.get("api_timeout_ms") + if timeout: + sdk_env["API_TIMEOUT_MS"] = timeout + + return sdk_env diff --git a/server/routers/assistant_chat.py b/server/routers/assistant_chat.py index ceae8bd2..92091288 100644 --- a/server/routers/assistant_chat.py +++ b/server/routers/assistant_chat.py @@ -26,7 +26,7 @@ get_conversations, ) from ..utils.project_helpers import get_project_path as _get_project_path -from ..utils.validation import is_valid_project_name as validate_project_name +from ..utils.validation import validate_project_name logger = logging.getLogger(__name__) @@ -217,20 +217,26 @@ async def assistant_chat_websocket(websocket: WebSocket, project_name: str): - {"type": "error", "content": "..."} - Error message - {"type": "pong"} - Keep-alive pong """ - if not validate_project_name(project_name): + # Always accept WebSocket first to avoid opaque 403 errors + await websocket.accept() + + try: + project_name = validate_project_name(project_name) + except HTTPException: + await websocket.send_json({"type": "error", "content": "Invalid project name"}) await websocket.close(code=4000, reason="Invalid project name") return project_dir = _get_project_path(project_name) if not project_dir: + await websocket.send_json({"type": "error", "content": "Project not found in registry"}) await websocket.close(code=4004, reason="Project not found in registry") return if not project_dir.exists(): + await websocket.send_json({"type": "error", "content": "Project directory not found"}) await websocket.close(code=4004, reason="Project directory not found") return - - await websocket.accept() logger.info(f"Assistant WebSocket connected for project: {project_name}") session: Optional[AssistantChatSession] = None diff --git a/server/routers/expand_project.py b/server/routers/expand_project.py index 5b55824e..d680b952 100644 --- a/server/routers/expand_project.py +++ b/server/routers/expand_project.py @@ -104,19 +104,26 @@ async def expand_project_websocket(websocket: WebSocket, project_name: str): - {"type": "error", "content": "..."} - Error message - {"type": "pong"} - Keep-alive pong """ + # Always accept the WebSocket first to avoid opaque 403 errors. + # Starlette returns 403 if we close before accepting. + await websocket.accept() + try: project_name = validate_project_name(project_name) except HTTPException: + await websocket.send_json({"type": "error", "content": "Invalid project name"}) await websocket.close(code=4000, reason="Invalid project name") return # Look up project directory from registry project_dir = _get_project_path(project_name) if not project_dir: + await websocket.send_json({"type": "error", "content": "Project not found in registry"}) await websocket.close(code=4004, reason="Project not found in registry") return if not project_dir.exists(): + await websocket.send_json({"type": "error", "content": "Project directory not found"}) await websocket.close(code=4004, reason="Project directory not found") return @@ -124,11 +131,10 @@ async def expand_project_websocket(websocket: WebSocket, project_name: str): from autoforge_paths import get_prompts_dir spec_path = get_prompts_dir(project_dir) / "app_spec.txt" if not spec_path.exists(): + await websocket.send_json({"type": "error", "content": "Project has no spec. Create a spec first before expanding."}) await websocket.close(code=4004, reason="Project has no spec. Create spec first.") return - await websocket.accept() - session: Optional[ExpandChatSession] = None try: diff --git a/server/routers/settings.py b/server/routers/settings.py index 77b4a4de..81df0242 100644 --- a/server/routers/settings.py +++ b/server/routers/settings.py @@ -12,7 +12,7 @@ from fastapi import APIRouter -from ..schemas import ModelInfo, ModelsResponse, SettingsResponse, SettingsUpdate +from ..schemas import ModelInfo, ModelsResponse, ProviderInfo, ProvidersResponse, SettingsResponse, SettingsUpdate from ..services.chat_constants import ROOT_DIR # Mimetype fix for Windows - must run before StaticFiles is mounted @@ -23,9 +23,11 @@ sys.path.insert(0, str(ROOT_DIR)) from registry import ( + API_PROVIDERS, AVAILABLE_MODELS, DEFAULT_MODEL, get_all_settings, + get_setting, set_setting, ) @@ -50,13 +52,40 @@ def _is_ollama_mode() -> bool: return "localhost:11434" in base_url or "127.0.0.1:11434" in base_url +@router.get("/providers", response_model=ProvidersResponse) +async def get_available_providers(): + """Get list of available API providers.""" + current = get_setting("api_provider", "claude") or "claude" + providers = [] + for pid, pdata in API_PROVIDERS.items(): + providers.append(ProviderInfo( + id=pid, + name=pdata["name"], + base_url=pdata.get("base_url"), + models=[ModelInfo(id=m["id"], name=m["name"]) for m in pdata.get("models", [])], + default_model=pdata.get("default_model", ""), + requires_auth=pdata.get("requires_auth", False), + )) + return ProvidersResponse(providers=providers, current=current) + + @router.get("/models", response_model=ModelsResponse) async def get_available_models(): """Get list of available models. - Frontend should call this to get the current list of models - instead of hardcoding them. + Returns models for the currently selected API provider. """ + current_provider = get_setting("api_provider", "claude") or "claude" + provider = API_PROVIDERS.get(current_provider) + + if provider and current_provider != "claude": + provider_models = provider.get("models", []) + return ModelsResponse( + models=[ModelInfo(id=m["id"], name=m["name"]) for m in provider_models], + default=provider.get("default_model", ""), + ) + + # Default: return Claude models return ModelsResponse( models=[ModelInfo(id=m["id"], name=m["name"]) for m in AVAILABLE_MODELS], default=DEFAULT_MODEL, @@ -85,14 +114,24 @@ async def get_settings(): """Get current global settings.""" all_settings = get_all_settings() + api_provider = all_settings.get("api_provider", "claude") + + # Compute glm_mode / ollama_mode from api_provider for backward compat + glm_mode = api_provider == "glm" or _is_glm_mode() + ollama_mode = api_provider == "ollama" or _is_ollama_mode() + return SettingsResponse( yolo_mode=_parse_yolo_mode(all_settings.get("yolo_mode")), model=all_settings.get("model", DEFAULT_MODEL), - glm_mode=_is_glm_mode(), - ollama_mode=_is_ollama_mode(), + glm_mode=glm_mode, + ollama_mode=ollama_mode, testing_agent_ratio=_parse_int(all_settings.get("testing_agent_ratio"), 1), playwright_headless=_parse_bool(all_settings.get("playwright_headless"), default=True), batch_size=_parse_int(all_settings.get("batch_size"), 3), + api_provider=api_provider, + api_base_url=all_settings.get("api_base_url"), + api_has_auth_token=bool(all_settings.get("api_auth_token")), + api_model=all_settings.get("api_model"), ) @@ -114,14 +153,47 @@ async def update_settings(update: SettingsUpdate): if update.batch_size is not None: set_setting("batch_size", str(update.batch_size)) + # API provider settings + if update.api_provider is not None: + old_provider = get_setting("api_provider", "claude") + set_setting("api_provider", update.api_provider) + + # When provider changes, auto-set defaults for the new provider + if update.api_provider != old_provider: + provider = API_PROVIDERS.get(update.api_provider) + if provider: + # Auto-set base URL from provider definition + if provider.get("base_url"): + set_setting("api_base_url", provider["base_url"]) + # Auto-set model to provider's default + if provider.get("default_model") and update.api_model is None: + set_setting("api_model", provider["default_model"]) + + if update.api_base_url is not None: + set_setting("api_base_url", update.api_base_url) + + if update.api_auth_token is not None: + set_setting("api_auth_token", update.api_auth_token) + + if update.api_model is not None: + set_setting("api_model", update.api_model) + # Return updated settings all_settings = get_all_settings() + api_provider = all_settings.get("api_provider", "claude") + glm_mode = api_provider == "glm" or _is_glm_mode() + ollama_mode = api_provider == "ollama" or _is_ollama_mode() + return SettingsResponse( yolo_mode=_parse_yolo_mode(all_settings.get("yolo_mode")), model=all_settings.get("model", DEFAULT_MODEL), - glm_mode=_is_glm_mode(), - ollama_mode=_is_ollama_mode(), + glm_mode=glm_mode, + ollama_mode=ollama_mode, testing_agent_ratio=_parse_int(all_settings.get("testing_agent_ratio"), 1), playwright_headless=_parse_bool(all_settings.get("playwright_headless"), default=True), batch_size=_parse_int(all_settings.get("batch_size"), 3), + api_provider=api_provider, + api_base_url=all_settings.get("api_base_url"), + api_has_auth_token=bool(all_settings.get("api_auth_token")), + api_model=all_settings.get("api_model"), ) diff --git a/server/routers/spec_creation.py b/server/routers/spec_creation.py index cb7263ce..44b8d048 100644 --- a/server/routers/spec_creation.py +++ b/server/routers/spec_creation.py @@ -21,7 +21,7 @@ remove_session, ) from ..utils.project_helpers import get_project_path as _get_project_path -from ..utils.validation import is_valid_project_name as validate_project_name +from ..utils.validation import is_valid_project_name, validate_project_name logger = logging.getLogger(__name__) @@ -49,7 +49,7 @@ async def list_spec_sessions(): @router.get("/sessions/{project_name}", response_model=SpecSessionStatus) async def get_session_status(project_name: str): """Get status of a spec creation session.""" - if not validate_project_name(project_name): + if not is_valid_project_name(project_name): raise HTTPException(status_code=400, detail="Invalid project name") session = get_session(project_name) @@ -67,7 +67,7 @@ async def get_session_status(project_name: str): @router.delete("/sessions/{project_name}") async def cancel_session(project_name: str): """Cancel and remove a spec creation session.""" - if not validate_project_name(project_name): + if not is_valid_project_name(project_name): raise HTTPException(status_code=400, detail="Invalid project name") session = get_session(project_name) @@ -95,7 +95,7 @@ async def get_spec_file_status(project_name: str): This is used for polling to detect when Claude has finished writing spec files. Claude writes this status file as the final step after completing all spec work. """ - if not validate_project_name(project_name): + if not is_valid_project_name(project_name): raise HTTPException(status_code=400, detail="Invalid project name") project_dir = _get_project_path(project_name) @@ -166,22 +166,28 @@ async def spec_chat_websocket(websocket: WebSocket, project_name: str): - {"type": "error", "content": "..."} - Error message - {"type": "pong"} - Keep-alive pong """ - if not validate_project_name(project_name): + # Always accept WebSocket first to avoid opaque 403 errors + await websocket.accept() + + try: + project_name = validate_project_name(project_name) + except HTTPException: + await websocket.send_json({"type": "error", "content": "Invalid project name"}) await websocket.close(code=4000, reason="Invalid project name") return # Look up project directory from registry project_dir = _get_project_path(project_name) if not project_dir: + await websocket.send_json({"type": "error", "content": "Project not found in registry"}) await websocket.close(code=4004, reason="Project not found in registry") return if not project_dir.exists(): + await websocket.send_json({"type": "error", "content": "Project directory not found"}) await websocket.close(code=4004, reason="Project directory not found") return - await websocket.accept() - session: Optional[SpecChatSession] = None try: diff --git a/server/routers/terminal.py b/server/routers/terminal.py index a53b9abe..6845a27e 100644 --- a/server/routers/terminal.py +++ b/server/routers/terminal.py @@ -221,8 +221,12 @@ async def terminal_websocket(websocket: WebSocket, project_name: str, terminal_i - {"type": "pong"} - Keep-alive response - {"type": "error", "message": "..."} - Error message """ + # Always accept WebSocket first to avoid opaque 403 errors + await websocket.accept() + # Validate project name if not validate_project_name(project_name): + await websocket.send_json({"type": "error", "message": "Invalid project name"}) await websocket.close( code=TerminalCloseCode.INVALID_PROJECT_NAME, reason="Invalid project name" ) @@ -230,6 +234,7 @@ async def terminal_websocket(websocket: WebSocket, project_name: str, terminal_i # Validate terminal ID if not validate_terminal_id(terminal_id): + await websocket.send_json({"type": "error", "message": "Invalid terminal ID"}) await websocket.close( code=TerminalCloseCode.INVALID_PROJECT_NAME, reason="Invalid terminal ID" ) @@ -238,6 +243,7 @@ async def terminal_websocket(websocket: WebSocket, project_name: str, terminal_i # Look up project directory from registry project_dir = _get_project_path(project_name) if not project_dir: + await websocket.send_json({"type": "error", "message": "Project not found in registry"}) await websocket.close( code=TerminalCloseCode.PROJECT_NOT_FOUND, reason="Project not found in registry", @@ -245,6 +251,7 @@ async def terminal_websocket(websocket: WebSocket, project_name: str, terminal_i return if not project_dir.exists(): + await websocket.send_json({"type": "error", "message": "Project directory not found"}) await websocket.close( code=TerminalCloseCode.PROJECT_NOT_FOUND, reason="Project directory not found", @@ -254,14 +261,13 @@ async def terminal_websocket(websocket: WebSocket, project_name: str, terminal_i # Verify terminal exists in metadata terminal_info = get_terminal_info(project_name, terminal_id) if not terminal_info: + await websocket.send_json({"type": "error", "message": "Terminal not found"}) await websocket.close( code=TerminalCloseCode.PROJECT_NOT_FOUND, reason="Terminal not found", ) return - await websocket.accept() - # Get or create terminal session for this project/terminal session = get_terminal_session(project_name, project_dir, terminal_id) diff --git a/server/schemas.py b/server/schemas.py index e15f1b3d..8365e7ad 100644 --- a/server/schemas.py +++ b/server/schemas.py @@ -391,6 +391,22 @@ class ModelInfo(BaseModel): name: str +class ProviderInfo(BaseModel): + """Information about an API provider.""" + id: str + name: str + base_url: str | None = None + models: list[ModelInfo] + default_model: str + requires_auth: bool = False + + +class ProvidersResponse(BaseModel): + """Response schema for available providers list.""" + providers: list[ProviderInfo] + current: str + + class SettingsResponse(BaseModel): """Response schema for global settings.""" yolo_mode: bool = False @@ -400,6 +416,10 @@ class SettingsResponse(BaseModel): testing_agent_ratio: int = 1 # Regression testing agents (0-3) playwright_headless: bool = True batch_size: int = 3 # Features per coding agent batch (1-3) + api_provider: str = "claude" + api_base_url: str | None = None + api_has_auth_token: bool = False # Never expose actual token + api_model: str | None = None class ModelsResponse(BaseModel): @@ -415,12 +435,21 @@ class SettingsUpdate(BaseModel): testing_agent_ratio: int | None = None # 0-3 playwright_headless: bool | None = None batch_size: int | None = None # Features per agent batch (1-3) + api_provider: str | None = None + api_base_url: str | None = None + api_auth_token: str | None = None # Write-only, never returned + api_model: str | None = None @field_validator('model') @classmethod - def validate_model(cls, v: str | None) -> str | None: - if v is not None and v not in VALID_MODELS: - raise ValueError(f"Invalid model. Must be one of: {VALID_MODELS}") + def validate_model(cls, v: str | None, info) -> str | None: # type: ignore[override] + if v is not None: + # Skip VALID_MODELS check when using an alternative API provider + api_provider = info.data.get("api_provider") + if api_provider and api_provider != "claude": + return v + if v not in VALID_MODELS: + raise ValueError(f"Invalid model. Must be one of: {VALID_MODELS}") return v @field_validator('testing_agent_ratio') diff --git a/server/services/assistant_chat_session.py b/server/services/assistant_chat_session.py index 73d3dfb7..bd4c9cae 100755 --- a/server/services/assistant_chat_session.py +++ b/server/services/assistant_chat_session.py @@ -25,7 +25,7 @@ create_conversation, get_messages, ) -from .chat_constants import API_ENV_VARS, ROOT_DIR +from .chat_constants import ROOT_DIR # Load environment variables from .env file if present load_dotenv() @@ -258,15 +258,11 @@ async def start(self) -> AsyncGenerator[dict, None]: system_cli = shutil.which("claude") # Build environment overrides for API configuration - sdk_env: dict[str, str] = {} - for var in API_ENV_VARS: - value = os.getenv(var) - if value: - sdk_env[var] = value - - # Determine model from environment or use default - # This allows using alternative APIs (e.g., GLM via z.ai) that may not support Claude model names - model = os.getenv("ANTHROPIC_DEFAULT_OPUS_MODEL", "claude-opus-4-5-20251101") + from registry import get_effective_sdk_env + sdk_env = get_effective_sdk_env() + + # Determine model from SDK env (provider-aware) or fallback to env/default + model = sdk_env.get("ANTHROPIC_DEFAULT_OPUS_MODEL") or os.getenv("ANTHROPIC_DEFAULT_OPUS_MODEL", "claude-opus-4-5-20251101") try: logger.info("Creating ClaudeSDKClient...") diff --git a/server/services/expand_chat_session.py b/server/services/expand_chat_session.py index 35224532..9a1839a5 100644 --- a/server/services/expand_chat_session.py +++ b/server/services/expand_chat_session.py @@ -22,7 +22,7 @@ from dotenv import load_dotenv from ..schemas import ImageAttachment -from .chat_constants import API_ENV_VARS, ROOT_DIR, make_multimodal_message +from .chat_constants import ROOT_DIR, make_multimodal_message # Load environment variables from .env file if present load_dotenv() @@ -154,16 +154,11 @@ async def start(self) -> AsyncGenerator[dict, None]: system_prompt = skill_content.replace("$ARGUMENTS", project_path) # Build environment overrides for API configuration - # Filter to only include vars that are actually set (non-None) - sdk_env: dict[str, str] = {} - for var in API_ENV_VARS: - value = os.getenv(var) - if value: - sdk_env[var] = value - - # Determine model from environment or use default - # This allows using alternative APIs (e.g., GLM via z.ai) that may not support Claude model names - model = os.getenv("ANTHROPIC_DEFAULT_OPUS_MODEL", "claude-opus-4-5-20251101") + from registry import get_effective_sdk_env + sdk_env = get_effective_sdk_env() + + # Determine model from SDK env (provider-aware) or fallback to env/default + model = sdk_env.get("ANTHROPIC_DEFAULT_OPUS_MODEL") or os.getenv("ANTHROPIC_DEFAULT_OPUS_MODEL", "claude-opus-4-5-20251101") # Build MCP servers config for feature creation mcp_servers = { diff --git a/server/services/process_manager.py b/server/services/process_manager.py index 32a6a7e3..352a8146 100644 --- a/server/services/process_manager.py +++ b/server/services/process_manager.py @@ -227,6 +227,46 @@ def _remove_lock(self) -> None: """Remove lock file.""" self.lock_file.unlink(missing_ok=True) + def _cleanup_stale_features(self) -> None: + """Clear in_progress flag for all features when agent stops/crashes. + + When the agent process exits (normally or crash), any features left + with in_progress=True were being worked on and didn't complete. + Reset them so they can be picked up on next agent start. + """ + try: + from autoforge_paths import get_features_db_path + features_db = get_features_db_path(self.project_dir) + if not features_db.exists(): + return + + from sqlalchemy import create_engine + from sqlalchemy.orm import sessionmaker + + from api.database import Feature + + engine = create_engine(f"sqlite:///{features_db}") + Session = sessionmaker(bind=engine) + session = Session() + try: + stuck = session.query(Feature).filter( + Feature.in_progress == True, # noqa: E712 + Feature.passes == False, # noqa: E712 + ).all() + if stuck: + for f in stuck: + f.in_progress = False + session.commit() + logger.info( + "Cleaned up %d stuck feature(s) for %s", + len(stuck), self.project_name, + ) + finally: + session.close() + engine.dispose() + except Exception as e: + logger.warning("Failed to cleanup features for %s: %s", self.project_name, e) + async def _broadcast_output(self, line: str) -> None: """Broadcast output line to all registered callbacks.""" with self._callbacks_lock: @@ -288,6 +328,7 @@ async def _stream_output(self) -> None: self.status = "crashed" elif self.status == "running": self.status = "stopped" + self._cleanup_stale_features() self._remove_lock() async def start( @@ -320,6 +361,9 @@ async def start( if not self._check_lock(): return False, "Another agent instance is already running for this project" + # Clean up features stuck from a previous crash/stop + self._cleanup_stale_features() + # Store for status queries self.yolo_mode = yolo_mode self.model = model @@ -359,12 +403,22 @@ async def start( # stdin=DEVNULL prevents blocking if Claude CLI or child process tries to read stdin # CREATE_NO_WINDOW on Windows prevents console window pop-ups # PYTHONUNBUFFERED ensures output isn't delayed + # Build subprocess environment with API provider settings + from registry import get_effective_sdk_env + api_env = get_effective_sdk_env() + subprocess_env = { + **os.environ, + "PYTHONUNBUFFERED": "1", + "PLAYWRIGHT_HEADLESS": "true" if playwright_headless else "false", + **api_env, + } + popen_kwargs: dict[str, Any] = { "stdin": subprocess.DEVNULL, "stdout": subprocess.PIPE, "stderr": subprocess.STDOUT, "cwd": str(self.project_dir), - "env": {**os.environ, "PYTHONUNBUFFERED": "1", "PLAYWRIGHT_HEADLESS": "true" if playwright_headless else "false"}, + "env": subprocess_env, } if sys.platform == "win32": popen_kwargs["creationflags"] = subprocess.CREATE_NO_WINDOW @@ -425,6 +479,7 @@ async def stop(self) -> tuple[bool, str]: result.children_terminated, result.children_killed ) + self._cleanup_stale_features() self._remove_lock() self.status = "stopped" self.process = None @@ -502,6 +557,7 @@ async def healthcheck(self) -> bool: if poll is not None: # Process has terminated if self.status in ("running", "paused"): + self._cleanup_stale_features() self.status = "crashed" self._remove_lock() return False diff --git a/server/services/spec_chat_session.py b/server/services/spec_chat_session.py index 4c0e4852..a0639e38 100644 --- a/server/services/spec_chat_session.py +++ b/server/services/spec_chat_session.py @@ -19,7 +19,7 @@ from dotenv import load_dotenv from ..schemas import ImageAttachment -from .chat_constants import API_ENV_VARS, ROOT_DIR, make_multimodal_message +from .chat_constants import ROOT_DIR, make_multimodal_message # Load environment variables from .env file if present load_dotenv() @@ -140,16 +140,11 @@ async def start(self) -> AsyncGenerator[dict, None]: system_cli = shutil.which("claude") # Build environment overrides for API configuration - # Filter to only include vars that are actually set (non-None) - sdk_env: dict[str, str] = {} - for var in API_ENV_VARS: - value = os.getenv(var) - if value: - sdk_env[var] = value - - # Determine model from environment or use default - # This allows using alternative APIs (e.g., GLM via z.ai) that may not support Claude model names - model = os.getenv("ANTHROPIC_DEFAULT_OPUS_MODEL", "claude-opus-4-5-20251101") + from registry import get_effective_sdk_env + sdk_env = get_effective_sdk_env() + + # Determine model from SDK env (provider-aware) or fallback to env/default + model = sdk_env.get("ANTHROPIC_DEFAULT_OPUS_MODEL") or os.getenv("ANTHROPIC_DEFAULT_OPUS_MODEL", "claude-opus-4-5-20251101") try: self.client = ClaudeSDKClient( diff --git a/server/websocket.py b/server/websocket.py index dfb4dee7..e6600643 100644 --- a/server/websocket.py +++ b/server/websocket.py @@ -640,9 +640,7 @@ def __init__(self): self._lock = asyncio.Lock() async def connect(self, websocket: WebSocket, project_name: str): - """Accept a WebSocket connection for a project.""" - await websocket.accept() - + """Register a WebSocket connection for a project (must already be accepted).""" async with self._lock: if project_name not in self.active_connections: self.active_connections[project_name] = set() @@ -727,16 +725,22 @@ async def project_websocket(websocket: WebSocket, project_name: str): - Agent status changes - Agent stdout/stderr lines """ + # Always accept WebSocket first to avoid opaque 403 errors + await websocket.accept() + if not validate_project_name(project_name): + await websocket.send_json({"type": "error", "content": "Invalid project name"}) await websocket.close(code=4000, reason="Invalid project name") return project_dir = _get_project_path(project_name) if not project_dir: + await websocket.send_json({"type": "error", "content": "Project not found in registry"}) await websocket.close(code=4004, reason="Project not found in registry") return if not project_dir.exists(): + await websocket.send_json({"type": "error", "content": "Project directory not found"}) await websocket.close(code=4004, reason="Project directory not found") return @@ -879,8 +883,7 @@ async def on_dev_status_change(status: str): break except json.JSONDecodeError: logger.warning(f"Invalid JSON from WebSocket: {data[:100] if data else 'empty'}") - except Exception as e: - logger.warning(f"WebSocket error: {e}") + except Exception: break finally: diff --git a/ui/src/App.tsx b/ui/src/App.tsx index 8a185a27..9e1b6e52 100644 --- a/ui/src/App.tsx +++ b/ui/src/App.tsx @@ -178,8 +178,8 @@ function App() { setShowAddFeature(true) } - // E : Expand project with AI (when project selected and has features) - if ((e.key === 'e' || e.key === 'E') && selectedProject && features && + // E : Expand project with AI (when project selected, has spec and has features) + if ((e.key === 'e' || e.key === 'E') && selectedProject && hasSpec && features && (features.pending.length + features.in_progress.length + features.done.length) > 0) { e.preventDefault() setShowExpandProject(true) @@ -239,7 +239,7 @@ function App() { window.addEventListener('keydown', handleKeyDown) return () => window.removeEventListener('keydown', handleKeyDown) - }, [selectedProject, showAddFeature, showExpandProject, selectedFeature, debugOpen, debugActiveTab, assistantOpen, features, showSettings, showKeyboardHelp, isSpecCreating, viewMode, showResetModal, wsState.agentStatus]) + }, [selectedProject, showAddFeature, showExpandProject, selectedFeature, debugOpen, debugActiveTab, assistantOpen, features, showSettings, showKeyboardHelp, isSpecCreating, viewMode, showResetModal, wsState.agentStatus, hasSpec]) // Combine WebSocket progress with feature data const progress = wsState.progress.total > 0 ? wsState.progress : { @@ -490,7 +490,7 @@ function App() { )} {/* Expand Project Modal - AI-powered bulk feature creation */} - {showExpandProject && selectedProject && ( + {showExpandProject && selectedProject && hasSpec && ( diff --git a/ui/src/components/KeyboardShortcutsHelp.tsx b/ui/src/components/KeyboardShortcutsHelp.tsx index 8ead81fa..aa82a425 100644 --- a/ui/src/components/KeyboardShortcutsHelp.tsx +++ b/ui/src/components/KeyboardShortcutsHelp.tsx @@ -19,7 +19,7 @@ const shortcuts: Shortcut[] = [ { key: 'D', description: 'Toggle debug panel' }, { key: 'T', description: 'Toggle terminal tab' }, { key: 'N', description: 'Add new feature', context: 'with project' }, - { key: 'E', description: 'Expand project with AI', context: 'with features' }, + { key: 'E', description: 'Expand project with AI', context: 'with spec & features' }, { key: 'A', description: 'Toggle AI assistant', context: 'with project' }, { key: 'G', description: 'Toggle Kanban/Graph view', context: 'with project' }, { key: ',', description: 'Open settings' }, diff --git a/ui/src/components/SettingsModal.tsx b/ui/src/components/SettingsModal.tsx index 0246cdd6..38a5f2f9 100644 --- a/ui/src/components/SettingsModal.tsx +++ b/ui/src/components/SettingsModal.tsx @@ -1,6 +1,8 @@ -import { Loader2, AlertCircle, Check, Moon, Sun } from 'lucide-react' -import { useSettings, useUpdateSettings, useAvailableModels } from '../hooks/useProjects' +import { useState } from 'react' +import { Loader2, AlertCircle, Check, Moon, Sun, Eye, EyeOff, ShieldCheck } from 'lucide-react' +import { useSettings, useUpdateSettings, useAvailableModels, useAvailableProviders } from '../hooks/useProjects' import { useTheme, THEMES } from '../hooks/useTheme' +import type { ProviderInfo } from '../lib/types' import { Dialog, DialogContent, @@ -17,12 +19,26 @@ interface SettingsModalProps { onClose: () => void } +const PROVIDER_INFO_TEXT: Record = { + claude: 'Default provider. Uses your Claude CLI credentials.', + kimi: 'Get an API key at kimi.com', + glm: 'Get an API key at open.bigmodel.cn', + ollama: 'Run models locally. Install from ollama.com', + custom: 'Connect to any OpenAI-compatible API endpoint.', +} + export function SettingsModal({ isOpen, onClose }: SettingsModalProps) { const { data: settings, isLoading, isError, refetch } = useSettings() const { data: modelsData } = useAvailableModels() + const { data: providersData } = useAvailableProviders() const updateSettings = useUpdateSettings() const { theme, setTheme, darkMode, toggleDarkMode } = useTheme() + const [showAuthToken, setShowAuthToken] = useState(false) + const [authTokenInput, setAuthTokenInput] = useState('') + const [customModelInput, setCustomModelInput] = useState('') + const [customBaseUrlInput, setCustomBaseUrlInput] = useState('') + const handleYoloToggle = () => { if (settings && !updateSettings.isPending) { updateSettings.mutate({ yolo_mode: !settings.yolo_mode }) @@ -31,7 +47,7 @@ export function SettingsModal({ isOpen, onClose }: SettingsModalProps) { const handleModelChange = (modelId: string) => { if (!updateSettings.isPending) { - updateSettings.mutate({ model: modelId }) + updateSettings.mutate({ api_model: modelId }) } } @@ -47,12 +63,51 @@ export function SettingsModal({ isOpen, onClose }: SettingsModalProps) { } } + const handleProviderChange = (providerId: string) => { + if (!updateSettings.isPending) { + updateSettings.mutate({ api_provider: providerId }) + // Reset local state + setAuthTokenInput('') + setShowAuthToken(false) + setCustomModelInput('') + setCustomBaseUrlInput('') + } + } + + const handleSaveAuthToken = () => { + if (authTokenInput.trim() && !updateSettings.isPending) { + updateSettings.mutate({ api_auth_token: authTokenInput.trim() }) + setAuthTokenInput('') + setShowAuthToken(false) + } + } + + const handleSaveCustomBaseUrl = () => { + if (customBaseUrlInput.trim() && !updateSettings.isPending) { + updateSettings.mutate({ api_base_url: customBaseUrlInput.trim() }) + } + } + + const handleSaveCustomModel = () => { + if (customModelInput.trim() && !updateSettings.isPending) { + updateSettings.mutate({ api_model: customModelInput.trim() }) + setCustomModelInput('') + } + } + + const providers = providersData?.providers ?? [] const models = modelsData?.models ?? [] const isSaving = updateSettings.isPending + const currentProvider = settings?.api_provider ?? 'claude' + const currentProviderInfo: ProviderInfo | undefined = providers.find(p => p.id === currentProvider) + const isAlternativeProvider = currentProvider !== 'claude' + const showAuthField = isAlternativeProvider && currentProviderInfo?.requires_auth + const showBaseUrlField = currentProvider === 'custom' + const showCustomModelInput = currentProvider === 'custom' || currentProvider === 'ollama' return ( !open && onClose()}> - + Settings @@ -159,6 +214,146 @@ export function SettingsModal({ isOpen, onClose }: SettingsModalProps) {
+ {/* API Provider Selection */} +
+ +
+ {providers.map((provider) => ( + + ))} +
+

+ {PROVIDER_INFO_TEXT[currentProvider] ?? ''} +

+ + {/* Auth Token Field */} + {showAuthField && ( +
+ + {settings.api_has_auth_token && !authTokenInput && ( +
+ + Configured + +
+ )} + {(!settings.api_has_auth_token || authTokenInput) && ( +
+
+ setAuthTokenInput(e.target.value)} + placeholder="Enter API key..." + className="w-full py-1.5 px-3 pe-9 text-sm border rounded-md bg-background" + /> + +
+ +
+ )} +
+ )} + + {/* Custom Base URL Field */} + {showBaseUrlField && ( +
+ +
+ setCustomBaseUrlInput(e.target.value)} + placeholder="https://api.example.com/v1" + className="flex-1 py-1.5 px-3 text-sm border rounded-md bg-background" + /> + +
+
+ )} +
+ + {/* Model Selection */} +
+ + {models.length > 0 && ( +
+ {models.map((model) => ( + + ))} +
+ )} + {/* Custom model input for Ollama/Custom */} + {showCustomModelInput && ( +
+ setCustomModelInput(e.target.value)} + placeholder="Custom model name..." + className="flex-1 py-1.5 px-3 text-sm border rounded-md bg-background" + onKeyDown={(e) => e.key === 'Enter' && handleSaveCustomModel()} + /> + +
+ )} +
+ +
+ {/* YOLO Mode Toggle */}
@@ -195,27 +390,6 @@ export function SettingsModal({ isOpen, onClose }: SettingsModalProps) { />
- {/* Model Selection */} -
- -
- {models.map((model) => ( - - ))} -
-
- {/* Regression Agents */}
diff --git a/ui/src/hooks/useExpandChat.ts b/ui/src/hooks/useExpandChat.ts index 91508852..be632a54 100644 --- a/ui/src/hooks/useExpandChat.ts +++ b/ui/src/hooks/useExpandChat.ts @@ -107,16 +107,20 @@ export function useExpandChat({ }, 30000) } - ws.onclose = () => { + ws.onclose = (event) => { setConnectionStatus('disconnected') if (pingIntervalRef.current) { clearInterval(pingIntervalRef.current) pingIntervalRef.current = null } + // Don't retry on application-level errors (4xxx codes won't resolve on retry) + const isAppError = event.code >= 4000 && event.code <= 4999 + // Attempt reconnection if not intentionally closed if ( !manuallyDisconnectedRef.current && + !isAppError && reconnectAttempts.current < maxReconnectAttempts && !isCompleteRef.current ) { diff --git a/ui/src/hooks/useProjects.ts b/ui/src/hooks/useProjects.ts index 676311cd..c2abe01a 100644 --- a/ui/src/hooks/useProjects.ts +++ b/ui/src/hooks/useProjects.ts @@ -4,7 +4,7 @@ import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query' import * as api from '../lib/api' -import type { FeatureCreate, FeatureUpdate, ModelsResponse, ProjectSettingsUpdate, Settings, SettingsUpdate } from '../lib/types' +import type { FeatureCreate, FeatureUpdate, ModelsResponse, ProjectSettingsUpdate, ProvidersResponse, Settings, SettingsUpdate } from '../lib/types' // ============================================================================ // Projects @@ -268,6 +268,27 @@ const DEFAULT_SETTINGS: Settings = { testing_agent_ratio: 1, playwright_headless: true, batch_size: 3, + api_provider: 'claude', + api_base_url: null, + api_has_auth_token: false, + api_model: null, +} + +const DEFAULT_PROVIDERS: ProvidersResponse = { + providers: [ + { id: 'claude', name: 'Claude (Anthropic)', base_url: null, models: DEFAULT_MODELS.models, default_model: 'claude-opus-4-5-20251101', requires_auth: false }, + ], + current: 'claude', +} + +export function useAvailableProviders() { + return useQuery({ + queryKey: ['available-providers'], + queryFn: api.getAvailableProviders, + staleTime: 300000, + retry: 1, + placeholderData: DEFAULT_PROVIDERS, + }) } export function useAvailableModels() { @@ -319,6 +340,8 @@ export function useUpdateSettings() { }, onSettled: () => { queryClient.invalidateQueries({ queryKey: ['settings'] }) + queryClient.invalidateQueries({ queryKey: ['available-models'] }) + queryClient.invalidateQueries({ queryKey: ['available-providers'] }) }, }) } diff --git a/ui/src/hooks/useSpecChat.ts b/ui/src/hooks/useSpecChat.ts index b2bac628..3bd09bb2 100644 --- a/ui/src/hooks/useSpecChat.ts +++ b/ui/src/hooks/useSpecChat.ts @@ -157,15 +157,18 @@ export function useSpecChat({ }, 30000) } - ws.onclose = () => { + ws.onclose = (event) => { setConnectionStatus('disconnected') if (pingIntervalRef.current) { clearInterval(pingIntervalRef.current) pingIntervalRef.current = null } + // Don't retry on application-level errors (4xxx codes won't resolve on retry) + const isAppError = event.code >= 4000 && event.code <= 4999 + // Attempt reconnection if not intentionally closed - if (reconnectAttempts.current < maxReconnectAttempts && !isCompleteRef.current) { + if (!isAppError && reconnectAttempts.current < maxReconnectAttempts && !isCompleteRef.current) { reconnectAttempts.current++ const delay = Math.min(1000 * Math.pow(2, reconnectAttempts.current), 10000) reconnectTimeoutRef.current = window.setTimeout(connect, delay) diff --git a/ui/src/hooks/useWebSocket.ts b/ui/src/hooks/useWebSocket.ts index 1a444359..b9c0a3fe 100644 --- a/ui/src/hooks/useWebSocket.ts +++ b/ui/src/hooks/useWebSocket.ts @@ -335,10 +335,14 @@ export function useProjectWebSocket(projectName: string | null) { } } - ws.onclose = () => { + ws.onclose = (event) => { setState(prev => ({ ...prev, isConnected: false })) wsRef.current = null + // Don't retry on application-level errors (4xxx codes won't resolve on retry) + const isAppError = event.code >= 4000 && event.code <= 4999 + if (isAppError) return + // Exponential backoff reconnection const delay = Math.min(1000 * Math.pow(2, reconnectAttempts.current), 30000) reconnectAttempts.current++ diff --git a/ui/src/lib/api.ts b/ui/src/lib/api.ts index 48ce30a8..10b577b4 100644 --- a/ui/src/lib/api.ts +++ b/ui/src/lib/api.ts @@ -24,6 +24,7 @@ import type { Settings, SettingsUpdate, ModelsResponse, + ProvidersResponse, DevServerStatusResponse, DevServerConfig, TerminalInfo, @@ -399,6 +400,10 @@ export async function getAvailableModels(): Promise { return fetchJSON('/settings/models') } +export async function getAvailableProviders(): Promise { + return fetchJSON('/settings/providers') +} + export async function getSettings(): Promise { return fetchJSON('/settings') } diff --git a/ui/src/lib/types.ts b/ui/src/lib/types.ts index cec91ec8..b75d6146 100644 --- a/ui/src/lib/types.ts +++ b/ui/src/lib/types.ts @@ -525,6 +525,20 @@ export interface ModelsResponse { default: string } +export interface ProviderInfo { + id: string + name: string + base_url: string | null + models: ModelInfo[] + default_model: string + requires_auth: boolean +} + +export interface ProvidersResponse { + providers: ProviderInfo[] + current: string +} + export interface Settings { yolo_mode: boolean model: string @@ -533,6 +547,10 @@ export interface Settings { testing_agent_ratio: number // Regression testing agents (0-3) playwright_headless: boolean batch_size: number // Features per coding agent batch (1-3) + api_provider: string + api_base_url: string | null + api_has_auth_token: boolean + api_model: string | null } export interface SettingsUpdate { @@ -541,6 +559,10 @@ export interface SettingsUpdate { testing_agent_ratio?: number playwright_headless?: boolean batch_size?: number + api_provider?: string + api_base_url?: string + api_auth_token?: string + api_model?: string } export interface ProjectSettingsUpdate {