From 42ff6735b539484a65f1901826e735251320b6eb Mon Sep 17 00:00:00 2001 From: nochore <40186790+nochore@users.noreply.github.com> Date: Wed, 18 Jun 2025 19:00:11 +0400 Subject: [PATCH 1/4] Feat: add BigQuery tool --- src/alita_tools/__init__.py | 4 + src/alita_tools/google/__init__.py | 0 src/alita_tools/google/bigquery/__init__.py | 153 ++++++ .../google/bigquery/api_wrapper.py | 459 ++++++++++++++++++ src/alita_tools/google/bigquery/schemas.py | 93 ++++ src/alita_tools/google/bigquery/tool.py | 34 ++ 6 files changed, 743 insertions(+) create mode 100644 src/alita_tools/google/__init__.py create mode 100644 src/alita_tools/google/bigquery/__init__.py create mode 100644 src/alita_tools/google/bigquery/api_wrapper.py create mode 100644 src/alita_tools/google/bigquery/schemas.py create mode 100644 src/alita_tools/google/bigquery/tool.py diff --git a/src/alita_tools/__init__.py b/src/alita_tools/__init__.py index e9dea343..feb16b38 100644 --- a/src/alita_tools/__init__.py +++ b/src/alita_tools/__init__.py @@ -44,6 +44,7 @@ from .carrier import get_tools as get_carrier, AlitaCarrierToolkit from .ocr import get_tools as get_ocr, OCRToolkit from .pptx import get_tools as get_pptx, PPTXToolkit +from .google.bigquery import get_tools as get_bigquery, BigQueryToolkit logger = logging.getLogger(__name__) @@ -120,6 +121,8 @@ def get_tools(tools_list, alita: 'AlitaClient', llm: 'LLMLikeObject', *args, **k tools.extend(get_ocr(tool)) elif tool['type'] == 'pptx': tools.extend(get_pptx(tool)) + elif tool['type'] == 'bigquery': + tools.extend(get_bigquery(tool)) else: if tool.get("settings", {}).get("module"): try: @@ -175,4 +178,5 @@ def get_toolkits(): AlitaCarrierToolkit.toolkit_config_schema(), OCRToolkit.toolkit_config_schema(), PPTXToolkit.toolkit_config_schema(), + BigQueryToolkit.toolkit_config_schema(), ] diff --git a/src/alita_tools/google/__init__.py b/src/alita_tools/google/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/alita_tools/google/bigquery/__init__.py b/src/alita_tools/google/bigquery/__init__.py new file mode 100644 index 00000000..e1ea1d39 --- /dev/null +++ b/src/alita_tools/google/bigquery/__init__.py @@ -0,0 +1,153 @@ +from functools import lru_cache +from typing import List, Optional, Type + +from langchain_core.tools import BaseTool, BaseToolkit +from pydantic import BaseModel, Field, SecretStr, computed_field, field_validator + +from ...utils import TOOLKIT_SPLITTER, clean_string, get_max_toolkit_length +from .api_wrapper import BigQueryApiWrapper +from .tool import BigQueryAction + +name = "bigquery" + + +@lru_cache(maxsize=1) +def get_available_tools() -> dict[str, dict]: + api_wrapper = BigQueryApiWrapper.model_construct() + available_tools: dict = { + x["name"]: x["args_schema"].model_json_schema() + for x in api_wrapper.get_available_tools() + } + return available_tools + + +toolkit_max_length = lru_cache(maxsize=1)( + lambda: get_max_toolkit_length(get_available_tools()) +) + + +class BigQueryToolkitConfig(BaseModel): + class Config: + title = name + json_schema_extra = { + "metadata": { + "label": "Cloud GCP", + "icon_url": "google.svg", + "sections": { + "auth": { + "required": False, + "subsections": [ + {"name": "API Key", "fields": ["api_key"]}, + ], + } + }, + } + } + + api_key: Optional[SecretStr] = Field( + default=None, + description="GCP API key", + json_schema_extra={"secret": True, "configuration": True}, + ) + project: Optional[str] = Field( + default=None, + description="BigQuery project ID", + json_schema_extra={"configuration": True}, + ) + location: Optional[str] = Field( + default=None, + description="BigQuery location", + json_schema_extra={"configuration": True}, + ) + dataset: Optional[str] = Field( + default=None, + description="BigQuery dataset name", + json_schema_extra={"configuration": True}, + ) + table: Optional[str] = Field( + default=None, + description="BigQuery table name", + json_schema_extra={"configuration": True}, + ) + selected_tools: List[str] = Field( + default=[], + description="Selected tools", + json_schema_extra={"args_schemas": get_available_tools()}, + ) + + @field_validator("selected_tools", mode="before", check_fields=False) + @classmethod + def selected_tools_validator(cls, value: List[str]) -> list[str]: + return [i for i in value if i in get_available_tools()] + + +def _get_toolkit(tool) -> BaseToolkit: + return BigQueryToolkit().get_toolkit( + selected_tools=tool["settings"].get("selected_tools", []), + api_key=tool["settings"].get("api_key", ""), + toolkit_name=tool.get("toolkit_name"), + ) + + +def get_toolkit(): + return BigQueryToolkit.toolkit_config_schema() + + +def get_tools(tool): + return _get_toolkit(tool).get_tools() + + +class BigQueryToolkit(BaseToolkit): + tools: List[BaseTool] = [] + api_wrapper: Optional[BigQueryApiWrapper] = Field( + default_factory=BigQueryApiWrapper.model_construct + ) + toolkit_name: Optional[str] = None + + @computed_field + @property + def tool_prefix(self) -> str: + return ( + clean_string(self.toolkit_name, toolkit_max_length()) + TOOLKIT_SPLITTER + if self.toolkit_name + else "" + ) + + @computed_field + @property + def available_tools(self) -> List[dict]: + return self.api_wrapper.get_available_tools() + + @staticmethod + def toolkit_config_schema() -> Type[BaseModel]: + return BigQueryToolkitConfig + + @classmethod + def get_toolkit( + cls, + selected_tools: list[str] | None = None, + toolkit_name: Optional[str] = None, + **kwargs, + ) -> "BigQueryToolkit": + bigquery_api_wrapper = BigQueryApiWrapper(**kwargs) + instance = cls( + tools=[], api_wrapper=bigquery_api_wrapper, toolkit_name=toolkit_name + ) + if selected_tools: + selected_tools = set(selected_tools) + for t in instance.available_tools: + if t["name"] in selected_tools: + instance.tools.append( + BigQueryAction( + api_wrapper=instance.api_wrapper, + name=instance.tool_prefix + t["name"], + # set unique description for declared tools to differentiate the same methods for different toolkits + description=f"Project: {getattr(instance.api_wrapper, 'project', '')}\n" + + t["description"], + args_schema=t["args_schema"], + ) + ) + return instance + + def get_tools(self): + return self.tools diff --git a/src/alita_tools/google/bigquery/api_wrapper.py b/src/alita_tools/google/bigquery/api_wrapper.py new file mode 100644 index 00000000..2a265388 --- /dev/null +++ b/src/alita_tools/google/bigquery/api_wrapper.py @@ -0,0 +1,459 @@ +import functools +import json +import logging +from typing import Any, Dict, List, Optional, Union + +from google.cloud import bigquery +from langchain_core.tools import ToolException +from pydantic import ( + ConfigDict, + Field, + PrivateAttr, + SecretStr, + field_validator, + model_validator, +) +from pydantic_core.core_schema import ValidationInfo + +from ...elitea_base import BaseToolApiWrapper +from .schemas import ArgsSchema + + +def process_output(func): + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + try: + result = func(self, *args, **kwargs) + if isinstance(result, Exception): + return ToolException(str(result)) + if isinstance(result, (dict, list)): + return json.dumps(result, default=str) + return str(result) + except Exception as e: + logging.error(f"Error in '{func.__name__}': {str(e)}") + return ToolException(str(e)) + + return wrapper + + +class BigQueryApiWrapper(BaseToolApiWrapper): + model_config = ConfigDict(arbitrary_types_allowed=True, from_attributes=True) + + api_key: Optional[SecretStr] = Field( + default=None, json_schema_extra={"env_key": "BIGQUERY_API_KEY"} + ) + project: Optional[str] = Field( + default=None, json_schema_extra={"env_key": "BIGQUERY_PROJECT"} + ) + location: Optional[str] = Field( + default=None, json_schema_extra={"env_key": "BIGQUERY_LOCATION"} + ) + dataset: Optional[str] = Field( + default=None, json_schema_extra={"env_key": "BIGQUERY_DATASET"} + ) + table: Optional[str] = Field( + default=None, json_schema_extra={"env_key": "BIGQUERY_TABLE"} + ) + embedding: Optional[Any] = None + _client: Optional[bigquery.Client] = PrivateAttr(default=None) + + @classmethod + def model_construct(cls, *args, **kwargs): + klass = super().model_construct(*args, **kwargs) + klass._client = None + return klass + + @field_validator( + "api_key", + "project", + "location", + "dataset", + "table", + mode="before", + check_fields=False, + ) + @classmethod + def set_from_values_or_env(cls, value, info: ValidationInfo): + if value is None: + if json_schema_extra := cls.model_fields[info.field_name].json_schema_extra: + if env_key := json_schema_extra.get("env_key"): + try: + from langchain_core.utils import get_from_env + + return get_from_env( + key=info.field_name, + env_key=env_key, + default=cls.model_fields[info.field_name].default, + ) + except Exception: + return None + return value + + @model_validator(mode="after") + def validate_auth(self) -> "BigQueryApiWrapper": + if not self.api_key: + raise ValueError("You must provide a BigQuery API key.") + return self + + @property + def bigquery_client(self) -> bigquery.Client: + if not self._client: + api_key = self.api_key.get_secret_value() if self.api_key else None + if not api_key: + raise ToolException("BigQuery API key is not set.") + try: + api_key_dict = json.loads(api_key) + credentials = bigquery.Client.from_service_account_info( + api_key_dict + )._credentials + self._client = bigquery.Client( + credentials=credentials, + project=self.project, + location=self.location, + ) + except Exception as e: + raise ToolException(f"Error initializing GCP credentials: {str(e)}") + return self._client + + def _get_table_id(self): + if not (self.project and self.dataset and self.table): + raise ToolException("Project, dataset, and table must be specified.") + return f"{self.project}.{self.dataset}.{self.table}" + + def _create_filters( + self, filter: Optional[Union[Dict[str, Any], str]] = None + ) -> str: + if filter: + if isinstance(filter, dict): + filter_expressions = [] + for k, v in filter.items(): + if isinstance(v, (int, float)): + filter_expressions.append(f"{k} = {v}") + else: + filter_expressions.append(f"{k} = '{v}'") + return " AND ".join(filter_expressions) + else: + return filter + return "TRUE" + + def job_stats(self, job_id: str) -> Dict: + return self.bigquery_client.get_job(job_id)._properties.get("statistics", {}) + + def create_vector_index(self): + table_id = self._get_table_id() + index_name = f"{self.table}_langchain_index" + sql = f""" + CREATE VECTOR INDEX IF NOT EXISTS + `{index_name}` + ON `{table_id}` + (embedding) + OPTIONS(distance_type="EUCLIDEAN", index_type="IVF") + """ + try: + self.bigquery_client.query(sql).result() + return f"Vector index '{index_name}' created or already exists." + except Exception as ex: + logging.error(f"Vector index creation failed: {ex}") + return ToolException(f"Vector index creation failed: {ex}") + + @process_output + def get_documents( + self, + ids: Optional[List[str]] = None, + filter: Optional[Union[Dict[str, Any], str]] = None, + ): + table_id = self._get_table_id() + job_config = None + id_expr = "TRUE" + if ids: + job_config = bigquery.QueryJobConfig( + query_parameters=[bigquery.ArrayQueryParameter("ids", "STRING", ids)] + ) + id_expr = "doc_id IN UNNEST(@ids)" + where_filter_expr = self._create_filters(filter) + query = f"SELECT * FROM `{table_id}` WHERE {id_expr} AND {where_filter_expr}" + job = self.bigquery_client.query(query, job_config=job_config) + return [dict(row) for row in job] + + @process_output + def similarity_search( + self, + query: str, + k: int = 5, + filter: Optional[Union[Dict[str, Any], str]] = None, + ): + """Search for top `k` docs most similar to input query using vector similarity search.""" + if not hasattr(self, "embedding") or self.embedding is None: + raise ToolException("Embedding model is not set on the wrapper.") + embedding_vector = self.embedding.embed_query(query) + # Prepare the vector search query + table_id = self._get_table_id() + where_filter_expr = "TRUE" + if filter: + if isinstance(filter, dict): + filter_expressions = [f"{k} = '{v}'" for k, v in filter.items()] + where_filter_expr = " AND ".join(filter_expressions) + else: + where_filter_expr = filter + # BigQuery vector search SQL (using VECTOR_SEARCH if available) + sql = f""" + SELECT *, + VECTOR_DISTANCE(embedding, @query_embedding) AS score + FROM `{table_id}` + WHERE {where_filter_expr} + ORDER BY score ASC + LIMIT {k} + """ + job_config = bigquery.QueryJobConfig( + query_parameters=[ + bigquery.ArrayQueryParameter( + "query_embedding", "FLOAT64", embedding_vector + ) + ] + ) + job = self.bigquery_client.query(sql, job_config=job_config) + return [dict(row) for row in job] + + @process_output + def batch_search( + self, + queries: Optional[List[str]] = None, + embeddings: Optional[List[List[float]]] = None, + k: int = 5, + filter: Optional[Union[Dict[str, Any], str]] = None, + ): + """Batch vector similarity search. Accepts either queries (to embed) or embeddings.""" + if queries is not None and embeddings is not None: + raise ToolException("Provide only one of 'queries' or 'embeddings'.") + if queries is not None: + if not hasattr(self, "embedding") or self.embedding is None: + raise ToolException("Embedding model is not set on the wrapper.") + embeddings = [self.embedding.embed_query(q) for q in queries] + if not embeddings: + raise ToolException("No embeddings or queries provided.") + table_id = self._get_table_id() + where_filter_expr = "TRUE" + if filter: + if isinstance(filter, dict): + filter_expressions = [f"{k} = '{v}'" for k, v in filter.items()] + where_filter_expr = " AND ".join(filter_expressions) + else: + where_filter_expr = filter + results = [] + for emb in embeddings: + sql = f""" + SELECT *, + VECTOR_DISTANCE(embedding, @query_embedding) AS score + FROM `{table_id}` + WHERE {where_filter_expr} + ORDER BY score ASC + LIMIT {k} + """ + job_config = bigquery.QueryJobConfig( + query_parameters=[ + bigquery.ArrayQueryParameter("query_embedding", "FLOAT64", emb) + ] + ) + job = self.bigquery_client.query(sql, job_config=job_config) + results.append([dict(row) for row in job]) + return results + + def similarity_search_by_vector( + self, embedding: List[float], k: int = 5, **kwargs + ) -> List[Dict]: + """Return docs most similar to embedding vector.""" + table_id = self._get_table_id() + sql = f""" + SELECT *, VECTOR_DISTANCE(embedding, @query_embedding) AS score + FROM `{table_id}` + ORDER BY score ASC + LIMIT {k} + """ + job_config = bigquery.QueryJobConfig( + query_parameters=[ + bigquery.ArrayQueryParameter("query_embedding", "FLOAT64", embedding) + ] + ) + job = self.bigquery_client.query(sql, job_config=job_config) + return [self._row_to_document(row) for row in job] + + def similarity_search_by_vector_with_score( + self, + embedding: List[float], + filter: Optional[Union[Dict[str, Any], str]] = None, + k: int = 5, + **kwargs, + ) -> List[Dict]: + """Return docs most similar to embedding vector with scores.""" + table_id = self._get_table_id() + where_filter_expr = self._create_filters(filter) + sql = f""" + SELECT *, VECTOR_DISTANCE(embedding, @query_embedding) AS score + FROM `{table_id}` + WHERE {where_filter_expr} + ORDER BY score ASC + LIMIT {k} + """ + job_config = bigquery.QueryJobConfig( + query_parameters=[ + bigquery.ArrayQueryParameter("query_embedding", "FLOAT64", embedding) + ] + ) + job = self.bigquery_client.query(sql, job_config=job_config) + return [self._row_to_document(row) for row in job] + + def similarity_search_with_score( + self, + query: str, + filter: Optional[Union[Dict[str, Any], str]] = None, + k: int = 5, + **kwargs, + ) -> List[Dict]: + """Search for top `k` docs most similar to input query, returns both docs and scores.""" + embedding = self.embedding.embed_query(query) + return self.similarity_search_by_vector_with_score( + embedding, filter=filter, k=k, **kwargs + ) + + def similarity_search_by_vectors( + self, + embeddings: List[List[float]], + filter: Optional[Union[Dict[str, Any], str]] = None, + k: int = 5, + with_scores: bool = False, + with_embeddings: bool = False, + **kwargs, + ) -> Any: + """Core similarity search function. Handles a list of embedding vectors, optionally returning scores and embeddings.""" + results = [] + for emb in embeddings: + docs = self.similarity_search_by_vector_with_score( + emb, filter=filter, k=k, **kwargs + ) + if not with_scores and not with_embeddings: + docs = [d for d in docs] + elif not with_embeddings: + docs = [{**d, "score": d.get("score")} for d in docs] + elif not with_scores: + docs = [{**d, "embedding": emb} for d in docs] + results.append(docs) + return results + + def execute(self, method: str, *args, **kwargs): + """ + Universal method to call any method from google.cloud.bigquery.Client. + Args: + method: Name of the method to call on the BigQuery client. + *args: Positional arguments for the method. + **kwargs: Keyword arguments for the method. + Returns: + The result of the called method. + Raises: + ToolException: If the client is not initialized or method does not exist. + """ + if not self._client: + raise ToolException("BigQuery client is not initialized.") + if not hasattr(self._client, method): + raise ToolException(f"BigQuery client has no method '{method}'") + func = getattr(self._client, method) + try: + result = func(*args, **kwargs) + return result + except Exception as e: + logging.error(f"Error executing '{method}': {e}") + raise ToolException(f"Error executing '{method}': {e}") + + def get_available_tools(self) -> List[Dict[str, Any]]: + return [ + { + "name": "get_documents", + "description": self.get_documents.__doc__, + "args_schema": ArgsSchema.GetDocuments.value, + "ref": self.get_documents, + }, + { + "name": "similarity_search", + "description": self.similarity_search.__doc__, + "args_schema": ArgsSchema.SimilaritySearch.value, + "ref": self.similarity_search, + }, + { + "name": "batch_search", + "description": self.batch_search.__doc__, + "args_schema": ArgsSchema.BatchSearch.value, + "ref": self.batch_search, + }, + { + "name": "create_vector_index", + "description": self.create_vector_index.__doc__, + "args_schema": ArgsSchema.NoInput.value, + "ref": self.create_vector_index, + }, + { + "name": "job_stats", + "description": self.job_stats.__doc__, + "args_schema": ArgsSchema.JobStatsArgs.value, + "ref": self.job_stats, + }, + { + "name": "similarity_search_by_vector", + "description": self.similarity_search_by_vector.__doc__, + "args_schema": ArgsSchema.SimilaritySearchByVectorArgs.value, + "ref": self.similarity_search_by_vector, + }, + { + "name": "similarity_search_by_vector_with_score", + "description": self.similarity_search_by_vector_with_score.__doc__, + "args_schema": ArgsSchema.SimilaritySearchByVectorWithScoreArgs.value, + "ref": self.similarity_search_by_vector_with_score, + }, + { + "name": "similarity_search_with_score", + "description": self.similarity_search_with_score.__doc__, + "args_schema": ArgsSchema.SimilaritySearchWithScoreArgs.value, + "ref": self.similarity_search_with_score, + }, + { + "name": "similarity_search_by_vectors", + "description": self.similarity_search_by_vectors.__doc__, + "args_schema": ArgsSchema.SimilaritySearchByVectorsArgs.value, + "ref": self.similarity_search_by_vectors, + }, + { + "name": "to_vertex_fs_vector_store", + "description": self.to_vertex_fs_vector_store.__doc__, + "args_schema": ArgsSchema.NoInput.value, + "ref": self.to_vertex_fs_vector_store, + }, + { + "name": "execute", + "description": self.execute.__doc__, + "args_schema": ArgsSchema.ExecuteArgs.value, + "ref": self.execute, + }, + ] + + def run(self, name: str, *args: Any, **kwargs: Any): + for tool in self.get_available_tools(): + if tool["name"] == name: + # Handle potential dictionary input for args when only one dict is passed + if len(args) == 1 and isinstance(args[0], dict) and not kwargs: + kwargs = args[0] + args = () # Clear args + try: + return tool["ref"](*args, **kwargs) + except TypeError as e: + # Attempt to call with kwargs only if args fail and kwargs exist + if kwargs and not args: + try: + return tool["ref"](**kwargs) + except TypeError: + raise ValueError( + f"Argument mismatch for tool '{name}'. Error: {e}" + ) from e + else: + raise ValueError( + f"Argument mismatch for tool '{name}'. Error: {e}" + ) from e + else: + raise ValueError(f"Unknown tool name: {name}") diff --git a/src/alita_tools/google/bigquery/schemas.py b/src/alita_tools/google/bigquery/schemas.py new file mode 100644 index 00000000..7d98c343 --- /dev/null +++ b/src/alita_tools/google/bigquery/schemas.py @@ -0,0 +1,93 @@ +from enum import Enum +from typing import Any, Dict, List, Optional, Union + +from pydantic import Field, create_model + + +class ArgsSchema(Enum): + NoInput = create_model("NoInput") + GetDocuments = create_model( + "GetDocuments", + ids=( + Optional[List[str]], + Field(default=None, description="List of document IDs to retrieve."), + ), + filter=( + Optional[Union[Dict[str, Any], str]], + Field(default=None, description="Filter as dict or SQL WHERE clause."), + ), + ) + SimilaritySearch = create_model( + "SimilaritySearch", + query=(str, Field(description="Text query to search for similar documents.")), + k=(int, Field(default=5, description="Number of top results to return.")), + filter=( + Optional[Union[Dict[str, Any], str]], + Field(default=None, description="Filter as dict or SQL WHERE clause."), + ), + ) + BatchSearch = create_model( + "BatchSearch", + queries=( + Optional[List[str]], + Field(default=None, description="List of text queries."), + ), + embeddings=( + Optional[List[List[float]]], + Field(default=None, description="List of embedding vectors."), + ), + k=(int, Field(default=5, description="Number of top results to return.")), + filter=( + Optional[Union[Dict[str, Any], str]], + Field(default=None, description="Filter as dict or SQL WHERE clause."), + ), + ) + JobStatsArgs = create_model( + "JobStatsArgs", job_id=(str, Field(description="BigQuery job ID.")) + ) + SimilaritySearchByVectorArgs = create_model( + "SimilaritySearchByVectorArgs", + embedding=(List[float], Field(description="Embedding vector.")), + k=(int, Field(default=5, description="Number of top results to return.")), + ) + SimilaritySearchByVectorWithScoreArgs = create_model( + "SimilaritySearchByVectorWithScoreArgs", + embedding=(List[float], Field(description="Embedding vector.")), + filter=( + Optional[Union[Dict[str, Any], str]], + Field(default=None, description="Filter as dict or SQL WHERE clause."), + ), + k=(int, Field(default=5, description="Number of top results to return.")), + ) + SimilaritySearchWithScoreArgs = create_model( + "SimilaritySearchWithScoreArgs", + query=(str, Field(description="Text query.")), + filter=( + Optional[Union[Dict[str, Any], str]], + Field(default=None, description="Filter as dict or SQL WHERE clause."), + ), + k=(int, Field(default=5, description="Number of top results to return.")), + ) + SimilaritySearchByVectorsArgs = create_model( + "SimilaritySearchByVectorsArgs", + embeddings=(List[List[float]], Field(description="List of embedding vectors.")), + filter=( + Optional[Union[Dict[str, Any], str]], + Field(default=None, description="Filter as dict or SQL WHERE clause."), + ), + k=(int, Field(default=5, description="Number of top results to return.")), + with_scores=(bool, Field(default=False)), + with_embeddings=(bool, Field(default=False)), + ) + ExecuteArgs = create_model( + "ExecuteArgs", + method=(str, Field(description="Name of the BigQuery client method to call.")), + args=( + Optional[List[Any]], + Field(default=None, description="Positional arguments for the method."), + ), + kwargs=( + Optional[Dict[str, Any]], + Field(default=None, description="Keyword arguments for the method."), + ), + ) diff --git a/src/alita_tools/google/bigquery/tool.py b/src/alita_tools/google/bigquery/tool.py new file mode 100644 index 00000000..bc05543f --- /dev/null +++ b/src/alita_tools/google/bigquery/tool.py @@ -0,0 +1,34 @@ +from typing import Optional, Type + +from langchain_core.callbacks import CallbackManagerForToolRun +from pydantic import BaseModel, field_validator, Field +from langchain_core.tools import BaseTool +from traceback import format_exc +from .api_wrapper import BigQueryApiWrapper + + +class BigQueryAction(BaseTool): + """Tool for interacting with the BigQuery API.""" + + api_wrapper: BigQueryApiWrapper = Field(default_factory=BigQueryApiWrapper) + name: str + mode: str = "" + description: str = "" + args_schema: Optional[Type[BaseModel]] = None + + @field_validator('name', mode='before') + @classmethod + def remove_spaces(cls, v): + return v.replace(' ', '') + + def _run( + self, + *args, + run_manager: Optional[CallbackManagerForToolRun] = None, + **kwargs, + ) -> str: + """Use the GitHub API to run an operation.""" + try: + return self.api_wrapper.run(self.mode, *args, **kwargs) + except Exception as e: + return f"Error: {format_exc()}" \ No newline at end of file From 26f79a379dd0210001e970ad1e4c4fd771884503 Mon Sep 17 00:00:00 2001 From: nochore <40186790+nochore@users.noreply.github.com> Date: Fri, 20 Jun 2025 13:26:19 +0400 Subject: [PATCH 2/4] Feat: add delta lake for AWS --- requirements.txt | 5 +- src/alita_tools/__init__.py | 4 - src/alita_tools/aws/__init__.py | 0 src/alita_tools/aws/delta_lake/__init__.py | 172 +++++++++++++++++ src/alita_tools/aws/delta_lake/api_wrapper.py | 182 ++++++++++++++++++ src/alita_tools/aws/delta_lake/schemas.py | 0 src/alita_tools/aws/delta_lake/tool.py | 34 ++++ .../google/bigquery/api_wrapper.py | 55 +++++- src/alita_tools/google/bigquery/schemas.py | 9 + 9 files changed, 449 insertions(+), 12 deletions(-) create mode 100644 src/alita_tools/aws/__init__.py create mode 100644 src/alita_tools/aws/delta_lake/__init__.py create mode 100644 src/alita_tools/aws/delta_lake/api_wrapper.py create mode 100644 src/alita_tools/aws/delta_lake/schemas.py create mode 100644 src/alita_tools/aws/delta_lake/tool.py diff --git a/requirements.txt b/requirements.txt index 115027a6..4c9a48f3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,7 +15,7 @@ lxml==5.2.2 beautifulsoup4 pymupdf==1.24.9 yagmail==0.15.293 -gitpython==3.1.43 +gitpython==3.1.43 qtest-swagger-client==0.0.3 requests>=2.3.0 testrail-api==1.13.2 @@ -57,4 +57,5 @@ factor_analyzer==0.5.1 statsmodels==0.14.4 tabulate==0.9.0 pysnc==1.1.10 -shortuuid==1.0.13 \ No newline at end of file +shortuuid==1.0.13 +deltalake==1.0.2 \ No newline at end of file diff --git a/src/alita_tools/__init__.py b/src/alita_tools/__init__.py index feb16b38..e9dea343 100644 --- a/src/alita_tools/__init__.py +++ b/src/alita_tools/__init__.py @@ -44,7 +44,6 @@ from .carrier import get_tools as get_carrier, AlitaCarrierToolkit from .ocr import get_tools as get_ocr, OCRToolkit from .pptx import get_tools as get_pptx, PPTXToolkit -from .google.bigquery import get_tools as get_bigquery, BigQueryToolkit logger = logging.getLogger(__name__) @@ -121,8 +120,6 @@ def get_tools(tools_list, alita: 'AlitaClient', llm: 'LLMLikeObject', *args, **k tools.extend(get_ocr(tool)) elif tool['type'] == 'pptx': tools.extend(get_pptx(tool)) - elif tool['type'] == 'bigquery': - tools.extend(get_bigquery(tool)) else: if tool.get("settings", {}).get("module"): try: @@ -178,5 +175,4 @@ def get_toolkits(): AlitaCarrierToolkit.toolkit_config_schema(), OCRToolkit.toolkit_config_schema(), PPTXToolkit.toolkit_config_schema(), - BigQueryToolkit.toolkit_config_schema(), ] diff --git a/src/alita_tools/aws/__init__.py b/src/alita_tools/aws/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/alita_tools/aws/delta_lake/__init__.py b/src/alita_tools/aws/delta_lake/__init__.py new file mode 100644 index 00000000..9a5d0d51 --- /dev/null +++ b/src/alita_tools/aws/delta_lake/__init__.py @@ -0,0 +1,172 @@ +from functools import lru_cache +from typing import List, Optional, Type + +from langchain_core.tools import BaseTool, BaseToolkit +from pydantic import BaseModel, Field, SecretStr, computed_field, field_validator + +from ...utils import TOOLKIT_SPLITTER, clean_string, get_max_toolkit_length +from .api_wrapper import DeltaLakeApiWrapper +from .tool import DeltaLakeAction + +name = "delta_lake" + + +@lru_cache(maxsize=1) +def get_available_tools() -> dict[str, dict]: + api_wrapper = DeltaLakeApiWrapper.model_construct() + available_tools: dict = { + x["name"]: x["args_schema"].model_json_schema() + for x in api_wrapper.get_available_tools() + } + return available_tools + + +toolkit_max_length = lru_cache(maxsize=1)( + lambda: get_max_toolkit_length(get_available_tools()) +) + + +class DeltaLakeToolkitConfig(BaseModel): + class Config: + title = name + json_schema_extra = { + "metadata": { + "label": "AWS Delta Lake", + "icon_url": "delta-lake.svg", + "sections": { + "auth": { + "required": False, + "subsections": [ + {"name": "AWS Access Key ID", "fields": ["aws_access_key_id"]}, + {"name": "AWS Secret Access Key", "fields": ["aws_secret_access_key"]}, + {"name": "AWS Session Token", "fields": ["aws_session_token"]}, + {"name": "AWS Region", "fields": ["aws_region"]}, + ], + }, + "connection": { + "required": False, + "subsections": [ + {"name": "Delta Lake S3 Path", "fields": ["s3_path"]}, + {"name": "Delta Lake Table Path", "fields": ["table_path"]}, + ], + }, + }, + } + } + + aws_access_key_id: Optional[SecretStr] = Field( + default=None, + description="AWS access key ID", + json_schema_extra={"secret": True, "configuration": True}, + ) + aws_secret_access_key: Optional[SecretStr] = Field( + default=None, + description="AWS secret access key", + json_schema_extra={"secret": True, "configuration": True}, + ) + aws_session_token: Optional[SecretStr] = Field( + default=None, + description="AWS session token (optional)", + json_schema_extra={"secret": True, "configuration": True}, + ) + aws_region: Optional[str] = Field( + default=None, + description="AWS region for Delta Lake storage", + json_schema_extra={"configuration": True}, + ) + s3_path: Optional[str] = Field( + default=None, + description="S3 path to Delta Lake data (e.g., s3://bucket/path)", + json_schema_extra={"configuration": True}, + ) + table_path: Optional[str] = Field( + default=None, + description="Delta Lake table path (if not using s3_path)", + json_schema_extra={"configuration": True}, + ) + selected_tools: List[str] = Field( + default=[], + description="Selected tools", + json_schema_extra={"args_schemas": get_available_tools()}, + ) + + @field_validator("selected_tools", mode="before", check_fields=False) + @classmethod + def selected_tools_validator(cls, value: List[str]) -> list[str]: + return [i for i in value if i in get_available_tools()] + + +def _get_toolkit(tool) -> BaseToolkit: + return DeltaLakeToolkit().get_toolkit( + selected_tools=tool["settings"].get("selected_tools", []), + aws_access_key_id=tool["settings"].get("aws_access_key_id", None), + aws_secret_access_key=tool["settings"].get("aws_secret_access_key", None), + aws_session_token=tool["settings"].get("aws_session_token", None), + aws_region=tool["settings"].get("aws_region", None), + s3_path=tool["settings"].get("s3_path", None), + table_path=tool["settings"].get("table_path", None), + toolkit_name=tool.get("toolkit_name"), + ) + + +def get_toolkit(): + return DeltaLakeToolkit.toolkit_config_schema() + + +def get_tools(tool): + return _get_toolkit(tool).get_tools() + + +class DeltaLakeToolkit(BaseToolkit): + tools: List[BaseTool] = [] + api_wrapper: Optional[DeltaLakeApiWrapper] = Field( + default_factory=DeltaLakeApiWrapper.model_construct + ) + toolkit_name: Optional[str] = None + + @computed_field + @property + def tool_prefix(self) -> str: + return ( + clean_string(self.toolkit_name, toolkit_max_length()) + TOOLKIT_SPLITTER + if self.toolkit_name + else "" + ) + + @computed_field + @property + def available_tools(self) -> List[dict]: + return self.api_wrapper.get_available_tools() + + @staticmethod + def toolkit_config_schema() -> Type[BaseModel]: + return DeltaLakeToolkitConfig + + @classmethod + def get_toolkit( + cls, + selected_tools: list[str] | None = None, + toolkit_name: Optional[str] = None, + **kwargs, + ) -> "DeltaLakeToolkit": + delta_lake_api_wrapper = DeltaLakeApiWrapper(**kwargs) + instance = cls( + tools=[], api_wrapper=delta_lake_api_wrapper, toolkit_name=toolkit_name + ) + if selected_tools: + selected_tools = set(selected_tools) + for t in instance.available_tools: + if t["name"] in selected_tools: + instance.tools.append( + DeltaLakeAction( + api_wrapper=instance.api_wrapper, + name=instance.tool_prefix + t["name"], + description=f"S3 Path: {getattr(instance.api_wrapper, 's3_path', '')} Table Path: {getattr(instance.api_wrapper, 'table_path', '')}\n" + + t["description"], + args_schema=t["args_schema"], + ) + ) + return instance + + def get_tools(self): + return self.tools diff --git a/src/alita_tools/aws/delta_lake/api_wrapper.py b/src/alita_tools/aws/delta_lake/api_wrapper.py new file mode 100644 index 00000000..d950142e --- /dev/null +++ b/src/alita_tools/aws/delta_lake/api_wrapper.py @@ -0,0 +1,182 @@ +import functools +import json +import logging +from typing import Any, Dict, List, Optional + +from deltalake import DeltaTable +from langchain_core.tools import ToolException +from pydantic import ( + ConfigDict, + Field, + PrivateAttr, + SecretStr, + field_validator, + model_validator, +) +from pydantic_core.core_schema import ValidationInfo +from ...elitea_base import BaseToolApiWrapper +from .schemas import ArgsSchema + + +def process_output(func): + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + try: + result = func(self, *args, **kwargs) + if isinstance(result, Exception): + return ToolException(str(result)) + if isinstance(result, (dict, list)): + return json.dumps(result, default=str) + return str(result) + except Exception as e: + logging.error(f"Error in '{func.__name__}': {str(e)}") + return ToolException(str(e)) + return wrapper + + +class DeltaLakeApiWrapper(BaseToolApiWrapper): + """ + API Wrapper for Delta Lake on Azure Databricks using PySpark. + Handles authentication, querying, and utility methods. + """ + model_config = ConfigDict(arbitrary_types_allowed=True, from_attributes=True) + + aws_access_key_id: Optional[SecretStr] = Field( + default=None, json_schema_extra={"env_key": "AWS_ACCESS_KEY_ID"} + ) + aws_secret_access_key: Optional[SecretStr] = Field( + default=None, json_schema_extra={"env_key": "AWS_SECRET_ACCESS_KEY"} + ) + aws_session_token: Optional[SecretStr] = Field( + default=None, json_schema_extra={"env_key": "AWS_SESSION_TOKEN"} + ) + aws_region: Optional[str] = Field( + default=None, json_schema_extra={"env_key": "AWS_REGION"} + ) + s3_path: Optional[str] = Field( + default=None, json_schema_extra={"env_key": "DELTA_LAKE_S3_PATH"} + ) + table_path: Optional[str] = Field( + default=None, json_schema_extra={"env_key": "DELTA_LAKE_TABLE_PATH"} + ) + _delta_table: Optional[DeltaTable] = PrivateAttr(default=None) + + @classmethod + def model_construct(cls, *args, **kwargs): + klass = super().model_construct(*args, **kwargs) + klass._delta_table = None + return klass + + @field_validator( + "aws_access_key_id", + "aws_secret_access_key", + "aws_session_token", + "aws_region", + "s3_path", + "table_path", + mode="before", + check_fields=False, + ) + @classmethod + def set_from_values_or_env(cls, value, info: ValidationInfo): + if value is None: + if json_schema_extra := cls.model_fields[info.field_name].json_schema_extra: + if env_key := json_schema_extra.get("env_key"): + try: + from langchain_core.utils import get_from_env + return get_from_env( + key=info.field_name, + env_key=env_key, + default=cls.model_fields[info.field_name].default, + ) + except Exception: + return None + return value + + @model_validator(mode="after") + def validate_auth(self) -> "DeltaLakeApiWrapper": + if not (self.aws_access_key_id and self.aws_secret_access_key and self.aws_region): + raise ValueError("You must provide AWS credentials and region.") + if not (self.s3_path or self.table_path): + raise ValueError("You must provide either s3_path or table_path.") + return self + + @property + def delta_table(self) -> DeltaTable: + if not self._delta_table: + path = self.table_path or self.s3_path + if not path: + raise ToolException("Delta Lake table path (table_path or s3_path) must be specified.") + try: + storage_options = { + "AWS_ACCESS_KEY_ID": self.aws_access_key_id.get_secret_value() if self.aws_access_key_id else None, + "AWS_SECRET_ACCESS_KEY": self.aws_secret_access_key.get_secret_value() if self.aws_secret_access_key else None, + "AWS_REGION": self.aws_region, + } + if self.aws_session_token: + storage_options["AWS_SESSION_TOKEN"] = self.aws_session_token.get_secret_value() + # Remove None values + storage_options = {k: v for k, v in storage_options.items() if v is not None} + self._delta_table = DeltaTable(path, storage_options=storage_options) + except Exception as e: + raise ToolException(f"Error initializing DeltaTable: {e}") + return self._delta_table + + @process_output + def query_table(self, query: str) -> List[Dict[str, Any]]: + """ + Query is ignored in this lightweight implementation. Returns all data as list of dicts. + """ + dt = self.delta_table + df = dt.to_pandas() + return df.to_dict(orient="records") + + @process_output + def get_table_schema(self) -> Dict[str, Any]: + dt = self.delta_table + return dt.schema().to_pyarrow().to_string() + + def get_available_tools(self) -> List[Dict[str, Any]]: + return [ + { + "name": "query_table", + "description": self.query_table.__doc__, + "args_schema": ArgsSchema.QueryTableArgs.value, + "ref": self.query_table, + }, + { + "name": "vector_search", + "description": self.vector_search.__doc__, + "args_schema": ArgsSchema.VectorSearchArgs.value, + "ref": self.vector_search, + }, + { + "name": "get_table_schema", + "description": self.get_table_schema.__doc__, + "args_schema": ArgsSchema.NoInput.value, + "ref": self.get_table_schema, + }, + ] + + def run(self, name: str, *args: Any, **kwargs: Any): + for tool in self.get_available_tools(): + if tool["name"] == name: + if len(args) == 1 and isinstance(args[0], dict) and not kwargs: + kwargs = args[0] + args = () + try: + return tool["ref"](*args, **kwargs) + except TypeError as e: + if kwargs and not args: + try: + return tool["ref"](**kwargs) + except TypeError: + raise ValueError( + f"Argument mismatch for tool '{name}'. Error: {e}" + ) from e + else: + raise ValueError( + f"Argument mismatch for tool '{name}'. Error: {e}" + ) from e + else: + raise ValueError(f"Unknown tool name: {name}") diff --git a/src/alita_tools/aws/delta_lake/schemas.py b/src/alita_tools/aws/delta_lake/schemas.py new file mode 100644 index 00000000..e69de29b diff --git a/src/alita_tools/aws/delta_lake/tool.py b/src/alita_tools/aws/delta_lake/tool.py new file mode 100644 index 00000000..3401020b --- /dev/null +++ b/src/alita_tools/aws/delta_lake/tool.py @@ -0,0 +1,34 @@ +from typing import Optional, Type + +from langchain_core.callbacks import CallbackManagerForToolRun +from pydantic import BaseModel, field_validator, Field +from langchain_core.tools import BaseTool +from traceback import format_exc +from .api_wrapper import DeltaLakeApiWrapper + + +class DeltaLakeAction(BaseTool): + """Tool for interacting with the Delta Lake API on Azure Databricks.""" + + api_wrapper: DeltaLakeApiWrapper = Field(default_factory=DeltaLakeApiWrapper) + name: str + mode: str = "" + description: str = "" + args_schema: Optional[Type[BaseModel]] = None + + @field_validator('name', mode='before') + @classmethod + def remove_spaces(cls, v): + return v.replace(' ', '') + + def _run( + self, + *args, + run_manager: Optional[CallbackManagerForToolRun] = None, + **kwargs, + ) -> str: + """Use the Delta Lake API to run an operation.""" + try: + return self.api_wrapper.run(self.mode, *args, **kwargs) + except Exception as e: + return f"Error: {format_exc()}" diff --git a/src/alita_tools/google/bigquery/api_wrapper.py b/src/alita_tools/google/bigquery/api_wrapper.py index 2a265388..793ff145 100644 --- a/src/alita_tools/google/bigquery/api_wrapper.py +++ b/src/alita_tools/google/bigquery/api_wrapper.py @@ -363,6 +363,49 @@ def execute(self, method: str, *args, **kwargs): logging.error(f"Error executing '{method}': {e}") raise ToolException(f"Error executing '{method}': {e}") + @process_output + def create_delta_lake_table( + self, + table_name: str, + dataset: Optional[str] = None, + connection_id: str = None, + source_uris: list = None, + autodetect: bool = True, + project: Optional[str] = None, + **kwargs, + ): + """ + Create a Delta Lake external table in BigQuery using the google.cloud.bigquery library. + Args: + table_name: Name of the Delta Lake table to create in BigQuery. + dataset: BigQuery dataset to contain the table (defaults to self.dataset). + connection_id: Fully qualified connection ID (project.region.connection_id). + source_uris: List of GCS URIs (prefixes) for the Delta Lake table. + autodetect: Whether to autodetect schema (default: True). + project: GCP project ID (defaults to self.project). + Returns: + API response as dict. + """ + dataset = dataset or self.dataset + project = project or self.project + if not (project and dataset and table_name and connection_id and source_uris): + raise ToolException("project, dataset, table_name, connection_id, and source_uris are required.") + client = self.bigquery_client + table_ref = bigquery.TableReference( + bigquery.DatasetReference(project, dataset), table_name + ) + external_config = bigquery.ExternalConfig("DELTA_LAKE") + external_config.autodetect = autodetect + external_config.source_uris = source_uris + external_config.connection_id = connection_id + table = bigquery.Table(table_ref) + table.external_data_configuration = external_config + try: + created_table = client.create_table(table, exists_ok=True) + return created_table.to_api_repr() + except Exception as e: + raise ToolException(f"Failed to create Delta Lake table: {e}") + def get_available_tools(self) -> List[Dict[str, Any]]: return [ { @@ -419,18 +462,18 @@ def get_available_tools(self) -> List[Dict[str, Any]]: "args_schema": ArgsSchema.SimilaritySearchByVectorsArgs.value, "ref": self.similarity_search_by_vectors, }, - { - "name": "to_vertex_fs_vector_store", - "description": self.to_vertex_fs_vector_store.__doc__, - "args_schema": ArgsSchema.NoInput.value, - "ref": self.to_vertex_fs_vector_store, - }, { "name": "execute", "description": self.execute.__doc__, "args_schema": ArgsSchema.ExecuteArgs.value, "ref": self.execute, }, + { + "name": "create_delta_lake_table", + "description": self.create_delta_lake_table.__doc__, + "args_schema": ArgsSchema.CreateDeltaLakeTable.value, + "ref": self.create_delta_lake_table, + }, ] def run(self, name: str, *args: Any, **kwargs: Any): diff --git a/src/alita_tools/google/bigquery/schemas.py b/src/alita_tools/google/bigquery/schemas.py index 7d98c343..068b6e0f 100644 --- a/src/alita_tools/google/bigquery/schemas.py +++ b/src/alita_tools/google/bigquery/schemas.py @@ -91,3 +91,12 @@ class ArgsSchema(Enum): Field(default=None, description="Keyword arguments for the method."), ), ) + CreateDeltaLakeTable = create_model( + "CreateDeltaLakeTable", + table_name=(str, Field(description="Name of the Delta Lake table to create in BigQuery.")), + dataset=(Optional[str], Field(default=None, description="BigQuery dataset to contain the table (defaults to self.dataset).")), + connection_id=(str, Field(description="Fully qualified connection ID (project.region.connection_id).")), + source_uris=(list, Field(description="List of GCS URIs (prefixes) for the Delta Lake table.")), + autodetect=(bool, Field(default=True, description="Whether to autodetect schema (default: True).")), + project=(Optional[str], Field(default=None, description="GCP project ID (defaults to self.project).")), + ) From 99cf3dd25b5a719d150d2ff3519ddc0e74b712e7 Mon Sep 17 00:00:00 2001 From: nochore Date: Sun, 13 Jul 2025 18:55:40 +0400 Subject: [PATCH 3/4] fix: improvements --- src/alita_tools/__init__.py | 10 ++ src/alita_tools/aws/__init__.py | 7 ++ src/alita_tools/aws/delta_lake/__init__.py | 59 +++--------- src/alita_tools/aws/delta_lake/api_wrapper.py | 92 +++++++++++++------ src/alita_tools/aws/delta_lake/schemas.py | 20 ++++ src/alita_tools/aws/delta_lake/tool.py | 9 +- src/alita_tools/google/__init__.py | 7 ++ 7 files changed, 125 insertions(+), 79 deletions(-) diff --git a/src/alita_tools/__init__.py b/src/alita_tools/__init__.py index e9dea343..44397823 100644 --- a/src/alita_tools/__init__.py +++ b/src/alita_tools/__init__.py @@ -1,6 +1,8 @@ import logging from importlib import import_module +from httpx import get + from .github import get_tools as get_github, AlitaGitHubToolkit from .openapi import get_tools as get_openapi from .jira import get_tools as get_jira, JiraToolkit @@ -44,6 +46,8 @@ from .carrier import get_tools as get_carrier, AlitaCarrierToolkit from .ocr import get_tools as get_ocr, OCRToolkit from .pptx import get_tools as get_pptx, PPTXToolkit +from .aws import get_tools as get_delta_lake, DeltaLakeToolkit +from .google import get_tools as get_bigquery, BigQueryToolkit logger = logging.getLogger(__name__) @@ -120,6 +124,10 @@ def get_tools(tools_list, alita: 'AlitaClient', llm: 'LLMLikeObject', *args, **k tools.extend(get_ocr(tool)) elif tool['type'] == 'pptx': tools.extend(get_pptx(tool)) + elif tool['type'] == 'delta_lake': + tools.extend(get_delta_lake(tool)) + elif tool['type'] == 'bigquery': + tools.extend(get_bigquery(tool)) else: if tool.get("settings", {}).get("module"): try: @@ -175,4 +183,6 @@ def get_toolkits(): AlitaCarrierToolkit.toolkit_config_schema(), OCRToolkit.toolkit_config_schema(), PPTXToolkit.toolkit_config_schema(), + DeltaLakeToolkit.toolkit_config_schema(), + BigQueryToolkit.toolkit_config_schema(), ] diff --git a/src/alita_tools/aws/__init__.py b/src/alita_tools/aws/__init__.py index e69de29b..32e5bf13 100644 --- a/src/alita_tools/aws/__init__.py +++ b/src/alita_tools/aws/__init__.py @@ -0,0 +1,7 @@ +from .delta_lake import DeltaLakeToolkit + +name = "aws" + +def get_tools(tool_type, tool): + if tool_type == 'delta_lake': + return DeltaLakeToolkit().get_toolkit().get_tools() \ No newline at end of file diff --git a/src/alita_tools/aws/delta_lake/__init__.py b/src/alita_tools/aws/delta_lake/__init__.py index 9a5d0d51..5f23cb4b 100644 --- a/src/alita_tools/aws/delta_lake/__init__.py +++ b/src/alita_tools/aws/delta_lake/__init__.py @@ -1,3 +1,4 @@ + from functools import lru_cache from typing import List, Optional, Type @@ -10,7 +11,6 @@ name = "delta_lake" - @lru_cache(maxsize=1) def get_available_tools() -> dict[str, dict]: api_wrapper = DeltaLakeApiWrapper.model_construct() @@ -20,12 +20,10 @@ def get_available_tools() -> dict[str, dict]: } return available_tools - toolkit_max_length = lru_cache(maxsize=1)( lambda: get_max_toolkit_length(get_available_tools()) ) - class DeltaLakeToolkitConfig(BaseModel): class Config: title = name @@ -54,48 +52,19 @@ class Config: } } - aws_access_key_id: Optional[SecretStr] = Field( - default=None, - description="AWS access key ID", - json_schema_extra={"secret": True, "configuration": True}, - ) - aws_secret_access_key: Optional[SecretStr] = Field( - default=None, - description="AWS secret access key", - json_schema_extra={"secret": True, "configuration": True}, - ) - aws_session_token: Optional[SecretStr] = Field( - default=None, - description="AWS session token (optional)", - json_schema_extra={"secret": True, "configuration": True}, - ) - aws_region: Optional[str] = Field( - default=None, - description="AWS region for Delta Lake storage", - json_schema_extra={"configuration": True}, - ) - s3_path: Optional[str] = Field( - default=None, - description="S3 path to Delta Lake data (e.g., s3://bucket/path)", - json_schema_extra={"configuration": True}, - ) - table_path: Optional[str] = Field( - default=None, - description="Delta Lake table path (if not using s3_path)", - json_schema_extra={"configuration": True}, - ) - selected_tools: List[str] = Field( - default=[], - description="Selected tools", - json_schema_extra={"args_schemas": get_available_tools()}, - ) + aws_access_key_id: Optional[SecretStr] = Field(default=None, description="AWS access key ID", json_schema_extra={"secret": True, "configuration": True}) + aws_secret_access_key: Optional[SecretStr] = Field(default=None, description="AWS secret access key", json_schema_extra={"secret": True, "configuration": True}) + aws_session_token: Optional[SecretStr] = Field(default=None, description="AWS session token (optional)", json_schema_extra={"secret": True, "configuration": True}) + aws_region: Optional[str] = Field(default=None, description="AWS region for Delta Lake storage", json_schema_extra={"configuration": True}) + s3_path: Optional[str] = Field(default=None, description="S3 path to Delta Lake data (e.g., s3://bucket/path)", json_schema_extra={"configuration": True}) + table_path: Optional[str] = Field(default=None, description="Delta Lake table path (if not using s3_path)", json_schema_extra={"configuration": True}) + selected_tools: List[str] = Field(default=[], description="Selected tools", json_schema_extra={"args_schemas": get_available_tools()}) @field_validator("selected_tools", mode="before", check_fields=False) @classmethod def selected_tools_validator(cls, value: List[str]) -> list[str]: return [i for i in value if i in get_available_tools()] - def _get_toolkit(tool) -> BaseToolkit: return DeltaLakeToolkit().get_toolkit( selected_tools=tool["settings"].get("selected_tools", []), @@ -108,20 +77,15 @@ def _get_toolkit(tool) -> BaseToolkit: toolkit_name=tool.get("toolkit_name"), ) - def get_toolkit(): return DeltaLakeToolkit.toolkit_config_schema() - def get_tools(tool): return _get_toolkit(tool).get_tools() - class DeltaLakeToolkit(BaseToolkit): tools: List[BaseTool] = [] - api_wrapper: Optional[DeltaLakeApiWrapper] = Field( - default_factory=DeltaLakeApiWrapper.model_construct - ) + api_wrapper: Optional[DeltaLakeApiWrapper] = Field(default_factory=DeltaLakeApiWrapper.model_construct) toolkit_name: Optional[str] = None @computed_field @@ -161,12 +125,11 @@ def get_toolkit( DeltaLakeAction( api_wrapper=instance.api_wrapper, name=instance.tool_prefix + t["name"], - description=f"S3 Path: {getattr(instance.api_wrapper, 's3_path', '')} Table Path: {getattr(instance.api_wrapper, 'table_path', '')}\n" - + t["description"], + description=f"S3 Path: {getattr(instance.api_wrapper, 's3_path', '')} Table Path: {getattr(instance.api_wrapper, 'table_path', '')}\n" + t["description"], args_schema=t["args_schema"], ) ) return instance def get_tools(self): - return self.tools + return self.tools \ No newline at end of file diff --git a/src/alita_tools/aws/delta_lake/api_wrapper.py b/src/alita_tools/aws/delta_lake/api_wrapper.py index d950142e..22142b64 100644 --- a/src/alita_tools/aws/delta_lake/api_wrapper.py +++ b/src/alita_tools/aws/delta_lake/api_wrapper.py @@ -1,7 +1,7 @@ import functools import json import logging -from typing import Any, Dict, List, Optional +from typing import Any, List, Optional from deltalake import DeltaTable from langchain_core.tools import ToolException @@ -36,29 +36,16 @@ def wrapper(self, *args, **kwargs): class DeltaLakeApiWrapper(BaseToolApiWrapper): """ - API Wrapper for Delta Lake on Azure Databricks using PySpark. - Handles authentication, querying, and utility methods. + API Wrapper for AWS Delta Lake. Handles authentication, querying, and utility methods. """ model_config = ConfigDict(arbitrary_types_allowed=True, from_attributes=True) - aws_access_key_id: Optional[SecretStr] = Field( - default=None, json_schema_extra={"env_key": "AWS_ACCESS_KEY_ID"} - ) - aws_secret_access_key: Optional[SecretStr] = Field( - default=None, json_schema_extra={"env_key": "AWS_SECRET_ACCESS_KEY"} - ) - aws_session_token: Optional[SecretStr] = Field( - default=None, json_schema_extra={"env_key": "AWS_SESSION_TOKEN"} - ) - aws_region: Optional[str] = Field( - default=None, json_schema_extra={"env_key": "AWS_REGION"} - ) - s3_path: Optional[str] = Field( - default=None, json_schema_extra={"env_key": "DELTA_LAKE_S3_PATH"} - ) - table_path: Optional[str] = Field( - default=None, json_schema_extra={"env_key": "DELTA_LAKE_TABLE_PATH"} - ) + aws_access_key_id: Optional[SecretStr] = Field(default=None, json_schema_extra={"env_key": "AWS_ACCESS_KEY_ID"}) + aws_secret_access_key: Optional[SecretStr] = Field(default=None, json_schema_extra={"env_key": "AWS_SECRET_ACCESS_KEY"}) + aws_session_token: Optional[SecretStr] = Field(default=None, json_schema_extra={"env_key": "AWS_SESSION_TOKEN"}) + aws_region: Optional[str] = Field(default=None, json_schema_extra={"env_key": "AWS_REGION"}) + s3_path: Optional[str] = Field(default=None, json_schema_extra={"env_key": "DELTA_LAKE_S3_PATH"}) + table_path: Optional[str] = Field(default=None, json_schema_extra={"env_key": "DELTA_LAKE_TABLE_PATH"}) _delta_table: Optional[DeltaTable] = PrivateAttr(default=None) @classmethod @@ -115,7 +102,6 @@ def delta_table(self) -> DeltaTable: } if self.aws_session_token: storage_options["AWS_SESSION_TOKEN"] = self.aws_session_token.get_secret_value() - # Remove None values storage_options = {k: v for k, v in storage_options.items() if v is not None} self._delta_table = DeltaTable(path, storage_options=storage_options) except Exception as e: @@ -123,20 +109,72 @@ def delta_table(self) -> DeltaTable: return self._delta_table @process_output - def query_table(self, query: str) -> List[Dict[str, Any]]: + def query_table(self, query: Optional[str] = None, columns: Optional[List[str]] = None, filters: Optional[dict] = None) -> List[dict]: """ - Query is ignored in this lightweight implementation. Returns all data as list of dicts. + Query Delta Lake table. Supports pandas-like filtering, column selection, and SQL-like queries (via pandas.DataFrame.query). + Args: + query: SQL-like query string (pandas.DataFrame.query syntax) + columns: List of columns to select + filters: Dict of column:value pairs for pandas-like filtering + Returns: + List of dicts representing rows """ dt = self.delta_table df = dt.to_pandas() + if filters: + for col, val in filters.items(): + df = df[df[col] == val] + if query: + try: + df = df.query(query) + except Exception as e: + raise ToolException(f"Error in query param: {e}") + if columns: + df = df[columns] return df.to_dict(orient="records") @process_output - def get_table_schema(self) -> Dict[str, Any]: + def vector_search(self, embedding: List[float], k: int = 5, embedding_column: str = "embedding") -> List[dict]: + """ + Perform a vector similarity search on the Delta Lake table. + Args: + embedding: Query embedding vector. + k: Number of top results to return. + embedding_column: Name of the column containing embeddings. + Returns: + List of dicts for top k most similar rows. + """ + import numpy as np + + dt = self.delta_table + df = dt.to_pandas() + if embedding_column not in df.columns: + raise ToolException(f"Embedding column '{embedding_column}' not found in table.") + + # Filter out rows with missing embeddings + df = df[df[embedding_column].notnull()] + if df.empty: + return [] + # Convert embeddings to numpy arrays + emb_matrix = np.array(df[embedding_column].tolist()) + query_vec = np.array(embedding) + + # Normalize for cosine similarity + emb_matrix_norm = emb_matrix / np.linalg.norm(emb_matrix, axis=1, keepdims=True) + query_vec_norm = query_vec / np.linalg.norm(query_vec) + similarities = np.dot(emb_matrix_norm, query_vec_norm) + + # Get top k indices + top_k_idx = np.argsort(similarities)[-k:][::-1] + top_rows = df.iloc[top_k_idx] + return top_rows.to_dict(orient="records") + + @process_output + def get_table_schema(self) -> str: dt = self.delta_table return dt.schema().to_pyarrow().to_string() - def get_available_tools(self) -> List[Dict[str, Any]]: + def get_available_tools(self) -> List[dict]: return [ { "name": "query_table", @@ -179,4 +217,4 @@ def run(self, name: str, *args: Any, **kwargs: Any): f"Argument mismatch for tool '{name}'. Error: {e}" ) from e else: - raise ValueError(f"Unknown tool name: {name}") + raise ValueError(f"Unknown tool name: {name}") \ No newline at end of file diff --git a/src/alita_tools/aws/delta_lake/schemas.py b/src/alita_tools/aws/delta_lake/schemas.py index e69de29b..29b9ea94 100644 --- a/src/alita_tools/aws/delta_lake/schemas.py +++ b/src/alita_tools/aws/delta_lake/schemas.py @@ -0,0 +1,20 @@ + +from enum import Enum +from typing import List, Optional + +from pydantic import Field, create_model + +class ArgsSchema(Enum): + NoInput = create_model("NoInput") + QueryTableArgs = create_model( + "QueryTableArgs", + query=(Optional[str], Field(default=None, description="SQL query to execute on Delta Lake table. If None, returns all data.")), + columns=(Optional[List[str]], Field(default=None, description="List of columns to select.")), + filters=(Optional[dict], Field(default=None, description="Dict of column:value pairs for pandas-like filtering.")), + ) + VectorSearchArgs = create_model( + "VectorSearchArgs", + embedding=(List[float], Field(description="Embedding vector for similarity search.")), + k=(int, Field(default=5, description="Number of top results to return.")), + embedding_column=(Optional[str], Field(default="embedding", description="Name of the column containing embeddings.")), + ) diff --git a/src/alita_tools/aws/delta_lake/tool.py b/src/alita_tools/aws/delta_lake/tool.py index 3401020b..60235ca7 100644 --- a/src/alita_tools/aws/delta_lake/tool.py +++ b/src/alita_tools/aws/delta_lake/tool.py @@ -1,3 +1,4 @@ + from typing import Optional, Type from langchain_core.callbacks import CallbackManagerForToolRun @@ -8,11 +9,10 @@ class DeltaLakeAction(BaseTool): - """Tool for interacting with the Delta Lake API on Azure Databricks.""" + """Tool for interacting with the Delta Lake API on AWS.""" api_wrapper: DeltaLakeApiWrapper = Field(default_factory=DeltaLakeApiWrapper) name: str - mode: str = "" description: str = "" args_schema: Optional[Type[BaseModel]] = None @@ -29,6 +29,7 @@ def _run( ) -> str: """Use the Delta Lake API to run an operation.""" try: - return self.api_wrapper.run(self.mode, *args, **kwargs) + # Use the tool name to dispatch to the correct API wrapper method + return self.api_wrapper.run(self.name, *args, **kwargs) except Exception as e: - return f"Error: {format_exc()}" + return f"Error: {format_exc()}" \ No newline at end of file diff --git a/src/alita_tools/google/__init__.py b/src/alita_tools/google/__init__.py index e69de29b..b8f4b963 100644 --- a/src/alita_tools/google/__init__.py +++ b/src/alita_tools/google/__init__.py @@ -0,0 +1,7 @@ +from .bigquery import BigQueryToolkit + +name = "google" + +def get_tools(tool_type, tool): + if tool_type == 'bigquery': + return BigQueryToolkit().get_toolkit().get_tools() \ No newline at end of file From 95e0b0553d7f46ade3f68999c0b129dcd7376bd0 Mon Sep 17 00:00:00 2001 From: nochore Date: Mon, 14 Jul 2025 20:35:35 +0400 Subject: [PATCH 4/4] Fix after review --- requirements.txt | 3 ++- src/alita_tools/aws/delta_lake/__init__.py | 1 + src/alita_tools/google/bigquery/__init__.py | 1 + 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 383f703d..ccd5cc5f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -59,4 +59,5 @@ tabulate==0.9.0 pysnc==1.1.10 shortuuid==1.0.13 textract-py3==2.1.1 -deltalake==1.0.2 \ No newline at end of file +deltalake==1.0.2 +google_cloud_bigquery==3.34.0 \ No newline at end of file diff --git a/src/alita_tools/aws/delta_lake/__init__.py b/src/alita_tools/aws/delta_lake/__init__.py index 5f23cb4b..ac1da490 100644 --- a/src/alita_tools/aws/delta_lake/__init__.py +++ b/src/alita_tools/aws/delta_lake/__init__.py @@ -29,6 +29,7 @@ class Config: title = name json_schema_extra = { "metadata": { + "hidden": True, "label": "AWS Delta Lake", "icon_url": "delta-lake.svg", "sections": { diff --git a/src/alita_tools/google/bigquery/__init__.py b/src/alita_tools/google/bigquery/__init__.py index e1ea1d39..563c5a0c 100644 --- a/src/alita_tools/google/bigquery/__init__.py +++ b/src/alita_tools/google/bigquery/__init__.py @@ -31,6 +31,7 @@ class Config: title = name json_schema_extra = { "metadata": { + "hidden": True, "label": "Cloud GCP", "icon_url": "google.svg", "sections": {