22from PIL import Image
33from transformers import AutoProcessor , LlavaForConditionalGeneration , GenerationConfig
44from pathlib import Path
5+ from config .image_formats import SUPPORTED_INPUT_EXTENSIONS
6+ Image .MAX_IMAGE_PIXELS = None
57
6- class ImageCaptioning :
7- def __init__ (self , model , max_new_tokens = 512 , temperature = 0.6 , top_p = 0.9 , top_k = None ):
8+ class ImageCaptioningService :
9+ def __init__ (self , model : Path , load_path : Path , save_path : Path , prompt : str , max_new_tokens = 512 , temperature = 0.6 , top_p = 0.9 , top_k = None ):
810 self ._model , self ._processor = self ._load_model (model )
11+ self .load_path = load_path
12+ self .save_path = save_path
13+ self .prompt = prompt
914 self ._generation_config = GenerationConfig (
1015 max_new_tokens = max_new_tokens ,
16+ do_sample = True ,
1117 temperature = temperature ,
1218 top_p = top_p ,
1319 top_k = top_k
1420 )
1521
1622 def _load_model (self , model_path ):
17- model = LlavaForConditionalGeneration .from_pretrained (model_path , torch_dtype = torch .bfloat16 , device_map = "auto" )
23+ model = LlavaForConditionalGeneration .from_pretrained (model_path , dtype = torch .bfloat16 , device_map = "auto" )
1824 processor = AutoProcessor .from_pretrained (model_path )
1925
2026 model .eval ()
2127 return model , processor
2228
23- def _image_inputs (self , image_path : Path , prompt : str ):
29+ def _image_inputs (self , image_path : Path ):
2430 image = Image .open (image_path ).convert ('RGB' )
2531
2632 messages = [
2733 {
2834 "role" : "system" ,
29- "content" : "You are a helpful image captioner."
35+ "content" :
36+ """
37+ You are a professional image captioner for machine learning datasets.
38+ Provide accurate, detailed descriptions focusing on visual elements, composition, and relevant details.
39+ Follow user instructions carefully.
40+ """
3041 },
3142 {
3243 "role" : "user" ,
33- "content" : prompt
44+ "content" : self . prompt
3445 }
3546 ]
3647
3748 formatted_prompt = self ._processor .apply_chat_template (
3849 messages ,
39- tokenize = False , #return string, not tokens yet
50+ tokenize = False ,
4051 add_generation_prompt = True
4152 )
4253
@@ -47,20 +58,17 @@ def _image_inputs(self, image_path: Path, prompt: str):
4758 return_tensors = "pt"
4859 )
4960
50- #move inputs to GPU if available
5161 if torch .cuda .is_available ():
5262 moved_inputs = {}
5363
5464 for key , value in inputs .items ():
55- #check if value can be moved to GPU (has a 'to' method)
5665 if hasattr (value , 'to' ):
5766 moved_inputs [key ] = value .to ('cuda' )
5867 else :
5968 moved_inputs [key ] = value
6069
6170 inputs = moved_inputs
6271
63- #convert pixel values to match model dtype (bfloat16)
6472 if 'pixel_values' in inputs :
6573 inputs ['pixel_values' ] = inputs ['pixel_values' ].to (torch .bfloat16 )
6674
@@ -70,7 +78,7 @@ def _generate_caption(self, inputs: dict):
7078 print ("Generation Captions" )
7179
7280 generated_ids = self ._model .generate (
73- ** inputs , #unpack tensors
81+ ** inputs ,
7482 generation_config = self ._generation_config ,
7583 )[0 ]
7684
@@ -85,9 +93,14 @@ def _generate_caption(self, inputs: dict):
8593
8694 return caption .strip ()
8795
88- def caption_image (self , image_path : Path , prompt : str ):
89- inputs = self ._image_inputs (image_path , prompt )
90- return self ._generate_caption (inputs )
96+ def caption_image (self ):
97+ files = [f for f in self .load_path .iterdir ()
98+ if f .is_file () and f .suffix .lower () in SUPPORTED_INPUT_EXTENSIONS ]
99+
100+ for img_file in files :
101+ inputs = self ._image_inputs (img_file )
102+ caption = self ._generate_caption (inputs )
103+ self .save_caption (caption , img_file , self .save_path )
91104
92105 def save_caption (self , caption , image_path : Path , output_dir :Path = None ):
93106 if output_dir :
0 commit comments