diff --git a/inference.py b/inference.py index 791b09a..5c8df2c 100644 --- a/inference.py +++ b/inference.py @@ -51,9 +51,9 @@ def infer(flowtron_path, waveglow_path, output_dir, text, speaker_id, n_frames, # load flowtron model = Flowtron(**model_config).cuda() - state_dict = torch.load(flowtron_path, map_location='cpu')['state_dict'] - model.load_state_dict(state_dict) + model = torch.load(flowtron_path, map_location='cpu')['model'] model.eval() + model.cuda() print("Loaded checkpoint '{}')" .format(flowtron_path)) ignore_keys = ['training_files', 'validation_files'] @@ -73,8 +73,8 @@ def infer(flowtron_path, waveglow_path, output_dir, text, speaker_id, n_frames, for k in range(len(attentions)): attention = torch.cat(attentions[k]).cpu().numpy() fig, axes = plt.subplots(1, 2, figsize=(16, 4)) - axes[0].imshow(mels[0].cpu().numpy(), origin='bottom', aspect='auto') - axes[1].imshow(attention[:, 0].transpose(), origin='bottom', aspect='auto') + axes[0].imshow(mels[0].cpu().numpy(), origin='lower', aspect='auto') + axes[1].imshow(attention[:, 0].transpose(), origin='lower', aspect='auto') fig.savefig(os.path.join(output_dir, 'sid{}_sigma{}_attnlayer{}.png'.format(speaker_id, sigma, k))) plt.close("all")