diff --git a/run_train.py b/run_train.py index 9d50f94..647a0fb 100755 --- a/run_train.py +++ b/run_train.py @@ -1,7 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -from math import dist +# from math import dist import sys import os import click diff --git a/training/loss.py b/training/loss.py index 0eb2ebf..12be14c 100755 --- a/training/loss.py +++ b/training/loss.py @@ -111,7 +111,7 @@ def run_G(self, z, c, sync, img=None, mode=None, get_loss=True): elif (generator_mode == 'random_z_random_c') or (generator_mode == 'random_z_image_c'): with misc.ddp_sync(self.G_mapping, sync): - ws = self.G_mapping(z, c) + ws = self.G_mapping(z, c).clone() if self.style_mixing_prob > 0: with torch.autograd.profiler.record_function('style_mixing'): cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1])