Skip to content

Commit 365e137

Browse files
committed
Add unit tests for image captioning service
1 parent de393ba commit 365e137

File tree

1 file changed

+50
-0
lines changed

1 file changed

+50
-0
lines changed
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import pytest
2+
from unittest.mock import MagicMock, patch
3+
from service.image_captioning import ImageCaptioning
4+
5+
@pytest.fixture
6+
def dummy_captioner():
7+
with patch('service.image_captioning.LlavaForConditionalGeneration.from_pretrained') as mock_model, \
8+
patch('service.image_captioning.AutoProcessor.from_pretrained') as mock_processor:
9+
mock_model.return_value = MagicMock()
10+
processor_instance = MagicMock()
11+
# Set the processor mock to return our processor_instance
12+
mock_processor.return_value = processor_instance
13+
captioner = ImageCaptioning("dummy-model")
14+
# Mock methods used in the pipeline
15+
captioner._model.generate = MagicMock(return_value=[[1, 2, 3, 4]])
16+
processor_instance.apply_chat_template = MagicMock(return_value="prompt")
17+
processor_instance.return_value = {
18+
"input_ids": MagicMock(shape=(1, 2)),
19+
"pixel_values": MagicMock()
20+
}
21+
processor_instance.tokenizer.decode = MagicMock(return_value="A cat on a mat.")
22+
# Make sure the processor mock is used for __call__
23+
captioner._processor = processor_instance
24+
return captioner
25+
26+
def test_image_inputs(dummy_captioner, tmp_path):
27+
# Create a dummy image
28+
img_path = tmp_path / "test.jpg"
29+
from PIL import Image
30+
Image.new("RGB", (10, 10)).save(img_path)
31+
inputs = dummy_captioner._image_inputs(img_path, "Describe this image.")
32+
assert "input_ids" in inputs
33+
34+
def test_generate_caption(dummy_captioner):
35+
dummy_inputs = {
36+
"input_ids": MagicMock(shape=(1, 2)),
37+
"pixel_values": MagicMock()
38+
}
39+
caption = dummy_captioner._generate_caption(dummy_inputs)
40+
assert isinstance(caption, str)
41+
assert caption == "A cat on a mat."
42+
43+
def test_save_caption(dummy_captioner, tmp_path):
44+
img_path = tmp_path / "test.jpg"
45+
img_path.touch()
46+
caption = "A cat on a mat."
47+
output_file = dummy_captioner.save_caption(caption, img_path)
48+
assert output_file.exists()
49+
with open(output_file, "r", encoding="utf-8") as f:
50+
assert f.read() == caption

0 commit comments

Comments
 (0)