Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 80 additions & 73 deletions server/services/google_ai_client.py
Original file line number Diff line number Diff line change
@@ -1,91 +1,98 @@
import time
import random
from typing import Optional

import google.generativeai as genai
from utilities.constants.LLM_enums import LLMType, ModelType
from utilities.constants.response_messages import ERROR_API_FAILURE
from utilities.utility_functions import format_chat
from utilities.config import ALL_GOOGLE_KEYS
from services.base_client import Client
from google.generativeai.types import HarmCategory, HarmBlockThreshold
import random

from services.base_client import Client
from utilities.config import ALL_GOOGLE_KEYS
from utilities.logging_utils import setup_logger
from utilities.utility_functions import format_chat
from utilities.constants.LLM_enums import LLMType, ModelType
from utilities.constants.response_messages import (
ERROR_API_FAILURE,
WARNING_ALL_API_KEYS_QUOTA_EXCEEDED,
)

logger = setup_logger(__name__)

# Constants
QUOTA_EXCEEDED_ERROR_CODE = "429"
GENERATION_SAFETY_SETTINGS = [
{"category": HarmCategory.HARM_CATEGORY_HATE_SPEECH, "threshold": HarmBlockThreshold.BLOCK_NONE},
{"category": HarmCategory.HARM_CATEGORY_HARASSMENT, "threshold": HarmBlockThreshold.BLOCK_NONE},
{"category": HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, "threshold": HarmBlockThreshold.BLOCK_NONE},
{"category": HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, "threshold": HarmBlockThreshold.BLOCK_NONE},
]

class GoogleAIClient(Client):
def __init__(self, model: ModelType, max_tokens: Optional[int] = 150, temperature: Optional[float] = 0.5):
self.current_key_idx = random.randint(0,len(ALL_GOOGLE_KEYS)-1)
self.client = genai.configure(api_key=ALL_GOOGLE_KEYS[self.current_key_idx])
self.call_num=0
self.call_limit=5
self.__current_key_index = random.randint(0, len(ALL_GOOGLE_KEYS) - 1)
self.client = genai.configure(api_key=ALL_GOOGLE_KEYS[self.__current_key_index])
super().__init__(model=model.value, temperature=temperature, max_tokens=max_tokens, client=self.client)

def change_client(self):
self.current_key_idx = (self.current_key_idx + 1) % len(ALL_GOOGLE_KEYS)
self.client = genai.configure(api_key=ALL_GOOGLE_KEYS[self.current_key_idx])
self.call_num = 0

def __rotate_api_key(self):
self.__current_key_index = (self.__current_key_index + 1) % len(ALL_GOOGLE_KEYS)
self.client = genai.configure(api_key=ALL_GOOGLE_KEYS[self.__current_key_index])

def __retry_on_quota_exceeded(self, llm_call):
# Retry the LLM call until it succeeds or raises a non-quota-exceeded error
response = None
consecutive_quota_errors = 0

while response is None:
try:
response = llm_call()
consecutive_quota_errors = 0 # Reset the error count on success
except Exception as e:
if QUOTA_EXCEEDED_ERROR_CODE in str(e):
consecutive_quota_errors += 1
self.__rotate_api_key()

# If we've tried all keys and still getting quota errors wait before retrying
if consecutive_quota_errors >= len(ALL_GOOGLE_KEYS):
logger.warning(WARNING_ALL_API_KEYS_QUOTA_EXCEEDED.format(llm_type=LLMType.GOOGLE_AI.value))
time.sleep(5)
consecutive_quota_errors = 0
else:
# Raise errors other than quota exceeded
raise RuntimeError(ERROR_API_FAILURE.format(
llm_type=LLMType.GOOGLE_AI.value,
error=str(e)
))

return response

def execute_prompt(self, prompt: str) -> str:
if self.call_num >= self.call_limit:
self.change_client()
try:
def generate_text_response():
model = genai.GenerativeModel(self.model)
response = model.generate_content(
contents=prompt,
generation_config={
'temperature': self.temperature,
'max_output_tokens': self.max_tokens,
"temperature": self.temperature,
"max_output_tokens": self.max_tokens,
},
safety_settings={
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE
}
safety_settings=GENERATION_SAFETY_SETTINGS,
)
self.call_num+=1
return response.text

except Exception as e1:
current_key = self.current_key_idx
self.change_client()
while current_key != self.current_key_idx:
try:
model = genai.GenerativeModel(self.model)
response = model.generate_content(
contents=prompt,
generation_config={
'temperature': self.temperature,
'max_output_tokens': self.max_tokens
}
)
self.call_num+=1
return response.text
except Exception as e:
self.change_client()

raise RuntimeError(ERROR_API_FAILURE.format(llm_type=LLMType.GOOGLE_AI.value, error=str(e1)))

def execute_chat(self, chat):

chat=format_chat(chat, {'system': 'system', 'user':'user', 'model':'model', 'content':'parts'})
changes=0
if self.call_num >= self.call_limit:
self.change_client()
while changes<len(ALL_GOOGLE_KEYS):
try:
system_msg = chat[0]['parts']
user_msg = next((msg['parts'] for msg in reversed(chat) if msg['role'] == 'user'), None)
history = chat[1:-1] if len(chat) > 2 else []

