-
Notifications
You must be signed in to change notification settings - Fork 0
Added TA-SQL schema linking module #107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
fi5421
wants to merge
8
commits into
schema-linking
Choose a base branch
from
ta-sql_integration
base: schema-linking
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
a7b0f22
feat: added TA-SQL schema linking module
fi5421 d8fa736
feat:
fi5421 7aad183
chore: cleaned code
fi5421 c2ed3cd
chore: cleaned code
fi5421 d92bb4d
chore: added updated google ai client
fi5421 fd66577
Merge branch 'schema-linking' into ta-sql_integration
fi5421 71629b6
chore:
fi5421 ca93cdc
chore: moved ta sql dummy prompt to zero_shots_prompts from prompt te…
fi5421 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,91 +1,98 @@ | ||
| import time | ||
| import random | ||
| from typing import Optional | ||
|
|
||
| import google.generativeai as genai | ||
| from utilities.constants.LLM_enums import LLMType, ModelType | ||
| from utilities.constants.response_messages import ERROR_API_FAILURE | ||
| from utilities.utility_functions import format_chat | ||
| from utilities.config import ALL_GOOGLE_KEYS | ||
| from services.base_client import Client | ||
| from google.generativeai.types import HarmCategory, HarmBlockThreshold | ||
| import random | ||
|
|
||
| from services.base_client import Client | ||
| from utilities.config import ALL_GOOGLE_KEYS | ||
| from utilities.logging_utils import setup_logger | ||
| from utilities.utility_functions import format_chat | ||
| from utilities.constants.LLM_enums import LLMType, ModelType | ||
| from utilities.constants.response_messages import ( | ||
| ERROR_API_FAILURE, | ||
| WARNING_ALL_API_KEYS_QUOTA_EXCEEDED, | ||
| ) | ||
|
|
||
| logger = setup_logger(__name__) | ||
|
|
||
| # Constants | ||
| QUOTA_EXCEEDED_ERROR_CODE = "429" | ||
| GENERATION_SAFETY_SETTINGS = [ | ||
| {"category": HarmCategory.HARM_CATEGORY_HATE_SPEECH, "threshold": HarmBlockThreshold.BLOCK_NONE}, | ||
| {"category": HarmCategory.HARM_CATEGORY_HARASSMENT, "threshold": HarmBlockThreshold.BLOCK_NONE}, | ||
| {"category": HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, "threshold": HarmBlockThreshold.BLOCK_NONE}, | ||
| {"category": HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, "threshold": HarmBlockThreshold.BLOCK_NONE}, | ||
| ] | ||
|
|
||
| class GoogleAIClient(Client): | ||
| def __init__(self, model: ModelType, max_tokens: Optional[int] = 150, temperature: Optional[float] = 0.5): | ||
| self.current_key_idx = random.randint(0,len(ALL_GOOGLE_KEYS)-1) | ||
| self.client = genai.configure(api_key=ALL_GOOGLE_KEYS[self.current_key_idx]) | ||
| self.call_num=0 | ||
| self.call_limit=5 | ||
| self.__current_key_index = random.randint(0, len(ALL_GOOGLE_KEYS) - 1) | ||
| self.client = genai.configure(api_key=ALL_GOOGLE_KEYS[self.__current_key_index]) | ||
| super().__init__(model=model.value, temperature=temperature, max_tokens=max_tokens, client=self.client) | ||
|
|
||
| def change_client(self): | ||
| self.current_key_idx = (self.current_key_idx + 1) % len(ALL_GOOGLE_KEYS) | ||
| self.client = genai.configure(api_key=ALL_GOOGLE_KEYS[self.current_key_idx]) | ||
| self.call_num = 0 | ||
|
|
||
| def __rotate_api_key(self): | ||
| self.__current_key_index = (self.__current_key_index + 1) % len(ALL_GOOGLE_KEYS) | ||
| self.client = genai.configure(api_key=ALL_GOOGLE_KEYS[self.__current_key_index]) | ||
|
|
||
| def __retry_on_quota_exceeded(self, llm_call): | ||
| # Retry the LLM call until it succeeds or raises a non-quota-exceeded error | ||
| response = None | ||
| consecutive_quota_errors = 0 | ||
|
|
||
| while response is None: | ||
| try: | ||
| response = llm_call() | ||
| consecutive_quota_errors = 0 # Reset the error count on success | ||
| except Exception as e: | ||
| if QUOTA_EXCEEDED_ERROR_CODE in str(e): | ||
| consecutive_quota_errors += 1 | ||
| self.__rotate_api_key() | ||
|
|
||
| # If we've tried all keys and still getting quota errors wait before retrying | ||
| if consecutive_quota_errors >= len(ALL_GOOGLE_KEYS): | ||
| logger.warning(WARNING_ALL_API_KEYS_QUOTA_EXCEEDED.format(llm_type=LLMType.GOOGLE_AI.value)) | ||
| time.sleep(5) | ||
| consecutive_quota_errors = 0 | ||
| else: | ||
| # Raise errors other than quota exceeded | ||
| raise RuntimeError(ERROR_API_FAILURE.format( | ||
| llm_type=LLMType.GOOGLE_AI.value, | ||
| error=str(e) | ||
| )) | ||
|
|
||
| return response | ||
|
|
||
| def execute_prompt(self, prompt: str) -> str: | ||
| if self.call_num >= self.call_limit: | ||
| self.change_client() | ||
| try: | ||
| def generate_text_response(): | ||
| model = genai.GenerativeModel(self.model) | ||
| response = model.generate_content( | ||
| contents=prompt, | ||
| generation_config={ | ||
| 'temperature': self.temperature, | ||
| 'max_output_tokens': self.max_tokens, | ||
| "temperature": self.temperature, | ||
| "max_output_tokens": self.max_tokens, | ||
| }, | ||
| safety_settings={ | ||
| HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, | ||
| HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, | ||
| HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, | ||
| HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE | ||
| } | ||
| safety_settings=GENERATION_SAFETY_SETTINGS, | ||
| ) | ||
| self.call_num+=1 | ||
| return response.text | ||
|
|
||
| except Exception as e1: | ||
| current_key = self.current_key_idx | ||
| self.change_client() | ||
| while current_key != self.current_key_idx: | ||
| try: | ||
| model = genai.GenerativeModel(self.model) | ||
| response = model.generate_content( | ||
| contents=prompt, | ||
| generation_config={ | ||
| 'temperature': self.temperature, | ||
| 'max_output_tokens': self.max_tokens | ||
| } | ||
| ) | ||
| self.call_num+=1 | ||
| return response.text | ||
| except Exception as e: | ||
| self.change_client() | ||
|
|
||
| raise RuntimeError(ERROR_API_FAILURE.format(llm_type=LLMType.GOOGLE_AI.value, error=str(e1))) | ||
|
|
||
| def execute_chat(self, chat): | ||
|
|
||
| chat=format_chat(chat, {'system': 'system', 'user':'user', 'model':'model', 'content':'parts'}) | ||
| changes=0 | ||
| if self.call_num >= self.call_limit: | ||
| self.change_client() | ||
| while changes<len(ALL_GOOGLE_KEYS): | ||
| try: | ||
| system_msg = chat[0]['parts'] | ||
| user_msg = next((msg['parts'] for msg in reversed(chat) if msg['role'] == 'user'), None) | ||
| history = chat[1:-1] if len(chat) > 2 else [] | ||
|
|
||
| model = genai.GenerativeModel( | ||
| model_name=self.model, | ||
| system_instruction= system_msg | ||
| ) | ||
| chat_model = model.start_chat(history=history) | ||
| response = chat_model.send_message(user_msg) | ||
| self.call_num+=1 | ||
| return response.text | ||
| except Exception as e: | ||
| changes+=1 | ||
| error = str(e) | ||
| self.change_client() | ||
| return self.__retry_on_quota_exceeded(generate_text_response) | ||
|
|
||
| def execute_chat(self, chat) -> str: | ||
| chat = format_chat(chat, {"system": "system", "user": "user", "model": "model", "content": "parts"}) | ||
|
|
||
| def generate_chat_response(): | ||
| system_msg = chat[0]["parts"] | ||
| user_msg = next((msg["parts"] for msg in reversed(chat) if msg["role"] == "user"), None) | ||
| history = chat[1:-1] if len(chat) > 2 else [] | ||
|
|
||
| model = genai.GenerativeModel( | ||
| model_name=self.model, | ||
| system_instruction=system_msg, | ||
| ) | ||
| chat_model = model.start_chat(history=history) | ||
| response = chat_model.send_message(user_msg) | ||
| return response.text | ||
|
|
||
| raise RuntimeError(ERROR_API_FAILURE.format(llm_type=LLMType.GOOGLE_AI.value, error=str(error))) | ||
|
|
||
| return self.__retry_on_quota_exceeded(generate_chat_response) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.