Skip to content

Commit ff552ab

Browse files
committed
add tiling support based on chaiNNer's implementation
1 parent 156dfbc commit ff552ab

File tree

5 files changed

+388
-0
lines changed

5 files changed

+388
-0
lines changed

backend/tiling/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from .tile_splitter import TileSplitter
2+
from .tile_blender import TileBlender
3+
from .blend_functions import linear_blend, half_sin_blend, sin_blend_fn
4+
from .models import Tile, Region, Padding, TileOverlap
5+
6+
__all__ = [
7+
'TileSplitter',
8+
'TileBlender',
9+
'Tile',
10+
'Region',
11+
'Padding',
12+
'TileOverlap',
13+
'linear_blend',
14+
'half_sin_blend',
15+
'sin_blend_fn',
16+
]
17+
18+
__version__ = '1.0.0'
19+
__author__ = 'based on chaiNNer\'s tiling implementation'

backend/tiling/blend_functions.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
"""blending weight functions for smooth tile merging"""
2+
import math
3+
import numpy as np
4+
5+
def linear_blend(distance: float, overlap: float):
6+
"""
7+
linear gradient blending
8+
weight increase from 0.0 to 1.0 linearly
9+
"""
10+
if overlap == 0:
11+
return 1.0
12+
return distance / (overlap - 1)
13+
14+
def sin_blend_fn(x: float):
15+
"""
16+
chaiNNer's sine blending function
17+
Formula: (sin(x * π - π/2) + 1) / 2
18+
"""
19+
return (math.sin(x * math.pi - math.pi / 2) + 1) / 2
20+
21+
def half_sin_blend(distance: float, overlap: float):
22+
"""
23+
chaiNNer's half-sine blending function
24+
"""
25+
if overlap == 0:
26+
return 1.0
27+
28+
normalized = distance / (overlap - 1)
29+
compressed = min(max(normalized * 2 - 0.5, 0), 1)
30+
return sin_blend_fn(compressed)
31+
32+
def create_blend_mask(
33+
tile_width: int,
34+
tile_height: int,
35+
overlap: int,
36+
blend_fn=half_sin_blend,
37+
is_top_edge: bool = False,
38+
is_bottom_edge: bool = False,
39+
is_right_edge: bool = False,
40+
is_left_edge: bool = False
41+
):
42+
"""
43+
Create 2D weight mask for tile blending.
44+
Edges at image boundaries don't get blended (full weight)
45+
Internal edges get gradual blending based on blend_fn
46+
47+
Returns: 2D numpy array of weights (0.0 to 1.0)
48+
"""
49+
50+
# start with full contribution
51+
mask = np.ones((tile_height, tile_width), dtype=np.float32)
52+
53+
if overlap <= 1:
54+
return mask
55+
56+
# apply blending gradient to each edge that overlaps another tile
57+
58+
# top edge: blend from 0.0 to 1.0 going top to bottom
59+
if not is_top_edge:
60+
for i in range(overlap):
61+
weight = blend_fn(i, overlap)
62+
mask[i, :] *= weight
63+
64+
# bottom edge: blend from 1.0 to 0.0 approaching bottom
65+
if not is_bottom_edge:
66+
for i in range(overlap):
67+
weight = blend_fn(i, overlap)
68+
mask[-(i+1), :] *= weight
69+
70+
# right edge: blend from 1.0 to 0.0 approaching right edge
71+
if not is_right_edge:
72+
for i in range(overlap):
73+
weight = blend_fn(i, overlap)
74+
mask[:, -(i+1)] *= weight
75+
76+
# left edge: blend from 0.0 to 1.0 going left to right
77+
if not is_left_edge:
78+
for i in range(overlap):
79+
weight = blend_fn(i, overlap)
80+
mask[:, i] *= weight
81+
82+
return mask