model = genai.GenerativeModel(
model_name=self.model,
system_instruction= system_msg
)
chat_model = model.start_chat(history=history)
response = chat_model.send_message(user_msg)
self.call_num+=1
return response.text
except Exception as e:
changes+=1
error = str(e)
self.change_client()
return self.__retry_on_quota_exceeded(generate_text_response)

def execute_chat(self, chat) -> str:
chat = format_chat(chat, {"system": "system", "user": "user", "model": "model", "content": "parts"})

def generate_chat_response():
system_msg = chat[0]["parts"]
user_msg = next((msg["parts"] for msg in reversed(chat) if msg["role"] == "user"), None)
history = chat[1:-1] if len(chat) > 2 else []

model = genai.GenerativeModel(
model_name=self.model,
system_instruction=system_msg,
)
chat_model = model.start_chat(history=history)
response = chat_model.send_message(user_msg)
return response.text

raise RuntimeError(ERROR_API_FAILURE.format(llm_type=LLMType.GOOGLE_AI.value, error=str(error)))

return self.__retry_on_quota_exceeded(generate_chat_response)
1 change: 1 addition & 0 deletions server/utilities/constants/prompts_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class PromptType(Enum):
DAIL_SQL = "dail_sql"
SEMANTIC_FULL_INFORMATION = "semantic_full_information"
ICL_XIYAN = "icl_xiyan"
TASL_DUMMY_SQL = "tasl_dummy_sql"

class RefinerPromptType(Enum):
BASIC = "basic"
Expand Down
1 change: 1 addition & 0 deletions server/utilities/constants/response_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
ERROR_FAILED_FETCH_COLUMN_NAMES = "Failed to fetch column names: {error}"
ERROR_FAILED_FETCH_SCHEMA = "Failed to fetch schema: {error}"
ERROR_FAILED_FETCH_FOREIGN_KEYS = "Failed to fetch foreign keys for table {table_name}: {error}"
ERROR_FAILED_FETCHING_PRIMARY_KEYS = "Failed to fetch primary keys: {error}"

