Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
name: Tests

on:
pull_request:
branches: [main]
push:
branches: [main]

jobs:
test:
runs-on: ubuntu-latest

steps:
- name: Checkout
uses: actions/checkout@v4

- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: '3.12'

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e ".[test]"

- name: Run tests
run: pytest tests/ -v -m "not slow" --ignore=tests/test_zoo.py
124 changes: 78 additions & 46 deletions examples/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import shutil
import sys

import lightning as L
import torch
import torch.nn as nn
import torch.optim as optim
Expand Down Expand Up @@ -61,16 +62,6 @@ def update(self, val, n=1):
self.avg = self.sum / self.count


class CustomDataParallel(nn.DataParallel):
"""Custom DataParallel to access the module methods."""

def __getattr__(self, key):
try:
return super().__getattr__(key)
except AttributeError:
return getattr(self.module, key)


def configure_optimizers(net, args):
"""Separate parameters for the main optimizer and the auxiliary optimizer.
Return two optimizers"""
Expand All @@ -83,44 +74,51 @@ def configure_optimizers(net, args):


def train_one_epoch(
model, criterion, train_dataloader, optimizer, aux_optimizer, epoch, clip_max_norm
fabric,
model,
criterion,
train_dataloader,
optimizer,
aux_optimizer,
epoch,
clip_max_norm,
):
model.train()
device = next(model.parameters()).device

for i, d in enumerate(train_dataloader):
d = d.to(device)

optimizer.zero_grad()
aux_optimizer.zero_grad()

out_net = model(d)

out_criterion = criterion(out_net, d)
out_criterion["loss"].backward()
fabric.backward(out_criterion["loss"])

if clip_max_norm > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), clip_max_norm)
fabric.clip_gradients(
model, optimizer, max_norm=clip_max_norm, error_if_nonfinite=False
)

optimizer.step()

aux_loss = model.aux_loss()
aux_loss.backward()
fabric.backward(aux_loss)
aux_optimizer.step()

if i % 10 == 0:
if i % 10 == 0 and fabric.is_global_zero:
print(
f"Train epoch {epoch}: ["
f"{i*len(d)}/{len(train_dataloader.dataset)}"
f" ({100. * i / len(train_dataloader):.0f}%)]"
f'\tLoss: {out_criterion["loss"].item():.3f} |'
f'\tMSE loss: {out_criterion["mse_loss"].item():.3f} |'
f'\tBpp loss: {out_criterion["bpp_loss"].item():.2f} |'
f"{i * len(d)}/{len(train_dataloader.dataset)}"
f" ({100.0 * i / len(train_dataloader):.0f}%)]"
f"\tLoss: {out_criterion['loss'].item():.3f} |"
f"\tMSE loss: {out_criterion['mse_loss'].item():.3f} |"
f"\tBpp loss: {out_criterion['bpp_loss'].item():.2f} |"
f"\tAux loss: {aux_loss.item():.2f}"
)


def test_epoch(epoch, test_dataloader, model, criterion):
def test_epoch(fabric, epoch, test_dataloader, model, criterion):
model.eval()
device = next(model.parameters()).device

loss = AverageMeter()
bpp_loss = AverageMeter()
Expand All @@ -129,7 +127,6 @@ def test_epoch(epoch, test_dataloader, model, criterion):

with torch.no_grad():
for d in test_dataloader:
d = d.to(device)
out_net = model(d)
out_criterion = criterion(out_net, d)

Expand All @@ -138,13 +135,14 @@ def test_epoch(epoch, test_dataloader, model, criterion):
loss.update(out_criterion["loss"])
mse_loss.update(out_criterion["mse_loss"])

print(
f"Test epoch {epoch}: Average losses:"
f"\tLoss: {loss.avg:.3f} |"
f"\tMSE loss: {mse_loss.avg:.3f} |"
f"\tBpp loss: {bpp_loss.avg:.2f} |"
f"\tAux loss: {aux_loss.avg:.2f}\n"
)
if fabric.is_global_zero:
print(
f"Test epoch {epoch}: Average losses:"
f"\tLoss: {loss.avg:.3f} |"
f"\tMSE loss: {mse_loss.avg:.3f} |"
f"\tBpp loss: {bpp_loss.avg:.2f} |"
f"\tAux loss: {aux_loss.avg:.2f}\n"
)

