Skip to content

issue in MNIST tutorial #237

@dmitrySorokin

Description

@dmitrySorokin

Plot function doesn't work for receiver_rnn cause it returns n images where n == len(tokens)

Steps to Reproduce

  1. run EGG walkthrough with a MNIST autoencoder untill cell # 26

Possible Implementation

Taking last image in plot would help

def plot(game, test_dataset, is_gs, variable_length):
    interaction = \
            core.dump_interactions(game, test_dataset, is_gs, variable_length)

    for z in range(10):
        src = interaction.sender_input[z].squeeze(0)
        if variable_length:
            dst = interaction.receiver_output[z].view(-1, 28, 28)[-1]
        else:
            dst = interaction.receiver_output[z].view(28, 28)
        # we'll plot two images side-by-side: the original (left) and the reconstruction
        image = torch.cat([src, dst], dim=1).cpu().numpy()

        plt.title(f"Input: digit {z}, channel message {interaction.message[z]}")
        plt.imshow(image, cmap='gray')
        plt.show()

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions