diff --git a/README.md b/README.md
index eaf756d..6e091f4 100644
--- a/README.md
+++ b/README.md
@@ -23,4 +23,12 @@ This notebook calculates similarity between question and answers with different
#### 1. Calculating Similarity Between Answers With and Without Caption
- This section runs code that:
- Generates embeddings using SBERT 'all-MiniLM-L6-v2' model
- - Calculates cosine similarity between 'answers' and 'answers without caption'
\ No newline at end of file
+ - Calculates cosine similarity between 'answers' and 'answers without caption'
+
+### `/chart_qna/prompt_testing.ipynb`
+This notebook experiments with different prompts to observe the different types of questions which can be obtained
+
+#### 1. Selecting 10 Charts With Caption and Trying out Different Prompts and Checking Response
+- This section runs code that:
+ - Subsets the dataset leaving only 10 charts
+ - Generates the questions for the 10 charts with caption for all the different prompts
\ No newline at end of file
diff --git a/chart_qna/prompt_testing.ipynb b/chart_qna/prompt_testing.ipynb
new file mode 100644
index 0000000..dcae642
--- /dev/null
+++ b/chart_qna/prompt_testing.ipynb
@@ -0,0 +1,430 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Initialize the Libaries and Set Up the OpenAI Environment"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# !pip install pandas\n",
+ "# !pip install openai\n",
+ "# !pip install python-dotenv"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import re\n",
+ "import pandas as pd\n",
+ "from openai import OpenAI\n",
+ "from dotenv import load_dotenv; load_dotenv()\n",
+ "\n",
+ "api_key = os.getenv(\"OPENAI_API_KEY\") # Set up OpenAI API key in .env file in root\n",
+ "client = OpenAI(api_key=api_key)\n",
+ "\n",
+ "import warnings\n",
+ "warnings.filterwarnings('ignore')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Selecting 10 Charts With Caption and Trying out Different Prompts and Checking Response"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " imageid | \n",
+ " full_caption | \n",
+ " image_base64 | \n",
+ " domain | \n",
+ " chart_type | \n",
+ " views | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 82 | \n",
+ " private and publicsector investment in rd clas... | \n",
+ " iVBORw0KGgoAAAANSUhEUgAAA5EAAAEbCAIAAADMBJd/AA... | \n",
+ " Healthcare | \n",
+ " Bar Graph | \n",
+ " single view | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 184 | \n",
+ " Probability ratio (PR) of exceeding (heavy pr... | \n",
+ " iVBORw0KGgoAAAANSUhEUgAAAqcAAAEDCAIAAACH4jo2AA... | \n",
+ " Climate Science | \n",
+ " Line Chart | \n",
+ " composite views | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 196 | \n",
+ " Decomposition of the change in total annual c... | \n",
+ " iVBORw0KGgoAAAANSUhEUgAABCsAAAGhCAIAAADHuqkfAA... | \n",
+ " Climate Science | \n",
+ " Bar Graph | \n",
+ " single view | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 236 | \n",
+ " Projections and uncertainties for global mean ... | \n",
+ " iVBORw0KGgoAAAANSUhEUgAABCMAAAG7CAIAAAB2IMgWAA... | \n",
+ " Energy | \n",
+ " Bar Graph | \n",
+ " single view | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 290 | \n",
+ " The value of improved technology. \\nNote: Mode... | \n",
+ " iVBORw0KGgoAAAANSUhEUgAABEAAAAHrCAIAAABAfn+SAA... | \n",
+ " Energy | \n",
+ " Bar Graph | \n",
+ " composite views | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " imageid full_caption \\\n",
+ "0 82 private and publicsector investment in rd clas... \n",
+ "1 184 Probability ratio (PR) of exceeding (heavy pr... \n",
+ "2 196 Decomposition of the change in total annual c... \n",
+ "3 236 Projections and uncertainties for global mean ... \n",
+ "4 290 The value of improved technology. \\nNote: Mode... \n",
+ "\n",
+ " image_base64 domain \\\n",
+ "0 iVBORw0KGgoAAAANSUhEUgAAA5EAAAEbCAIAAADMBJd/AA... Healthcare \n",
+ "1 iVBORw0KGgoAAAANSUhEUgAAAqcAAAEDCAIAAACH4jo2AA... Climate Science \n",
+ "2 iVBORw0KGgoAAAANSUhEUgAABCsAAAGhCAIAAADHuqkfAA... Climate Science \n",
+ "3 iVBORw0KGgoAAAANSUhEUgAABCMAAAG7CAIAAAB2IMgWAA... Energy \n",
+ "4 iVBORw0KGgoAAAANSUhEUgAABEAAAAHrCAIAAABAfn+SAA... Energy \n",
+ "\n",
+ " chart_type views \n",
+ "0 Bar Graph single view \n",
+ "1 Line Chart composite views \n",
+ "2 Bar Graph single view \n",
+ "3 Bar Graph single view \n",
+ "4 Bar Graph composite views "
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "imageids = [82, 184, 196, 236, 290, 324, 332, 380, 447, 547]\n",
+ "\n",
+ "df = pd.read_csv('../data/200charts.csv')\n",
+ "df = df[df['imageid'].isin(imageids)].reset_index(drop=True)\n",
+ "df.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "prompt1 = '''\n",
+ " I have a chart **along with its caption**, and I need a list of inference-based questions generated from it. Your task is to create **questions that require interpretation, trend analysis, and reasoning**, rather than just retrieving values.\n",
+ "\n",
+ " ### **Guidelines for Question Generation:** \n",
+ " 1. **Encourage trend identification and pattern recognition.** \n",
+ " - Instead of simply asking for numerical values, prompt reasoning about **how** or **why** trends occur. \n",
+ " - Example: ✅ *\"How does the trend in renewable energy consumption compare to fossil fuel consumption over the past decade?\"* \n",
+ " - ❌ Avoid: *\"What was the renewable energy consumption in 2020?\"* \n",
+ "\n",
+ " 2. **Emphasize comparisons, correlations, and cause-effect relationships (when suggested by the chart).** \n",
+ " - Look for **patterns between different variables** and form questions that explore their relationship. \n",
+ " - Example: ✅ *\"Based on the trend shown, how might an increase in electric vehicle sales impact oil consumption?\"* \n",
+ "\n",
+ " 3. **Encourage reasoning-based and predictive questions.** \n",
+ " - Questions should encourage logical inference rather than direct retrieval of numbers. \n",
+ " - Example: ✅ *\"Given the declining trend in traditional media consumption, what can we infer about digital media’s dominance in the coming years?\"* \n",
+ "\n",
+ " 4. **Ensure all questions are fully answerable using the given chart.** \n",
+ " - The chart should contain enough information to support a logical answer. \n",
+ " - Example: ✅ *\"What does the chart suggest about the relationship between inflation and consumer spending?\"* \n",
+ " - ❌ Avoid: *\"What are the main reasons behind the rise in inflation?\"* (Requires external knowledge) \n",
+ "\n",
+ " ### **Format for Output:** \n",
+ " Generate a numbered list of refined inference-based questions: \n",
+ "\n",
+ " 1. (Generated question 1) \n",
+ " 2. (Generated question 2) \n",
+ " 3. (Generated question 3) \n",
+ " 4. (Generated question 4) \n",
+ "\n",
+ " Now, generate the list of inference-based questions based on the attached chart and caption.\n",
+ "\n",
+ "'''\n",
+ "\n",
+ "prompt2 = '''\n",
+ " I have a chart **along with its caption**, and I need a list of **analytical and reasoning-based questions** generated from it. \n",
+ " Your task is to create **questions that require interpretation, pattern recognition, causal reasoning, \n",
+ " and forecasting—rather than simply retrieving values.** \n",
+ "\n",
+ " ### **Guidelines for Question Generation:** \n",
+ " 1. **Adapt to the specific chart** \n",
+ " - Select the most relevant question types **based on the data presented** rather than forcing a fixed distribution of question types. \n",
+ "\n",
+ " 2. **Encourage trend identification and pattern recognition** \n",
+ " - ✅ Example: *\"How does the trend in renewable energy consumption compare to fossil fuel consumption over the past decade?\"* \n",
+ " - ❌ Avoid: *\"What was the renewable energy consumption in 2020?\"* (Pure retrieval) \n",
+ "\n",
+ " 3. **Use inference-based reasoning to connect data points** \n",
+ " - ✅ Example: *\"What does the trend in inflation suggest about changes in consumer spending patterns?\"* \n",
+ "\n",
+ " 4. **Incorporate explanatory (cause-effect) questions when relationships are implied in the data** \n",
+ " - ✅ Example: *\"What pattern in the chart might explain the sharp drop in production costs in 2018?\"* \n",
+ "\n",
+ " 5. **Use counterfactual questions only when meaningful scenarios exist in the chart** \n",
+ " - ✅ Example: *\"If tax rates had remained unchanged in 2019, how might economic growth have differed?\"* \n",
+ "\n",
+ " 6. **Include predictive questions only if the trend is clear and projectable** \n",
+ " - ✅ Example: *\"If the current trend continues, what would be the projected GDP in 2030?\"* \n",
+ "\n",
+ " 7. **Prioritize evaluative, anomaly detection, mechanistic, analogical, and conceptual questions only if applicable to the data** \n",
+ " - ✅ Example: *\"Which investment sector demonstrated the most stable returns over the last decade?\"* \n",
+ " - ✅ Example: *\"Which year deviates the most from the expected trend in GDP growth?\"* \n",
+ "\n",
+ " ### **Constraints:** \n",
+ " ✅ **Do not force one question per category—choose questions dynamically.** \n",
+ " ✅ **All questions must be fully answerable using only the given chart.** \n",
+ " ✅ **Avoid pure retrieval-based questions.** \n",
+ "\n",
+ " ### **Format for Output:** \n",
+ " Generate a numbered list of refined questions: \n",
+ "\n",
+ " 1. (Generated question 1) \n",
+ " 2. (Generated question 2) \n",
+ " 3. (Generated question 3) \n",
+ " 4. (Generated question 4) \n",
+ "\n",
+ " Now, generate the list of refined questions based on the attached chart and caption.\n",
+ "\n",
+ "'''\n",
+ "\n",
+ "prompt3 = '''\n",
+ " I have a chart **along with its caption**, and I need a list of questions generated from it. Your task is to create **open-ended and thought-provoking questions** that encourage deeper exploration, reasoning, and interpretation while ensuring they are fully answerable using only the given chart. \n",
+ "\n",
+ " ### **Guidelines for Question Generation:** \n",
+ "\n",
+ " 1. **Encourage reflection on underlying themes and insights.** \n",
+ " - Instead of focusing on direct numerical retrieval, prompt discussion on **what the data implies** or **what patterns reveal**. \n",
+ " - Example: ✅ *\"What underlying factors might explain the changes in the trend observed in the chart?\"* \n",
+ "\n",
+ " 2. **Focus on implications and broader interpretations.** \n",
+ " - Questions should encourage reasoning about **how different aspects of the data relate to each other** rather than just reporting values. \n",
+ " - Example: ✅ *\"What potential consequences might arise if the trend shown in the chart continues?\"* \n",
+ "\n",
+ " 3. **Encourage critical thinking and alternative perspectives.** \n",
+ " - Ask about **possible explanations** for patterns or **alternative ways** the data could have been presented. \n",
+ " - Example: ✅ *\"How would a different visualization (e.g., line chart vs. bar chart) change the way we interpret this data?\"* \n",
+ "\n",
+ " 4. **Highlight areas of uncertainty or data limitations.** \n",
+ " - Prompt awareness of **what the chart does not show** and encourage reasoning within those constraints. \n",
+ " - Example: ✅ *\"What are some key insights missing from this chart that would help provide a more complete picture?\"* \n",
+ "\n",
+ " 5. **Explore hypothetical and counterfactual reasoning.** \n",
+ " - Create **\"What if?\"** scenarios that are grounded in the data but push the reader to think beyond its immediate representation. \n",
+ " - Example: ✅ *\"If one category had been excluded from this chart, how would that impact our understanding of the trend?\"* \n",
+ "\n",
+ " ### **Format for Output:** \n",
+ " Generate a numbered list of refined open-ended questions: \n",
+ "\n",
+ " 1. (Generated question 1) \n",
+ " 2. (Generated question 2) \n",
+ " 3. (Generated question 3) \n",
+ " 4. (Generated question 4) \n",
+ "\n",
+ " Now, generate the list of exploratory and conceptual questions based on the attached chart and caption.\n",
+ "\n",
+ "'''\n",
+ "\n",
+ "prompts = [prompt1, prompt2, prompt3]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Processed row 0\n",
+ "Processed row 1\n",
+ "Processed row 2\n",
+ "Processed row 3\n",
+ "Processed row 4\n",
+ "Processed row 5\n",
+ "Processed row 6\n",
+ "Processed row 7\n",
+ "Processed row 8\n",
+ "Processed row 9\n",
+ "Saved ../data/prompt_testing/q_prompt1.csv\n",
+ "Processed row 0\n",
+ "Processed row 1\n",
+ "Processed row 2\n",
+ "Processed row 3\n",
+ "Processed row 4\n",
+ "Processed row 5\n",
+ "Processed row 6\n",
+ "Processed row 7\n",
+ "Processed row 8\n",
+ "Processed row 9\n",
+ "Saved ../data/prompt_testing/q_prompt2.csv\n",
+ "Processed row 0\n",
+ "Processed row 1\n",
+ "Processed row 2\n",
+ "Processed row 3\n",
+ "Processed row 4\n",
+ "Processed row 5\n",
+ "Processed row 6\n",
+ "Processed row 7\n",
+ "Processed row 8\n",
+ "Processed row 9\n",
+ "Saved ../data/prompt_testing/q_prompt3.csv\n"
+ ]
+ }
+ ],
+ "source": [
+ "max_epochs = 10\n",
+ "\n",
+ "for i, prompt in enumerate(prompts, start=1): # Iterate over prompts with index starting from 1\n",
+ " current_epoch = 0 \n",
+ " results = []\n",
+ "\n",
+ " for idx, row in df.iterrows():\n",
+ " if current_epoch >= max_epochs: # Ensure we don't exceed max_epochs\n",
+ " break \n",
+ "\n",
+ " chart = row[\"image_base64\"] \n",
+ " caption = row['full_caption']\n",
+ "\n",
+ " try:\n",
+ " response = client.chat.completions.create(\n",
+ " model=\"chatgpt-4o-latest\",\n",
+ " messages=[\n",
+ " {\n",
+ " \"role\": \"system\",\n",
+ " \"content\": \"Give me a maximum of 10 questions.\"\n",
+ " },\n",
+ " {\n",
+ " \"role\": \"user\",\n",
+ " \"content\": [\n",
+ " {\n",
+ " \"type\": \"text\",\n",
+ " \"text\": prompt, \n",
+ " },\n",
+ " {\n",
+ " \"type\": \"text\",\n",
+ " \"text\": \"Caption: \" + caption,\n",
+ " },\n",
+ " {\n",
+ " \"type\": \"image_url\",\n",
+ " \"image_url\": {\n",
+ " \"url\": f\"data:image/png;base64,{chart}\"\n",
+ " },\n",
+ " },\n",
+ " ],\n",
+ " }\n",
+ " ],\n",
+ " )\n",
+ "\n",
+ " # Extract questions from response\n",
+ " questions = response.choices[0].message.content\n",
+ " questions = re.findall(r\"\\d+\\.\\s(.+?)(?=\\n|$)\", questions)\n",
+ " questions = [q.rstrip() for q in questions]\n",
+ "\n",
+ " # Store result with imageid and extracted questions\n",
+ " result_entry = {'imageid': row['imageid']}\n",
+ " for q_num in range(1, 11):\n",
+ " result_entry[f'Q{q_num}'] = questions[q_num - 1]\n",
+ " results.append(result_entry)\n",
+ "\n",
+ " print(f\"Processed row {idx}\")\n",
+ " current_epoch += 1 \n",
+ "\n",
+ " except Exception as e:\n",
+ " print(f\"Error processing row {idx}: {e}\")\n",
+ " current_epoch += 1 \n",
+ "\n",
+ " # Convert results to DataFrame and save dynamically\n",
+ " q_df = pd.DataFrame(results)\n",
+ " output_path = f\"../data/prompt_testing/q_prompt{i}.csv\"\n",
+ " q_df.to_csv(output_path, index=False)\n",
+ "\n",
+ " print(f\"Saved {output_path}\")"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.11"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/chart_qna/sbert_cosine.ipynb b/chart_qna/sbert_cosine.ipynb
index 3571483..8c6e41d 100644
--- a/chart_qna/sbert_cosine.ipynb
+++ b/chart_qna/sbert_cosine.ipynb
@@ -41,7 +41,7 @@
"outputs": [],
"source": [
"# Load the CSV file\n",
- "file_path = \"./qa_200_singleview.csv\" # Update this with your local file path\n",
+ "file_path = \"../data/qa_200_singleview.csv\" # Update this with your local file path\n",
"data = pd.read_csv(file_path)\n",
"\n",
"# Initialize the SentenceTransformer model\n",
@@ -74,7 +74,7 @@
" print(f\"Columns {a_col} or {awc_col} not found in the DataFrame.\")\n",
"\n",
"# Save the updated DataFrame to a new file\n",
- "data.to_csv(\"./qa_200_singleview.csv\", index=False)\n",
+ "data.to_csv(\"../data/qa_200_singleview.csv\", index=False)\n",
"data.head()"
]
}
diff --git a/chart_qna/single_view_qa.ipynb b/chart_qna/single_view_qa.ipynb
index 1e898ff..5531bb0 100644
--- a/chart_qna/single_view_qa.ipynb
+++ b/chart_qna/single_view_qa.ipynb
@@ -157,7 +157,7 @@
}
],
"source": [
- "df = pd.read_csv('./200charts.csv')\n",
+ "df = pd.read_csv('../data/200charts.csv')\n",
"df = df[df['views'] == 'single view'].reset_index(drop=True)\n",
"df.head()"
]
@@ -733,7 +733,7 @@
"\n",
"# Create a new dataframe from the results list\n",
"qa_df = pd.DataFrame(results)\n",
- "qa_df.to_csv(\"qa_200_singleview.csv\", index=False)\n",
+ "qa_df.to_csv(\"../data/qa_200_singleview.csv\", index=False)\n",
"\n",
"qa_df.head()"
]
@@ -747,7 +747,7 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -858,7 +858,7 @@
}
],
"source": [
- "df = pd.read_csv('../200charts.csv')\n",
+ "df = pd.read_csv('../data/200charts.csv')\n",
"df = df[df['views'] == 'single view'].reset_index(drop=True)\n",
"df.head()"
]
@@ -1444,7 +1444,7 @@
"\n",
"# Create a new dataframe from the results list\n",
"qa_df = pd.DataFrame(results)\n",
- "qa_df.to_csv(\"qa_withcaption_200_singleview.csv\", index=False)\n",
+ "qa_df.to_csv(\"../data/qa_withcaption_200_singleview.csv\", index=False)\n",
"\n",
"qa_df.head()"
]