backend/tiling/models.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
"""data models for tiling operations"""
2+
"""based on chaiNNer's region and tileoverlap structures"""
3+
from dataclasses import dataclass
4+
from PIL import Image
5+
import numpy as np
6+
7+
@dataclass
8+
class Region:
9+
"""Rectangular region in an image"""
10+
x: int
11+
y: int
12+
width: int
13+
height:int
14+
15+
def intersect(self, other: 'Region'):
16+
"""
17+
Get the intersection of two regions.
18+
Ensures tiles don't extend beyond image boundaries.
19+
"""
20+
x1 = max(self.x, other.x)
21+
y1 = max(self.y, other.y)
22+
x2 = min(self.x + self.width, other.x + other.width)
23+
y2 = min(self.y + self.height, other.y + other.height)
24+
return Region(x1, y1, max(0, x2 - x1), max(0, y2 - y1))
25+
26+
def add_padding(self, padding: 'Padding'):
27+
"""
28+
Expand this region by adding padding on all sides.
29+
Used to create overlap between tiles.
30+
"""
31+
return Region(
32+
x=self.x - padding.left,
33+
y=self.y - padding.top,
34+
width=self.width + padding.left + padding.right,
35+
height=self.height + padding.top + padding.bottom
36+
)
37+
38+
def child_padding(self, child: 'Region'):
39+
"""
40+
Calculate how much padding a child region has relative to this region.
41+
Used to determine overlap amounts for tiles.
42+
"""
43+
return Padding(
44+
top=child.y - self.y,
45+
bottom=(self.y + self.height) - (child.y + child.height),
46+
left=child.x - self.x,
47+
right=(self.x + self.width) - (child.x + child.width)
48+
)
49+
50+
def read_from(self, img: np.ndarray) -> np.ndarray:
51+
"""Extract this region from an image array"""
52+
return img[self.y:self.y+self.height, self.x:self.x+self.width, ...]
53+
54+
@dataclass
55+
class Padding:
56+
"""Padding amounts for each side"""
57+
top: int
58+
bottom: int
59+
left: int
60+
right: int
61+
62+
def min(self, max_padding: int):
63+
"""
64+
Limit padding to a maximum value.
65+
Prevents overlap from being larger than necessary.
66+
"""
67+
return Padding(
68+
top=min(self.top, max_padding),
69+
bottom=min(self.bottom, max_padding),
70+
left=min(self.left, max_padding),
71+
right=min(self.right, max_padding)
72+
)
73+
74+
@dataclass
75+
class TileOverlap:
76+
"""
77+
ChaiNNer's structure for tracking overlap on a single axis.
78+
Start overlap is on the leading edge, end overlap is on the trailing edge.
79+
"""
80+
start: int # Overlap at the start (left for X, top for Y)
81+
end: int # Overlap at the end (right for X, bottom for Y)
82+
83+
@property
84+
def total(self) -> int:
85+
"""Total overlap amount"""
86+
return self.start + self.end
87+
88+
@dataclass
89+
class Tile:
90+
"""
91+
A tile with image data and position
92+
used to track tiles during split -> process -> merge
93+
"""
94+
image: Image.Image
95+
region: Region
96+
x: int #x pos of orig image
97+
y: int #y pos of orig image
98+
padding: Padding = None

backend/tiling/tile_blender.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
"""
2+
merges processed tiles with smooth blending.
3+
based on chaiNNer's directional blending approach.
4+
"""
5+
import numpy as np
6+
from typing import List
7+
from PIL import Image
8+
from .models import Tile
9+
from .blend_functions import create_blend_mask, half_sin_blend
10+
11+
class TileBlender:
12+
"""
13+
Merges upscaled tiles with seamless blending.
14+
Uses chaiNNer's two-stage blending approach:
15+
1. Blend tiles horizontally into rows (BlendDirection.X)
16+
2. Blend rows vertically into final image (BlendDirection.Y)
17+
"""
18+
19+
def __init__(self, overlap: int = 16, scale: int = 1, blend_fn=half_sin_blend):
20+
self.overlap = overlap
21+
self.scale = scale
22+
self.blend_fn = blend_fn
23+
24+
def merge(self, tiles: List[Tile], original_width: int, original_height: int):
25+
"""
26+
Merge tiles into final image with blending.
27+
"""
28+
out_w = original_width * self.scale
29+
out_h = original_height * self.scale
30+
overlap_scaled = self.overlap * self.scale
31+
32+
# Accumulators
33+
# output: accumulates weighted pixel values
34+
# weights: tracks total weight at each pixel
35+
output = np.zeros((out_h, out_w, 3), dtype=np.float32)
36+
weights = np.zeros((out_h, out_w), dtype=np.float32)
37+
38+
for tile in tiles:
39+
# Position and size in output image (scaled)
40+
x = tile.x * self.scale
41+
y = tile.y * self.scale
42+
region_w = tile.region.width * self.scale
43+
region_h = tile.region.height * self.scale
44+
45+
# The tile image includes padding/overlap
46+
tile_img = tile.image
47+
48+
# Determine edge positions
49+
# Tiles at image boundaries don't blend on those edges
50+
is_left = (tile.x == 0)
51+
is_top = (tile.y == 0)
52+
is_right = (tile.x + tile.region.width >= original_width)
53+
is_bottom = (tile.y + tile.region.height >= original_height)
54+
55+
# Create blending mask for the region size
56+
mask = create_blend_mask(
57+
region_w, region_h, overlap_scaled,
58+
self.blend_fn,
59+
is_top_edge=is_top,
60+
is_bottom_edge=is_bottom,
61+
is_right_edge=is_right,
62+
is_left_edge=is_left
63+
)
64+
65+
# Convert tile to numpy array
66+
tile_arr = np.array(tile_img, dtype=np.float32)
67+
68+
# The tile image may have padding - we need to crop it to match the region
69+
# Calculate padding that was added
70+
tile_w_scaled = tile_img.width
71+
tile_h_scaled = tile_img.height
72+
73+
# Padding is the difference between tile image size and region size
74+
pad_left = (tile_w_scaled - region_w) // 2 if not is_left else 0
75+
pad_top = (tile_h_scaled - region_h) // 2 if not is_top else 0
76+
77+
# Crop the tile array to match the region size
78+
tile_arr_cropped = tile_arr[
79+
pad_top:pad_top+region_h,
80+
pad_left:pad_left+region_w,
81+
:
82+
]
83+
84+
# Add weighted contribution to output
85+
# In non-overlap regions, mask = 1.0, so full contribution
86+
# In overlap regions, mask varies, creating blend
87+
for c in range(3):
88+
try:
89+
output[y:y+region_h, x:x+region_w, c] += tile_arr_cropped[:, :, c] * mask
90+
except ValueError as e:
91+
print(f"ERROR at tile ({tile.x}, {tile.y}):")
92+
print(f" x={x}, y={y}, region_w={region_w}, region_h={region_h}")
93+
print(f" Slice would be output[{y}:{y+region_h}, {x}:{x+region_w}, {c}]")
94+
print(f" output.shape={output.shape}")
95+
print(f" tile_arr_cropped[:,:,{c}].shape={tile_arr_cropped[:,:,c].shape}")
96+
print(f" mask.shape={mask.shape}")
97+
print(f" Actual slice shape: {output[y:y+region_h, x:x+region_w, c].shape}")
98+
raise
99+
100+
# Track total weight at each pixel
101+
weights[y:y+region_h, x:x+region_w] += mask
102+
103+
# Normalize: divide accumulated colors by total weights
104+
# ChaiNNer does this via their _fast_mix function
105+
# Example: If two tiles overlap with weights 0.3 and 0.7:
106+
# output = color1 * 0.3 + color2 * 0.7
107+
# weights = 0.3 + 0.7 = 1.0
108+
# final = output / weights = (color1 * 0.3 + color2 * 0.7) / 1.0
109+
for c in range(3):
110+
output[:, :, c] /= np.maximum(weights, 1e-6)
111+
112+
# Convert to 8-bit image
113+
output = np.clip(output, 0, 255).astype(np.uint8)
114+
return Image.fromarray(output)

