From a007dcab89db4d8a8781b7b197be4593e5eecd0f Mon Sep 17 00:00:00 2001 From: Artem Dubrovskiy Date: Mon, 19 May 2025 16:13:49 +0400 Subject: [PATCH 1/2] refactoring github toolkit as an example --- src/alita_tools/__init__.py | 1 - src/alita_tools/github/__init__.py | 206 +++++++++++++++-------- src/alita_tools/github/api_wrapper.py | 228 +++++++++++++------------- 3 files changed, 249 insertions(+), 186 deletions(-) diff --git a/src/alita_tools/__init__.py b/src/alita_tools/__init__.py index 5178e34f..5d5c5e2e 100644 --- a/src/alita_tools/__init__.py +++ b/src/alita_tools/__init__.py @@ -49,7 +49,6 @@ def get_tools(tools_list, alita: 'AlitaClient', llm: 'LLMLikeObject', *args, **kwargs): tools = [] - print("Tools") for tool in tools_list: tool['settings']['alita'] = alita tool['settings']['llm'] = llm diff --git a/src/alita_tools/github/__init__.py b/src/alita_tools/github/__init__.py index 8188e23b..7586c051 100644 --- a/src/alita_tools/github/__init__.py +++ b/src/alita_tools/github/__init__.py @@ -1,7 +1,9 @@ -from typing import Dict, List, Optional, Literal +from functools import lru_cache +from typing import Type +from typing import List, Optional from langchain_core.tools import BaseTool, BaseToolkit -from pydantic import create_model, BaseModel, ConfigDict, Field, SecretStr +from pydantic import BaseModel, Field, SecretStr, field_validator, computed_field from .api_wrapper import AlitaGitHubAPIWrapper from .tool import GitHubAction @@ -10,6 +12,105 @@ name = "github" + +@lru_cache(maxsize=1) +def get_available_tools() -> dict[str, dict]: + api_wrapper = AlitaGitHubAPIWrapper.model_construct() + available_tools: dict = { + x['name']: x['args_schema'].model_json_schema() for x in + api_wrapper.get_available_tools() + } + return available_tools + +toolkit_max_length = lru_cache(maxsize=1)(lambda: get_max_toolkit_length(get_available_tools())) + + +class AlitaGitHubToolkitConfig(BaseModel): + class Config: + title = name + json_schema_extra = { + 'metadata': { + "label": "GitHub", + "icon_url": None, + "sections": { + "auth": { + "required": False, + "subsections": [ + { + "name": "Token", + "fields": ["access_token"] + }, + { + "name": "Password", + "fields": ["username", "password"] + }, + { + "name": "App private key", + "fields": ["app_id", "app_private_key"] + }, + ] + } + } + } + } + + base_url: Optional[str] = Field( + default="https://api.github.com", + description="Base API URL", + json_schema_extra={'configuration': True, 'configuration_title': True} + ) + app_id: Optional[str] = Field( + default=None, + description="Github APP ID", + json_schema_extra={'configuration': True}, + ) + app_private_key: Optional[SecretStr] = Field( + default=None, + description="Github APP private key", + json_schema_extra={'secret': True, 'configuration': True}, + ) + access_token: Optional[SecretStr] = Field( + default=None, + description="Github Access Token", + json_schema_extra={'secret': True, 'configuration': True}, + ) + username: Optional[str] = Field( + default=None, + description="Github Username", + json_schema_extra={'configuration': True}, + ) + password: Optional[SecretStr] = Field( + default=None, + description="Github Password", + json_schema_extra={'secret': True, 'configuration': True}, + ) + repository: str = Field( + description="Github repository", + json_schema_extra={ + 'toolkit_name': True, + 'max_toolkit_length': 100 # Example limit; adjust as needed + }, + ) + active_branch: Optional[str] = Field( + default="main", + description="Active branch", + ) + base_branch: Optional[str] = Field( + default="main", + description="Github Base branch", + ) + selected_tools: List[str] = Field( + default=[], + description="Selected tools", + json_schema_extra={'args_schemas': get_available_tools()}, + ) + + @field_validator('selected_tools', mode='before', check_fields=False) + @classmethod + def selected_tools_validator(cls, value: List[str]) -> list[str]: + return [i for i in value if i in get_available_tools()] + + def _get_toolkit(tool) -> BaseToolkit: return AlitaGitHubToolkit().get_toolkit( selected_tools=tool['settings'].get('selected_tools', []), @@ -26,85 +127,56 @@ def _get_toolkit(tool) -> BaseToolkit: toolkit_name=tool.get('toolkit_name') ) + def get_toolkit(): return AlitaGitHubToolkit.toolkit_config_schema() + def get_tools(tool): return _get_toolkit(tool).get_tools() + class AlitaGitHubToolkit(BaseToolkit): tools: List[BaseTool] = [] - toolkit_max_length: int = 0 - @staticmethod - def toolkit_config_schema() -> BaseModel: - selected_tools = {x['name']: x['args_schema'].schema() for x in AlitaGitHubAPIWrapper.model_construct().get_available_tools()} - AlitaGitHubToolkit.toolkit_max_length = get_max_toolkit_length(selected_tools) - return create_model( - name, - __config__=ConfigDict( - json_schema_extra={ - 'metadata': { - "label": "GitHub", - "icon_url": None, - "sections": { - "auth": { - "required": False, - "subsections": [ - { - "name": "Token", - "fields": ["access_token"] - }, - { - "name": "Password", - "fields": ["username", "password"] - }, - { - "name": "App private key", - "fields": ["app_id", "app_private_key"] - } - ] - } - } - }, - } - ), - base_url=(Optional[str], Field(description="Base API URL", default="https://api.github.com", json_schema_extra={'configuration': True, 'configuration_title': True})), - app_id=(Optional[str], Field(description="Github APP ID", default=None, json_schema_extra={'configuration': True})), - app_private_key=(Optional[SecretStr], Field(description="Github APP private key", default=None, json_schema_extra={'secret': True, 'configuration': True})), + api_wrapper: Optional[AlitaGitHubAPIWrapper] = Field(default_factory=AlitaGitHubAPIWrapper.model_construct) + toolkit_name: Optional[str] = None - access_token=(Optional[SecretStr], Field(description="Github Access Token", default=None, json_schema_extra={'secret': True, 'configuration': True})), + @computed_field + @property + def tool_prefix(self) -> str: + return clean_string(self.toolkit_name, toolkit_max_length()) + TOOLKIT_SPLITTER if self.toolkit_name else '' - username=(Optional[str], Field(description="Github Username", default=None, json_schema_extra={'configuration': True})), - password=(Optional[SecretStr], Field(description="Github Password", default=None, json_schema_extra={'secret': True, 'configuration': True})), + @computed_field + @property + def available_tools(self) -> List[dict]: + return self.api_wrapper.get_available_tools() - repository=(str, Field(description="Github repository", json_schema_extra={'toolkit_name': True, 'max_toolkit_length': AlitaGitHubToolkit.toolkit_max_length})), - active_branch=(Optional[str], Field(description="Active branch", default="main")), - base_branch=(Optional[str], Field(description="Github Base branch", default="main")), - selected_tools=(List[Literal[tuple(selected_tools)]], Field(default=[], json_schema_extra={'args_schemas': selected_tools})) - ) + @staticmethod + def toolkit_config_schema() -> Type[BaseModel]: + return AlitaGitHubToolkitConfig @classmethod - def get_toolkit(cls, selected_tools: list[str] | None = None, toolkit_name: Optional[str] = None, **kwargs): - if selected_tools is None: - selected_tools = [] + def get_toolkit(cls, selected_tools: list[str] | None = None, toolkit_name: Optional[str] = None, **kwargs) -> "AlitaGitHubToolkit": github_api_wrapper = AlitaGitHubAPIWrapper(**kwargs) - available_tools: List[Dict] = github_api_wrapper.get_available_tools() - tools = [] - prefix = clean_string(toolkit_name, AlitaGitHubToolkit.toolkit_max_length) + TOOLKIT_SPLITTER if toolkit_name else '' - for tool in available_tools: - if selected_tools: - if tool["name"] not in selected_tools: - continue - tools.append(GitHubAction( - api_wrapper=github_api_wrapper, - name=prefix + tool["name"], - mode=tool["mode"], - # set unique description for declared tools to differentiate the same methods for different toolkits - description=f"Repository: {github_api_wrapper.github_repository}\n" + tool["description"], - args_schema=tool["args_schema"] - )) - return cls(tools=tools) + instance = cls( + tools=[], + api_wrapper=github_api_wrapper, + toolkit_name=toolkit_name + ) + if selected_tools: + selected_tools = set(selected_tools) + for t in instance.available_tools: + if t["name"] in selected_tools: + instance.tools.append(GitHubAction( + api_wrapper=instance.api_wrapper, + name=instance.tool_prefix + t["name"], + mode=t["mode"], + # set unique description for declared tools to differentiate the same methods for different toolkits + description=f"Repository: {github_api_wrapper.github_repository}\n" + t["description"], + args_schema=t["args_schema"] + )) + return instance def get_tools(self): - return self.tools \ No newline at end of file + return self.tools diff --git a/src/alita_tools/github/api_wrapper.py b/src/alita_tools/github/api_wrapper.py index 15752198..4a7c77ab 100644 --- a/src/alita_tools/github/api_wrapper.py +++ b/src/alita_tools/github/api_wrapper.py @@ -1,135 +1,127 @@ -from typing import Any, Dict, List, Optional, Union, Tuple +import json import logging import traceback -import json -import re -from pydantic import BaseModel, model_validator, Field +from typing import Any, Optional, Self + +from langchain_core.callbacks import dispatch_custom_event +from langchain_core.utils import get_from_env +from pydantic import model_validator, Field, field_validator, SecretStr, ConfigDict +from pydantic_core.core_schema import ValidationInfo -from .github_client import GitHubClient -from .graphql_client_wrapper import GraphQLClientWrapper # Add imports for the executor and generator from .executor.github_code_executor import GitHubCodeExecutor from .generator.github_code_generator import GitHubCodeGenerator +from .github_client import GitHubClient +from .graphql_client_wrapper import GraphQLClientWrapper from .schemas import ( GitHubAuthConfig, - GitHubRepoConfig, - ProcessGitHubQueryModel + GitHubRepoConfig ) -from langchain_core.callbacks import dispatch_custom_event - logger = logging.getLogger(__name__) -# Import prompts for tools -from .tool_prompts import ( - UPDATE_FILE_PROMPT, - CREATE_ISSUE_PROMPT, - UPDATE_ISSUE_PROMPT, - CREATE_ISSUE_ON_PROJECT_PROMPT, - UPDATE_ISSUE_ON_PROJECT_PROMPT, - CODE_AND_RUN -) - -class AlitaGitHubAPIWrapper(BaseModel): +class AlitaGitHubAPIWrapper(GitHubAuthConfig, GitHubRepoConfig): """ Wrapper for GitHub API that integrates both REST and GraphQL functionality. """ + model_config = ConfigDict( + arbitrary_types_allowed=True, + from_attributes=True + ) + # Authentication config - github_access_token: Optional[str] = None - github_username: Optional[str] = None - github_password: Optional[str] = None - github_app_id: Optional[str] = None - github_app_private_key: Optional[str] = None - github_base_url: Optional[str] = None + github_access_token: Optional[SecretStr] = Field(default=None, json_schema_extra={'env_key': 'GITHUB_ACCESS_TOKEN'}) + github_username: Optional[str] = Field(default=None, json_schema_extra={'env_key': 'GITHUB_USERNAME'}) + github_password: Optional[SecretStr] = Field(default=None, json_schema_extra={'env_key': 'GITHUB_PASSWORD'}) + github_app_id: Optional[str] = Field(default=None, json_schema_extra={'env_key': 'GITHUB_APP_ID'}) + github_app_private_key: Optional[SecretStr] = Field(default=None, + json_schema_extra={'env_key': 'GITHUB_APP_PRIVATE_KEY'}) + github_base_url: Optional[str] = Field(default='https://api.github.com', + json_schema_extra={'env_key': 'GITHUB_BASE_URL'}) # Repository config - github_repository: Optional[str] = None - active_branch: Optional[str] = None - github_base_branch: Optional[str] = None + github_repository: Optional[str] = Field(default=None, json_schema_extra={'env_key': 'GITHUB_REPOSITORY'}) + github_base_branch: Optional[str] = 'main' # Add LLM instance llm: Optional[Any] = None - # Client instances - renamed without leading underscores and marked as exclude=True - github_client_instance: Optional[GitHubClient] = Field(default=None, exclude=True) - graphql_client_instance: Optional[GraphQLClientWrapper] = Field(default=None, exclude=True) + _github_client_instance: Optional[GitHubClient] = None + _graphql_client_instance: Optional[GraphQLClientWrapper] = None - class Config: - arbitrary_types_allowed = True - - @model_validator(mode='before') @classmethod - def validate_environment(cls, values: Dict) -> Dict: - """ - Initialize GitHub clients based on the provided values. - - Args: - values (Dict): Configuration values for GitHub API wrapper - - Returns: - Dict: Updated values dictionary - """ - from langchain.utils import get_from_dict_or_env - - # Get all authentication values - github_access_token = get_from_dict_or_env(values, "github_access_token", "GITHUB_ACCESS_TOKEN", default='') - github_username = get_from_dict_or_env(values, "github_username", "GITHUB_USERNAME", default='') - github_password = get_from_dict_or_env(values, "github_password", "GITHUB_PASSWORD", default='') - github_app_id = get_from_dict_or_env(values, "github_app_id", "GITHUB_APP_ID", default='') - github_app_private_key = get_from_dict_or_env(values, "github_app_private_key", "GITHUB_APP_PRIVATE_KEY", default='') - github_base_url = get_from_dict_or_env(values, "github_base_url", "GITHUB_BASE_URL", default='https://api.github.com') + def model_construct(cls, *args, **kwargs) -> Self: + klass = super().model_construct(*args, **kwargs) + klass._github_client_instance = GitHubClient.model_construct() + klass._graphql_client_instance = GraphQLClientWrapper.model_construct() + return klass - auth_config = GitHubAuthConfig( - github_access_token=github_access_token, - github_username=github_username, - github_password=github_password, - github_app_id=github_app_id, # This will be None if not provided - GitHubAuthConfig should allow this - github_app_private_key=github_app_private_key, - github_base_url=github_base_url - ) - - # Rest of initialization code remains the same - github_repository = get_from_dict_or_env(values, "github_repository", "GITHUB_REPOSITORY") - github_repository = GitHubClient.clean_repository_name(github_repository) - - repo_config = GitHubRepoConfig( - github_repository=github_repository, - active_branch=get_from_dict_or_env(values, "active_branch", "ACTIVE_BRANCH", default='main'), # Change from 'ai' to 'main' - github_base_branch=get_from_dict_or_env(values, "github_base_branch", "GITHUB_BASE_BRANCH", default="main") + @property + def auth_config(self): + return GitHubAuthConfig( + github_access_token=self.github_access_token, + github_username=self.github_username, + github_password=self.github_password, + github_app_id=self.github_app_id, # This will be None if not provided - GitHubAuthConfig should allow this + github_app_private_key=self.github_app_private_key, + github_base_url=self.github_base_url ) - # Initialize GitHub client with keyword arguments - github_client = GitHubClient(auth_config=auth_config, repo_config=repo_config) - # Initialize GraphQL client with keyword argument - graphql_client = GraphQLClientWrapper(github_graphql_instance=github_client.github_api._Github__requester) - # Set client attributes on the class (renamed from _github_client to github_client_instance) - values["github_client_instance"] = github_client - values["graphql_client_instance"] = graphql_client - - # Update values - values["github_repository"] = github_repository - values["active_branch"] = repo_config.active_branch - values["github_base_branch"] = repo_config.github_base_branch - - # Ensure LLM is available in values if needed - if "llm" not in values: - values["llm"] = None - - return values + @field_validator( + 'github_access_token', + 'github_username', + 'github_password', + 'github_app_id', + 'github_app_private_key', + 'github_repository', + 'github_base_url', + mode='before', check_fields=False + ) + def set_from_values_or_env(cls, value: str, info: ValidationInfo) -> Optional[str]: + if value is None: + if json_schema_extra := cls.model_fields[info.field_name].json_schema_extra: + if env_key := json_schema_extra.get('env_key'): + try: + return get_from_env(key=info.field_name, env_key=env_key, + default=cls.model_fields[info.field_name].default) + except ValueError: + return None + return value + + @field_validator('github_repository', mode='after') + def clean_value(cls, value: str) -> str: + return GitHubClient.clean_repository_name(value) + + @model_validator(mode='after') + def validate_auth(self) -> Self: + # Check that at least one authentication method is provided + if not (self.github_access_token or (self.github_username and self.github_password) or self.github_app_id): + raise ValueError( + "You must provide either a GitHub access token, username/password, or app credentials." + ) + return self # Expose GitHub REST client methods directly via property @property def github_client(self) -> GitHubClient: """Access to GitHub REST client methods""" - return self.github_client_instance + if not self._github_client_instance: + self._github_client_instance = GitHubClient( + auth_config=self.auth_config, + repo_config=GitHubRepoConfig.model_validate(self) + ) + return self._github_client_instance - # Expose GraphQL client methods directly via property + # Expose GraphQL client methods directly via property @property def graphql_client(self) -> GraphQLClientWrapper: """Access to GitHub GraphQL client methods""" - return self.graphql_client_instance - + if not self._graphql_client_instance: + self._graphql_client_instance = GraphQLClientWrapper( + github_graphql_instance=self.github_client.github_api._Github__requester + ) + return self._graphql_client_instance def process_github_query(self, query: str) -> Any: try: @@ -161,7 +153,6 @@ def process_github_query(self, query: str) -> Any: logger.error(f"Error processing GitHub query: {e}\n{traceback.format_exc()}") return f"Error processing GitHub query: {e}" - def generate_github_code(self, task_to_solve: str, error_trace: str = None) -> str: """Generate Python code using LLM based on the GitHub task to solve.""" if not self.llm: @@ -218,7 +209,7 @@ def generate_code_with_retries(self, query: str) -> str: generated_code = self.generate_github_code(query, error_context) # Basic validation: check if code seems runnable (contains 'self.run') if "self.run(" in generated_code: - return generated_code + return generated_code else: raise ValueError("Generated code does not seem to call any GitHub tools.") except Exception as e: @@ -231,20 +222,21 @@ def generate_code_with_retries(self, query: str) -> str: logger.error( f"Maximum retry attempts exceeded for GitHub code generation. Last error: {last_error}" ) - raise Exception(f"Failed to generate valid GitHub code after {max_retries} retries. Last error: {e}") from e + raise Exception( + f"Failed to generate valid GitHub code after {max_retries} retries. Last error: {e}") from e # Should not be reached if logic is correct, but added for safety raise Exception("Failed to generate GitHub code.") - def get_available_tools(self): + @property + def github_tools(self) -> list: + return self.github_client.get_available_tools() + + @property + def graphql_tools(self) -> list: + return self.graphql_client.get_available_tools() + + def get_available_tools(self) -> list[dict[str, Any]]: # this is horrible, I need to think on something better - if not self.github_client_instance: - github_tools = GitHubClient.model_construct().get_available_tools() - else: - github_tools = self.github_client_instance.get_available_tools() - if not self.graphql_client_instance: - graphql_tools = GraphQLClientWrapper.model_construct().get_available_tools() - else: - graphql_tools = self.graphql_client_instance.get_available_tools() code_gen = [ # { # "ref": self.process_github_query, @@ -254,7 +246,7 @@ def get_available_tools(self): # "args_schema": ProcessGitHubQueryModel # } ] - tools = github_tools + graphql_tools + code_gen + tools = self.github_tools + self.graphql_tools + code_gen return tools def run(self, name: str, *args: Any, **kwargs: Any): @@ -262,18 +254,18 @@ def run(self, name: str, *args: Any, **kwargs: Any): if tool["name"] == name: # Handle potential dictionary input for args when only one dict is passed if len(args) == 1 and isinstance(args[0], dict) and not kwargs: - kwargs = args[0] - args = () # Clear args + kwargs = args[0] + args = () # Clear args try: return tool["ref"](*args, **kwargs) except TypeError as e: - # Attempt to call with kwargs only if args fail and kwargs exist - if kwargs and not args: - try: - return tool["ref"](**kwargs) - except TypeError: - raise ValueError(f"Argument mismatch for tool '{name}'. Error: {e}") from e - else: - raise ValueError(f"Argument mismatch for tool '{name}'. Error: {e}") from e + # Attempt to call with kwargs only if args fail and kwargs exist + if kwargs and not args: + try: + return tool["ref"](**kwargs) + except TypeError: + raise ValueError(f"Argument mismatch for tool '{name}'. Error: {e}") from e + else: + raise ValueError(f"Argument mismatch for tool '{name}'. Error: {e}") from e else: raise ValueError(f"Unknown tool name: {name}") From fe950ffd5dc2e9b0924ba94be5a1bea22f587f2a Mon Sep 17 00:00:00 2001 From: Artem Dubrovskiy Date: Wed, 11 Jun 2025 17:28:05 +0400 Subject: [PATCH 2/2] refactoring init process --- src/alita_tools/__init__.py | 459 ++++++++++++++-------- src/alita_tools/ado/__init__.py | 58 +-- src/alita_tools/ado/repos/__init__.py | 19 + src/alita_tools/ado/test_plan/__init__.py | 11 + src/alita_tools/ado/wiki/__init__.py | 12 + src/alita_tools/ado/work_item/__init__.py | 11 + src/alita_tools/keycloak/__init__.py | 2 +- 7 files changed, 392 insertions(+), 180 deletions(-) diff --git a/src/alita_tools/__init__.py b/src/alita_tools/__init__.py index 5d5c5e2e..d25175ae 100644 --- a/src/alita_tools/__init__.py +++ b/src/alita_tools/__init__.py @@ -1,172 +1,325 @@ import logging from importlib import import_module +from typing import Dict, List + +# ado is an exceptional toolkit, unfortunately +from .ado import get_tools as get_ado_tools, supported_types as ado_supported_types -from .github import get_tools as get_github, AlitaGitHubToolkit -from .openapi import get_tools as get_openapi -from .jira import get_tools as get_jira, JiraToolkit -from .confluence import get_tools as get_confluence, ConfluenceToolkit -from .servicenow import get_tools as get_service_now, ServiceNowToolkit -from .gitlab import get_tools as get_gitlab, AlitaGitlabToolkit -from .gitlab_org import get_tools as get_gitlab_org, AlitaGitlabSpaceToolkit -from .zephyr import get_tools as get_zephyr, ZephyrToolkit -from .browser import get_tools as get_browser, BrowserToolkit -from .report_portal import get_tools as get_report_portal, ReportPortalToolkit -from .bitbucket import get_tools as get_bitbucket, AlitaBitbucketToolkit -from .testrail import get_tools as get_testrail, TestrailToolkit -from .testio import get_tools as get_testio, TestIOToolkit -from .xray import get_tools as get_xray_cloud, XrayToolkit -from .sharepoint import get_tools as get_sharepoint, SharepointToolkit -from .qtest import get_tools as get_qtest, QtestToolkit -from .zephyr_scale import get_tools as get_zephyr_scale, ZephyrScaleToolkit -from .zephyr_enterprise import get_tools as get_zephyr_enterprise, ZephyrEnterpriseToolkit -from .ado import get_tools as get_ado -from .ado.repos import AzureDevOpsReposToolkit -from .ado.test_plan import AzureDevOpsPlansToolkit -from .ado.work_item import AzureDevOpsWorkItemsToolkit -from .ado.wiki import AzureDevOpsWikiToolkit -from .rally import get_tools as get_rally, RallyToolkit -from .sql import get_tools as get_sql, SQLToolkit -from .code.sonar import get_tools as get_sonar, SonarToolkit -from .google_places import get_tools as get_google_places, GooglePlacesToolkit -from .yagmail import get_tools as get_yagmail, AlitaYagmailToolkit -from .cloud.aws import AWSToolkit -from .cloud.azure import AzureToolkit -from .cloud.gcp import GCPToolkit -from .cloud.k8s import KubernetesToolkit -from .custom_open_api import OpenApiToolkit as CustomOpenApiToolkit -from .elastic import ElasticToolkit -from .keycloak import KeycloakToolkit -from .localgit import AlitaLocalGitToolkit -from .pandas import get_tools as get_pandas, PandasToolkit -from .azure_ai.search import AzureSearchToolkit, get_tools as get_azure_search -from .figma import get_tools as get_figma, FigmaToolkit -from .salesforce import get_tools as get_salesforce, SalesforceToolkit -from .carrier import get_tools as get_carrier, AlitaCarrierToolkit -from .ocr import get_tools as get_ocr, OCRToolkit -from .pptx import get_tools as get_pptx, PPTXToolkit logger = logging.getLogger(__name__) -def get_tools(tools_list, alita: 'AlitaClient', llm: 'LLMLikeObject', *args, **kwargs): +# Registry that holds the name of the toolkit, import path, get_tools function, and toolkit class +TOOLKIT_REGISTRY: list[Dict[str, str]] = [ + { + "import_path": ".github", + "get_function_name": "get_tools", + "toolkit_class_name": "AlitaGitHubToolkit", + }, + { + "import_path": ".openapi", + "get_function_name": "get_tools", + "toolkit_class_name": "AlitaOpenAPIToolkit", + }, + { + "import_path": ".jira", + "get_function_name": "get_tools", + "toolkit_class_name": "JiraToolkit", + }, + { + "import_path": ".confluence", + "get_function_name": "get_tools", + "toolkit_class_name": "ConfluenceToolkit", + }, + { + "import_path": ".servicenow", + "get_function_name": "get_tools", + "toolkit_class_name": "ServiceNowToolkit", + }, + { + "import_path": ".gitlab", + "get_function_name": "get_tools", + "toolkit_class_name": "AlitaGitlabToolkit", + }, + { + "import_path": ".gitlab_org", + "get_function_name": "get_tools", + "toolkit_class_name": "AlitaGitlabSpaceToolkit", + }, + { + "import_path": ".zephyr", + "get_function_name": "get_tools", + "toolkit_class_name": "ZephyrToolkit", + }, + { + "import_path": ".browser", + "get_function_name": "get_tools", + "toolkit_class_name": "BrowserToolkit", + }, + { + "import_path": ".yagmail", + "get_function_name": "get_tools", + "toolkit_class_name": "AlitaYagmailToolkit", + }, + { + "import_path": ".report_portal", + "get_function_name": "get_tools", + "toolkit_class_name": "ReportPortalToolkit", + }, + { + "import_path": ".bitbucket", + "get_function_name": "get_tools", + "toolkit_class_name": "AlitaBitbucketToolkit", + }, + { + "import_path": ".testrail", + "get_function_name": "get_tools", + "toolkit_class_name": "TestrailToolkit", + }, + { + "import_path": ".testio", + "get_function_name": "get_tools", + "toolkit_class_name": "TestIOToolkit", + }, + { + "import_path": ".xray", + "get_function_name": "get_tools", + "toolkit_class_name": "XrayToolkit", + }, + { + "import_path": ".sharepoint", + "get_function_name": "get_tools", + "toolkit_class_name": "SharepointToolkit", + }, + { + "import_path": ".qtest", + "get_function_name": "get_tools", + "toolkit_class_name": "QtestToolkit", + }, + { + "import_path": ".zephyr_scale", + "get_function_name": "get_tools", + "toolkit_class_name": "ZephyrScaleToolkit", + }, + { + "import_path": ".zephyr_enterprise", + "get_function_name": "get_tools", + "toolkit_class_name": "ZephyrEnterpriseToolkit", + }, + { + "import_path": ".rally", + "get_function_name": "get_tools", + "toolkit_class_name": "RallyToolkit", + }, + { + "import_path": ".sql", + "get_function_name": "get_tools", + "toolkit_class_name": "SQLToolkit", + }, + { + "import_path": ".code.sonar", + "get_function_name": "get_tools", + "toolkit_class_name": "SonarToolkit", + }, + { + "import_path": ".google_places", + "get_function_name": "get_tools", + "toolkit_class_name": "GooglePlacesToolkit", + }, + { + "import_path": ".azure_ai.search", + "get_function_name": "get_tools", + "toolkit_class_name": "AzureSearchToolkit", + }, + { + "import_path": ".pandas", + "get_function_name": "get_tools", + "toolkit_class_name": "PandasToolkit", + }, + { + "import_path": ".figma", + "get_function_name": "get_tools", + "toolkit_class_name": "FigmaToolkit", + }, + { + "import_path": ".salesforce", + "get_function_name": "get_tools", + "toolkit_class_name": "SalesforceToolkit", + }, + { + "import_path": ".carrier", + "get_function_name": "get_tools", + "toolkit_class_name": "AlitaCarrierToolkit", + }, + { + "import_path": ".ocr", + "get_function_name": "get_tools", + "toolkit_class_name": "OCRToolkit", + }, + { + "import_path": ".pptx", + "get_function_name": "get_tools", + "toolkit_class_name": "PPTXToolkit", + }, + { + "import_path": ".ado.repos", + "get_function_name": "get_tools", + "toolkit_class_name": "AzureDevOpsReposToolkit", + }, + { + "import_path": ".ado.work_item", + "get_function_name": "get_tools", + "toolkit_class_name": "AzureDevOpsWorkItemsToolkit", + }, + { + "import_path": ".ado.wiki", + "get_function_name": "get_tools", + "toolkit_class_name": "AzureDevOpsWikiToolkit", + }, + { + "import_path": ".ado.test_plan", + "get_function_name": "get_tools", + "toolkit_class_name": "AzureDevOpsPlansToolkit", + }, + { + "import_path": ".elastic", + "get_function_name": "get_tools", + "toolkit_class_name": "ElasticToolkit", + }, + { + "import_path": ".custom_open_api", + "get_function_name": "get_tools", + "toolkit_class_name": "OpenApiToolkit", + }, + { + "import_path": ".localgit", + "get_function_name": "get_tools", + "toolkit_class_name": "AlitaLocalGitToolkit", + }, + { + "import_path": ".cloud.aws", + "get_function_name": "get_tools", + "toolkit_class_name": "AWSToolkit", + }, + { + "import_path": ".cloud.azure", + "get_function_name": "get_tools", + "toolkit_class_name": "AzureToolkit", + }, + { + "import_path": ".cloud.gcp", + "get_function_name": "get_tools", + "toolkit_class_name": "GCPToolkit", + }, + { + "import_path": ".cloud.k8s", + "get_function_name": "get_tools", + "toolkit_class_name": "KubernetesToolkit", + }, + { + "import_path": ".keycloak", + "get_function_name": "get_tools", + "toolkit_class_name": "KeycloakToolkit", + } +] + + +# Dynamically import everything once and store references +IMPORTED_TOOLKITS = {} +for toolkit_info in TOOLKIT_REGISTRY: + import_path = toolkit_info['import_path'] + try: + mod = import_module(import_path, package=__name__) + try: + toolkit_name = getattr(mod, 'name') + get_function = getattr(mod, toolkit_info["get_function_name"]) + toolkit_class = getattr(mod, toolkit_info["toolkit_class_name"]) + except AttributeError as e: + logger.error(f'Error importing toolkit {toolkit_info["import_path"]}: {e}') + raise + if toolkit_name and get_function and toolkit_class: + IMPORTED_TOOLKITS[toolkit_name] = { + "get_function": get_function, + "toolkit_class": toolkit_class, + 'name': toolkit_name + } + except ImportError as e: + logger.warning( + "Could not import '%s' library; skipping. Reason: %s", + toolkit_info['import_path'], e + ) + + +def get_tools(tools_list: list[dict], alita: "AlitaClient", llm: "LLMLikeObject", *args, **kwargs) -> List[dict]: + + tools = [] for tool in tools_list: - tool['settings']['alita'] = alita - tool['settings']['llm'] = llm - if tool['type'] == 'openapi': - tools.extend(get_openapi(tool)) - elif tool['type'] == 'github': - tools.extend(get_github(tool)) - elif tool['type'] == 'jira': - tools.extend(get_jira(tool)) - elif tool['type'] == 'confluence': - tools.extend(get_confluence(tool)) - elif tool['type'] == 'service_now': - tools.extend(get_service_now(tool)) - elif tool['type'] == 'gitlab': - tools.extend(get_gitlab(tool)) - elif tool['type'] == 'gitlab_org': - tools.extend(get_gitlab_org(tool)) - elif tool['type'] == 'zephyr': - tools.extend(get_zephyr(tool)) - elif tool['type'] == 'browser': - tools.extend(get_browser(tool)) - elif tool['type'] == 'yagmail': - tools.extend(get_yagmail(tool)) - elif tool['type'] == 'report_portal': - tools.extend(get_report_portal(tool)) - elif tool['type'] == 'bitbucket': - tools.extend(get_bitbucket(tool)) - elif tool['type'] == 'testrail': - tools.extend(get_testrail(tool)) - elif tool['type'] in ['ado_boards', 'ado_wiki', 'ado_plans', 'ado_repos', 'azure_devops_repos']: - tools.extend(get_ado(tool['type'], tool)) - elif tool['type'] == 'testio': - tools.extend(get_testio(tool)) - elif tool['type'] == 'xray_cloud': - tools.extend(get_xray_cloud(tool)) - elif tool['type'] == 'sharepoint': - tools.extend(get_sharepoint(tool)) - elif tool['type'] == 'qtest': - tools.extend(get_qtest(tool)) - elif tool['type'] == 'zephyr_scale': - tools.extend(get_zephyr_scale(tool)) - elif tool['type'] == 'zephyr_enterprise': - tools.extend(get_zephyr_enterprise(tool)) - elif tool['type'] == 'rally': - tools.extend(get_rally(tool)) - elif tool['type'] == 'sql': - tools.extend(get_sql(tool)) - elif tool['type'] == 'sonar': - tools.extend(get_sonar(tool)) - elif tool['type'] == 'google_places': - tools.extend(get_google_places(tool)) - elif tool['type'] == 'azure_search': - tools.extend(get_azure_search(tool)) - elif tool['type'] == 'pandas': - tools.extend(get_pandas(tool)) - elif tool['type'] == 'figma': - tools.extend(get_figma(tool)) - elif tool['type'] == 'salesforce': - tools.extend(get_salesforce(tool)) - elif tool['type'] == 'carrier': - tools.extend(get_carrier(tool)) - elif tool['type'] == 'ocr': - tools.extend(get_ocr(tool)) - elif tool['type'] == 'pptx': - tools.extend(get_pptx(tool)) + tool.setdefault("settings", {}) + tool["settings"]["alita"] = alita + tool["settings"]["llm"] = llm + + # Identify the toolkit name or type + ttype: str = tool["type"] + + # If we have a dynamic import function for the type + if ttype in IMPORTED_TOOLKITS: + toolkit_info = IMPORTED_TOOLKITS[ttype] + try: + toolkit_tools = toolkit_info["get_function"](tool) + tools.extend(toolkit_tools) + except Exception as e: + logger.warning("Error getting tools for '%s'. Reason: %s", ttype, e) + elif ttype in ado_supported_types: + try: + toolkit_tools = get_ado_tools(tool_type=ttype, tool=tool) + tools.extend(toolkit_tools) + except Exception as e: + logger.warning("Error getting tools for '%s'. Reason: %s", ttype, e) else: + # Fallback for custom modules if tool.get("settings", {}).get("module"): try: settings = tool.get("settings", {}) mod = import_module(settings.pop("module")) tkitclass = getattr(mod, settings.pop("class")) - toolkit = tkitclass.get_toolkit(**tool["settings"]) + toolkit = tkitclass.get_toolkit(**settings) tools.extend(toolkit.get_tools()) except Exception as e: - logger.error(f"Error in getting toolkit: {e}") + logger.warning( + "Error in getting custom toolkit [%s]. Skipping. Reason: %s", + tool.get("type"), e + ) + else: + logger.warning( + "Unrecognized toolkit '%s' and no custom module provided. Skipping.", + ttype + ) return tools + def get_toolkits(): - return [ - AlitaGitHubToolkit.toolkit_config_schema(), - TestrailToolkit.toolkit_config_schema(), - JiraToolkit.toolkit_config_schema(), - AzureDevOpsPlansToolkit.toolkit_config_schema(), - AzureDevOpsWikiToolkit.toolkit_config_schema(), - AzureDevOpsWorkItemsToolkit.toolkit_config_schema(), - RallyToolkit.toolkit_config_schema(), - QtestToolkit.toolkit_config_schema(), - ReportPortalToolkit.toolkit_config_schema(), - TestIOToolkit.toolkit_config_schema(), - SQLToolkit.toolkit_config_schema(), - SonarToolkit.toolkit_config_schema(), - GooglePlacesToolkit.toolkit_config_schema(), - BrowserToolkit.toolkit_config_schema(), - XrayToolkit.toolkit_config_schema(), - AlitaGitlabToolkit.toolkit_config_schema(), - ConfluenceToolkit.toolkit_config_schema(), - ServiceNowToolkit.toolkit_config_schema(), - AlitaBitbucketToolkit.toolkit_config_schema(), - AlitaGitlabSpaceToolkit.toolkit_config_schema(), - ZephyrScaleToolkit.toolkit_config_schema(), - ZephyrEnterpriseToolkit.toolkit_config_schema(), - ZephyrToolkit.toolkit_config_schema(), - AlitaYagmailToolkit.toolkit_config_schema(), - SharepointToolkit.toolkit_config_schema(), - AzureDevOpsReposToolkit.toolkit_config_schema(), - AWSToolkit.toolkit_config_schema(), - AzureToolkit.toolkit_config_schema(), - GCPToolkit.toolkit_config_schema(), - KubernetesToolkit.toolkit_config_schema(), - CustomOpenApiToolkit.toolkit_config_schema(), - ElasticToolkit.toolkit_config_schema(), - KeycloakToolkit.toolkit_config_schema(), - AlitaLocalGitToolkit.toolkit_config_schema(), - PandasToolkit.toolkit_config_schema(), - AzureSearchToolkit.toolkit_config_schema(), - FigmaToolkit.toolkit_config_schema(), - SalesforceToolkit.toolkit_config_schema(), - AlitaCarrierToolkit.toolkit_config_schema(), - OCRToolkit.toolkit_config_schema(), - PPTXToolkit.toolkit_config_schema(), - ] \ No newline at end of file + schemas = [] + for ttype, info in IMPORTED_TOOLKITS.items(): + # If we have a valid class ref, call the .toolkit_config_schema(), etc. + if info.get("toolkit_class"): + try: + config_schema = info["toolkit_class"].toolkit_config_schema() + schemas.append(config_schema) + except AttributeError: + logger.warning( + "Toolkit class for '%s' does not have 'toolkit_config_schema'. Skipping.", + ttype + ) + return schemas + + +def get_supported_tool_types() -> List[str]: + """ + Get a list of all supported tool types. + + Returns: + List of supported tool type strings + """ + result = set(IMPORTED_TOOLKITS.keys()) + result.update(ado_supported_types) + return list(result) diff --git a/src/alita_tools/ado/__init__.py b/src/alita_tools/ado/__init__.py index e3d07fdf..30c6f4ad 100644 --- a/src/alita_tools/ado/__init__.py +++ b/src/alita_tools/ado/__init__.py @@ -1,30 +1,36 @@ -from .test_plan import AzureDevOpsPlansToolkit -from .wiki import AzureDevOpsWikiToolkit -from .work_item import AzureDevOpsWorkItemsToolkit -from .repos import AzureDevOpsReposToolkit +from typing import List, Literal, Union + +from langchain_core.tools import BaseTool + +from .repos import AzureDevOpsReposToolkit, name as ado_repos_name, get_tools as ado_repos_get_tools +from .test_plan import AzureDevOpsPlansToolkit, name as ado_plans_name, name_alias as ado_plans_name_alias, \ + get_tools as ado_plans_get_tools +from .wiki import AzureDevOpsWikiToolkit, name as ado_wiki_name, name_alias as ado_wiki_name_alias, \ + get_tools as ado_wiki_get_tools +from .work_item import AzureDevOpsWorkItemsToolkit, name as ado_work_items_name, \ + name_alias as ado_work_items_name_alias, get_tools as ado_work_items_get_tools name = "azure_devops" -def get_tools(tool_type, tool): - config_dict = { - # common - "selected_tools": tool['settings'].get('selected_tools', []), - "organization_url": tool['settings']['organization_url'], - "project": tool['settings'].get('project', None), - "token": tool['settings'].get('token', None), - "limit": tool['settings'].get('limit', 5), - # repos only - "repository_id": tool['settings'].get('repository_id', None), - "base_branch": tool['settings'].get('base_branch', None), - "active_branch": tool['settings'].get('active_branch', None), - "toolkit_name": tool.get('toolkit_name', ''), - } - if tool_type == 'ado_plans': - return AzureDevOpsPlansToolkit().get_toolkit(**config_dict).get_tools() - elif tool_type == 'ado_wiki': - return AzureDevOpsWikiToolkit().get_toolkit(**config_dict).get_tools() - elif tool_type == 'ado_repos' or tool_type == 'azure_devops_repos': - return AzureDevOpsReposToolkit().get_toolkit(**config_dict).get_tools() - else: - return AzureDevOpsWorkItemsToolkit().get_toolkit(**config_dict).get_tools() +supported_types: set = { + ado_plans_name, ado_plans_name_alias, + ado_wiki_name, ado_wiki_name_alias, + ado_repos_name, ado_work_items_name, + ado_work_items_name_alias +} + + +def get_tools( + tool_type: Union[*supported_types], + tool: dict +) -> List[BaseTool]: + if tool_type in (ado_plans_name, ado_plans_name_alias): + return ado_plans_get_tools(tool) + elif tool_type in (ado_wiki_name, ado_wiki_name_alias): + return ado_wiki_get_tools(tool) + elif tool_type == ado_repos_name: + return ado_repos_get_tools(tool) + elif tool_type in (ado_work_items_name, ado_work_items_name_alias): + return ado_work_items_get_tools(tool) + raise ValueError(f"Unsupported tool type: {tool_type}") diff --git a/src/alita_tools/ado/repos/__init__.py b/src/alita_tools/ado/repos/__init__.py index 9a822f4a..e7198fc8 100644 --- a/src/alita_tools/ado/repos/__init__.py +++ b/src/alita_tools/ado/repos/__init__.py @@ -9,6 +9,25 @@ name = "ado_repos" +def get_tools(tool: dict) -> List[BaseTool]: + config_dict = { + # common + "selected_tools": tool['settings'].get('selected_tools', []), + "organization_url": tool['settings']['organization_url'], + "project": tool['settings'].get('project', None), + "token": tool['settings'].get('token', None), + "limit": tool['settings'].get('limit', 5), + # repos only + "repository_id": tool['settings'].get('repository_id', None), + "base_branch": tool['settings'].get('base_branch', None), + "active_branch": tool['settings'].get('active_branch', None), + "toolkit_name": tool.get('toolkit_name', ''), + } + + return AzureDevOpsReposToolkit().get_toolkit(**config_dict).get_tools() + + + class AzureDevOpsReposToolkit(BaseToolkit): tools: List[BaseTool] = [] toolkit_max_length: int = 0 diff --git a/src/alita_tools/ado/test_plan/__init__.py b/src/alita_tools/ado/test_plan/__init__.py index b2545161..6e837093 100644 --- a/src/alita_tools/ado/test_plan/__init__.py +++ b/src/alita_tools/ado/test_plan/__init__.py @@ -11,6 +11,17 @@ name = "azure_devops_plans" name_alias = "ado_plans" +def get_tools(tool: dict) -> List[BaseTool]: + config_dict = { + # common + "selected_tools": tool['settings'].get('selected_tools', []), + "organization_url": tool['settings']['organization_url'], + "project": tool['settings'].get('project', None), + "token": tool['settings'].get('token', None), + "limit": tool['settings'].get('limit', 5), + } + return AzureDevOpsPlansToolkit().get_toolkit(**config_dict).get_tools() + class AzureDevOpsPlansToolkit(BaseToolkit): tools: List[BaseTool] = [] diff --git a/src/alita_tools/ado/wiki/__init__.py b/src/alita_tools/ado/wiki/__init__.py index 3bd511e1..83bb3389 100644 --- a/src/alita_tools/ado/wiki/__init__.py +++ b/src/alita_tools/ado/wiki/__init__.py @@ -9,6 +9,18 @@ name = "azure_devops_wiki" name_alias = 'ado_wiki' +def get_tools(tool: dict) -> List[BaseTool]: + config_dict = { + # common + "selected_tools": tool['settings'].get('selected_tools', []), + "organization_url": tool['settings']['organization_url'], + "project": tool['settings'].get('project', None), + "token": tool['settings'].get('token', None), + "limit": tool['settings'].get('limit', 5), + } + return AzureDevOpsWikiToolkit().get_toolkit(**config_dict).get_tools() + + class AzureDevOpsWikiToolkit(BaseToolkit): tools: List[BaseTool] = [] toolkit_max_length: int = 0 diff --git a/src/alita_tools/ado/work_item/__init__.py b/src/alita_tools/ado/work_item/__init__.py index 47939885..ed5638a2 100644 --- a/src/alita_tools/ado/work_item/__init__.py +++ b/src/alita_tools/ado/work_item/__init__.py @@ -9,6 +9,17 @@ name = "azure_devops_boards" name_alias = 'ado_boards' +def get_tools(tool: dict) -> List[BaseTool]: + config_dict = { + # common + "selected_tools": tool['settings'].get('selected_tools', []), + "organization_url": tool['settings']['organization_url'], + "project": tool['settings'].get('project', None), + "token": tool['settings'].get('token', None), + "limit": tool['settings'].get('limit', 5), + } + return AzureDevOpsWorkItemsToolkit().get_toolkit(**config_dict).get_tools() + class AzureDevOpsWorkItemsToolkit(BaseToolkit): tools: List[BaseTool] = [] toolkit_max_length: int = 0 diff --git a/src/alita_tools/keycloak/__init__.py b/src/alita_tools/keycloak/__init__.py index f2427c86..c6007d72 100644 --- a/src/alita_tools/keycloak/__init__.py +++ b/src/alita_tools/keycloak/__init__.py @@ -57,4 +57,4 @@ def get_toolkit(cls, selected_tools: list[str] | None = None, toolkit_name: Opti return cls(tools=tools) def get_tools(self) -> list[BaseTool]: - return self.tools \ No newline at end of file + return self.tools