From 6f95cc2fbe77112606070b60627b8ab3bcffabc8 Mon Sep 17 00:00:00 2001 From: Artem Dubrovskiy Date: Mon, 19 May 2025 16:13:49 +0400 Subject: [PATCH] refactoring github toolkit as an example --- src/alita_tools/__init__.py | 1 - src/alita_tools/github/__init__.py | 206 +++++++++++++++-------- src/alita_tools/github/api_wrapper.py | 228 ++++++++++++-------------- 3 files changed, 246 insertions(+), 189 deletions(-) diff --git a/src/alita_tools/__init__.py b/src/alita_tools/__init__.py index 33cb60c0..f68e65b9 100644 --- a/src/alita_tools/__init__.py +++ b/src/alita_tools/__init__.py @@ -48,7 +48,6 @@ def get_tools(tools_list, alita: 'AlitaClient', llm: 'LLMLikeObject', *args, **kwargs): tools = [] - print("Tools") for tool in tools_list: tool['settings']['alita'] = alita tool['settings']['llm'] = llm diff --git a/src/alita_tools/github/__init__.py b/src/alita_tools/github/__init__.py index 8188e23b..7586c051 100644 --- a/src/alita_tools/github/__init__.py +++ b/src/alita_tools/github/__init__.py @@ -1,7 +1,9 @@ -from typing import Dict, List, Optional, Literal +from functools import lru_cache +from typing import Type +from typing import List, Optional from langchain_core.tools import BaseTool, BaseToolkit -from pydantic import create_model, BaseModel, ConfigDict, Field, SecretStr +from pydantic import BaseModel, Field, SecretStr, field_validator, computed_field from .api_wrapper import AlitaGitHubAPIWrapper from .tool import GitHubAction @@ -10,6 +12,105 @@ name = "github" + +@lru_cache(maxsize=1) +def get_available_tools() -> dict[str, dict]: + api_wrapper = AlitaGitHubAPIWrapper.model_construct() + available_tools: dict = { + x['name']: x['args_schema'].model_json_schema() for x in + api_wrapper.get_available_tools() + } + return available_tools + +toolkit_max_length = lru_cache(maxsize=1)(lambda: get_max_toolkit_length(get_available_tools())) + + +class AlitaGitHubToolkitConfig(BaseModel): + class Config: + title = name + json_schema_extra = { + 'metadata': { + "label": "GitHub", + "icon_url": None, + "sections": { + "auth": { + "required": False, + "subsections": [ + { + "name": "Token", + "fields": ["access_token"] + }, + { + "name": "Password", + "fields": ["username", "password"] + }, + { + "name": "App private key", + "fields": ["app_id", "app_private_key"] + }, + ] + } + } + } + } + + base_url: Optional[str] = Field( + default="https://api.github.com", + description="Base API URL", + json_schema_extra={'configuration': True, 'configuration_title': True} + ) + app_id: Optional[str] = Field( + default=None, + description="Github APP ID", + json_schema_extra={'configuration': True}, + ) + app_private_key: Optional[SecretStr] = Field( + default=None, + description="Github APP private key", + json_schema_extra={'secret': True, 'configuration': True}, + ) + access_token: Optional[SecretStr] = Field( + default=None, + description="Github Access Token", + json_schema_extra={'secret': True, 'configuration': True}, + ) + username: Optional[str] = Field( + default=None, + description="Github Username", + json_schema_extra={'configuration': True}, + ) + password: Optional[SecretStr] = Field( + default=None, + description="Github Password", + json_schema_extra={'secret': True, 'configuration': True}, + ) + repository: str = Field( + description="Github repository", + json_schema_extra={ + 'toolkit_name': True, + 'max_toolkit_length': 100 # Example limit; adjust as needed + }, + ) + active_branch: Optional[str] = Field( + default="main", + description="Active branch", + ) + base_branch: Optional[str] = Field( + default="main", + description="Github Base branch", + ) + selected_tools: List[str] = Field( + default=[], + description="Selected tools", + json_schema_extra={'args_schemas': get_available_tools()}, + ) + + @field_validator('selected_tools', mode='before', check_fields=False) + @classmethod + def selected_tools_validator(cls, value: List[str]) -> list[str]: + return [i for i in value if i in get_available_tools()] + + def _get_toolkit(tool) -> BaseToolkit: return AlitaGitHubToolkit().get_toolkit( selected_tools=tool['settings'].get('selected_tools', []), @@ -26,85 +127,56 @@ def _get_toolkit(tool) -> BaseToolkit: toolkit_name=tool.get('toolkit_name') ) + def get_toolkit(): return AlitaGitHubToolkit.toolkit_config_schema() + def get_tools(tool): return _get_toolkit(tool).get_tools() + class AlitaGitHubToolkit(BaseToolkit): tools: List[BaseTool] = [] - toolkit_max_length: int = 0 - @staticmethod - def toolkit_config_schema() -> BaseModel: - selected_tools = {x['name']: x['args_schema'].schema() for x in AlitaGitHubAPIWrapper.model_construct().get_available_tools()} - AlitaGitHubToolkit.toolkit_max_length = get_max_toolkit_length(selected_tools) - return create_model( - name, - __config__=ConfigDict( - json_schema_extra={ - 'metadata': { - "label": "GitHub", - "icon_url": None, - "sections": { - "auth": { - "required": False, - "subsections": [ - { - "name": "Token", - "fields": ["access_token"] - }, - { - "name": "Password", - "fields": ["username", "password"] - }, - { - "name": "App private key", - "fields": ["app_id", "app_private_key"] - } - ] - } - } - }, - } - ), - base_url=(Optional[str], Field(description="Base API URL", default="https://api.github.com", json_schema_extra={'configuration': True, 'configuration_title': True})), - app_id=(Optional[str], Field(description="Github APP ID", default=None, json_schema_extra={'configuration': True})), - app_private_key=(Optional[SecretStr], Field(description="Github APP private key", default=None, json_schema_extra={'secret': True, 'configuration': True})), + api_wrapper: Optional[AlitaGitHubAPIWrapper] = Field(default_factory=AlitaGitHubAPIWrapper.model_construct) + toolkit_name: Optional[str] = None - access_token=(Optional[SecretStr], Field(description="Github Access Token", default=None, json_schema_extra={'secret': True, 'configuration': True})), + @computed_field + @property + def tool_prefix(self) -> str: + return clean_string(self.toolkit_name, toolkit_max_length()) + TOOLKIT_SPLITTER if self.toolkit_name else '' - username=(Optional[str], Field(description="Github Username", default=None, json_schema_extra={'configuration': True})), - password=(Optional[SecretStr], Field(description="Github Password", default=None, json_schema_extra={'secret': True, 'configuration': True})), + @computed_field + @property + def available_tools(self) -> List[dict]: + return self.api_wrapper.get_available_tools() - repository=(str, Field(description="Github repository", json_schema_extra={'toolkit_name': True, 'max_toolkit_length': AlitaGitHubToolkit.toolkit_max_length})), - active_branch=(Optional[str], Field(description="Active branch", default="main")), - base_branch=(Optional[str], Field(description="Github Base branch", default="main")), - selected_tools=(List[Literal[tuple(selected_tools)]], Field(default=[], json_schema_extra={'args_schemas': selected_tools})) - ) + @staticmethod + def toolkit_config_schema() -> Type[BaseModel]: + return AlitaGitHubToolkitConfig @classmethod - def get_toolkit(cls, selected_tools: list[str] | None = None, toolkit_name: Optional[str] = None, **kwargs): - if selected_tools is None: - selected_tools = [] + def get_toolkit(cls, selected_tools: list[str] | None = None, toolkit_name: Optional[str] = None, **kwargs) -> "AlitaGitHubToolkit": github_api_wrapper = AlitaGitHubAPIWrapper(**kwargs) - available_tools: List[Dict] = github_api_wrapper.get_available_tools() - tools = [] - prefix = clean_string(toolkit_name, AlitaGitHubToolkit.toolkit_max_length) + TOOLKIT_SPLITTER if toolkit_name else '' - for tool in available_tools: - if selected_tools: - if tool["name"] not in selected_tools: - continue - tools.append(GitHubAction( - api_wrapper=github_api_wrapper, - name=prefix + tool["name"], - mode=tool["mode"], - # set unique description for declared tools to differentiate the same methods for different toolkits - description=f"Repository: {github_api_wrapper.github_repository}\n" + tool["description"], - args_schema=tool["args_schema"] - )) - return cls(tools=tools) + instance = cls( + tools=[], + api_wrapper=github_api_wrapper, + toolkit_name=toolkit_name + ) + if selected_tools: + selected_tools = set(selected_tools) + for t in instance.available_tools: + if t["name"] in selected_tools: + instance.tools.append(GitHubAction( + api_wrapper=instance.api_wrapper, + name=instance.tool_prefix + t["name"], + mode=t["mode"], + # set unique description for declared tools to differentiate the same methods for different toolkits + description=f"Repository: {github_api_wrapper.github_repository}\n" + t["description"], + args_schema=t["args_schema"] + )) + return instance def get_tools(self): - return self.tools \ No newline at end of file + return self.tools diff --git a/src/alita_tools/github/api_wrapper.py b/src/alita_tools/github/api_wrapper.py index 590136e9..4a7c77ab 100644 --- a/src/alita_tools/github/api_wrapper.py +++ b/src/alita_tools/github/api_wrapper.py @@ -1,141 +1,127 @@ -from typing import Any, Dict, List, Optional, Union, Tuple +import json import logging import traceback -import json -import re -from pydantic import BaseModel, model_validator, Field +from typing import Any, Optional, Self + +from langchain_core.callbacks import dispatch_custom_event +from langchain_core.utils import get_from_env +from pydantic import model_validator, Field, field_validator, SecretStr, ConfigDict +from pydantic_core.core_schema import ValidationInfo -from .github_client import GitHubClient -from .graphql_client_wrapper import GraphQLClientWrapper # Add imports for the executor and generator from .executor.github_code_executor import GitHubCodeExecutor from .generator.github_code_generator import GitHubCodeGenerator +from .github_client import GitHubClient +from .graphql_client_wrapper import GraphQLClientWrapper from .schemas import ( GitHubAuthConfig, - GitHubRepoConfig, - ProcessGitHubQueryModel + GitHubRepoConfig ) -from langchain_core.callbacks import dispatch_custom_event - logger = logging.getLogger(__name__) -# Import prompts for tools -from .tool_prompts import ( - UPDATE_FILE_PROMPT, - CREATE_ISSUE_PROMPT, - UPDATE_ISSUE_PROMPT, - CREATE_ISSUE_ON_PROJECT_PROMPT, - UPDATE_ISSUE_ON_PROJECT_PROMPT, - CODE_AND_RUN -) - -class AlitaGitHubAPIWrapper(BaseModel): +class AlitaGitHubAPIWrapper(GitHubAuthConfig, GitHubRepoConfig): """ Wrapper for GitHub API that integrates both REST and GraphQL functionality. """ + model_config = ConfigDict( + arbitrary_types_allowed=True, + from_attributes=True + ) + # Authentication config - github_access_token: Optional[str] = None - github_username: Optional[str] = None - github_password: Optional[str] = None - github_app_id: Optional[str] = None - github_app_private_key: Optional[str] = None - github_base_url: Optional[str] = None + github_access_token: Optional[SecretStr] = Field(default=None, json_schema_extra={'env_key': 'GITHUB_ACCESS_TOKEN'}) + github_username: Optional[str] = Field(default=None, json_schema_extra={'env_key': 'GITHUB_USERNAME'}) + github_password: Optional[SecretStr] = Field(default=None, json_schema_extra={'env_key': 'GITHUB_PASSWORD'}) + github_app_id: Optional[str] = Field(default=None, json_schema_extra={'env_key': 'GITHUB_APP_ID'}) + github_app_private_key: Optional[SecretStr] = Field(default=None, + json_schema_extra={'env_key': 'GITHUB_APP_PRIVATE_KEY'}) + github_base_url: Optional[str] = Field(default='https://api.github.com', + json_schema_extra={'env_key': 'GITHUB_BASE_URL'}) # Repository config - github_repository: Optional[str] = None - active_branch: Optional[str] = None - github_base_branch: Optional[str] = None + github_repository: Optional[str] = Field(default=None, json_schema_extra={'env_key': 'GITHUB_REPOSITORY'}) + github_base_branch: Optional[str] = 'main' # Add LLM instance llm: Optional[Any] = None - # Client instances - renamed without leading underscores and marked as exclude=True - github_client_instance: Optional[GitHubClient] = Field(default=None, exclude=True) - graphql_client_instance: Optional[GraphQLClientWrapper] = Field(default=None, exclude=True) + _github_client_instance: Optional[GitHubClient] = None + _graphql_client_instance: Optional[GraphQLClientWrapper] = None - class Config: - arbitrary_types_allowed = True - - @model_validator(mode='before') @classmethod - def validate_environment(cls, values: Dict) -> Dict: - """ - Initialize GitHub clients based on the provided values. + def model_construct(cls, *args, **kwargs) -> Self: + klass = super().model_construct(*args, **kwargs) + klass._github_client_instance = GitHubClient.model_construct() + klass._graphql_client_instance = GraphQLClientWrapper.model_construct() + return klass - Args: - values (Dict): Configuration values for GitHub API wrapper - - Returns: - Dict: Updated values dictionary - """ - from langchain.utils import get_from_dict_or_env - - # Get all authentication values - github_access_token = get_from_dict_or_env(values, "github_access_token", "GITHUB_ACCESS_TOKEN", default='') - github_username = get_from_dict_or_env(values, "github_username", "GITHUB_USERNAME", default='') - github_password = get_from_dict_or_env(values, "github_password", "GITHUB_PASSWORD", default='') - github_app_id = get_from_dict_or_env(values, "github_app_id", "GITHUB_APP_ID", default='') - github_app_private_key = get_from_dict_or_env(values, "github_app_private_key", "GITHUB_APP_PRIVATE_KEY", default='') - github_base_url = get_from_dict_or_env(values, "github_base_url", "GITHUB_BASE_URL", default='https://api.github.com') + @property + def auth_config(self): + return GitHubAuthConfig( + github_access_token=self.github_access_token, + github_username=self.github_username, + github_password=self.github_password, + github_app_id=self.github_app_id, # This will be None if not provided - GitHubAuthConfig should allow this + github_app_private_key=self.github_app_private_key, + github_base_url=self.github_base_url + ) + @field_validator( + 'github_access_token', + 'github_username', + 'github_password', + 'github_app_id', + 'github_app_private_key', + 'github_repository', + 'github_base_url', + mode='before', check_fields=False + ) + def set_from_values_or_env(cls, value: str, info: ValidationInfo) -> Optional[str]: + if value is None: + if json_schema_extra := cls.model_fields[info.field_name].json_schema_extra: + if env_key := json_schema_extra.get('env_key'): + try: + return get_from_env(key=info.field_name, env_key=env_key, + default=cls.model_fields[info.field_name].default) + except ValueError: + return None + return value + + @field_validator('github_repository', mode='after') + def clean_value(cls, value: str) -> str: + return GitHubClient.clean_repository_name(value) + + @model_validator(mode='after') + def validate_auth(self) -> Self: # Check that at least one authentication method is provided - if not (github_access_token or (github_username and github_password) or github_app_id): + if not (self.github_access_token or (self.github_username and self.github_password) or self.github_app_id): raise ValueError( "You must provide either a GitHub access token, username/password, or app credentials." ) - - auth_config = GitHubAuthConfig( - github_access_token=github_access_token, - github_username=github_username, - github_password=github_password, - github_app_id=github_app_id, # This will be None if not provided - GitHubAuthConfig should allow this - github_app_private_key=github_app_private_key, - github_base_url=github_base_url - ) - - # Rest of initialization code remains the same - github_repository = get_from_dict_or_env(values, "github_repository", "GITHUB_REPOSITORY") - github_repository = GitHubClient.clean_repository_name(github_repository) - - repo_config = GitHubRepoConfig( - github_repository=github_repository, - active_branch=get_from_dict_or_env(values, "active_branch", "ACTIVE_BRANCH", default='main'), # Change from 'ai' to 'main' - github_base_branch=get_from_dict_or_env(values, "github_base_branch", "GITHUB_BASE_BRANCH", default="main") - ) - - # Initialize GitHub client with keyword arguments - github_client = GitHubClient(auth_config=auth_config, repo_config=repo_config) - # Initialize GraphQL client with keyword argument - graphql_client = GraphQLClientWrapper(github_graphql_instance=github_client.github_api._Github__requester) - # Set client attributes on the class (renamed from _github_client to github_client_instance) - values["github_client_instance"] = github_client - values["graphql_client_instance"] = graphql_client - - # Update values - values["github_repository"] = github_repository - values["active_branch"] = repo_config.active_branch - values["github_base_branch"] = repo_config.github_base_branch - - # Ensure LLM is available in values if needed - if "llm" not in values: - values["llm"] = None - - return values + return self # Expose GitHub REST client methods directly via property @property def github_client(self) -> GitHubClient: """Access to GitHub REST client methods""" - return self.github_client_instance + if not self._github_client_instance: + self._github_client_instance = GitHubClient( + auth_config=self.auth_config, + repo_config=GitHubRepoConfig.model_validate(self) + ) + return self._github_client_instance - # Expose GraphQL client methods directly via property + # Expose GraphQL client methods directly via property @property def graphql_client(self) -> GraphQLClientWrapper: """Access to GitHub GraphQL client methods""" - return self.graphql_client_instance - + if not self._graphql_client_instance: + self._graphql_client_instance = GraphQLClientWrapper( + github_graphql_instance=self.github_client.github_api._Github__requester + ) + return self._graphql_client_instance def process_github_query(self, query: str) -> Any: try: @@ -167,7 +153,6 @@ def process_github_query(self, query: str) -> Any: logger.error(f"Error processing GitHub query: {e}\n{traceback.format_exc()}") return f"Error processing GitHub query: {e}" - def generate_github_code(self, task_to_solve: str, error_trace: str = None) -> str: """Generate Python code using LLM based on the GitHub task to solve.""" if not self.llm: @@ -224,7 +209,7 @@ def generate_code_with_retries(self, query: str) -> str: generated_code = self.generate_github_code(query, error_context) # Basic validation: check if code seems runnable (contains 'self.run') if "self.run(" in generated_code: - return generated_code + return generated_code else: raise ValueError("Generated code does not seem to call any GitHub tools.") except Exception as e: @@ -237,20 +222,21 @@ def generate_code_with_retries(self, query: str) -> str: logger.error( f"Maximum retry attempts exceeded for GitHub code generation. Last error: {last_error}" ) - raise Exception(f"Failed to generate valid GitHub code after {max_retries} retries. Last error: {e}") from e + raise Exception( + f"Failed to generate valid GitHub code after {max_retries} retries. Last error: {e}") from e # Should not be reached if logic is correct, but added for safety raise Exception("Failed to generate GitHub code.") - def get_available_tools(self): + @property + def github_tools(self) -> list: + return self.github_client.get_available_tools() + + @property + def graphql_tools(self) -> list: + return self.graphql_client.get_available_tools() + + def get_available_tools(self) -> list[dict[str, Any]]: # this is horrible, I need to think on something better - if not self.github_client_instance: - github_tools = GitHubClient.model_construct().get_available_tools() - else: - github_tools = self.github_client_instance.get_available_tools() - if not self.graphql_client_instance: - graphql_tools = GraphQLClientWrapper.model_construct().get_available_tools() - else: - graphql_tools = self.graphql_client_instance.get_available_tools() code_gen = [ # { # "ref": self.process_github_query, @@ -260,7 +246,7 @@ def get_available_tools(self): # "args_schema": ProcessGitHubQueryModel # } ] - tools = github_tools + graphql_tools + code_gen + tools = self.github_tools + self.graphql_tools + code_gen return tools def run(self, name: str, *args: Any, **kwargs: Any): @@ -268,18 +254,18 @@ def run(self, name: str, *args: Any, **kwargs: Any): if tool["name"] == name: # Handle potential dictionary input for args when only one dict is passed if len(args) == 1 and isinstance(args[0], dict) and not kwargs: - kwargs = args[0] - args = () # Clear args + kwargs = args[0] + args = () # Clear args try: return tool["ref"](*args, **kwargs) except TypeError as e: - # Attempt to call with kwargs only if args fail and kwargs exist - if kwargs and not args: - try: - return tool["ref"](**kwargs) - except TypeError: - raise ValueError(f"Argument mismatch for tool '{name}'. Error: {e}") from e - else: - raise ValueError(f"Argument mismatch for tool '{name}'. Error: {e}") from e + # Attempt to call with kwargs only if args fail and kwargs exist + if kwargs and not args: + try: + return tool["ref"](**kwargs) + except TypeError: + raise ValueError(f"Argument mismatch for tool '{name}'. Error: {e}") from e + else: + raise ValueError(f"Argument mismatch for tool '{name}'. Error: {e}") from e else: raise ValueError(f"Unknown tool name: {name}")