From e825c286877f5b0157fe9cbaa6d50bd96157b6ce Mon Sep 17 00:00:00 2001 From: "Y. Frisch" <39304298+YFrisch@users.noreply.github.com> Date: Wed, 21 Jan 2026 14:41:04 +0100 Subject: [PATCH] Update checkpoint saving to use raw_model Update checkpoint saving to use raw_model. Allows the use of multi-GPU training of the VAE with PyTorch DataParallel. --- src/python/training/training_functions.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/python/training/training_functions.py b/src/python/training/training_functions.py index 3c3e08b..e322414 100644 --- a/src/python/training/training_functions.py +++ b/src/python/training/training_functions.py @@ -132,13 +132,13 @@ def train_vqgan( # Save checkpoint checkpoint = { "epoch": epoch + 1, - "state_dict": model.state_dict(), + "state_dict": raw_model.state_dict(), "discriminator": discriminator.state_dict(), "optimizer_g": optimizer_g.state_dict(), "optimizer_d": optimizer_d.state_dict(), "best_loss": best_loss, - "ema_cluster_size": model.quantizer.quantizer.ema_cluster_size, # Add this - "ema_w": model.quantizer.quantizer.ema_w, + "ema_cluster_size": raw_model.quantizer.quantizer.ema_cluster_size, # Add this + "ema_w": raw_model.quantizer.quantizer.ema_w, } torch.save(checkpoint, str(run_dir / f"checkpoint_epoch_{epoch + 1}.pth"))