Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 59 additions & 23 deletions pytorch/structure/LDR.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,41 @@
# Copyright 2018 HazyResearch
# https://github.com/HazyResearch/structured-nets
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch
from torch.autograd import Variable
import torch.nn as nn
from torch.autograd import Variable
from torch.nn.parameter import Parameter

from . import toeplitz as toep
from . import krylov as kry
from . import krylov as kry, toeplitz as toep

# TODO: rewrite with structure.layer
# TODO: subclass with each DR type
class LDR(nn.Module):

def name(self):
return str(self.in_channels) + str(self.out_channels) + self.displacement + str(self.r)
return (
str(self.in_channels)
+ str(self.out_channels)
+ self.displacement
+ str(self.r)
)

# TODO: support non-square multiplications
def __init__(self, displacement, in_channels, out_channels, rank, layer_size, bias=True):
def __init__(
self, displacement, in_channels, out_channels, rank, layer_size, bias=True
):
super(LDR, self).__init__()
self.displacement = displacement
self.in_channels = in_channels
Expand All @@ -23,20 +44,27 @@ def __init__(self, displacement, in_channels, out_channels, rank, layer_size, bi
self.n = layer_size
self.bias = None

self.G = Parameter(torch.Tensor(self.in_channels, self.out_channels, self.r, self.n))
self.H = Parameter(torch.Tensor(self.in_channels, self.out_channels, self.r, self.n))
torch.nn.init.normal_(self.G, std=0.01) #TODO
self.G = Parameter(
torch.Tensor(self.in_channels, self.out_channels, self.r, self.n)
)
self.H = Parameter(
torch.Tensor(self.in_channels, self.out_channels, self.r, self.n)
)
torch.nn.init.normal_(self.G, std=0.01) # TODO
torch.nn.init.normal_(self.H, std=0.01)
if bias:
self.bias = Parameter(torch.zeros(self.out_channels, 1, self.n))
if self.displacement == 'toeplitz_corner' or self.displacement == 'tc':
if self.displacement == "toeplitz_corner" or self.displacement == "tc":
self.corner = True
elif self.displacement == 'toeplitz' or self.displacement == 't':
elif self.displacement == "toeplitz" or self.displacement == "t":
self.corner = False
elif self.displacement == 'subdiagonal' or self.displacement == 'sd':
self.subd_A = Parameter(torch.ones((self.in_channels, self.out_channels, self.n-1)))
self.subd_B = Parameter(torch.ones((self.in_channels, self.out_channels, self.n-1)))

elif self.displacement == "subdiagonal" or self.displacement == "sd":
self.subd_A = Parameter(
torch.ones((self.in_channels, self.out_channels, self.n - 1))
)
self.subd_B = Parameter(
torch.ones((self.in_channels, self.out_channels, self.n - 1))
)

def forward(self, x):
"""
Expand All @@ -47,15 +75,23 @@ def forward(self, x):
assert n == self.n

# print("shapes ", self.G[0,0].shape, self.H[0,0].shape, x[0].shape)
comps = Variable(torch.Tensor(self.in_channels, self.out_channels, b, self.n)).cuda()
comps = Variable(
torch.Tensor(self.in_channels, self.out_channels, b, self.n)
).cuda()
for i in range(self.in_channels):
for j in range(self.out_channels):
if self.displacement in ['toeplitz_corner', 'toeplitz', 'tc', 't']:
g = self.G[i,j]
h = self.H[i,j]
comps[i,j] = toep.toeplitz_mult(self.G[i,j], self.H[i,j], x[i], self.corner)
elif self.displacement == 'subdiagonal' or self.displacement == 'sd':
comps[i,j] = kry.subdiag_mult_conv(self.subd_A[i,j], self.subd_B[i,j], self.G[i,j], self.H[i,j], x[i])
if self.displacement in ["toeplitz_corner", "toeplitz", "tc", "t"]:
comps[i, j] = toep.toeplitz_mult(
self.G[i, j], self.H[i, j], x[i], self.corner
)
elif self.displacement == "subdiagonal" or self.displacement == "sd":
comps[i, j] = kry.subdiag_mult_conv(
self.subd_A[i, j],
self.subd_B[i, j],
self.G[i, j],
self.H[i, j],
x[i],
)
out = torch.sum(comps, dim=0)
if self.bias is not None:
out += self.bias
Expand All @@ -64,4 +100,4 @@ def forward(self, x):
def loss(self):
lamb = 0.0001
# lamb = 0
return lamb*torch.sum(torch.abs(self.G)) + lamb*torch.sum(torch.abs(self.H))
return lamb * torch.sum(torch.abs(self.G)) + lamb * torch.sum(torch.abs(self.H))
40 changes: 23 additions & 17 deletions pytorch/structure/circulant.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,33 @@
# Copyright 2018 HazyResearch
# https://github.com/HazyResearch/structured-nets
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch
from scipy.linalg import circulant
from .complex_utils import complex_mult

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def circulant_multiply(c, x):
""" Multiply circulant matrix with first column c by x
"""Multiply circulant matrix with first column c by x.
E.g. if the matrix is
[1 2 3]
[2 3 1]
[3 1 2]
c should be [1,2,3]
Parameters:
c: (n, )
x: (batch_size, n) or (n, )
Return:
prod: (batch_size, n) or (n, )
"""
return torch.irfft(complex_mult(torch.rfft(c, 1), torch.rfft(x, 1)), 1, signal_sizes=(c.shape[-1], ))

def test_circulant_multiply(n):
c = torch.rand(n, device=device)
x = torch.rand((3, n), device=device)
C = torch.tensor(circulant(c.detach().cpu().numpy()), dtype=c.dtype, device=c.device)
slow = x @ C.t()
fast = circulant_multiply(c, x)
print('Error compared to slow multiply: ', (slow - fast).abs().max().item())

# TODO: move test into subpackage
if __name__ == '__main__':
test_circulant_multiply(100)
return torch.fft.irfft(torch.fft.rfft(c) * torch.fft.rfft(x))
19 changes: 0 additions & 19 deletions pytorch/structure/complex_utils.py

This file was deleted.

71 changes: 63 additions & 8 deletions pytorch/structure/diag_mult_cuda/diag_mult_cuda.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,37 @@
/*
Copyright 2018 HazyResearch
https://github.com/HazyResearch/structured-nets

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

#include <torch/extension.h>

void subdiagMultGPU(float *d_Subdiag, float *d_Data, float *d_Output, int shiftSubdiag, int shiftV, int batchSize, int N, bool batchedSubdiag);
void subdiagMultGPU(
float* d_Subdiag,
float* d_Data,
float* d_Output,
int shiftSubdiag,
int shiftV,
int batchSize,
int N,
bool batchedSubdiag);

torch::Tensor cycle_mult(torch::Tensor subdiag, torch::Tensor v, int shiftSubdiag, int shiftV) {
torch::Tensor cycle_mult(
torch::Tensor subdiag,
torch::Tensor v,
int64_t shiftSubdiag,
int64_t shiftV) {
TORCH_CHECK(subdiag.is_cuda(), "subdiag must be a CUDA tensor");
TORCH_CHECK(v.is_cuda(), "v must be a CUDA tensor");
// Need to make tensors contiguous before passing to CUDA
Expand All @@ -12,28 +41,54 @@ torch::Tensor cycle_mult(torch::Tensor subdiag, torch::Tensor v, int shiftSubdia
auto batchSize = v.numel() / n;
auto output = torch::empty_like(v);
bool batchedSubdiag = subdiag.numel() == v.numel();
subdiagMultGPU(subdiag.data_ptr<float>(), v.data_ptr<float>(), output.data_ptr<float>(), shiftSubdiag, shiftV, batchSize, n, batchedSubdiag);
subdiagMultGPU(
subdiag.data_ptr<float>(),
v.data_ptr<float>(),
output.data_ptr<float>(),
shiftSubdiag,
shiftV,
batchSize,
n,
batchedSubdiag);
return output;
}

torch::Tensor subdiagKrylov(torch::Tensor subdiag, torch::Tensor v, int m) {
torch::Tensor subdiagKrylov(torch::Tensor subdiag, torch::Tensor v, int64_t m) {
TORCH_CHECK(subdiag.is_cuda(), "subdiag must be a CUDA tensor");
TORCH_CHECK(v.is_cuda(), "v must be a CUDA tensor");
// Need to make tensors contiguous before passing to CUDA
subdiag = subdiag.contiguous();
v = v.contiguous();
auto n = v.sizes().back();
auto batchSize = v.numel() / n;
auto output = torch::empty({m, batchSize, n}, torch::dtype(v.dtype()).device(v.device()));
// subdiagKrylovGPU(subdiag.data_ptr<float>(), v.data_ptr<float>(), output.data_ptr<float>(), shiftSubdiag, shiftV, batchSize, n);
auto output = torch::empty(
{m, batchSize, n}, torch::dtype(v.dtype()).device(v.device()));
// subdiagKrylovGPU(subdiag.data_ptr<float>(), v.data_ptr<float>(),
// output.data_ptr<float>(), shiftSubdiag, shiftV, batchSize, n);
output[0] = v;
for (int i = 1; i < m; ++i) {
subdiagMultGPU(subdiag.data_ptr<float>(), output[i - 1].data_ptr<float>(), output[i].data_ptr<float>(), 0, -1, batchSize, n, false);
subdiagMultGPU(
subdiag.data_ptr<float>(),
output[i - 1].data_ptr<float>(),
output[i].data_ptr<float>(),
0,
-1,
batchSize,
n,
false);
}
return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("cycle_mult", &cycle_mult, "Cycle the vector and then do a pointwise multiplication. Shift should be between -n and n - 1.");
m.def(
"cycle_mult",
&cycle_mult,
"Cycle the vector and then do a pointwise multiplication. Shift should be between -n and n - 1.");
m.def("subdiagKrylov", &subdiagKrylov, "Subdiag Krylov");
}

TORCH_LIBRARY(TORCH_EXTENSION_NAME, m) {
m.def("cycle_mult", &cycle_mult);
m.def("subdiagKrylov", &subdiagKrylov);
}
64 changes: 49 additions & 15 deletions pytorch/structure/diag_mult_cuda/diag_mult_cuda_kernel.cu
Original file line number Diff line number Diff line change
@@ -1,18 +1,52 @@
__global__ void subdiagMult(float *d_Subdiag, float *d_Data, float *d_Output, int shiftSubdiag, int shiftV, int N, int subdiagOffset) {
const int pos = blockIdx.x * blockDim.x + threadIdx.x;
/*
Copyright 2018 HazyResearch
https://github.com/HazyResearch/structured-nets

float *d_Src = d_Data + blockIdx.y * N;
float *d_Dst = d_Output + blockIdx.y * N;
float *d_Sub = d_Subdiag + blockIdx.y * subdiagOffset;
// for (int pos = tid; pos < N; pos += numThreads) {
if (pos < N) {
d_Dst[pos] = d_Sub[(pos + shiftSubdiag + N) % N] * d_Src[(pos + shiftV + N) % N];
}
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

__global__ void subdiagMult(
float* d_Subdiag,
float* d_Data,
float* d_Output,
int shiftSubdiag,
int shiftV,
int N,
int subdiagOffset) {
const int pos = blockIdx.x * blockDim.x + threadIdx.x;

float* d_Src = d_Data + blockIdx.y * N;
float* d_Dst = d_Output + blockIdx.y * N;
float* d_Sub = d_Subdiag + blockIdx.y * subdiagOffset;
// for (int pos = tid; pos < N; pos += numThreads) {
if (pos < N) {
d_Dst[pos] =
d_Sub[(pos + shiftSubdiag + N) % N] * d_Src[(pos + shiftV + N) % N];
}
}

void subdiagMultGPU(float *d_Subdiag, float *d_Data, float *d_Output, int shiftSubdiag, int shiftV, int batchSize, int N, bool batchedSubdiag) {
const int THREAD_N = 256;
dim3 grid((N + THREAD_N - 1) / THREAD_N, batchSize);
int subdiagOffset = batchedSubdiag ? N : 0;
subdiagMult<<<grid, THREAD_N>>>(d_Subdiag, d_Data, d_Output, shiftSubdiag, shiftV, N, subdiagOffset);
}
void subdiagMultGPU(
float* d_Subdiag,
float* d_Data,
float* d_Output,
int shiftSubdiag,
int shiftV,
int batchSize,
int N,
bool batchedSubdiag) {
const int THREAD_N = 256;
dim3 grid((N + THREAD_N - 1) / THREAD_N, batchSize);
int subdiagOffset = batchedSubdiag ? N : 0;
subdiagMult<<<grid, THREAD_N>>>(
d_Subdiag, d_Data, d_Output, shiftSubdiag, shiftV, N, subdiagOffset);
}
21 changes: 0 additions & 21 deletions pytorch/structure/diag_mult_cuda/setup.py

This file was deleted.

Loading