# Cost Estimation related Errors
ERROR_INVALID_MODEL_FOR_TOKEN_ESTIMATION = "Model {model} is not a valid OpenAI model. Only OpenAI models are supported."
Expand Down
2 changes: 2 additions & 0 deletions server/utilities/prompts/prompt_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,7 @@ def get_prompt_class(prompt_type: PromptType, target_question: str, examples = N
return SemanticAndFullInformationOrganizationPrompt(shots=shots, target_question=target_question, schema_format=schema_format, schema=schema, evidence=evidence, database_name=database_name).get_prompt()
elif prompt_type == PromptType.ICL_XIYAN:
return ICLXiyanPrompt(shots=shots, target_question=target_question, schema=schema, evidence=evidence, database_name=database_name).get_prompt()
elif prompt_type == PromptType.TASL_DUMMY_SQL:
return TASLDummySQLPrompt(target_question=target_question, database_name=database_name, evidence=evidence).get_prompt()
else:
raise ValueError(ERROR_PROMPT_TYPE_NOT_FOUND.format(prompt_type=prompt_type))
63 changes: 61 additions & 2 deletions server/utilities/prompts/zero_shot_prompts.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from utilities.utility_functions import format_schema
from utilities.utility_functions import format_schema, get_schema_dict, get_table_foreign_keys, get_primary_keys
from utilities.prompts.base_prompt import BasePrompt
from utilities.constants.prompts_enums import FormatType
from utilities.config import PATH_CONFIG

import sqlite3
import json
from collections import defaultdict
class BasicPrompt(BasePrompt):

def get_prompt(self) -> str:
Expand Down Expand Up @@ -60,3 +62,60 @@ def get_prompt(self) -> str:
### Response:
SELECT"""
return prompt

class TASLDummySQLPrompt(BasePrompt):

def get_prompt(self) -> str:

column_meanings = json.load(open(PATH_CONFIG.column_meaning_path()))

schema_item_dict = defaultdict(lambda: defaultdict(dict))
for key, value in column_meanings.items():
db_id, table, column = key.split('|')
value = value.replace('#', '')
value = value.replace('\n', ', ')
schema_item_dict[db_id][table][column] = value

schema_item_dict = json.loads(json.dumps(schema_item_dict))
connection = sqlite3.connect(PATH_CONFIG.sqlite_path(database_name=self.database_name))
schema_dict = get_schema_dict(PATH_CONFIG.sqlite_path(self.database_name))

foreign_keys = {}
schema_with_descriptions = {}
for table in schema_dict:
fks = get_table_foreign_keys(connection=connection, table_name=table)
if len(fks)>0:
for relation in fks:
key = table.lower()+"."+str(relation['from_column'])
value = str(relation['to_table'])+"."+str(relation['to_column'])
foreign_keys[key] = value

schema_with_descriptions[table] = {}
for column in schema_dict[table]:
try:
schema_with_descriptions[table][column] = schema_item_dict[self.database_name][table][column]
except KeyError:
schema_with_descriptions[table][column] = ""

primary_keys = get_primary_keys(connection=connection)

prompt = f"""# the key is the table, the value is a dict which key is original column name and value is the column information including full name, column description, value_description and example values.
database_schema = {json.dumps(schema_with_descriptions, indent=4)}

# the key is the table, the value is the list of its counterpart primary keys
primary_keys = {json.dumps(primary_keys, indent=4)}

# the key is the source column, the value is the target column referenced by foreign key relationship.
foreign_keys = {json.dumps(foreign_keys, indent=4)}

question = "{self.target_question}"

evidence = "{self.evidence}"

def question_to_SQL(question):
# DO NOT select more things other than what the question asks
# Generate the SQL to answer the question considering database_schema, primary_keys and foreign_keys
# Also consider the evidence when generating the SQL
SQL = "SELECT """

return prompt
1 change: 1 addition & 0 deletions server/utilities/schema_linking/schema_linking_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from utilities.prompts.prompt_templates import SCHEMA_SELECTOR_PROMPT_TEMPLATE
from utilities.utility_functions import format_schema, get_table_foreign_keys
from utilities.constants.script_constants import GOOGLE_RESOURCE_EXHAUSTED_EXCEPTION_STR
from services.base_client import Client

logger = setup_logger(__name__)

Expand Down
81 changes: 79 additions & 2 deletions server/utilities/utility_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import concurrent.futures
import pandas as pd
import yaml

from typing import Dict, List
from utilities.constants.response_messages import (
ERROR_DATABASE_QUERY_FAILURE,
ERROR_SQL_QUERY_REQUIRED,
Expand All @@ -17,6 +17,9 @@
ERROR_UNSUPPORTED_FORMAT_TYPE,
ERROR_FAILED_FETCH_COLUMN_NAMES,
ERROR_FAILED_FETCH_TABLE_NAMES,
ERROR_FAILED_FETCH_FOREIGN_KEYS,
ERROR_FAILED_FETCHING_PRIMARY_KEYS,
ERROR_FAILED_FETCH_SCHEMA,
)

from utilities.constants.LLM_enums import LLMType, ModelType, VALID_LLM_MODELS
Expand Down Expand Up @@ -120,6 +123,52 @@ def get_table_columns(connection: sqlite3.Connection, table_name: str):
raise RuntimeError((ERROR_FAILED_FETCH_COLUMN_NAMES.format(error=str(e))))


def get_table_foreign_keys(connection: sqlite3.Connection, table_name: str):
"""
Retrieves foreign key information for a given table.
"""
try:
query = f'PRAGMA foreign_key_list("{table_name}");'
cursor = connection.execute(query)
foreign_keys = cursor.fetchall()

return [
{
"from_column": row[3], # column in current table
"to_column": row[4], # referenced column in foreign table
"to_table": row[2] # referenced table
}
for row in foreign_keys
]

except Exception as e:
raise RuntimeError(ERROR_FAILED_FETCH_FOREIGN_KEYS.format(table_name=table_name, error=str(e)))

def get_primary_keys(connection: sqlite3.Connection) -> Dict[str, List[str]]:
"""
Returns a dictionary mapping each table name to its list of primary key columns.
"""
try:
cursor = connection.cursor()
pk_dict = {}
# Get all table names
tables = get_table_names(connection)
if 'sqlite_sequence' in tables:
tables.remove('sqlite_sequence')

for table_name in tables:
cursor.execute(f"PRAGMA table_info(\"{table_name}\");")
columns = cursor.fetchall()

# PRAGMA table_info returns: cid, name, type, notnull, dflt_value, pk
pk_columns = [col[1] for col in columns if col[5] > 0] # col[1] is column name, col[5] is pk flag
pk_dict[table_name] = pk_columns

return pk_dict
except Exception as e:
raise RuntimeError(ERROR_FAILED_FETCHING_PRIMARY_KEYS.format(error=str(e)))


def get_array_of_table_and_column_name(database_path: str):
try:
connection = sqlite3.connect(database_path)
Expand All @@ -135,6 +184,34 @@ def get_array_of_table_and_column_name(database_path: str):
connection.close()


def get_schema_dict(database_path: str) -> Dict[str, List[str]]:
"""
Retrieves schema dictionary from the SQLite database in the format {table_name: [column1, column2, ...]}.
"""

try:

with sqlite3.connect(database_path) as connection:

connection.row_factory = sqlite3.Row

table_names = get_table_names(connection)

schema = {
table_name: get_table_columns(connection, table_name)
for table_name in table_names
}

if "sqlite_sequence" in schema:
del schema['sqlite_sequence']

return schema

except Exception as e:

raise RuntimeError(ERROR_FAILED_FETCH_SCHEMA.format(error=str(e)))


def prune_code(ddl, columns, connection, table):
"""
Filters the given DDL statement to retain only the specified columns and
Expand Down Expand Up @@ -513,4 +590,4 @@ def check_config_types(
f"Key '{full_key}' should be of type {expected_type.__name__}"
)

return errors
return errors