From 7ab8d215b01d7063814b10cb95c303b066e83386 Mon Sep 17 00:00:00 2001 From: Anshul Yadav Date: Tue, 16 Apr 2024 11:22:34 +0530 Subject: [PATCH] changes to gemini configuration and added embedding method for gemini --- lucknowllm/models/gemini_model.py | 66 ++++++++++++++++++++++--------- 1 file changed, 47 insertions(+), 19 deletions(-) diff --git a/lucknowllm/models/gemini_model.py b/lucknowllm/models/gemini_model.py index 95c8f61..8f7ed05 100644 --- a/lucknowllm/models/gemini_model.py +++ b/lucknowllm/models/gemini_model.py @@ -1,42 +1,70 @@ import google.generativeai as genai +import os +from typing import Optional + class GeminiModel: - def __init__(self, api_key, model_name): + def __init__( + self, + api_key=os.environ["GOOGLE_API_KEY"], + model_name="gemini-pro", + temperature: Optional[int] = 0, + top_p: Optional[int] = 1, + top_k: Optional[int] = 1, + max_output_tokens: Optional[int] = 30720, + safety_settings: Optional[dict] = None, + ): # Configure the API with the provided key genai.configure(api_key=api_key) - + # Default configuration settings; can be customized further if needed generation_config = { - "temperature": 0, - "top_p": 1, - "top_k": 1, - "max_output_tokens": 30720, + "temperature": temperature, + "top_p": top_p, + "top_k": top_k, + "max_output_tokens": max_output_tokens, } - - safety_settings = [ - {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH"}, - {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_ONLY_HIGH"}, - {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_ONLY_HIGH"}, - {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_ONLY_HIGH"}, - ] - + safety_settings = safety_settings + # safety_settings = [ + # {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH"}, + # {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_ONLY_HIGH"}, + # { + # "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + # "threshold": "BLOCK_ONLY_HIGH", + # }, + # { + # "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + # "threshold": "BLOCK_ONLY_HIGH", + # }, + # ] + # Set up the model with the provided model name - self.model = genai.GenerativeModel(model_name=model_name, - generation_config=generation_config, - safety_settings=safety_settings) + self.model = genai.GenerativeModel( + model_name=model_name, + generation_config=generation_config, + safety_settings=safety_settings, + ) def generate_content(self, prompts): # Generate content based on the provided prompts response = self.model.generate_content([prompts]) return response.text + def embedding(self, text): + embeddings = genai.embed_content( + model="models/embedding-001", + content=text, + task_type="retrieval_document", + ) + return embeddings + # # Example usage: # if __name__ == "__main__": # api_key = "YOUR_API_KEY" # model_name = "gemini-1.0-pro" - + # gen_ai_model = GeminiModel(api_key, model_name) # prompts = ["hey Hi"] # response_text = gen_ai_model.generate_content(prompts) -# print(response_text) \ No newline at end of file +# print(response_text)