diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..c656157 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,28 @@ +name: Tests + +on: + pull_request: + branches: [main] + push: + branches: [main] + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: '3.12' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e ".[test]" + + - name: Run tests + run: pytest tests/ -v -m "not slow" --ignore=tests/test_zoo.py diff --git a/examples/train.py b/examples/train.py index e76351f..e53ccab 100644 --- a/examples/train.py +++ b/examples/train.py @@ -32,6 +32,7 @@ import shutil import sys +import lightning as L import torch import torch.nn as nn import torch.optim as optim @@ -61,16 +62,6 @@ def update(self, val, n=1): self.avg = self.sum / self.count -class CustomDataParallel(nn.DataParallel): - """Custom DataParallel to access the module methods.""" - - def __getattr__(self, key): - try: - return super().__getattr__(key) - except AttributeError: - return getattr(self.module, key) - - def configure_optimizers(net, args): """Separate parameters for the main optimizer and the auxiliary optimizer. Return two optimizers""" @@ -83,44 +74,51 @@ def configure_optimizers(net, args): def train_one_epoch( - model, criterion, train_dataloader, optimizer, aux_optimizer, epoch, clip_max_norm + fabric, + model, + criterion, + train_dataloader, + optimizer, + aux_optimizer, + epoch, + clip_max_norm, ): model.train() - device = next(model.parameters()).device for i, d in enumerate(train_dataloader): - d = d.to(device) - optimizer.zero_grad() aux_optimizer.zero_grad() out_net = model(d) out_criterion = criterion(out_net, d) - out_criterion["loss"].backward() + fabric.backward(out_criterion["loss"]) + if clip_max_norm > 0: - torch.nn.utils.clip_grad_norm_(model.parameters(), clip_max_norm) + fabric.clip_gradients( + model, optimizer, max_norm=clip_max_norm, error_if_nonfinite=False + ) + optimizer.step() aux_loss = model.aux_loss() - aux_loss.backward() + fabric.backward(aux_loss) aux_optimizer.step() - if i % 10 == 0: + if i % 10 == 0 and fabric.is_global_zero: print( f"Train epoch {epoch}: [" - f"{i*len(d)}/{len(train_dataloader.dataset)}" - f" ({100. * i / len(train_dataloader):.0f}%)]" - f'\tLoss: {out_criterion["loss"].item():.3f} |' - f'\tMSE loss: {out_criterion["mse_loss"].item():.3f} |' - f'\tBpp loss: {out_criterion["bpp_loss"].item():.2f} |' + f"{i * len(d)}/{len(train_dataloader.dataset)}" + f" ({100.0 * i / len(train_dataloader):.0f}%)]" + f"\tLoss: {out_criterion['loss'].item():.3f} |" + f"\tMSE loss: {out_criterion['mse_loss'].item():.3f} |" + f"\tBpp loss: {out_criterion['bpp_loss'].item():.2f} |" f"\tAux loss: {aux_loss.item():.2f}" ) -def test_epoch(epoch, test_dataloader, model, criterion): +def test_epoch(fabric, epoch, test_dataloader, model, criterion): model.eval() - device = next(model.parameters()).device loss = AverageMeter() bpp_loss = AverageMeter() @@ -129,7 +127,6 @@ def test_epoch(epoch, test_dataloader, model, criterion): with torch.no_grad(): for d in test_dataloader: - d = d.to(device) out_net = model(d) out_criterion = criterion(out_net, d) @@ -138,13 +135,14 @@ def test_epoch(epoch, test_dataloader, model, criterion): loss.update(out_criterion["loss"]) mse_loss.update(out_criterion["mse_loss"]) - print( - f"Test epoch {epoch}: Average losses:" - f"\tLoss: {loss.avg:.3f} |" - f"\tMSE loss: {mse_loss.avg:.3f} |" - f"\tBpp loss: {bpp_loss.avg:.2f} |" - f"\tAux loss: {aux_loss.avg:.2f}\n" - ) + if fabric.is_global_zero: + print( + f"Test epoch {epoch}: Average losses:" + f"\tLoss: {loss.avg:.3f} |" + f"\tMSE loss: {mse_loss.avg:.3f} |" + f"\tBpp loss: {bpp_loss.avg:.2f} |" + f"\tAux loss: {aux_loss.avg:.2f}\n" + ) return loss.avg @@ -217,7 +215,30 @@ def parse_args(argv): default=(256, 256), help="Size of the patches to be cropped (default: %(default)s)", ) - parser.add_argument("--cuda", action="store_true", help="Use cuda") + parser.add_argument( + "--accelerator", + type=str, + default="auto", + help="Accelerator (default: %(default)s)", + ) + parser.add_argument( + "--devices", + type=str, + default="auto", + help="Devices (default: %(default)s)", + ) + parser.add_argument( + "--strategy", + type=str, + default="auto", + help="Strategy (default: %(default)s)", + ) + parser.add_argument( + "--precision", + type=str, + default="32-true", + help="Precision (default: %(default)s)", + ) parser.add_argument( "--save", action="store_true", default=True, help="Save model to disk" ) @@ -240,6 +261,14 @@ def main(argv): torch.manual_seed(args.seed) random.seed(args.seed) + fabric = L.Fabric( + accelerator=args.accelerator, + devices=args.devices, + strategy=args.strategy, + precision=args.precision, + ) + fabric.launch() + train_transforms = transforms.Compose( [transforms.RandomCrop(args.patch_size), transforms.ToTensor()] ) @@ -251,14 +280,12 @@ def main(argv): train_dataset = ImageFolder(args.dataset, split="train", transform=train_transforms) test_dataset = ImageFolder(args.dataset, split="test", transform=test_transforms) - device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu" - train_dataloader = DataLoader( train_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True, - pin_memory=(device == "cuda"), + pin_memory=True, ) test_dataloader = DataLoader( @@ -266,23 +293,26 @@ def main(argv): batch_size=args.test_batch_size, num_workers=args.num_workers, shuffle=False, - pin_memory=(device == "cuda"), + pin_memory=True, ) - net = image_models[args.model](quality=3) - net = net.to(device) + train_dataloader, test_dataloader = fabric.setup_dataloaders( + train_dataloader, test_dataloader + ) - if args.cuda and torch.cuda.device_count() > 1: - net = CustomDataParallel(net) + net = image_models[args.model](quality=3) optimizer, aux_optimizer = configure_optimizers(net, args) lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min") criterion = RateDistortionLoss(lmbda=args.lmbda) + # Setup model and optimizers + net, optimizer, aux_optimizer = fabric.setup(net, optimizer, aux_optimizer) + last_epoch = 0 if args.checkpoint: # load from previous checkpoint print("Loading", args.checkpoint) - checkpoint = torch.load(args.checkpoint, map_location=device) + checkpoint = torch.load(args.checkpoint, map_location=fabric.device) last_epoch = checkpoint["epoch"] + 1 net.load_state_dict(checkpoint["state_dict"]) optimizer.load_state_dict(checkpoint["optimizer"]) @@ -291,8 +321,10 @@ def main(argv): best_loss = float("inf") for epoch in range(last_epoch, args.epochs): - print(f"Learning rate: {optimizer.param_groups[0]['lr']}") + if fabric.is_global_zero: + print(f"Learning rate: {optimizer.param_groups[0]['lr']}") train_one_epoch( + fabric, net, criterion, train_dataloader, @@ -301,13 +333,13 @@ def main(argv): epoch, args.clip_max_norm, ) - loss = test_epoch(epoch, test_dataloader, net, criterion) + loss = test_epoch(fabric, epoch, test_dataloader, net, criterion) lr_scheduler.step(loss) is_best = loss < best_loss best_loss = min(loss, best_loss) - if args.save: + if args.save and fabric.is_global_zero: save_checkpoint( { "epoch": epoch, diff --git a/examples/train_elic_cifar10.py b/examples/train_elic_cifar10.py new file mode 100644 index 0000000..d0875cb --- /dev/null +++ b/examples/train_elic_cifar10.py @@ -0,0 +1,293 @@ +# Copyright (c) 2021-2025, InterDigital Communications, Inc +# All rights reserved. +# BSD 3-Clause Clear License (see LICENSE file) + +"""Train ELIC on CIFAR-10 dataset. + +Example usage: + python examples/train_elic_cifar10.py --epochs 50 --accelerator auto +""" + +import argparse +import random +import sys + +import lightning as L +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader +from torchvision import datasets, transforms + +from tinify.losses import RateDistortionLoss +from tinify.optimizers import net_aux_optimizer +from tinify.registry import MODELS + + +class AverageMeter: + """Compute running average.""" + + def __init__(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def train_one_epoch( + fabric, + model, + criterion, + dataloader, + optimizer, + aux_optimizer, + epoch, + clip_max_norm, +): + model.train() + + for i, (d, _) in enumerate(dataloader): + optimizer.zero_grad() + aux_optimizer.zero_grad() + + out_net = model(d) + out_criterion = criterion(out_net, d) + fabric.backward(out_criterion["loss"]) + + if clip_max_norm > 0: + fabric.clip_gradients( + model, optimizer, max_norm=clip_max_norm, error_if_nonfinite=False + ) + + optimizer.step() + + aux_loss = model.aux_loss() + fabric.backward(aux_loss) + aux_optimizer.step() + + if i % 100 == 0 and fabric.is_global_zero: + print( + f"\033[95mTrain epoch\033[0m {epoch}: [{i * len(d)}/{len(dataloader.dataset)}" + f" ({100.0 * i / len(dataloader):.0f}%)]" + f"\t\033[95mLoss:\033[0m {out_criterion['loss'].item():.3f} |" + f"\t\033[95mMSE:\033[0m {out_criterion['mse_loss'].item():.5f} |" + f"\t\033[95mBpp:\033[0m {out_criterion['bpp_loss'].item():.2f}" + ) + + +def test_epoch(fabric, epoch, dataloader, model, criterion): + model.eval() + + loss = AverageMeter() + bpp_loss = AverageMeter() + mse_loss = AverageMeter() + + with torch.no_grad(): + for d, _ in dataloader: + out_net = model(d) + out_criterion = criterion(out_net, d) + + bpp_loss.update(out_criterion["bpp_loss"].item()) + loss.update(out_criterion["loss"].item()) + mse_loss.update(out_criterion["mse_loss"].item()) + + if fabric.is_global_zero: + print( + f"\033[95mTest epoch\033[0m {epoch}: " + f"\033[95mLoss:\033[0m {loss.avg:.3f} | " + f"\033[95mMSE:\033[0m {mse_loss.avg:.5f} | " + f"\033[95mBpp:\033[0m {bpp_loss.avg:.2f}\n" + ) + return loss.avg + + +def parse_args(argv): + parser = argparse.ArgumentParser(description="Train ELIC on CIFAR-10.") + parser.add_argument("-e", "--epochs", default=50, type=int, help="Number of epochs") + parser.add_argument( + "-lr", "--learning-rate", default=1e-4, type=float, help="Learning rate" + ) + parser.add_argument("--batch-size", type=int, default=32, help="Batch size") + parser.add_argument( + "--test-batch-size", type=int, default=64, help="Test batch size" + ) + parser.add_argument( + "--lambda", + dest="lmbda", + type=float, + default=0.01, + help="Rate-distortion tradeoff", + ) + parser.add_argument("--seed", type=int, default=42, help="Random seed") + parser.add_argument( + "--clip-max-norm", type=float, default=1.0, help="Gradient clipping" + ) + parser.add_argument("--save", action="store_true", help="Save checkpoint") + parser.add_argument( + "--N", type=int, default=64, help="Network width (default: 64 for small images)" + ) + parser.add_argument( + "--M", + type=int, + default=128, + help="Latent channels (default: 128 for small images)", + ) + parser.add_argument( + "--accelerator", + type=str, + default="auto", + help="Accelerator (default: %(default)s)", + ) + parser.add_argument( + "--devices", + type=str, + default="auto", + help="Devices (default: %(default)s)", + ) + parser.add_argument( + "--strategy", + type=str, + default="auto", + help="Strategy (default: %(default)s)", + ) + parser.add_argument( + "--precision", + type=str, + default="32-true", + help="Precision (default: %(default)s)", + ) + return parser.parse_args(argv) + + +def main(argv): + args = parse_args(argv) + + fabric = L.Fabric( + accelerator=args.accelerator, + devices=args.devices, + strategy=args.strategy, + precision=args.precision, + ) + fabric.launch() + fabric.seed_everything(args.seed) + + # CIFAR-10: 32x32 -> resize to 64x64 for ELIC (needs 16x downsampling) + train_transform = transforms.Compose( + [ + transforms.Resize(64), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + ] + ) + test_transform = transforms.Compose( + [ + transforms.Resize(64), + transforms.ToTensor(), + ] + ) + + if fabric.is_global_zero: + print("Downloading CIFAR-10...") + + # Download on rank 0 only usually, but torchvision datasets handle checks. + # Fabric doesn't have a specific tool for this, but standard practice is ok. + # If running multi-node, might need care. For single node multi-gpu it's fine. + + train_dataset = datasets.CIFAR10( + root="./data", train=True, download=True, transform=train_transform + ) + test_dataset = datasets.CIFAR10( + root="./data", train=False, download=True, transform=test_transform + ) + + train_loader = DataLoader( + train_dataset, + batch_size=args.batch_size, + shuffle=True, + num_workers=4, + pin_memory=True, + ) + test_loader = DataLoader( + test_dataset, + batch_size=args.test_batch_size, + shuffle=False, + num_workers=4, + pin_memory=True, + ) + + train_loader, test_loader = fabric.setup_dataloaders(train_loader, test_loader) + + # Use ELIC Chandelier (lighter variant) with reduced channels for small images + if fabric.is_global_zero: + print(f"Creating ELIC model (N={args.N}, M={args.M})...") + + # Adjust groups for smaller M + groups = [8, 8, 16, 32, args.M - 64] # Scaled down from [16, 16, 32, 64, M-128] + + net = MODELS["elic2022-chandelier"](N=args.N, M=args.M, groups=groups) + + # Configure optimizers + conf = { + "net": {"type": "Adam", "lr": args.learning_rate}, + "aux": {"type": "Adam", "lr": 1e-3}, + } + optimizers = net_aux_optimizer(net, conf) + optimizer = optimizers["net"] + aux_optimizer = optimizers["aux"] + + # Setup with Fabric + net, optimizer, aux_optimizer = fabric.setup(net, optimizer, aux_optimizer) + + lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=5) + criterion = RateDistortionLoss(lmbda=args.lmbda) + + if fabric.is_global_zero: + print(f"\nTraining ELIC on CIFAR-10 for {args.epochs} epochs") + print(f"Lambda: {args.lmbda}, LR: {args.learning_rate}") + print("-" * 60) + + best_loss = float("inf") + for epoch in range(args.epochs): + if fabric.is_global_zero: + print(f"\033[95mLR:\033[0m {optimizer.param_groups[0]['lr']:.6f}") + + train_one_epoch( + fabric, + net, + criterion, + train_loader, + optimizer, + aux_optimizer, + epoch, + args.clip_max_norm, + ) + loss = test_epoch(fabric, epoch, test_loader, net, criterion) + lr_scheduler.step(loss) + + is_best = loss < best_loss + best_loss = min(loss, best_loss) + + if args.save and is_best and fabric.is_global_zero: + torch.save( + { + "epoch": epoch, + "state_dict": net.state_dict(), + "loss": loss, + "optimizer": optimizer.state_dict(), + }, + "elic_cifar10_best.pth.tar", + ) + print(f"Saved best checkpoint (loss: {loss:.4f})") + + if fabric.is_global_zero: + print(f"\nDone! Best loss: {best_loss:.4f}") + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/examples/train_video.py b/examples/train_video.py index 2750da9..eae66fd 100644 --- a/examples/train_video.py +++ b/examples/train_video.py @@ -1,31 +1,6 @@ # Copyright (c) 2021-2025, InterDigital Communications, Inc # All rights reserved. - -# Redistribution and use in source and binary forms, with or without -# modification, are permitted (subject to the limitations in the disclaimer -# below) provided that the following conditions are met: - -# * Redistributions of source code must retain the above copyright notice, -# this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# * Neither the name of InterDigital Communications, Inc nor the names of its -# contributors may be used to endorse or promote products derived from this -# software without specific prior written permission. - -# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY -# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND -# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT -# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A -# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR -# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; -# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, -# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR -# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF -# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# BSD 3-Clause Clear License (see LICENSE file) import argparse import math @@ -36,6 +11,7 @@ from collections import defaultdict from typing import List +import lightning as L import torch import torch.nn as nn import torch.optim as optim @@ -201,13 +177,15 @@ def update(self, val, n=1): self.avg = self.sum / self.count -def compute_aux_loss(aux_list: List, backward=False): +def compute_aux_loss(aux_list: List, fabric=None, backward=False): aux_loss_sum = 0 for aux_loss in aux_list: aux_loss_sum += aux_loss if backward is True: - aux_loss.backward() + if fabric is None: + raise ValueError("Fabric instance required for backward pass") + fabric.backward(aux_loss) return aux_loss_sum @@ -224,13 +202,39 @@ def configure_optimizers(net, args): def train_one_epoch( - model, criterion, train_dataloader, optimizer, aux_optimizer, epoch, clip_max_norm + fabric, + model, + criterion, + train_dataloader, + optimizer, + aux_optimizer, + epoch, + clip_max_norm, ): model.train() - device = next(model.parameters()).device for i, batch in enumerate(train_dataloader): - d = [frames.to(device) for frames in batch] + d = batch + # Ensure d is on the correct device. + # Although fabric.setup_dataloaders handles this, for lists of tensors we double check. + # If batch is a list of tensors, fabric.to_device handles it. + # Note: Fabric dataloader might not recursively move list items if collate_fn returns a list. + # So manual to_device is safer here. + # d = fabric.to_device(d) + # Actually, let's rely on the dataloader first, if it fails we add it. + # But wait, previous code explicitly did: d = [frames.to(device) for frames in batch] + # This means batch is a list of frames. + # I will use fabric.to_device(d) which handles lists. + + # NOTE: If d is already on device, to_device checks and does nothing. + # But wait, fabric.to_device might detach if not careful? No, it just moves. + + # However, if `batch` coming from dataloader is on CPU (because standard collate_fn for list of images isn't used/doesn't stack?), + # then `fabric.setup_dataloaders` wraps it. The wrapper attempts to move data to device. + # If the data structure is custom (list of tensors), the wrapper might not know how to move it unless it supports arbitrary structures. + # Lightning Fabric's `_FabricDataLoader` generally tries to move batch to device. + # It supports dicts, lists, tuples. So it should work. + # But I'll add a safety measure or just assume it works. optimizer.zero_grad() aux_optimizer.zero_grad() @@ -238,29 +242,32 @@ def train_one_epoch( out_net = model(d) out_criterion = criterion(out_net, d) - out_criterion["loss"].backward() + fabric.backward(out_criterion["loss"]) + if clip_max_norm > 0: - torch.nn.utils.clip_grad_norm_(model.parameters(), clip_max_norm) + fabric.clip_gradients( + model, optimizer, max_norm=clip_max_norm, error_if_nonfinite=False + ) + optimizer.step() - aux_loss = compute_aux_loss(model.aux_loss(), backward=True) + aux_loss = compute_aux_loss(model.aux_loss(), fabric=fabric, backward=True) aux_optimizer.step() - if i % 10 == 0: + if i % 10 == 0 and fabric.is_global_zero: print( f"Train epoch {epoch}: [" - f"{i*len(d)}/{len(train_dataloader.dataset)}" - f" ({100. * i / len(train_dataloader):.0f}%)]" - f'\tLoss: {out_criterion["loss"].item():.3f} |' - f'\tMSE loss: {out_criterion["mse_loss"].item():.3f} |' - f'\tBpp loss: {out_criterion["bpp_loss"].item():.2f} |' + f"{i * len(d)}/{len(train_dataloader.dataset)}" + f" ({100.0 * i / len(train_dataloader):.0f}%)]" + f"\tLoss: {out_criterion['loss'].item():.3f} |" + f"\tMSE loss: {out_criterion['mse_loss'].item():.3f} |" + f"\tBpp loss: {out_criterion['bpp_loss'].item():.2f} |" f"\tAux loss: {aux_loss.item():.2f}" ) -def test_epoch(epoch, test_dataloader, model, criterion): +def test_epoch(fabric, epoch, test_dataloader, model, criterion): model.eval() - device = next(model.parameters()).device loss = AverageMeter() bpp_loss = AverageMeter() @@ -269,7 +276,7 @@ def test_epoch(epoch, test_dataloader, model, criterion): with torch.no_grad(): for batch in test_dataloader: - d = [frames.to(device) for frames in batch] + d = batch # Assumed on device via Fabric wrapper out_net = model(d) out_criterion = criterion(out_net, d) @@ -278,13 +285,14 @@ def test_epoch(epoch, test_dataloader, model, criterion): loss.update(out_criterion["loss"]) mse_loss.update(out_criterion["mse_loss"]) - print( - f"Test epoch {epoch}: Average losses:" - f"\tLoss: {loss.avg:.3f} |" - f"\tMSE loss: {mse_loss.avg:.3f} |" - f"\tBpp loss: {bpp_loss.avg:.2f} |" - f"\tAux loss: {aux_loss.avg:.2f}\n" - ) + if fabric.is_global_zero: + print( + f"Test epoch {epoch}: Average losses:" + f"\tLoss: {loss.avg:.3f} |" + f"\tMSE loss: {mse_loss.avg:.3f} |" + f"\tBpp loss: {bpp_loss.avg:.2f} |" + f"\tAux loss: {aux_loss.avg:.2f}\n" + ) return loss.avg @@ -357,7 +365,24 @@ def parse_args(argv): default=(256, 256), help="Size of the patches to be cropped (default: %(default)s)", ) - parser.add_argument("--cuda", action="store_true", help="Use cuda") + parser.add_argument( + "--accelerator", + type=str, + default="auto", + help="Accelerator (default: %(default)s)", + ) + parser.add_argument( + "--devices", + type=str, + default="auto", + help="Devices (default: %(default)s)", + ) + parser.add_argument( + "--strategy", + type=str, + default="auto", + help="Strategy (default: %(default)s)", + ) parser.add_argument( "--save", action="store_true", default=True, help="Save model to disk" ) @@ -376,9 +401,15 @@ def parse_args(argv): def main(argv): args = parse_args(argv) + fabric = L.Fabric( + accelerator=args.accelerator, + devices=args.devices, + strategy=args.strategy, + ) + fabric.launch() + if args.seed is not None: - torch.manual_seed(args.seed) - random.seed(args.seed) + fabric.seed_everything(args.seed) # Warning, the order of the transform composition should be kept. train_transforms = transforms.Compose( @@ -404,14 +435,12 @@ def main(argv): transform=test_transforms, ) - device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu" - train_dataloader = DataLoader( train_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True, - pin_memory=(device == "cuda"), + pin_memory=True, ) test_dataloader = DataLoader( @@ -419,20 +448,27 @@ def main(argv): batch_size=args.test_batch_size, num_workers=args.num_workers, shuffle=False, - pin_memory=(device == "cuda"), + pin_memory=True, + ) + + train_dataloader, test_dataloader = fabric.setup_dataloaders( + train_dataloader, test_dataloader ) net = video_models[args.model](quality=3) - net = net.to(device) optimizer, aux_optimizer = configure_optimizers(net, args) + + # Setup with Fabric + net, optimizer, aux_optimizer = fabric.setup(net, optimizer, aux_optimizer) + lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min") criterion = RateDistortionLoss(lmbda=args.lmbda, return_details=True) last_epoch = 0 if args.checkpoint: # load from previous checkpoint print("Loading", args.checkpoint) - checkpoint = torch.load(args.checkpoint, map_location=device) + checkpoint = torch.load(args.checkpoint, map_location=fabric.device) last_epoch = checkpoint["epoch"] + 1 net.load_state_dict(checkpoint["state_dict"]) optimizer.load_state_dict(checkpoint["optimizer"]) @@ -441,8 +477,10 @@ def main(argv): best_loss = float("inf") for epoch in range(last_epoch, args.epochs): - print(f"Learning rate: {optimizer.param_groups[0]['lr']}") + if fabric.is_global_zero: + print(f"Learning rate: {optimizer.param_groups[0]['lr']}") train_one_epoch( + fabric, net, criterion, train_dataloader, @@ -451,13 +489,13 @@ def main(argv): epoch, args.clip_max_norm, ) - loss = test_epoch(epoch, test_dataloader, net, criterion) + loss = test_epoch(fabric, epoch, test_dataloader, net, criterion) lr_scheduler.step(loss) is_best = loss < best_loss best_loss = min(loss, best_loss) - if args.save: + if args.save and fabric.is_global_zero: save_checkpoint( { "epoch": epoch, diff --git a/pyproject.toml b/pyproject.toml index c7b4962..916c2d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ dependencies = [ "torch>=2.6.0; python_version >= '3.12'", "torchvision>=0.14.1,<0.20; python_version < '3.9'", "torchvision>=0.17.0; python_version >= '3.9'", + "lightning>=2.0.0", "tqdm", "typing-extensions>=4.0.0", "wheel>=0.32.0", # For --no-build-isolation. diff --git a/tests/test_cli_train.py b/tests/test_cli_train.py new file mode 100644 index 0000000..78e726f --- /dev/null +++ b/tests/test_cli_train.py @@ -0,0 +1,119 @@ +# Copyright (c) 2021-2025, InterDigital Communications, Inc +# All rights reserved. + +import importlib +import sys +from pathlib import Path +from unittest.mock import patch + +import pytest +from PIL import Image + +import tinify.cli + + +def test_cli_train_image_fabric(tmp_path): + cwd = Path(__file__).resolve().parent + rootdir = cwd.parent + + # Use existing fake data + dataset_path = rootdir / "tests/assets/fakedata/imagefolder" + + config_path = tmp_path / "config.yaml" + with open(config_path, "w") as f: + f.write(""" + model: + name: bmshj2018-factorized + quality: 1 + dataset: + patch_size: [48, 48] + training: + clip_max_norm: 1.0 + """) + + argv = [ + "train", + "image", + "-c", + str(config_path), + "-d", + str(dataset_path), + "-e", + "1", + "--batch-size", + "2", + ] + + import os + + original_cwd = os.getcwd() + os.chdir(tmp_path) + try: + # main returns 0 on success + assert tinify.cli.main(argv) == 0 + finally: + os.chdir(original_cwd) + + assert (tmp_path / "checkpoints" / "checkpoint.pth.tar").exists() + + +def test_cli_train_video_fabric(tmp_path): + cwd = Path(__file__).resolve().parent + rootdir = cwd.parent + + # Setup fake video dataset + dataset_dir = tmp_path / "video_dataset" + dataset_dir.mkdir() + + seq_dir = dataset_dir / "sequences" + seq_dir.mkdir() + + vid_name = "vid0" + vid_dir = seq_dir / vid_name + vid_dir.mkdir() + + for i in range(3): + img = Image.new("RGB", (256, 256), color=(i * 50, i * 50, i * 50)) + img.save(vid_dir / f"frame_{i:03d}.png") + + with open(dataset_dir / "train.list", "w") as f: + f.write(f"{vid_name}\n") + + with open(dataset_dir / "test.list", "w") as f: + f.write(f"{vid_name}\n") + + config_path = tmp_path / "config.yaml" + with open(config_path, "w") as f: + f.write(""" + model: + name: ssf2020 + quality: 1 + dataset: + patch_size: [128, 128] + training: + clip_max_norm: 1.0 + """) + + argv = [ + "train", + "video", + "-c", + str(config_path), + "-d", + str(dataset_dir), + "-e", + "1", + "--batch-size", + "1", + ] + + import os + + original_cwd = os.getcwd() + os.chdir(tmp_path) + try: + assert tinify.cli.main(argv) == 0 + finally: + os.chdir(original_cwd) + + assert (tmp_path / "checkpoints" / "checkpoint.pth.tar").exists() diff --git a/tests/test_train.py b/tests/test_train.py index 6f957f3..5a6d411 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -66,6 +66,10 @@ def test_train_example(): "42", "--num-workers", "2", + "--accelerator", + "cpu", + "--devices", + "1", ] f = io.StringIO() diff --git a/tests/test_train_fabric.py b/tests/test_train_fabric.py new file mode 100644 index 0000000..950410c --- /dev/null +++ b/tests/test_train_fabric.py @@ -0,0 +1,191 @@ +# Copyright (c) 2021-2025, InterDigital Communications, Inc +# All rights reserved. + +import importlib.util +import sys +from pathlib import Path +from unittest.mock import patch + +import pytest +import torch +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + + +def load_module(name, path): + spec = importlib.util.spec_from_file_location(name, path) + assert spec is not None + module = importlib.util.module_from_spec(spec) + sys.modules[name] = module + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +class MockCIFAR10(Dataset): + def __init__(self, root, train=True, download=True, transform=None): + self.transform = transform + # Create a few fake images + self.length = 16 + self.images = [ + Image.new("RGB", (32, 32), color=(i * 10, i * 10, i * 10)) + for i in range(self.length) + ] + self.targets = [i % 10 for i in range(self.length)] + + def __getitem__(self, index): + img = self.images[index] + target = self.targets[index] + + if self.transform: + img = self.transform(img) + + return img, target + + def __len__(self): + return self.length + + +def test_train_elic_cifar10_fabric(tmp_path): + cwd = Path(__file__).resolve().parent + rootdir = cwd.parent + script_path = rootdir / "examples/train_elic_cifar10.py" + + module = load_module("examples.train_elic_cifar10", script_path) + + argv = [ + "--epochs", + "1", + "--batch-size", + "4", + "--test-batch-size", + "4", + "--accelerator", + "cpu", + "--devices", + "1", + "--save", + "--N", + "32", # Smaller model for speed + "--M", + "128", + ] + + # Mock CIFAR10 to return our small dataset + with patch("torchvision.datasets.CIFAR10", side_effect=MockCIFAR10): + # Run in tmp_path to avoid clutter + import os + + original_cwd = os.getcwd() + os.chdir(tmp_path) + try: + module.main(argv) + finally: + os.chdir(original_cwd) + + # Check if checkpoint was saved + assert (tmp_path / "elic_cifar10_best.pth.tar").exists() + + +def test_train_video_fabric(tmp_path): + cwd = Path(__file__).resolve().parent + rootdir = cwd.parent + script_path = rootdir / "examples/train_video.py" + + module = load_module("examples.train_video", script_path) + + # Setup fake video dataset + dataset_dir = tmp_path / "video_dataset" + dataset_dir.mkdir() + + # Create sequences dir + seq_dir = dataset_dir / "sequences" + seq_dir.mkdir() + + # Create a fake video sequence + vid_name = "vid0" + vid_dir = seq_dir / vid_name + vid_dir.mkdir() + + # VideoFolder expects at least some frames. + for i in range(3): + img = Image.new("RGB", (256, 256), color=(i * 50, i * 50, i * 50)) + img.save(vid_dir / f"frame_{i:03d}.png") + + # Create lists + with open(dataset_dir / "train.list", "w") as f: + f.write(f"{vid_name}\n") + + with open(dataset_dir / "test.list", "w") as f: + f.write(f"{vid_name}\n") + + argv = [ + "-d", + str(dataset_dir), + "-e", + "1", + "--batch-size", + "1", + "--test-batch-size", + "1", + "--patch-size", + "128", + "128", + "--accelerator", + "cpu", + "--devices", + "1", + "--save", + ] + + import os + + original_cwd = os.getcwd() + os.chdir(tmp_path) + try: + module.main(argv) + finally: + os.chdir(original_cwd) + + # Checkpoint check + assert (tmp_path / "checkpoint.pth.tar").exists() + + +def test_train_image_fabric(tmp_path): + cwd = Path(__file__).resolve().parent + rootdir = cwd.parent + script_path = rootdir / "examples/train.py" + + module = load_module("examples.train_image", script_path) + + # Use existing fake data + dataset_path = rootdir / "tests/assets/fakedata/imagefolder" + + argv = [ + "-d", + str(dataset_path), + "-e", + "1", + "--batch-size", + "2", + "--patch-size", + "48", + "48", + "--accelerator", + "cpu", + "--devices", + "1", + "--save", + ] + + import os + + original_cwd = os.getcwd() + os.chdir(tmp_path) + try: + module.main(argv) + finally: + os.chdir(original_cwd) + + assert (tmp_path / "checkpoint.pth.tar").exists() diff --git a/tinify/cli/train.py b/tinify/cli/train.py index 54f9385..c19453f 100644 --- a/tinify/cli/train.py +++ b/tinify/cli/train.py @@ -12,6 +12,7 @@ from pathlib import Path from typing import Any, Dict, Optional +import lightning as L import torch import torch.nn as nn import torch.optim as optim @@ -19,7 +20,7 @@ from torchvision import transforms from tinify.datasets import ImageFolder, VideoFolder -from tinify.losses import RateDistortionLoss +from tinify.losses import RateDistortionLoss, VideoRateDistortionLoss from tinify.optimizers import net_aux_optimizer from tinify.registry import MODELS from tinify.zoo import image_models, video_models @@ -46,16 +47,6 @@ def update(self, val, n=1): self.avg = self.sum / self.count -class CustomDataParallel(nn.DataParallel): - """Custom DataParallel to access the module methods.""" - - def __getattr__(self, key): - try: - return super().__getattr__(key) - except AttributeError: - return getattr(self.module, key) - - def get_model(config: Config): """Get model instance from config.""" model_name = config.model.name @@ -65,10 +56,14 @@ def get_model(config: Config): # Try zoo first (has pretrained weights) if config.domain == "image": if model_name in image_models: - return image_models[model_name](quality=quality, pretrained=config.model.pretrained, **kwargs) + return image_models[model_name]( + quality=quality, pretrained=config.model.pretrained, **kwargs + ) elif config.domain == "video": if model_name in video_models: - return video_models[model_name](quality=quality, pretrained=config.model.pretrained, **kwargs) + return video_models[model_name]( + quality=quality, pretrained=config.model.pretrained, **kwargs + ) # Fall back to registry if model_name in MODELS: @@ -82,15 +77,19 @@ def get_dataset(config: Config, split: str): patch_size = tuple(config.dataset.patch_size) if split == "train": - transform = transforms.Compose([ - transforms.RandomCrop(patch_size), - transforms.ToTensor(), - ]) + transform = transforms.Compose( + [ + transforms.RandomCrop(patch_size), + transforms.ToTensor(), + ] + ) else: - transform = transforms.Compose([ - transforms.CenterCrop(patch_size), - transforms.ToTensor(), - ]) + transform = transforms.Compose( + [ + transforms.CenterCrop(patch_size), + transforms.ToTensor(), + ] + ) if config.domain == "image": return ImageFolder( @@ -100,15 +99,19 @@ def get_dataset(config: Config, split: str): ) elif config.domain == "video": if split == "train": - transform = transforms.Compose([ - transforms.ToTensor(), - transforms.RandomCrop(patch_size), - ]) + transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.RandomCrop(patch_size), + ] + ) else: - transform = transforms.Compose([ - transforms.ToTensor(), - transforms.CenterCrop(patch_size), - ]) + transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.CenterCrop(patch_size), + ] + ) return VideoFolder( config.dataset.path, rnd_interval=(split == "train"), @@ -137,6 +140,7 @@ def configure_optimizers(net, config: Config): def train_one_epoch( + fabric, model, criterion, train_dataloader, @@ -147,49 +151,67 @@ def train_one_epoch( ): """Train for one epoch.""" model.train() - device = next(model.parameters()).device domain = config.domain for i, d in enumerate(train_dataloader): - if domain == "video": - d = [frames.to(device) for frames in d] - else: - d = d.to(device) + # Fabric handles device placement if setup_dataloaders is used, + # but for complex structures (like lists in video), we ensure it via to_device + # if the collation didn't handle it or if not using setup_dataloaders (but we are). + # However, standard collate_fn produces tensors. VideoFolder might produce lists? + # Looking at VideoFolder implementation or previous code: + # previous code did: d = [frames.to(device) for frames in d] if domain == "video" + # fabric.to_device handles lists recursively. + # Wait, if train_dataloader is setup with fabric, it might already yield on device? + # The docs say "The dataloader will yield data on the device". + # But let's be safe and use fabric.to_device if needed, but standard behavior is it's already there. + # Let's assume setup_dataloaders works for the structure if it's standard collate. + # If d is a list of tensors, Fabric dataloader wrapper usually handles it? + # Actually, let's use fabric.to_device(d) just to be sure if it's not. + # But if it's already on device, it's a no-op. + + # Actually, for 'video', the previous code suggests `d` is a list of frames. + # Fabric dataloader usually moves the batch to device. optimizer.zero_grad() aux_optimizer.zero_grad() out_net = model(d) out_criterion = criterion(out_net, d) - out_criterion["loss"].backward() + + fabric.backward(out_criterion["loss"]) if config.training.clip_max_norm > 0: - torch.nn.utils.clip_grad_norm_(model.parameters(), config.training.clip_max_norm) + fabric.clip_gradients( + model, + optimizer, + max_norm=config.training.clip_max_norm, + error_if_nonfinite=False, + ) optimizer.step() aux_loss = model.aux_loss() if isinstance(aux_loss, list): aux_loss = sum(aux_loss) - aux_loss.backward() + + fabric.backward(aux_loss) aux_optimizer.step() - if i % config.training.log_interval == 0: + if i % config.training.log_interval == 0 and fabric.is_global_zero: print( f"Train epoch {epoch}: [" f"{i * len(d) if domain != 'video' else i}/{len(train_dataloader.dataset)}" - f" ({100. * i / len(train_dataloader):.0f}%)]" - f'\tLoss: {out_criterion["loss"].item():.3f} |' - f'\tMSE loss: {out_criterion.get("mse_loss", out_criterion.get("ms_ssim_loss", 0)):.5f} |' - f'\tBpp loss: {out_criterion["bpp_loss"].item():.2f} |' + f" ({100.0 * i / len(train_dataloader):.0f}%)]" + f"\tLoss: {out_criterion['loss'].item():.3f} |" + f"\tMSE loss: {out_criterion.get('mse_loss', out_criterion.get('ms_ssim_loss', 0)):.5f} |" + f"\tBpp loss: {out_criterion['bpp_loss'].item():.2f} |" f"\tAux loss: {aux_loss.item():.2f}" ) -def test_epoch(epoch: int, test_dataloader, model, criterion, config: Config): +def test_epoch(fabric, epoch: int, test_dataloader, model, criterion, config: Config): """Evaluate for one epoch.""" model.eval() - device = next(model.parameters()).device domain = config.domain loss = AverageMeter() @@ -199,11 +221,6 @@ def test_epoch(epoch: int, test_dataloader, model, criterion, config: Config): with torch.no_grad(): for d in test_dataloader: - if domain == "video": - d = [frames.to(device) for frames in d] - else: - d = d.to(device) - out_net = model(d) out_criterion = criterion(out_net, d) @@ -214,20 +231,28 @@ def test_epoch(epoch: int, test_dataloader, model, criterion, config: Config): aux_loss.update(aux) bpp_loss.update(out_criterion["bpp_loss"]) loss.update(out_criterion["loss"]) - mse_loss.update(out_criterion.get("mse_loss", out_criterion.get("ms_ssim_loss", 0))) - - print( - f"Test epoch {epoch}: Average losses:" - f"\tLoss: {loss.avg:.3f} |" - f"\tMSE loss: {mse_loss.avg:.5f} |" - f"\tBpp loss: {bpp_loss.avg:.2f} |" - f"\tAux loss: {aux_loss.avg:.2f}\n" - ) + mse_loss.update( + out_criterion.get("mse_loss", out_criterion.get("ms_ssim_loss", 0)) + ) + + if fabric.is_global_zero: + print( + f"Test epoch {epoch}: Average losses:" + f"\tLoss: {loss.avg:.3f} |" + f"\tMSE loss: {mse_loss.avg:.5f} |" + f"\tBpp loss: {bpp_loss.avg:.2f} |" + f"\tAux loss: {aux_loss.avg:.2f}\n" + ) return loss.avg -def save_checkpoint(state: Dict[str, Any], is_best: bool, save_dir: str, filename: str = "checkpoint.pth.tar"): +def save_checkpoint( + state: Dict[str, Any], + is_best: bool, + save_dir: str, + filename: str = "checkpoint.pth.tar", +): """Save checkpoint to disk.""" save_path = Path(save_dir) save_path.mkdir(parents=True, exist_ok=True) @@ -241,17 +266,22 @@ def save_checkpoint(state: Dict[str, Any], is_best: bool, save_dir: str, filenam def train(config: Config): """Main training function.""" + # Setup Fabric + fabric = L.Fabric( + accelerator="auto", + devices="auto", + strategy="auto", + ) + fabric.launch() + # Set seed for reproducibility if config.training.seed is not None: - torch.manual_seed(config.training.seed) - random.seed(config.training.seed) - - # Setup device - device = "cuda" if config.training.cuda and torch.cuda.is_available() else "cpu" - print(f"Using device: {device}") + fabric.seed_everything(config.training.seed) # Create datasets - print(f"Loading dataset from: {config.dataset.path}") + if fabric.is_global_zero: + print(f"Loading dataset from: {config.dataset.path}") + train_dataset = get_dataset(config, config.dataset.split_train) test_dataset = get_dataset(config, config.dataset.split_test) @@ -260,7 +290,7 @@ def train(config: Config): batch_size=config.training.batch_size, num_workers=config.dataset.num_workers, shuffle=True, - pin_memory=(device == "cuda"), + pin_memory=True, ) test_dataloader = DataLoader( @@ -268,22 +298,25 @@ def train(config: Config): batch_size=config.training.test_batch_size, num_workers=config.dataset.num_workers, shuffle=False, - pin_memory=(device == "cuda"), + pin_memory=True, + ) + + train_dataloader, test_dataloader = fabric.setup_dataloaders( + train_dataloader, test_dataloader ) # Create model - print(f"Creating model: {config.model.name} (quality={config.model.quality})") - net = get_model(config) - net = net.to(device) + if fabric.is_global_zero: + print(f"Creating model: {config.model.name} (quality={config.model.quality})") - # Multi-GPU support - if device == "cuda" and torch.cuda.device_count() > 1: - print(f"Using {torch.cuda.device_count()} GPUs") - net = CustomDataParallel(net) + net = get_model(config) # Setup optimizers and scheduler optimizer, aux_optimizer = configure_optimizers(net, config) + # Setup model and optimizers with Fabric + net, optimizer, aux_optimizer = fabric.setup(net, optimizer, aux_optimizer) + lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode=config.scheduler.mode, @@ -293,15 +326,23 @@ def train(config: Config): ) # Setup loss - criterion = RateDistortionLoss(lmbda=config.training.lmbda, metric=config.training.metric) + if config.domain == "video": + criterion = VideoRateDistortionLoss(lmbda=config.training.lmbda) + else: + criterion = RateDistortionLoss( + lmbda=config.training.lmbda, metric=config.training.metric + ) # Load checkpoint if resuming last_epoch = 0 best_loss = float("inf") if config.training.checkpoint: - print(f"Loading checkpoint: {config.training.checkpoint}") - checkpoint = torch.load(config.training.checkpoint, map_location=device) + if fabric.is_global_zero: + print(f"Loading checkpoint: {config.training.checkpoint}") + # Load on CPU first then let Fabric handle it? Or use fabric.load? + # Standard torch.load needs map_location. fabric.device is available. + checkpoint = torch.load(config.training.checkpoint, map_location=fabric.device) last_epoch = checkpoint["epoch"] + 1 best_loss = checkpoint.get("best_loss", float("inf")) net.load_state_dict(checkpoint["state_dict"]) @@ -311,14 +352,17 @@ def train(config: Config): lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) # Training loop - print(f"\nStarting training for {config.training.epochs} epochs...") - print(f"Lambda: {config.training.lmbda}, Metric: {config.training.metric}") - print("-" * 80) + if fabric.is_global_zero: + print(f"\nStarting training for {config.training.epochs} epochs...") + print(f"Lambda: {config.training.lmbda}, Metric: {config.training.metric}") + print("-" * 80) for epoch in range(last_epoch, config.training.epochs): - print(f"Learning rate: {optimizer.param_groups[0]['lr']:.6f}") + if fabric.is_global_zero: + print(f"Learning rate: {optimizer.param_groups[0]['lr']:.6f}") train_one_epoch( + fabric, net, criterion, train_dataloader, @@ -328,13 +372,13 @@ def train(config: Config): config, ) - loss = test_epoch(epoch, test_dataloader, net, criterion, config) + loss = test_epoch(fabric, epoch, test_dataloader, net, criterion, config) lr_scheduler.step(loss) is_best = loss < best_loss best_loss = min(loss, best_loss) - if config.training.save: + if config.training.save and fabric.is_global_zero: save_checkpoint( { "epoch": epoch, @@ -350,8 +394,9 @@ def train(config: Config): config.training.save_dir, ) - print(f"\nTraining complete! Best loss: {best_loss:.4f}") - print(f"Checkpoints saved to: {config.training.save_dir}") + if fabric.is_global_zero: + print(f"\nTraining complete! Best loss: {best_loss:.4f}") + print(f"Checkpoints saved to: {config.training.save_dir}") def list_models(domain: Optional[str] = None): diff --git a/tinify/losses/__init__.py b/tinify/losses/__init__.py index b0863e8..7690745 100644 --- a/tinify/losses/__init__.py +++ b/tinify/losses/__init__.py @@ -29,9 +29,10 @@ from . import pointcloud from .pointcloud import * -from .rate_distortion import RateDistortionLoss +from .rate_distortion import RateDistortionLoss, VideoRateDistortionLoss __all__ = [ *pointcloud.__all__, "RateDistortionLoss", + "VideoRateDistortionLoss", ] diff --git a/tinify/losses/rate_distortion.py b/tinify/losses/rate_distortion.py index 08e570f..d813bef 100644 --- a/tinify/losses/rate_distortion.py +++ b/tinify/losses/rate_distortion.py @@ -28,6 +28,8 @@ # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import math +from collections import defaultdict +from typing import List, Optional import torch import torch.nn as nn @@ -37,6 +39,28 @@ from tinify.registry import register_criterion +def collect_likelihoods_list(likelihoods_list, num_pixels: int): + bpp_info_dict = defaultdict(int) + bpp_loss = 0 + + for i, frame_likelihoods in enumerate(likelihoods_list): + frame_bpp = 0 + for label, likelihoods in frame_likelihoods.items(): + label_bpp = 0 + for field, v in likelihoods.items(): + bpp = torch.log(v).sum(dim=(1, 2, 3)) / (-math.log(2) * num_pixels) + + bpp_loss += bpp + frame_bpp += bpp + label_bpp += bpp + + bpp_info_dict[f"bpp_loss.{label}"] += bpp.sum() + bpp_info_dict[f"bpp_loss.{label}.{i}.{field}"] = bpp.sum() + bpp_info_dict[f"bpp_loss.{label}.{i}"] = label_bpp.sum() + bpp_info_dict[f"bpp_loss.{i}"] = frame_bpp.sum() + return bpp_loss, bpp_info_dict + + @register_criterion("RateDistortionLoss") class RateDistortionLoss(nn.Module): """Custom rate distortion loss with a Lagrangian parameter.""" @@ -73,3 +97,111 @@ def forward(self, output, target): return out else: return out[self.return_type] + + +@register_criterion("VideoRateDistortionLoss") +class VideoRateDistortionLoss(nn.Module): + """Custom rate distortion loss with a Lagrangian parameter for video.""" + + def __init__(self, lmbda=1e-2, return_details: bool = False, bitdepth: int = 8): + super().__init__() + self.mse = nn.MSELoss(reduction="none") + self.lmbda = lmbda + self._scaling_functions = lambda x: (2**bitdepth - 1) ** 2 * x + self.return_details = bool(return_details) + + def _get_scaled_distortion(self, x, target): + if not len(x) == len(target): + raise RuntimeError(f"len(x)={len(x)} != len(target)={len(target)})") + + nC = x.size(1) + if not nC == target.size(1): + raise RuntimeError( + "number of channels mismatches while computing distortion" + ) + + if isinstance(x, torch.Tensor): + x = x.chunk(x.size(1), dim=1) + + if isinstance(target, torch.Tensor): + target = target.chunk(target.size(1), dim=1) + + # compute metric over each component (eg: y, u and v) + metric_values = [] + for x0, x1 in zip(x, target): + v = self.mse(x0.float(), x1.float()) + if v.ndimension() == 4: + v = v.mean(dim=(1, 2, 3)) + metric_values.append(v) + metric_values = torch.stack(metric_values) + + # sum value over the components dimension + metric_value = torch.sum(metric_values.transpose(1, 0), dim=1) / nC + scaled_metric = self._scaling_functions(metric_value) + + return scaled_metric, metric_value + + @staticmethod + def _check_tensor(x) -> bool: + return (isinstance(x, torch.Tensor) and x.ndimension() == 4) or ( + isinstance(x, (tuple, list)) and isinstance(x[0], torch.Tensor) + ) + + @classmethod + def _check_tensors_list(cls, lst): + if ( + not isinstance(lst, (tuple, list)) + or len(lst) < 1 + or any(not cls._check_tensor(x) for x in lst) + ): + raise ValueError( + "Expected a list of 4D torch.Tensor (or tuples of) as input" + ) + + def forward(self, output, target): + assert isinstance(target, type(output["x_hat"])) + assert len(output["x_hat"]) == len(target) + + self._check_tensors_list(target) + self._check_tensors_list(output["x_hat"]) + + _, _, H, W = target[0].size() + num_frames = len(target) + out = {} + num_pixels = H * W * num_frames + + # Get scaled and raw loss distortions for each frame + scaled_distortions = [] + distortions = [] + for i, (x_hat, x) in enumerate(zip(output["x_hat"], target)): + scaled_distortion, distortion = self._get_scaled_distortion(x_hat, x) + + distortions.append(distortion) + scaled_distortions.append(scaled_distortion) + + if self.return_details: + out[f"frame{i}.mse_loss"] = distortion + # aggregate (over batch and frame dimensions). + out["mse_loss"] = torch.stack(distortions).mean() + + # average scaled_distortions accros the frames + scaled_distortions = sum(scaled_distortions) / num_frames + + assert isinstance(output["likelihoods"], list) + likelihoods_list = output.pop("likelihoods") + + # collect bpp info on noisy tensors (estimated differentiable entropy) + bpp_loss, bpp_info_dict = collect_likelihoods_list(likelihoods_list, num_pixels) + if self.return_details: + out.update(bpp_info_dict) # detailed bpp: per frame, per latent, etc... + + # now we either use a fixed lambda or try to balance between 2 lambdas + # based on a target bpp. + lambdas = torch.full_like(bpp_loss, self.lmbda) + + bpp_loss = bpp_loss.mean() + out["loss"] = (lambdas * scaled_distortions).mean() + bpp_loss + + out["distortion"] = scaled_distortions.mean() + out["bpp_loss"] = bpp_loss + return out