Skip to content

Commit 2fd49d5

Browse files
committed
feat: add image captioning service and endpoint, update dependencies
1 parent 7208a4e commit 2fd49d5

File tree

5 files changed

+92
-19
lines changed

5 files changed

+92
-19
lines changed

backend/main.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from pydantic import BaseModel
44
from service.image_rename import RenameService
55
from service.image_upscaling import ImageUpscaleService
6+
from service.image_captioning import ImageCaptioningService
67
from utils.image_util import get_device
78

89
app = FastAPI()
@@ -19,10 +20,12 @@ class UpscaleRequest(BaseModel):
1920
save_path: str
2021
format: str
2122
use_tiling: bool = True
22-
23-
@app.get("/")
24-
def root():
25-
return {"Hello": "World"}
23+
24+
class CaptionRequest(BaseModel):
25+
caption_model_path: str
26+
load_path: str
27+
save_path: str
28+
prompt: str
2629

2730
@app.post("/rename")
2831
def rename(request: RenameRequest):
@@ -42,6 +45,13 @@ def upscale(request: UpscaleRequest):
4245
service.upscale(use_tiling=request.use_tiling)
4346
return {"status": "Upscaling Complete"}
4447

48+
@app.post("/caption")
49+
def caption(request: CaptionRequest):
50+
service = ImageCaptioningService(Path(request.caption_model_path), Path(request.load_path), Path(request.save_path), request.prompt)
51+
service.caption_image()
52+
return{"status": "Captioning Complete"}
53+
54+
4555
@app.get("/device")
4656
def device():
4757
device = get_device()

backend/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ description = "Add your description here"
55
readme = "README.md"
66
requires-python = ">=3.10"
77
dependencies = [
8+
"accelerate>=1.11.0",
89
"difpy>=4.2.1",
910
"fastapi[standard]>=0.116.2",
1011
"pillow>=11.3.0",

backend/service/image_captioning.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,41 +2,52 @@
22
from PIL import Image
33
from transformers import AutoProcessor, LlavaForConditionalGeneration, GenerationConfig
44
from pathlib import Path
5+
from config.image_formats import SUPPORTED_INPUT_EXTENSIONS
6+
Image.MAX_IMAGE_PIXELS = None
57

6-
class ImageCaptioning:
7-
def __init__(self, model,max_new_tokens=512, temperature=0.6, top_p=0.9, top_k=None):
8+
class ImageCaptioningService:
9+
def __init__(self, model: Path, load_path: Path, save_path: Path, prompt: str, max_new_tokens=512, temperature=0.6, top_p=0.9, top_k=None):
810
self._model, self._processor = self._load_model(model)
11+
self.load_path = load_path
12+
self.save_path = save_path
13+
self.prompt = prompt
914
self._generation_config = GenerationConfig(
1015
max_new_tokens=max_new_tokens,
16+
do_sample=True,
1117
temperature=temperature,
1218
top_p=top_p,
1319
top_k=top_k
1420
)
1521

1622
def _load_model(self, model_path):
17-
model = LlavaForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map="auto")
23+
model = LlavaForConditionalGeneration.from_pretrained(model_path, dtype=torch.bfloat16, device_map="auto")
1824
processor = AutoProcessor.from_pretrained(model_path)
1925

2026
model.eval()
2127
return model, processor
2228

23-
def _image_inputs(self, image_path: Path, prompt: str):
29+
def _image_inputs(self, image_path: Path):
2430
image = Image.open(image_path).convert('RGB')
2531

2632
messages = [
2733
{
2834
"role": "system",
29-
"content": "You are a helpful image captioner."
35+
"content":
36+
"""
37+
You are a professional image captioner for machine learning datasets.
38+
Provide accurate, detailed descriptions focusing on visual elements, composition, and relevant details.
39+
Follow user instructions carefully.
40+
"""
3041
},
3142
{
3243
"role": "user",
33-
"content": prompt
44+
"content": self.prompt
3445
}
3546
]
3647

3748
formatted_prompt = self._processor.apply_chat_template(
3849
messages,
39-
tokenize=False, #return string, not tokens yet
50+
tokenize=False,
4051
add_generation_prompt=True
4152
)
4253

@@ -47,20 +58,17 @@ def _image_inputs(self, image_path: Path, prompt: str):
4758
return_tensors="pt"
4859
)
4960

50-
#move inputs to GPU if available
5161
if torch.cuda.is_available():
5262
moved_inputs = {}
5363

5464
for key, value in inputs.items():
55-
#check if value can be moved to GPU (has a 'to' method)
5665
if hasattr(value, 'to'):
5766
moved_inputs[key] = value.to('cuda')
5867
else:
5968
moved_inputs[key] = value
6069

6170
inputs = moved_inputs
6271

63-
#convert pixel values to match model dtype (bfloat16)
6472
if 'pixel_values' in inputs:
6573
inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
6674

@@ -70,7 +78,7 @@ def _generate_caption(self, inputs: dict):
7078
print("Generation Captions")
7179

7280
generated_ids = self._model.generate(
73-
**inputs, #unpack tensors
81+
**inputs,
7482
generation_config=self._generation_config,
7583
)[0]
7684

@@ -85,9 +93,14 @@ def _generate_caption(self, inputs: dict):
8593

8694
return caption.strip()
8795

88-
def caption_image(self, image_path: Path, prompt: str):
89-
inputs = self._image_inputs(image_path, prompt)
90-
return self._generate_caption(inputs)
96+
def caption_image(self):
97+
files = [f for f in self.load_path.iterdir()
98+
if f.is_file() and f.suffix.lower() in SUPPORTED_INPUT_EXTENSIONS]
99+
100+
for img_file in files:
101+
inputs = self._image_inputs(img_file)
102+
caption = self._generate_caption(inputs)
103+
self.save_caption(caption, img_file, self.save_path)
91104

92105
def save_caption(self, caption, image_path: Path, output_dir:Path = None):
93106
if output_dir:

backend/service/image_upscaling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from config.image_formats import SUPPORTED_OUTPUT_FORMATS, SUPPORTED_INPUT_EXTENSIONS
99

1010
class ImageUpscaleService:
11-
def __init__(self, device, model_path: Path="", load_path: Path="", save_path: Path ="", format:str = "jpg", tile_size: int = 512, tile_overlap:int = 16):
11+
def __init__(self, device, model_path: Path, load_path: Path, save_path: Path, format:str = "jpg", tile_size: int = 512, tile_overlap:int = 16):
1212
self.load_path = load_path
1313
self.save_path = save_path
1414
self.format = format

backend/uv.lock

Lines changed: 49 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)