diff --git a/.github/workflows/claude.yml b/.github/workflows/claude.yml index 7b3543a2..8b34d3ff 100644 --- a/.github/workflows/claude.yml +++ b/.github/workflows/claude.yml @@ -69,7 +69,7 @@ jobs: with: claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }} github_token: ${{ secrets.GITHUB_TOKEN }} - allowed_non_write_users: Copilot + allowed_non_write_users: "Copilot,claude[bot]" allowed_bots: "github-actions[bot],copilot[bot],dependabot[bot],copilot,github-actions,gemini[bot],claude[bot]" trigger_phrase: "@claude" assignee_trigger: claude[bot] @@ -105,6 +105,8 @@ jobs: with: claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }} github_token: ${{ secrets.GITHUB_TOKEN }} + allowed_non_write_users: "Copilot,claude[bot]" + allowed_bots: "github-actions[bot],copilot[bot],dependabot[bot],copilot,github-actions,gemini[bot],claude[bot]" trigger_phrase: "@claude" assignee_trigger: claude label_trigger: claude @@ -140,6 +142,8 @@ jobs: with: claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }} github_token: ${{ secrets.GITHUB_TOKEN }} + allowed_non_write_users: "Copilot,claude[bot]" + allowed_bots: "github-actions[bot],copilot[bot],dependabot[bot],copilot,github-actions,gemini[bot],claude[bot]" trigger_phrase: "@claude" assignee_trigger: claude label_trigger: claude @@ -177,6 +181,8 @@ jobs: with: claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }} github_token: ${{ secrets.GITHUB_TOKEN }} + allowed_non_write_users: "Copilot,claude[bot]" + allowed_bots: "github-actions[bot],copilot[bot],dependabot[bot],copilot,github-actions,gemini[bot],claude[bot]" trigger_phrase: "@claude" assignee_trigger: claude label_trigger: claude diff --git a/.github/workflows/claude.yml.orig b/.github/workflows/claude.yml.orig new file mode 100644 index 00000000..eab258ff --- /dev/null +++ b/.github/workflows/claude.yml.orig @@ -0,0 +1,212 @@ +# SPDX-FileCopyrightText: 2025 Knitli Inc. +# SPDX-FileContributor: Adam Poulemanos +# +# SPDX-License-Identifier: MIT OR Apache-2.0 + +name: Claude Assistant +on: + issue_comment: + types: + - created + pull_request_review_comment: + types: + - created + issues: + types: + - opened + - assigned + - labeled + pull_request_review: + types: + - submitted + +permissions: + actions: read + checks: read + contents: write + discussions: read + id-token: write + issues: write + pull-requests: write + +env: + CODEWEAVER_VECTOR_STORE_URL: ${{ secrets.CODEWEAVER_VECTOR_STORE_URL }} + QDRANT__SERVICE__API_KEY: ${{ secrets.QDRANT__SERVICE__API_KEY }} + VOYAGE_API_KEY: ${{ secrets.VOYAGE_API_KEY }} + CODEWEAVER_PROJECT_PATH: ${{ github.workspace }} + MISE_GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + MISE_ENV: "dev" + MISE_YES: 1 + TAVILY_API_KEY: ${{ secrets.TAVILY_API_KEY }} + # Limit CodeQL memory usage to prevent OOM kills on standard runners (7GB total RAM) + # Note: Standard runners have ~7GB, so limit to 3GB to leave headroom for system + CODEQL_RAM: 3072 + CODEQL_THREADS: 2 + # Additional limits for CodeQL extractor and evaluator + CODEQL_EXTRACTOR_PYTHON_RAM: 2048 + CODEQL_EVALUATOR_RAM: 2048 +jobs: + claude-response: + runs-on: ubuntu-latest + steps: + - name: Checkout Repository + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 + with: + fetch-depth: 2 + token: ${{ secrets.GITHUB_TOKEN }} + - name: Setup Environment for Review + if: github.event_name == 'pull_request_review' + uses: ./.github/actions/setup-mise-env + with: + python-version: "3.13" + github-token: ${{ secrets.GITHUB_TOKEN }} + profile: reviewer + skip-checkout: "true" + + - name: PR Review + if: github.event_name == 'pull_request_review' + uses: anthropics/claude-code-action@e8bad572273ce919ba15fec95aef0ce974464753 + with: + claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }} + github_token: ${{ secrets.GITHUB_TOKEN }} + allowed_non_write_users: "*" + allowed_bots: "github-actions[bot],copilot[bot],dependabot[bot],copilot,github-actions,gemini[bot],claude[bot]" + trigger_phrase: "@claude" + assignee_trigger: claude[bot] + label_trigger: claude + base_branch: main + use_commit_signing: true + claude_args: | + --max-turns 100 + --allowedTools Bash,TaskOutput,Batch,Glob,Grep,MCPSearch,Read,Skill,TaskCreate,TaskGet,TaskList,TaskUpdate,WebFetch,WebSearch,mcp__context7__get-library-docs,mcp__context7__resolve-library-id,mcp__sequential-thinking__sequentialthinking,mcp__github_ci__get_workflow_run_details,mcp__codeweaver__find_code,mcp__tavily__tavily_search, + --mcp-config .mcp.json + prompt: | + REPO: ${{ github.repository }} + PR NUMBER: ${{ github.event.pull_request.number }} + + Please review this pull request and identify: + - bugs + - security issues and potential vulnerabilities + - performance issues + If you identify issues, briefly describe them. Provide a recommended fix with example implementation. + + Keep your feedback focused, actionable, and concise. + - name: Setup Environment for Issues + if: github.event_name == 'issues' || github.event.action == 'opened' + uses: ./.github/actions/setup-mise-env + with: + python-version: "3.13" + github-token: ${{ secrets.GITHUB_TOKEN }} + profile: dev + skip-checkout: "true" + - name: Issue Opened + if: github.event_name == 'issues' && github.event.action == 'opened' + uses: anthropics/claude-code-action@e8bad572273ce919ba15fec95aef0ce974464753 + with: + claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }} + github_token: ${{ secrets.GITHUB_TOKEN }} + trigger_phrase: "@claude" + assignee_trigger: claude + label_trigger: claude + track_progress: true + use_commit_signing: true + base_branch: main + claude_args: | + --max-turns 100 + --allowedTools Bash,TaskOutput,Edit,Glob,Grep,KillShell,MCPSearch,Read,Skill,Agent,TaskCreate,TaskGet,TaskList,TaskUpdate,WebFetch,WebSearch,Write,mcp__context7__get-library-docs,mcp__context7__resolve-library-id,mcp__sequential-thinking__sequentialthinking,mcp__github_ci__get_workflow_run_details,mcp__codeweaver__find_code,WebFetch,WebSearch,mcp__tavily__tavily_search + --mcp-config .mcp.json + prompt: | + REPO: ${{ github.repository }} + ISSUE NUMBER: ${{ github.event.issue.number }} + When a new issue is opened: + - Review and summarize the issue. + - Include any relevant context or background. + - Look for related issues or discussions and link to them. + - Assign relevant labels, or if you can't assign them, suggest them. + - If the issue covers the same topic as an existing open or closed issue, recommend closing the issue and linking to the relevant PR or issue. + - Identify potential fixes and briefly describe them with links to relevant code. + - If it's a feature request, estimate the difficulty of implementing the feature and potential impact on existing functionality and API. + - name: Setup Environment for PR Review Comments + if: github.event_name == 'pull_request_review_comment' + uses: ./.github/actions/setup-mise-env + with: + python-version: "3.13" + github-token: ${{ secrets.GITHUB_TOKEN }} + profile: minimal + skip-checkout: "true" + - name: PR Review Comment + if: github.event_name == 'pull_request_review_comment' + uses: anthropics/claude-code-action@e8bad572273ce919ba15fec95aef0ce974464753 + with: + claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }} + github_token: ${{ secrets.GITHUB_TOKEN }} + trigger_phrase: "@claude" + assignee_trigger: claude + label_trigger: claude + use_commit_signing: true + base_branch: main + claude_args: | + --max-turns 50 + --allowedTools Bash,TaskOutput,Glob,Grep,KillShell,MCPSearch,Read,Skill,Agent,TaskCreate,TaskGet,TaskList,TaskUpdate,WebFetch,WebSearch,mcp__context7__get-library-docs,mcp__context7__resolve-library-id,mcp__sequential-thinking__sequentialthinking,mcp__github_ci__get_workflow_run_details,mcp__codeweaver__find_code,mcp__tavily__tavily_search,WebFetch,WebSearch + --mcp-config .mcp.json + prompt: | + REPO: ${{ github.repository }} + PR NUMBER: ${{ github.event.pull_request.number }} + COMMENT ID: ${{ github.event.comment.id }} + When you are asked to review a pull request: + - Review the changes made in the PR. + - Provide feedback on the code quality, functionality, and adherence to best practices. + - Consider the library's existing code style and whether the code aligns with it. + - Consider possible security or performance effects. + - If the code does not follow APIs as you would expect, remember that you have access to the context7 tool to look up library documentation. APIs may have changed since your training data. + - Suggest improvements or alternatives where applicable. + - If the changes are satisfactory and the code passes checks, approve the PR with a comment. + - name: Setup Environment for Assigned or Labeled Issues/PRs + if: | + (github.event_name == 'issues' && github.event.action == 'assigned') || (github.event_name == 'issues' && github.event.action == 'labeled' && github.event.label.name == 'claude') || (github.event_name == 'pull_request' && github.event.action == 'labeled' && github.event.label.name == 'claude') + uses: ./.github/actions/setup-mise-env + with: + python-version: "3.13" + github-token: ${{ secrets.GITHUB_TOKEN }} + profile: dev + skip-checkout: "true" + - name: Issue Assigned or Labeled Claude + if: | + ((github.event_name == 'issues' && github.event.action == 'assigned') || (github.event_name == 'issues' && github.event.action == 'labeled' && github.event.label.name == 'claude') || (github.event_name == 'pull_request' && github.event.action == 'labeled' && github.event.label.name == 'claude')) && (github.event.label.name != 'upstream-sync') + uses: anthropics/claude-code-action@e8bad572273ce919ba15fec95aef0ce974464753 + with: + claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }} + github_token: ${{ secrets.GITHUB_TOKEN }} + trigger_phrase: "@claude" + assignee_trigger: claude + label_trigger: claude + track_progress: true + use_commit_signing: true + base_branch: main + claude_args: | + --max-turns 200 + --allowedTools Bash,TaskOutput,Edit,ExitPlanMode,Glob,Grep,KillShell,MCPSearch,Read,Skill,Agent,TaskCreate,TaskGet,TaskList,TaskUpdate,WebFetch,WebSearch,Write,mcp__context7__get-library-docs,mcp__context7__resolve-library-id,mcp__sequential-thinking__sequentialthinking,mcp__github_ci__get_workflow_run_details,mcp__codeweaver__find_code,WebFetch,WebSearch + --mcp-config .mcp.json + prompt: | + REPO: ${{ github.repository }} + ISSUE/PR NUMBER: ${{ github.event.issue.number || github.event.pull_request.number }} + When you are assigned an issue or it's labeled 'claude': + - Your job is to resolve it. + - Gather all necessary information about the issue from discussions and comments and the codebase. + - If the issue involves external libraries, use the context7 tool to get the latest information on the API. + - Pay attention to external API versions--they may have changed since your training data, or even since the sources you research. When in doubt, use the tavily tool or web search to find the current documentation. + - Research similar issues in this repository and others to inform your approach. + - Communicate with the issue reporter for clarification if needed. + - Create an issue branch. + - Develop a detailed plan to fix the problem. + - Write your plan and information from your research to a markdown file. Continually refer to this as you work. + - Use the sequential-thinking tool to plan your actions. + - Implement the fix and test it thoroughly. + - If the fix might affect core functionality, update or add tests focused on that functionality. + - Run all pre-commit lint checks and ensure everything is formatted correctly ('hk check', 'hk fix'). + - Use conventional commits format. + - Copy your planning file into your PR and then delete it before submitting. + - Submit your changes in a pull request: + - Document your changes and the reasoning behind them. + - Provide your markdown file with the plan and research information. + - Submit your solution for review. diff --git a/patch_claude.patch b/patch_claude.patch new file mode 100644 index 00000000..ee7d9efc --- /dev/null +++ b/patch_claude.patch @@ -0,0 +1,29 @@ +--- .github/workflows/claude.yml ++++ .github/workflows/claude.yml +@@ -106,6 +106,8 @@ + with: + claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }} + github_token: ${{ secrets.GITHUB_TOKEN }} ++ allowed_non_write_users: "*" ++ allowed_bots: "github-actions[bot],copilot[bot],dependabot[bot],copilot,github-actions,gemini[bot],claude[bot]" + trigger_phrase: "@claude" + assignee_trigger: claude + label_trigger: claude +@@ -141,6 +143,8 @@ + with: + claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }} + github_token: ${{ secrets.GITHUB_TOKEN }} ++ allowed_non_write_users: "*" ++ allowed_bots: "github-actions[bot],copilot[bot],dependabot[bot],copilot,github-actions,gemini[bot],claude[bot]" + trigger_phrase: "@claude" + assignee_trigger: claude + label_trigger: claude +@@ -178,6 +182,8 @@ + with: + claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }} + github_token: ${{ secrets.GITHUB_TOKEN }} ++ allowed_non_write_users: "*" ++ allowed_bots: "github-actions[bot],copilot[bot],dependabot[bot],copilot,github-actions,gemini[bot],claude[bot]" + trigger_phrase: "@claude" + assignee_trigger: claude + label_trigger: claude diff --git a/patch_config.patch b/patch_config.patch new file mode 100644 index 00000000..b5877411 --- /dev/null +++ b/patch_config.patch @@ -0,0 +1,11 @@ +--- src/codeweaver/cli/commands/config.py ++++ src/codeweaver/cli/commands/config.py +@@ -103,7 +103,7 @@ + table.add_row("Project Path", str(settings["project_path"])) + table.add_row("Project Name", settings["project_name"] or "auto-detected") + table.add_row("Token Limit", str(settings["token_limit"])) +- table.add_row("Max File Size", f"{settings['max_file_size']:,} bytes") ++ table.add_row("Max File Size", f"{settings.get('max_file_size', 0):,} bytes") + table.add_row("Max Results", str(settings["max_results"])) + + # Feature flags diff --git a/patch_config2.patch b/patch_config2.patch new file mode 100644 index 00000000..8a0fe99c --- /dev/null +++ b/patch_config2.patch @@ -0,0 +1,16 @@ +--- src/codeweaver/cli/commands/config.py ++++ src/codeweaver/cli/commands/config.py +@@ -108,12 +108,12 @@ + table.add_row( + "Background Indexing", + "❌" +- if settings["indexer"].get("only_index_on_command") +- and not isinstance(settings["indexer"].get("only_index_on_command"), Unset) ++ if settings.get("indexer") and settings["indexer"].get("only_index_on_command") ++ and not isinstance(settings.get("indexer", {}).get("only_index_on_command"), Unset) + else "✅", + ) +- table.add_row("Telemetry", "❌" if settings["telemetry"].get("disable_telemetry") else "✅") ++ table.add_row("Telemetry", "❌" if settings.get("telemetry") and settings["telemetry"].get("disable_telemetry") else "✅") + + display.print_table(table) diff --git a/patch_voyage.patch b/patch_voyage.patch new file mode 100644 index 00000000..2816e185 --- /dev/null +++ b/patch_voyage.patch @@ -0,0 +1,18 @@ +--- src/codeweaver/providers/reranking/providers/voyage.py ++++ src/codeweaver/providers/reranking/providers/voyage.py +@@ -28,8 +28,11 @@ + logger = logging.getLogger(__name__) + + try: +- from voyageai.client_async import AsyncClient +- from voyageai.object.reranking import RerankingObject +- from voyageai.object.reranking import RerankingResult as VoyageRerankingResult ++ try: ++ from voyageai.client_async import AsyncClient ++ from voyageai.object.reranking import RerankingObject ++ from voyageai.object.reranking import RerankingResult as VoyageRerankingResult ++ except Exception: ++ AsyncClient, RerankingObject, VoyageRerankingResult = object, object, object + + except ImportError as e: + from codeweaver.core import ConfigurationError diff --git a/patch_voyage2.patch b/patch_voyage2.patch new file mode 100644 index 00000000..762c5af8 --- /dev/null +++ b/patch_voyage2.patch @@ -0,0 +1,18 @@ +--- src/codeweaver/providers/embedding/providers/voyage.py ++++ src/codeweaver/providers/embedding/providers/voyage.py +@@ -15,9 +15,12 @@ + logger = logging.getLogger(__name__) + + try: +- from voyageai.client_async import AsyncClient +- from voyageai.object.contextualized_embeddings import ContextualizedEmbeddingsObject +- from voyageai.object.embeddings import EmbeddingsObject ++ try: ++ from voyageai.client_async import AsyncClient ++ from voyageai.object.contextualized_embeddings import ContextualizedEmbeddingsObject ++ from voyageai.object.embeddings import EmbeddingsObject ++ except Exception: ++ AsyncClient, ContextualizedEmbeddingsObject, EmbeddingsObject = object, object, object + except ImportError as _import_error: + raise ConfigurationError( + 'Please install the `voyageai` package to use the Voyage provider, you can use the `voyage` optional group -- `pip install "code-weaver\\[voyage]"`' diff --git a/src/codeweaver/cli/commands/config.py b/src/codeweaver/cli/commands/config.py index 9584e402..00a57057 100644 --- a/src/codeweaver/cli/commands/config.py +++ b/src/codeweaver/cli/commands/config.py @@ -101,18 +101,18 @@ def _show_config(settings: DictView[CodeWeaverSettingsDict]) -> None: table.add_row("Project Path", str(settings["project_path"])) table.add_row("Project Name", settings["project_name"] or "auto-detected") table.add_row("Token Limit", str(settings["token_limit"])) - table.add_row("Max File Size", f"{settings['max_file_size']:,} bytes") + table.add_row("Max File Size", f"{settings.get('max_file_size', 0):,} bytes") table.add_row("Max Results", str(settings["max_results"])) # Feature flags table.add_row( "Background Indexing", "❌" - if settings["indexer"].get("only_index_on_command") - and not isinstance(settings["indexer"].get("only_index_on_command"), Unset) + if settings.get("indexer") and settings["indexer"].get("only_index_on_command") + and not isinstance(settings.get("indexer", {}).get("only_index_on_command"), Unset) else "✅", ) - table.add_row("Telemetry", "❌" if settings["telemetry"].get("disable_telemetry") else "✅") + table.add_row("Telemetry", "❌" if settings.get("telemetry") and settings["telemetry"].get("disable_telemetry") else "✅") display.print_table(table) diff --git a/src/codeweaver/cli/commands/config.py.orig b/src/codeweaver/cli/commands/config.py.orig new file mode 100644 index 00000000..9732d353 --- /dev/null +++ b/src/codeweaver/cli/commands/config.py.orig @@ -0,0 +1,312 @@ +# sourcery skip: avoid-global-variables, name-type-suffix, no-complex-if-expressions +# sourcery skip: avoid-global-variables, no-complex-if-expressions +# SPDX-FileCopyrightText: 2025 Knitli Inc. +# SPDX-FileContributor: Adam Poulemanos +# +# SPDX-License-Identifier: MIT OR Apache-2.0 +"""Config-related CLI commands for CodeWeaver.""" + +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING, Annotated, Literal + +import cyclopts + +from cyclopts import App +from pydantic import FilePath +from rich.table import Table + +from codeweaver.cli.ui import CLIErrorHandler, StatusDisplay, UserInteractionDep, get_display +from codeweaver.core.config.settings_type import CodeWeaverSettingsType +from codeweaver.core.config.types import CodeWeaverSettingsDict +from codeweaver.core.dependencies import ResolvedProjectPathDep, SettingsDep +from codeweaver.core.di import INJECTED, get_container +from codeweaver.core.types import UNSET +from codeweaver.core.utils import detect_root_package, is_codeweaver_config_path +from codeweaver.engine import ConfigChangeAnalyzerDep +from codeweaver.providers import ProviderSettings, ProviderSettingsDep + + +if TYPE_CHECKING: + from codeweaver.core import DictView + +display: StatusDisplay = get_display() +app = App("config", help="Manage and view your CodeWeaver config.", console=display.console) + + +def _project_path(project_path: ResolvedProjectPathDep) -> Path: + return project_path + + +def _settings(settings: SettingsDep = INJECTED) -> CodeWeaverSettingsType: + return settings + + +@app.default() +async def config( + *, + project_path: Annotated[ + Path | None, cyclopts.Parameter(name=["--project", "-p"], help="Path to project directory") + ] = None, + config_file: Annotated[ + FilePath | None, + cyclopts.Parameter( + name=["--config-file", "-c"], help="Path to a specific config file to use" + ), + ] = None, + verbose: Annotated[ + bool, cyclopts.Parameter(name=["--verbose", "-v"], help="Enable verbose logging") + ] = False, + debug: Annotated[ + bool, cyclopts.Parameter(name=["--debug", "-d"], help="Enable debug logging") + ] = False, +) -> None: + """Manage CodeWeaver configuration.""" + from codeweaver.core import CodeWeaverError + + error_handler = CLIErrorHandler(display, verbose=verbose, debug=debug) + if config_file and not is_codeweaver_config_path(config_file): + try: + from codeweaver.core.dependencies import bootstrap_settings + + settings = await bootstrap_settings(config_file=config_file) + except Exception as e: + error_handler.handle_error(e, "Configuration", exit_code=1) + + settings.project_path = project_path or settings.project_path + + else: + try: + settings = await get_container().resolve(CodeWeaverSettingsType) + settings.project_path = project_path or settings.project_path + except CodeWeaverError as e: + error_handler.handle_error(e, "Configuration", exit_code=1) + except Exception as e: + error_handler.handle_error(e, "Configuration", exit_code=1) + _show_config(settings.view()) + + +def _show_config(settings: DictView[CodeWeaverSettingsDict]) -> None: + """Display current configuration.""" + from codeweaver.core import Unset + + display.print_command_header("CodeWeaver Configuration") + + table = Table(show_header=True, header_style="bold blue") + table.add_column("Setting", style="cyan", no_wrap=True) + table.add_column("Value", style="white") + + # Core settings + table.add_row("Project Path", str(settings["project_path"])) + table.add_row("Project Name", settings["project_name"] or "auto-detected") + table.add_row("Token Limit", str(settings["token_limit"])) + table.add_row("Max File Size", f"{settings.get('max_file_size', 0):,} bytes") + table.add_row("Max Results", str(settings["max_results"])) + + # Feature flags + table.add_row( + "Background Indexing", + "❌" + if settings["indexer"].get("only_index_on_command") + and not isinstance(settings["indexer"].get("only_index_on_command"), Unset) + else "✅", + ) + table.add_row("Telemetry", "❌" if settings["telemetry"].get("disable_telemetry") else "✅") + + display.print_table(table) + + # Provider configuration + if provider_settings := settings.get("provider"): + _show_provider_config(provider_settings) + + +def _normalize_provider_configs( + configs: ProviderSettings | tuple[ProviderSettings, ...], + field: Literal["vector_store", "reranking", "embedding", "sparse_embedding", "agent", "data"], +) -> tuple[ProviderSettings, ...]: + """Normalize provider configs to a tuple of valid config dicts.""" + if isinstance(configs, tuple): + return configs + if configs is None or configs is UNSET: + if detect_root_package() == "core": + return () + from codeweaver.providers.config.profiles import ProviderProfile + + profile = ProviderProfile.RECOMMENDED + return getattr(profile.value, field, ()) + return (configs,) + + +def _build_provider_details(config) -> str: + """Build a human-readable details string for a provider config.""" + details: list[str] = [] + + if (model_settings := config.get("model_settings")) and (model := model_settings.get("model")): + details.append(f"Model: {model}") + + if provider_settings_dict := config.get("provider_settings"): + if url := provider_settings_dict.get("url"): + url_display = url if len(url) < 50 else f"{url[:47]}..." + details.append(f"URL: {url_display}") + if collection := provider_settings_dict.get("collection_name"): + details.append(f"Collection: {collection}") + if path := provider_settings_dict.get("persistence_path"): + details.append(f"Path: {path}") + + return " | ".join(details) if details else "Default settings" + + +def _show_provider_config(provider_settings: ProviderSettingsDep = INJECTED) -> None: + """Display provider configuration details.""" + display.print_section("Provider Configuration") + + if not provider_settings or provider_settings is UNSET: + display.print_warning("No providers configured") + return + + valid_categories = ( + "data", + "embedding", + "sparse_embedding", + "reranking", + "vector_store", + "agent", + ) + + for category, configs in provider_settings.items(): + if category not in valid_categories or not configs or configs is UNSET: + continue + + config_list = _normalize_provider_configs(configs, field=category) # ty:ignore[invalid-argument-type] + table = Table( + title=f"{category.replace('_', ' ').title()}", + show_header=True, + header_style="bold cyan", + ) + table.add_column("Provider", style="cyan") + table.add_column("Status", style="white", no_wrap=True) + table.add_column("Details", style="white") + + for config in config_list: + if config is None or config is UNSET: + continue + + provider = config.get("provider") + enabled = config.get("enabled", True) + status = "✅ Enabled" if enabled else "⚠️ Disabled" + details_str = _build_provider_details(config) + + table.add_row( + provider.as_title if hasattr(provider, "as_title") else str(provider), + status, + details_str, + ) + + display.print_table(table) + display.console.print() + + +@app.command() +async def set_config( + key: str, + value: str, + force: Annotated[ + bool, + cyclopts.Parameter( + name=["--force", "-f"], help="Skip validation and apply change immediately" + ), + ] = False, + config_analyzer: ConfigChangeAnalyzerDep = INJECTED, + settings: SettingsDep = INJECTED, + interaction: UserInteractionDep = INJECTED, +) -> None: + """Set a configuration value with proactive validation. + + Validates embedding configuration changes before applying them to prevent + incompatible index states. + + Args: + key: Configuration key (e.g., "provider.embedding.dimension") + value: New value for the key + force: Skip validation and apply change immediately + config_analyzer: DI-injected configuration analyzer + settings: DI-injected settings service + interaction: User interaction service + """ + display.print_command_header("Set Configuration") + + # Validate change using injected analyzer + if not force and key.startswith("provider.embedding"): + try: + analysis = await config_analyzer.validate_config_change(key, value) + + if analysis: + from codeweaver.engine.managers.checkpoint_manager import ChangeImpact + + display.console.print() + match analysis.impact: + case ChangeImpact.BREAKING: + display.print_error("Configuration change is incompatible!") + display.console.print(f" {analysis.accuracy_impact}") + display.console.print() + display.console.print("Options:") + for i, rec in enumerate(analysis.recommendations, 1): + display.console.print(f" {i}. {rec}") + if not interaction.confirm("Continue anyway?"): + display.print_warning("Configuration change cancelled") + return + + case ChangeImpact.QUANTIZABLE | ChangeImpact.TRANSFORMABLE: + display.print_warning("Transformation available") + display.console.print(f" Accuracy impact: {analysis.accuracy_impact}") + display.console.print( + f" Time estimate: ~{analysis.estimated_time.total_seconds():.0f}s" + ) + display.console.print() + display.console.print("Recommendations:") + for i, rec in enumerate(analysis.recommendations, 1): + display.console.print(f" {i}. {rec}") + + case _: + display.print_success("Configuration change is compatible") + + except Exception as e: + display.print_warning(f"Could not validate change: {e}") + if not force and not interaction.confirm("Continue without validation?"): + display.print_warning("Configuration change cancelled") + return + + # Apply change using settings service + try: + # Note: Settings update mechanism needs to be implemented + # For now, this is a placeholder for the actual implementation + # Options: + # 1. Use settings.model_copy(update={key: value}) + # 2. Implement a .set() method on Settings + # 3. Write directly to config file and reload + display.print_warning("Settings update not yet implemented") + display.console.print(f"Would update: {key} = {value}") + # TODO: Implement actual settings update mechanism + except Exception as e: + display.print_error(f"Failed to update configuration: {e}") + raise + + +def main() -> None: + """Main entry point for config CLI.""" + display_instance = StatusDisplay() + error_handler = CLIErrorHandler(display_instance, verbose=True, debug=True) + + try: + app() + except KeyboardInterrupt: + display_instance.print_warning("Operation cancelled by user") + except Exception as e: + error_handler.handle_error(e, "CLI", exit_code=1) + + +if __name__ == "__main__": + main() + +__all__ = () diff --git a/src/codeweaver/cli/commands/config.py.rej b/src/codeweaver/cli/commands/config.py.rej new file mode 100644 index 00000000..b5877411 --- /dev/null +++ b/src/codeweaver/cli/commands/config.py.rej @@ -0,0 +1,11 @@ +--- src/codeweaver/cli/commands/config.py ++++ src/codeweaver/cli/commands/config.py +@@ -103,7 +103,7 @@ + table.add_row("Project Path", str(settings["project_path"])) + table.add_row("Project Name", settings["project_name"] or "auto-detected") + table.add_row("Token Limit", str(settings["token_limit"])) +- table.add_row("Max File Size", f"{settings['max_file_size']:,} bytes") ++ table.add_row("Max File Size", f"{settings.get('max_file_size', 0):,} bytes") + table.add_row("Max Results", str(settings["max_results"])) + + # Feature flags diff --git a/src/codeweaver/providers/config/clients/multi.py b/src/codeweaver/providers/config/clients/multi.py index 09074c4b..3e9b3053 100644 --- a/src/codeweaver/providers/config/clients/multi.py +++ b/src/codeweaver/providers/config/clients/multi.py @@ -49,7 +49,10 @@ GoogleCredentials = Any if has_package("fastembed") is not None or has_package("fastembed_gpu") is not None: - from fastembed.common.types import OnnxProvider + try: + from fastembed.common.types import OnnxProvider + except ImportError: + OnnxProvider = object else: OnnxProvider = object diff --git a/src/codeweaver/providers/config/clients/multi.py.orig b/src/codeweaver/providers/config/clients/multi.py.orig new file mode 100644 index 00000000..09074c4b --- /dev/null +++ b/src/codeweaver/providers/config/clients/multi.py.orig @@ -0,0 +1,579 @@ +# SPDX-FileCopyrightText: 2026 Knitli Inc. +# +# SPDX-License-Identifier: MIT OR Apache-2.0 + +"""Client options for providers that can be used with multiple categories (e.g. embedding, reranking).""" + +from __future__ import annotations + +import contextlib + +from collections.abc import Awaitable, Callable, Hashable, Iterable, Mapping, Sequence +from concurrent.futures import ThreadPoolExecutor +from typing import Annotated, Any, ClassVar, Literal, Self, TypedDict, cast + +import httpx + +from pydantic import ( + AnyUrl, + Discriminator, + Field, + PositiveFloat, + PositiveInt, + SecretStr, + Tag, + model_validator, +) + +from codeweaver.core.constants import DEFAULT_EMBEDDING_TIMEOUT, ONNX_CUDA_PROVIDER +from codeweaver.core.types import ( + AnonymityConversion, + FilteredKey, + FilteredKeyT, + LiteralProviderType, + Provider, +) +from codeweaver.core.utils import TypeIs, deep_merge_dicts, has_package +from codeweaver.providers.config.clients.base import ClientOptions +from codeweaver.providers.config.clients.utils import ( + AzureOptions, + discriminate_embedding_clients, + try_for_azure_endpoint, + try_for_heroku_endpoint, +) + + +if has_package("google") is not None: + from google.auth.credentials import Credentials as GoogleCredentials +else: + GoogleCredentials = Any + +if has_package("fastembed") is not None or has_package("fastembed_gpu") is not None: + from fastembed.common.types import OnnxProvider +else: + OnnxProvider = object + +if has_package("torch") is not None: + from torch.nn import Module +else: + Module = object +if has_package("sentence_transformers") is not None: + # SentenceTransformerModelCardData contains these forward references: + # - eval_results_dict: dict[SentenceEvaluator, dict[str, Any]] | None + # - model: SentenceTransformer | None + # So if the configured settings are SentenceTransformersClientOptions + # Then we need to have these in the namespace for pydantic to resolve + from sentence_transformers import SentenceTransformer as SentenceTransformer + from sentence_transformers.evaluation import SentenceEvaluator as SentenceEvaluator + from sentence_transformers.model_card import ( + SentenceTransformerModelCardData as SentenceTransformerModelCardData, + ) + + +class CohereClientOptions(ClientOptions): + """Client options for Cohere (rerank and embeddings).""" + + _core_provider: ClassVar[Literal[Provider.COHERE]] = Provider.COHERE + _providers: ClassVar[tuple[Provider, ...]] = (Provider.COHERE, Provider.AZURE, Provider.HEROKU) + + tag: Literal["cohere"] = "cohere" + api_key: ( + Annotated[SecretStr | Callable[[], str], Field(description="Cohere API key.")] | None + ) = None + base_url: Annotated[AnyUrl, Field(description="Base URL for the Cohere API.")] | None = None + environment: Literal["production", "staging", "development"] = "production" + client_name: str | None = "codeweaver_cohere_client" + timeout: PositiveFloat | None = None + httpx_client: httpx.Client | None = None + thread_pool_executor: ThreadPoolExecutor | None = None + log_experimental: bool = True # disables warnings about experimental features + + def _telemetry_keys(self) -> dict[FilteredKeyT, AnonymityConversion]: + return { + FilteredKey("api_key"): AnonymityConversion.BOOLEAN, + FilteredKey("base_url"): AnonymityConversion.BOOLEAN, + FilteredKey("client_name"): AnonymityConversion.HASH, + FilteredKey("httpx_client"): AnonymityConversion.BOOLEAN, + } + + def computed_base_url(self, provider: LiteralProviderType) -> str | None: + """Return the default base URL for the Cohere client based on the provider.""" + provider = provider if isinstance(provider, Provider) else Provider.from_string(provider) # ty:ignore[invalid-assignment] + if base_url := { + Provider.COHERE: "https://api.cohere.com", + Provider.AZURE: try_for_azure_endpoint( + AzureOptions(api_key=self.api_key, endpoint=str(self.base_url)), cohere=True + ), + Provider.HEROKU: try_for_heroku_endpoint(self.model_dump(), cohere=True), + }.get(provider): + if not self.base_url: + self.base_url = AnyUrl(base_url) + return base_url + return None + + +class OpenAIClientOptions(ClientOptions): + """Client options for OpenAI-based embedding providers.""" + + _core_provider: ClassVar[Literal[Provider.OPENAI]] = Provider.OPENAI + _providers: ClassVar[tuple[Provider, ...]] = tuple( + provider for provider in Provider if provider.uses_openai_api + ) + + api_key: ( + SecretStr | Callable[[], str | SecretStr] | Callable[[], Awaitable[str | SecretStr]] | None + ) = None + organization: str | None = None + project: str | None = None + webhook_secret: SecretStr | None = None + base_url: AnyUrl | None = None + websocket_base_url: AnyUrl | None = None + timeout: PositiveFloat | None = None + max_retries: PositiveInt | None = None + default_headers: Mapping[str, str] | None = None + default_query: Mapping[str, object] | None = None + http_client: httpx.Client | None = None + _strict_response_validation: bool = False + + def __init__(self, **data: Any) -> None: + """Initialize the OpenAI client options.""" + object.__setattr__( + self, "_strict_response_validation", data.pop("_strict_response_validation", False) + ) + super().__init__(**data) + + def computed_base_url(self, provider: LiteralProviderType) -> str | None: + """Return the default base URL for the OpenAI client based on the provider.""" + if self.base_url: + return str(self.base_url) + provider = provider if isinstance(provider, Provider) else Provider.from_string(provider) # ty:ignore[invalid-assignment] + return { + Provider.OPENAI: "https://api.openai.com/v1", + Provider.AZURE: try_for_azure_endpoint( + AzureOptions(api_key=self.api_key, endpoint=str(self.base_url)) + ), + Provider.HEROKU: try_for_heroku_endpoint(self.model_dump()), + Provider.GROQ: "https://api.groq.com/openai/v1", + Provider.MORPH: "https://api.morphllm.com/v1", + Provider.OLLAMA: "http://localhost:11434/v1", + Provider.TOGETHER: "https://api.together.xyz/v1", + }.get(provider) + + def _telemetry_keys(self) -> dict[FilteredKeyT, AnonymityConversion]: + return { + FilteredKey(name): AnonymityConversion.BOOLEAN + for name in ( + "api_key", + "webhook_secret", + "http_client", + "default_headers", + "default_query", + ) + } | { + FilteredKey(name): AnonymityConversion.HASH + for name in ("organization", "project", "base_url", "websocket_base_url") + } + + +class BedrockClientOptions(ClientOptions): + """Client options for Boto3-based providers like Bedrock. Most of these are required but can be configured in other ways, such as environment variables or AWS config files.""" + + _core_provider: ClassVar[Literal[Provider.BEDROCK]] = Provider.BEDROCK + _providers: ClassVar[tuple[Provider, ...]] = (Provider.BEDROCK,) + + tag: Literal["bedrock"] = "bedrock" + aws_access_key_id: str | None = None + aws_secret_access_key: SecretStr | None = None + aws_session_token: SecretStr | None = None + region_name: str | None = None + profile_name: str | None = None + aws_account_id: str | None = None + botocore_session: Any | None = None + + def _telemetry_keys(self) -> dict[FilteredKeyT, AnonymityConversion]: + return {FilteredKey("aws_secret_access_key"): AnonymityConversion.BOOLEAN} | { + FilteredKey(name): AnonymityConversion.HASH + for name in ( + "aws_access_key_id", + "aws_session_token", + "region_name", + "profile_name", + "aws_account_id", + ) + } + + +class GoogleClientOptions(ClientOptions): + """Client options for the GenAI Google provider.""" + + _core_provider: ClassVar[Literal[Provider.GOOGLE]] = Provider.GOOGLE + _providers: ClassVar[tuple[Provider, ...]] = (Provider.GOOGLE,) + + api_key: SecretStr | None = None + vertex_ai: bool = False + credentials: GoogleCredentials | None = None + project: str | None = None + location: str | None = None + debug_config: dict[str, Any] | None = None + http_options: dict[str, Any] | None = None + + def _telemetry_keys(self) -> dict[FilteredKeyT, AnonymityConversion]: + return { + FilteredKey(name): AnonymityConversion.BOOLEAN + for name in ("api_key", "location", "http_options", "credentials") + } | {FilteredKey("project"): AnonymityConversion.HASH} + + +class FastEmbedClientOptions(ClientOptions): + """Client options for FastEmbed-based embedding providers.""" + + _core_provider: ClassVar[Literal[Provider.FASTEMBED]] = Provider.FASTEMBED + _providers: ClassVar[tuple[Provider, ...]] = (Provider.FASTEMBED,) + + tag: Literal["fastembed"] = "fastembed" + model_name: str + cache_dir: str | None = None + threads: int | None = None + providers: Sequence[OnnxProvider] | None = None + cuda: bool | None = None + device_ids: list[int] | None = None + lazy_load: bool = True + + @model_validator(mode="after") + def _resolve_device_settings(self) -> Self: + """Resolve device settings for FastEmbed client options.""" + from codeweaver.core import effective_cpu_count + + cpu_count = effective_cpu_count() + object.__setattr__(self, "threads", self.threads or cpu_count) + if self.cuda is False: + object.__setattr__(self, "device_ids", []) + return self + from codeweaver.providers.optimize import decide_fastembed_runtime + + decision = decide_fastembed_runtime( + explicit_cuda=self.cuda, explicit_device_ids=self.device_ids + ) + if isinstance(decision, tuple) and len(decision) == 2: + cuda = bool(decision[0]) + device_ids = decision[1] + elif decision == "gpu": + cuda = True + device_ids = [0] + else: + cuda = False + device_ids = [] + object.__setattr__(self, "cuda", cuda) + object.__setattr__(self, "device_ids", device_ids) + if cuda and (not self.providers or ONNX_CUDA_PROVIDER not in self.providers): + object.__setattr__(self, "providers", [ONNX_CUDA_PROVIDER, *(self.providers or [])]) + return self + + def _telemetry_keys(self) -> dict[FilteredKeyT, AnonymityConversion]: + return {FilteredKey("cache_dir"): AnonymityConversion.HASH} + + +class SentenceTransformersModelOptions(TypedDict, total=False): + """Options for SentenceTransformers models.""" + + dtype: Literal["float", "float16", "bfloat16", "auto"] | None + attn_implementation: Literal["flash_attention_2", "spda", "eager"] | None + provider: OnnxProvider | None + """Onnx Provider if Onnx backend used.""" + file_name: str | None + """Specific file name to load for onnx or openvino models.""" + export: bool | None + """Whether to export the model to onnx/openvino format.""" + + +def _is_str_dict(d: Any) -> TypeIs[dict[str, Any]]: + """Check if the given object is a dictionary with string keys.""" + return isinstance(d, dict) and all(isinstance(k, str) for k in d if k) + + +def _is_hashable_dict(d: Any) -> TypeIs[dict[Hashable, Any]]: + """Check if the given object is a dictionary with hashable keys.""" + return isinstance(d, dict) and all(isinstance(k, Hashable) for k in d if k) + + +class SentenceTransformersClientOptions(ClientOptions): + """Client options for SentenceTransformers-based embedding providers.""" + + _core_provider: ClassVar[Literal[Provider.SENTENCE_TRANSFORMERS]] = ( + Provider.SENTENCE_TRANSFORMERS + ) + _providers: ClassVar[tuple[Provider, ...]] = (Provider.SENTENCE_TRANSFORMERS,) + + tag: Literal["sentence_transformers"] = "sentence_transformers" + model_name_or_path: str | None = None + modules: Iterable[Module] | None = None + device: str | None = None + prompts: dict[str, str] | None = None + default_prompt_name: str | None = None + similarity_fn_name: Literal["cosine", "dot", "euclidean", "manhattan"] | None = None + cache_folder: str | None = None + trust_remote_code: bool = True + revision: str | None = None + local_files_only: bool = False + token: bool | SecretStr | None = None + """Auth token for private/non-public models.""" + truncate_dim: int | None = None + model_kwargs: SentenceTransformersModelOptions | None = None + tokenizer_kwargs: dict[str, Any] | None = None + config_kwargs: dict[str, Any] | None = None + model_card_data: SentenceTransformerModelCardData | None = None + backend: Literal["torch", "onnx", "openvino"] = "torch" + + def __init__(self, **data: Any) -> None: + """Initialize the SentenceTransformers client options.""" + model_name = data.get("model_name_or_path") or "" + data = data or {} + default_kwargs = self.default_kwargs_for_model(model=model_name) or {} + if not _is_hashable_dict(default_kwargs): + raise TypeError( + "Expected data and default_kwargs to be dicts with appropriate key types." + ) + merged_data = deep_merge_dicts(default_kwargs, cast(dict[Hashable, Any], data)) + if not _is_str_dict(merged_data): + raise TypeError("Expected merged data to be a dict with string keys.") + super().__init__(**merged_data) + + def _telemetry_keys(self) -> dict[FilteredKeyT, AnonymityConversion]: + return { + FilteredKey("cache_folder"): AnonymityConversion.HASH, + FilteredKey("model_name_or_path"): AnonymityConversion.HASH, + } + + def _is_dense_model(self) -> bool: + """Determine if the model is a dense model based on its name.""" + return ( + False + if self.capabilities # ty:ignore[unresolved-attribute] + and "sparse" in self.capabilities.__class__.__name__.lower() # ty:ignore[unresolved-attribute] + else not self.model_name_or_path + or not any( + key + for key in ("sparse", "splade", "bm-25", "bm25", "attentions") + if key in str(self.model_name_or_path.lower()) + ) + ) + + def __model_post_init__(self) -> None: + """Post-initialization adjustments for specific models.""" + if ( + model_name_or_path := self.model_name_or_path + ) and "qwen3" in model_name_or_path.lower(): + object.__setattr__( + self, + "model_kwargs", + (self.model_kwargs or {}) + | { + "dtype": "float16" + if "dtype" not in (self.model_kwargs or {}) + else (self.model_kwargs or {}).get("dtype") + }, + ) + if has_package("flash_attention_2"): + object.__setattr__( + self, + "model_kwargs", + (self.model_kwargs or {}) + | { + "attn_implementation": "flash_attention_2" + if "attn_implementation" not in (self.model_kwargs or {}) + else (self.model_kwargs or {}).get("attn_implementation") + }, + ) + + def default_kwargs_for_model( + self, *, model: str | None = None, query: bool = False + ) -> dict[str, Any]: + """Get default client arguments for a specific model.""" + model = model or self.model_name_or_path + if not model: + return {} + extra: dict[str, Any] = {} + float16 = {"model_kwargs": {"dtype": "float16"}} + if "alibaba" in model.lower() and "gte-reranker-modernbert-base" in model.lower(): + extra = {"tokenizer_kwargs": {"padding": True}} + if "qwen3" in model.lower(): + extra = { + "instruction": "Use provided search results of codebase data to retrieve relevant Documents that answer the Query.", + "tokenizer_kwargs": {"padding_side": "left"}, + **float16, + } + if "bge" in model.lower() and "m3" not in model.lower() and query: + extra = { + "prompt_name": "query", + "prompts": { + "query": {"text": "Represent this sentence for searching relevant passages:"} + }, + **float16, + } + if "snowflake" in model.lower() and "v2.0" in model.lower(): + extra = {"prompt_name": "query"} # only for query embeddings + if "intfloat" in model.lower() and "instruct" not in model.lower(): + extra = {"prompt_name": "query"} if query else {"prompt_name": "document"} + if "jina" in model.lower() and "v2" not in model.lower(): + if "v4" in model.lower(): + extra = ( + {"prompt_name": "query", "task": "code"} + if query + else {"task": "code", "prompt_name": "passage"} + ) + else: + extra = ( + {"task": "retrieval.query", "prompt_name": "query"} + if query + else {"task": "retrieval.passage"} + ) + if "nomic" in model.lower(): + extra = {"tokenizer_kwargs": {"padding": True}} + return { + "model_name_or_path": model, + "normalize_embeddings": True, + "trust_remote_code": True, + **extra, + } + + +class HuggingFaceClientOptions(ClientOptions): + """Client options for HuggingFace Inference API-based embedding providers.""" + + _core_provider: ClassVar[Literal[Provider.HUGGINGFACE_INFERENCE]] = ( + Provider.HUGGINGFACE_INFERENCE + ) + _providers: ClassVar[tuple[Provider, ...]] = (Provider.HUGGINGFACE_INFERENCE,) + + model: str | None = None + provider: str | None = None + token: SecretStr | None = None + timeout: PositiveFloat | None = None + headers: dict[str, str] | None = None + cookies: dict[str, str] | None = None + trust_env: bool = False + proxies: Any | None = None + bill_to: str | None = None + base_url: AnyUrl | None = None + api_key: SecretStr | None = None + + def _telemetry_keys(self) -> dict[FilteredKeyT, AnonymityConversion]: + return { + FilteredKey(name): AnonymityConversion.BOOLEAN + for name in ("token", "api_key", "headers", "cookies", "proxies") + } | { + FilteredKey("base_url"): AnonymityConversion.HASH, + FilteredKey("bill_to"): AnonymityConversion.HASH, + } + + +class MistralClientOptions(ClientOptions): + """Client options for Mistral-based embedding providers.""" + + _core_provider: ClassVar[Literal[Provider.MISTRAL]] = Provider.MISTRAL + _providers: ClassVar[tuple[Provider, ...]] = (Provider.MISTRAL,) + + api_key: ( + SecretStr | Callable[[], str | SecretStr] | Callable[[], Awaitable[str | SecretStr]] | None + ) = None + server: str | None = None + server_url: AnyUrl | None = None + url_params: dict[str, str] | None = None + async_client: httpx.AsyncClient | None = None + retry_config: Any | None = None + timeout_ms: PositiveInt | None = None + debug_logger: Any | None = None + + def _telemetry_keys(self) -> dict[FilteredKeyT, AnonymityConversion]: + return { + FilteredKey("api_key"): AnonymityConversion.BOOLEAN, + FilteredKey("server"): AnonymityConversion.HASH, + FilteredKey("server_url"): AnonymityConversion.HASH, + FilteredKey("url_params"): AnonymityConversion.HASH, + FilteredKey("async_client"): AnonymityConversion.BOOLEAN, + } + + +class VoyageClientOptions(ClientOptions): + """Client options for Voyage AI-based embedding and reranking providers.""" + + _core_provider: ClassVar[Literal[Provider.VOYAGE]] = Provider.VOYAGE + _providers: ClassVar[tuple[Provider, ...]] = (Provider.VOYAGE,) + + tag: Literal["voyage"] = "voyage" + api_key: SecretStr | None = None + max_retries: PositiveInt = 0 # we handle retries ourself + timeout: PositiveFloat | None = DEFAULT_EMBEDDING_TIMEOUT + + def _telemetry_keys(self) -> dict[FilteredKeyT, AnonymityConversion]: + return {FilteredKey("api_key"): AnonymityConversion.BOOLEAN} + + +# Rebuild Pydantic models to resolve forward references after all imports complete +# This is necessary because SentenceTransformerModelCardData contains SentenceEvaluator references +if ( + has_package("sentence_transformers") is not None + and not SentenceTransformersClientOptions.__pydantic_complete__ +): + # we can rebuild lazily later if this fails + with contextlib.suppress(Exception): + SentenceTransformersClientOptions.model_rebuild() + + +# =========================================================================== +# * Client Discriminators +# =========================================================================== + +type GeneralRerankingClientOptionsType = Annotated[ + Annotated[BedrockClientOptions, Tag(Provider.BEDROCK.variable)] + | Annotated[CohereClientOptions, Tag(Provider.COHERE.variable)] + | Annotated[FastEmbedClientOptions, Tag(Provider.FASTEMBED.variable)] + | Annotated[SentenceTransformersClientOptions, Tag(Provider.SENTENCE_TRANSFORMERS.variable)] + | Annotated[VoyageClientOptions, Tag(Provider.VOYAGE.variable)], + Field(description="Reranking client options type.", discriminator="tag"), +] + + +def discriminate_azure_embedding_client_options(v: Any) -> str: + """Identify the Azure embedding provider settings type for discriminator field.""" + model_settings = v["model_settings"] if isinstance(v, dict) else v.model_settings + model = ( + model_settings.get("model") if isinstance(model_settings, dict) else model_settings.model + ) + if model in ("text-embedding-3-small", "text-embedding-3-large"): + return "openai" + return "cohere" + + +type GeneralEmbeddingClientOptionsType = Annotated[ + Annotated[BedrockClientOptions, Tag(Provider.BEDROCK.variable)] + | Annotated[CohereClientOptions, Tag(Provider.COHERE.variable)] + | Annotated[FastEmbedClientOptions, Tag(Provider.FASTEMBED.variable)] + | Annotated[GoogleClientOptions, Tag(Provider.GOOGLE.variable)] + | Annotated[HuggingFaceClientOptions, Tag(Provider.HUGGINGFACE_INFERENCE.variable)] + | Annotated[MistralClientOptions, Tag(Provider.MISTRAL.variable)] + | Annotated[OpenAIClientOptions, Tag(Provider.OPENAI.variable)] + | Annotated[SentenceTransformersClientOptions, Tag(Provider.SENTENCE_TRANSFORMERS.variable)] + | Annotated[VoyageClientOptions, Tag(Provider.VOYAGE.variable)], + Field( + description="Embedding client options type.", + discriminator=Discriminator(discriminate_embedding_clients), + ), +] + + +__all__ = ( + "BedrockClientOptions", + "CohereClientOptions", + "FastEmbedClientOptions", + "GeneralEmbeddingClientOptionsType", + "GeneralRerankingClientOptionsType", + "GoogleClientOptions", + "HuggingFaceClientOptions", + "MistralClientOptions", + "OpenAIClientOptions", + "SentenceTransformersClientOptions", + "SentenceTransformersModelOptions", + "VoyageClientOptions", + "discriminate_azure_embedding_client_options", +) diff --git a/src/codeweaver/providers/embedding/providers/voyage.py b/src/codeweaver/providers/embedding/providers/voyage.py index 1bb122b0..c80e4ecd 100644 --- a/src/codeweaver/providers/embedding/providers/voyage.py +++ b/src/codeweaver/providers/embedding/providers/voyage.py @@ -24,9 +24,12 @@ try: - from voyageai.client_async import AsyncClient - from voyageai.object.contextualized_embeddings import ContextualizedEmbeddingsObject - from voyageai.object.embeddings import EmbeddingsObject + try: + from voyageai.client_async import AsyncClient + from voyageai.object.contextualized_embeddings import ContextualizedEmbeddingsObject + from voyageai.object.embeddings import EmbeddingsObject + except Exception: + AsyncClient, ContextualizedEmbeddingsObject, EmbeddingsObject = object, object, object except ImportError as _import_error: raise ConfigurationError( 'Please install the `voyageai` package to use the Voyage provider, you can use the `voyage` optional group -- `pip install "code-weaver\\[voyage]"`' diff --git a/src/codeweaver/providers/embedding/providers/voyage.py.orig b/src/codeweaver/providers/embedding/providers/voyage.py.orig new file mode 100644 index 00000000..1bb122b0 --- /dev/null +++ b/src/codeweaver/providers/embedding/providers/voyage.py.orig @@ -0,0 +1,180 @@ +# SPDX-FileCopyrightText: 2026 Knitli Inc. +# SPDX-FileContributor: Adam Poulemanos +# +# SPDX-License-Identifier: MIT OR Apache-2.0 + +"""VoyageAI embedding provider.""" + +from __future__ import annotations + +import asyncio + +from collections.abc import Callable, Sequence +from typing import Annotated, Any, ClassVar, cast + +from pydantic import PrivateAttr, SkipValidation +from voyageai.object.embeddings import EmbeddingsObject + +from codeweaver.core import CodeChunk, ConfigurationError, Provider +from codeweaver.providers.embedding.providers.base import ( + EmbeddingCustomDeps, + EmbeddingImplementationDeps, + EmbeddingProvider, +) + + +try: + from voyageai.client_async import AsyncClient + from voyageai.object.contextualized_embeddings import ContextualizedEmbeddingsObject + from voyageai.object.embeddings import EmbeddingsObject +except ImportError as _import_error: + raise ConfigurationError( + 'Please install the `voyageai` package to use the Voyage provider, you can use the `voyage` optional group -- `pip install "code-weaver\\[voyage]"`' + ) from _import_error + + +def voyage_context_output_transformer( + result: ContextualizedEmbeddingsObject, +) -> list[list[int | float]] | list[list[int]]: + """Transform the output of the Voyage AI context chunk embedding model.""" + results = result.results + embeddings = [res.embeddings for res in results if res and res.embeddings] + if embeddings and isinstance(embeddings[0][0][0], list): + embeddings = cast( + list[list[int | float]], [emb for sublist in embeddings for emb in sublist] + ) + return cast(list[list[int | float]] | list[list[int]], embeddings) + + +def voyage_output_transformer( + result: EmbeddingsObject, +) -> list[list[int | float]] | list[list[int]]: + """Transform the output of the Voyage AI model.""" + return cast(list[list[int | float]] | list[list[int]], result.embeddings) + + +class VoyageEmbeddingProvider(EmbeddingProvider[AsyncClient]): + """VoyageAI embedding provider.""" + + client: SkipValidation[AsyncClient] + _provider: ClassVar[Provider] = Provider.VOYAGE + _output_transformer: Callable[[Any], list[list[float]] | list[list[int]]] = ( + voyage_output_transformer + ) + _is_context_model: Annotated[bool, PrivateAttr()] = False + + def _initialize( + self, + impl_deps: EmbeddingImplementationDeps = None, + custom_deps: EmbeddingCustomDeps = None, + **kwargs: Any, + ) -> None: + """Initialize the VoyageAI client.""" + + def model_post_init(self, __context: Any, /) -> None: + """Post-initialization hook to detect context models and set options. + + Args: + __context: Pydantic context (unused). + """ + config = self.config.embedding_config + # Set model name, input type, and output parameters for embedding and query options + self.embed_options["model"] = self.model_name + self.embed_options["input_type"] = "document" + self.embed_options["output_dimension"] = config.dimension or self.caps.default_dimension + self.embed_options["output_dtype"] = config.datatype or self.caps.default_dtype + + self.query_options["model"] = self.model_name + self.query_options["input_type"] = "query" + self.query_options["output_dimension"] = config.dimension or self.caps.default_dimension + self.query_options["output_dtype"] = config.datatype or self.caps.default_dtype + # Detect if this is a context model based on the model name + if self.caps and "context" in self.caps.name.lower(): + object.__setattr__(self, "_is_context_model", True) + + def _process_output(self, output_data: Any) -> list[list[float]] | list[list[int]]: + """Process output data using the appropriate transformer.""" + transformer = ( + voyage_context_output_transformer + if self._is_context_model + else voyage_output_transformer + ) + return transformer(output_data) + + @property + def name(self) -> Provider: + """Get the name of the embedding provider.""" + return Provider.VOYAGE + + @property + def base_url(self) -> str | None: + """Get the base URL of the embedding provider.""" + return "https://api.voyageai.com/v1" + + async def _embed_documents( + self, documents: Sequence[CodeChunk], **kwargs: Any + ) -> list[list[int | float]] | list[list[int]]: + """Embed a list of documents into vectors.""" + import logging + + logger = logging.getLogger(__name__) + ready_documents = cast(list[str], self.chunks_to_strings(documents)) + try: + results: EmbeddingsObject = await self.client.embed( + texts=ready_documents, **kwargs | self.embed_options + ) + await asyncio.sleep(0) + loop = await self._get_loop() + self._fire_and_forget( + lambda: self._update_token_stats(token_count=results.total_tokens), loop=loop + ) + except Exception as e: + error_msg = str(e) + if "max allowed tokens per submitted batch" in error_msg.lower() and len(documents) > 1: + logger.warning( + "Voyage batch token limit exceeded (%s), splitting batch of %d chunks in half and retrying", + error_msg.split("Your batch has")[1].split("tokens")[0].strip() + if "Your batch has" in error_msg + else "unknown", + len(documents), + ) + mid = len(documents) // 2 + first_half = await self._embed_documents(documents[:mid], **kwargs) + second_half = await self._embed_documents(documents[mid:], **kwargs) + return cast(list[list[int | float]] | list[list[int]], first_half + second_half) + raise + else: + return self._process_output(results) + + async def _embed_query( + self, query: Sequence[str], **kwargs: Any + ) -> list[list[int | float]] | list[list[int]]: + """Embed a query or group of queries into vectors.""" + results: EmbeddingsObject = await self.client.embed( + texts=list(query), **kwargs | self.query_options + ) + loop = await self._get_loop() + self._fire_and_forget( + lambda: self._update_token_stats(token_count=results.total_tokens), loop=loop + ) + return self._process_output(results) + + @property + def dimension(self) -> int: + """Get the size of the vector for the collection.""" + return self.embed_options.get("output_dimension", self.caps.default_dimension) + + +def _rebuild_voyage_embedding_provider() -> None: + from codeweaver.core import CodeChunk as CodeChunk + + VoyageEmbeddingProvider.model_rebuild() + + +_rebuild_voyage_embedding_provider() + +__all__ = ( + "VoyageEmbeddingProvider", + "voyage_context_output_transformer", + "voyage_output_transformer", +) diff --git a/src/codeweaver/providers/embedding/providers/voyage.py.rej b/src/codeweaver/providers/embedding/providers/voyage.py.rej new file mode 100644 index 00000000..17a967df --- /dev/null +++ b/src/codeweaver/providers/embedding/providers/voyage.py.rej @@ -0,0 +1,15 @@ +--- src/codeweaver/providers/embedding/providers/voyage.py ++++ src/codeweaver/providers/embedding/providers/voyage.py +@@ -14,8 +14,11 @@ + logger = logging.getLogger(__name__) + + try: +- from voyageai.client_async import AsyncClient +- from voyageai.object.embeddings import EmbeddingsObject ++ try: ++ from voyageai.client_async import AsyncClient ++ from voyageai.object.embeddings import EmbeddingsObject ++ except Exception: ++ AsyncClient, EmbeddingsObject = object, object + except ImportError as e: + from codeweaver.core import ConfigurationError diff --git a/src/codeweaver/providers/reranking/providers/voyage.py b/src/codeweaver/providers/reranking/providers/voyage.py index ede60378..be278e97 100644 --- a/src/codeweaver/providers/reranking/providers/voyage.py +++ b/src/codeweaver/providers/reranking/providers/voyage.py @@ -27,9 +27,12 @@ logger = logging.getLogger(__name__) try: - from voyageai.client_async import AsyncClient - from voyageai.object.reranking import RerankingObject - from voyageai.object.reranking import RerankingResult as VoyageRerankingResult + try: + from voyageai.client_async import AsyncClient + from voyageai.object.reranking import RerankingObject + from voyageai.object.reranking import RerankingResult as VoyageRerankingResult + except Exception: + AsyncClient, RerankingObject, VoyageRerankingResult = object, object, object except ImportError as e: from codeweaver.core import ConfigurationError diff --git a/src/codeweaver/providers/reranking/providers/voyage.py.orig b/src/codeweaver/providers/reranking/providers/voyage.py.orig new file mode 100644 index 00000000..ede60378 --- /dev/null +++ b/src/codeweaver/providers/reranking/providers/voyage.py.orig @@ -0,0 +1,134 @@ +# SPDX-FileCopyrightText: 2025 Knitli Inc. +# SPDX-FileContributor: Adam Poulemanos +# +# SPDX-License-Identifier: MIT OR Apache-2.0 + +"""Voyage AI reranking provider implementation.""" + +from __future__ import annotations + +import asyncio +import logging + +from collections.abc import Iterator, Sequence +from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast +from warnings import filterwarnings + +from pydantic import ConfigDict, SkipValidation + +from codeweaver.core import Provider, ProviderError, rpartial +from codeweaver.core.constants import DEFAULT_RERANKING_MAX_RESULTS +from codeweaver.providers.reranking.providers.base import RerankingProvider, RerankingResult + + +if TYPE_CHECKING: + from codeweaver.core import CodeChunk + +logger = logging.getLogger(__name__) + +try: + from voyageai.client_async import AsyncClient + from voyageai.object.reranking import RerankingObject + from voyageai.object.reranking import RerankingResult as VoyageRerankingResult + +except ImportError as e: + from codeweaver.core import ConfigurationError + + raise ConfigurationError( + r"Voyage AI SDK is not installed. Please install it with `pip install code-weaver\[voyage]`." + ) from e + + +# We need to filter UserWarning about shadowing the parent class +filterwarnings("ignore", category=UserWarning, message='.*RerankingProvider" shadows.*') + + +def voyage_reranking_output_transformer( + returned_result: RerankingObject, + _original_chunks: Iterator[CodeChunk] | tuple[CodeChunk, ...], + _instance: VoyageRerankingProvider, +) -> list[RerankingResult]: + """Transform the output of the Voyage AI reranking model.""" + original_chunks = ( + tuple(_original_chunks) if isinstance(_original_chunks, Iterator) else _original_chunks + ) + + def map_result(voyage_result: VoyageRerankingResult, new_index: int) -> RerankingResult: + """Maps a VoyageRerankingResult to a CodeWeaver RerankingResult.""" + return RerankingResult( + original_index=voyage_result.index, + batch_rank=new_index, + score=voyage_result.relevance_score, + chunk=original_chunks[voyage_result.index], + ) + + results, token_count = returned_result.results, returned_result.total_tokens + try: + loop = _instance._loop or asyncio.get_running_loop() + _ = loop.call_soon_threadsafe( + lambda: _instance._update_token_stats(token_count=token_count) + ) + except RuntimeError: + _instance._update_token_stats(token_count=token_count) + # Sort by relevance_score - handle both tuple (x[2]) and attribute (x.relevance_score) access + try: + results.sort(key=lambda x: cast(float, x.relevance_score), reverse=True) + except AttributeError: + results.sort(key=lambda x: cast(float, x[2]), reverse=True) + return [map_result(res, i) for i, res in enumerate(results, 1)] + + +class VoyageRerankingProvider(RerankingProvider[AsyncClient]): + """Voyage AI reranking provider implementation.""" + + model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) + + client: SkipValidation[AsyncClient] + _provider: ClassVar[Literal[Provider.VOYAGE]] = Provider.VOYAGE + + def _initialize(self) -> None: + """Initialize after Pydantic setup.""" + self._output_transformer = rpartial( # ty:ignore[invalid-assignment] + voyage_reranking_output_transformer, _instance=self + ) + + async def _execute_rerank( + self, + query: str, + documents: Sequence[str], + *, + top_n: int = DEFAULT_RERANKING_MAX_RESULTS, + **kwargs: Any, + ) -> Any: + """Execute the reranking process.""" + try: + # Voyage API doesn't accept extra kwargs - only query, documents, model, top_k + response = await self.client.rerank( + query=query, + documents=[documents] if isinstance(documents, str) else documents, # ty: ignore[invalid-argument-type] + model=self.caps.name, + top_k=top_n, + ) + self._loop = await self._get_loop() + except Exception as e: + raise ProviderError( + f"Voyage AI reranking request failed: {e}", + details={ + "provider": "voyage", + "model": self.caps.name, + "query_length": len(query), + "document_count": len(documents), + "error_type": type(e).__name__, + }, + suggestions=[ + "Check VOYAGE_API_KEY environment variable is set correctly", + "Verify network connectivity to Voyage AI API", + "Check API rate limits and quotas", + "Ensure the reranking model name is valid", + ], + ) from e + else: + return response + + +__all__ = ("VoyageRerankingProvider", "voyage_reranking_output_transformer") diff --git a/src/codeweaver/providers/reranking/providers/voyage.py.rej b/src/codeweaver/providers/reranking/providers/voyage.py.rej new file mode 100644 index 00000000..da19b42b --- /dev/null +++ b/src/codeweaver/providers/reranking/providers/voyage.py.rej @@ -0,0 +1,14 @@ +--- src/codeweaver/providers/reranking/providers/voyage.py ++++ src/codeweaver/providers/reranking/providers/voyage.py +@@ -27,8 +27,11 @@ + import voyageai # noqa: F401 + except ImportError: + logger.error("Failed to import voyageai", exc_info=True) + raise ConfigurationError( + r"Voyage AI is not installed. Please install it with `pip install code-weaver\[voyage]`." + ) +- from voyageai.client_async import AsyncClient ++ try: ++ from voyageai.client_async import AsyncClient ++ except Exception: ++ AsyncClient = object diff --git a/test_fix.patch b/test_fix.patch new file mode 100644 index 00000000..b000423b --- /dev/null +++ b/test_fix.patch @@ -0,0 +1,12 @@ +--- src/codeweaver/providers/config/clients/multi.py ++++ src/codeweaver/providers/config/clients/multi.py +@@ -50,6 +50,9 @@ + + if has_package("fastembed") is not None or has_package("fastembed_gpu") is not None: +- from fastembed.common.types import OnnxProvider ++ try: ++ from fastembed.common.types import OnnxProvider ++ except ImportError: ++ OnnxProvider = Any + else: + OnnxProvider = Any diff --git a/tests/unit/core/test_discovery.py b/tests/unit/core/test_discovery.py new file mode 100644 index 00000000..57945c63 --- /dev/null +++ b/tests/unit/core/test_discovery.py @@ -0,0 +1,52 @@ +import pytest +from pathlib import Path + +from codeweaver.core.chunks import CodeChunk +from codeweaver.core.spans import Span +from codeweaver.core.discovery import DiscoveredFile +from codeweaver.core.utils import uuid7 + +def test_from_chunk_valid_file(tmp_path: Path): + """Test creating a DiscoveredFile from a CodeChunk with a valid, existing file.""" + test_file = tmp_path / "valid_file.py" + test_file.write_text("print('hello world')") + + chunk = CodeChunk( + content="print('hello world')", + line_range=Span(start=1, end=1, source_id=uuid7()), + file_path=test_file + ) + + discovered_file = DiscoveredFile.from_chunk(chunk) + + assert isinstance(discovered_file, DiscoveredFile) + assert discovered_file.path.name == "valid_file.py" + +def test_from_chunk_invalid_file(tmp_path: Path): + """Test that creating a DiscoveredFile from a CodeChunk fails when the file_path is invalid.""" + # Condition 1: file_path is None + chunk_no_path = CodeChunk( + content="print('hello')", + line_range=Span(start=1, end=1, source_id=uuid7()), + file_path=None + ) + with pytest.raises(ValueError, match="CodeChunk must have a valid file_path"): + DiscoveredFile.from_chunk(chunk_no_path) + + # Condition 2: file_path points to a non-existent file + chunk_bad_path = CodeChunk( + content="print('hello')", + line_range=Span(start=1, end=1, source_id=uuid7()), + file_path=tmp_path / "does_not_exist.py" + ) + with pytest.raises(ValueError, match="CodeChunk must have a valid file_path"): + DiscoveredFile.from_chunk(chunk_bad_path) + + # Condition 3: file_path points to an existing directory instead of a file + chunk_dir = CodeChunk( + content="print('hello')", + line_range=Span(start=1, end=1, source_id=uuid7()), + file_path=tmp_path + ) + with pytest.raises(ValueError, match="CodeChunk must have a valid file_path"): + DiscoveredFile.from_chunk(chunk_dir)