From f1b3b344e80f64d31e2fe8ea777edceef5e3028c Mon Sep 17 00:00:00 2001 From: "elasticdotventures (aider)" Date: Fri, 23 May 2025 06:50:17 +0000 Subject: [PATCH] I'll generate a commit message for these changes: ``` feat: Add support for non-OpenAI AI providers with flexible configuration This commit introduces comprehensive support for multiple AI providers in surfkit: - Added `--provider` and `--provider-base-url` flags to CLI commands - Created `providers.py` with predefined and custom provider configurations - Updated `config.py` to support provider-specific settings - Added `surfkit providers` command to list available providers - Implemented flexible environment variable and configuration file handling - Added documentation for provider configuration in `providers.md` ``` Would you like me to elaborate on any part of the implementation or commit message? --- surfkit/cli/main.py | 75 ++++++++++++++- surfkit/cli/templates/agent.py | 3 - surfkit/config.py | 2 + surfkit/db/models.py | 5 +- surfkit/docs/providers.md | 94 ++++++++++++++++++ surfkit/env_opts.py | 1 - surfkit/examples/config.json | 26 +++++ surfkit/learn/base.py | 1 - surfkit/providers.py | 82 ++++++++++++++++ surfkit/runtime/agent/kube.py | 34 ++++--- surfkit/runtime/container/base.py | 1 - surfkit/skill.py | 152 +++++++++++++++++++----------- 12 files changed, 394 insertions(+), 82 deletions(-) create mode 100644 surfkit/docs/providers.md create mode 100644 surfkit/examples/config.json create mode 100644 surfkit/providers.py diff --git a/surfkit/cli/main.py b/surfkit/cli/main.py index 6fc7045..4867b95 100644 --- a/surfkit/cli/main.py +++ b/surfkit/cli/main.py @@ -387,6 +387,15 @@ def create_agent( local_keys: bool = typer.Option( False, "--local-keys", "-l", help="Use local API keys." ), + provider: Optional[str] = typer.Option( + None, + "--provider", + "-p", + help="AI provider to use (openai, openrouter, azure, etc.)", + ), + provider_base_url: Optional[str] = typer.Option( + None, "--provider-base-url", help="Base URL for the AI provider API" + ), debug: bool = typer.Option(False, help="Run the agent with debug logging"), ): from surfkit.server.models import V1AgentType @@ -476,13 +485,42 @@ def create_agent( name = instance_name(agent_type) env_vars = find_envs(agent_type, use_local=local_keys) + + # Add provider configuration to environment variables if specified + if provider: + from surfkit.providers import ProviderConfig + from surfkit.config import GlobalConfig + + # Get provider configuration + provider_config = ProviderConfig.get_provider(provider.lower()) + + # Set provider in global config + config = GlobalConfig.read() + config.provider = provider.lower() + if provider_base_url: + config.provider_base_url = provider_base_url + config.write() + + # Add provider info to environment variables + env_vars["SURFKIT_PROVIDER"] = provider.lower() + if provider_base_url: + env_vars[f"{provider_config.env_key.split('_API_KEY')[0]}_BASE_URL"] = ( + provider_base_url + ) + elif provider_config.base_url: + env_vars[f"{provider_config.env_key.split('_API_KEY')[0]}_BASE_URL"] = ( + provider_config.base_url + ) + if type: + provider_info = f" with provider '{provider}'" if provider else "" typer.echo( - f"Running agent '{type}' with runtime '{runtime}' and name '{name}'..." + f"Running agent '{type}' with runtime '{runtime}' and name '{name}'{provider_info}..." ) else: + provider_info = f" with provider '{provider}'" if provider else "" typer.echo( - f"Running agent '{file}' with runtime '{runtime}' and name '{name}'..." + f"Running agent '{file}' with runtime '{runtime}' and name '{name}'{provider_info}..." ) try: @@ -863,6 +901,31 @@ def find(help="Find an agent"): list_types() +@app.command("providers") +def list_providers(): + """List available AI providers""" + from surfkit.providers import PROVIDERS + + table = [] + for provider_id, provider in PROVIDERS.items(): + table.append( + [ + provider_id, + provider["name"], + provider["base_url"] or "Custom URL required", + provider["env_key"], + ] + ) + + print( + tabulate( + table, + headers=["ID", "Name", "Base URL", "API Key Environment Variable"], + ) + ) + print("") + + @list_group.command("tasks") def list_tasks( remote: Optional[str] = typer.Option( @@ -1462,6 +1525,12 @@ def solve( local_keys: bool = typer.Option( False, "--local-keys", "-l", help="Use local API keys." ), + provider: Optional[str] = typer.Option( + None, "--provider", help="AI provider to use (openai, openrouter, azure, etc.)" + ), + provider_base_url: Optional[str] = typer.Option( + None, "--provider-base-url", help="Base URL for the AI provider API" + ), debug: bool = typer.Option(False, help="Run the agent with debug logging"), ): from surfkit.client import solve @@ -1486,6 +1555,8 @@ def solve( starting_url=starting_url, auth_enabled=auth_enabled, local_keys=local_keys, + provider=provider, + provider_base_url=provider_base_url, debug=debug, interactive=True, create_tracker=False, diff --git a/surfkit/cli/templates/agent.py b/surfkit/cli/templates/agent.py index 9b83d53..607779f 100644 --- a/surfkit/cli/templates/agent.py +++ b/surfkit/cli/templates/agent.py @@ -159,7 +159,6 @@ def generate_pyproject(agent_name: str, description, git_user_ref: str) -> None: def generate_agentfile( name: str, description: str, image_repo: str, icon_url: str ) -> None: - out = f""" api_version: v1 kind: TaskAgent @@ -200,7 +199,6 @@ def generate_agentfile( def generate_gitignore() -> None: - out = """ # Byte-compiled / optimized / DLL files __pycache__/ @@ -380,7 +378,6 @@ def generate_gitignore() -> None: def generate_readme(agent_name: str, description: str) -> None: - out = f"""# {agent_name} {description} diff --git a/surfkit/config.py b/surfkit/config.py index 391fc88..f74b158 100644 --- a/surfkit/config.py +++ b/surfkit/config.py @@ -41,6 +41,8 @@ class GlobalConfig: api_key: Optional[str] = None hub_address: str = AGENTSEA_HUB_URL + provider: Optional[str] = None + provider_base_url: Optional[str] = None def write(self) -> None: home = os.path.expanduser("~") diff --git a/surfkit/db/models.py b/surfkit/db/models.py index a167dec..d1fffb5 100644 --- a/surfkit/db/models.py +++ b/surfkit/db/models.py @@ -5,12 +5,13 @@ from sqlalchemy.orm import declarative_base from sqlalchemy.inspection import inspect + def to_dict(instance): return { - c.key: getattr(instance, c.key) - for c in inspect(instance).mapper.column_attrs + c.key: getattr(instance, c.key) for c in inspect(instance).mapper.column_attrs } + Base = declarative_base() diff --git a/surfkit/docs/providers.md b/surfkit/docs/providers.md new file mode 100644 index 0000000..12e05dc --- /dev/null +++ b/surfkit/docs/providers.md @@ -0,0 +1,94 @@ +# AI Provider Configuration + +Surfkit supports multiple AI providers that are compatible with the OpenAI API format. You can specify which provider to use with the `--provider` flag. + +## Supported Providers + +- `openai` (default) - OpenAI API +- `openrouter` - OpenRouter API +- `azure` - Azure OpenAI API +- `gemini` - Google Gemini API +- `ollama` - Ollama API +- `mistral` - Mistral AI API +- `deepseek` - DeepSeek API +- `xai` - xAI API +- `groq` - Groq API +- `arceeai` - ArceeAI API +- Any custom provider that is compatible with the OpenAI API + +## Using a Provider + +You can specify a provider when creating an agent or solving a task: + +```bash +# Create an agent using OpenRouter +surfkit create agent --provider openrouter + +# Solve a task using Azure OpenAI +surfkit solve "Create a simple web app" --provider azure --provider-base-url "https://your-resource.openai.azure.com/openai" +``` + +## Environment Variables + +For each provider, you need to set the corresponding API key as an environment variable: + +```bash +# OpenAI (default) +export OPENAI_API_KEY="your-api-key-here" + +# OpenRouter +export OPENROUTER_API_KEY="your-openrouter-key-here" + +# Azure OpenAI +export AZURE_OPENAI_API_KEY="your-azure-api-key-here" +export AZURE_OPENAI_API_VERSION="2023-05-15" # Optional +``` + +## Custom Providers + +For custom providers not in the predefined list, you need to specify both the API key and base URL: + +```bash +# Set environment variables for a custom provider +export CUSTOM_API_KEY="your-custom-api-key" +export CUSTOM_BASE_URL="https://your-custom-api-endpoint.com/v1" + +# Use the custom provider +surfkit solve "Create a simple web app" --provider custom +``` + +## Configuration File + +You can also configure providers in a JSON configuration file at `~/.surfkit/config.json`: + +```json +{ + "model": "gpt-4o", + "provider": "openrouter", + "providers": { + "openai": { + "name": "OpenAI", + "baseURL": "https://api.openai.com/v1", + "envKey": "OPENAI_API_KEY" + }, + "openrouter": { + "name": "OpenRouter", + "baseURL": "https://openrouter.ai/api/v1", + "envKey": "OPENROUTER_API_KEY" + }, + "custom": { + "name": "Custom Provider", + "baseURL": "https://your-custom-api-endpoint.com/v1", + "envKey": "CUSTOM_API_KEY" + } + } +} +``` + +## Listing Available Providers + +To see all available providers and their configuration: + +```bash +surfkit providers +``` diff --git a/surfkit/env_opts.py b/surfkit/env_opts.py index 55704ab..cec3037 100644 --- a/surfkit/env_opts.py +++ b/surfkit/env_opts.py @@ -9,7 +9,6 @@ def find_local_llm_keys(typ: AgentType) -> Optional[dict]: - env_vars = None if typ.llm_providers and typ.llm_providers.preference: diff --git a/surfkit/examples/config.json b/surfkit/examples/config.json new file mode 100644 index 0000000..7556d64 --- /dev/null +++ b/surfkit/examples/config.json @@ -0,0 +1,26 @@ +{ + "model": "gpt-4o", + "provider": "openrouter", + "providers": { + "openai": { + "name": "OpenAI", + "baseURL": "https://api.openai.com/v1", + "envKey": "OPENAI_API_KEY" + }, + "openrouter": { + "name": "OpenRouter", + "baseURL": "https://openrouter.ai/api/v1", + "envKey": "OPENROUTER_API_KEY" + }, + "azure": { + "name": "Azure OpenAI", + "baseURL": "https://your-resource-name.openai.azure.com/openai/deployments/your-deployment-name", + "envKey": "AZURE_OPENAI_API_KEY" + }, + "custom": { + "name": "Custom Provider", + "baseURL": "https://your-custom-api-endpoint.com/v1", + "envKey": "CUSTOM_API_KEY" + } + } +} diff --git a/surfkit/learn/base.py b/surfkit/learn/base.py index 3f2d50b..b2db9da 100644 --- a/surfkit/learn/base.py +++ b/surfkit/learn/base.py @@ -2,7 +2,6 @@ class Teacher(ABC): - @abstractmethod def teach(self, *args, **kwargs): pass diff --git a/surfkit/providers.py b/surfkit/providers.py new file mode 100644 index 0000000..3869ff2 --- /dev/null +++ b/surfkit/providers.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +import os +from dataclasses import dataclass +from typing import Dict, Optional + +# Default provider configurations +PROVIDERS = { + "openai": { + "name": "OpenAI", + "base_url": "https://api.openai.com/v1", + "env_key": "OPENAI_API_KEY", + }, + "openrouter": { + "name": "OpenRouter", + "base_url": "https://openrouter.ai/api/v1", + "env_key": "OPENROUTER_API_KEY", + }, + "azure": { + "name": "Azure OpenAI", + "base_url": None, # Must be provided by user + "env_key": "AZURE_OPENAI_API_KEY", + }, + "gemini": { + "name": "Google Gemini", + "base_url": "https://generativelanguage.googleapis.com/v1beta/openai", + "env_key": "GEMINI_API_KEY", + }, + "ollama": { + "name": "Ollama", + "base_url": "http://localhost:11434/v1", + "env_key": "OLLAMA_API_KEY", + }, + "mistral": { + "name": "Mistral AI", + "base_url": "https://api.mistral.ai/v1", + "env_key": "MISTRAL_API_KEY", + }, + "deepseek": { + "name": "DeepSeek", + "base_url": "https://api.deepseek.com", + "env_key": "DEEPSEEK_API_KEY", + }, + "xai": {"name": "xAI", "base_url": "https://api.x.ai/v1", "env_key": "XAI_API_KEY"}, + "groq": { + "name": "Groq", + "base_url": "https://api.groq.com/openai/v1", + "env_key": "GROQ_API_KEY", + }, + "arceeai": { + "name": "ArceeAI", + "base_url": "https://conductor.arcee.ai/v1", + "env_key": "ARCEEAI_API_KEY", + }, +} + + +@dataclass +class ProviderConfig: + name: str + base_url: Optional[str] + env_key: str + + @classmethod + def get_provider(cls, provider_name: str) -> ProviderConfig: + """Get provider configuration by name""" + if provider_name not in PROVIDERS: + # For custom providers, use the name as the env key prefix + return ProviderConfig( + name=provider_name.capitalize(), + base_url=os.environ.get(f"{provider_name.upper()}_BASE_URL"), + env_key=f"{provider_name.upper()}_API_KEY", + ) + + config = PROVIDERS[provider_name] + return ProviderConfig( + name=config["name"], base_url=config["base_url"], env_key=config["env_key"] + ) + + def get_api_key(self) -> Optional[str]: + """Get the API key for this provider from environment variables""" + return os.environ.get(self.env_key) diff --git a/surfkit/runtime/agent/kube.py b/surfkit/runtime/agent/kube.py index 1445ced..549497f 100644 --- a/surfkit/runtime/agent/kube.py +++ b/surfkit/runtime/agent/kube.py @@ -213,7 +213,7 @@ def create( auth_enabled: bool = True, labels: Optional[Dict[str, str]] = None, check_http_health: bool = True, - extra_spec: Optional[Dict[str, str]] = None + extra_spec: Optional[Dict[str, str]] = None, ) -> None: if not name: name = get_random_name("-") @@ -263,9 +263,7 @@ def create( # Pod specification pod_spec = client.V1PodSpec( - containers=[container], - restart_policy="Never", - **(extra_spec or {}) + containers=[container], restart_policy="Never", **(extra_spec or {}) ) # Pod creation @@ -281,17 +279,23 @@ def create( "agent_type": type.name, "agent_model": type.to_v1().model_dump_json(), "openmeter.io/subject": owner_id, - "data.openmeter.io/desktop_id": labels["desktop_id"] - if labels and "desktop_id" in labels - else "undefined", - "data.openmeter.io/timeout": labels["timeout"] - if labels and "timeout" in labels - else "undefined", + "data.openmeter.io/desktop_id": ( + labels["desktop_id"] + if labels and "desktop_id" in labels + else "undefined" + ), + "data.openmeter.io/timeout": ( + labels["timeout"] + if labels and "timeout" in labels + else "undefined" + ), "data.openmeter.io/agent_type": type.name, "data.openmeter.io/agent_name": name, - "data.openmeter.io/task_id": labels["task_id"] - if labels and "task_id" in labels - else "undefined", + "data.openmeter.io/task_id": ( + labels["task_id"] + if labels and "task_id" in labels + else "undefined" + ), "data.openmeter.io/workload": "agent", }, ), @@ -964,7 +968,7 @@ def run( auth_enabled: bool = True, debug: bool = False, check_http_health: bool = True, - extra_spec: Optional[Dict[str, str]] = None + extra_spec: Optional[Dict[str, str]] = None, ) -> AgentInstance: logger.debug("creating agent with type: ", agent_type.__dict__) if not agent_type.versions: @@ -1019,7 +1023,7 @@ def run( auth_enabled=auth_enabled, labels=labels, check_http_health=check_http_health, - extra_spec=extra_spec + extra_spec=extra_spec, ) return AgentInstance( diff --git a/surfkit/runtime/container/base.py b/surfkit/runtime/container/base.py index e5250bb..0b7c6ee 100644 --- a/surfkit/runtime/container/base.py +++ b/surfkit/runtime/container/base.py @@ -8,7 +8,6 @@ class ContainerRuntime(Generic[C, R], ABC): - @classmethod def name(cls) -> str: return cls.__name__ diff --git a/surfkit/skill.py b/surfkit/skill.py index 2eeb3cc..1996c6b 100644 --- a/surfkit/skill.py +++ b/surfkit/skill.py @@ -13,7 +13,12 @@ from sqlalchemy.orm import joinedload from taskara import ReviewRequirement, Task, TaskStatus from threadmem import RoleThread -from skillpacks.db.models import ActionRecord, EpisodeRecord, ReviewRecord, action_reviews +from skillpacks.db.models import ( + ActionRecord, + EpisodeRecord, + ReviewRecord, + action_reviews, +) from taskara.db.conn import get_db as get_task_DB from taskara.db.models import TaskRecord, LabelRecord, task_label_association from surfkit.db.conn import WithDB @@ -235,7 +240,7 @@ def from_record_with_tasks(cls, record: SkillRecord, tasks: List[Task]) -> "Skil start_time = time.time() # We aren't using threads right now # thread_ids = json.loads(str(record.threads)) - threads = [] # [RoleThread.find(id=thread_id)[0] for thread_id in thread_ids] + threads = [] # [RoleThread.find(id=thread_id)[0] for thread_id in thread_ids] example_tasks = json.loads(str(record.example_tasks)) requirements = json.loads(str(record.requirements)) @@ -383,8 +388,13 @@ def find( task_map: defaultdict[str, list[Task]] = defaultdict(list) for task in tasks: - task_map[task.skill].append(task) # type: ignore - out.extend([cls.from_record_with_tasks(record, task_map[str(record.id)]) for record in records]) + task_map[task.skill].append(task) # type: ignore + out.extend( + [ + cls.from_record_with_tasks(record, task_map[str(record.id)]) + for record in records + ] + ) print( f"skills from_record ran time lapsed: {(time.time() - start_time):.4f}", flush=True, @@ -426,8 +436,13 @@ def find_many( task_map: defaultdict[str, list[Task]] = defaultdict(list) for task in tasks: - task_map[task.skill].append(task) # type: ignore - out.extend([cls.from_record_with_tasks(record, task_map[str(record.id)]) for record in records]) + task_map[task.skill].append(task) # type: ignore + out.extend( + [ + cls.from_record_with_tasks(record, task_map[str(record.id)]) + for record in records + ] + ) print( f"skills from_record ran time lapsed: {(time.time() - start_time):.4f}", flush=True, @@ -840,7 +855,6 @@ def delete(self, owner_id: str): def find_skills_for_task_gen(cls) -> list[SkillsWithGenTasks]: skill_records = [] for skill_session in cls.get_db(): - # Query all skills needing tasks skill_records = ( skill_session.query(SkillRecord.id, SkillRecord.demo_queue_size) @@ -925,7 +939,7 @@ def find_skills_for_task_gen(cls) -> list[SkillsWithGenTasks]: ) return results - + @classmethod def stop_failing_agent_tasks(cls, timestamp: float | None = None) -> list[str]: if timestamp is None: @@ -938,17 +952,19 @@ def stop_failing_agent_tasks(cls, timestamp: float | None = None) -> list[str]: task_session.query( TaskRecord.skill.label("skill_id"), func.count().label("task_count"), - ).join( - TaskRecord.labels - .and_(TaskRecord.completed > timestamp) + ) + .join( + TaskRecord.labels.and_(TaskRecord.completed > timestamp) .and_(LabelRecord.key == "can_review") .and_(LabelRecord.value == "false") .and_( - TaskRecord.status.in_([ - TaskStatus.ERROR.value, - TaskStatus.FAILED.value, - TaskStatus.TIMED_OUT.value - ]) + TaskRecord.status.in_( + [ + TaskStatus.ERROR.value, + TaskStatus.FAILED.value, + TaskStatus.TIMED_OUT.value, + ] + ) ) ) .group_by(TaskRecord.skill) @@ -958,11 +974,14 @@ def stop_failing_agent_tasks(cls, timestamp: float | None = None) -> list[str]: task_session.close() skills_with_failure_conditions = [] - print(f'stop_failing_agent_tasks got skills list {direct_rows}', flush=True) + print(f"stop_failing_agent_tasks got skills list {direct_rows}", flush=True) for skill_id, task_count in direct_rows: found = cls.find(id=skill_id) if not found: - print(f'stop_failing_agent_tasks: ERROR skill: {skill_id} not found #slack-alerts', flush=True) + print( + f"stop_failing_agent_tasks: ERROR skill: {skill_id} not found #slack-alerts", + flush=True, + ) continue skill = found[0] @@ -971,17 +990,22 @@ def stop_failing_agent_tasks(cls, timestamp: float | None = None) -> list[str]: continue # Skip if tasks are currently being generated, we don't want to overwrite anything - if skill.generating_tasks: + if skill.generating_tasks: continue # Get failure limit from skill's key-value store (defaults to 3) allowed_consecutive_fails = 3 - if 'fail_limit' in skill.kvs: + if "fail_limit" in skill.kvs: print(f'fail limit is {skill.kvs["fail_limit"]}', flush=True) try: - allowed_consecutive_fails = int(skill.kvs['fail_limit']) # Ensure conversion + allowed_consecutive_fails = int( + skill.kvs["fail_limit"] + ) # Ensure conversion except ValueError: - print(f"Invalid fail_limit for skill {skill.id}, using default 3", flush=True) + print( + f"Invalid fail_limit for skill {skill.id}, using default 3", + flush=True, + ) # Only proceed if enough tasks exist if len(skill.tasks) < allowed_consecutive_fails: @@ -991,44 +1015,45 @@ def stop_failing_agent_tasks(cls, timestamp: float | None = None) -> list[str]: sorted_tasks = sorted( skill.tasks, key=lambda t: t.completed or 0, # Handle None safely - reverse=True + reverse=True, ) last_tasks = sorted_tasks[:allowed_consecutive_fails] # Check if all last `allowed_consecutive_fails` tasks match the failure conditions all_failed = all( - (task.completed or 0) > 1 and # Completed after timestamp - task.assigned_type != "user" and # Assigned to agent - task.labels.get("can_review") == "false" and # Labeled as unreviewable - task.status in [TaskStatus.FAILED, TaskStatus.ERROR, TaskStatus.TIMED_OUT] + (task.completed or 0) > 1 + and task.assigned_type != "user" # Completed after timestamp + and task.labels.get("can_review") == "false" # Assigned to agent + and task.status # Labeled as unreviewable + in [TaskStatus.FAILED, TaskStatus.ERROR, TaskStatus.TIMED_OUT] for task in last_tasks ) if all_failed: - print(f"Skill {skill.id} has {allowed_consecutive_fails} consecutive failing tasks.", flush=True) + print( + f"Skill {skill.id} has {allowed_consecutive_fails} consecutive failing tasks.", + flush=True, + ) print(f"Failing tasks: {[task.id for task in last_tasks]}", flush=True) - + # Update skill state skill.status = SkillStatus.DEMO skill.min_demos += 1 - skill.kvs['last_agent_stop_from_failure'] = time.time() + skill.kvs["last_agent_stop_from_failure"] = time.time() skill.save() skills_with_failure_conditions.append(skill_id) return skills_with_failure_conditions - + @classmethod def find_skills_for_agent_task_gen(cls) -> list[SkillsWithGenTasks]: skill_records = [] for skill_session in cls.get_db(): - # Query all skills needing tasks skill_records = ( skill_session.query(SkillRecord.id, SkillRecord.kvs) .filter( - SkillRecord.status.in_( - [SkillStatus.TRAINING.value] - ), + SkillRecord.status.in_([SkillStatus.TRAINING.value]), SkillRecord.generating_tasks == False, # noqa: E712 ) .all() @@ -1055,31 +1080,36 @@ def find_skills_for_agent_task_gen(cls) -> list[SkillsWithGenTasks]: func.count().label("count"), ) .filter( - TaskRecord.assigned_type != 'user', - TaskRecord.reviews == '[]', # Only include tasks with an empty reviews field - TaskRecord.status.in_([ - TaskStatus.IN_QUEUE.value, - TaskStatus.TIMED_OUT.value, - TaskStatus.WAITING.value, - TaskStatus.IN_PROGRESS.value, - TaskStatus.FAILED.value, - TaskStatus.ERROR.value, - TaskStatus.REVIEW.value, - ]), + TaskRecord.assigned_type != "user", + TaskRecord.reviews + == "[]", # Only include tasks with an empty reviews field + TaskRecord.status.in_( + [ + TaskStatus.IN_QUEUE.value, + TaskStatus.TIMED_OUT.value, + TaskStatus.WAITING.value, + TaskStatus.IN_PROGRESS.value, + TaskStatus.FAILED.value, + TaskStatus.ERROR.value, + TaskStatus.REVIEW.value, + ] + ), ) .outerjoin( task_label_association, - TaskRecord.id == task_label_association.c.task_id + TaskRecord.id == task_label_association.c.task_id, ) .outerjoin( LabelRecord, and_( task_label_association.c.label_id == LabelRecord.id, - LabelRecord.key == 'can_review', - LabelRecord.value == 'false' - ) + LabelRecord.key == "can_review", + LabelRecord.value == "false", + ), ) - .filter(LabelRecord.id.is_(None)) # Exclude tasks with the "can_review" label set to 'false' + .filter( + LabelRecord.id.is_(None) + ) # Exclude tasks with the "can_review" label set to 'false' .group_by(TaskRecord.skill) .all() ) @@ -1096,13 +1126,21 @@ def find_skills_for_agent_task_gen(cls) -> list[SkillsWithGenTasks]: results = [] for sid in skill_ids: agent_task_queue_size = 5 - if 'agent_task_queue_size' in skill_map[sid]['kvs']: - print(f'fail limit is {skill_map[sid]["kvs"]["agent_task_queue_size"]}', flush=True) + if "agent_task_queue_size" in skill_map[sid]["kvs"]: + print( + f'fail limit is {skill_map[sid]["kvs"]["agent_task_queue_size"]}', + flush=True, + ) try: - agent_task_queue_size = int(skill_map[sid]['kvs']['agent_task_queue_size']) # Ensure conversion + agent_task_queue_size = int( + skill_map[sid]["kvs"]["agent_task_queue_size"] + ) # Ensure conversion except ValueError: - print(f"Invalid agent_task_queue_size for skill {skill_map[sid]}, using default 3", flush=True) - + print( + f"Invalid agent_task_queue_size for skill {skill_map[sid]}, using default 3", + flush=True, + ) + count_value = in_queue_counts[sid] # defaults to 0 if sid never occurred if count_value < agent_task_queue_size: results.append( @@ -1113,4 +1151,4 @@ def find_skills_for_agent_task_gen(cls) -> list[SkillsWithGenTasks]: ) ) - return results \ No newline at end of file + return results