From d60467b4433d4f25ddaa97924ead05d7d12b6e0a Mon Sep 17 00:00:00 2001 From: vai Date: Wed, 26 Feb 2025 16:14:51 -0500 Subject: [PATCH] made small changes to .ipynbs and added prompt_testing.ipynb --- README.md | 10 +- chart_qna/prompt_testing.ipynb | 430 +++++++++++++++++++++++++++++++++ chart_qna/sbert_cosine.ipynb | 4 +- chart_qna/single_view_qa.ipynb | 10 +- 4 files changed, 446 insertions(+), 8 deletions(-) create mode 100644 chart_qna/prompt_testing.ipynb diff --git a/README.md b/README.md index eaf756d..6e091f4 100644 --- a/README.md +++ b/README.md @@ -23,4 +23,12 @@ This notebook calculates similarity between question and answers with different #### 1. Calculating Similarity Between Answers With and Without Caption - This section runs code that: - Generates embeddings using SBERT 'all-MiniLM-L6-v2' model - - Calculates cosine similarity between 'answers' and 'answers without caption' \ No newline at end of file + - Calculates cosine similarity between 'answers' and 'answers without caption' + +### `/chart_qna/prompt_testing.ipynb` +This notebook experiments with different prompts to observe the different types of questions which can be obtained + +#### 1. Selecting 10 Charts With Caption and Trying out Different Prompts and Checking Response +- This section runs code that: + - Subsets the dataset leaving only 10 charts + - Generates the questions for the 10 charts with caption for all the different prompts \ No newline at end of file diff --git a/chart_qna/prompt_testing.ipynb b/chart_qna/prompt_testing.ipynb new file mode 100644 index 0000000..dcae642 --- /dev/null +++ b/chart_qna/prompt_testing.ipynb @@ -0,0 +1,430 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Initialize the Libaries and Set Up the OpenAI Environment" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "# !pip install pandas\n", + "# !pip install openai\n", + "# !pip install python-dotenv" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import re\n", + "import pandas as pd\n", + "from openai import OpenAI\n", + "from dotenv import load_dotenv; load_dotenv()\n", + "\n", + "api_key = os.getenv(\"OPENAI_API_KEY\") # Set up OpenAI API key in .env file in root\n", + "client = OpenAI(api_key=api_key)\n", + "\n", + "import warnings\n", + "warnings.filterwarnings('ignore')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Selecting 10 Charts With Caption and Trying out Different Prompts and Checking Response" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
imageidfull_captionimage_base64domainchart_typeviews
082private and publicsector investment in rd clas...iVBORw0KGgoAAAANSUhEUgAAA5EAAAEbCAIAAADMBJd/AA...HealthcareBar Graphsingle view
1184Probability ratio (PR) of exceeding (heavy pr...iVBORw0KGgoAAAANSUhEUgAAAqcAAAEDCAIAAACH4jo2AA...Climate ScienceLine Chartcomposite views
2196Decomposition of the change in total annual c...iVBORw0KGgoAAAANSUhEUgAABCsAAAGhCAIAAADHuqkfAA...Climate ScienceBar Graphsingle view
3236Projections and uncertainties for global mean ...iVBORw0KGgoAAAANSUhEUgAABCMAAAG7CAIAAAB2IMgWAA...EnergyBar Graphsingle view
4290The value of improved technology. \\nNote: Mode...iVBORw0KGgoAAAANSUhEUgAABEAAAAHrCAIAAABAfn+SAA...EnergyBar Graphcomposite views
\n", + "
" + ], + "text/plain": [ + " imageid full_caption \\\n", + "0 82 private and publicsector investment in rd clas... \n", + "1 184 Probability ratio (PR) of exceeding (heavy pr... \n", + "2 196 Decomposition of the change in total annual c... \n", + "3 236 Projections and uncertainties for global mean ... \n", + "4 290 The value of improved technology. \\nNote: Mode... \n", + "\n", + " image_base64 domain \\\n", + "0 iVBORw0KGgoAAAANSUhEUgAAA5EAAAEbCAIAAADMBJd/AA... Healthcare \n", + "1 iVBORw0KGgoAAAANSUhEUgAAAqcAAAEDCAIAAACH4jo2AA... Climate Science \n", + "2 iVBORw0KGgoAAAANSUhEUgAABCsAAAGhCAIAAADHuqkfAA... Climate Science \n", + "3 iVBORw0KGgoAAAANSUhEUgAABCMAAAG7CAIAAAB2IMgWAA... Energy \n", + "4 iVBORw0KGgoAAAANSUhEUgAABEAAAAHrCAIAAABAfn+SAA... Energy \n", + "\n", + " chart_type views \n", + "0 Bar Graph single view \n", + "1 Line Chart composite views \n", + "2 Bar Graph single view \n", + "3 Bar Graph single view \n", + "4 Bar Graph composite views " + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "imageids = [82, 184, 196, 236, 290, 324, 332, 380, 447, 547]\n", + "\n", + "df = pd.read_csv('../data/200charts.csv')\n", + "df = df[df['imageid'].isin(imageids)].reset_index(drop=True)\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "prompt1 = '''\n", + " I have a chart **along with its caption**, and I need a list of inference-based questions generated from it. Your task is to create **questions that require interpretation, trend analysis, and reasoning**, rather than just retrieving values.\n", + "\n", + " ### **Guidelines for Question Generation:** \n", + " 1. **Encourage trend identification and pattern recognition.** \n", + " - Instead of simply asking for numerical values, prompt reasoning about **how** or **why** trends occur. \n", + " - Example: ✅ *\"How does the trend in renewable energy consumption compare to fossil fuel consumption over the past decade?\"* \n", + " - ❌ Avoid: *\"What was the renewable energy consumption in 2020?\"* \n", + "\n", + " 2. **Emphasize comparisons, correlations, and cause-effect relationships (when suggested by the chart).** \n", + " - Look for **patterns between different variables** and form questions that explore their relationship. \n", + " - Example: ✅ *\"Based on the trend shown, how might an increase in electric vehicle sales impact oil consumption?\"* \n", + "\n", + " 3. **Encourage reasoning-based and predictive questions.** \n", + " - Questions should encourage logical inference rather than direct retrieval of numbers. \n", + " - Example: ✅ *\"Given the declining trend in traditional media consumption, what can we infer about digital media’s dominance in the coming years?\"* \n", + "\n", + " 4. **Ensure all questions are fully answerable using the given chart.** \n", + " - The chart should contain enough information to support a logical answer. \n", + " - Example: ✅ *\"What does the chart suggest about the relationship between inflation and consumer spending?\"* \n", + " - ❌ Avoid: *\"What are the main reasons behind the rise in inflation?\"* (Requires external knowledge) \n", + "\n", + " ### **Format for Output:** \n", + " Generate a numbered list of refined inference-based questions: \n", + "\n", + " 1. (Generated question 1) \n", + " 2. (Generated question 2) \n", + " 3. (Generated question 3) \n", + " 4. (Generated question 4) \n", + "\n", + " Now, generate the list of inference-based questions based on the attached chart and caption.\n", + "\n", + "'''\n", + "\n", + "prompt2 = '''\n", + " I have a chart **along with its caption**, and I need a list of **analytical and reasoning-based questions** generated from it. \n", + " Your task is to create **questions that require interpretation, pattern recognition, causal reasoning, \n", + " and forecasting—rather than simply retrieving values.** \n", + "\n", + " ### **Guidelines for Question Generation:** \n", + " 1. **Adapt to the specific chart** \n", + " - Select the most relevant question types **based on the data presented** rather than forcing a fixed distribution of question types. \n", + "\n", + " 2. **Encourage trend identification and pattern recognition** \n", + " - ✅ Example: *\"How does the trend in renewable energy consumption compare to fossil fuel consumption over the past decade?\"* \n", + " - ❌ Avoid: *\"What was the renewable energy consumption in 2020?\"* (Pure retrieval) \n", + "\n", + " 3. **Use inference-based reasoning to connect data points** \n", + " - ✅ Example: *\"What does the trend in inflation suggest about changes in consumer spending patterns?\"* \n", + "\n", + " 4. **Incorporate explanatory (cause-effect) questions when relationships are implied in the data** \n", + " - ✅ Example: *\"What pattern in the chart might explain the sharp drop in production costs in 2018?\"* \n", + "\n", + " 5. **Use counterfactual questions only when meaningful scenarios exist in the chart** \n", + " - ✅ Example: *\"If tax rates had remained unchanged in 2019, how might economic growth have differed?\"* \n", + "\n", + " 6. **Include predictive questions only if the trend is clear and projectable** \n", + " - ✅ Example: *\"If the current trend continues, what would be the projected GDP in 2030?\"* \n", + "\n", + " 7. **Prioritize evaluative, anomaly detection, mechanistic, analogical, and conceptual questions only if applicable to the data** \n", + " - ✅ Example: *\"Which investment sector demonstrated the most stable returns over the last decade?\"* \n", + " - ✅ Example: *\"Which year deviates the most from the expected trend in GDP growth?\"* \n", + "\n", + " ### **Constraints:** \n", + " ✅ **Do not force one question per category—choose questions dynamically.** \n", + " ✅ **All questions must be fully answerable using only the given chart.** \n", + " ✅ **Avoid pure retrieval-based questions.** \n", + "\n", + " ### **Format for Output:** \n", + " Generate a numbered list of refined questions: \n", + "\n", + " 1. (Generated question 1) \n", + " 2. (Generated question 2) \n", + " 3. (Generated question 3) \n", + " 4. (Generated question 4) \n", + "\n", + " Now, generate the list of refined questions based on the attached chart and caption.\n", + "\n", + "'''\n", + "\n", + "prompt3 = '''\n", + " I have a chart **along with its caption**, and I need a list of questions generated from it. Your task is to create **open-ended and thought-provoking questions** that encourage deeper exploration, reasoning, and interpretation while ensuring they are fully answerable using only the given chart. \n", + "\n", + " ### **Guidelines for Question Generation:** \n", + "\n", + " 1. **Encourage reflection on underlying themes and insights.** \n", + " - Instead of focusing on direct numerical retrieval, prompt discussion on **what the data implies** or **what patterns reveal**. \n", + " - Example: ✅ *\"What underlying factors might explain the changes in the trend observed in the chart?\"* \n", + "\n", + " 2. **Focus on implications and broader interpretations.** \n", + " - Questions should encourage reasoning about **how different aspects of the data relate to each other** rather than just reporting values. \n", + " - Example: ✅ *\"What potential consequences might arise if the trend shown in the chart continues?\"* \n", + "\n", + " 3. **Encourage critical thinking and alternative perspectives.** \n", + " - Ask about **possible explanations** for patterns or **alternative ways** the data could have been presented. \n", + " - Example: ✅ *\"How would a different visualization (e.g., line chart vs. bar chart) change the way we interpret this data?\"* \n", + "\n", + " 4. **Highlight areas of uncertainty or data limitations.** \n", + " - Prompt awareness of **what the chart does not show** and encourage reasoning within those constraints. \n", + " - Example: ✅ *\"What are some key insights missing from this chart that would help provide a more complete picture?\"* \n", + "\n", + " 5. **Explore hypothetical and counterfactual reasoning.** \n", + " - Create **\"What if?\"** scenarios that are grounded in the data but push the reader to think beyond its immediate representation. \n", + " - Example: ✅ *\"If one category had been excluded from this chart, how would that impact our understanding of the trend?\"* \n", + "\n", + " ### **Format for Output:** \n", + " Generate a numbered list of refined open-ended questions: \n", + "\n", + " 1. (Generated question 1) \n", + " 2. (Generated question 2) \n", + " 3. (Generated question 3) \n", + " 4. (Generated question 4) \n", + "\n", + " Now, generate the list of exploratory and conceptual questions based on the attached chart and caption.\n", + "\n", + "'''\n", + "\n", + "prompts = [prompt1, prompt2, prompt3]" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Processed row 0\n", + "Processed row 1\n", + "Processed row 2\n", + "Processed row 3\n", + "Processed row 4\n", + "Processed row 5\n", + "Processed row 6\n", + "Processed row 7\n", + "Processed row 8\n", + "Processed row 9\n", + "Saved ../data/prompt_testing/q_prompt1.csv\n", + "Processed row 0\n", + "Processed row 1\n", + "Processed row 2\n", + "Processed row 3\n", + "Processed row 4\n", + "Processed row 5\n", + "Processed row 6\n", + "Processed row 7\n", + "Processed row 8\n", + "Processed row 9\n", + "Saved ../data/prompt_testing/q_prompt2.csv\n", + "Processed row 0\n", + "Processed row 1\n", + "Processed row 2\n", + "Processed row 3\n", + "Processed row 4\n", + "Processed row 5\n", + "Processed row 6\n", + "Processed row 7\n", + "Processed row 8\n", + "Processed row 9\n", + "Saved ../data/prompt_testing/q_prompt3.csv\n" + ] + } + ], + "source": [ + "max_epochs = 10\n", + "\n", + "for i, prompt in enumerate(prompts, start=1): # Iterate over prompts with index starting from 1\n", + " current_epoch = 0 \n", + " results = []\n", + "\n", + " for idx, row in df.iterrows():\n", + " if current_epoch >= max_epochs: # Ensure we don't exceed max_epochs\n", + " break \n", + "\n", + " chart = row[\"image_base64\"] \n", + " caption = row['full_caption']\n", + "\n", + " try:\n", + " response = client.chat.completions.create(\n", + " model=\"chatgpt-4o-latest\",\n", + " messages=[\n", + " {\n", + " \"role\": \"system\",\n", + " \"content\": \"Give me a maximum of 10 questions.\"\n", + " },\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": [\n", + " {\n", + " \"type\": \"text\",\n", + " \"text\": prompt, \n", + " },\n", + " {\n", + " \"type\": \"text\",\n", + " \"text\": \"Caption: \" + caption,\n", + " },\n", + " {\n", + " \"type\": \"image_url\",\n", + " \"image_url\": {\n", + " \"url\": f\"data:image/png;base64,{chart}\"\n", + " },\n", + " },\n", + " ],\n", + " }\n", + " ],\n", + " )\n", + "\n", + " # Extract questions from response\n", + " questions = response.choices[0].message.content\n", + " questions = re.findall(r\"\\d+\\.\\s(.+?)(?=\\n|$)\", questions)\n", + " questions = [q.rstrip() for q in questions]\n", + "\n", + " # Store result with imageid and extracted questions\n", + " result_entry = {'imageid': row['imageid']}\n", + " for q_num in range(1, 11):\n", + " result_entry[f'Q{q_num}'] = questions[q_num - 1]\n", + " results.append(result_entry)\n", + "\n", + " print(f\"Processed row {idx}\")\n", + " current_epoch += 1 \n", + "\n", + " except Exception as e:\n", + " print(f\"Error processing row {idx}: {e}\")\n", + " current_epoch += 1 \n", + "\n", + " # Convert results to DataFrame and save dynamically\n", + " q_df = pd.DataFrame(results)\n", + " output_path = f\"../data/prompt_testing/q_prompt{i}.csv\"\n", + " q_df.to_csv(output_path, index=False)\n", + "\n", + " print(f\"Saved {output_path}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/chart_qna/sbert_cosine.ipynb b/chart_qna/sbert_cosine.ipynb index 3571483..8c6e41d 100644 --- a/chart_qna/sbert_cosine.ipynb +++ b/chart_qna/sbert_cosine.ipynb @@ -41,7 +41,7 @@ "outputs": [], "source": [ "# Load the CSV file\n", - "file_path = \"./qa_200_singleview.csv\" # Update this with your local file path\n", + "file_path = \"../data/qa_200_singleview.csv\" # Update this with your local file path\n", "data = pd.read_csv(file_path)\n", "\n", "# Initialize the SentenceTransformer model\n", @@ -74,7 +74,7 @@ " print(f\"Columns {a_col} or {awc_col} not found in the DataFrame.\")\n", "\n", "# Save the updated DataFrame to a new file\n", - "data.to_csv(\"./qa_200_singleview.csv\", index=False)\n", + "data.to_csv(\"../data/qa_200_singleview.csv\", index=False)\n", "data.head()" ] } diff --git a/chart_qna/single_view_qa.ipynb b/chart_qna/single_view_qa.ipynb index 1e898ff..5531bb0 100644 --- a/chart_qna/single_view_qa.ipynb +++ b/chart_qna/single_view_qa.ipynb @@ -157,7 +157,7 @@ } ], "source": [ - "df = pd.read_csv('./200charts.csv')\n", + "df = pd.read_csv('../data/200charts.csv')\n", "df = df[df['views'] == 'single view'].reset_index(drop=True)\n", "df.head()" ] @@ -733,7 +733,7 @@ "\n", "# Create a new dataframe from the results list\n", "qa_df = pd.DataFrame(results)\n", - "qa_df.to_csv(\"qa_200_singleview.csv\", index=False)\n", + "qa_df.to_csv(\"../data/qa_200_singleview.csv\", index=False)\n", "\n", "qa_df.head()" ] @@ -747,7 +747,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -858,7 +858,7 @@ } ], "source": [ - "df = pd.read_csv('../200charts.csv')\n", + "df = pd.read_csv('../data/200charts.csv')\n", "df = df[df['views'] == 'single view'].reset_index(drop=True)\n", "df.head()" ] @@ -1444,7 +1444,7 @@ "\n", "# Create a new dataframe from the results list\n", "qa_df = pd.DataFrame(results)\n", - "qa_df.to_csv(\"qa_withcaption_200_singleview.csv\", index=False)\n", + "qa_df.to_csv(\"../data/qa_withcaption_200_singleview.csv\", index=False)\n", "\n", "qa_df.head()" ]