-
Notifications
You must be signed in to change notification settings - Fork 3
Gemini Integration #29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Archillesjakins
wants to merge
18
commits into
shure-dev:develop
Choose a base branch
from
Archillesjakins:Gemini-integration
base: develop
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
b4b156e
Merge pull request #27 from shure-dev/develop
shure-dev c3350a3
switch openai model to gemini
Archillesjakins 2e999d6
new test_logger for gemini
Archillesjakins d694282
query gemini integration
Archillesjakins 2b31474
changed setup
Archillesjakins fe503ed
response in json format
Archillesjakins 1f930ff
modified
Archillesjakins c868afb
commits
Archillesjakins 338506d
provider for llms
Archillesjakins 9efbc31
added query for gemini provider
Archillesjakins b35664b
commit
Archillesjakins 1bfa411
modified logllm response
Archillesjakins d5dd6ce
data visualization
Archillesjakins 0669102
commit
Archillesjakins 2b8a620
modified plot_metrics throw exception None metricsor value
Archillesjakins 2cc2147
commit changes
Archillesjakins 679c9ca
plot multiple metrics
Archillesjakins 20e57b0
updated plotter to plot metrics
Archillesjakins File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,88 +1,126 @@ | ||
| from openai import OpenAI | ||
| import google.generativeai as genai | ||
| import wandb | ||
| from .extractor import extract_notebook_code | ||
| from logllm import extract_notebook_code | ||
| import json | ||
|
|
||
| import os | ||
| from dotenv import load_dotenv | ||
| from openai import OpenAI | ||
|
|
||
| def init_wandb(project_name): | ||
| wandb.init(project=project_name, settings=wandb.Settings(_disable_stats=True)) | ||
|
|
||
| def extract_experimental_conditions(code): | ||
| client = OpenAI() | ||
|
|
||
| system_prompt = """ | ||
| # You are advanced machine learning experiment designer. | ||
| # Extract all experimental conditions and results for logging via wandb api. | ||
| # Add your original params in your JSON responce if you want to log other params. | ||
| # Extract all informaiton you can find the given script as int, bool or float value. | ||
| # If you can not describe conditions with int, bool or float value, use list of natural language. | ||
| # Give advice to improve the acc. | ||
| # If you use natural language, answer should be very short. | ||
| # Do not include information already provided in param_name_1 for `condition_as_natural_langauge`. | ||
| # Output JSON schema example: | ||
| This is just a example, make it change as you want. Use nested dictionally if nessasary. | ||
| {{ | ||
| "method":"str", | ||
| "dataset":"str", | ||
| "task":"str", | ||
| "accuracy":"", | ||
| "other_param_here":{ | ||
| "other_param_here":"", | ||
| "other_param_here":"", | ||
| }, | ||
| "other_param_here":"", | ||
| ... | ||
| "condition_as_natural_langauge":["Small dataset."], | ||
| "advice_to_improve_acc":["Use bigger dataset.","Use more simple model."] | ||
| }} | ||
| """.replace(" ","") | ||
|
|
||
| user_prompt = f""" | ||
| # Here is a user's Jupyter Notebook script:{code} | ||
| """ | ||
|
|
||
| response = client.chat.completions.create( | ||
| model="gpt-4o-mini-2024-07-18", | ||
| messages=[ | ||
| {"role": "system", "content": system_prompt}, | ||
| {"role": "user", "content": user_prompt}, | ||
| ], | ||
| response_format={"type": "json_object"}, | ||
| # Load environment variables from a .env file | ||
| load_dotenv() | ||
|
|
||
| # Function to configure Google Generative AI only when needed | ||
| def configure_google_genai(): | ||
| # Set up Google Generative AI with API key and model configuration | ||
| genai.configure(api_key=os.getenv('API_KEY')) | ||
|
|
||
| # Define generation settings for the model | ||
| generation_config = { | ||
| "temperature": 0, # Controls the randomness of the output | ||
| "top_p": 0.95, # Nucleus sampling parameter | ||
| "top_k": 64, # Limits the pool of candidates to the top-k | ||
| "max_output_tokens": 8192, # Maximum number of tokens in the output | ||
| "response_mime_type": "application/json", # Expected response format | ||
| } | ||
|
|
||
| # Initialize and return the GenerativeModel instance | ||
| return genai.GenerativeModel( | ||
| model_name="gemini-1.5-flash", | ||
| generation_config=generation_config, | ||
| ) | ||
|
|
||
| # Parse the JSON string from `response.choices[0].message.content` into a dictionary | ||
| parsed_json = json.loads(response.choices[0].message.content) | ||
|
|
||
| # Format the dictionary to make it more readable (4-space indentation, non-ASCII characters displayed as is) | ||
| formatted_json = json.dumps(parsed_json, indent=4, ensure_ascii=False) | ||
|
|
||
| # Print the formatted JSON data | ||
| print(formatted_json) | ||
|
|
||
| return response.choices[0].message.content | ||
|
|
||
|
|
||
| # System prompt for guiding the AI model to extract experiment details | ||
| system_prompt = """ | ||
| You are an advanced machine learning experiment designer. | ||
| Extract all experimental conditions and results for logging via wandb API. | ||
| Add your original parameters in your JSON response if you want to log other parameters. | ||
| Extract all information you can find in the given script as int, bool, or float values. | ||
| If you cannot describe conditions with int, bool, or float values, use a list of natural language. | ||
| Give advice to improve the accuracy. | ||
| If you use natural language, the answers should be very short. | ||
| Do not include information already provided in param_name_1 for `condition_as_natural_language`. | ||
| Output JSON schema example: | ||
| This is just an example, make changes as necessary. Use nested dictionaries if necessary. | ||
| {{ | ||
| "method":"str", | ||
| "dataset":"str", | ||
| "task":"str", | ||
| "accuracy":"", | ||
| "other_param_here":{ | ||
| "other_param_here":"", | ||
| "other_param_here":"", | ||
| }, | ||
| "other_param_here":"", | ||
| ... | ||
| "condition_as_natural_language":["Small dataset."], | ||
| "advice_to_improve_acc":["Use a bigger dataset.","Use a simpler model."] | ||
| }} | ||
| """.replace(" ", "") | ||
|
|
||
| # Function to extract experimental conditions using the specified provider (Google or OpenAI) | ||
| def extract_experimental_conditions(provider, code): | ||
| # Combine system prompt with user's code input | ||
| user_input = f"{system_prompt}\n\nHere is a user's Jupyter Notebook script: {code}" | ||
|
|
||
| if provider == "gemini": | ||
| # Configure and use Google Generative AI if specified | ||
| model = configure_google_genai() | ||
| chat_session = model.start_chat( | ||
| history=[{"role": "user", "parts": ["Hello! help me analyze data in JSON format only and return only json object nothing else"]}] | ||
| ) | ||
| response = chat_session.send_message(user_input) | ||
| result = response.candidates[0].content.parts[0].text | ||
|
|
||
| elif provider == "openai": | ||
| # Use OpenAI's API to get the response | ||
| client = OpenAI() | ||
| response = client.chat.completions.create( | ||
| model="gpt-4o-mini-2024-07-18", | ||
| messages=[ | ||
| {"role": "system", "content": user_input}, | ||
| ], | ||
| response_format={"type": "json_object"}, | ||
| ) | ||
| result = response.choices[0].message.content | ||
|
|
||
| else: | ||
| # Raise an error if an invalid provider is specified | ||
| raise ValueError("Invalid provider specified. Use 'gemini' or 'openai'.") | ||
|
|
||
| # Parse the result from JSON string to Python dictionary | ||
| result = json.loads(result) | ||
| # Format the JSON output for better readability | ||
| return json.dumps(result, indent=4, ensure_ascii=False) | ||
|
|
||
| # Function to log the extracted information to Weights & Biases (W&B) | ||
| def log_to_wandb(response_text): | ||
| wandb.log(json.loads(response_text)) | ||
|
|
||
| def log_llm(notebook_path, project_name = None, is_logging = False): | ||
|
|
||
|
|
||
| if project_name is None: | ||
| project_name = notebook_path.replace(".ipynb","") | ||
|
|
||
| # Initialize W&B | ||
| try: | ||
| # Parse the JSON response and log it to W&B | ||
| response_dict = json.loads(response_text) | ||
| wandb.log(response_dict) | ||
| except (json.JSONDecodeError, Exception) as e: | ||
| # Handle errors in JSON parsing or W&B logging | ||
| print(f"Error logging to W&B: {e}") | ||
|
|
||
| # Main function to extract and log experimental conditions from a Jupyter Notebook | ||
| def log_llm(notebook_path, project_name=None, is_logging=False, provider=None): | ||
| # Use the notebook file name as the project name if not specified | ||
| project_name = project_name or os.path.basename(notebook_path).replace(".ipynb", "") | ||
| if is_logging: | ||
| # Initialize a new W&B run if logging is enabled | ||
| init_wandb(project_name) | ||
|
|
||
| # Extract code from Jupyter Notebook | ||
| # Extract the code from the notebook | ||
| code_string = extract_notebook_code(notebook_path) | ||
| # Extract the experimental conditions using the specified AI provider | ||
| parsed_json = extract_experimental_conditions(provider, code_string) | ||
|
|
||
| # Send code to OpenAI | ||
| response_text = extract_experimental_conditions(code_string) | ||
|
|
||
| # Log response to W&B | ||
| if is_logging: | ||
| log_to_wandb(response_text) | ||
| if is_logging and parsed_json: | ||
| # Log the extracted information to W&B | ||
| log_to_wandb(parsed_json) | ||
|
|
||
| print("Response from OpenAI logged to W&B.") | ||
| # Inform the user that the process is complete | ||
| print("Response from the provider processed and logged to W&B.") | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,85 @@ | ||
| import json | ||
| import numpy as np | ||
| import matplotlib.pyplot as plt | ||
|
|
||
| def plot_metrics(*models_results): | ||
| # Process each model to ensure it's in dictionary format | ||
| processed_models = [] | ||
| for model in models_results: | ||
| # Convert to dictionary if the input is a JSON string | ||
| if isinstance(model, str): | ||
| model = json.loads(model) # Convert string to dictionary | ||
| processed_models.append(model) | ||
|
|
||
| # Define keys to exclude from the plot | ||
| exclude_keys = {'cache_size', 'random_state_tts', 'random_state', 'random_state_1', 'n_estimators'} | ||
|
|
||
| # Initialize containers for metrics, values, and model names | ||
| metrics = [] | ||
| model_names = [] | ||
| model_values = {} | ||
|
|
||
| for model in processed_models: | ||
| model_name = model.get("model_name", "Test Model") | ||
| model_names.append(model_name) | ||
|
|
||
| for key, value in model.items(): | ||
| if key.startswith("result_name_"): | ||
| metric_name = value | ||
| metric_index = key.split("_")[-1] # Extract the index (e.g., "1" from "result_name_1") | ||
| metric_value = model.get(f"result_value_{metric_index}", None) | ||
|
|
||
| if metric_value is not None: | ||
| if metric_name not in metrics: | ||
| metrics.append(metric_name) | ||
| if metric_name not in model_values: | ||
| model_values[metric_name] = [] | ||
|
|
||
| # Add the metric value to the list | ||
| model_values[metric_name].append(metric_value) | ||
|
|
||
| # Handle additional numeric values in the model (not using the result_name format) | ||
| for key, value in model.items(): | ||
| # Exclude specific keys and ensure the value is numeric | ||
| if key not in exclude_keys and isinstance(value, (int, float)): | ||
| if key not in metrics: | ||
| metrics.append(key) | ||
| if key not in model_values: | ||
| model_values[key] = [] | ||
| model_values[key].append(value) | ||
|
|
||
| # Handle cases where no valid metrics were provided | ||
| if not metrics or not model_names: | ||
| print("No valid metrics or model names found.") | ||
| return | ||
|
|
||
| # Ensure all models have values for all metrics, filling in with 0 if not available | ||
| for metric in metrics: | ||
| for i in range(len(model_names)): | ||
| if len(model_values[metric]) <= i: | ||
| model_values[metric].append(0) # Default value if missing | ||
|
|
||
| # Plotting side-by-side bar chart | ||
| x = np.arange(len(metrics)) # Label locations | ||
| bar_width = 0.15 # Width of the bars | ||
| fig, ax = plt.subplots(figsize=(10, 6)) | ||
|
|
||
| # Create a bar for each model's performance metrics | ||
| for i, model_name in enumerate(model_names): | ||
| values = [model_values[metric][i] for metric in metrics] | ||
| ax.bar(x + i * bar_width, values, width=bar_width, label=model_name) | ||
|
|
||
| # Customization of the plot | ||
| ax.set_xlabel('Metric', fontsize=14) | ||
| ax.set_ylabel('Value', fontsize=14) | ||
| ax.set_title('Comparison of Model Performance Metrics', fontsize=16) | ||
| ax.set_xticks(x + bar_width * (len(model_names) - 1) / 2) | ||
| ax.set_xticklabels(metrics, fontsize=12) | ||
| ax.legend(title='Models') | ||
| ax.grid(True, axis='y', linestyle='--', alpha=0.7) | ||
|
|
||
| plt.xticks(rotation=45, ha='right') # Rotate x-axis labels for better readability | ||
| plt.tight_layout() | ||
| plt.show() | ||
|
|
||
|
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,24 +1,76 @@ | ||
| from openai import OpenAI | ||
| import google.generativeai as genai | ||
| import os | ||
| from dotenv import load_dotenv | ||
| import openai | ||
| from logllm.log_llm import log_llm | ||
|
|
||
| def query(user_input: str): | ||
| client = OpenAI() | ||
| # Load environment variables | ||
| load_dotenv() | ||
|
|
||
| # Configure Google Generative AI API Key | ||
| genai.configure(api_key=os.getenv('API_KEY')) | ||
| generation_config = { | ||
| "temperature": 0, | ||
| "top_p": 0.95, | ||
| "top_k": 64, | ||
| "max_output_tokens": 8192, | ||
| "response_mime_type": "text/plain", | ||
| } | ||
|
|
||
| # Function to query OpenAI | ||
| def query_openai(user_input: str): | ||
|
|
||
| system_prompt = """ | ||
| # Convert the following query to a W&B API query: | ||
| """.replace(" ","") | ||
| Convert the following query to a W&B API query: | ||
| """.strip() | ||
|
|
||
| user_prompt = f""" | ||
| # Here is a user's :{user_input} | ||
| """ | ||
| Here is a user's query: {user_input} | ||
| """.strip() | ||
|
|
||
| response = client.chat.completions.create( | ||
| response = openai.ChatCompletion.create( | ||
| model="gpt-4o-mini-2024-07-18", | ||
| messages=[ | ||
| {"role": "system", "content": system_prompt}, | ||
| {"role": "user", "content": user_prompt}, | ||
| ] | ||
| ) | ||
|
|
||
| print(response) | ||
|
|
||
|
|
||
| return response['choices'][0]['message']['content'] | ||
|
|
||
| # Function to query Google Gemini | ||
| def query_gemini(user_input: str, code): | ||
| model = genai.GenerativeModel("gemini-1.5-flash", generation_config=generation_config) | ||
| user_input = f"{code}" | ||
|
|
||
| system_prompt = """ | ||
| Please provide the data you want me to convert to a W&B API query: | ||
| """.strip() | ||
|
|
||
| user_prompt = f""" | ||
| Here is a user's query: {user_input} | ||
| """ | ||
|
|
||
| chat_session = model.start_chat( | ||
| history=[ | ||
| {"role": "model", "parts": [system_prompt]}, | ||
| {"role": "user", "parts": [user_prompt]}, | ||
| ] | ||
| ) | ||
|
|
||
| response = chat_session.send_message(user_prompt) | ||
| return response.candidates[0].content.parts[0].text | ||
|
|
||
| # General query function that calls the appropriate provider | ||
|
|
||
| def query(provider): | ||
| if provider == 'openai': | ||
| return query_openai() | ||
| elif provider == 'gemini': | ||
| return query_gemini() | ||
| else: | ||
| raise ValueError("Invalid provider specified. Use 'openai' or 'gemini'.") | ||
|
|
||
|
|
||
| # Usage Example: | ||
|
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
JSON is not necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes please check on the suggestions I made on the issues post.