From e8e4803a4bfb09e07cde1d98162ddbe108fcf249 Mon Sep 17 00:00:00 2001 From: tahle Date: Wed, 26 Oct 2022 20:09:39 -0700 Subject: [PATCH] Extracted tests and updated Pytorch version --- pytorch/structure/LDR.py | 82 +- pytorch/structure/circulant.py | 40 +- pytorch/structure/complex_utils.py | 19 - .../diag_mult_cuda/diag_mult_cuda.cpp | 71 +- .../diag_mult_cuda/diag_mult_cuda_kernel.cu | 64 +- pytorch/structure/diag_mult_cuda/setup.py | 21 - pytorch/structure/fastfood.py | 58 +- pytorch/structure/hadamard.py | 87 +- .../structure/hadamard_cuda/hadamard_cuda.cpp | 23 +- .../hadamard_cuda/hadamard_cuda_kernel.cu | 228 ++--- pytorch/structure/hadamard_cuda/setup.py | 21 - pytorch/structure/krylov.py | 920 ++++++++---------- pytorch/structure/layer.py | 174 ++-- pytorch/structure/scratch/fft.py | 91 -- pytorch/structure/scratch/krylovfast.py | 375 ------- pytorch/structure/scratch/krylovslow.py | 215 ---- pytorch/structure/scratch/tests_snippets.py | 52 - pytorch/structure/tests/test_circulant.py | 20 + pytorch/structure/tests/test_fastfood.py | 34 + pytorch/structure/tests/test_hadamard.py | 49 + pytorch/structure/tests/test_krylov.py | 336 +++++++ pytorch/structure/tests/test_toeplitz.py | 120 +++ pytorch/structure/toeplitz.py | 218 ++--- pytorch/structure/toeplitz_cpu.py | 168 ++-- 24 files changed, 1590 insertions(+), 1896 deletions(-) delete mode 100644 pytorch/structure/complex_utils.py delete mode 100644 pytorch/structure/diag_mult_cuda/setup.py delete mode 100644 pytorch/structure/hadamard_cuda/setup.py delete mode 100644 pytorch/structure/scratch/fft.py delete mode 100644 pytorch/structure/scratch/krylovfast.py delete mode 100644 pytorch/structure/scratch/krylovslow.py delete mode 100644 pytorch/structure/scratch/tests_snippets.py create mode 100644 pytorch/structure/tests/test_circulant.py create mode 100644 pytorch/structure/tests/test_fastfood.py create mode 100644 pytorch/structure/tests/test_hadamard.py create mode 100644 pytorch/structure/tests/test_krylov.py create mode 100644 pytorch/structure/tests/test_toeplitz.py diff --git a/pytorch/structure/LDR.py b/pytorch/structure/LDR.py index 7937d8e..b2a19d4 100644 --- a/pytorch/structure/LDR.py +++ b/pytorch/structure/LDR.py @@ -1,20 +1,41 @@ +# Copyright 2018 HazyResearch +# https://github.com/HazyResearch/structured-nets +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import torch -from torch.autograd import Variable import torch.nn as nn +from torch.autograd import Variable from torch.nn.parameter import Parameter -from . import toeplitz as toep -from . import krylov as kry +from . import krylov as kry, toeplitz as toep # TODO: rewrite with structure.layer # TODO: subclass with each DR type class LDR(nn.Module): - def name(self): - return str(self.in_channels) + str(self.out_channels) + self.displacement + str(self.r) + return ( + str(self.in_channels) + + str(self.out_channels) + + self.displacement + + str(self.r) + ) # TODO: support non-square multiplications - def __init__(self, displacement, in_channels, out_channels, rank, layer_size, bias=True): + def __init__( + self, displacement, in_channels, out_channels, rank, layer_size, bias=True + ): super(LDR, self).__init__() self.displacement = displacement self.in_channels = in_channels @@ -23,20 +44,27 @@ def __init__(self, displacement, in_channels, out_channels, rank, layer_size, bi self.n = layer_size self.bias = None - self.G = Parameter(torch.Tensor(self.in_channels, self.out_channels, self.r, self.n)) - self.H = Parameter(torch.Tensor(self.in_channels, self.out_channels, self.r, self.n)) - torch.nn.init.normal_(self.G, std=0.01) #TODO + self.G = Parameter( + torch.Tensor(self.in_channels, self.out_channels, self.r, self.n) + ) + self.H = Parameter( + torch.Tensor(self.in_channels, self.out_channels, self.r, self.n) + ) + torch.nn.init.normal_(self.G, std=0.01) # TODO torch.nn.init.normal_(self.H, std=0.01) if bias: self.bias = Parameter(torch.zeros(self.out_channels, 1, self.n)) - if self.displacement == 'toeplitz_corner' or self.displacement == 'tc': + if self.displacement == "toeplitz_corner" or self.displacement == "tc": self.corner = True - elif self.displacement == 'toeplitz' or self.displacement == 't': + elif self.displacement == "toeplitz" or self.displacement == "t": self.corner = False - elif self.displacement == 'subdiagonal' or self.displacement == 'sd': - self.subd_A = Parameter(torch.ones((self.in_channels, self.out_channels, self.n-1))) - self.subd_B = Parameter(torch.ones((self.in_channels, self.out_channels, self.n-1))) - + elif self.displacement == "subdiagonal" or self.displacement == "sd": + self.subd_A = Parameter( + torch.ones((self.in_channels, self.out_channels, self.n - 1)) + ) + self.subd_B = Parameter( + torch.ones((self.in_channels, self.out_channels, self.n - 1)) + ) def forward(self, x): """ @@ -47,15 +75,23 @@ def forward(self, x): assert n == self.n # print("shapes ", self.G[0,0].shape, self.H[0,0].shape, x[0].shape) - comps = Variable(torch.Tensor(self.in_channels, self.out_channels, b, self.n)).cuda() + comps = Variable( + torch.Tensor(self.in_channels, self.out_channels, b, self.n) + ).cuda() for i in range(self.in_channels): for j in range(self.out_channels): - if self.displacement in ['toeplitz_corner', 'toeplitz', 'tc', 't']: - g = self.G[i,j] - h = self.H[i,j] - comps[i,j] = toep.toeplitz_mult(self.G[i,j], self.H[i,j], x[i], self.corner) - elif self.displacement == 'subdiagonal' or self.displacement == 'sd': - comps[i,j] = kry.subdiag_mult_conv(self.subd_A[i,j], self.subd_B[i,j], self.G[i,j], self.H[i,j], x[i]) + if self.displacement in ["toeplitz_corner", "toeplitz", "tc", "t"]: + comps[i, j] = toep.toeplitz_mult( + self.G[i, j], self.H[i, j], x[i], self.corner + ) + elif self.displacement == "subdiagonal" or self.displacement == "sd": + comps[i, j] = kry.subdiag_mult_conv( + self.subd_A[i, j], + self.subd_B[i, j], + self.G[i, j], + self.H[i, j], + x[i], + ) out = torch.sum(comps, dim=0) if self.bias is not None: out += self.bias @@ -64,4 +100,4 @@ def forward(self, x): def loss(self): lamb = 0.0001 # lamb = 0 - return lamb*torch.sum(torch.abs(self.G)) + lamb*torch.sum(torch.abs(self.H)) + return lamb * torch.sum(torch.abs(self.G)) + lamb * torch.sum(torch.abs(self.H)) diff --git a/pytorch/structure/circulant.py b/pytorch/structure/circulant.py index ea30ff8..5f04102 100644 --- a/pytorch/structure/circulant.py +++ b/pytorch/structure/circulant.py @@ -1,27 +1,33 @@ +# Copyright 2018 HazyResearch +# https://github.com/HazyResearch/structured-nets +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import torch -from scipy.linalg import circulant -from .complex_utils import complex_mult -device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") def circulant_multiply(c, x): - """ Multiply circulant matrix with first column c by x + """Multiply circulant matrix with first column c by x. + E.g. if the matrix is + [1 2 3] + [2 3 1] + [3 1 2] + c should be [1,2,3] Parameters: c: (n, ) x: (batch_size, n) or (n, ) Return: prod: (batch_size, n) or (n, ) """ - return torch.irfft(complex_mult(torch.rfft(c, 1), torch.rfft(x, 1)), 1, signal_sizes=(c.shape[-1], )) - -def test_circulant_multiply(n): - c = torch.rand(n, device=device) - x = torch.rand((3, n), device=device) - C = torch.tensor(circulant(c.detach().cpu().numpy()), dtype=c.dtype, device=c.device) - slow = x @ C.t() - fast = circulant_multiply(c, x) - print('Error compared to slow multiply: ', (slow - fast).abs().max().item()) - -# TODO: move test into subpackage -if __name__ == '__main__': - test_circulant_multiply(100) + return torch.fft.irfft(torch.fft.rfft(c) * torch.fft.rfft(x)) diff --git a/pytorch/structure/complex_utils.py b/pytorch/structure/complex_utils.py deleted file mode 100644 index 4f21baa..0000000 --- a/pytorch/structure/complex_utils.py +++ /dev/null @@ -1,19 +0,0 @@ -''' Utility functions for handling complex tensors: conjugate and complex_mult. -Pytorch (as of 0.4.0) does not support complex tensors, so we store them as -float tensors where the last dimension is 2 (real and imaginary parts). -''' - -import torch - - -def conjugate(X): - assert X.shape[-1] == 2, 'Last dimension must be 2' - return X * torch.tensor((1, -1), dtype=X.dtype, device=X.device) - - -def complex_mult(X, Y): - assert X.shape[-1] == 2 and Y.shape[-1] == 2, 'Last dimension must be 2' - return torch.stack( - (X[..., 0] * Y[..., 0] - X[..., 1] * Y[..., 1], - X[..., 0] * Y[..., 1] + X[..., 1] * Y[..., 0]), - dim=-1) diff --git a/pytorch/structure/diag_mult_cuda/diag_mult_cuda.cpp b/pytorch/structure/diag_mult_cuda/diag_mult_cuda.cpp index ce0cf47..57dc343 100644 --- a/pytorch/structure/diag_mult_cuda/diag_mult_cuda.cpp +++ b/pytorch/structure/diag_mult_cuda/diag_mult_cuda.cpp @@ -1,8 +1,37 @@ +/* +Copyright 2018 HazyResearch +https://github.com/HazyResearch/structured-nets + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + #include -void subdiagMultGPU(float *d_Subdiag, float *d_Data, float *d_Output, int shiftSubdiag, int shiftV, int batchSize, int N, bool batchedSubdiag); +void subdiagMultGPU( + float* d_Subdiag, + float* d_Data, + float* d_Output, + int shiftSubdiag, + int shiftV, + int batchSize, + int N, + bool batchedSubdiag); -torch::Tensor cycle_mult(torch::Tensor subdiag, torch::Tensor v, int shiftSubdiag, int shiftV) { +torch::Tensor cycle_mult( + torch::Tensor subdiag, + torch::Tensor v, + int64_t shiftSubdiag, + int64_t shiftV) { TORCH_CHECK(subdiag.is_cuda(), "subdiag must be a CUDA tensor"); TORCH_CHECK(v.is_cuda(), "v must be a CUDA tensor"); // Need to make tensors contiguous before passing to CUDA @@ -12,11 +41,19 @@ torch::Tensor cycle_mult(torch::Tensor subdiag, torch::Tensor v, int shiftSubdia auto batchSize = v.numel() / n; auto output = torch::empty_like(v); bool batchedSubdiag = subdiag.numel() == v.numel(); - subdiagMultGPU(subdiag.data_ptr(), v.data_ptr(), output.data_ptr(), shiftSubdiag, shiftV, batchSize, n, batchedSubdiag); + subdiagMultGPU( + subdiag.data_ptr(), + v.data_ptr(), + output.data_ptr(), + shiftSubdiag, + shiftV, + batchSize, + n, + batchedSubdiag); return output; } -torch::Tensor subdiagKrylov(torch::Tensor subdiag, torch::Tensor v, int m) { +torch::Tensor subdiagKrylov(torch::Tensor subdiag, torch::Tensor v, int64_t m) { TORCH_CHECK(subdiag.is_cuda(), "subdiag must be a CUDA tensor"); TORCH_CHECK(v.is_cuda(), "v must be a CUDA tensor"); // Need to make tensors contiguous before passing to CUDA @@ -24,16 +61,34 @@ torch::Tensor subdiagKrylov(torch::Tensor subdiag, torch::Tensor v, int m) { v = v.contiguous(); auto n = v.sizes().back(); auto batchSize = v.numel() / n; - auto output = torch::empty({m, batchSize, n}, torch::dtype(v.dtype()).device(v.device())); - // subdiagKrylovGPU(subdiag.data_ptr(), v.data_ptr(), output.data_ptr(), shiftSubdiag, shiftV, batchSize, n); + auto output = torch::empty( + {m, batchSize, n}, torch::dtype(v.dtype()).device(v.device())); + // subdiagKrylovGPU(subdiag.data_ptr(), v.data_ptr(), + // output.data_ptr(), shiftSubdiag, shiftV, batchSize, n); output[0] = v; for (int i = 1; i < m; ++i) { - subdiagMultGPU(subdiag.data_ptr(), output[i - 1].data_ptr(), output[i].data_ptr(), 0, -1, batchSize, n, false); + subdiagMultGPU( + subdiag.data_ptr(), + output[i - 1].data_ptr(), + output[i].data_ptr(), + 0, + -1, + batchSize, + n, + false); } return output; } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("cycle_mult", &cycle_mult, "Cycle the vector and then do a pointwise multiplication. Shift should be between -n and n - 1."); + m.def( + "cycle_mult", + &cycle_mult, + "Cycle the vector and then do a pointwise multiplication. Shift should be between -n and n - 1."); m.def("subdiagKrylov", &subdiagKrylov, "Subdiag Krylov"); } + +TORCH_LIBRARY(TORCH_EXTENSION_NAME, m) { + m.def("cycle_mult", &cycle_mult); + m.def("subdiagKrylov", &subdiagKrylov); +} diff --git a/pytorch/structure/diag_mult_cuda/diag_mult_cuda_kernel.cu b/pytorch/structure/diag_mult_cuda/diag_mult_cuda_kernel.cu index cb739f0..f80363f 100644 --- a/pytorch/structure/diag_mult_cuda/diag_mult_cuda_kernel.cu +++ b/pytorch/structure/diag_mult_cuda/diag_mult_cuda_kernel.cu @@ -1,18 +1,52 @@ -__global__ void subdiagMult(float *d_Subdiag, float *d_Data, float *d_Output, int shiftSubdiag, int shiftV, int N, int subdiagOffset) { - const int pos = blockIdx.x * blockDim.x + threadIdx.x; +/* +Copyright 2018 HazyResearch +https://github.com/HazyResearch/structured-nets - float *d_Src = d_Data + blockIdx.y * N; - float *d_Dst = d_Output + blockIdx.y * N; - float *d_Sub = d_Subdiag + blockIdx.y * subdiagOffset; - // for (int pos = tid; pos < N; pos += numThreads) { - if (pos < N) { - d_Dst[pos] = d_Sub[(pos + shiftSubdiag + N) % N] * d_Src[(pos + shiftV + N) % N]; - } +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +__global__ void subdiagMult( + float* d_Subdiag, + float* d_Data, + float* d_Output, + int shiftSubdiag, + int shiftV, + int N, + int subdiagOffset) { + const int pos = blockIdx.x * blockDim.x + threadIdx.x; + + float* d_Src = d_Data + blockIdx.y * N; + float* d_Dst = d_Output + blockIdx.y * N; + float* d_Sub = d_Subdiag + blockIdx.y * subdiagOffset; + // for (int pos = tid; pos < N; pos += numThreads) { + if (pos < N) { + d_Dst[pos] = + d_Sub[(pos + shiftSubdiag + N) % N] * d_Src[(pos + shiftV + N) % N]; + } } -void subdiagMultGPU(float *d_Subdiag, float *d_Data, float *d_Output, int shiftSubdiag, int shiftV, int batchSize, int N, bool batchedSubdiag) { - const int THREAD_N = 256; - dim3 grid((N + THREAD_N - 1) / THREAD_N, batchSize); - int subdiagOffset = batchedSubdiag ? N : 0; - subdiagMult<<>>(d_Subdiag, d_Data, d_Output, shiftSubdiag, shiftV, N, subdiagOffset); -} \ No newline at end of file +void subdiagMultGPU( + float* d_Subdiag, + float* d_Data, + float* d_Output, + int shiftSubdiag, + int shiftV, + int batchSize, + int N, + bool batchedSubdiag) { + const int THREAD_N = 256; + dim3 grid((N + THREAD_N - 1) / THREAD_N, batchSize); + int subdiagOffset = batchedSubdiag ? N : 0; + subdiagMult<<>>( + d_Subdiag, d_Data, d_Output, shiftSubdiag, shiftV, N, subdiagOffset); +} diff --git a/pytorch/structure/diag_mult_cuda/setup.py b/pytorch/structure/diag_mult_cuda/setup.py deleted file mode 100644 index f7a81b1..0000000 --- a/pytorch/structure/diag_mult_cuda/setup.py +++ /dev/null @@ -1,21 +0,0 @@ -import torch.cuda -from setuptools import setup -from torch.utils.cpp_extension import CppExtension, CUDAExtension, BuildExtension -from torch.utils.cpp_extension import CUDA_HOME - -ext_modules = [] - -if torch.cuda.is_available() and CUDA_HOME is not None: - extension = CUDAExtension( - 'diag_mult_cuda', [ - 'diag_mult_cuda.cpp', - 'diag_mult_cuda_kernel.cu' - ], - extra_compile_args={'cxx': ['-g'], - 'nvcc': ['-O2']}) - ext_modules.append(extension) - -setup( - name='diag_mult_cuda', - ext_modules=ext_modules, - cmdclass={'build_ext': BuildExtension}) diff --git a/pytorch/structure/fastfood.py b/pytorch/structure/fastfood.py index 3edb1de..e0a1176 100644 --- a/pytorch/structure/fastfood.py +++ b/pytorch/structure/fastfood.py @@ -1,41 +1,33 @@ -from .hadamard import hadamard_transform -import torch -import numpy as np -from scipy.linalg import hadamard +# Copyright 2018 HazyResearch +# https://github.com/HazyResearch/structured-nets +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +from .hadamard import hadamard_transform_cuda, hadamard_transform_torch # S,G,B: diagonal # P: permutation # x: batch_size x n_features -def fastfood_multiply(S,G,B,P,x): - HBx = hadamard_transform(B*x) +def fastfood_multiply(S, G, B, P, x): + HBx = hadamard_transform_torch(B * x) PHBx = HBx[:, P] - HGPHBx = hadamard_transform(G*PHBx) - return S*HGPHBx - -def test_fastfood_multiply(n, batch_size): - S = np.random.randn(n) - G = np.random.randn(n) - B = np.random.randn(n) - P = np.random.permutation(n) - x = np.random.randn(batch_size,n) - H = hadamard(n) - HBx = np.dot(H,(B*x).T).T - PHBx = HBx[:,P] - HGPHBx = np.dot(H,(G*PHBx).T).T - output_explicit = S*HGPHBx - - S = torch.tensor(S, dtype=torch.float, device=device) - G = torch.tensor(G, dtype=torch.float, device=device) - B = torch.tensor(B, dtype=torch.float, device=device) - P = torch.tensor(P, dtype=torch.long, device=device) - x = torch.tensor(x, dtype=torch.float, device=device) + HGPHBx = hadamard_transform_torch(G * PHBx) + return S * HGPHBx - output = fastfood_multiply(S,G,B,P,x) - print(np.linalg.norm(output_explicit - output)) - -# TODO: move test into subpackage -if __name__ == '__main__': - test_fastfood_multiply(128,50) +def fastfood_multiply_cuda(S, G, B, P, x): + HBx = hadamard_transform_cuda(B * x) + PHBx = HBx[:, P] + HGPHBx = hadamard_transform_cuda(G * PHBx) + return S * HGPHBx diff --git a/pytorch/structure/hadamard.py b/pytorch/structure/hadamard.py index 4e7948d..db832f4 100644 --- a/pytorch/structure/hadamard.py +++ b/pytorch/structure/hadamard.py @@ -1,26 +1,23 @@ -import numpy as np -import torch +# Copyright 2018 HazyResearch +# https://github.com/HazyResearch/structured-nets +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -use_hadamard_transform_cuda = True -try: - import hadamard_cuda - # import torch.utils.cpp_extension - # hadamard_cuda = torch.utils.cpp_extension.load( - # name='hadamard_cuda', - # sources=[ - # 'hadamard_cuda/hadamard_cuda.cpp', - # 'hadamard_cuda/hadamard_cuda_kernel.cu', - # ], - # extra_cuda_cflags=['-O2'], - # verbose=False - # ) -except (ImportError, RuntimeError) as e: - print("CUDA version of Hadamard transform isn't installed. Will use Pytorch's version, which is much slower.") - use_hadamard_transform_cuda = False -from scipy.linalg import hadamard +from math import log2 -device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +import hadamard_cuda +import torch def hadamard_transform_torch(u, normalize=False): @@ -33,17 +30,19 @@ def hadamard_transform_torch(u, normalize=False): product: Tensor of shape (..., n) """ batch_size, n = u.shape - m = int(np.log2(n)) - assert n == 1 << m, 'n must be a power of 2' - x = u[..., np.newaxis] - for d in range(m)[::-1]: - x = torch.cat((x[..., ::2, :] + x[..., 1::2, :], x[..., ::2, :] - x[..., 1::2, :]), dim=-1) - return x.squeeze(-2) / 2**(m / 2) if normalize else x.squeeze(-2) + m = int(log2(n)) + assert n == 1 << m, "n must be a power of 2" + x = u[..., None] + for _ in range(m): + top = x[..., ::2, :] + x[..., 1::2, :] + bot = x[..., ::2, :] - x[..., 1::2, :] + x = torch.cat((top, bot), dim=-1) + return x.squeeze(-2) / 2 ** (m / 2) if normalize else x.squeeze(-2) class HadamardTransformCuda(torch.autograd.Function): - '''The unnormalized Hadamard transform (i.e. without dividing by sqrt(2)) - ''' + """The unnormalized Hadamard transform (i.e. without dividing by sqrt(2))""" + @staticmethod def forward(ctx, u): return hadamard_cuda.hadamard_transform(u) @@ -63,33 +62,7 @@ def hadamard_transform_cuda(u, normalize=False): product: Tensor of shape (..., n) """ _, n = u.shape - m = int(np.log2(n)) - assert n == 1 << m, 'n must be a power of 2' + m = int(log2(n)) + assert n == 1 << m, "n must be a power of 2" output = HadamardTransformCuda.apply(u) - return output / 2**(m / 2) if normalize else output - - -def test_hadamard_transform(): - m = 15 - n = 1 << m - batch_size = 50 - u = torch.rand((batch_size, n), requires_grad=True, device=device) - result_cuda = hadamard_transform_cuda(u) - grad_cuda, = torch.autograd.grad(result_cuda.sum(), u, retain_graph=True) - result_torch = hadamard_transform_torch(u) - grad_torch, = torch.autograd.grad(result_torch.sum(), u, retain_graph=True) - # Explicit construction from scipy - H = torch.tensor(hadamard(n), dtype=torch.float, device=device) - result_explicit = u @ H.t() - print((result_cuda - result_explicit).abs().max().item()) - print((result_cuda - result_explicit).abs().mean().item()) - print((result_torch - result_explicit).abs().max().item()) - print((result_torch - result_explicit).abs().mean().item()) - print((grad_cuda - grad_torch).abs().max().item()) - print((grad_cuda - grad_torch).abs().mean().item()) - - -hadamard_transform = hadamard_transform_cuda if use_hadamard_transform_cuda else hadamard_transform_torch - -if __name__ == '__main__': - test_hadamard_transform() + return output / 2 ** (m / 2) if normalize else output diff --git a/pytorch/structure/hadamard_cuda/hadamard_cuda.cpp b/pytorch/structure/hadamard_cuda/hadamard_cuda.cpp index ef20097..7185c78 100644 --- a/pytorch/structure/hadamard_cuda/hadamard_cuda.cpp +++ b/pytorch/structure/hadamard_cuda/hadamard_cuda.cpp @@ -1,3 +1,20 @@ +/* +Copyright 2018 HazyResearch +https://github.com/HazyResearch/structured-nets + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + #include void fwtBatchGPU(float* x, int batchSize, int log2N); @@ -13,6 +30,10 @@ torch::Tensor hadamard_transform(torch::Tensor x) { return output; } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { +PYBIND11_MODULE(HADAMARD_EXTENSION_NAME, m) { m.def("hadamard_transform", &hadamard_transform, "Fast Hadamard transform"); } + +TORCH_LIBRARY(HADAMARD_EXTENSION_NAME, m) { + m.def("hadamard_transform", &hadamard_transform); +} diff --git a/pytorch/structure/hadamard_cuda/hadamard_cuda_kernel.cu b/pytorch/structure/hadamard_cuda/hadamard_cuda_kernel.cu index 5ebbb4c..26a7dbb 100644 --- a/pytorch/structure/hadamard_cuda/hadamard_cuda_kernel.cu +++ b/pytorch/structure/hadamard_cuda/hadamard_cuda_kernel.cu @@ -1,7 +1,7 @@ -/* Adated from the CUDA samples https://docs.nvidia.com/cuda/cuda-samples/index.html. - Changed from "natural order" Hadamard transform (larger strides before - smaller strides) to the standard Hadamard transform (smaller strides before - larger strides). +/* Adated from the CUDA samples + https://docs.nvidia.com/cuda/cuda-samples/index.html. Changed from "natural + order" Hadamard transform (larger strides before smaller strides) to the + standard Hadamard transform (smaller strides before larger strides). */ /* @@ -19,153 +19,137 @@ namespace cg = cooperative_groups; - /////////////////////////////////////////////////////////////////////////////// // Elementary(for vectors less than elementary size) in-shared memory // combined radix-2 + radix-4 Fast Walsh Transform /////////////////////////////////////////////////////////////////////////////// #define ELEMENTARY_LOG2SIZE 11 -__global__ void fwtBatch1Kernel(float *d_Output, float *d_Input, int log2N) -{ - // Handle to thread block group - cg::thread_block cta = cg::this_thread_block(); - const int N = 1 << log2N; - const int base = blockIdx.x << log2N; - - //(2 ** 11) * 4 bytes == 8KB -- maximum s_data[] size for G80 - extern __shared__ float s_data[]; - float *d_Src = d_Input + base; - float *d_Dst = d_Output + base; - - for (int pos = threadIdx.x; pos < N; pos += blockDim.x) - { - s_data[pos] = d_Src[pos]; - } +__global__ void fwtBatch1Kernel(float* d_Output, float* d_Input, int log2N) { + // Handle to thread block group + cg::thread_block cta = cg::this_thread_block(); + const int N = 1 << log2N; + const int base = blockIdx.x << log2N; - int stride = 1; - //Do single radix-2 stage for odd power of two - if (log2N & 1) - { - cg::sync(cta); - - for (int pos = threadIdx.x; pos < N / 2; pos += blockDim.x) - { - int i0 = pos << 1; - int i1 = i0 + 1; - - float D0 = s_data[i0]; - float D1 = s_data[i1]; - s_data[i0] = D0 + D1; - s_data[i1] = D0 - D1; - } - stride <<= 1; - } + //(2 ** 11) * 4 bytes == 8KB -- maximum s_data[] size for G80 + extern __shared__ float s_data[]; + float* d_Src = d_Input + base; + float* d_Dst = d_Output + base; - //Main radix-4 stages - const int pos = threadIdx.x; - - for (; stride <= N >> 2; stride <<= 2) - { - int lo = pos & (stride - 1); - int i0 = ((pos - lo) << 2) + lo; - int i1 = i0 + stride; - int i2 = i1 + stride; - int i3 = i2 + stride; - - cg::sync(cta); - float D0 = s_data[i0]; - float D1 = s_data[i1]; - float D2 = s_data[i2]; - float D3 = s_data[i3]; - - float T; - T = D0; - D0 = D0 + D2; - D2 = T - D2; - T = D1; - D1 = D1 + D3; - D3 = T - D3; - T = D0; - s_data[i0] = D0 + D1; - s_data[i1] = T - D1; - T = D2; - s_data[i2] = D2 + D3; - s_data[i3] = T - D3; - } + for (int pos = threadIdx.x; pos < N; pos += blockDim.x) { + s_data[pos] = d_Src[pos]; + } + int stride = 1; + // Do single radix-2 stage for odd power of two + if (log2N & 1) { cg::sync(cta); - for (int pos = threadIdx.x; pos < N; pos += blockDim.x) - { - d_Dst[pos] = s_data[pos]; + for (int pos = threadIdx.x; pos < N / 2; pos += blockDim.x) { + int i0 = pos << 1; + int i1 = i0 + 1; + + float D0 = s_data[i0]; + float D1 = s_data[i1]; + s_data[i0] = D0 + D1; + s_data[i1] = D0 - D1; } -} + stride <<= 1; + } -//////////////////////////////////////////////////////////////////////////////// -// Single in-global memory radix-4 Fast Walsh Transform pass -// (for strides exceeding elementary vector size) -//////////////////////////////////////////////////////////////////////////////// -__global__ void fwtBatch2Kernel( - float *d_Output, - float *d_Input, - int stride -) -{ - const int pos = blockIdx.x * blockDim.x + threadIdx.x; - const int N = blockDim.x * gridDim.x * 4; - - float *d_Src = d_Input + blockIdx.y * N; - float *d_Dst = d_Output + blockIdx.y * N; + // Main radix-4 stages + const int pos = threadIdx.x; + for (; stride <= N >> 2; stride <<= 2) { int lo = pos & (stride - 1); int i0 = ((pos - lo) << 2) + lo; int i1 = i0 + stride; int i2 = i1 + stride; int i3 = i2 + stride; - float D0 = d_Src[i0]; - float D1 = d_Src[i1]; - float D2 = d_Src[i2]; - float D3 = d_Src[i3]; + cg::sync(cta); + float D0 = s_data[i0]; + float D1 = s_data[i1]; + float D2 = s_data[i2]; + float D3 = s_data[i3]; float T; T = D0; - D0 = D0 + D2; - D2 = T - D2; + D0 = D0 + D2; + D2 = T - D2; T = D1; - D1 = D1 + D3; - D3 = T - D3; + D1 = D1 + D3; + D3 = T - D3; T = D0; - d_Dst[i0] = D0 + D1; - d_Dst[i1] = T - D1; + s_data[i0] = D0 + D1; + s_data[i1] = T - D1; T = D2; - d_Dst[i2] = D2 + D3; - d_Dst[i3] = T - D3; + s_data[i2] = D2 + D3; + s_data[i3] = T - D3; + } + + cg::sync(cta); + + for (int pos = threadIdx.x; pos < N; pos += blockDim.x) { + d_Dst[pos] = s_data[pos]; + } } //////////////////////////////////////////////////////////////////////////////// -// Put everything together: batched Fast Walsh Transform CPU front-end +// Single in-global memory radix-4 Fast Walsh Transform pass +// (for strides exceeding elementary vector size) //////////////////////////////////////////////////////////////////////////////// -void fwtBatchGPU(float *d_Data, int batchSize, int log2N) -{ - int nMixedRadixPasses = log2N > ELEMENTARY_LOG2SIZE ? ELEMENTARY_LOG2SIZE - (log2N - ELEMENTARY_LOG2SIZE) % 2 : log2N; - int N = 1 << nMixedRadixPasses; - int curBatchSize = batchSize << (log2N - nMixedRadixPasses); - - // (N + 3) / 4 to handle the case of N == 2 - fwtBatch1Kernel<<>>( - d_Data, - d_Data, - nMixedRadixPasses - ); - - const int THREAD_N = 256; - dim3 grid((1 << log2N) / (4 * THREAD_N), batchSize, 1); - - for (int logSize = nMixedRadixPasses + 2; logSize <= log2N; logSize += 2) - { - fwtBatch2Kernel<<>>(d_Data, d_Data, (1 << logSize) / 4); - } +__global__ void fwtBatch2Kernel(float* d_Output, float* d_Input, int stride) { + const int pos = blockIdx.x * blockDim.x + threadIdx.x; + const int N = blockDim.x * gridDim.x * 4; + + float* d_Src = d_Input + blockIdx.y * N; + float* d_Dst = d_Output + blockIdx.y * N; + + int lo = pos & (stride - 1); + int i0 = ((pos - lo) << 2) + lo; + int i1 = i0 + stride; + int i2 = i1 + stride; + int i3 = i2 + stride; + + float D0 = d_Src[i0]; + float D1 = d_Src[i1]; + float D2 = d_Src[i2]; + float D3 = d_Src[i3]; + + float T; + T = D0; + D0 = D0 + D2; + D2 = T - D2; + T = D1; + D1 = D1 + D3; + D3 = T - D3; + T = D0; + d_Dst[i0] = D0 + D1; + d_Dst[i1] = T - D1; + T = D2; + d_Dst[i2] = D2 + D3; + d_Dst[i3] = T - D3; +} +//////////////////////////////////////////////////////////////////////////////// +// Put everything together: batched Fast Walsh Transform CPU front-end +//////////////////////////////////////////////////////////////////////////////// +void fwtBatchGPU(float* d_Data, int batchSize, int log2N) { + int nMixedRadixPasses = log2N > ELEMENTARY_LOG2SIZE + ? ELEMENTARY_LOG2SIZE - (log2N - ELEMENTARY_LOG2SIZE) % 2 + : log2N; + int N = 1 << nMixedRadixPasses; + int curBatchSize = batchSize << (log2N - nMixedRadixPasses); + + // (N + 3) / 4 to handle the case of N == 2 + fwtBatch1Kernel<<>>( + d_Data, d_Data, nMixedRadixPasses); + + const int THREAD_N = 256; + dim3 grid((1 << log2N) / (4 * THREAD_N), batchSize, 1); + + for (int logSize = nMixedRadixPasses + 2; logSize <= log2N; logSize += 2) { + fwtBatch2Kernel<<>>(d_Data, d_Data, (1 << logSize) / 4); + } } diff --git a/pytorch/structure/hadamard_cuda/setup.py b/pytorch/structure/hadamard_cuda/setup.py deleted file mode 100644 index f6cafda..0000000 --- a/pytorch/structure/hadamard_cuda/setup.py +++ /dev/null @@ -1,21 +0,0 @@ -import torch.cuda -from setuptools import setup -from torch.utils.cpp_extension import CppExtension, CUDAExtension, BuildExtension -from torch.utils.cpp_extension import CUDA_HOME - -ext_modules = [] - -if torch.cuda.is_available() and CUDA_HOME is not None: - extension = CUDAExtension( - 'hadamard_cuda', [ - 'hadamard_cuda.cpp', - 'hadamard_cuda_kernel.cu' - ], - extra_compile_args={'cxx': ['-g'], - 'nvcc': ['-O2']}) - ext_modules.append(extension) - -setup( - name='hadamard_cuda', - ext_modules=ext_modules, - cmdclass={'build_ext': BuildExtension}) diff --git a/pytorch/structure/krylov.py b/pytorch/structure/krylov.py index fff34e8..70088da 100644 --- a/pytorch/structure/krylov.py +++ b/pytorch/structure/krylov.py @@ -1,4 +1,20 @@ -'''Functions to multiply by an LDR matrix with subdiagonal and tridiagonal +# Copyright 2018 HazyResearch +# https://github.com/HazyResearch/structured-nets +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Functions to multiply by an LDR matrix with subdiagonal and tridiagonal operator matrices. We implement the fast multiplication for the subdiagonal case. @@ -7,116 +23,16 @@ For tridiagonal case, we implement the slow multiplication algorithm: construct the Krylov matrix then call regular matrix multiply. -''' +""" import functools -import numpy as np +from math import ceil, log2 + +import diag_mult_cuda import torch from torch.nn import functional as F -from .scratch.krylovslow import krylov_construct -from .complex_utils import complex_mult, conjugate - -try: - import diag_mult_cuda - # import torch.utils.cpp_extension - # diag_mult_cuda = torch.utils.cpp_extension.load( - # name='diag_mult_cuda', - # sources=[ - # 'diag_mult_cuda/diag_mult_cuda.cpp', - # 'diag_mult_cuda/diag_mult_cuda_kernel.cu', - # ], - # extra_cuda_cflags=['-O2'], - # verbose=False - # ) -except (ImportError, RuntimeError) as e: - print("CUDA version of slow Krylov multiply isn't installed.") - -device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - -##### Fast multiplication for the subdiagonal case - -def poly_mult_sum_benchmark(p, q): - """Multiply and sum two sets of polynomials. - Parameters: - p: (batch_size, n1, n2) - q: (rank, n1, n2) - Output: - o: (batch_size, rank, 2 * n2 - 1) - """ - print(p.shape[2]) - - import time - - torch.cuda.synchronize() - start = time.perf_counter() - for _ in range(100): - y = F.conv1d(p, q.flip(q.dim() - 1), padding=p.shape[-1] -1) - g = torch.autograd.grad(y.sum(), (p, q), retain_graph=True) - torch.cuda.synchronize() - end = time.perf_counter() - print(f'Elapsed time conv1d: {end - start}s.') - - batch_size, rank = p.shape[0], q.shape[0] - n1, n2 = p.shape[1], p.shape[2] - start = time.perf_counter() - for _ in range(100): - S = torch.cat((torch.cat((q, p)), - torch.zeros((rank + batch_size, p.shape[1], p.shape[2]), dtype=q.dtype, device=q.device)), dim=-1) - S_f = torch.rfft(S, 1) - S0_10_f, S1_01_f = S_f[:rank], S_f[rank:rank+batch_size] - prod = (S1_01_f[:, np.newaxis, ..., np.newaxis] * S0_10_f[np.newaxis, ..., np.newaxis, :]).sum(dim=2) - T_00_f_sum = torch.stack((prod[..., 0, 0] - prod[..., 1, 1], prod[..., 0, 1] + prod[..., 1, 0]), dim=-1) - T_00_sum = torch.irfft(T_00_f_sum, 1, signal_sizes=(2 * n2, ))[..., :-1] - g = torch.autograd.grad(T_00_sum.sum(), (p, q), retain_graph=True) - torch.cuda.synchronize() - end = time.perf_counter() - print(f'Elapsed time FFT: {end - start}s.\n') - - return F.conv1d(p, q.flip(q.dim() - 1), padding=p.shape[-1] - 1) - - -def poly_mult_sum_backward_benchmark(grad, q): - """Backward pass of multiplying and summing two sets of polynomials. - Parameters: - grad: (batch_size, rank, 2 * n2 - 1) - q: (rank, n1, n2) - Output: - dp: (batch_size, n1, n2) - """ - print(q.shape[2]) - - import time - - torch.cuda.synchronize() - start = time.perf_counter() - for _ in range(100): - dp = F.conv_transpose1d(grad, q.flip(2), padding=q.shape[-1] - 1) - g = torch.autograd.grad(dp.sum(), (grad, q), retain_graph=True) - torch.cuda.synchronize() - end = time.perf_counter() - print(f'Elapsed time conv1d: {end - start}s.') - - batch_size, rank = grad.shape[0], q.shape[0] - n1, n2 = q.shape[1], q.shape[2] - start = time.perf_counter() - for _ in range(100): - dT_00_sum = torch.cat((grad, torch.zeros((batch_size, rank, 1), dtype=grad.dtype, device=grad.device)), dim=-1) - dT_00_sum_f = torch.rfft(dT_00_sum, 1) - S0_10_f = torch.rfft(torch.cat((q, torch.zeros_like(q)), dim=-1), 1) - # dS1_01_f = complex_mult(conjugate(S0_10_f), dT_00_sum_f[:, :, np.newaxis]).sum(dim=1) - # Manually doing complex multiply - prod = (S0_10_f[..., np.newaxis] * dT_00_sum_f[:, :, np.newaxis, :, np.newaxis, :]).sum(dim=1) - dS1_01_f = torch.stack((prod[..., 0, 0] + prod[..., 1, 1], prod[..., 0, 1] - prod[..., 1, 0]), dim=-1) - dp = torch.irfft(dS1_01_f, 1, signal_sizes=(2 * n2, ))[:, :, :n2] - g = torch.autograd.grad(dp.sum(), (grad, q), retain_graph=True) - torch.cuda.synchronize() - end = time.perf_counter() - print(f'Elapsed time FFT: {end - start}s.\n') - - return F.conv_transpose1d(grad, q.flip(2), padding=q.shape[-1] - 1) - def krylov_transpose_multiply_conv(subdiag, v, u): """Multiply Krylov(A, v_i)^T @ u when A is zero except on the subdiagonal. @@ -131,47 +47,51 @@ def krylov_transpose_multiply_conv(subdiag, v, u): """ batch_size, n = u.shape rank, n_ = v.shape - assert n == n_, 'u and v must have the same last dimension' - m = int(np.log2(n)) - assert n == 1 << m, 'n must be a power of 2' + assert n == n_, "u and v must have the same last dimension" + m = int(log2(n)) + assert n == 1 << m, "n must be a power of 2" result = torch.zeros((batch_size, rank, n), dtype=u.dtype, device=u.device) T_00_sum = u @ v.t() result[:, :, 0] += T_00_sum - T_01 = u[..., np.newaxis] - T_10 = v[..., np.newaxis] + T_01 = u[..., None] + T_10 = v[..., None] T_11 = torch.ones(n, device=T_00_sum.device) for d in range(m)[::-1]: n1, n2 = 1 << d, 1 << (m - d - 1) - S_00_sum, S_01, S_10, S_11 = T_00_sum, T_01, T_10, T_11 - S0_10_mult_subdiag = S_10[:, ::2] * subdiag[(n2 - 1)::(2 * n2), np.newaxis] + _, S_01, S_10, S_11 = T_00_sum, T_01, T_10, T_11 + S0_10_mult_subdiag = S_10[:, ::2] * subdiag[(n2 - 1) :: (2 * n2), None] # polynomial multiplication - # T_00_sum = poly_mult_sum_benchmark(S_01[:, 1::2], S0_10_mult_subdiag) if n2 <= 128: # Pick between 2 implementations based on polynomial degree n2 - T_00_sum = F.conv1d(S_01[:, 1::2], S0_10_mult_subdiag.flip(2), padding=n2 - 1) + T_00_sum = F.conv1d( + S_01[:, 1::2], S0_10_mult_subdiag.flip(2), padding=n2 - 1 + ) else: - S = torch.cat((torch.cat((S0_10_mult_subdiag, S_01[:, 1::2])), - torch.zeros((rank + batch_size, n1, n2), dtype=S_10.dtype, device=S_10.device)), dim=-1) - S_f = torch.rfft(S, 1) - S0_10_f, S1_01_f = S_f[:rank], S_f[rank:rank+batch_size] - # Different ways to compute the same expression, for speed vs readability - # Option 1: call complex_mult, slowest - # T_00_f_sum = complex_mult(S1_01_f[:, np.newaxis], S0_10_f[np.newaxis]).sum(dim=2) - # Option 2: multiply and sum - # prod = (S1_01_f[:, np.newaxis, ..., np.newaxis] * S0_10_f[np.newaxis, ..., np.newaxis, :]).sum(dim=2) - # Option 3: einsum - prod = torch.einsum('bnmo,rnmp->brmop', S1_01_f, S0_10_f) - # Option 4: manually doing permute and reshape and bmm, only 3% faster than einsum. - # temp1 = S1_01_f.permute(2, 0, 3, 1).reshape((-1, batch_size * 2, n1)) - # temp2 = S0_10_f.permute(2, 1, 0, 3).reshape((-1, n1, rank * 2)) - # prod = (temp1 @ temp2).reshape((-1, batch_size, 2, rank, 2)).permute(1, 3, 0, 2, 4) - T_00_f_sum = torch.stack((prod[..., 0, 0] - prod[..., 1, 1], prod[..., 0, 1] + prod[..., 1, 0]), dim=-1) - T_00_sum = torch.irfft(T_00_f_sum, 1, signal_sizes=(2 * n2, ))[..., :-1] + S = torch.cat( + ( + torch.cat((S0_10_mult_subdiag, S_01[:, 1::2])), + torch.zeros( + (rank + batch_size, n1, n2), + dtype=S_10.dtype, + device=S_10.device, + ), + ), + dim=-1, + ) + # S = torch.cat((S0_10_mult_subdiag, S_01[:, 1::2])) + S_f = torch.fft.rfft(S) + S0_10_f, S1_01_f = S_f[:rank], S_f[rank : rank + batch_size] + T_00_f_sum = (S1_01_f[:, None] * S0_10_f[None]).sum(dim=2) + T_00_sum = torch.fft.irfft(T_00_f_sum, n=2 * n2)[..., :-1] # polynomial additions - result[:, :, 1:2*n2] += T_00_sum - S0_11_mult_subdiag = S_11[::2] * subdiag[(n2 - 1)::(2 * n2)] - T_01 = torch.cat((S_01[:, ::2], S_01[:, 1::2] * S0_11_mult_subdiag[:, np.newaxis]), dim=-1) - T_10 = torch.cat((S_10[:, 1::2], S0_10_mult_subdiag * S_11[1::2][:, np.newaxis]), dim=-1) + result[:, :, 1 : 2 * n2] += T_00_sum + S0_11_mult_subdiag = S_11[::2] * subdiag[(n2 - 1) :: (2 * n2)] + T_01 = torch.cat( + (S_01[:, ::2], S_01[:, 1::2] * S0_11_mult_subdiag[:, None]), dim=-1 + ) + T_10 = torch.cat( + (S_10[:, 1::2], S0_10_mult_subdiag * S_11[1::2][:, None]), dim=-1 + ) T_11 = S0_11_mult_subdiag * S_11[1::2] return result @@ -187,51 +107,46 @@ def krylov_transpose_multiply(subdiag, v, u): """ batch_size, n = u.shape rank, n_ = v.shape - assert n == n_, 'u and v must have the same last dimension' - m = int(np.log2(n)) - assert n == 1 << m, 'n must be a power of 2' + assert n == n_, "u and v must have the same last dimension" + m = int(log2(n)) + assert n == 1 << m, "n must be a power of 2" result = torch.zeros((batch_size, rank, n), dtype=u.dtype, device=u.device) - # T_00_sum = (u[:, np.newaxis, ..., np.newaxis] * v[np.newaxis, ..., np.newaxis]).sum(dim=2) T_00_sum = u @ v.t() result[:, :, 0] = T_00_sum - T_01 = u[..., np.newaxis] - T_10 = v[..., np.newaxis] + T_01 = u[..., None] + T_10 = v[..., None] T_11 = torch.ones(n, device=T_00_sum.device) for d in range(m)[::-1]: n1, n2 = 1 << d, 1 << (m - d - 1) S_01, S_10, S_11 = T_01, T_10, T_11 - # S0_10 = torch.cat((S_10[:, ::2], torch.zeros_like(S_10[:, ::2])), dim=-1) - # S1_01 = torch.cat((S_01[:, 1::2], torch.zeros_like(S_01[:, 1::2])), dim=-1) - # S = torch.cat((S0_10, S1_01)) - S0_10_mult_subdiag = S_10[:, ::2] * subdiag[(n2 - 1)::(2 * n2), np.newaxis] - S = torch.cat((torch.cat((S0_10_mult_subdiag, S_01[:, 1::2])), - torch.zeros((rank + batch_size, n1, n2), dtype=S_10.dtype, device=S_10.device)), dim=-1) + S0_10_mult_subdiag = S_10[:, ::2] * subdiag[(n2 - 1) :: (2 * n2), None] + S = torch.cat( + ( + torch.cat((S0_10_mult_subdiag, S_01[:, 1::2])), + torch.zeros( + (rank + batch_size, n1, n2), dtype=S_10.dtype, device=S_10.device + ), + ), + dim=-1, + ) + + S_f = torch.fft.rfft(S) # polynomial multiplications - S_f = torch.rfft(S, 1) - S0_10_f, S1_01_f = S_f[:rank], S_f[rank:rank+batch_size] - # Different ways to compute the same expression, for speed vs readability - # Option 1: call complex_mult, slowest - # T_00_f_sum = complex_mult(S1_01_f[:, np.newaxis], S0_10_f[np.newaxis]).sum(dim=2) - # Option 2: multiply and sum - # Manually doing complex multiply, somehow this is faster than Cupy's complex mult - # prod = (S1_01_f[:, np.newaxis, ..., np.newaxis] * S0_10_f[np.newaxis, ..., np.newaxis, :]).sum(dim=2) - # Option 3: einsum - prod = torch.einsum('bnmo,rnmp->brmop', S1_01_f, S0_10_f) - # Option 4: manually doing permute and reshape and bmm, only 3% faster than einsum. - # temp1 = S1_01_f.permute(2, 0, 3, 1).reshape((-1, batch_size * 2, n1)) - # temp2 = S0_10_f.permute(2, 1, 0, 3).reshape((-1, n1, rank * 2)) - # prod = (temp1 @ temp2).reshape((-1, batch_size, 2, rank, 2)).permute(1, 3, 0, 2, 4) - # prod = (S1_01_f[:, np.newaxis, ..., np.newaxis] * S0_10_f[np.newaxis, ..., np.newaxis, :]).sum(dim=2) - T_00_f_sum = torch.stack((prod[..., 0, 0] - prod[..., 1, 1], prod[..., 0, 1] + prod[..., 1, 0]), dim=-1) - T_00_sum = torch.irfft(T_00_f_sum, 1, signal_sizes=(2 * n2, ))[..., :-1] + S0_10_f, S1_01_f = S_f[:rank], S_f[rank : rank + batch_size] + T_00_f_sum = (S1_01_f[:, None] * S0_10_f[None]).sum(dim=2) + T_00_sum = torch.fft.irfft(T_00_f_sum, n=2 * n2)[..., :-1] # polynomial additions - result[:, :, 1:2*n2] += T_00_sum - S0_11_mult_subdiag = S_11[::2] * subdiag[(n2 - 1)::(2 * n2)] - T_01 = torch.cat((S_01[:, ::2], S_01[:, 1::2] * S0_11_mult_subdiag[:, np.newaxis]), dim=-1) - T_10 = torch.cat((S_10[:, 1::2], S0_10_mult_subdiag * S_11[1::2][:, np.newaxis]), dim=-1) + result[:, :, 1 : 2 * n2] += T_00_sum + S0_11_mult_subdiag = S_11[::2] * subdiag[(n2 - 1) :: (2 * n2)] + T_01 = torch.cat( + (S_01[:, ::2], S_01[:, 1::2] * S0_11_mult_subdiag[:, None]), dim=-1 + ) + T_10 = torch.cat( + (S_10[:, 1::2], S0_10_mult_subdiag * S_11[1::2][:, None]), dim=-1 + ) T_11 = S0_11_mult_subdiag * S_11[1::2] return result @@ -249,50 +164,40 @@ def KTu_traceable(subdiag, v, u): """ batch_size, n = u.shape rank, n_ = v.shape - # assert n == n_, 'u and v must have the same last dimension' - m = int(np.log2(n)) - # assert n == 1 << m, 'n must be a power of 2' + m = int(log2(n)) - # T_00_sum = (u[:, np.newaxis, ..., np.newaxis] * v[np.newaxis, ..., np.newaxis]).sum(dim=2) T_00_sum = u @ v.t() result = T_00_sum.unsqueeze(-1) - T_01 = u[..., np.newaxis] - T_10 = v[..., np.newaxis] + T_01 = u[..., None] + T_10 = v[..., None] T_11 = torch.ones(n, device=T_00_sum.device) for d in range(m)[::-1]: - n1, n2 = 1 << d, 1 << (m - d - 1) + n2 = 1 << (m - d - 1) S_01, S_10, S_11 = T_01, T_10, T_11 - # S0_10 = torch.cat((S_10[:, ::2], torch.zeros_like(S_10[:, ::2])), dim=-1) - # S1_01 = torch.cat((S_01[:, 1::2], torch.zeros_like(S_01[:, 1::2])), dim=-1) - # S = torch.cat((S0_10, S1_01)) - S0_10_mult_subdiag = S_10[:, ::2] * subdiag[(n2 - 1)::(2 * n2), np.newaxis] - S = torch.cat((torch.cat((S0_10_mult_subdiag, S_01[:, 1::2])), - torch.zeros((rank + batch_size, n1, n2), dtype=S_10.dtype, device=S_10.device)), dim=-1) - + S0_10_mult_subdiag = S_10[:, ::2] * subdiag[(n2 - 1) :: (2 * n2), None] + S = torch.cat((S0_10_mult_subdiag, S_01[:, 1::2])) # polynomial multiplications - S_f = torch.rfft(S, 1) - S0_10_f, S1_01_f = S_f[:rank], S_f[rank:rank+batch_size] - # Different ways to compute the same expression, for speed vs readability - # Option 1: call complex_mult, slowest - # T_00_f_sum = complex_mult(S1_01_f[:, np.newaxis], S0_10_f[np.newaxis]).sum(dim=2) - # Option 2: multiply and sum - # Manually doing complex multiply, somehow this is faster than Cupy's complex mult - # prod = (S1_01_f[:, np.newaxis, ..., np.newaxis] * S0_10_f[np.newaxis, ..., np.newaxis, :]).sum(dim=2) - # Option 3: einsum - prod = torch.einsum('bnmo,rnmp->brmop', S1_01_f, S0_10_f) - # Option 4: manually doing permute and reshape and bmm, only 3% faster than einsum. - # temp1 = S1_01_f.permute(2, 0, 3, 1).reshape((-1, batch_size * 2, n1)) - # temp2 = S0_10_f.permute(2, 1, 0, 3).reshape((-1, n1, rank * 2)) - # prod = (temp1 @ temp2).reshape((-1, batch_size, 2, rank, 2)).permute(1, 3, 0, 2, 4) - # prod = (S1_01_f[:, np.newaxis, ..., np.newaxis] * S0_10_f[np.newaxis, ..., np.newaxis, :]).sum(dim=2) - T_00_f_sum = torch.stack((prod[..., 0, 0] - prod[..., 1, 1], prod[..., 0, 1] + prod[..., 1, 0]), dim=-1) - T_00_sum = torch.irfft(T_00_f_sum, 1, signal_sizes=(2 * n2, ))[..., :-1] + S_f = torch.fft.rfft(S) + S0_10_f, S1_01_f = S_f[:rank], S_f[rank : rank + batch_size] + T_00_f_sum = (S1_01_f[:, None] * S0_10_f[None]).sum(dim=2) + T_00_sum = torch.fft.irfft(T_00_f_sum, n=2 * n2)[..., :-1] # polynomial additions - result = torch.cat((result[:, :, :1], result[:, :, 1:] + T_00_sum[:, :, :n2 - 1], T_00_sum[:, :, n2 - 1:]), dim=-1) - S0_11_mult_subdiag = S_11[::2] * subdiag[(n2 - 1)::(2 * n2)] - T_01 = torch.cat((S_01[:, ::2], S_01[:, 1::2] * S0_11_mult_subdiag[:, np.newaxis]), dim=-1) - T_10 = torch.cat((S_10[:, 1::2], S0_10_mult_subdiag * S_11[1::2][:, np.newaxis]), dim=-1) + result = torch.cat( + ( + result[:, :, :1], + result[:, :, 1:] + T_00_sum[:, :, : n2 - 1], + T_00_sum[:, :, n2 - 1 :], + ), + dim=-1, + ) + S0_11_mult_subdiag = S_11[::2] * subdiag[(n2 - 1) :: (2 * n2)] + T_01 = torch.cat( + (S_01[:, ::2], S_01[:, 1::2] * S0_11_mult_subdiag[:, None]), dim=-1 + ) + T_10 = torch.cat( + (S_10[:, 1::2], S0_10_mult_subdiag * S_11[1::2][:, None]), dim=-1 + ) T_11 = S0_11_mult_subdiag * S_11[1::2] return result @@ -310,42 +215,44 @@ def krylov_transpose_multiply_old(subdiag, v, u): """ batch_size, n = u.shape rank, n_ = v.shape - assert n == n_, 'u and v must have the same last dimension' - m = int(np.log2(n)) - assert n == 1 << m, 'n must be a power of 2' + assert n == n_, "u and v must have the same last dimension" + m = int(log2(n)) + assert n == 1 << m, "n must be a power of 2" - T_00 = u[:, np.newaxis, ..., np.newaxis] * v[np.newaxis, ..., np.newaxis] - T_01 = u[..., np.newaxis] - T_10 = v[..., np.newaxis] + T_00 = u[:, None, ..., None] * v[None, ..., None] + T_01 = u[..., None] + T_10 = v[..., None] T_11 = torch.ones((n, 1), device=T_00.device) for d in range(m)[::-1]: - n1, n2 = 1 << d, 1 << (m - d - 1) + _, n2 = 1 << d, 1 << (m - d - 1) S_00, S_01, S_10, S_11 = T_00, T_01, T_10, T_11 S0_10 = torch.cat((S_10[:, ::2], torch.zeros_like(S_10[:, ::2])), dim=-1) S1_01 = torch.cat((S_01[:, 1::2], torch.zeros_like(S_01[:, 1::2])), dim=-1) S0_11 = torch.cat((S_11[::2], torch.zeros_like(S_11[::2])), dim=-1) S1_11 = torch.cat((S_11[1::2], torch.zeros_like(S_11[1::2])), dim=-1) - S = torch.cat((S0_10, S0_11[np.newaxis], S1_01, S1_11[np.newaxis])) + S = torch.cat((S0_10, S0_11[None], S1_01, S1_11[None])) # polynomial multiplications - S_f = torch.rfft(S, 1) - # S0_10_f, S0_11_f, S1_01_f, S1_11_f = S_f[:rank], S_f[rank], S_f[rank+1:rank+1+batch_size], S_f[-1] - # T_00_f = complex_mult(S1_01_f[:, np.newaxis], S0_10_f[np.newaxis]) - # T_01_f = complex_mult(S1_01_f, S0_11_f) - # T_10_f = complex_mult(S1_11_f, S0_10_f) - # T_11_f = complex_mult(S1_11_f, S0_11_f) - - # T_f = torch.cat((torch.cat((T_00_f, T_01_f[:, np.newaxis]), dim=1), - # torch.cat((T_10_f[np.newaxis], T_11_f[np.newaxis, np.newaxis]), dim=1))) - - # I didn't realize you could just batch all 4 multiplications like this - T_f = complex_mult(S_f[rank+1:, np.newaxis], S_f[:rank+1]) - - T = torch.irfft(T_f, 1, signal_sizes=(2 * n2, )) * subdiag[(n2 - 1)::(2 * n2), np.newaxis] - T_00, T_01, T_10, T_11 = T[:batch_size, :rank], T[:batch_size, -1], T[-1, :rank], T[-1, -1] + S_f = torch.fft.rfft(S, 1) + # T_f = complex_mult(S_f[rank + 1 :, None], S_f[: rank + 1]) + T_f = S_f[rank + 1 :, None] * S_f[: rank + 1] + + T = torch.fft.irfft(T_f, n=2 * n2) * subdiag[(n2 - 1) :: (2 * n2), None] + T_00, T_01, T_10, T_11 = ( + T[:batch_size, :rank], + T[:batch_size, -1], + T[-1, :rank], + T[-1, -1], + ) # polynomial additions - T_00 = torch.cat((T_00[:, :, :, :n2], T_00[:, :, :, n2:] + S_00[:, :, ::2] + S_00[:, :, 1::2]), dim=-1) + T_00 = torch.cat( + ( + T_00[:, :, :, :n2], + T_00[:, :, :, n2:] + S_00[:, :, ::2] + S_00[:, :, 1::2], + ), + dim=-1, + ) T_01 = torch.cat((T_01[:, :, :n2], T_01[:, :, n2:] + S_01[:, ::2]), dim=-1) T_10 = torch.cat((T_10[:, :, :n2], T_10[:, :, n2:] + S_10[:, 1::2]), dim=-1) @@ -353,7 +260,7 @@ def krylov_transpose_multiply_old(subdiag, v, u): def krylov_multiply_conv(subdiag, v, w): - """Multiply \sum_i Krylov(A, v_i) @ w_i when A is zero except on the subdiagonal. + """Multiply sum_i Krylov(A, v_i) @ w_i when A is zero except on the subdiagonal. Since K @ w can be computed by autodiffing K^T @ u, the algorithm is just hand-differentiating the code of @krylov_transpose_multiply. Use either Pytorch's conv1d or FFT for polynomial multiplication, depending @@ -367,24 +274,26 @@ def krylov_multiply_conv(subdiag, v, w): """ batch_size, rank, n = w.shape rank_, n_ = v.shape - assert n == n_, 'w and v must have the same last dimension' - assert rank == rank_, 'w and v must have the same rank' - m = int(np.log2(n)) - assert n == 1 << m, 'n must be a power of 2' + assert n == n_, "w and v must have the same last dimension" + assert rank == rank_, "w and v must have the same rank" + m = int(log2(n)) + assert n == 1 << m, "n must be a power of 2" # Forward pass. Since K @ w can be computed by autodiffing K^T @ u, we # carry out the forward pass K^T @ u for u = 0 here to save the # intermediate values. This code is exactly the same as the function # @krylov_transpose_multiply, specialized to the case where u = 0. save_for_backward = [None] * m - T_10 = v[..., np.newaxis] + T_10 = v[..., None] T_11 = torch.ones((n), device=T_10.device) for d in range(m)[::-1]: n1, n2 = 1 << d, 1 << (m - d - 1) S_10, S_11 = T_10, T_11 - S0_10_mult_subdiag = S_10[:, ::2] * subdiag[(n2 - 1)::(2 * n2), np.newaxis] - T_10 = torch.cat((S_10[:, 1::2], S0_10_mult_subdiag * S_11[1::2][:, np.newaxis]), dim=-1) - S0_11_mult_subdiag = S_11[::2] * subdiag[(n2 - 1)::(2 * n2)] + S0_10_mult_subdiag = S_10[:, ::2] * subdiag[(n2 - 1) :: (2 * n2), None] + T_10 = torch.cat( + (S_10[:, 1::2], S0_10_mult_subdiag * S_11[1::2][:, None]), dim=-1 + ) + S0_11_mult_subdiag = S_11[::2] * subdiag[(n2 - 1) :: (2 * n2)] save_for_backward[d] = S0_10_mult_subdiag, S0_11_mult_subdiag T_11 = S0_11_mult_subdiag * S_11[1::2] @@ -396,20 +305,27 @@ def krylov_multiply_conv(subdiag, v, w): S0_10_mult_subdiag, S0_11_mult_subdiag = save_for_backward[d] dS_01 = torch.empty((batch_size, 2 * n1, n2), device=w.device) dS_01[:, ::2] = dT_01[:, :, :n2] - # dS1_01 = poly_mult_sum_backward_benchmark(w[:, :, 1:2*n2], S0_10_mult_subdiag) if n2 <= 128: - dS1_01 = F.conv_transpose1d(w[:, :, 1:2*n2], S0_10_mult_subdiag.flip(2), padding=n2 - 1) + dS1_01 = F.conv_transpose1d( + w[:, :, 1 : 2 * n2], S0_10_mult_subdiag.flip(2), padding=n2 - 1 + ) else: - dT_00_sum = torch.cat((w[:, :, 1:2*n2], torch.zeros((batch_size, rank, 1), dtype=w.dtype, device=w.device)), dim=-1) - dT_00_sum_f = torch.rfft(dT_00_sum, 1) - S0_10_f = torch.rfft(torch.cat((S0_10_mult_subdiag, torch.zeros_like(S0_10_mult_subdiag)), dim=-1), 1) - # dS1_01_f = complex_mult(conjugate(S0_10_f), dT_00_sum_f[:, :, np.newaxis]).sum(dim=1) - # Manually doing complex multiply - # prod = (S0_10_f[..., np.newaxis] * dT_00_sum_f[:, :, np.newaxis, :, np.newaxis, :]).sum(dim=1) - prod = torch.einsum('rnmo,brmp->bnmop', S0_10_f, dT_00_sum_f) - dS1_01_f = torch.stack((prod[..., 0, 0] + prod[..., 1, 1], prod[..., 0, 1] - prod[..., 1, 0]), dim=-1) - dS1_01 = torch.irfft(dS1_01_f, 1, signal_sizes=(2 * n2, ))[:, :, :n2] - dS_01[:, 1::2] = dT_01[:, :, n2:] * S0_11_mult_subdiag[:, np.newaxis] + dS1_01 + dT_00_sum = torch.cat( + ( + w[:, :, 1 : 2 * n2], + torch.zeros((batch_size, rank, 1), dtype=w.dtype, device=w.device), + ), + dim=-1, + ) + dT_00_sum_f = torch.fft.rfft(dT_00_sum) + S0_10_f = torch.fft.rfft( + torch.cat( + (S0_10_mult_subdiag, torch.zeros_like(S0_10_mult_subdiag)), dim=-1 + ), + ) + dS1_01_f = (S0_10_f.conj() * dT_00_sum_f[:, :, None]).sum(dim=1) + dS1_01 = torch.fft.irfft(dS1_01_f, n=2 * n2)[:, :, :n2] + dS_01[:, 1::2] = dT_01[:, :, n2:] * S0_11_mult_subdiag[:, None] + dS1_01 dT_01 = dS_01 @@ -417,8 +333,9 @@ def krylov_multiply_conv(subdiag, v, w): du = w[:, :, 0] @ v + dT_01.squeeze(dim=-1) return du + def krylov_multiply(subdiag, v, w): - """Multiply \sum_i Krylov(A, v_i) @ w_i when A is zero except on the subdiagonal. + """Multiply sum_i Krylov(A, v_i) @ w_i when A is zero except on the subdiagonal. Since K @ w can be computed by autodiffing K^T @ u, the algorithm is just hand-differentiating the code of @krylov_transpose_multiply. Parameters: @@ -430,24 +347,26 @@ def krylov_multiply(subdiag, v, w): """ batch_size, rank, n = w.shape rank_, n_ = v.shape - assert n == n_, 'w and v must have the same last dimension' - assert rank == rank_, 'w and v must have the same rank' - m = int(np.log2(n)) - assert n == 1 << m, 'n must be a power of 2' + assert n == n_, "w and v must have the same last dimension" + assert rank == rank_, "w and v must have the same rank" + m = int(log2(n)) + assert n == 1 << m, "n must be a power of 2" # Forward pass. Since K @ w can be computed by autodiffing K^T @ u, we # carry out the forward pass K^T @ u for u = 0 here to save the # intermediate values. This code is exactly the same as the function # @krylov_transpose_multiply, specialized to the case where u = 0. save_for_backward = [None] * m - T_10 = v[..., np.newaxis] + T_10 = v[..., None] T_11 = torch.ones((n), device=T_10.device) for d in range(m)[::-1]: n1, n2 = 1 << d, 1 << (m - d - 1) S_10, S_11 = T_10, T_11 - S0_10_mult_subdiag = S_10[:, ::2] * subdiag[(n2 - 1)::(2 * n2), np.newaxis] - T_10 = torch.cat((S_10[:, 1::2], S0_10_mult_subdiag * S_11[1::2][:, np.newaxis]), dim=-1) - S0_11_mult_subdiag = S_11[::2] * subdiag[(n2 - 1)::(2 * n2)] + S0_10_mult_subdiag = S_10[:, ::2] * subdiag[(n2 - 1) :: (2 * n2), None] + T_10 = torch.cat( + (S_10[:, 1::2], S0_10_mult_subdiag * S_11[1::2][:, None]), dim=-1 + ) + S0_11_mult_subdiag = S_11[::2] * subdiag[(n2 - 1) :: (2 * n2)] save_for_backward[d] = S0_10_mult_subdiag, S0_11_mult_subdiag T_11 = S0_11_mult_subdiag * S_11[1::2] @@ -459,17 +378,23 @@ def krylov_multiply(subdiag, v, w): S0_10_mult_subdiag, S0_11_mult_subdiag = save_for_backward[d] dS_01 = torch.empty((batch_size, 2 * n1, n2), device=w.device) dS_01[:, ::2] = dT_01[:, :, :n2] - dT_00_sum = torch.cat((w[:, :, 1:2*n2], torch.zeros((batch_size, rank, 1), dtype=w.dtype, device=w.device)), dim=-1) - - dT_00_sum_f = torch.rfft(dT_00_sum, 1) - S0_10_f = torch.rfft(torch.cat((S0_10_mult_subdiag, torch.zeros_like(S0_10_mult_subdiag)), dim=-1), 1) - # dS1_01_f = complex_mult(conjugate(S0_10_f), dT_00_sum_f[:, :, np.newaxis]).sum(dim=1) - # Manually doing complex multiply - # prod = (S0_10_f[..., np.newaxis] * dT_00_sum_f[:, :, np.newaxis, :, np.newaxis, :]).sum(dim=1) - prod = torch.einsum('rnmo,brmp->bnmop', S0_10_f, dT_00_sum_f) - dS1_01_f = torch.stack((prod[..., 0, 0] + prod[..., 1, 1], prod[..., 0, 1] - prod[..., 1, 0]), dim=-1) - dS1_01 = torch.irfft(dS1_01_f, 1, signal_sizes=(2 * n2, ))[:, :, :n2] - dS_01[:, 1::2] = dT_01[:, :, n2:] * S0_11_mult_subdiag[:, np.newaxis] + dS1_01 + dT_00_sum = torch.cat( + ( + w[:, :, 1 : 2 * n2], + torch.zeros((batch_size, rank, 1), dtype=w.dtype, device=w.device), + ), + dim=-1, + ) + + dT_00_sum_f = torch.fft.rfft(dT_00_sum) + S0_10_f = torch.fft.rfft( + torch.cat( + (S0_10_mult_subdiag, torch.zeros_like(S0_10_mult_subdiag)), dim=-1 + ) + ) + dS1_01_f = (S0_10_f.conj() * dT_00_sum_f[:, :, None]).sum(dim=1) + dS1_01 = torch.fft.irfft(dS1_01_f, n=2 * n2)[:, :, :n2] + dS_01[:, 1::2] = dT_01[:, :, n2:] * S0_11_mult_subdiag[:, None] + dS1_01 dT_01 = dS_01 @@ -477,8 +402,9 @@ def krylov_multiply(subdiag, v, w): du = w[:, :, 0] @ v + dT_01.squeeze(dim=-1) return du + def krylov_multiply_by_autodiff(subdiag, v, w): - """Multiply \sum_i Krylov(A, v_i) @ w_i when A is zero except on the subdiagonal, using Pytorch's autodiff. + """Multiply sum_i Krylov(A, v_i) @ w_i when A is zero except on the subdiagonal, using Pytorch's autodiff. Parameters: subdiag: Tensor of shape (n - 1, ) v: Tensor of shape (rank, n) @@ -488,14 +414,14 @@ def krylov_multiply_by_autodiff(subdiag, v, w): """ batch_size, rank, n = w.shape rank_, n_ = v.shape - assert n == n_, 'w and v must have the same last dimension' - assert rank == rank_, 'w and v must have the same rank' - m = int(np.log2(n)) - assert n == 1 << m, 'n must be a power of 2' + assert n == n_, "w and v must have the same last dimension" + assert rank == rank_, "w and v must have the same rank" + m = int(log2(n)) + assert n == 1 << m, "n must be a power of 2" u = torch.zeros((batch_size, n), dtype=v.dtype, device=v.device, requires_grad=True) prod = krylov_transpose_multiply(subdiag, v, u) - result, = torch.autograd.grad(prod, u, grad_outputs=w, create_graph=True) + (result,) = torch.autograd.grad(prod, u, grad_outputs=w, create_graph=True) return result @@ -513,34 +439,30 @@ def krylov_multiply_forward_old_(subdiag, v): necessary for the backward pass K @ w. """ rank, n = v.shape - m = int(np.log2(n)) - assert n == 1 << m, 'n must be a power of 2' + m = int(log2(n)) + assert n == 1 << m, "n must be a power of 2" save_for_backward = [None] * m - T_10 = v[..., np.newaxis] + T_10 = v[..., None] T_11 = torch.ones((n, 1), device=T_10.device) for d in range(m)[::-1]: - n1, n2 = 1 << d, 1 << (m - d - 1) + _, n2 = 1 << d, 1 << (m - d - 1) S_10, S_11 = T_10, T_11 S0_10 = torch.cat((S_10[:, ::2], torch.zeros_like(S_10[:, ::2])), dim=-1) S0_11 = torch.cat((S_11[::2], torch.zeros_like(S_11[::2])), dim=-1) S1_11 = torch.cat((S_11[1::2], torch.zeros_like(S_11[1::2])), dim=-1) - S = torch.cat((S0_10, S0_11[np.newaxis], S1_11[np.newaxis])) + S = torch.cat((S0_10, S0_11[None], S1_11[None])) # polynomial multiplications - S_f = torch.rfft(S, 1) - # S0_10_f, S0_11_f, S1_11_f = S_f[:rank], S_f[-2], S_f[-1] - # save_for_backward[d] = (S0_10_f, S0_11_f) - - # T_10_f = complex_mult(S1_11_f, S0_10_f) - # T_11_f = complex_mult(S1_11_f, S0_11_f) - - # T_f = torch.cat((T_10_f, T_11_f[np.newaxis])) - - save_for_backward[d] = S_f[:rank+1] - T_f = complex_mult(S_f[-1], S_f[:rank+1]) - - T = torch.irfft(T_f, 1, signal_sizes=(2 * n2, )) * subdiag[(n2 - 1)::(2 * n2), np.newaxis] + S_f = torch.fft.rfft(S, 1) + save_for_backward[d] = S_f[: rank + 1] + # T_f = complex_mult(S_f[-1], S_f[: rank + 1]) + T_f = S_f[-1] * S_f[: rank + 1] + + T = ( + torch.irfft(T_f, 1, signal_sizes=(2 * n2,)) + * subdiag[(n2 - 1) :: (2 * n2), None] + ) T_10, T_11 = T[:rank], T[-1] # polynomial additions @@ -548,8 +470,9 @@ def krylov_multiply_forward_old_(subdiag, v): return save_for_backward + def krylov_multiply_old(subdiag, v, w): - """Multiply \sum_i Krylov(A, v_i) @ w_i when A is zero except on the subdiagonal. + """Multiply sum_i Krylov(A, v_i) @ w_i when A is zero except on the subdiagonal. Since K @ w can be computed by autodiffing K^T @ u, the algorithm is just hand-differentiating the code of @krylov_transpose_multiply. Uses the old algorithm that scales worse when batching. @@ -562,14 +485,16 @@ def krylov_multiply_old(subdiag, v, w): """ batch_size, rank, n = w.shape rank_, n_ = v.shape - assert n == n_, 'w and v must have the same last dimension' - assert rank == rank_, 'w and v must have the same rank' - m = int(np.log2(n)) - assert n == 1 << m, 'n must be a power of 2' + assert n == n_, "w and v must have the same last dimension" + assert rank == rank_, "w and v must have the same rank" + m = int(log2(n)) + assert n == 1 << m, "n must be a power of 2" save_for_backward = krylov_multiply_forward_old_(subdiag, v) - w = w[:, :, np.newaxis, :] - dT_00, dT_01 = w.flip(w.dim() - 1), torch.zeros((batch_size, 1, n), dtype=w.dtype, device=w.device) + w = w[:, :, None, :] + dT_00, dT_01 = w.flip(w.dim() - 1), torch.zeros( + (batch_size, 1, n), dtype=w.dtype, device=w.device + ) for d in range(m): n1, n2 = 1 << d, 1 << (m - d - 1) @@ -579,27 +504,24 @@ def krylov_multiply_old(subdiag, v, w): dS_01 = torch.empty((batch_size, 2 * n1, n2), device=w.device) dS_01[:, ::2] = dT_01[:, :, n2:] - dT = torch.cat((dT_00, dT_01[:, np.newaxis]), dim=1) - dT = dT * subdiag[(n2 - 1)::(2 * n2), np.newaxis] + dT = torch.cat((dT_00, dT_01[:, None]), dim=1) + dT = dT * subdiag[(n2 - 1) :: (2 * n2), None] - dT_f = torch.rfft(dT, 1) / (2 * n2) - # dT_00_f, dT_01_f = dT_f[:, :rank], dT_f[:, -1] + dT_f = torch.fft.rfft(dT, 1) / (2 * n2) - # S0_10_f, S0_11_f = save_for_backward[d] - # dS1_01_f = complex_mult(conjugate(S0_10_f)[np.newaxis], dT_00_f).sum(dim=1) + complex_mult(conjugate(S0_11_f), dT_01_f) + dS1_01_f = (save_for_backward[d].conj() * dT_f).sum(dim=1) - dS1_01_f = complex_mult(conjugate(save_for_backward[d]), dT_f).sum(dim=1) - - dS1_01 = torch.irfft(dS1_01_f, 1, signal_sizes=(2 * n2, )) * (2 * n2) + dS1_01 = torch.irfft(dS1_01_f, 1, signal_sizes=(2 * n2,)) * (2 * n2) dS_01[:, 1::2] = dS1_01[:, :, :n2] dT_00, dT_01 = dS_00, dS_01 - du = ((dT_00 * v[np.newaxis, :, :, np.newaxis]).sum(dim=1) + dT_01).squeeze(dim=-1) + du = ((dT_00 * v[None, :, :, None]).sum(dim=1) + dT_01).squeeze(dim=-1) return du + def subdiag_mult_conv(subdiag_A, subdiag_B, G, H, x): - """Multiply \sum_i Krylov(A, G_i) @ Krylov(B, H_i) @ x when A and B are zero except on the subdiagonal. + """Multiply sum_i Krylov(A, G_i) @ Krylov(B, H_i) @ x when A and B are zero except on the subdiagonal. Uses the fast algorithm. Use either Pytorch's conv1d or FFT for polynomial multiplication, depending on polynomial degree. This is the fastest implementation. @@ -616,21 +538,47 @@ def subdiag_mult_conv(subdiag_A, subdiag_B, G, H, x): batch_size = x.shape[0] # if not power of 2, round everything up # TODO: this can maybe be handled better. also should benchmark how much speed non-po2 FFT loses - m = int(np.ceil(np.log2(n))) + m = int(ceil(log2(n))) n_extended = 1 << m if n != n_extended: - x = torch.cat((x, torch.zeros(batch_size, n_extended - n, dtype=x.dtype, device=x.device)), dim=-1) - G = torch.cat((G, torch.zeros(rank, n_extended - n, dtype=G.dtype, device=G.device)), dim=-1) - H = torch.cat((H, torch.zeros(rank, n_extended - n, dtype=H.dtype, device=H.device)), dim=-1) - subdiag_A = torch.cat((subdiag_A, torch.zeros(n_extended - n, dtype=subdiag_A.dtype, device=subdiag_A.device))) - subdiag_B = torch.cat((subdiag_B, torch.zeros(n_extended - n, dtype=subdiag_B.dtype, device=subdiag_B.device))) + x = torch.cat( + ( + x, + torch.zeros(batch_size, n_extended - n, dtype=x.dtype, device=x.device), + ), + dim=-1, + ) + G = torch.cat( + (G, torch.zeros(rank, n_extended - n, dtype=G.dtype, device=G.device)), + dim=-1, + ) + H = torch.cat( + (H, torch.zeros(rank, n_extended - n, dtype=H.dtype, device=H.device)), + dim=-1, + ) + subdiag_A = torch.cat( + ( + subdiag_A, + torch.zeros( + n_extended - n, dtype=subdiag_A.dtype, device=subdiag_A.device + ), + ) + ) + subdiag_B = torch.cat( + ( + subdiag_B, + torch.zeros( + n_extended - n, dtype=subdiag_B.dtype, device=subdiag_B.device + ), + ) + ) KT_out = krylov_transpose_multiply_conv(subdiag_B, H, x) K_out = krylov_multiply_conv(subdiag_A, G, KT_out) return K_out[:, :n] if n != n_extended else K_out def subdiag_mult(subdiag_A, subdiag_B, G, H, x): - """Multiply \sum_i Krylov(A, G_i) @ Krylov(B, H_i) @ x when A and B are zero except on the subdiagonal. + """Multiply sum_i Krylov(A, G_i) @ Krylov(B, H_i) @ x when A and B are zero except on the subdiagonal. Uses the fast algorithm. Parameters: subdiag_A: Tensor of shape (n - 1, ) @@ -645,20 +593,48 @@ def subdiag_mult(subdiag_A, subdiag_B, G, H, x): batch_size = x.shape[0] # if not power of 2, round everything up # TODO: this can maybe be handled better. also should benchmark how much speed non-po2 FFT loses - m = int(np.ceil(np.log2(n))) + m = int(ceil(log2(n))) n_extended = 1 << m if n != n_extended: - x = torch.cat((x, torch.zeros(batch_size, n_extended - n, dtype=x.dtype, device=x.device)), dim=-1) - G = torch.cat((G, torch.zeros(rank, n_extended - n, dtype=G.dtype, device=G.device)), dim=-1) - H = torch.cat((H, torch.zeros(rank, n_extended - n, dtype=H.dtype, device=H.device)), dim=-1) - subdiag_A = torch.cat((subdiag_A, torch.zeros(n_extended - n, dtype=subdiag_A.dtype, device=subdiag_A.device))) - subdiag_B = torch.cat((subdiag_B, torch.zeros(n_extended - n, dtype=subdiag_B.dtype, device=subdiag_B.device))) + x = torch.cat( + ( + x, + torch.zeros(batch_size, n_extended - n, dtype=x.dtype, device=x.device), + ), + dim=-1, + ) + G = torch.cat( + (G, torch.zeros(rank, n_extended - n, dtype=G.dtype, device=G.device)), + dim=-1, + ) + H = torch.cat( + (H, torch.zeros(rank, n_extended - n, dtype=H.dtype, device=H.device)), + dim=-1, + ) + subdiag_A = torch.cat( + ( + subdiag_A, + torch.zeros( + n_extended - n, dtype=subdiag_A.dtype, device=subdiag_A.device + ), + ) + ) + subdiag_B = torch.cat( + ( + subdiag_B, + torch.zeros( + n_extended - n, dtype=subdiag_B.dtype, device=subdiag_B.device + ), + ) + ) KT_out = krylov_transpose_multiply(subdiag_B, H, x) K_out = krylov_multiply(subdiag_A, G, KT_out) return K_out[:, :n] if n != n_extended else K_out + ##### Slow multiplication for the subdiagonal case + def Krylov(linear_map, v, m=None): """Explicit construction of Krylov matrix [v A @ v A^2 @ v ... A^{m-1} @ v]. Parameters: @@ -705,7 +681,14 @@ def subdiag_linear_map(subdiag, upper_right_corner=0.0): """ n = subdiag.size(0) + 1 shift_down = torch.arange(-1, n - 1, device=subdiag.device) - subdiag_extended = torch.cat((torch.tensor([upper_right_corner], dtype=subdiag.dtype, device=subdiag.device), subdiag)) + subdiag_extended = torch.cat( + ( + torch.tensor( + [upper_right_corner], dtype=subdiag.dtype, device=subdiag.device + ), + subdiag, + ) + ) # Pytorch 1.0 has torch.roll that should be much faster # return lambda v: subdiag_extended * v.roll(1, dims=-1) return lambda v: subdiag_extended * v[..., shift_down] @@ -729,9 +712,16 @@ def krylov_subdiag_fast(subdiag, v, upper_right_corner=0.0): rank, n = v.shape a = torch.arange(n, dtype=torch.long, device=v.device) b = -a - indices = a[:, np.newaxis] + b[np.newaxis] + indices = a[:, None] + b[None] v_circulant = v[:, indices] - subdiag_extended = torch.cat((torch.tensor([upper_right_corner], dtype=subdiag.dtype, device=subdiag.device), subdiag)) + subdiag_extended = torch.cat( + ( + torch.tensor( + [upper_right_corner], dtype=subdiag.dtype, device=subdiag.device + ), + subdiag, + ) + ) subdiag_circulant = subdiag_extended[indices] subdiag_cumprod = subdiag_circulant.cumprod(dim=1) K = v_circulant @@ -740,7 +730,7 @@ def krylov_subdiag_fast(subdiag, v, upper_right_corner=0.0): def subdiag_mult_slow_old(subdiag_A, subdiag_B, G, H, x): - """Multiply \sum_i Krylov(A, G_i) @ Krylov(B, H_i) @ x when A and B are zero except on the subdiagonal. + """Multiply sum_i Krylov(A, G_i) @ Krylov(B, H_i) @ x when A and B are zero except on the subdiagonal. Uses the explicit Krylov construction with slow (and easy to understand) linear map. Parameters: @@ -755,13 +745,16 @@ def subdiag_mult_slow_old(subdiag_A, subdiag_B, G, H, x): rank, n = G.shape linear_map_A = functools.partial(shift_subdiag, subdiag_A) linear_map_B = functools.partial(shift_subdiag, subdiag_B) - krylovs = [(Krylov(linear_map_A, G[i]), Krylov(linear_map_B, H[i]).t()) for i in range(rank)] + krylovs = [ + (Krylov(linear_map_A, G[i]), Krylov(linear_map_B, H[i]).t()) + for i in range(rank) + ] prods = [K[0] @ (K[1] @ x.t()) for K in krylovs] return sum(prods).t() def subdiag_mult_slow(subdiag_A, subdiag_B, G, H, x, corner_A=0.0, corner_B=0.0): - """Multiply \sum_i Krylov(A, G_i) @ Krylov(B, H_i) @ x when A and B are zero except on the subdiagonal. + """Multiply sum_i Krylov(A, G_i) @ Krylov(B, H_i) @ x when A and B are zero except on the subdiagonal. Uses the explicit Krylov construction with the more careful implementation of linear map. Parameters: subdiag_A: Tensor of shape (n - 1, ) @@ -783,7 +776,7 @@ def subdiag_mult_slow(subdiag_A, subdiag_B, G, H, x, corner_A=0.0, corner_B=0.0) def subdiag_mult_slow_fast(subdiag_A, subdiag_B, G, H, x): - """Multiply \sum_i Krylov(A, G_i) @ Krylov(B, H_i) @ x when A and B are zero except on the subdiagonal. + """Multiply sum_i Krylov(A, G_i) @ Krylov(B, H_i) @ x when A and B are zero except on the subdiagonal. Uses the fast construction of Krylov matrix. Parameters: subdiag_A: Tensor of shape (n - 1, ) @@ -799,8 +792,8 @@ def subdiag_mult_slow_fast(subdiag_A, subdiag_B, G, H, x): class CycleDownMultCuda(torch.autograd.Function): - '''Cycle v down and do pointwise multiplication with subdiag. - ''' + """Cycle v down and do pointwise multiplication with subdiag.""" + @staticmethod def forward(ctx, subdiag, v): ctx.save_for_backward(subdiag, v) @@ -809,24 +802,12 @@ def forward(ctx, subdiag, v): @staticmethod def backward(ctx, grad): subdiag, v = ctx.saved_tensors - return diag_mult_cuda.cycle_mult(grad, v, 0, -1).sum(dim=0), diag_mult_cuda.cycle_mult(subdiag, grad, 1, 1) + return diag_mult_cuda.cycle_mult(grad, v, 0, -1).sum( + dim=0 + ), diag_mult_cuda.cycle_mult(subdiag, grad, 1, 1) -cycle_down_mult = CycleDownMultCuda.apply -def test_cycle_down_mult(): - n = 1 << 10 - rank = 16 - subdiag = torch.rand(n, requires_grad=True, device=device) - v = torch.rand((rank, n), requires_grad=True, device=device) - z = cycle_down_mult(subdiag, v) - y = torch.cat((subdiag[0] * v[..., -1:], subdiag[1:] * v[..., :-1]), dim=-1) - print((z - y).abs().max().item()) - - grad_output = torch.rand_like(y) - gs, gv = torch.autograd.grad(y, (subdiag, v), grad_output, retain_graph=True) - zs, zv = torch.autograd.grad(z.sum(), (subdiag, v), grad_output, retain_graph=True) - print((zs - gs).abs().max().item()) - print((zv - gv).abs().max().item()) +cycle_down_mult = CycleDownMultCuda.apply def subdiag_linear_map_cuda(subdiag, upper_right_corner=0.0): @@ -838,12 +819,19 @@ def subdiag_linear_map_cuda(subdiag, upper_right_corner=0.0): Returns: linear_map: v -> product, with v of shape either (n, ) or (rank, n) """ - subdiag_extended = torch.cat((torch.tensor([upper_right_corner], dtype=subdiag.dtype, device=subdiag.device), subdiag)) + subdiag_extended = torch.cat( + ( + torch.tensor( + [upper_right_corner], dtype=subdiag.dtype, device=subdiag.device + ), + subdiag, + ) + ) return lambda v: cycle_down_mult(subdiag_extended, v) def subdiag_mult_cuda(subdiag_A, subdiag_B, G, H, x, corner_A=0.0, corner_B=0.0): - """Multiply \sum_i Krylov(A, G_i) @ Krylov(B, H_i) @ x when A and B are zero except on the subdiagonal. + """Multiply sum_i Krylov(A, G_i) @ Krylov(B, H_i) @ x when A and B are zero except on the subdiagonal. Uses the explicit Krylov construction in CUDA. Parameters: subdiag_A: Tensor of shape (n - 1, ) @@ -858,9 +846,13 @@ def subdiag_mult_cuda(subdiag_A, subdiag_B, G, H, x, corner_A=0.0, corner_B=0.0) K_H = Krylov(subdiag_linear_map_cuda(subdiag_B, corner_B), H) return ((x @ K_H) @ K_G.transpose(1, 2)).sum(dim=0) + ##### Slow multiplication for the tridiagonal case -def tridiag_linear_map(subdiag, diag, superdiag, upper_right_corner=0.0, lower_left_corner=0.0): + +def tridiag_linear_map( + subdiag, diag, superdiag, upper_right_corner=0.0, lower_left_corner=0.0 +): """Construct the linear map for multiplying with a tridiagonal matrix (possibly with upper right and lower left corners). Similar to subdiag_linear_map, we want to reduce the number of CUDA @@ -880,13 +872,29 @@ def tridiag_linear_map(subdiag, diag, superdiag, upper_right_corner=0.0, lower_l shift_down = shift_none - 1 shift_up = (shift_none + 1) % n shifts = torch.stack((shift_down, shift_none, shift_up)) - subdiag_extended = torch.cat((torch.tensor([upper_right_corner], dtype=subdiag.dtype, device=subdiag.device), subdiag)) - superdiag_extended = torch.cat((superdiag, torch.tensor([lower_left_corner], dtype=superdiag.dtype, device=superdiag.device))) + subdiag_extended = torch.cat( + ( + torch.tensor( + [upper_right_corner], dtype=subdiag.dtype, device=subdiag.device + ), + subdiag, + ) + ) + superdiag_extended = torch.cat( + ( + superdiag, + torch.tensor( + [lower_left_corner], dtype=superdiag.dtype, device=superdiag.device + ), + ) + ) diags = torch.stack((subdiag_extended, diag, superdiag_extended)) return lambda v: (diags * v[..., shifts]).sum(dim=-2) -def tridiag_linear_map_slow(subdiag, diag, superdiag, upper_right_corner=0.0, lower_left_corner=0.0): +def tridiag_linear_map_slow( + subdiag, diag, superdiag, upper_right_corner=0.0, lower_left_corner=0.0 +): """The linear map for multiplying with a tridiagonal matrix (possibly with upper right and lower left corner). This implementation is slow, but easy to understand. @@ -899,11 +907,29 @@ def tridiag_linear_map_slow(subdiag, diag, superdiag, upper_right_corner=0.0, lo Returns: linear_map: v -> product, with v of shape either (n, ) or (rank, n) """ - return lambda v: torch.cat((upper_right_corner * v[..., -1:], subdiag * v[..., :-1]), dim=-1) + diag * v + torch.cat((superdiag * v[..., 1:], lower_left_corner * v[..., :1]), dim=-1) - - -def tridiag_mult_slow(subdiag_A, diag_A, superdiag_A, subdiag_B, diag_B, superdiag_B, G, H, x, corners_A=(0.0, 0.0), corners_B=(0.0, 0.0)): - """Multiply \sum_i Krylov(A, G_i) @ Krylov(B, H_i) @ x when A and B are zero except on the subdiagonal. + return ( + lambda v: torch.cat( + (upper_right_corner * v[..., -1:], subdiag * v[..., :-1]), dim=-1 + ) + + diag * v + + torch.cat((superdiag * v[..., 1:], lower_left_corner * v[..., :1]), dim=-1) + ) + + +def tridiag_mult_slow( + subdiag_A, + diag_A, + superdiag_A, + subdiag_B, + diag_B, + superdiag_B, + G, + H, + x, + corners_A=(0.0, 0.0), + corners_B=(0.0, 0.0), +): + """Multiply sum_i Krylov(A, G_i) @ Krylov(B, H_i) @ x when A and B are zero except on the subdiagonal. Uses the explicit Krylov construction with the more careful implementation of linear map. Parameters: subdiag_A: Tensor of shape (n - 1, ) @@ -921,178 +947,14 @@ def tridiag_mult_slow(subdiag_A, diag_A, superdiag_A, subdiag_B, diag_B, superdi product: Tensor of shape (batch_size, n) """ if G.shape[0] == 1: # specialized code for rank=1, giving 2x speedup. - K_G = Krylov(tridiag_linear_map(subdiag_A, diag_A, superdiag_A, *corners_A), G[0]) - K_H = Krylov(tridiag_linear_map(subdiag_B, diag_B, superdiag_B, *corners_B), H[0]) + K_G = Krylov( + tridiag_linear_map(subdiag_A, diag_A, superdiag_A, *corners_A), G[0] + ) + K_H = Krylov( + tridiag_linear_map(subdiag_B, diag_B, superdiag_B, *corners_B), H[0] + ) return (x @ K_H) @ K_G.t() else: K_G = Krylov(tridiag_linear_map(subdiag_A, diag_A, superdiag_A, *corners_A), G) K_H = Krylov(tridiag_linear_map(subdiag_B, diag_B, superdiag_B, *corners_B), H) return ((x @ K_H) @ K_G.transpose(1, 2)).sum(dim=0) - - -def test_krylov_transpose_multiply(): - m = 10 - n = 1 << m - batch_size = 50 - rank = 16 - subdiag = torch.rand(n-1, requires_grad=True, device=device) - A = np.diag(subdiag.data.cpu().numpy(), -1) - u = torch.rand((batch_size, n), requires_grad=True, device=device) - v = torch.rand((rank, n), requires_grad=True, device=device) - # Fast algorithm on GPU - # KTu_traced = torch.jit.trace(KTu_traceable, (subdiag, v, u)) - result = krylov_transpose_multiply(subdiag, v, u) - # result = krylov_transpose_multiply_conv(subdiag, v, u) - # result = krylov_transpose_multiply_old(subdiag, v, u) - grad, = torch.autograd.grad(result.sum(), subdiag, retain_graph=True) - # CPU dense multiply - Ks = [krylov_construct(A, v.data.cpu().numpy()[i], n) for i in range(rank)] - u_cpu = u.data.cpu().numpy() - result_cpu = np.stack([u_cpu @ K.T for K in Ks]) - result_cpu = result_cpu.swapaxes(0, 1).squeeze() - result_cpu = torch.tensor(result_cpu, dtype=torch.float, device=device) - # GPU dense multiply - Ks_gpu_dense = [torch.tensor(K, dtype=torch.float, device=device) for K in Ks] - result_gpu_dense = torch.stack([u @ K.t() for K in Ks_gpu_dense]) - result_gpu_dense = result_gpu_dense.transpose(0, 1).squeeze() - # Explicit construction on GPU - Ks_gpu = Krylov(subdiag_linear_map(subdiag), v) - result_gpu = (u @ Ks_gpu).transpose(0, 1) - grad_gpu, = torch.autograd.grad(result_gpu.sum(), subdiag, retain_graph=True) - # Explicit construction on GPU, but faster - Ks_gpu_fast = krylov_subdiag_fast(subdiag, v) - result_gpu_fast = (u @ Ks_gpu_fast).transpose(0, 1) - grad_gpu_fast, = torch.autograd.grad(result_gpu_fast.sum(), subdiag, retain_graph=True) - # These max and mean differences should be small - print((result - result_cpu).abs().max().item()) - print((result - result_cpu).abs().mean().item()) - print((result - result_gpu_dense).abs().max().item()) - print((result - result_gpu_dense).abs().mean().item()) - print((result - result_gpu).abs().max().item()) - print((result - result_gpu).abs().mean().item()) - print((grad - grad_gpu).abs().max().item()) - print((grad - grad_gpu).abs().mean().item()) - print((result - result_gpu_fast).abs().max().item()) - print((result - result_gpu_fast).abs().mean().item()) - print((grad - grad_gpu_fast).abs().max().item()) - print((grad - grad_gpu_fast).abs().mean().item()) - - # with torch.autograd.profiler.profile(use_cuda=True) as prof: - # result = krylov_transpose_multiply_conv(subdiag, v, u) - # grad, = torch.autograd.grad(result.sum(), subdiag, retain_graph=True) - -def test_krylov_multiply(): - m = 10 - n = 1 << m - batch_size = 50 - rank = 16 - subdiag = torch.rand(n-1, requires_grad=True, device=device) - A = np.diag(subdiag.data.cpu().numpy(), -1) - u = torch.rand((batch_size, n), requires_grad=True, device=device) - v = torch.rand((rank, n), requires_grad=True, device=device) - w = torch.rand((batch_size, rank, n), requires_grad=True, device=device) - # Fast algorithm on GPU - # result = krylov_multiply_conv(subdiag, v, w) - result = krylov_multiply(subdiag, v, w) - # result = krylov_multiply_old(subdiag, v, w) - grad, = torch.autograd.grad(result.sum(), subdiag, retain_graph=True) - # Using autodiff - result_autodiff = krylov_multiply_by_autodiff(subdiag, v, w) - grad_autodiff, = torch.autograd.grad(result_autodiff.sum(), subdiag, retain_graph=True) - # CPU dense multiply - Ks = [krylov_construct(A, v.data.cpu().numpy()[i], n) for i in range(rank)] - w_cpu = w.data.cpu().numpy() - result_cpu = np.stack([w_cpu[:, i] @ Ks[i] for i in range(rank)]).sum(axis=0).squeeze() - result_cpu = torch.tensor(result_cpu, dtype=torch.float, device=device) - # Explicit construction on GPU - Ks_gpu = Krylov(subdiag_linear_map(subdiag), v) - result_gpu = (w.transpose(0, 1) @ Ks_gpu.transpose(1, 2)).sum(dim=0) - grad_gpu, = torch.autograd.grad(result_gpu.sum(), subdiag, retain_graph=True) - # Explicit construction on GPU, but faster - Ks_gpu_fast = krylov_subdiag_fast(subdiag, v) - result_gpu_fast = (w.transpose(0, 1) @ Ks_gpu_fast.transpose(1, 2)).sum(dim=0) - grad_gpu_fast, = torch.autograd.grad(result_gpu_fast.sum(), subdiag, retain_graph=True) - # These max and mean differences should be small - print((result - result_autodiff).abs().max().item()) - print((result - result_autodiff).abs().mean().item()) - print((grad - grad_autodiff).abs().max().item()) - print((grad - grad_autodiff).abs().mean().item()) - print((result - result_cpu).abs().max().item()) - print((result - result_cpu).abs().mean().item()) - print((result - result_gpu).abs().max().item()) - print((result - result_gpu).abs().mean().item()) - print((grad - grad_gpu).abs().max().item()) - print((grad - grad_gpu).abs().mean().item()) - print((result - result_gpu_fast).abs().max().item()) - print((result - result_gpu_fast).abs().mean().item()) - print((grad - grad_gpu_fast).abs().max().item()) - print((grad - grad_gpu_fast).abs().mean().item()) - - -def test_subdiag_mult(): - m = 10 - n = 1 << m - batch_size = 50 - rank = 16 - subdiag = torch.rand(n-1, requires_grad=True, device=device) - diag = torch.rand(n, requires_grad=True, device=device) - superdiag = torch.rand(n-1, requires_grad=True, device=device) - u = torch.rand((batch_size, n), requires_grad=True, device=device) - v = torch.rand((rank, n), requires_grad=True, device=device) - - K = Krylov(subdiag_linear_map(subdiag, 1.0), v) - K_fast = krylov_subdiag_fast(subdiag, v, upper_right_corner=1.0) - print((K - K_fast).abs().max().item()) - - result = subdiag_mult_conv(subdiag, subdiag, v, v, u) - # result = subdiag_mult(subdiag, subdiag, v, v, u) - grad, = torch.autograd.grad(result.sum(), subdiag, retain_graph=True) - result_slow_old = subdiag_mult_slow_old(subdiag, subdiag, v, v, u) - grad_slow_old, = torch.autograd.grad(result_slow_old.sum(), subdiag, retain_graph=True) - result_slow = subdiag_mult_slow(subdiag, subdiag, v, v, u) - grad_slow, = torch.autograd.grad(result_slow.sum(), subdiag, retain_graph=True) - result_slow_fast = subdiag_mult_slow_fast(subdiag, subdiag, v, v, u) - grad_slow_fast, = torch.autograd.grad(result_slow_fast.sum(), subdiag, retain_graph=True) - result_cuda = subdiag_mult_cuda(subdiag, subdiag, v, v, u) - grad_cuda, = torch.autograd.grad(result_cuda.sum(), subdiag, retain_graph=True) - # These max and mean differences should be small - print((result - result_slow_old).abs().max().item()) - print((result - result_slow_old).abs().mean().item()) - print((grad - grad_slow_old).abs().max().item()) - print((grad - grad_slow_old).abs().mean().item()) - print((result - result_slow).abs().max().item()) - print((result - result_slow).abs().mean().item()) - print((grad - grad_slow).abs().max().item()) - print((grad - grad_slow).abs().mean().item()) - print((result - result_slow_fast).abs().max().item()) - print((result - result_slow_fast).abs().mean().item()) - print((grad - grad_slow_fast).abs().max().item()) - print((grad - grad_slow_fast).abs().mean().item()) - print((result - result_cuda).abs().max().item()) - print((result - result_cuda).abs().mean().item()) - print((grad - grad_cuda).abs().max().item()) - print((grad - grad_cuda).abs().mean().item()) - - -def test_tridiag_mult(): - m = 10 - n = 1 << m - batch_size = 50 - rank = 16 - subdiag = torch.rand(n-1, requires_grad=True, device=device) / 2 - diag = torch.rand(n, requires_grad=True, device=device) / 2 - superdiag = torch.rand(n-1, requires_grad=True, device=device) / 2 - u = torch.rand((batch_size, n), requires_grad=True, device=device) - v = torch.rand((rank, n), requires_grad=True, device=device) - K = Krylov(tridiag_linear_map(subdiag, diag, superdiag, 0.5, 0.5), v) - K_old = Krylov(tridiag_linear_map_slow(subdiag, diag, superdiag, 0.5, 0.5), v) - print((K - K_old).abs().max().item()) - trid_slow = tridiag_mult_slow(subdiag, diag, superdiag, subdiag, diag, superdiag, v, v, u) - - -# TODO: broken, move test into subpackage -if __name__ == "__main__": - test_krylov_transpose_multiply() - test_krylov_multiply() - test_subdiag_mult() - test_tridiag_mult() diff --git a/pytorch/structure/layer.py b/pytorch/structure/layer.py index 193e26e..935a4ed 100644 --- a/pytorch/structure/layer.py +++ b/pytorch/structure/layer.py @@ -1,15 +1,26 @@ -import numpy as np +# Copyright 2018 HazyResearch +# https://github.com/HazyResearch/structured-nets +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import torch import torch.nn as nn -from torch.nn.parameter import Parameter from torch.autograd import Variable +from torch.nn.parameter import Parameter -from . import toeplitz as toep -from . import krylov as kry -from . import circulant as circ -from . import fastfood as ff +from . import circulant as circ, fastfood as ff, krylov as kry, toeplitz as toep -from utils import descendants class Layer(nn.Module): class_type = None @@ -40,9 +51,10 @@ def apply_bias(self, out): def loss(self): return 0 + class Unconstrained(Layer): - class_type = 'unconstrained' - abbrev = 'u' + class_type = "unconstrained" + abbrev = "u" def name(self): return self.__class__.abbrev + str(self.hidden_size) @@ -55,7 +67,7 @@ def __init__(self, layer_size, hidden_size=None, **kwargs): def reset_parameters(self): super().reset_parameters() self.W = Parameter(torch.Tensor(self.layer_size, self.hidden_size)) - self.init_stddev = np.sqrt(1./self.layer_size) + self.init_stddev = torch.sqrt(1.0 / self.layer_size) torch.nn.init.normal_(self.W, std=self.init_stddev) self.mask = None if self.bias: @@ -64,28 +76,27 @@ def reset_parameters(self): def set_mask(self, mask, device): self.mask = Variable(torch.FloatTensor(mask).to(device), requires_grad=False) self.W.data *= self.mask.data - print('Num. nonzero entries after pruning: ', torch.nonzero(self.W).size(0)) + print("Num. nonzero entries after pruning: ", torch.nonzero(self.W).size(0)) def forward(self, x): if self.mask is not None: - masked_W = self.W*self.mask - #print('NNZ, mask: ', torch.nonzero(self.mask).size(0)) - #print('NNZ, masked_W: ', torch.nonzero(masked_W).size(0)) + masked_W = self.W * self.mask + # print('NNZ, mask: ', torch.nonzero(self.mask).size(0)) + # print('NNZ, masked_W: ', torch.nonzero(masked_W).size(0)) out = torch.matmul(x, masked_W) else: out = torch.matmul(x, self.W) return self.apply_bias(out) - class Circulant(Layer): - class_type = 'circulant' - abbrev = 'c' + class_type = "circulant" + abbrev = "c" def reset_parameters(self): super().reset_parameters() self.c = Parameter(torch.Tensor(self.layer_size)) - self.init_stddev = np.sqrt(1./self.layer_size) + self.init_stddev = torch.sqrt(1.0 / self.layer_size) torch.nn.init.normal_(self.c, std=self.init_stddev) def forward(self, x): @@ -93,8 +104,8 @@ def forward(self, x): class FastFood(Layer): - class_type = 'fastfood' - abbrev = 'f' + class_type = "fastfood" + abbrev = "f" def reset_parameters(self): super().reset_parameters() @@ -102,25 +113,22 @@ def reset_parameters(self): # TODO: check initialization of S (scaling matrix) is correct # S,G,B: diagonal, learnable parameters # P: permutation, fixed - S = np.sqrt(np.random.chisquare(self.layer_size, size=self.layer_size)) - G = np.random.randn(self.layer_size) - S /= np.linalg.norm(G) - B = np.random.choice((-1, 1), size=self.layer_size) + S = torch.sqrt(torch.random.chisquare(self.layer_size, size=self.layer_size)) + G = torch.random.randn(self.layer_size) + S /= torch.linalg.norm(G) + B = torch.random.choice((-1, 1), size=self.layer_size) self.S = Parameter(torch.FloatTensor(S)) self.G = Parameter(torch.FloatTensor(G)) self.B = Parameter(torch.FloatTensor(B)) - self.P = torch.LongTensor(np.random.permutation(self.layer_size)) - #self.init_stddev = np.sqrt(1./self.layer_size) - #torch.nn.init.normal_(self.S, std=self.init_stddev) - #torch.nn.init.normal_(self.G, std=self.init_stddev) - #torch.nn.init.normal_(self.B, std=self.init_stddev) + self.P = torch.LongTensor(torch.random.permutation(self.layer_size)) def forward(self, x): return self.apply_bias(ff.fastfood_multiply(self.S, self.G, self.B, self.P, x)) + class LowRank(Layer): - class_type = 'low_rank' - abbrev = 'lr' + class_type = "low_rank" + abbrev = "lr" def name(self): return self.__class__.abbrev + str(self.r) @@ -133,7 +141,7 @@ def reset_parameters(self): self.G = Parameter(torch.Tensor(self.r, self.layer_size)) self.H = Parameter(torch.Tensor(self.r, self.layer_size)) # self.init_stddev = 0.01 - self.init_stddev = np.power(1. / (self.r * self.layer_size), 1/2) + self.init_stddev = torch.power(1.0 / (self.r * self.layer_size), 1 / 2) torch.nn.init.normal_(self.G, std=self.init_stddev) torch.nn.init.normal_(self.H, std=self.init_stddev) @@ -149,8 +157,8 @@ def loss(self): class ToeplitzLike(LowRank): - class_type = 'toeplitz' - abbrev = 't' + class_type = "toeplitz" + abbrev = "t" def reset_parameters(self): super().reset_parameters() @@ -160,25 +168,28 @@ def forward(self, x): out = toep.toeplitz_mult(self.G, self.H, x, self.corner) return self.apply_bias(out) + class ToeplitzLikeC(ToeplitzLike): - class_type = 'toeplitz_corner' - abbrev = 'tc' + class_type = "toeplitz_corner" + abbrev = "tc" def reset_parameters(self): super().reset_parameters() self.corner = True + class HankelLike(LowRank): - class_type = 'hankel' - abbrev = 'h' + class_type = "hankel" + abbrev = "h" def forward(self, x): out = toep.toeplitz_mult(self.G, self.H, x, True) return self.apply_bias(out.flip(out.dim() - 1)) + class VandermondeLike(LowRank): - class_type = 'vandermonde' - abbrev = 'v' + class_type = "vandermonde" + abbrev = "v" def reset_parameters(self): super().reset_parameters() @@ -196,7 +207,7 @@ def forward(self, x): # out = (x @ K_B) @ K_A.transpose(1,2) out = toep.toeplitz_krylov_transpose_multiply(self.H, x) - out = out.transpose(0,1) @ K_A.transpose(1,2) + out = out.transpose(0, 1) @ K_A.transpose(1, 2) out = torch.sum(out, dim=0) return self.apply_bias(out) @@ -210,32 +221,35 @@ class LearnedOperator(LowRank): Abstract class for learned displacement operators Contains parameters such as tie_operators """ - class_type = None # abstract + + class_type = None # abstract abbrev = None def __init__(self, tie_operators=False, corner=False, **kwargs): super().__init__(tie_operators=tie_operators, corner=corner, **kwargs) + class LDRSubdiagonal(LearnedOperator): - class_type = 'subdiagonal' - abbrev = 'sd' + class_type = "subdiagonal" + abbrev = "sd" def reset_parameters(self): super().reset_parameters() - self.subd_A = Parameter(torch.ones(self.layer_size-1)) + self.subd_A = Parameter(torch.ones(self.layer_size - 1)) if self.tie_operators: self.subd_B = self.subd_A else: - self.subd_B = Parameter(torch.ones(self.layer_size-1)) + self.subd_B = Parameter(torch.ones(self.layer_size - 1)) def forward(self, x): out = kry.subdiag_mult(self.subd_A, self.subd_B, self.G, self.H, x) - #out = kry.subdiag_mult_conv(self.subd_A, self.subd_B, self.G, self.H, x) + # out = kry.subdiag_mult_conv(self.subd_A, self.subd_B, self.G, self.H, x) return self.apply_bias(out) + class LDRSubdiagonalC(LDRSubdiagonal): - class_type = 'subdiagonal_corner' - abbrev = 'sdc' + class_type = "subdiagonal_corner" + abbrev = "sdc" def reset_parameters(self): super().reset_parameters() @@ -243,36 +257,58 @@ def reset_parameters(self): self.corner_B = Parameter(torch.tensor(0.0)) def forward(self, x): - out = kry.subdiag_mult_cuda(self.subd_A, self.subd_B, self.G, self.H, x, corner_A=self.corner_A, corner_B=self.corner_B) + out = kry.subdiag_mult_cuda( + self.subd_A, + self.subd_B, + self.G, + self.H, + x, + corner_A=self.corner_A, + corner_B=self.corner_B, + ) return self.apply_bias(out) + class LDRTridiagonal(LearnedOperator): - class_type = 'tridiagonal' - abbrev = 'td' + class_type = "tridiagonal" + abbrev = "td" def reset_parameters(self): super().reset_parameters() - self.subd_A = Parameter(torch.ones(self.layer_size-1)) + self.subd_A = Parameter(torch.ones(self.layer_size - 1)) self.diag_A = Parameter(torch.zeros(self.layer_size)) - self.supd_A = Parameter(torch.zeros(self.layer_size-1)) + self.supd_A = Parameter(torch.zeros(self.layer_size - 1)) if self.tie_operators: self.subd_B = self.subd_A self.diag_B = self.diag_A self.supd_B = self.supd_A else: - self.subd_B = Parameter(torch.ones(self.layer_size-1)) + self.subd_B = Parameter(torch.ones(self.layer_size - 1)) self.diag_B = Parameter(torch.zeros(self.layer_size)) - self.supd_B = Parameter(torch.zeros(self.layer_size-1)) - self.corners_A = (0.0,0.0) - self.corners_B = (0.0,0.0) + self.supd_B = Parameter(torch.zeros(self.layer_size - 1)) + self.corners_A = (0.0, 0.0) + self.corners_B = (0.0, 0.0) def forward(self, x): - out = kry.tridiag_mult_slow(self.subd_A, self.diag_A, self.supd_A, self.subd_B, self.diag_B, self.supd_B, self.G, self.H, x, corners_A=self.corners_A, corners_B=self.corners_B) + out = kry.tridiag_mult_slow( + self.subd_A, + self.diag_A, + self.supd_A, + self.subd_B, + self.diag_B, + self.supd_B, + self.G, + self.H, + x, + corners_A=self.corners_A, + corners_B=self.corners_B, + ) return self.apply_bias(out) + class LDRTridiagonalC(LDRTridiagonal): - class_type = 'tridiagonal_corner' - abbrev = 'tdc' + class_type = "tridiagonal_corner" + abbrev = "tdc" def reset_parameters(self): super().reset_parameters() @@ -281,11 +317,25 @@ def reset_parameters(self): # create a map from class names to the Python class +def descendants(cls): + """ + Get all subclasses (recursively) of class cls, not including itself + Assumes no multiple inheritance + """ + desc = [] + for subcls in cls.__subclasses__(): + desc.append(subcls) + desc.extend(descendants(subcls)) + return desc + + class_map = {} for cls in descendants(Layer): - if cls.class_type is None: continue + if cls.class_type is None: + continue class_map[cls.class_type] = cls class_map[cls.abbrev] = cls + def StructuredLinear(class_type, **kwargs): return class_map[class_type](**kwargs) diff --git a/pytorch/structure/scratch/fft.py b/pytorch/structure/scratch/fft.py deleted file mode 100644 index ead4789..0000000 --- a/pytorch/structure/scratch/fft.py +++ /dev/null @@ -1,91 +0,0 @@ -import numpy as np -import itertools - -p=2 -d=3 -N=p << (d-1) - -f = np.arange(N) -print(np.fft.fft(f)) - - -def init(f): - x = np.zeros(d*[p], dtype=np.complex_) - idx = [list(range(p)) for i in range(d)] - powers = np.array([p**i for i in range(d)]) - for t in itertools.product(*idx): - x[t] = f[np.sum(powers*np.array(t))] - return x -x = init(f) -print(x.shape) - -def unshape(x): - f = np.zeros(p**d, dtype=np.complex_) - idx = [list(range(p)) for i in range(d)] - powers = np.array([p**i for i in range(d)]) - for t in itertools.product(*idx): - f[np.sum(powers*np.array(t))] = x[t] - return f - - -# x = f.reshape([[p]*d]).astype(np.complex_) - - -# At pass r the layout is -# -# x_0,..., x_{d-r-1}, y_{r-1}, ..., y_{0} -# So at -# r = 0 => x_0,..., x_{d-1} -# r = 1 => x_0,..., x_{d-2}, y_0 -# r = 2 => x_0,..., x_{d-3}, y_0, y_1 -# r = d => y_0,..., y_{d-1} -# -def pass_it_(x,x_new,r, verbose=False): - # The index ranges - # (x_0,...,x_{d-r-2},x_{d-r-1}, y_{0}, .., y_{r-1}, y_r) - idx = [list(range(p)) for i in range(d+1)] - omega = -2*np.complex(0,1)*np.pi/(p**d) - powers = np.array([p**i for i in range(r+1)]) - # powers = np.array([p**i for i in range(r,-1,-1)]) - for t in itertools.product(*idx): - # The last index are the ys - x_base = t[0:d-r-1] - x_last = t[d-r-1] # this is xm - y_base = t[d-r:d] - y_last = t[d] - # marginalize out over xm, but keep the ys in the same order? - new_t = x_base + y_base + (y_last,) - old_t = x_base + (x_last,) + y_base - y_sum = np.sum(np.array(t[d-r:d+1]) * powers) * p**(d-r-1) - if verbose: - print(f"x={x_base},{x_last} -> y={y_base},{y_last} :: new={new_t} += old={old_t} y_sum={y_sum} {y_sum*x_last}") - q = omega*x_last*y_sum - x_new[new_t] += x[old_t]*np.exp(q) - if verbose: print("**") - return x_new - -def pass_it(x,r,verbose=False): - x_new = np.zeros(d*[p], dtype=np.complex_) - return pass_it_(x,x_new,r,verbose=verbose) - -def fft_pass(x): - _x = np.copy(x) - x_new = np.zeros(d*[p], dtype=np.complex_) - for r in range(d): - pass_it_(_x,x_new,r) - (_x,x_new) = (x_new,_x) - x_new[:] = 0 - return _x - -def slow_fft(x): - y = np.zeros(x.shape, dtype=np.complex_) - idx = [list(range(p)) for i in range(d)] - omega = -2*np.complex(0,1)*np.pi/(p**d) - powers = np.array([p**i for i in range(d)]) - # powers = np.array([p**i for i in range(d-1,-1,-1)]) - for t in itertools.product(*idx): - y_t = np.sum(powers*np.array(t)) - for u in itertools.product(*idx): - x_t = np.sum(powers*np.array(u)) - y[t] += x[u]*np.exp(omega*y_t*x_t) - return y diff --git a/pytorch/structure/scratch/krylovfast.py b/pytorch/structure/scratch/krylovfast.py deleted file mode 100644 index b357c22..0000000 --- a/pytorch/structure/scratch/krylovfast.py +++ /dev/null @@ -1,375 +0,0 @@ -import numpy as np -import itertools - -import pyfftw -import sys -sys.path.insert(0,'../../../pytorch/') -from structure.scratch.krylovslow import krylov_construct - - -# define fft calls -def _plan_ffts(in_shape, lib='numpy'): - out_shape = in_shape[:-1] + (in_shape[-1]//2 + 1,) - if lib == 'numpy': - x_for = np.zeros(shape=in_shape) - fft = lambda: np.fft.rfft(x_for) - - y_bak = np.empty(shape=out_shape, dtype='complex128') - ifft = lambda: np.fft.irfft(y_bak) - return ((x_for, fft), (y_bak, ifft)) - if lib == 'scipy': - pass - if lib == 'fftw': - out_shape = in_shape[:-1] + (in_shape[-1]//2 + 1,) - x_for = pyfftw.empty_aligned(in_shape, dtype='float64') - y_for = pyfftw.empty_aligned(out_shape, dtype='complex128') - fft_for = pyfftw.FFTW(x_for, y_for, direction='FFTW_FORWARD', flags=['FFTW_MEASURE']) # don't destroy input so 0s are preserved - x_for[:] = 0 - - x_bak = pyfftw.empty_aligned(in_shape, dtype='float64') - y_bak = pyfftw.empty_aligned(out_shape, dtype='complex128') - fft_bak = pyfftw.FFTW(y_bak, x_bak, direction='FFTW_BACKWARD', flags=['FFTW_MEASURE', 'FFTW_DESTROY_INPUT']) - return ((x_for, fft_for), (y_bak, fft_bak)) - - -def plan_ffts(m, lib='numpy'): - fft_plans = [None] * m - for d in range(m-1,-1,-1): - n1, n2 = 1<brm", S1_01_f, S0_10_f, out=T_00_f_sum) - T = fft_freq2time(T_f, output_array=T) - T_00_sum = T - - # polynomial additions - result[:, :, 1:2*n2] += T_00_sum[..., :-1] - S0_11_mult_subdiag = S_11[::2] * subdiag[(n2 - 1)::(2 * n2)] - # T_01 = np.concatenate((S_01[:, ::2], S_01[:, 1::2] * S0_11_mult_subdiag[:, np.newaxis]), axis=-1) - T_01 = S_01.reshape(batch_size, n1, 2 * n2) - T_01[:, :, n2:] *= S0_11_mult_subdiag[:, np.newaxis] - T_10 = np.concatenate((S_10[:, 1::2], S0_10_mult_subdiag * S_11[1::2][:, np.newaxis]), axis=-1) - T_11 = S0_11_mult_subdiag * S_11[1::2] - - return result - - -class KrylovMultiply(): - """Multiply Krylov(A, v) @ w when A is zero except on the subdiagonal. - """ - - def __init__(self, n, batch_size=1, rank=1): - m = int(np.log2(n)) - assert n == 1 << m, 'n must be a power of 2' - self.n = n - self.m = m - self.batch_size = batch_size - self.rank = rank - self.plan_ffts_forward_pass_u_zero() - self.plan_ffts_backward_pass() - - def plan_ffts_forward_pass_u_zero(self): - n, m, batch_size, rank = self.n, self.m, self.batch_size, self.rank - self.S_storage = [np.empty((rank, n))] * m - self.S_f_storage = [np.empty((rank, 1 << d, (1 << (m - d - 1)) + 1), dtype='complex128') for d in range(m)] - self.ffts_forward_pass = [] - for d, (S, S_f) in enumerate(zip(self.S_storage, self.S_f_storage)): - S = S.reshape((rank, 1 << d, 1 << (m - d))) - fft_time2freq = pyfftw.FFTW(S, S_f, direction='FFTW_FORWARD', flags=['FFTW_MEASURE', 'FFTW_DESTROY_INPUT'], threads=1) - self.ffts_forward_pass.append(fft_time2freq) - - def plan_ffts_backward_pass(self): - n, m, batch_size, rank = self.n, self.m, self.batch_size, self.rank - self.dT_storage = [np.empty((batch_size, rank, 1 << (m - d))) for d in range(m)] - self.dT_f_storage = [np.empty((batch_size, rank, (1 << (m - d - 1)) + 1), dtype='complex128') for d in range(m)] - self.dS_f_storage = [np.empty((batch_size, 1 << d, (1 << (m - d - 1)) + 1), dtype='complex128') for d in range(m)] - self.dS_storage = [np.empty((batch_size, n))] * m - self.ffts_backward_pass = [] - for d, (dT, dT_f, dS_f, dS) in enumerate(zip(self.dT_storage, self.dT_f_storage, self.dS_f_storage, self.dS_storage)): - dS = dS.reshape((batch_size, 1 << d, 1 << (m - d))) - fft_time2freq = pyfftw.FFTW(dT, dT_f, direction='FFTW_FORWARD', flags=['FFTW_MEASURE', 'FFTW_DESTROY_INPUT'], threads=1) - fft_freq2time = pyfftw.FFTW(dS_f, dS, direction='FFTW_BACKWARD', flags=['FFTW_MEASURE', 'FFTW_DESTROY_INPUT'], threads=1) - self.ffts_backward_pass.append((fft_time2freq, fft_freq2time)) - - def __call__(self, subdiag, v, w): - n, m, batch_size, rank = self.n, self.m, self.batch_size, self.rank - # Forward pass. Since K @ w can be computed by autodiffing K^T @ u, we - # carry out the forward pass K^T @ u for u = 0 here to save the - # intermediate values. This code is exactly the same as the function - # @krylov_transpose_multiply, specialized to the case where u = 0. - save_for_backward = [None] * m - T_10 = v.reshape(rank, n, 1) - T_11 = np.ones(n) - for d in range(m)[::-1]: - n1, n2 = 1 << d, 1 << (m - d - 1) - S = self.S_storage[d].reshape((rank, n1, 2 * n2)) - S_f = self.S_f_storage[d] - fft_time2freq = self.ffts_forward_pass[d] - S_10, S_11 = T_10, T_11 - S0_10_mult_subdiag = S[:, :, :n2] - S0_10_mult_subdiag[:] = S_10[:, ::2] * subdiag[(n2 - 1)::(2 * n2), np.newaxis] - S[:, :, n2:] = 0.0 - S0_10_mult_subdiag_f = fft_time2freq(S, output_array=S_f) - T_10 = np.concatenate((S_10[:, 1::2], S0_10_mult_subdiag * S_11[1::2][:, np.newaxis]), axis=-1) - S0_11_mult_subdiag = S_11[::2] * subdiag[(n2 - 1)::(2 * n2)] - save_for_backward[d] = S0_10_mult_subdiag_f, S0_11_mult_subdiag - T_11 = S0_11_mult_subdiag * S_11[1::2] - - # Backward pass - w, v = w.reshape(batch_size, rank, n), v.reshape((rank, n)) - dT_01 = np.zeros((batch_size, 1, n), dtype=w.dtype) - - for d in range(m): - n1, n2 = 1 << d, 1 << (m - d - 1) - dT = self.dT_storage[d] - dT_f = self.dT_f_storage[d] - dS_f = self.dS_f_storage[d] - dS = self.dS_storage[d].reshape((batch_size, n1, 2 * n2)) - fft_time2freq, fft_freq2time = self.ffts_backward_pass[d] - - S0_10_mult_subdiag_f, S0_11_mult_subdiag = save_for_backward[d] - dS_01 = np.empty((batch_size, 2 * n1, n2), dtype=w.dtype) - dS_01[:, ::2] = dT_01[:, :, :n2] - dT_00_sum = dT - dT_00_sum[:, :, :2*n2 - 1] = w[:, :, 1:2*n2] - dT_00_sum[:, :, -1] = 0.0 - - dT_00_sum_f = fft_time2freq(dT, output_array=dT_f) - dS1_01_f = dS_f - # dS1_01_f[:] = (np.conjugate(S0_10_mult_subdiag_f, out=S0_10_mult_subdiag_f) * dT_00_sum_f[:, :, np.newaxis]).sum(axis=1) - np.einsum("brm,rnm->bnm", dT_00_sum_f, np.conjugate(S0_10_mult_subdiag_f, out=S0_10_mult_subdiag_f), out=dS1_01_f) - - dS1_01 = fft_freq2time(dS_f, output_array=dS) - dS_01[:, 1::2] = dT_01[:, :, n2:] * S0_11_mult_subdiag[:, np.newaxis] + dS1_01[:, :, :n2] - dT_01 = dS_01 - - # du = ((dT_00_sum[:, :, np.newaxis] * v[np.newaxis, :, :, np.newaxis]).sum(dim=1) + dT_01).squeeze(axis=-1) - du = w[:, :, 0] @ v + dT_01.squeeze(axis=-1) - return du - - -def test_krylov_transpose_multiply(): - m = 14 - n = 1<= 128: - prod = signal.fftconvolve(p1, p2, mode='full') - else: - prod = np.convolve(p1, p2) - # prod = np.convolve(p1, p2) - # if prod.shape[0] != n+1: - # print(d1, d2, p1.shape, p2.shape, prod.shape) - # assert false - # assert prod.shape[0] == n+1 - - return prod - -def poly_inv(p, n): - """ - invert p mod x^n - """ - assert n >= 1 - if n == 1: - return np.array([1 / p[0]]) - - # represent p = p_low + x^k p_high, and its inverse q similarly - d = p.shape[0] - k = (n+1)//2 - - # invert the lower order terms - q_low = poly_inv(p[:min(d,k)], k) - # print(q_low) - - # since 2k >= n, p q_l + x^k p_l q_h = 1 (mod x^n) - # so p_l q_h = (1 - p q_l)/x^k (mod x^{n-k}) - r = poly_mult(p, q_low) - r[0] -= 1 - # assert np.all(r[:min(r.shape[0],k)] == 0) - # but we know p_l^{-1} mod x^{n-k} since we already know it mod x^k - q_high = poly_mult(-r[k:min(r.shape[0],n)], q_low) - - # q_low = np.pad(q_low, (0,k-q_low.shape[0]), 'constant') - q = np.concatenate((q_low, q_high))[:n] - # q = np.trim_zeros(q, 'b') - return q - - - -def resolvent_bilinear(A, v, u, n): - """ - Compute [u e_n]^T * (I-Ax)^{-1} * [v e_1] - (2x2 matrix of rational fractions) - output: array of shape (2, 2, n), array shape (n) - (numerator, denominator) - - invariants: - numerator has degree n-1 - denominator degree n - """ - if n == 1: - # don't know how write outer product in numpy - return (np.array([[[ u[0]*v[0] ], [ u[0]*1 ]], [[ 1*v[0] ], [ 1*1 ]]]), np.array([1,-A[0,0]])) - - k = n//2 - # Let M00 = M[0:k, 0:k], M10 = M[k:n, 0:k], M11 = M[k:n,k:n] - # i.e. M = [M00 0 ; M10 M11] (where M = I-Ax) - # then M^{-1} = [M00^{-1} 0 ; -M11^{-1} M_10^{-1} M_00^{-1}] - S0, d0 = resolvent_bilinear(A[:k,:k], v[:k], u[:k], k) - S1, d1 = resolvent_bilinear(A[k:,k:], v[k:], u[k:], n-k) - - # the part corresponding to bottom left corner is - # -A[k, k-1]x * u_1^T M_11^{-1} e_1 * e_k^T M_00^{-1} v_0 - # or S1[:,1] * S0[1,:] - L = np.array([[poly_mult(S1[0,1], S0[1,0]), poly_mult(S1[0,1], S0[1,1])], [poly_mult( S1[1,1], S0[1,0] ), poly_mult( S1[1,1], S0[1,1] )]]) - # print(L) - L = A[k,k-1] * np.pad(L, ((0,0),(0,0),(1,0)), 'constant') # multiply by X - # TODO: above padding should be able to be optimized away; when we allocate memory properly can store the coefficients directly in the right place - # print(L) - - # clear denominators - # S0 = np.array([[ poly_mult(s, d1) for s in r ] for r in S0]) - # S1 = np.array([[ poly_mult(s, d0) for s in r ] for r in S1]) - # print(S0) - - # really need to define poly matrix operations - # S = np.array([[poly_add(S0[i,j],S1[i,j]) for j in range(2)] for i in range(2)]) - # S = np.array([[poly_add(S[i,j],L[i,j]) for j in range(2)] for i in range(2)]) - # L[0,0] = poly_add(L[0,0], poly_mult(S0[0,0], d1), n) - # L[0,1] = poly_add(L[0,1], poly_mult(S0[0,1], d1), n) - # L[0,0] = poly_add(L[0,0], poly_mult(S1[0,0], d0), n) - # L[1,0] = poly_add(L[1,0], poly_mult(S1[1,0], d0), n) - L[0,0] += poly_mult(S0[0,0], d1) + poly_mult(S1[0,0], d0) - L[0,1] += poly_mult(S0[0,1], d1) - L[1,0] += poly_mult(S1[1,0], d0) - return (L, poly_mult(d0,d1)) - -def krylov_mult(A, v, u, m): - """ - Compute the matrix-vector product Kry(A, v)^T * u - A: R^{n \times n}, lower triangular and 2-banded - u: R^n - v: R^n - m: output dimension (i.e. width of K) - """ - - n = v.shape[0] - assert A.shape == (n,n) - - R, d = resolvent_bilinear(A,v,u,n) - ans = poly_mult(R[0,0], poly_inv(d, m)) - return ans[:m] - -def Amult(d, subd, v): - ans = d*v - ans[1:] += subd*v[:-1] - return ans - -def krylov_mult_slow(A, v, u, m): - n = v.shape[0] - assert A.shape == (n,n) - cols = [v] - d = np.diagonal(A, 0) - subd = np.diagonal(A, -1) - for i in range(1,m): - cols.append(Amult(d, subd, cols[-1])) - K = np.stack(cols, axis=1) - return K.T @ u - -def krylov_mult_slow_allocated(A, v, u, m): - n = v.shape[0] - assert A.shape == (n,n) - d = np.diagonal(A, 0) - subd = np.diagonal(A, -1) - # Allocate memory at once to K - K_T = np.empty((m, n)) - K_T[0] = v - for i in range(1,m): - K_T[i] = Amult(d, subd, K_T[i-1]) - return K_T @ u - -def krylov_construct(A, v, m): - n = v.shape[0] - assert A.shape == (n,n) - d = np.diagonal(A, 0) - subd = np.diagonal(A, -1) - - K = np.zeros(shape=(m,n)) - K[0,:] = v - for i in range(1,m): - K[i,1:] = subd*K[i-1,:-1] - return K - -def krylov_mult_slow_faster(A, v, u, m): - K = krylov_construct(A, v, m) - return K @ u diff --git a/pytorch/structure/scratch/tests_snippets.py b/pytorch/structure/scratch/tests_snippets.py deleted file mode 100644 index 7c40dcb..0000000 --- a/pytorch/structure/scratch/tests_snippets.py +++ /dev/null @@ -1,52 +0,0 @@ -import os -os.environ['MKL_NUM_THREADS'] = '1' -os.environ['OMP_NUM_THREADS'] = '1' -os.environ['NUMEXPR_NUM_THREADS'] = '1' -os.environ['VECLIB_MAXIMUM_THREADS'] = '1' -import numpy as np -from krylovfast import * -from krylovslow import * - -np.random.seed(0) - - -# n, m = 2, 1 -# A = np.array([[0,0],[1,0]]) -# u = np.array([1,1]) -# v = np.array([1,1]) -# print(resolvent_bilinear(A,v,u,n)) -# ans: [2 1], [1, 1], [1 1], [0 1] - - -# n, m = 4, 2 -# A = np.diag(np.arange(1,4),-1) -# u = np.ones(4) -# v = np.ones(4) -# print(resolvent_bilinear(A,v,u,4)) -# print(krylov_mult(A,v,u,4)) -# print(krylov_mult_slow(A,v,u,4)) -# print(krylov_mult_slow_faster(A,v,u,4)) -# print(resolvent_bilinear_flattened(A,v,u,4,2)) -# ans: [4 6 8 6], [1 1 2 6], [1 3 6 6], [0 0 0 6] - - -m = 14 -n = 1< 0: - arg = torch.stack((torch.ones(n, dtype=u.dtype, device=u.device), - torch.zeros(n, dtype=u.dtype, device=u.device)), dim=-1) - else: # Find primitive roots of -1 - angles = torch.arange(n, dtype=u.dtype, device=u.device) / n * np.pi - arg = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1) - eta = mod[:, np.newaxis] * arg - eta_inverse = (1.0 / mod)[:, np.newaxis] * conjugate(arg) - u_f = torch.ifft(eta_inverse * u[..., np.newaxis], 1) - v_f = torch.fft(eta * v[..., np.newaxis], 1) - uv_f = complex_mult(u_f[:, np.newaxis], v_f[np.newaxis]) - uv = torch.fft(uv_f, 1) - # We only need the real part of complex_mult(eta, uv) - return eta[..., 0] * uv[..., 0] - eta[..., 1] * uv[..., 1] + eta = torch.tensor(f, dtype=torch.complex64) ** ( + torch.arange(n, dtype=u.dtype, device=u.device) / n + ) + u_f = torch.fft.ifft(1 / eta * u) + v_f = torch.fft.fft(eta * v) + uv = torch.fft.fft(u_f[:, None] * v_f[None]) + return (eta * uv).real else: - u_f = torch.rfft(torch.cat((u.flip(1), torch.zeros_like(u)), dim=-1), 1) - v_f = torch.rfft(torch.cat((v, torch.zeros_like(v)), dim=-1), 1) - uv_f = complex_mult(u_f[:, np.newaxis], v_f[np.newaxis]) - return torch.irfft(uv_f, 1, signal_sizes=(2 * n, ))[..., :n].flip(2) + u_f = torch.fft.rfft(torch.cat((u.flip(1), torch.zeros_like(u)), dim=-1)) + v_f = torch.fft.rfft(torch.cat((v, torch.zeros_like(v)), dim=-1)) + uv_f = u_f[:, None] * v_f[None] + return torch.fft.irfft(uv_f)[..., :n].flip(2) -def toeplitz_krylov_multiply_by_autodiff(v, w, f=0.0): - """Multiply \sum_i Krylov(Z_f, v_i) @ w_i, using Pytorch's autodiff. - This function is just to check the result of toeplitz_krylov_multiply. +def toeplitz_krylov_multiply(v, w, f=0.0): + """Multiply sum_i Krylov(Z_f, v_i) @ w_i. Parameters: v: (rank, n) w: (batch_size, rank, n) @@ -57,19 +62,30 @@ def toeplitz_krylov_multiply_by_autodiff(v, w, f=0.0): Returns: product: (batch, n) """ - batch_size, rank, n = w.shape + _, rank, n = w.shape rank_, n_ = v.shape - assert n == n_, 'w and v must have the same last dimension' - assert rank == rank_, 'w and v must have the same rank' - - u = torch.zeros((batch_size, n), dtype=v.dtype, device=v.device, requires_grad=True) - prod = toeplitz_krylov_transpose_multiply(v, u, f) - result, = torch.autograd.grad(prod, u, grad_outputs=w, create_graph=True) - return result + assert n == n_, "w and v must have the same last dimension" + assert rank == rank_, "w and v must have the same rank" + if f != 0.0: # cycle version + eta = torch.tensor(f, dtype=torch.complex64) ** ( + torch.arange(n, dtype=v.dtype, device=v.device) / n + ) + w_f = torch.fft.fft(1 / eta * w) + v_f = torch.fft.fft(eta * v) + wv_sum_f = (w_f * v_f).sum(dim=1) # Does this happen in the right space? + wv_sum = torch.fft.ifft(wv_sum_f, 1) + return (1 / eta * wv_sum).real + else: + w_f = torch.fft.rfft(torch.cat((w, torch.zeros_like(w)), dim=-1)) + v_f = torch.fft.rfft(torch.cat((v, torch.zeros_like(v)), dim=-1)) + wv_sum_f = (w_f * v_f).sum(dim=1) + # return torch.fft.irfft(wv_sum_f, 1, signal_sizes=(2 * n,))[..., :n] + return torch.fft.irfft(wv_sum_f)[..., :n] -def toeplitz_krylov_multiply(v, w, f=0.0): - """Multiply \sum_i Krylov(Z_f, v_i) @ w_i. +def toeplitz_krylov_multiply_by_autodiff(v, w, f=0.0): + """Multiply sum_i Krylov(Z_f, v_i) @ w_i, using Pytorch's autodiff. + This function is just to check the result of toeplitz_krylov_multiply. Parameters: v: (rank, n) w: (batch_size, rank, n) @@ -77,36 +93,19 @@ def toeplitz_krylov_multiply(v, w, f=0.0): Returns: product: (batch, n) """ - _, rank, n = w.shape + batch_size, rank, n = w.shape rank_, n_ = v.shape - assert n == n_, 'w and v must have the same last dimension' - assert rank == rank_, 'w and v must have the same rank' - if f != 0.0: # cycle version - # Computing the roots of f - mod = abs(f) ** (torch.arange(n, dtype=w.dtype, device=w.device) / n) - if f > 0: - arg = torch.stack((torch.ones(n, dtype=w.dtype, device=w.device), - torch.zeros(n, dtype=w.dtype, device=w.device)), dim=-1) - else: # Find primitive roots of -1 - angles = torch.arange(n, dtype=w.dtype, device=w.device) / n * np.pi - arg = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1) - eta = mod[:, np.newaxis] * arg - eta_inverse = (1.0 / mod)[:, np.newaxis] * conjugate(arg) - w_f = torch.fft(eta * w[..., np.newaxis], 1) - v_f = torch.fft(eta * v[..., np.newaxis], 1) - wv_sum_f = complex_mult(w_f, v_f).sum(dim=1) - wv_sum = torch.ifft(wv_sum_f, 1) - # We only need the real part of complex_mult(eta_inverse, wv_sum) - return eta_inverse[..., 0] * wv_sum[..., 0] - eta_inverse[..., 1] - wv_sum[..., 1] - else: - w_f = torch.rfft(torch.cat((w, torch.zeros_like(w)), dim=-1), 1) - v_f = torch.rfft(torch.cat((v, torch.zeros_like(v)), dim=-1), 1) - wv_sum_f = complex_mult(w_f, v_f).sum(dim=1) - return torch.irfft(wv_sum_f, 1, signal_sizes=(2 * n, ))[..., :n] + assert n == n_, "w and v must have the same last dimension" + assert rank == rank_, "w and v must have the same rank" + + u = torch.zeros((batch_size, n), dtype=v.dtype, device=v.device, requires_grad=True) + prod = toeplitz_krylov_transpose_multiply(v, u, f) + (result,) = torch.autograd.grad(prod, u, grad_outputs=w, create_graph=True) + return result def toeplitz_mult(G, H, x, cycle=True): - """Multiply \sum_i Krylov(Z_f, G_i) @ Krylov(Z_f, H_i) @ x. + """Multiply sum_i Krylov(Z_f, G_i) @ Krylov(Z_f, H_i) @ x. Parameters: G: Tensor of shape (rank, n) H: Tensor of shape (rank, n) @@ -115,7 +114,6 @@ def toeplitz_mult(G, H, x, cycle=True): Returns: product: Tensor of shape (batch_size, n) """ - # f = (1,-1) if cycle else (1,1) f = (1, -1) if cycle else (0, 0) transpose_out = toeplitz_krylov_transpose_multiply(H, x, f[1]) return toeplitz_krylov_multiply(G, transpose_out, f[0]) @@ -123,6 +121,7 @@ def toeplitz_mult(G, H, x, cycle=True): ##### Slow multiplication for the Toeplitz-like case + def toeplitz_Z_f_linear_map(f=0.0): """The linear map for multiplying by Z_f. This implementation is slow and not batched wrt rank, but easy to understand. @@ -144,17 +143,17 @@ def krylov_toeplitz_fast(v, f=0.0): Returns: K: Krylov matrix of size (n, n) or (rank, n, n). """ - rank, n = v.shape + rank, n = v.shape a = torch.arange(n, device=v.device) b = -a - indices = a[:, np.newaxis] + b[np.newaxis] + indices = a[:, None] + b[None] K = v[:, indices] K[:, indices < 0] *= f return K def toeplitz_mult_slow(G, H, x, cycle=True): - """Multiply \sum_i Krylov(Z_f, G_i) @ Krylov(Z_f, H_i) @ x. + """Multiply sum_i Krylov(Z_f, G_i) @ Krylov(Z_f, H_i) @ x. Uses the explicit Krylov construction with slow (and easy to understand) linear map. Parameters: @@ -165,16 +164,22 @@ def toeplitz_mult_slow(G, H, x, cycle=True): Returns: product: Tensor of shape (batch_size, n) """ - assert G.shape == H.shape, 'G and H must have the same shape' + assert G.shape == H.shape, "G and H must have the same shape" rank, n = G.shape f = (1, -1) if cycle else (0, 0) - krylovs = [(Krylov(toeplitz_Z_f_linear_map(f[0]), G[i]), Krylov(toeplitz_Z_f_linear_map(f[1]), H[i]).t()) for i in range(rank)] + krylovs = [ + ( + Krylov(toeplitz_Z_f_linear_map(f[0]), G[i]), + Krylov(toeplitz_Z_f_linear_map(f[1]), H[i]).t(), + ) + for i in range(rank) + ] prods = [K[0] @ (K[1] @ x.t()) for K in krylovs] return sum(prods).t() def toeplitz_mult_slow_fast(G, H, x, cycle=True): - """Multiply \sum_i Krylov(Z_f, G_i) @ Krylov(Z_f, H_i) @ x. + """Multiply sum_i Krylov(Z_f, G_i) @ Krylov(Z_f, H_i) @ x. Uses the fast construction of Krylov matrix. Parameters: G: Tensor of shape (rank, n) @@ -188,66 +193,3 @@ def toeplitz_mult_slow_fast(G, H, x, cycle=True): f_G, f_H = (1, -1) if cycle else (0, 0) K_G, K_H = krylov_toeplitz_fast(G, f_G), krylov_toeplitz_fast(H, f_H) return ((x @ K_H) @ K_G.transpose(1, 2)).sum(dim=0) - - -def test_toeplitz_mult(): - v = torch.tensor([[0,1,0,-1],[0,1,2,3]], dtype=torch.float, device=device, requires_grad=True) - u = torch.tensor([[1,1,1,1],[0,1,2,3]], dtype=torch.float, device=device, requires_grad=True) - - w = toeplitz_krylov_transpose_multiply(v, u, f=-1) - # output: - # [[[ 0 2 2 0] - # [ 6 0 -4 -6]] - - # [[ -2 2 4 2] - # [ 14 8 0 -8]]] - - toeplitz_mult(v, v, u) - toeplitz_mult_slow(v, v, u) - # output: - # array([[-16., -20., -4., 16.], - # [ 16., -8., 12., 64.]]) - - toeplitz_mult(v, v, u, cycle=False) - toeplitz_mult_slow(v, v, u, cycle=False) - # output: - # array([[ 0., 6., 16., 26.], - # [ 0., 12., 38., 66.]]) - - m = 10 - n = 1< 0: - arg = np.ones(n) - else: - arg = np.fft.fft(np.eye(1,2*n,2*n-1))[0,:n] - self.eta = mod * arg - + self.eta = f ** (torch.arange(n).to(torch.complex64) / n) if f != 0 else None def __call__(self, v, u): """ @@ -35,45 +38,40 @@ def __call__(self, v, u): u: (batch, n) out: (batch, rank, n) """ - n, m, batch_size, rank = self.n, self.m, self.batch_size, self.rank - - if self.eta is not None: # cycle version - u_ = np.fft.ifft(1/self.eta * u) - v_ = np.fft.fft(self.eta * v) - uv_ = u_.reshape(batch_size, 1, n) * v_.reshape(1, rank, n) - ans = self.eta * np.fft.fft(uv_) - return np.real(ans) + n, _, batch_size, rank = self.n, self.m, self.batch_size, self.rank + + if self.eta is not None: # cycle version + eta = self.eta.to(u.device) + u_ = torch.fft.ifft(1 / eta * u) + v_ = torch.fft.fft(eta * v) + # uv_ = u_.reshape(batch_size, 1, n) * v_.reshape(1, rank, n) + # ans = eta * torch.fft.fft(uv_) + # return torch.real(ans) + uv = torch.fft.fft(u_[:, None] * v_[None]) + return (eta * uv).real else: - u_ = np.fft.rfft(np.concatenate((u[...,::-1], np.zeros_like(u)), axis=-1)) - v_ = np.fft.rfft(np.concatenate((v, np.zeros_like(v)), axis=-1)) + u_ = torch.fft.rfft( + torch.concatenate((u.flip(-1), torch.zeros_like(u)), dim=-1) + ) + v_ = torch.fft.rfft(torch.concatenate((v, torch.zeros_like(v)), dim=-1)) uv_ = u_.reshape(batch_size, 1, -1) * v_.reshape(1, rank, -1) - ans = np.fft.irfft(uv_)[..., n-1::-1] + # ans = torch.fft.irfft(uv_)[..., n - 1 :: -1] + ans = torch.fft.irfft(uv_)[..., :n].flip(-1) return ans -class K_Toeplitz(): - """Multiply Krylov(A, v) @ w when A is zero except on the subdiagonal. - """ +class K_Toeplitz: + """Multiply Krylov(A, v) @ w when A is zero except on the subdiagonal.""" def __init__(self, n, f, batch_size=1, rank=1): - m = int(np.log2(n)) - assert n == 1 << m, 'n must be a power of 2' + m = int(log2(n)) + assert n == 1 << m, "n must be a power of 2" self.n = n self.m = m self.batch_size = batch_size self.rank = rank - self.eta = None - if f == 0: - pass - else: - mod = np.power(np.abs(f), np.arange(n)/n) - if f > 0: - arg = np.ones(n) - else: - arg = np.fft.fft(np.eye(1,2*n,2*n-1))[0,:n] - # arg = np.exp(np.arange(n) * 1j * np.pi / n) - self.eta = mod * arg + self.eta = f ** (torch.arange(n).to(torch.complex64) / n) if f != 0 else None def __call__(self, v, w): """ @@ -81,75 +79,51 @@ def __call__(self, v, w): w: (batch_size, rank, n) out: (batch_size, n) """ - n, m, batch_size, rank = self.n, self.m, self.batch_size, self.rank + n, _, _, rank = self.n, self.m, self.batch_size, self.rank if self.eta is not None: - w_ = np.fft.fft(self.eta * w) - v_ = np.fft.fft(self.eta * v) - wv_ = w_ * v_.reshape((1, rank, n)) - ans = 1/self.eta * np.fft.ifft(np.sum(wv_, axis=1)) - ans = np.real(ans) + eta = self.eta.to(v.device) + w_ = torch.fft.fft(eta * w) + v_ = torch.fft.fft(eta * v) + wv_ = w_ * v_.reshape((1, rank, n)) + ans = 1 / eta * torch.fft.ifft(torch.sum(wv_, dim=1)) + ans = torch.real(ans) else: - w_ = np.fft.rfft(np.concatenate((w, np.zeros_like(w)), axis=-1)) - v_ = np.fft.rfft(np.concatenate((v, np.zeros_like(v)), axis=-1)) - wv_ = w_ * v_.reshape((1, rank, -1)) - ans = np.fft.irfft(np.sum(wv_, axis=1))[..., :n] + w_ = torch.fft.rfft(torch.concatenate((w, torch.zeros_like(w)), dim=-1)) + v_ = torch.fft.rfft(torch.concatenate((v, torch.zeros_like(v)), dim=-1)) + wv_ = w_ * v_.reshape((1, rank, -1)) + ans = torch.fft.irfft(torch.sum(wv_, dim=1))[..., :n] return ans def toeplitz_mult(G, H, x, cycle=True): rank, n = G.shape batch_size = x.shape[0] - f = (1,-1) if cycle else (0,0) + f = (1, -1) if cycle else (0, 0) transpose_out = KT_Toeplitz(n, f[1], batch_size, rank)(H, x) krylov_out = K_Toeplitz(n, f[0], batch_size, rank)(G, transpose_out) - return krylov_out/2 if cycle else krylov_out + return krylov_out ##### Slow mult + def krylov_construct(f, v, m): n = v.shape[0] - K = np.zeros(shape=(m,n)) - K[0,:] = v - for i in range(1,m): - K[i,1:] = K[i-1,:-1] - K[i,0] = f*K[i-1,-1] + K = torch.zeros(size=(m, n), device=f.device) + K[0, :] = v + for i in range(1, m): + K[i, 1:] = K[i - 1, :-1] + K[i, 0] = f * K[i - 1, -1] return K.T + def toeplitz_mult_slow(G, H, x, cycle=True): assert G.shape == H.shape rank, n = G.shape - f = (1,-1) if cycle else (0,0) - krylovs = [(krylov_construct(f[0], G[i], n), krylov_construct(f[1], H[i], n).T) for i in range(rank)] + f = (1, -1) if cycle else (0, 0) + krylovs = [ + (krylov_construct(f[0], G[i], n), krylov_construct(f[1], H[i], n).T) + for i in range(rank) + ] prods = [K[0] @ K[1] @ x.T for K in krylovs] - return np.sum(np.array(prods), axis=0).T - -if __name__ == '__main__': - v = np.array([[0,1,0,-1],[0,1,2,3]]) - u = np.array([[1,1,1,1],[0,1,2,3]]) - - w = KT_Toeplitz(4, -1, 2, 2)(v, u) - # output: - # [[[ 0 2 2 0] - # [ 6 0 -4 -6]] - - # [[ -2 2 4 2] - # [ 14 8 0 -8]]] - - w = KT_Toeplitz(4, 0, 2, 2)(v, u) - # [[[ 0 1 1 0] - # [ 6 3 1 0]] - # [[ -2 2 3 0] - # [ 14 8 3 0]]] - - print(toeplitz_mult(v, v, u)) - print(toeplitz_mult_slow(v, v, u)) - # output: - # array([[-16., -20., -4., 16.], - # [ 16., -8., 12., 64.]]) - - print(toeplitz_mult(v, v, u, cycle=False)) - print(toeplitz_mult_slow(v, v, u, cycle=False)) - # output: - # array([[ 0., 6., 16., 26.], - # [ 0., 12., 38., 66.]]) + return torch.sum(torch.tensor(prods), dim=0).T