diff --git a/inference.py b/inference.py index a2f6c44..218e584 100644 --- a/inference.py +++ b/inference.py @@ -50,9 +50,15 @@ def infer(flowtron_path, waveglow_path, output_dir, text, speaker_id, n_frames, waveglow.eval() # load flowtron - model = Flowtron(**model_config).cuda() - state_dict = torch.load(flowtron_path, map_location='cpu')['state_dict'] - model.load_state_dict(state_dict) + try: + model = Flowtron(**model_config).cuda() + state_dict = torch.load(flowtron_path, map_location='cpu')['state_dict'] + model.load_state_dict(state_dict) + except KeyError: + # model saved by train.py module + # do not need to load state dict + # and can be used directly + model = torch.load(flowtron_path)['model'] model.eval() print("Loaded checkpoint '{}')" .format(flowtron_path))