diff --git a/test.py b/test.py index 2729657..348fa0b 100755 --- a/test.py +++ b/test.py @@ -24,18 +24,18 @@ def detect(net,img): olist = net(img) bboxlist = [] - for i in range(len(olist)/2): olist[i*2] = F.softmax(olist[i*2]) - for i in range(len(olist)/2): - ocls,oreg = olist[i*2].data.cpu(),olist[i*2+1].data.cpu() + for i in range(len(olist)//2): olist[i*2] = F.softmax(olist[i*2]) + olist = [oelem.data.cpu() for oelem in olist] + for i in range(len(olist)//2): + ocls,oreg = olist[i*2],olist[i*2+1] FB,FC,FH,FW = ocls.size() # feature map size stride = 2**(i+2) # 4,8,16,32,64,128 anchor = stride*4 - for Findex in range(FH*FW): - windex,hindex = Findex%FW,Findex//FW + poss = zip(*np.where(ocls[:,1,:,:]>0.05)) + for Iindex, hindex, windex in poss: axc,ayc = stride/2+windex*stride,stride/2+hindex*stride score = ocls[0,1,hindex,windex] loc = oreg[0,:,hindex,windex].contiguous().view(1,4) - if score<0.05: continue priors = torch.Tensor([[axc/1.0,ayc/1.0,stride*4/1.0,stride*4/1.0]]) variances = [0.1,0.2] box = decode(loc,priors,variances)