diff --git a/torchbenchmark/models/drq/drq.py b/torchbenchmark/models/drq/drq.py index a1ce6f241f..74a0ff4455 100644 --- a/torchbenchmark/models/drq/drq.py +++ b/torchbenchmark/models/drq/drq.py @@ -43,7 +43,9 @@ def forward_conv(self, obs): conv = torch.relu(self.convs[i](conv)) self.outputs['conv%s' % (i + 1)] = conv - h = conv.view(conv.size(0), -1) + # Changed view to reshape here to support channels last input + # TODO: upstream this change to https://github.com/denisyarats/drq/blob/master/drq.py#L48 + h = conv.reshape(conv.size(0), -1) return h def forward(self, obs, detach=False):