From 2a64a74d4d0cc29a30c7a2e8e701e0979eb9726e Mon Sep 17 00:00:00 2001 From: "Wu, Chunyuan" Date: Thu, 12 Jan 2023 15:35:47 +0800 Subject: [PATCH 1/2] drq: change view to reshape to support channels last --- torchbenchmark/models/drq/drq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchbenchmark/models/drq/drq.py b/torchbenchmark/models/drq/drq.py index a1ce6f241f..9ed69842b5 100644 --- a/torchbenchmark/models/drq/drq.py +++ b/torchbenchmark/models/drq/drq.py @@ -43,7 +43,7 @@ def forward_conv(self, obs): conv = torch.relu(self.convs[i](conv)) self.outputs['conv%s' % (i + 1)] = conv - h = conv.view(conv.size(0), -1) + h = conv.reshape(conv.size(0), -1) return h def forward(self, obs, detach=False): From 9ae34ad097b940c175f4b6b0a36356d13eec3b5e Mon Sep 17 00:00:00 2001 From: "Wu, Chunyuan" Date: Fri, 13 Jan 2023 09:37:43 +0800 Subject: [PATCH 2/2] add comment --- torchbenchmark/models/drq/drq.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchbenchmark/models/drq/drq.py b/torchbenchmark/models/drq/drq.py index 9ed69842b5..74a0ff4455 100644 --- a/torchbenchmark/models/drq/drq.py +++ b/torchbenchmark/models/drq/drq.py @@ -43,6 +43,8 @@ def forward_conv(self, obs): conv = torch.relu(self.convs[i](conv)) self.outputs['conv%s' % (i + 1)] = conv + # Changed view to reshape here to support channels last input + # TODO: upstream this change to https://github.com/denisyarats/drq/blob/master/drq.py#L48 h = conv.reshape(conv.size(0), -1) return h