Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 51 additions & 51 deletions evaluation_harness/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,62 +3,62 @@
import numpy as np
from PIL import Image
from skimage.metrics import structural_similarity as ssim
from transformers import (
Blip2ForConditionalGeneration,
Blip2Processor,
)

from openai import OpenAI
import base64
import io
import os
from browser_env.utils import pil_to_b64

def get_captioning_fn(
device, dtype, model_name: str = "Salesforce/blip2-flan-t5-xl"
) -> callable:
if "blip2" in model_name:
captioning_processor = Blip2Processor.from_pretrained(model_name)
captioning_model = Blip2ForConditionalGeneration.from_pretrained(
model_name, torch_dtype=dtype
)
else:
raise NotImplementedError(
"Only BLIP-2 models are currently supported"
)
captioning_model.to(device)

def caption_images(
images: List[Image.Image],
prompt: List[str] = None,
max_new_tokens: int = 32,
) -> List[str]:
if prompt is None:
# Perform VQA
inputs = captioning_processor(
images=images, return_tensors="pt"
).to(device, dtype)
generated_ids = captioning_model.generate(
**inputs, max_new_tokens=max_new_tokens
)
captions = captioning_processor.batch_decode(
generated_ids, skip_special_tokens=True
)
else:
# Regular captioning. Prompt is a list of strings, one for each image
assert len(images) == len(
prompt
), "Number of images and prompts must match, got {} and {}".format(
len(images), len(prompt)
)
inputs = captioning_processor(
images=images, text=prompt, return_tensors="pt"
).to(device, dtype)
generated_ids = captioning_model.generate(
**inputs, max_new_tokens=max_new_tokens
)
captions = captioning_processor.batch_decode(
generated_ids, skip_special_tokens=True
)
def get_captioning_fn(
device, dtype, model_name: str = "qwen3-vl-plus"
) -> callable:
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"), base_url=os.environ.get("OPENAI_BASE_URL"))

def caption_images(
images: List[Image.Image],
prompt: List[str] = None,
max_new_tokens: int = 300, # Increased default for API
) -> List[str]:
captions = []

# Prepare prompts
if prompt is None:
prompts = ["Describe this image in detail."] * len(images)
else:
prompts = prompt

assert len(images) == len(prompts), "Number of images and prompts must match"

return captions
for img, p in zip(images, prompts):
try:
response = client.chat.completions.create(
model=model_name,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": p},
{
"type": "image_url",
"image_url": {
"url": pil_to_b64(img)
},
},
],
}
],
max_tokens=max_new_tokens,
)
captions.append(response.choices[0].message.content)
except Exception as e:
print(f"Error calling API: {e}")
captions.append("")

return captions

return caption_images
return caption_images


def get_image_ssim(imageA, imageB):
Expand Down
8 changes: 4 additions & 4 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,15 @@ def config() -> argparse.Namespace:
parser.add_argument(
"--eval_captioning_model",
type=str,
default="Salesforce/blip2-flan-t5-xl",
choices=["Salesforce/blip2-flan-t5-xl"],
default="qwen3-vl-plus",
# choices=["Salesforce/blip2-flan-t5-xl"], # Allow other models
help="Captioning backbone for VQA-type evals.",
)
parser.add_argument(
"--captioning_model",
type=str,
default="Salesforce/blip2-flan-t5-xl",
choices=["Salesforce/blip2-flan-t5-xl", "llava-hf/llava-1.5-7b-hf"],
default="qwen3-vl-plus",
# choices=["Salesforce/blip2-flan-t5-xl", "llava-hf/llava-1.5-7b-hf"], # Allow other models
help="Captioning backbone for accessibility tree alt text.",
)

Expand Down
55 changes: 55 additions & 0 deletions tests/test_api_captioner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@

import unittest
from unittest.mock import MagicMock, patch
import sys
import os
from PIL import Image

# Add the project root to sys.path to import modules
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

from evaluation_harness import image_utils

class TestApiCaptioner(unittest.TestCase):
@patch("openai.OpenAI")
@patch.dict(os.environ, {"OPENAI_API_KEY": "test_key", "OPENAI_BASE_URL": "test_url"})
def test_get_captioning_fn_api(self, mock_openai):
# Setup mock
mock_client = MagicMock()
mock_openai.return_value = mock_client
mock_response = MagicMock()
mock_response.choices = [MagicMock(message=MagicMock(content="A test caption"))]
mock_client.chat.completions.create.return_value = mock_response

# Get the function
device = "cpu"
dtype = "float32"
model_name = "qwen3-vl-plus"
caption_fn = image_utils.get_captioning_fn(device, dtype, model_name)

# Create dummy image
image = Image.new("RGB", (100, 100), color="red")

# Test 1: VQA (no prompt provided, uses default)
captions = caption_fn([image])
self.assertEqual(captions, ["A test caption"])

# Verify API call
mock_client.chat.completions.create.assert_called()
call_args = mock_client.chat.completions.create.call_args
self.assertEqual(call_args.kwargs["model"], model_name)
self.assertEqual(len(call_args.kwargs["messages"]), 1)
self.assertEqual(call_args.kwargs["messages"][0]["role"], "user")

# Test 2: Custom Prompt
custom_prompt = ["What color is this?"]
captions = caption_fn([image], prompt=custom_prompt)
self.assertEqual(captions, ["A test caption"])

# Verify API call with custom prompt
call_args = mock_client.chat.completions.create.call_args
content = call_args.kwargs["messages"][0]["content"]
self.assertEqual(content[0]["text"], "What color is this?")

if __name__ == "__main__":
unittest.main()