backend/tiling/tile_splitter.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
"""
2+
splits large images into overlapping tiles.
3+
based on chaiNNer's optimal tile size calculation to prevent uneven tiles.
4+
"""
5+
import math
6+
from PIL import Image
7+
from .models import Tile, Region
8+
9+
10+
class TileSplitter:
11+
"""
12+
Splits images into memory-efficient overlapping tiles.
13+
"""
14+
15+
def __init__(self, tile_size: int = 512, overlap: int = 16):
16+
self.max_tile_size = tile_size
17+
self.overlap = overlap
18+
19+
def calculate_optimal_tile_size(self, width: int, height: int):
20+
# Calculate how many tiles needed in each direction
21+
tile_count_x = math.ceil(width / self.max_tile_size)
22+
tile_count_y = math.ceil(height / self.max_tile_size)
23+
24+
# Distribute image evenly across tiles
25+
optimal_x = math.ceil(width / tile_count_x)
26+
optimal_y = math.ceil(height / tile_count_y)
27+
28+
return optimal_x, optimal_y
29+
30+
def split(self, image: Image.Image):
31+
32+
width, height = image.size
33+
tile_w, tile_h = self.calculate_optimal_tile_size(width, height)
34+
35+
# Image boundary for intersection checks
36+
img_region = Region(0, 0, width, height)
37+
38+
tiles = []
39+
40+
# Step size: tile size minus overlap
41+
# This creates the overlapping pattern
42+
step_x = tile_w - self.overlap
43+
step_y = tile_h - self.overlap
44+
45+
# Generate tiles row by row
46+
for y in range(0, height, step_y):
47+
for x in range(0, width, step_x):
48+
# Create base tile region
49+
tile_region = Region(x, y, tile_w, tile_h)
50+
51+
# Clip to image boundaries (for edge tiles)
52+
tile_region = tile_region.intersect(img_region)
53+
54+
# Calculate padding for overlap
55+
# ChaiNNer's approach: img_region.child_padding(tile).min(overlap)
56+
pad = img_region.child_padding(tile_region).min(self.overlap)
57+
58+
padded_region = tile_region.add_padding(pad)
59+
60+
tile_img = image.crop((
61+
padded_region.x,
62+
padded_region.y,
63+
padded_region.x + padded_region.width,
64+
padded_region.y + padded_region.height
65+
))
66+
67+
tiles.append(Tile(
68+
image=tile_img,
69+
region=tile_region,
70+
x=tile_region.x,
71+
y=tile_region.y,
72+
padding=pad
73+
))
74+
75+
return tiles

0 commit comments

Comments
 (0)