diff --git a/cliffordlayers/nn/functional/utils.py b/cliffordlayers/nn/functional/utils.py index 1300af9..7acb8b7 100644 --- a/cliffordlayers/nn/functional/utils.py +++ b/cliffordlayers/nn/functional/utils.py @@ -32,15 +32,35 @@ def clifford_convnd( Returns: torch.Tensor: Convolved output tensor. """ - # Reshape x such that the convolution function can be applied. + # Reshape x such that the convolution function with grouping can be applied. B, *_ = x.shape + groups = kwargs['groups'] B_dim, C_dim, *D_dims, I_dim = range(len(x.shape)) x = x.permute(B_dim, -1, C_dim, *D_dims) + B_dim, I_dim, C_dim, *D_dims = range(len(x.shape)) + x = x.chunk(groups, C_dim) + x = torch.cat(x, dim=I_dim) x = x.reshape(B, -1, *x.shape[3:]) + # Reshape weight and bias such that the convolution function with grouping can be applied. + ICO, CI, *K = weight.shape + weight = weight.reshape(output_blades, ICO // output_blades, *weight.shape[1:]) + I_dim, CO_dim, *_ = range(len(weight.shape)) + weight = weight.chunk(groups, CO_dim) + weight = torch.cat(weight, dim=I_dim) + weight = weight.reshape(-1, CI, *K) + bias = bias.reshape(output_blades, ICO // output_blades) + bias = bias.chunk(groups, CO_dim) + bias = torch.cat(bias, dim=I_dim) + bias = bias.reshape(-1) # Apply convolution function output = conv_fn(x, weight, bias=bias, **kwargs) # Reshape back. - output = output.view(B, output_blades, -1, *output.shape[2:]) + output = output.view(B, groups, -1, *output.shape[2:]) + B_dim, G_dim, C_dim, *D_dims = range(len(output.shape)) + output = output.chunk(output_blades, dim=C_dim) + output = torch.cat(output, dim=G_dim) + B, IG, CO_G, *D = output.shape + output = output.reshape(B, IG // groups, CO_G * groups, *D) B_dim, I_dim, C_dim, *D_dims = range(len(output.shape)) output = output.permute(B_dim, C_dim, *D_dims, I_dim) return output diff --git a/tests/test_clifford_convolution.py b/tests/test_clifford_convolution.py index bb6c6f8..fd0dfa0 100644 --- a/tests/test_clifford_convolution.py +++ b/tests/test_clifford_convolution.py @@ -27,6 +27,18 @@ def test_complex_convolution(): output_c = F.conv1d(input_c, w_c, b_c) torch.testing.assert_close(output_clifford_conv, torch.view_as_real(output_c)) +def test_complex_grouped_convolution(): + """Test Clifford1d grouped convolution module against complex convolution module using g = [-1].""" + in_channels = 8 + out_channels = 16 + x = torch.randn(1, in_channels, 128, 2) + clifford_conv = CliffordConv1d(g=[-1], in_channels=in_channels, out_channels=out_channels, kernel_size=3, groups=4) + output_clifford_conv = clifford_conv(x) + w_c = torch.view_as_complex(torch.stack((clifford_conv.weight[0], clifford_conv.weight[1]), -1)) + b_c = torch.view_as_complex(clifford_conv.bias.permute(1, 0).contiguous()) + input_c = torch.view_as_complex(x) + output_c = F.conv1d(input_c, w_c, b_c, groups=4) + torch.testing.assert_close(output_clifford_conv, torch.view_as_real(output_c)) def test_Clifford1d_conv_shapes(): """Test shapes of Clifford1d convolution module."""