return loss.avg

Expand Down Expand Up @@ -217,7 +215,30 @@ def parse_args(argv):
default=(256, 256),
help="Size of the patches to be cropped (default: %(default)s)",
)
parser.add_argument("--cuda", action="store_true", help="Use cuda")
parser.add_argument(
"--accelerator",
type=str,
default="auto",
help="Accelerator (default: %(default)s)",
)
parser.add_argument(
"--devices",
type=str,
default="auto",
help="Devices (default: %(default)s)",
)
parser.add_argument(
"--strategy",
type=str,
default="auto",
help="Strategy (default: %(default)s)",
)
parser.add_argument(
"--precision",
type=str,
default="32-true",
help="Precision (default: %(default)s)",
)
parser.add_argument(
"--save", action="store_true", default=True, help="Save model to disk"
)
Expand All @@ -240,6 +261,14 @@ def main(argv):
torch.manual_seed(args.seed)
random.seed(args.seed)

fabric = L.Fabric(
accelerator=args.accelerator,
devices=args.devices,
strategy=args.strategy,
precision=args.precision,
)
fabric.launch()

train_transforms = transforms.Compose(
[transforms.RandomCrop(args.patch_size), transforms.ToTensor()]
)
Expand All @@ -251,38 +280,39 @@ def main(argv):
train_dataset = ImageFolder(args.dataset, split="train", transform=train_transforms)
test_dataset = ImageFolder(args.dataset, split="test", transform=test_transforms)

device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu"

train_dataloader = DataLoader(
train_dataset,
batch_size=args.batch_size,
num_workers=args.num_workers,
shuffle=True,
pin_memory=(device == "cuda"),
pin_memory=True,
)

test_dataloader = DataLoader(
test_dataset,
batch_size=args.test_batch_size,
num_workers=args.num_workers,
shuffle=False,
pin_memory=(device == "cuda"),
pin_memory=True,
)

net = image_models[args.model](quality=3)
net = net.to(device)
train_dataloader, test_dataloader = fabric.setup_dataloaders(
train_dataloader, test_dataloader
)

if args.cuda and torch.cuda.device_count() > 1:
net = CustomDataParallel(net)
net = image_models[args.model](quality=3)

optimizer, aux_optimizer = configure_optimizers(net, args)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min")
criterion = RateDistortionLoss(lmbda=args.lmbda)

# Setup model and optimizers
net, optimizer, aux_optimizer = fabric.setup(net, optimizer, aux_optimizer)

last_epoch = 0
if args.checkpoint: # load from previous checkpoint
print("Loading", args.checkpoint)
checkpoint = torch.load(args.checkpoint, map_location=device)
checkpoint = torch.load(args.checkpoint, map_location=fabric.device)
last_epoch = checkpoint["epoch"] + 1
net.load_state_dict(checkpoint["state_dict"])
optimizer.load_state_dict(checkpoint["optimizer"])
Expand All @@ -291,8 +321,10 @@ def main(argv):

best_loss = float("inf")
for epoch in range(last_epoch, args.epochs):
print(f"Learning rate: {optimizer.param_groups[0]['lr']}")
if fabric.is_global_zero:
print(f"Learning rate: {optimizer.param_groups[0]['lr']}")
train_one_epoch(
fabric,
net,
criterion,
train_dataloader,
Expand All @@ -301,13 +333,13 @@ def main(argv):
epoch,
args.clip_max_norm,
)
loss = test_epoch(epoch, test_dataloader, net, criterion)
loss = test_epoch(fabric, epoch, test_dataloader, net, criterion)
lr_scheduler.step(loss)

is_best = loss < best_loss
best_loss = min(loss, best_loss)

if args.save:
if args.save and fabric.is_global_zero:
save_checkpoint(
{
"epoch": epoch,
Expand Down
Loading
Loading