Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def get_args_parser():
parser.add_argument('--clip_max_norm', default=0.1, type=float,
help='gradient clipping max norm')

parser.add_argument('--sgd', action='store_true')
parser.add_argument('--sgd', action='store_true')

# Variants of Deformable DETR
parser.add_argument('--with_box_refine', default=False, action='store_true')
Expand Down Expand Up @@ -149,7 +149,7 @@ def main(args):
random.seed(seed)

model, criterion, postprocessors = build_model(args)

model.to(device)

model_without_ddp = model
Expand Down Expand Up @@ -206,7 +206,7 @@ def match_name_keywords(n, name_keywords):
"lr": args.lr * args.lr_linear_proj_mult,
}
]

if args.sgd:
optimizer = torch.optim.SGD(param_dicts, lr=args.lr, momentum=0.9,
weight_decay=args.weight_decay)
Expand All @@ -216,11 +216,11 @@ def match_name_keywords(n, name_keywords):
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lr_drop )

if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module

base_ds = {}

if args.dataset_file == 'YoutubeVIS' or args.dataset_file == 'jointcoco' or args.dataset_file == 'Seq_coco':
base_ds['ytvos'] = get_coco_api_from_dataset(dataset_val)
else:
Expand Down Expand Up @@ -248,7 +248,7 @@ def match_name_keywords(n, name_keywords):
lr_scheduler.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
lr_scheduler.step(lr_scheduler.last_epoch)
args.start_epoch = checkpoint['epoch'] + 1

elif args.pretrain_weights is not None:
print('load weigth from pretrain weight:',args.pretrain_weights)
checkpoint = torch.load(args.pretrain_weights, map_location='cpu')['model']
Expand Down Expand Up @@ -292,7 +292,7 @@ def match_name_keywords(n, name_keywords):
'args': args,
}, checkpoint_path)



if (epoch + 1) % 1 == 0 and args.eval_types == 'coco':
test_stats, coco_evaluator = evaluate(model, criterion, postprocessors,
Expand Down
18 changes: 10 additions & 8 deletions models/ops/modules/ms_deform_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4, mode='encode'
self.value_proj = nn.Linear(d_model, d_model)
self.output_proj = nn.Linear(d_model, d_model)
self.output_proj_box = nn.Linear(d_model, d_model)
if self.mode == 'decode':
self.output_proj_box = nn.Linear(d_model, d_model)

self._reset_parameters()

Expand Down Expand Up @@ -105,7 +107,7 @@ def encode_forward(self, query, reference_points, input_flatten, input_spatial_s
for i in range(nf):
value_list.append(value[:,i].contiguous())
for idx_f in range(nf):
sampling_offsets_i = sampling_offsets[:,idx_f]
sampling_offsets_i = sampling_offsets[:,idx_f]
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
sampling_locations_i = reference_points[:, :, None, :, None, :] \
Expand Down Expand Up @@ -147,14 +149,14 @@ def decode_forward(self, query, query_box, reference_points, input_flatten, inpu
point_list.append(reference_points[:,i].contiguous() )

result_idx_f = []

for samp_i in range(nf): # perform deformable attention per frame

reference_points_i = point_list[samp_i]
if reference_points_i.shape[-1] == 2:
offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
sampling_locations = reference_points_i[:, :, None, :, None, :] + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
elif reference_points_i.shape[-1] == 4:
elif reference_points_i.shape[-1] == 4:
sampling_locations = reference_points_i[:, :, None, :, None, :2] \
+ sampling_offsets / self.n_points * reference_points_i[:, :, None, :, None, 2:] * 0.5
else:
Expand All @@ -177,12 +179,12 @@ def decode_forward(self, query, query_box, reference_points, input_flatten, inpu
value = self.value_proj(input_flatten)
if input_padding_mask is not None:
value = value.masked_fill(input_padding_mask[..., None], float(0))
#
#
value = value.view(N, nf, Len_in, self.n_heads, self.d_model // self.n_heads)
sampling_offsets = self.sampling_offsets(query_box).view(N, nf,Len_q, self.n_heads, self.n_levels, self.n_points, 2)
attention_weights = self.attention_weights(query_box).view(N, nf, Len_q, self.n_heads, self.n_levels * self.n_points)
attention_weights = F.softmax(attention_weights, -1).view(N, nf, Len_q, self.n_heads, self.n_levels, self.n_points)

value_list = []
point_list = []
sampling_offsets_list = []
Expand All @@ -199,7 +201,7 @@ def decode_forward(self, query, query_box, reference_points, input_flatten, inpu
if reference_points_i.shape[-1] == 2:
offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
sampling_locations = reference_points_i[:, :, None, :, None, :] + sampling_offsets_list[samp_i] / offset_normalizer[None, None, None, :, None, :]
elif reference_points_i.shape[-1] == 4:
elif reference_points_i.shape[-1] == 4:
sampling_locations = reference_points_i[:, :, None, :, None, :2] \
+ sampling_offsets_list[samp_i] / self.n_points * reference_points_i[:, :, None, :, None, 2:] * 0.5
else:
Expand All @@ -208,7 +210,7 @@ def decode_forward(self, query, query_box, reference_points, input_flatten, inpu
output_samp_i = MSDeformAttnFunction.apply(
value_list[samp_i], input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights_list[samp_i], self.im2col_step)
result_idx_f.append(output_samp_i.unsqueeze(1))
result_idx_f = torch.cat(result_idx_f,dim=1)
result_idx_f = torch.cat(result_idx_f,dim=1)
result_sum = result_idx_f
output = self.output_proj(result_sum)
output_box = self.output_proj_box(result_idx_f)
Expand Down