1+ import pytest
2+ from unittest .mock import MagicMock , patch
3+ from service .image_captioning import ImageCaptioning
4+
5+ @pytest .fixture
6+ def dummy_captioner ():
7+ with patch ('service.image_captioning.LlavaForConditionalGeneration.from_pretrained' ) as mock_model , \
8+ patch ('service.image_captioning.AutoProcessor.from_pretrained' ) as mock_processor :
9+ mock_model .return_value = MagicMock ()
10+ processor_instance = MagicMock ()
11+ # Set the processor mock to return our processor_instance
12+ mock_processor .return_value = processor_instance
13+ captioner = ImageCaptioning ("dummy-model" )
14+ # Mock methods used in the pipeline
15+ captioner ._model .generate = MagicMock (return_value = [[1 , 2 , 3 , 4 ]])
16+ processor_instance .apply_chat_template = MagicMock (return_value = "prompt" )
17+ processor_instance .return_value = {
18+ "input_ids" : MagicMock (shape = (1 , 2 )),
19+ "pixel_values" : MagicMock ()
20+ }
21+ processor_instance .tokenizer .decode = MagicMock (return_value = "A cat on a mat." )
22+ # Make sure the processor mock is used for __call__
23+ captioner ._processor = processor_instance
24+ return captioner
25+
26+ def test_image_inputs (dummy_captioner , tmp_path ):
27+ # Create a dummy image
28+ img_path = tmp_path / "test.jpg"
29+ from PIL import Image
30+ Image .new ("RGB" , (10 , 10 )).save (img_path )
31+ inputs = dummy_captioner ._image_inputs (img_path , "Describe this image." )
32+ assert "input_ids" in inputs
33+
34+ def test_generate_caption (dummy_captioner ):
35+ dummy_inputs = {
36+ "input_ids" : MagicMock (shape = (1 , 2 )),
37+ "pixel_values" : MagicMock ()
38+ }
39+ caption = dummy_captioner ._generate_caption (dummy_inputs )
40+ assert isinstance (caption , str )
41+ assert caption == "A cat on a mat."
42+
43+ def test_save_caption (dummy_captioner , tmp_path ):
44+ img_path = tmp_path / "test.jpg"
45+ img_path .touch ()
46+ caption = "A cat on a mat."
47+ output_file = dummy_captioner .save_caption (caption , img_path )
48+ assert output_file .exists ()
49+ with open (output_file , "r" , encoding = "utf-8" ) as f :
50+ assert f .read () == caption
0 commit comments