Skip to content

How to get predicted mask #21

@itsss

Description

@itsss

Can you please let me know how to get the predicted mask during validation? (after through ASPP)

I used this code(in test_frame.py) to get the predicted mask, but this code always gives GT for me.


for data in val_dataloader:
            begin_time = time.time()
            it = it+1
            query_img, query_mask, support_img, support_mask, idx, size = data

            query_img, query_mask, support_img, support_mask, idx \
                = query_img.cuda(), query_mask.cuda(), support_img.cuda(), support_mask.cuda(), idx.cuda()

            with torch.no_grad():
                logits = model(query_img, support_img, support_mask)

                query_img = F.upsample(query_img, size=(size[0], size[1]), mode='bilinear')
                query_mask = F.upsample(query_mask, size=(size[0], size[1]), mode='nearest')
                print(query_mask.size())

                values, pred = model.get_pred(logits, query_img)
                evaluations.update_evl(idx, query_mask, pred, 0)

                plt.figure()
                plt.subplot(2,2,1)
                plt.imshow(np.array(query_mask.squeeze().cpu()), cmap=cm.tab10_r)
                plt.subplot(2,2,2)
                plt.imshow(np.array(query_img.squeeze().permute(1,2,0).cpu()), cmap=cm.tab10_r)
                plt.axis('off')
                # plt.show()
                print(cnt)
                cnt = cnt + 1
                plt.savefig("result/"+str(cnt)+".png")
                time.sleep(0.1)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions