diff --git a/scripts/trainer.py b/scripts/trainer.py index 92f1b91..de1a3f3 100644 --- a/scripts/trainer.py +++ b/scripts/trainer.py @@ -1142,6 +1142,10 @@ def save_and_sample_weights(step,context='checkpoint',save_model=True): send_telegram_message(f"Generating samples for {step} {context}", args.telegram_chat_id, args.telegram_token) except: pass + n_sample = args.n_save_sample + if args.save_sample_controlled_seed: + n_sample += len(args.save_sample_controlled_seed) + progress_bar_sample = tqdm(range(len(prompts)*n_sample),desc="Generating samples") for samplePrompt in prompts: sampleIndex = prompts.index(samplePrompt) #convert sampleIndex to number in words @@ -1172,7 +1176,7 @@ def save_and_sample_weights(step,context='checkpoint',save_model=True): depth = depth.astype(np.float32) / 255.0 depth = depth[None, None] depth = torch.from_numpy(depth) - for i in tqdm(range(args.n_save_sample) if not args.save_sample_controlled_seed else range(args.n_save_sample+len(args.save_sample_controlled_seed)), desc="Generating samples"): + for i in range(n_sample): #check if the sample is controlled by a seed if i < args.n_save_sample: if args.model_variant == 'inpainting': @@ -1201,6 +1205,7 @@ def save_and_sample_weights(step,context='checkpoint',save_model=True): images[0].save(os.path.join(sample_dir,sampleName, f"{sampleName}_controlled_seed_{str(seed)}.png")) else: images[0].save(os.path.join(sample_dir, f"{sampleName}_controlled_seed_{str(seed)}.png")) + progress_bar_sample.update(1) if args.send_telegram_updates: imgs = []