From 0b580e859bff94be6c846ca6a5989ef9d46ddf2d Mon Sep 17 00:00:00 2001 From: Ansh Mishra Date: Wed, 12 Nov 2025 01:40:58 +0530 Subject: [PATCH] Refactor GeminiModel to support custom configurations This refactor updates the GeminiModel class to allow passing in custom generation_config and safety_settings during instantiation. Previously, these configurations were hardcoded. Now, the constructor accepts optional config arguments. If none are provided, the class falls back to its original default settings. This makes the class more flexible and reusable for different use cases (e.g., creative vs. factual generation) without modifying the class definition. --- lucknowllm/models/gemini_model.py | 125 +++++++++++++++++++++++------- 1 file changed, 97 insertions(+), 28 deletions(-) diff --git a/lucknowllm/models/gemini_model.py b/lucknowllm/models/gemini_model.py index 95c8f61..9f925e0 100644 --- a/lucknowllm/models/gemini_model.py +++ b/lucknowllm/models/gemini_model.py @@ -1,42 +1,111 @@ import google.generativeai as genai +from typing import Union, List, Dict, Optional class GeminiModel: - def __init__(self, api_key, model_name): - # Configure the API with the provided key + def __init__( + self, + api_key: str, + model_name: str = "gemini-1.0-pro", + generation_config: Optional[Dict] = None, + safety_settings: Optional[List] = None, + ): + """ + Initializes the Gemini model with optional generation and safety configurations. + + Args: + api_key (str): Google Generative AI API key. + model_name (str): Name of the Gemini model (e.g., "gemini-1.0-pro"). + generation_config (dict, optional): Custom generation configuration parameters. + safety_settings (list, optional): Custom safety configuration list. + """ + # Configure API genai.configure(api_key=api_key) - - # Default configuration settings; can be customized further if needed - generation_config = { + + # Default generation config (fallback) + default_generation_config = { "temperature": 0, "top_p": 1, "top_k": 1, "max_output_tokens": 30720, } - - safety_settings = [ + + # Default safety settings (fallback) + default_safety_settings = [ {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH"}, {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_ONLY_HIGH"}, {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_ONLY_HIGH"}, {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_ONLY_HIGH"}, ] - - # Set up the model with the provided model name - self.model = genai.GenerativeModel(model_name=model_name, - generation_config=generation_config, - safety_settings=safety_settings) - - def generate_content(self, prompts): - # Generate content based on the provided prompts - response = self.model.generate_content([prompts]) - return response.text - - -# # Example usage: -# if __name__ == "__main__": -# api_key = "YOUR_API_KEY" -# model_name = "gemini-1.0-pro" - -# gen_ai_model = GeminiModel(api_key, model_name) -# prompts = ["hey Hi"] -# response_text = gen_ai_model.generate_content(prompts) -# print(response_text) \ No newline at end of file + + # Use user-provided or default settings + self.model = genai.GenerativeModel( + model_name=model_name, + generation_config=generation_config or default_generation_config, + safety_settings=safety_settings or default_safety_settings, + ) + + def generate_content(self, prompts: Union[str, List[str]]) -> str: + """ + Generate text or content based on the provided prompts. + + Args: + prompts (str | list[str]): The input text prompt(s). + + Returns: + str: The generated response text. + """ + # Ensure prompts are in a list format for the API + if isinstance(prompts, str): + prompts = [prompts] + + try: + response = self.model.generate_content(prompts) + return response.text + except Exception as e: + print(f"An error occurred during content generation: {e}") + # Depending on desired error handling, you might want to return None, + # an empty string, or re-raise the exception. + return "" + + +# Example usage: +if __name__ == "__main__": + # --- IMPORTANT --- + # Replace "YOUR_API_KEY" with your actual Google Generative AI API key + # It's best practice to load this from an environment variable or a + # secure config file instead of hardcoding it. + API_KEY = "YOUR_API_KEY" + + if API_KEY == "YOUR_API_KEY": + print("Please replace 'YOUR_API_KEY' with your actual API key to run the example.") + else: + # --- Example 1: Using default settings --- + print("--- Running with default settings ---") + try: + default_gemini = GeminiModel(api_key=API_KEY, model_name="gemini-1.5-flash") + response_default = default_gemini.generate_content("Hello! Write a one-sentence greeting.") + print(f"Default Response: {response_default}\n") + except Exception as e: + print(f"Error with default settings: {e}\n") + + # --- Example 2: Custom configuration --- + print("--- Running with custom settings ---") + try: + # Example: Custom configuration for more creative output + custom_gen_config = {"temperature": 0.8, "max_output_tokens": 512} + + # Example: Custom safety settings (use with caution) + # This example blocks nothing. + custom_safety = [{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}] + + custom_gemini = GeminiModel( + api_key=API_KEY, + model_name="gemini-1.5-flash", + generation_config=custom_gen_config, + safety_settings=custom_safety + ) + + response_custom = custom_gemini.generate_content("Write a poetic, one-sentence introduction to AI.") + print(f"Custom Response: {response_custom}\n") + except Exception as e: + print(f"Error with custom settings: {e}\n")