diff --git a/evaluation_harness/image_utils.py b/evaluation_harness/image_utils.py index da6782e..ad81644 100644 --- a/evaluation_harness/image_utils.py +++ b/evaluation_harness/image_utils.py @@ -3,62 +3,62 @@ import numpy as np from PIL import Image from skimage.metrics import structural_similarity as ssim -from transformers import ( - Blip2ForConditionalGeneration, - Blip2Processor, -) +from openai import OpenAI +import base64 +import io +import os +from browser_env.utils import pil_to_b64 -def get_captioning_fn( - device, dtype, model_name: str = "Salesforce/blip2-flan-t5-xl" -) -> callable: - if "blip2" in model_name: - captioning_processor = Blip2Processor.from_pretrained(model_name) - captioning_model = Blip2ForConditionalGeneration.from_pretrained( - model_name, torch_dtype=dtype - ) - else: - raise NotImplementedError( - "Only BLIP-2 models are currently supported" - ) - captioning_model.to(device) - def caption_images( - images: List[Image.Image], - prompt: List[str] = None, - max_new_tokens: int = 32, - ) -> List[str]: - if prompt is None: - # Perform VQA - inputs = captioning_processor( - images=images, return_tensors="pt" - ).to(device, dtype) - generated_ids = captioning_model.generate( - **inputs, max_new_tokens=max_new_tokens - ) - captions = captioning_processor.batch_decode( - generated_ids, skip_special_tokens=True - ) - else: - # Regular captioning. Prompt is a list of strings, one for each image - assert len(images) == len( - prompt - ), "Number of images and prompts must match, got {} and {}".format( - len(images), len(prompt) - ) - inputs = captioning_processor( - images=images, text=prompt, return_tensors="pt" - ).to(device, dtype) - generated_ids = captioning_model.generate( - **inputs, max_new_tokens=max_new_tokens - ) - captions = captioning_processor.batch_decode( - generated_ids, skip_special_tokens=True - ) +def get_captioning_fn( + device, dtype, model_name: str = "qwen3-vl-plus" +) -> callable: + client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"), base_url=os.environ.get("OPENAI_BASE_URL")) + + def caption_images( + images: List[Image.Image], + prompt: List[str] = None, + max_new_tokens: int = 300, # Increased default for API + ) -> List[str]: + captions = [] + + # Prepare prompts + if prompt is None: + prompts = ["Describe this image in detail."] * len(images) + else: + prompts = prompt + + assert len(images) == len(prompts), "Number of images and prompts must match" - return captions + for img, p in zip(images, prompts): + try: + response = client.chat.completions.create( + model=model_name, + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": p}, + { + "type": "image_url", + "image_url": { + "url": pil_to_b64(img) + }, + }, + ], + } + ], + max_tokens=max_new_tokens, + ) + captions.append(response.choices[0].message.content) + except Exception as e: + print(f"Error calling API: {e}") + captions.append("") + + return captions - return caption_images + return caption_images def get_image_ssim(imageA, imageB): diff --git a/run.py b/run.py index 7a48a2e..4420e8d 100644 --- a/run.py +++ b/run.py @@ -136,15 +136,15 @@ def config() -> argparse.Namespace: parser.add_argument( "--eval_captioning_model", type=str, - default="Salesforce/blip2-flan-t5-xl", - choices=["Salesforce/blip2-flan-t5-xl"], + default="qwen3-vl-plus", + # choices=["Salesforce/blip2-flan-t5-xl"], # Allow other models help="Captioning backbone for VQA-type evals.", ) parser.add_argument( "--captioning_model", type=str, - default="Salesforce/blip2-flan-t5-xl", - choices=["Salesforce/blip2-flan-t5-xl", "llava-hf/llava-1.5-7b-hf"], + default="qwen3-vl-plus", + # choices=["Salesforce/blip2-flan-t5-xl", "llava-hf/llava-1.5-7b-hf"], # Allow other models help="Captioning backbone for accessibility tree alt text.", ) diff --git a/tests/test_api_captioner.py b/tests/test_api_captioner.py new file mode 100644 index 0000000..549515e --- /dev/null +++ b/tests/test_api_captioner.py @@ -0,0 +1,55 @@ + +import unittest +from unittest.mock import MagicMock, patch +import sys +import os +from PIL import Image + +# Add the project root to sys.path to import modules +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +from evaluation_harness import image_utils + +class TestApiCaptioner(unittest.TestCase): + @patch("openai.OpenAI") + @patch.dict(os.environ, {"OPENAI_API_KEY": "test_key", "OPENAI_BASE_URL": "test_url"}) + def test_get_captioning_fn_api(self, mock_openai): + # Setup mock + mock_client = MagicMock() + mock_openai.return_value = mock_client + mock_response = MagicMock() + mock_response.choices = [MagicMock(message=MagicMock(content="A test caption"))] + mock_client.chat.completions.create.return_value = mock_response + + # Get the function + device = "cpu" + dtype = "float32" + model_name = "qwen3-vl-plus" + caption_fn = image_utils.get_captioning_fn(device, dtype, model_name) + + # Create dummy image + image = Image.new("RGB", (100, 100), color="red") + + # Test 1: VQA (no prompt provided, uses default) + captions = caption_fn([image]) + self.assertEqual(captions, ["A test caption"]) + + # Verify API call + mock_client.chat.completions.create.assert_called() + call_args = mock_client.chat.completions.create.call_args + self.assertEqual(call_args.kwargs["model"], model_name) + self.assertEqual(len(call_args.kwargs["messages"]), 1) + self.assertEqual(call_args.kwargs["messages"][0]["role"], "user") + + # Test 2: Custom Prompt + custom_prompt = ["What color is this?"] + captions = caption_fn([image], prompt=custom_prompt) + self.assertEqual(captions, ["A test caption"]) + + # Verify API call with custom prompt + call_args = mock_client.chat.completions.create.call_args + content = call_args.kwargs["messages"][0]["content"] + self.assertEqual(content[0]["text"], "What color is this?") + +if __name__ == "__main__": + unittest.main()