-
Notifications
You must be signed in to change notification settings - Fork 110
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Plot function doesn't work for receiver_rnn cause it returns n images where n == len(tokens)
Steps to Reproduce
- run
EGG walkthrough with a MNIST autoencoderuntill cell # 26
Possible Implementation
Taking last image in plot would help
def plot(game, test_dataset, is_gs, variable_length):
interaction = \
core.dump_interactions(game, test_dataset, is_gs, variable_length)
for z in range(10):
src = interaction.sender_input[z].squeeze(0)
if variable_length:
dst = interaction.receiver_output[z].view(-1, 28, 28)[-1]
else:
dst = interaction.receiver_output[z].view(28, 28)
# we'll plot two images side-by-side: the original (left) and the reconstruction
image = torch.cat([src, dst], dim=1).cpu().numpy()
plt.title(f"Input: digit {z}, channel message {interaction.message[z]}")
plt.imshow(image, cmap='gray')
plt.show()
eugene-kharitonov and codeleanrer
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working