-
Notifications
You must be signed in to change notification settings - Fork 28
Open
Description
Can you please let me know how to get the predicted mask during validation? (after through ASPP)
I used this code(in test_frame.py) to get the predicted mask, but this code always gives GT for me.
for data in val_dataloader:
begin_time = time.time()
it = it+1
query_img, query_mask, support_img, support_mask, idx, size = data
query_img, query_mask, support_img, support_mask, idx \
= query_img.cuda(), query_mask.cuda(), support_img.cuda(), support_mask.cuda(), idx.cuda()
with torch.no_grad():
logits = model(query_img, support_img, support_mask)
query_img = F.upsample(query_img, size=(size[0], size[1]), mode='bilinear')
query_mask = F.upsample(query_mask, size=(size[0], size[1]), mode='nearest')
print(query_mask.size())
values, pred = model.get_pred(logits, query_img)
evaluations.update_evl(idx, query_mask, pred, 0)
plt.figure()
plt.subplot(2,2,1)
plt.imshow(np.array(query_mask.squeeze().cpu()), cmap=cm.tab10_r)
plt.subplot(2,2,2)
plt.imshow(np.array(query_img.squeeze().permute(1,2,0).cpu()), cmap=cm.tab10_r)
plt.axis('off')
# plt.show()
print(cnt)
cnt = cnt + 1
plt.savefig("result/"+str(cnt)+".png")
time.sleep(0.1)
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels