Skip to content

Commit 3d6e052

Browse files
committed
Implement image upscaling service with format conversion and add unit tests
1 parent da6e167 commit 3d6e052

File tree

2 files changed

+112
-5
lines changed

2 files changed

+112
-5
lines changed
Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,43 @@
1-
from spandrel import ImageModelDescriptor
1+
from spandrel import ModelLoader
2+
from pathlib import Path
3+
from PIL import Image
4+
import torch
5+
import torchvision.transforms as transforms
6+
from config.image_formats import SUPPORTED_OUTPUT_FORMATS, SUPPORTED_INPUT_EXTENSIONS, get_supported_format
27

3-
class image_service():
4-
def __init__(self, model: ImageModelDescriptor):
5-
self._model = model
8+
class ImageService():
9+
def __init__(self, model_path, device):
10+
loader = ModelLoader()
11+
model = loader.load_from_file(model_path)
12+
self.model = model.to(device).eval()
613

7-
14+
@staticmethod
15+
def get_supported_output_formats():
16+
return get_supported_format()
17+
18+
19+
def upscale(self, images_path: str, save_path: str, format: str = "jpg"):
20+
images_path = Path(images_path)
21+
save_path = Path(save_path)
22+
23+
image_files = [f for f in images_path.iterdir() if f.is_file() and f.suffix.lower() in SUPPORTED_INPUT_EXTENSIONS]
24+
format_info = SUPPORTED_OUTPUT_FORMATS[format.lower()]
25+
26+
for image_file in image_files:
27+
image = Image.open(image_file)
28+
29+
if image.mode == "RGBA":
30+
image = image.convert('RGB')
31+
32+
33+
to_tensor = transforms.ToTensor()
34+
image_tensor = to_tensor(image).unsqueeze(0).to(self.model.device)
35+
36+
with torch.no_grad():
37+
result = self.model(image_tensor)
38+
39+
to_pil = transforms.ToPILImage()
40+
output_image = to_pil(result.cpu().squeeze(0))
41+
42+
file_name = image_file.stem + "_upscaled" + format_info["extension"]
43+
output_image.save(save_path / file_name, format=format_info["pil_format"])
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import pytest
2+
import tempfile
3+
import shutil
4+
from pathlib import Path
5+
from PIL import Image
6+
from unittest.mock import Mock, patch
7+
from service.image_processing import ImageService
8+
import torch
9+
10+
11+
@pytest.fixture
12+
def temp_dirs():
13+
"""Create temporary input and output directories."""
14+
input_dir = Path(tempfile.mkdtemp())
15+
output_dir = Path(tempfile.mkdtemp())
16+
17+
test_image = Image.new('RGB', (100, 100), color='red')
18+
test_image.save(input_dir / "test.jpg", "JPEG")
19+
20+
yield input_dir, output_dir
21+
22+
shutil.rmtree(input_dir)
23+
shutil.rmtree(output_dir)
24+
25+
26+
@pytest.fixture
27+
def image_service():
28+
with patch('service.image_processing.ModelLoader') as mock_loader:
29+
# Create a mock model that returns the same tensor (no actual upscaling)
30+
mock_model = Mock()
31+
mock_model.to.return_value = mock_model
32+
mock_model.eval.return_value = mock_model
33+
mock_model.device = "cpu"
34+
mock_model.return_value = torch.rand(1, 3, 100, 100)
35+
36+
mock_loader.return_value.load_from_file.return_value = mock_model
37+
38+
return ImageService("dummy_model_path.pth", "cpu")
39+
40+
41+
def test_upscale_default_format(image_service, temp_dirs):
42+
"""Test upscaling with default JPG format."""
43+
input_dir, output_dir = temp_dirs
44+
45+
image_service.upscale(str(input_dir), str(output_dir))
46+
47+
output_files = list(output_dir.glob("*_upscaled.*"))
48+
assert len(output_files) == 1
49+
assert output_files[0].suffix == ".jpg"
50+
51+
52+
def test_upscale_format_conversion(image_service, temp_dirs):
53+
"""Test upscaling with format conversion."""
54+
input_dir, output_dir = temp_dirs
55+
56+
image_service.upscale(str(input_dir), str(output_dir), "png")
57+
58+
output_files = list(output_dir.glob("*_upscaled.png"))
59+
assert len(output_files) == 1
60+
61+
with Image.open(output_files[0]) as img:
62+
assert img.format == "PNG"
63+
64+
65+
def test_get_supported_formats():
66+
"""Test getting supported output formats."""
67+
formats = ImageService.get_supported_output_formats()
68+
69+
assert isinstance(formats, list)
70+
assert len(formats) > 0
71+
assert all("value" in fmt and "label" in fmt for fmt in formats)

0 commit comments

Comments
 (0)