diff --git a/requirements.txt b/requirements.txt index cd03819a..ccd5cc5f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -58,4 +58,6 @@ statsmodels==0.14.4 tabulate==0.9.0 pysnc==1.1.10 shortuuid==1.0.13 -textract-py3==2.1.1 \ No newline at end of file +textract-py3==2.1.1 +deltalake==1.0.2 +google_cloud_bigquery==3.34.0 \ No newline at end of file diff --git a/src/alita_tools/__init__.py b/src/alita_tools/__init__.py index e9dea343..44397823 100644 --- a/src/alita_tools/__init__.py +++ b/src/alita_tools/__init__.py @@ -1,6 +1,8 @@ import logging from importlib import import_module +from httpx import get + from .github import get_tools as get_github, AlitaGitHubToolkit from .openapi import get_tools as get_openapi from .jira import get_tools as get_jira, JiraToolkit @@ -44,6 +46,8 @@ from .carrier import get_tools as get_carrier, AlitaCarrierToolkit from .ocr import get_tools as get_ocr, OCRToolkit from .pptx import get_tools as get_pptx, PPTXToolkit +from .aws import get_tools as get_delta_lake, DeltaLakeToolkit +from .google import get_tools as get_bigquery, BigQueryToolkit logger = logging.getLogger(__name__) @@ -120,6 +124,10 @@ def get_tools(tools_list, alita: 'AlitaClient', llm: 'LLMLikeObject', *args, **k tools.extend(get_ocr(tool)) elif tool['type'] == 'pptx': tools.extend(get_pptx(tool)) + elif tool['type'] == 'delta_lake': + tools.extend(get_delta_lake(tool)) + elif tool['type'] == 'bigquery': + tools.extend(get_bigquery(tool)) else: if tool.get("settings", {}).get("module"): try: @@ -175,4 +183,6 @@ def get_toolkits(): AlitaCarrierToolkit.toolkit_config_schema(), OCRToolkit.toolkit_config_schema(), PPTXToolkit.toolkit_config_schema(), + DeltaLakeToolkit.toolkit_config_schema(), + BigQueryToolkit.toolkit_config_schema(), ] diff --git a/src/alita_tools/aws/__init__.py b/src/alita_tools/aws/__init__.py new file mode 100644 index 00000000..32e5bf13 --- /dev/null +++ b/src/alita_tools/aws/__init__.py @@ -0,0 +1,7 @@ +from .delta_lake import DeltaLakeToolkit + +name = "aws" + +def get_tools(tool_type, tool): + if tool_type == 'delta_lake': + return DeltaLakeToolkit().get_toolkit().get_tools() \ No newline at end of file diff --git a/src/alita_tools/aws/delta_lake/__init__.py b/src/alita_tools/aws/delta_lake/__init__.py new file mode 100644 index 00000000..ac1da490 --- /dev/null +++ b/src/alita_tools/aws/delta_lake/__init__.py @@ -0,0 +1,136 @@ + +from functools import lru_cache +from typing import List, Optional, Type + +from langchain_core.tools import BaseTool, BaseToolkit +from pydantic import BaseModel, Field, SecretStr, computed_field, field_validator + +from ...utils import TOOLKIT_SPLITTER, clean_string, get_max_toolkit_length +from .api_wrapper import DeltaLakeApiWrapper +from .tool import DeltaLakeAction + +name = "delta_lake" + +@lru_cache(maxsize=1) +def get_available_tools() -> dict[str, dict]: + api_wrapper = DeltaLakeApiWrapper.model_construct() + available_tools: dict = { + x["name"]: x["args_schema"].model_json_schema() + for x in api_wrapper.get_available_tools() + } + return available_tools + +toolkit_max_length = lru_cache(maxsize=1)( + lambda: get_max_toolkit_length(get_available_tools()) +) + +class DeltaLakeToolkitConfig(BaseModel): + class Config: + title = name + json_schema_extra = { + "metadata": { + "hidden": True, + "label": "AWS Delta Lake", + "icon_url": "delta-lake.svg", + "sections": { + "auth": { + "required": False, + "subsections": [ + {"name": "AWS Access Key ID", "fields": ["aws_access_key_id"]}, + {"name": "AWS Secret Access Key", "fields": ["aws_secret_access_key"]}, + {"name": "AWS Session Token", "fields": ["aws_session_token"]}, + {"name": "AWS Region", "fields": ["aws_region"]}, + ], + }, + "connection": { + "required": False, + "subsections": [ + {"name": "Delta Lake S3 Path", "fields": ["s3_path"]}, + {"name": "Delta Lake Table Path", "fields": ["table_path"]}, + ], + }, + }, + } + } + + aws_access_key_id: Optional[SecretStr] = Field(default=None, description="AWS access key ID", json_schema_extra={"secret": True, "configuration": True}) + aws_secret_access_key: Optional[SecretStr] = Field(default=None, description="AWS secret access key", json_schema_extra={"secret": True, "configuration": True}) + aws_session_token: Optional[SecretStr] = Field(default=None, description="AWS session token (optional)", json_schema_extra={"secret": True, "configuration": True}) + aws_region: Optional[str] = Field(default=None, description="AWS region for Delta Lake storage", json_schema_extra={"configuration": True}) + s3_path: Optional[str] = Field(default=None, description="S3 path to Delta Lake data (e.g., s3://bucket/path)", json_schema_extra={"configuration": True}) + table_path: Optional[str] = Field(default=None, description="Delta Lake table path (if not using s3_path)", json_schema_extra={"configuration": True}) + selected_tools: List[str] = Field(default=[], description="Selected tools", json_schema_extra={"args_schemas": get_available_tools()}) + + @field_validator("selected_tools", mode="before", check_fields=False) + @classmethod + def selected_tools_validator(cls, value: List[str]) -> list[str]: + return [i for i in value if i in get_available_tools()] + +def _get_toolkit(tool) -> BaseToolkit: + return DeltaLakeToolkit().get_toolkit( + selected_tools=tool["settings"].get("selected_tools", []), + aws_access_key_id=tool["settings"].get("aws_access_key_id", None), + aws_secret_access_key=tool["settings"].get("aws_secret_access_key", None), + aws_session_token=tool["settings"].get("aws_session_token", None), + aws_region=tool["settings"].get("aws_region", None), + s3_path=tool["settings"].get("s3_path", None), + table_path=tool["settings"].get("table_path", None), + toolkit_name=tool.get("toolkit_name"), + ) + +def get_toolkit(): + return DeltaLakeToolkit.toolkit_config_schema() + +def get_tools(tool): + return _get_toolkit(tool).get_tools() + +class DeltaLakeToolkit(BaseToolkit): + tools: List[BaseTool] = [] + api_wrapper: Optional[DeltaLakeApiWrapper] = Field(default_factory=DeltaLakeApiWrapper.model_construct) + toolkit_name: Optional[str] = None + + @computed_field + @property + def tool_prefix(self) -> str: + return ( + clean_string(self.toolkit_name, toolkit_max_length()) + TOOLKIT_SPLITTER + if self.toolkit_name + else "" + ) + + @computed_field + @property + def available_tools(self) -> List[dict]: + return self.api_wrapper.get_available_tools() + + @staticmethod + def toolkit_config_schema() -> Type[BaseModel]: + return DeltaLakeToolkitConfig + + @classmethod + def get_toolkit( + cls, + selected_tools: list[str] | None = None, + toolkit_name: Optional[str] = None, + **kwargs, + ) -> "DeltaLakeToolkit": + delta_lake_api_wrapper = DeltaLakeApiWrapper(**kwargs) + instance = cls( + tools=[], api_wrapper=delta_lake_api_wrapper, toolkit_name=toolkit_name + ) + if selected_tools: + selected_tools = set(selected_tools) + for t in instance.available_tools: + if t["name"] in selected_tools: + instance.tools.append( + DeltaLakeAction( + api_wrapper=instance.api_wrapper, + name=instance.tool_prefix + t["name"], + description=f"S3 Path: {getattr(instance.api_wrapper, 's3_path', '')} Table Path: {getattr(instance.api_wrapper, 'table_path', '')}\n" + t["description"], + args_schema=t["args_schema"], + ) + ) + return instance + + def get_tools(self): + return self.tools \ No newline at end of file diff --git a/src/alita_tools/aws/delta_lake/api_wrapper.py b/src/alita_tools/aws/delta_lake/api_wrapper.py new file mode 100644 index 00000000..22142b64 --- /dev/null +++ b/src/alita_tools/aws/delta_lake/api_wrapper.py @@ -0,0 +1,220 @@ +import functools +import json +import logging +from typing import Any, List, Optional + +from deltalake import DeltaTable +from langchain_core.tools import ToolException +from pydantic import ( + ConfigDict, + Field, + PrivateAttr, + SecretStr, + field_validator, + model_validator, +) +from pydantic_core.core_schema import ValidationInfo +from ...elitea_base import BaseToolApiWrapper +from .schemas import ArgsSchema + + +def process_output(func): + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + try: + result = func(self, *args, **kwargs) + if isinstance(result, Exception): + return ToolException(str(result)) + if isinstance(result, (dict, list)): + return json.dumps(result, default=str) + return str(result) + except Exception as e: + logging.error(f"Error in '{func.__name__}': {str(e)}") + return ToolException(str(e)) + return wrapper + + +class DeltaLakeApiWrapper(BaseToolApiWrapper): + """ + API Wrapper for AWS Delta Lake. Handles authentication, querying, and utility methods. + """ + model_config = ConfigDict(arbitrary_types_allowed=True, from_attributes=True) + + aws_access_key_id: Optional[SecretStr] = Field(default=None, json_schema_extra={"env_key": "AWS_ACCESS_KEY_ID"}) + aws_secret_access_key: Optional[SecretStr] = Field(default=None, json_schema_extra={"env_key": "AWS_SECRET_ACCESS_KEY"}) + aws_session_token: Optional[SecretStr] = Field(default=None, json_schema_extra={"env_key": "AWS_SESSION_TOKEN"}) + aws_region: Optional[str] = Field(default=None, json_schema_extra={"env_key": "AWS_REGION"}) + s3_path: Optional[str] = Field(default=None, json_schema_extra={"env_key": "DELTA_LAKE_S3_PATH"}) + table_path: Optional[str] = Field(default=None, json_schema_extra={"env_key": "DELTA_LAKE_TABLE_PATH"}) + _delta_table: Optional[DeltaTable] = PrivateAttr(default=None) + + @classmethod + def model_construct(cls, *args, **kwargs): + klass = super().model_construct(*args, **kwargs) + klass._delta_table = None + return klass + + @field_validator( + "aws_access_key_id", + "aws_secret_access_key", + "aws_session_token", + "aws_region", + "s3_path", + "table_path", + mode="before", + check_fields=False, + ) + @classmethod + def set_from_values_or_env(cls, value, info: ValidationInfo): + if value is None: + if json_schema_extra := cls.model_fields[info.field_name].json_schema_extra: + if env_key := json_schema_extra.get("env_key"): + try: + from langchain_core.utils import get_from_env + return get_from_env( + key=info.field_name, + env_key=env_key, + default=cls.model_fields[info.field_name].default, + ) + except Exception: + return None + return value + + @model_validator(mode="after") + def validate_auth(self) -> "DeltaLakeApiWrapper": + if not (self.aws_access_key_id and self.aws_secret_access_key and self.aws_region): + raise ValueError("You must provide AWS credentials and region.") + if not (self.s3_path or self.table_path): + raise ValueError("You must provide either s3_path or table_path.") + return self + + @property + def delta_table(self) -> DeltaTable: + if not self._delta_table: + path = self.table_path or self.s3_path + if not path: + raise ToolException("Delta Lake table path (table_path or s3_path) must be specified.") + try: + storage_options = { + "AWS_ACCESS_KEY_ID": self.aws_access_key_id.get_secret_value() if self.aws_access_key_id else None, + "AWS_SECRET_ACCESS_KEY": self.aws_secret_access_key.get_secret_value() if self.aws_secret_access_key else None, + "AWS_REGION": self.aws_region, + } + if self.aws_session_token: + storage_options["AWS_SESSION_TOKEN"] = self.aws_session_token.get_secret_value() + storage_options = {k: v for k, v in storage_options.items() if v is not None} + self._delta_table = DeltaTable(path, storage_options=storage_options) + except Exception as e: + raise ToolException(f"Error initializing DeltaTable: {e}") + return self._delta_table + + @process_output + def query_table(self, query: Optional[str] = None, columns: Optional[List[str]] = None, filters: Optional[dict] = None) -> List[dict]: + """ + Query Delta Lake table. Supports pandas-like filtering, column selection, and SQL-like queries (via pandas.DataFrame.query). + Args: + query: SQL-like query string (pandas.DataFrame.query syntax) + columns: List of columns to select + filters: Dict of column:value pairs for pandas-like filtering + Returns: + List of dicts representing rows + """ + dt = self.delta_table + df = dt.to_pandas() + if filters: + for col, val in filters.items(): + df = df[df[col] == val] + if query: + try: + df = df.query(query) + except Exception as e: + raise ToolException(f"Error in query param: {e}") + if columns: + df = df[columns] + return df.to_dict(orient="records") + + @process_output + def vector_search(self, embedding: List[float], k: int = 5, embedding_column: str = "embedding") -> List[dict]: + """ + Perform a vector similarity search on the Delta Lake table. + Args: + embedding: Query embedding vector. + k: Number of top results to return. + embedding_column: Name of the column containing embeddings. + Returns: + List of dicts for top k most similar rows. + """ + import numpy as np + + dt = self.delta_table + df = dt.to_pandas() + if embedding_column not in df.columns: + raise ToolException(f"Embedding column '{embedding_column}' not found in table.") + + # Filter out rows with missing embeddings + df = df[df[embedding_column].notnull()] + if df.empty: + return [] + # Convert embeddings to numpy arrays + emb_matrix = np.array(df[embedding_column].tolist()) + query_vec = np.array(embedding) + + # Normalize for cosine similarity + emb_matrix_norm = emb_matrix / np.linalg.norm(emb_matrix, axis=1, keepdims=True) + query_vec_norm = query_vec / np.linalg.norm(query_vec) + similarities = np.dot(emb_matrix_norm, query_vec_norm) + + # Get top k indices + top_k_idx = np.argsort(similarities)[-k:][::-1] + top_rows = df.iloc[top_k_idx] + return top_rows.to_dict(orient="records") + + @process_output + def get_table_schema(self) -> str: + dt = self.delta_table + return dt.schema().to_pyarrow().to_string() + + def get_available_tools(self) -> List[dict]: + return [ + { + "name": "query_table", + "description": self.query_table.__doc__, + "args_schema": ArgsSchema.QueryTableArgs.value, + "ref": self.query_table, + }, + { + "name": "vector_search", + "description": self.vector_search.__doc__, + "args_schema": ArgsSchema.VectorSearchArgs.value, + "ref": self.vector_search, + }, + { + "name": "get_table_schema", + "description": self.get_table_schema.__doc__, + "args_schema": ArgsSchema.NoInput.value, + "ref": self.get_table_schema, + }, + ] + + def run(self, name: str, *args: Any, **kwargs: Any): + for tool in self.get_available_tools(): + if tool["name"] == name: + if len(args) == 1 and isinstance(args[0], dict) and not kwargs: + kwargs = args[0] + args = () + try: + return tool["ref"](*args, **kwargs) + except TypeError as e: + if kwargs and not args: + try: + return tool["ref"](**kwargs) + except TypeError: + raise ValueError( + f"Argument mismatch for tool '{name}'. Error: {e}" + ) from e + else: + raise ValueError( + f"Argument mismatch for tool '{name}'. Error: {e}" + ) from e + else: + raise ValueError(f"Unknown tool name: {name}") \ No newline at end of file diff --git a/src/alita_tools/aws/delta_lake/schemas.py b/src/alita_tools/aws/delta_lake/schemas.py new file mode 100644 index 00000000..29b9ea94 --- /dev/null +++ b/src/alita_tools/aws/delta_lake/schemas.py @@ -0,0 +1,20 @@ + +from enum import Enum +from typing import List, Optional + +from pydantic import Field, create_model + +class ArgsSchema(Enum): + NoInput = create_model("NoInput") + QueryTableArgs = create_model( + "QueryTableArgs", + query=(Optional[str], Field(default=None, description="SQL query to execute on Delta Lake table. If None, returns all data.")), + columns=(Optional[List[str]], Field(default=None, description="List of columns to select.")), + filters=(Optional[dict], Field(default=None, description="Dict of column:value pairs for pandas-like filtering.")), + ) + VectorSearchArgs = create_model( + "VectorSearchArgs", + embedding=(List[float], Field(description="Embedding vector for similarity search.")), + k=(int, Field(default=5, description="Number of top results to return.")), + embedding_column=(Optional[str], Field(default="embedding", description="Name of the column containing embeddings.")), + ) diff --git a/src/alita_tools/aws/delta_lake/tool.py b/src/alita_tools/aws/delta_lake/tool.py new file mode 100644 index 00000000..60235ca7 --- /dev/null +++ b/src/alita_tools/aws/delta_lake/tool.py @@ -0,0 +1,35 @@ + +from typing import Optional, Type + +from langchain_core.callbacks import CallbackManagerForToolRun +from pydantic import BaseModel, field_validator, Field +from langchain_core.tools import BaseTool +from traceback import format_exc +from .api_wrapper import DeltaLakeApiWrapper + + +class DeltaLakeAction(BaseTool): + """Tool for interacting with the Delta Lake API on AWS.""" + + api_wrapper: DeltaLakeApiWrapper = Field(default_factory=DeltaLakeApiWrapper) + name: str + description: str = "" + args_schema: Optional[Type[BaseModel]] = None + + @field_validator('name', mode='before') + @classmethod + def remove_spaces(cls, v): + return v.replace(' ', '') + + def _run( + self, + *args, + run_manager: Optional[CallbackManagerForToolRun] = None, + **kwargs, + ) -> str: + """Use the Delta Lake API to run an operation.""" + try: + # Use the tool name to dispatch to the correct API wrapper method + return self.api_wrapper.run(self.name, *args, **kwargs) + except Exception as e: + return f"Error: {format_exc()}" \ No newline at end of file diff --git a/src/alita_tools/google/__init__.py b/src/alita_tools/google/__init__.py new file mode 100644 index 00000000..b8f4b963 --- /dev/null +++ b/src/alita_tools/google/__init__.py @@ -0,0 +1,7 @@ +from .bigquery import BigQueryToolkit + +name = "google" + +def get_tools(tool_type, tool): + if tool_type == 'bigquery': + return BigQueryToolkit().get_toolkit().get_tools() \ No newline at end of file diff --git a/src/alita_tools/google/bigquery/__init__.py b/src/alita_tools/google/bigquery/__init__.py new file mode 100644 index 00000000..563c5a0c --- /dev/null +++ b/src/alita_tools/google/bigquery/__init__.py @@ -0,0 +1,154 @@ +from functools import lru_cache +from typing import List, Optional, Type + +from langchain_core.tools import BaseTool, BaseToolkit +from pydantic import BaseModel, Field, SecretStr, computed_field, field_validator + +from ...utils import TOOLKIT_SPLITTER, clean_string, get_max_toolkit_length +from .api_wrapper import BigQueryApiWrapper +from .tool import BigQueryAction + +name = "bigquery" + + +@lru_cache(maxsize=1) +def get_available_tools() -> dict[str, dict]: + api_wrapper = BigQueryApiWrapper.model_construct() + available_tools: dict = { + x["name"]: x["args_schema"].model_json_schema() + for x in api_wrapper.get_available_tools() + } + return available_tools + + +toolkit_max_length = lru_cache(maxsize=1)( + lambda: get_max_toolkit_length(get_available_tools()) +) + + +class BigQueryToolkitConfig(BaseModel): + class Config: + title = name + json_schema_extra = { + "metadata": { + "hidden": True, + "label": "Cloud GCP", + "icon_url": "google.svg", + "sections": { + "auth": { + "required": False, + "subsections": [ + {"name": "API Key", "fields": ["api_key"]}, + ], + } + }, + } + } + + api_key: Optional[SecretStr] = Field( + default=None, + description="GCP API key", + json_schema_extra={"secret": True, "configuration": True}, + ) + project: Optional[str] = Field( + default=None, + description="BigQuery project ID", + json_schema_extra={"configuration": True}, + ) + location: Optional[str] = Field( + default=None, + description="BigQuery location", + json_schema_extra={"configuration": True}, + ) + dataset: Optional[str] = Field( + default=None, + description="BigQuery dataset name", + json_schema_extra={"configuration": True}, + ) + table: Optional[str] = Field( + default=None, + description="BigQuery table name", + json_schema_extra={"configuration": True}, + ) + selected_tools: List[str] = Field( + default=[], + description="Selected tools", + json_schema_extra={"args_schemas": get_available_tools()}, + ) + + @field_validator("selected_tools", mode="before", check_fields=False) + @classmethod + def selected_tools_validator(cls, value: List[str]) -> list[str]: + return [i for i in value if i in get_available_tools()] + + +def _get_toolkit(tool) -> BaseToolkit: + return BigQueryToolkit().get_toolkit( + selected_tools=tool["settings"].get("selected_tools", []), + api_key=tool["settings"].get("api_key", ""), + toolkit_name=tool.get("toolkit_name"), + ) + + +def get_toolkit(): + return BigQueryToolkit.toolkit_config_schema() + + +def get_tools(tool): + return _get_toolkit(tool).get_tools() + + +class BigQueryToolkit(BaseToolkit): + tools: List[BaseTool] = [] + api_wrapper: Optional[BigQueryApiWrapper] = Field( + default_factory=BigQueryApiWrapper.model_construct + ) + toolkit_name: Optional[str] = None + + @computed_field + @property + def tool_prefix(self) -> str: + return ( + clean_string(self.toolkit_name, toolkit_max_length()) + TOOLKIT_SPLITTER + if self.toolkit_name + else "" + ) + + @computed_field + @property + def available_tools(self) -> List[dict]: + return self.api_wrapper.get_available_tools() + + @staticmethod + def toolkit_config_schema() -> Type[BaseModel]: + return BigQueryToolkitConfig + + @classmethod + def get_toolkit( + cls, + selected_tools: list[str] | None = None, + toolkit_name: Optional[str] = None, + **kwargs, + ) -> "BigQueryToolkit": + bigquery_api_wrapper = BigQueryApiWrapper(**kwargs) + instance = cls( + tools=[], api_wrapper=bigquery_api_wrapper, toolkit_name=toolkit_name + ) + if selected_tools: + selected_tools = set(selected_tools) + for t in instance.available_tools: + if t["name"] in selected_tools: + instance.tools.append( + BigQueryAction( + api_wrapper=instance.api_wrapper, + name=instance.tool_prefix + t["name"], + # set unique description for declared tools to differentiate the same methods for different toolkits + description=f"Project: {getattr(instance.api_wrapper, 'project', '')}\n" + + t["description"], + args_schema=t["args_schema"], + ) + ) + return instance + + def get_tools(self): + return self.tools diff --git a/src/alita_tools/google/bigquery/api_wrapper.py b/src/alita_tools/google/bigquery/api_wrapper.py new file mode 100644 index 00000000..793ff145 --- /dev/null +++ b/src/alita_tools/google/bigquery/api_wrapper.py @@ -0,0 +1,502 @@ +import functools +import json +import logging +from typing import Any, Dict, List, Optional, Union + +from google.cloud import bigquery +from langchain_core.tools import ToolException +from pydantic import ( + ConfigDict, + Field, + PrivateAttr, + SecretStr, + field_validator, + model_validator, +) +from pydantic_core.core_schema import ValidationInfo + +from ...elitea_base import BaseToolApiWrapper +from .schemas import ArgsSchema + + +def process_output(func): + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + try: + result = func(self, *args, **kwargs) + if isinstance(result, Exception): + return ToolException(str(result)) + if isinstance(result, (dict, list)): + return json.dumps(result, default=str) + return str(result) + except Exception as e: + logging.error(f"Error in '{func.__name__}': {str(e)}") + return ToolException(str(e)) + + return wrapper + + +class BigQueryApiWrapper(BaseToolApiWrapper): + model_config = ConfigDict(arbitrary_types_allowed=True, from_attributes=True) + + api_key: Optional[SecretStr] = Field( + default=None, json_schema_extra={"env_key": "BIGQUERY_API_KEY"} + ) + project: Optional[str] = Field( + default=None, json_schema_extra={"env_key": "BIGQUERY_PROJECT"} + ) + location: Optional[str] = Field( + default=None, json_schema_extra={"env_key": "BIGQUERY_LOCATION"} + ) + dataset: Optional[str] = Field( + default=None, json_schema_extra={"env_key": "BIGQUERY_DATASET"} + ) + table: Optional[str] = Field( + default=None, json_schema_extra={"env_key": "BIGQUERY_TABLE"} + ) + embedding: Optional[Any] = None + _client: Optional[bigquery.Client] = PrivateAttr(default=None) + + @classmethod + def model_construct(cls, *args, **kwargs): + klass = super().model_construct(*args, **kwargs) + klass._client = None + return klass + + @field_validator( + "api_key", + "project", + "location", + "dataset", + "table", + mode="before", + check_fields=False, + ) + @classmethod + def set_from_values_or_env(cls, value, info: ValidationInfo): + if value is None: + if json_schema_extra := cls.model_fields[info.field_name].json_schema_extra: + if env_key := json_schema_extra.get("env_key"): + try: + from langchain_core.utils import get_from_env + + return get_from_env( + key=info.field_name, + env_key=env_key, + default=cls.model_fields[info.field_name].default, + ) + except Exception: + return None + return value + + @model_validator(mode="after") + def validate_auth(self) -> "BigQueryApiWrapper": + if not self.api_key: + raise ValueError("You must provide a BigQuery API key.") + return self + + @property + def bigquery_client(self) -> bigquery.Client: + if not self._client: + api_key = self.api_key.get_secret_value() if self.api_key else None + if not api_key: + raise ToolException("BigQuery API key is not set.") + try: + api_key_dict = json.loads(api_key) + credentials = bigquery.Client.from_service_account_info( + api_key_dict + )._credentials + self._client = bigquery.Client( + credentials=credentials, + project=self.project, + location=self.location, + ) + except Exception as e: + raise ToolException(f"Error initializing GCP credentials: {str(e)}") + return self._client + + def _get_table_id(self): + if not (self.project and self.dataset and self.table): + raise ToolException("Project, dataset, and table must be specified.") + return f"{self.project}.{self.dataset}.{self.table}" + + def _create_filters( + self, filter: Optional[Union[Dict[str, Any], str]] = None + ) -> str: + if filter: + if isinstance(filter, dict): + filter_expressions = [] + for k, v in filter.items(): + if isinstance(v, (int, float)): + filter_expressions.append(f"{k} = {v}") + else: + filter_expressions.append(f"{k} = '{v}'") + return " AND ".join(filter_expressions) + else: + return filter + return "TRUE" + + def job_stats(self, job_id: str) -> Dict: + return self.bigquery_client.get_job(job_id)._properties.get("statistics", {}) + + def create_vector_index(self): + table_id = self._get_table_id() + index_name = f"{self.table}_langchain_index" + sql = f""" + CREATE VECTOR INDEX IF NOT EXISTS + `{index_name}` + ON `{table_id}` + (embedding) + OPTIONS(distance_type="EUCLIDEAN", index_type="IVF") + """ + try: + self.bigquery_client.query(sql).result() + return f"Vector index '{index_name}' created or already exists." + except Exception as ex: + logging.error(f"Vector index creation failed: {ex}") + return ToolException(f"Vector index creation failed: {ex}") + + @process_output + def get_documents( + self, + ids: Optional[List[str]] = None, + filter: Optional[Union[Dict[str, Any], str]] = None, + ): + table_id = self._get_table_id() + job_config = None + id_expr = "TRUE" + if ids: + job_config = bigquery.QueryJobConfig( + query_parameters=[bigquery.ArrayQueryParameter("ids", "STRING", ids)] + ) + id_expr = "doc_id IN UNNEST(@ids)" + where_filter_expr = self._create_filters(filter) + query = f"SELECT * FROM `{table_id}` WHERE {id_expr} AND {where_filter_expr}" + job = self.bigquery_client.query(query, job_config=job_config) + return [dict(row) for row in job] + + @process_output + def similarity_search( + self, + query: str, + k: int = 5, + filter: Optional[Union[Dict[str, Any], str]] = None, + ): + """Search for top `k` docs most similar to input query using vector similarity search.""" + if not hasattr(self, "embedding") or self.embedding is None: + raise ToolException("Embedding model is not set on the wrapper.") + embedding_vector = self.embedding.embed_query(query) + # Prepare the vector search query + table_id = self._get_table_id() + where_filter_expr = "TRUE" + if filter: + if isinstance(filter, dict): + filter_expressions = [f"{k} = '{v}'" for k, v in filter.items()] + where_filter_expr = " AND ".join(filter_expressions) + else: + where_filter_expr = filter + # BigQuery vector search SQL (using VECTOR_SEARCH if available) + sql = f""" + SELECT *, + VECTOR_DISTANCE(embedding, @query_embedding) AS score + FROM `{table_id}` + WHERE {where_filter_expr} + ORDER BY score ASC + LIMIT {k} + """ + job_config = bigquery.QueryJobConfig( + query_parameters=[ + bigquery.ArrayQueryParameter( + "query_embedding", "FLOAT64", embedding_vector + ) + ] + ) + job = self.bigquery_client.query(sql, job_config=job_config) + return [dict(row) for row in job] + + @process_output + def batch_search( + self, + queries: Optional[List[str]] = None, + embeddings: Optional[List[List[float]]] = None, + k: int = 5, + filter: Optional[Union[Dict[str, Any], str]] = None, + ): + """Batch vector similarity search. Accepts either queries (to embed) or embeddings.""" + if queries is not None and embeddings is not None: + raise ToolException("Provide only one of 'queries' or 'embeddings'.") + if queries is not None: + if not hasattr(self, "embedding") or self.embedding is None: + raise ToolException("Embedding model is not set on the wrapper.") + embeddings = [self.embedding.embed_query(q) for q in queries] + if not embeddings: + raise ToolException("No embeddings or queries provided.") + table_id = self._get_table_id() + where_filter_expr = "TRUE" + if filter: + if isinstance(filter, dict): + filter_expressions = [f"{k} = '{v}'" for k, v in filter.items()] + where_filter_expr = " AND ".join(filter_expressions) + else: + where_filter_expr = filter + results = [] + for emb in embeddings: + sql = f""" + SELECT *, + VECTOR_DISTANCE(embedding, @query_embedding) AS score + FROM `{table_id}` + WHERE {where_filter_expr} + ORDER BY score ASC + LIMIT {k} + """ + job_config = bigquery.QueryJobConfig( + query_parameters=[ + bigquery.ArrayQueryParameter("query_embedding", "FLOAT64", emb) + ] + ) + job = self.bigquery_client.query(sql, job_config=job_config) + results.append([dict(row) for row in job]) + return results + + def similarity_search_by_vector( + self, embedding: List[float], k: int = 5, **kwargs + ) -> List[Dict]: + """Return docs most similar to embedding vector.""" + table_id = self._get_table_id() + sql = f""" + SELECT *, VECTOR_DISTANCE(embedding, @query_embedding) AS score + FROM `{table_id}` + ORDER BY score ASC + LIMIT {k} + """ + job_config = bigquery.QueryJobConfig( + query_parameters=[ + bigquery.ArrayQueryParameter("query_embedding", "FLOAT64", embedding) + ] + ) + job = self.bigquery_client.query(sql, job_config=job_config) + return [self._row_to_document(row) for row in job] + + def similarity_search_by_vector_with_score( + self, + embedding: List[float], + filter: Optional[Union[Dict[str, Any], str]] = None, + k: int = 5, + **kwargs, + ) -> List[Dict]: + """Return docs most similar to embedding vector with scores.""" + table_id = self._get_table_id() + where_filter_expr = self._create_filters(filter) + sql = f""" + SELECT *, VECTOR_DISTANCE(embedding, @query_embedding) AS score + FROM `{table_id}` + WHERE {where_filter_expr} + ORDER BY score ASC + LIMIT {k} + """ + job_config = bigquery.QueryJobConfig( + query_parameters=[ + bigquery.ArrayQueryParameter("query_embedding", "FLOAT64", embedding) + ] + ) + job = self.bigquery_client.query(sql, job_config=job_config) + return [self._row_to_document(row) for row in job] + + def similarity_search_with_score( + self, + query: str, + filter: Optional[Union[Dict[str, Any], str]] = None, + k: int = 5, + **kwargs, + ) -> List[Dict]: + """Search for top `k` docs most similar to input query, returns both docs and scores.""" + embedding = self.embedding.embed_query(query) + return self.similarity_search_by_vector_with_score( + embedding, filter=filter, k=k, **kwargs + ) + + def similarity_search_by_vectors( + self, + embeddings: List[List[float]], + filter: Optional[Union[Dict[str, Any], str]] = None, + k: int = 5, + with_scores: bool = False, + with_embeddings: bool = False, + **kwargs, + ) -> Any: + """Core similarity search function. Handles a list of embedding vectors, optionally returning scores and embeddings.""" + results = [] + for emb in embeddings: + docs = self.similarity_search_by_vector_with_score( + emb, filter=filter, k=k, **kwargs + ) + if not with_scores and not with_embeddings: + docs = [d for d in docs] + elif not with_embeddings: + docs = [{**d, "score": d.get("score")} for d in docs] + elif not with_scores: + docs = [{**d, "embedding": emb} for d in docs] + results.append(docs) + return results + + def execute(self, method: str, *args, **kwargs): + """ + Universal method to call any method from google.cloud.bigquery.Client. + Args: + method: Name of the method to call on the BigQuery client. + *args: Positional arguments for the method. + **kwargs: Keyword arguments for the method. + Returns: + The result of the called method. + Raises: + ToolException: If the client is not initialized or method does not exist. + """ + if not self._client: + raise ToolException("BigQuery client is not initialized.") + if not hasattr(self._client, method): + raise ToolException(f"BigQuery client has no method '{method}'") + func = getattr(self._client, method) + try: + result = func(*args, **kwargs) + return result + except Exception as e: + logging.error(f"Error executing '{method}': {e}") + raise ToolException(f"Error executing '{method}': {e}") + + @process_output + def create_delta_lake_table( + self, + table_name: str, + dataset: Optional[str] = None, + connection_id: str = None, + source_uris: list = None, + autodetect: bool = True, + project: Optional[str] = None, + **kwargs, + ): + """ + Create a Delta Lake external table in BigQuery using the google.cloud.bigquery library. + Args: + table_name: Name of the Delta Lake table to create in BigQuery. + dataset: BigQuery dataset to contain the table (defaults to self.dataset). + connection_id: Fully qualified connection ID (project.region.connection_id). + source_uris: List of GCS URIs (prefixes) for the Delta Lake table. + autodetect: Whether to autodetect schema (default: True). + project: GCP project ID (defaults to self.project). + Returns: + API response as dict. + """ + dataset = dataset or self.dataset + project = project or self.project + if not (project and dataset and table_name and connection_id and source_uris): + raise ToolException("project, dataset, table_name, connection_id, and source_uris are required.") + client = self.bigquery_client + table_ref = bigquery.TableReference( + bigquery.DatasetReference(project, dataset), table_name + ) + external_config = bigquery.ExternalConfig("DELTA_LAKE") + external_config.autodetect = autodetect + external_config.source_uris = source_uris + external_config.connection_id = connection_id + table = bigquery.Table(table_ref) + table.external_data_configuration = external_config + try: + created_table = client.create_table(table, exists_ok=True) + return created_table.to_api_repr() + except Exception as e: + raise ToolException(f"Failed to create Delta Lake table: {e}") + + def get_available_tools(self) -> List[Dict[str, Any]]: + return [ + { + "name": "get_documents", + "description": self.get_documents.__doc__, + "args_schema": ArgsSchema.GetDocuments.value, + "ref": self.get_documents, + }, + { + "name": "similarity_search", + "description": self.similarity_search.__doc__, + "args_schema": ArgsSchema.SimilaritySearch.value, + "ref": self.similarity_search, + }, + { + "name": "batch_search", + "description": self.batch_search.__doc__, + "args_schema": ArgsSchema.BatchSearch.value, + "ref": self.batch_search, + }, + { + "name": "create_vector_index", + "description": self.create_vector_index.__doc__, + "args_schema": ArgsSchema.NoInput.value, + "ref": self.create_vector_index, + }, + { + "name": "job_stats", + "description": self.job_stats.__doc__, + "args_schema": ArgsSchema.JobStatsArgs.value, + "ref": self.job_stats, + }, + { + "name": "similarity_search_by_vector", + "description": self.similarity_search_by_vector.__doc__, + "args_schema": ArgsSchema.SimilaritySearchByVectorArgs.value, + "ref": self.similarity_search_by_vector, + }, + { + "name": "similarity_search_by_vector_with_score", + "description": self.similarity_search_by_vector_with_score.__doc__, + "args_schema": ArgsSchema.SimilaritySearchByVectorWithScoreArgs.value, + "ref": self.similarity_search_by_vector_with_score, + }, + { + "name": "similarity_search_with_score", + "description": self.similarity_search_with_score.__doc__, + "args_schema": ArgsSchema.SimilaritySearchWithScoreArgs.value, + "ref": self.similarity_search_with_score, + }, + { + "name": "similarity_search_by_vectors", + "description": self.similarity_search_by_vectors.__doc__, + "args_schema": ArgsSchema.SimilaritySearchByVectorsArgs.value, + "ref": self.similarity_search_by_vectors, + }, + { + "name": "execute", + "description": self.execute.__doc__, + "args_schema": ArgsSchema.ExecuteArgs.value, + "ref": self.execute, + }, + { + "name": "create_delta_lake_table", + "description": self.create_delta_lake_table.__doc__, + "args_schema": ArgsSchema.CreateDeltaLakeTable.value, + "ref": self.create_delta_lake_table, + }, + ] + + def run(self, name: str, *args: Any, **kwargs: Any): + for tool in self.get_available_tools(): + if tool["name"] == name: + # Handle potential dictionary input for args when only one dict is passed + if len(args) == 1 and isinstance(args[0], dict) and not kwargs: + kwargs = args[0] + args = () # Clear args + try: + return tool["ref"](*args, **kwargs) + except TypeError as e: + # Attempt to call with kwargs only if args fail and kwargs exist + if kwargs and not args: + try: + return tool["ref"](**kwargs) + except TypeError: + raise ValueError( + f"Argument mismatch for tool '{name}'. Error: {e}" + ) from e + else: + raise ValueError( + f"Argument mismatch for tool '{name}'. Error: {e}" + ) from e + else: + raise ValueError(f"Unknown tool name: {name}") diff --git a/src/alita_tools/google/bigquery/schemas.py b/src/alita_tools/google/bigquery/schemas.py new file mode 100644 index 00000000..068b6e0f --- /dev/null +++ b/src/alita_tools/google/bigquery/schemas.py @@ -0,0 +1,102 @@ +from enum import Enum +from typing import Any, Dict, List, Optional, Union + +from pydantic import Field, create_model + + +class ArgsSchema(Enum): + NoInput = create_model("NoInput") + GetDocuments = create_model( + "GetDocuments", + ids=( + Optional[List[str]], + Field(default=None, description="List of document IDs to retrieve."), + ), + filter=( + Optional[Union[Dict[str, Any], str]], + Field(default=None, description="Filter as dict or SQL WHERE clause."), + ), + ) + SimilaritySearch = create_model( + "SimilaritySearch", + query=(str, Field(description="Text query to search for similar documents.")), + k=(int, Field(default=5, description="Number of top results to return.")), + filter=( + Optional[Union[Dict[str, Any], str]], + Field(default=None, description="Filter as dict or SQL WHERE clause."), + ), + ) + BatchSearch = create_model( + "BatchSearch", + queries=( + Optional[List[str]], + Field(default=None, description="List of text queries."), + ), + embeddings=( + Optional[List[List[float]]], + Field(default=None, description="List of embedding vectors."), + ), + k=(int, Field(default=5, description="Number of top results to return.")), + filter=( + Optional[Union[Dict[str, Any], str]], + Field(default=None, description="Filter as dict or SQL WHERE clause."), + ), + ) + JobStatsArgs = create_model( + "JobStatsArgs", job_id=(str, Field(description="BigQuery job ID.")) + ) + SimilaritySearchByVectorArgs = create_model( + "SimilaritySearchByVectorArgs", + embedding=(List[float], Field(description="Embedding vector.")), + k=(int, Field(default=5, description="Number of top results to return.")), + ) + SimilaritySearchByVectorWithScoreArgs = create_model( + "SimilaritySearchByVectorWithScoreArgs", + embedding=(List[float], Field(description="Embedding vector.")), + filter=( + Optional[Union[Dict[str, Any], str]], + Field(default=None, description="Filter as dict or SQL WHERE clause."), + ), + k=(int, Field(default=5, description="Number of top results to return.")), + ) + SimilaritySearchWithScoreArgs = create_model( + "SimilaritySearchWithScoreArgs", + query=(str, Field(description="Text query.")), + filter=( + Optional[Union[Dict[str, Any], str]], + Field(default=None, description="Filter as dict or SQL WHERE clause."), + ), + k=(int, Field(default=5, description="Number of top results to return.")), + ) + SimilaritySearchByVectorsArgs = create_model( + "SimilaritySearchByVectorsArgs", + embeddings=(List[List[float]], Field(description="List of embedding vectors.")), + filter=( + Optional[Union[Dict[str, Any], str]], + Field(default=None, description="Filter as dict or SQL WHERE clause."), + ), + k=(int, Field(default=5, description="Number of top results to return.")), + with_scores=(bool, Field(default=False)), + with_embeddings=(bool, Field(default=False)), + ) + ExecuteArgs = create_model( + "ExecuteArgs", + method=(str, Field(description="Name of the BigQuery client method to call.")), + args=( + Optional[List[Any]], + Field(default=None, description="Positional arguments for the method."), + ), + kwargs=( + Optional[Dict[str, Any]], + Field(default=None, description="Keyword arguments for the method."), + ), + ) + CreateDeltaLakeTable = create_model( + "CreateDeltaLakeTable", + table_name=(str, Field(description="Name of the Delta Lake table to create in BigQuery.")), + dataset=(Optional[str], Field(default=None, description="BigQuery dataset to contain the table (defaults to self.dataset).")), + connection_id=(str, Field(description="Fully qualified connection ID (project.region.connection_id).")), + source_uris=(list, Field(description="List of GCS URIs (prefixes) for the Delta Lake table.")), + autodetect=(bool, Field(default=True, description="Whether to autodetect schema (default: True).")), + project=(Optional[str], Field(default=None, description="GCP project ID (defaults to self.project).")), + ) diff --git a/src/alita_tools/google/bigquery/tool.py b/src/alita_tools/google/bigquery/tool.py new file mode 100644 index 00000000..bc05543f --- /dev/null +++ b/src/alita_tools/google/bigquery/tool.py @@ -0,0 +1,34 @@ +from typing import Optional, Type + +from langchain_core.callbacks import CallbackManagerForToolRun +from pydantic import BaseModel, field_validator, Field +from langchain_core.tools import BaseTool +from traceback import format_exc +from .api_wrapper import BigQueryApiWrapper + + +class BigQueryAction(BaseTool): + """Tool for interacting with the BigQuery API.""" + + api_wrapper: BigQueryApiWrapper = Field(default_factory=BigQueryApiWrapper) + name: str + mode: str = "" + description: str = "" + args_schema: Optional[Type[BaseModel]] = None + + @field_validator('name', mode='before') + @classmethod + def remove_spaces(cls, v): + return v.replace(' ', '') + + def _run( + self, + *args, + run_manager: Optional[CallbackManagerForToolRun] = None, + **kwargs, + ) -> str: + """Use the GitHub API to run an operation.""" + try: + return self.api_wrapper.run(self.mode, *args, **kwargs) + except Exception as e: + return f"Error: {format_exc()}" \ No newline at end of file