11import pytest
22from unittest .mock import MagicMock , patch
3- from service .image_captioning import ImageCaptioning
3+ from pathlib import Path
4+ from service .image_captioning import ImageCaptioningService
45
56@pytest .fixture
6- def dummy_captioner ():
7+ def dummy_captioner (tmp_path ):
78 with patch ('service.image_captioning.LlavaForConditionalGeneration.from_pretrained' ) as mock_model , \
89 patch ('service.image_captioning.AutoProcessor.from_pretrained' ) as mock_processor :
910 mock_model .return_value = MagicMock ()
1011 processor_instance = MagicMock ()
11- # Set the processor mock to return our processor_instance
1212 mock_processor .return_value = processor_instance
13- captioner = ImageCaptioning ("dummy-model" )
13+
14+ load_path = tmp_path / "load"
15+ save_path = tmp_path / "save"
16+ load_path .mkdir ()
17+ save_path .mkdir ()
18+
19+ captioner = ImageCaptioningService (
20+ model = Path ("dummy-model" ),
21+ load_path = load_path ,
22+ save_path = save_path ,
23+ prompt = "Describe this image."
24+ )
25+
1426 # Mock methods used in the pipeline
1527 captioner ._model .generate = MagicMock (return_value = [[1 , 2 , 3 , 4 ]])
1628 processor_instance .apply_chat_template = MagicMock (return_value = "prompt" )
@@ -19,7 +31,6 @@ def dummy_captioner():
1931 "pixel_values" : MagicMock ()
2032 }
2133 processor_instance .tokenizer .decode = MagicMock (return_value = "A cat on a mat." )
22- # Make sure the processor mock is used for __call__
2334 captioner ._processor = processor_instance
2435 return captioner
2536
@@ -28,7 +39,7 @@ def test_image_inputs(dummy_captioner, tmp_path):
2839 img_path = tmp_path / "test.jpg"
2940 from PIL import Image
3041 Image .new ("RGB" , (10 , 10 )).save (img_path )
31- inputs = dummy_captioner ._image_inputs (img_path , "Describe this image." )
42+ inputs = dummy_captioner ._image_inputs (img_path )
3243 assert "input_ids" in inputs
3344
3445def test_generate_caption (dummy_captioner ):
@@ -44,7 +55,7 @@ def test_save_caption(dummy_captioner, tmp_path):
4455 img_path = tmp_path / "test.jpg"
4556 img_path .touch ()
4657 caption = "A cat on a mat."
47- output_file = dummy_captioner .save_caption (caption , img_path )
58+ output_file = dummy_captioner .save_caption (caption , img_path , tmp_path )
4859 assert output_file .exists ()
4960 with open (output_file , "r" , encoding = "utf-8" ) as f :
5061 assert f .read () == caption
0 commit comments