Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions src/mars_patcher/color_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,19 +266,17 @@ def chroma(self) -> float:
neutral gray of the same lightness."""
return math.sqrt(self.a_star * self.a_star + self.b_star * self.b_star)

def shift_hue(self, shift: float) -> "OklabColor":
def shift_hue(self, shift: float) -> None:
"""Shifts hue by the provided amount, measured in radians."""
# Get hue in range 0 to 2pi
hue = self.hue() + math.pi
hue = (hue + shift) % (2 * math.pi)
# Put hue back in range -pi to pi
hue -= math.pi

# Get new A and B values
chroma = self.chroma()
a = chroma * math.cos(hue)
b = chroma * math.sin(hue)
return OklabColor(self.l_star, a, b)
self.a_star = chroma * math.cos(hue)
self.b_star = chroma * math.sin(hue)

@staticmethod
def linear_to_srgb(value: float) -> float:
Expand Down
85 changes: 68 additions & 17 deletions src/mars_patcher/palette.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,68 @@
import math
import random

from mars_patcher.color_spaces import RgbBitSize, RgbColor
from mars_patcher.color_spaces import HsvColor, OklabColor, RgbBitSize, RgbColor
from mars_patcher.rom import Rom

HUE_VARIATION_RANGE = 180.0
"""The maximum range that hue can be additionally rotated."""


class SineWave:
STEP = (2 * math.pi) / 16

def __init__(self, amplitude: float, frequency: float, phase: float):
self.amplitude = amplitude
self.frequency = frequency
self.phase = phase

@staticmethod
def generate(max_range: float) -> "SineWave":
"""
Generates a random sine wave of the form
y = amplitude * sin(frequency * x + phase)
where
0 <= amplitude <= 1
1/4 <= frequency <= 1
x increases in steps of 1/16 of a cycle
0 <= phase <= 2pi (one cycle)
"""
assert 0 <= max_range <= 1
# Prefer amplitudes closer to the max, otherwise the variation is often too subtle
amplitude = random.uniform(max_range / 2, max_range)
frequency = random.uniform(0.25, 1)
phase = random.uniform(0, 2 * math.pi)
return SineWave(amplitude, frequency, phase)

def calculate_variation(self, x: int) -> float:
assert 0 <= x < 16
return self.amplitude * math.sin(self.frequency * x * self.STEP + self.phase)


class ColorChange:
def __init__(self, hue_shift: float, hue_var: SineWave | None):
self.hue_shift = hue_shift
self.hue_var = hue_var

def _get_hue_shift(self, index: int) -> float:
shift = self.hue_shift
if self.hue_var is not None:
factor = HUE_VARIATION_RANGE / 2
shift += self.hue_var.calculate_variation(index) * factor
return shift

def change_hsv(self, hsv: HsvColor, index: int) -> HsvColor:
shift = self._get_hue_shift(index)
hsv.hue = (hsv.hue + shift) % 360
return hsv

def change_oklab(self, lab: OklabColor, index: int) -> OklabColor:
shift = self._get_hue_shift(index)
# Convert hue shift to radians
shift *= math.pi / 180
lab.shift_hue(shift)
return lab


class Palette:
def __init__(self, rows: int, rom: Rom, addr: int):
Expand Down Expand Up @@ -31,11 +91,8 @@ def write(self, rom: Rom, addr: int) -> None:
data = self.byte_data()
rom.write_bytes(addr, data)

def shift_hue_hsv(self, shift: int, excluded_rows: set[int]) -> None:
"""
Shifts hue by the provided amount, measured in degrees.
Uses HSV color space.
"""
def change_colors_hsv(self, change: ColorChange, excluded_rows: set[int]) -> None:
"""Apply a color change using HSV color space."""
black = RgbColor.black()
white = RgbColor.white_5()
for row in range(self.rows()):
Expand All @@ -47,30 +104,24 @@ def shift_hue_hsv(self, shift: int, excluded_rows: set[int]) -> None:
rgb = self.colors[offset + i]
if rgb == black or rgb == white:
continue
# Get HSV and shift hue
orig_luma = rgb.luma()
hsv = rgb.hsv()
hsv.hue = (hsv.hue + shift) % 360
# Get new RGB and rescale luma
hsv = change.change_hsv(rgb.hsv(), i)
rgb = hsv.rgb()
# Rescale luma
luma_ratio = orig_luma / rgb.luma()
rgb.red = min(int(rgb.red * luma_ratio), 255)
rgb.green = min(int(rgb.green * luma_ratio), 255)
rgb.blue = min(int(rgb.blue * luma_ratio), 255)
self.colors[offset + i] = rgb

def shift_hue_oklab(self, shift: int, excluded_rows: set[int]) -> None:
"""
Shifts hue by the provided amount, measured in degrees.
Uses Oklab color space.
"""
def change_colors_oklab(self, change: ColorChange, excluded_rows: set[int]) -> None:
"""Apply a color change using Oklab color space."""
# Convert shift to radians
shift_rads = shift * (math.pi / 180)
for row in range(self.rows()):
if row in excluded_rows:
continue
offset = row * 16
for i in range(16):
rgb = self.colors[offset + i]
lab = rgb.oklab().shift_hue(shift_rads)
lab = change.change_oklab(rgb.oklab(), i)
self.colors[offset + i] = lab.rgb()
94 changes: 55 additions & 39 deletions src/mars_patcher/random_palettes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
NETTORI_EXTRA_PALS,
TILESET_ANIM_PALS,
)
from mars_patcher.palette import Palette
from mars_patcher.palette import ColorChange, Palette, SineWave
from mars_patcher.rom import Game, Rom


Expand All @@ -39,11 +39,13 @@ def __init__(
pal_types: dict[PaletteType, tuple[int, int]], # TODO: change this tuple(int, int)
color_space: MarsschemaPalettesColorspace,
symmetric: bool,
extra_variation: bool,
):
self.seed = seed
self.pal_types = pal_types
self.color_space: MarsschemaPalettesColorspace = color_space
self.symmetric = symmetric
self.extra_variation = extra_variation

@classmethod
def from_json(cls, data: MarsschemaPalettes) -> "PaletteSettings":
Expand All @@ -56,7 +58,8 @@ def from_json(cls, data: MarsschemaPalettes) -> "PaletteSettings":
pal_types[pal_type] = hue_range
color_space = data.get("ColorSpace", "Oklab")
symmetric = data.get("Symmetric", True)
return cls(seed, pal_types, color_space, symmetric)
# Extra variation is always enabled. This could be passed via JSON instead.
return cls(seed, pal_types, color_space, symmetric, True)

@classmethod
def get_hue_range(cls, data: MarsschemaPalettesRandomize) -> tuple[int, int]:
Expand All @@ -82,26 +85,37 @@ def __init__(self, rom: Rom, settings: PaletteSettings):
self.rom = rom
self.settings = settings
if settings.color_space == "HSV":
self.shift_func = self.shift_palette_hsv
self.change_func = self.change_palette_hsv
elif settings.color_space == "Oklab":
self.shift_func = self.shift_palette_oklab
self.change_func = self.change_palette_oklab
else:
raise ValueError(f"Invalid color space '{settings.color_space}' for color space!")

@staticmethod
def shift_palette_hsv(pal: Palette, shift: int, excluded_rows: set[int] = set()) -> None:
pal.shift_hue_hsv(shift, excluded_rows)
def change_palette_hsv(
pal: Palette, change: ColorChange, excluded_rows: set[int] = set()
) -> None:
pal.change_colors_hsv(change, excluded_rows)

@staticmethod
def shift_palette_oklab(pal: Palette, shift: int, excluded_rows: set[int] = set()) -> None:
pal.shift_hue_oklab(shift, excluded_rows)

def get_hue_shift(self, hue_range: tuple[int, int]) -> int:
"""Returns a hue shift in a random direction between hue_min and hue_max."""
shift = random.randint(hue_range[0], hue_range[1])
if self.settings.symmetric and random.random() < 0.5:
shift = 360 - shift
return shift
def change_palette_oklab(
pal: Palette, change: ColorChange, excluded_rows: set[int] = set()
) -> None:
pal.change_colors_oklab(change, excluded_rows)

def generate_palette_change(self, hue_range: tuple[int, int]) -> ColorChange:
"""Generates a random color change. hue_range determines how far each color's hue will be
initially rotated. Individual colors can be additionally rotated using the values of a
random sine wave."""
hue_shift = random.randint(hue_range[0], hue_range[1])
if self.settings.symmetric and random.choice([True, False]):
hue_shift = 360 - hue_shift
if self.settings.extra_variation:
hue_var_range = min(1.0, (hue_range[1] - hue_range[0]) / 180)
hue_var = SineWave.generate(hue_var_range)
else:
hue_var = None
return ColorChange(hue_shift, hue_var)

def randomize(self) -> None:
random.seed(self.settings.seed)
Expand All @@ -120,24 +134,24 @@ def randomize(self) -> None:
if self.rom.is_zm():
self.fix_zm_palettes()

def shift_palettes(self, pals: list[tuple[int, int]], shift: int) -> None:
def change_palettes(self, pals: list[tuple[int, int]], change: ColorChange) -> None:
for addr, rows in pals:
if addr in self.randomized_pals:
continue
pal = Palette(rows, self.rom, addr)
self.shift_func(pal, shift)
self.change_func(pal, change)
pal.write(self.rom, addr)
self.randomized_pals.add(addr)

def randomize_samus(self, hue_range: tuple[int, int]) -> None:
shift = self.get_hue_shift(hue_range)
self.shift_palettes(gd.samus_palettes(self.rom), shift)
self.shift_palettes(gd.helmet_cursor_palettes(self.rom), shift)
self.shift_palettes(gd.sax_palettes(self.rom), shift)
change = self.generate_palette_change(hue_range)
self.change_palettes(gd.samus_palettes(self.rom), change)
self.change_palettes(gd.helmet_cursor_palettes(self.rom), change)
self.change_palettes(gd.sax_palettes(self.rom), change)

def randomize_beams(self, hue_range: tuple[int, int]) -> None:
shift = self.get_hue_shift(hue_range)
self.shift_palettes(gd.beam_palettes(self.rom), shift)
change = self.generate_palette_change(hue_range)
self.change_palettes(gd.beam_palettes(self.rom), change)

def randomize_tilesets(self, hue_range: tuple[int, int]) -> None:
rom = self.rom
Expand All @@ -161,30 +175,30 @@ def randomize_tilesets(self, hue_range: tuple[int, int]) -> None:
excluded_rows = {row}
# Load palette and shift hue
pal = Palette(13, rom, pal_addr)
shift = self.get_hue_shift(hue_range)
self.shift_func(pal, shift, excluded_rows)
change = self.generate_palette_change(hue_range)
self.change_func(pal, change, excluded_rows)
pal.write(rom, pal_addr)
self.randomized_pals.add(pal_addr)
# Check animated palette
anim_pal_id = TILESET_ANIM_PALS.get(pal_addr)
if anim_pal_id is not None:
self.randomize_anim_palette(anim_pal_id, shift)
self.randomize_anim_palette(anim_pal_id, change)
anim_pal_to_randomize.remove(anim_pal_id)

# Go through remaining animated palettes
for anim_pal_id in anim_pal_to_randomize:
shift = self.get_hue_shift(hue_range)
self.randomize_anim_palette(anim_pal_id, shift)
change = self.generate_palette_change(hue_range)
self.randomize_anim_palette(anim_pal_id, change)

def randomize_anim_palette(self, anim_pal_id: int, shift: int) -> None:
def randomize_anim_palette(self, anim_pal_id: int, change: ColorChange) -> None:
rom = self.rom
addr = gd.anim_palette_entries(rom) + anim_pal_id * 8
pal_addr = rom.read_ptr(addr + 4)
if pal_addr in self.randomized_pals:
return
rows = rom.read_8(addr + 2)
pal = Palette(rows, rom, pal_addr)
self.shift_func(pal, shift)
self.change_func(pal, change)
pal.write(rom, pal_addr)
self.randomized_pals.add(pal_addr)

Expand All @@ -198,18 +212,19 @@ def randomize_enemies(self, hue_range: tuple[int, int]) -> None:
# Go through sprites in groups
groups = ENEMY_GROUPS[rom.game]
for _, sprite_ids in groups.items():
shift = self.get_hue_shift(hue_range)
change = self.generate_palette_change(hue_range)
for sprite_id in sprite_ids:
assert sprite_id in to_randomize, f"{sprite_id:X} should be excluded"
self.randomize_enemy(sprite_id, shift)
self.randomize_enemy(sprite_id, change)
to_randomize.remove(sprite_id)

# Go through remaining sprites
for sprite_id in to_randomize:
shift = self.get_hue_shift(hue_range)
self.randomize_enemy(sprite_id, shift)
change = self.generate_palette_change(hue_range)
self.randomize_enemy(sprite_id, change)

def randomize_enemy(self, sprite_id: int, shift: int) -> None:
def randomize_enemy(self, sprite_id: int, change: ColorChange) -> None:
# Get palette address and row count
rom = self.rom
sprite_gfx_id = sprite_id - 0x10
pal_ptr = gd.sprite_palette_ptrs(rom)
Expand All @@ -230,12 +245,13 @@ def randomize_enemy(self, sprite_id: int, shift: int) -> None:
rows = (rom.read_32(gfx_addr) >> 8) // 0x800
else:
raise ValueError("Unknown game!")
# Load palette, change colors, and write to ROM
pal = Palette(rows, rom, pal_addr)
self.shift_func(pal, shift)
self.change_func(pal, change)
pal.write(rom, pal_addr)
self.randomized_pals.add(pal_addr)
if rom.is_mf() and sprite_id == 0x26:
self.fix_nettori(shift)
self.fix_nettori(change)

def get_sprite_addr(self, sprite_id: int) -> int:
addr = gd.sprite_palette_ptrs(self.rom) + (sprite_id - 0x10) * 4
Expand All @@ -245,11 +261,11 @@ def get_tileset_addr(self, sprite_id: int) -> int:
addr = gd.tileset_entries(self.rom) + sprite_id * 0x14 + 4
return self.rom.read_ptr(addr)

def fix_nettori(self, shift: int) -> None:
def fix_nettori(self, change: ColorChange) -> None:
"""Nettori has extra palettes stored separately, so they require the same color change."""
for addr, rows in NETTORI_EXTRA_PALS:
pal = Palette(rows, self.rom, addr)
self.shift_func(pal, shift)
self.change_func(pal, change)
pal.write(self.rom, addr)

def fix_zm_palettes(self) -> None:
Expand Down