-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathSparseConv.py
More file actions
171 lines (131 loc) · 6.41 KB
/
SparseConv.py
File metadata and controls
171 lines (131 loc) · 6.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import torch
import torch.nn as nn
import torch.sparse as sparse
from typing import Union, Any, Tuple
import model_utils
import gc
class SparseConv(nn.Conv2d):
def __init__(self, *args, **kwargs):
super(SparseConv, self).__init__(*args, **kwargs)
def forward(self,
input_act: torch.tensor) -> torch.tensor:
Unfold = nn.Unfold(kernel_size=self.kernel_size, dilation=self.dilation, padding=self.padding, stride=self.stride)
unfolded_input = input_act.to(torch.float16).detach()
unfolded_input = Unfold(unfolded_input)
unfolded_input_size = unfolded_input.size()
group, batch_size = self.groups, unfolded_input_size[0]
if self.groups == 1:
flattened_weight = self.weight.half().detach()
flattened_weight = flattened_weight.view(self.weight.size(0), -1).t()
fw_complete = self._complete(flattened_weight.T).half().cuda().contiguous()
fw_sparse = sparse.to_sparse_semi_structured(fw_complete)
ui_trans = self._complete3D(unfolded_input)
fw_sparse_size = fw_sparse.size()
flattened_weight_size = flattened_weight.size()
del unfolded_input
del flattened_weight
del fw_complete
out_unf = torch.zeros((unfolded_input_size[0], fw_sparse_size[0], ui_trans.size(2)), dtype=torch.float16, device='cuda')
for ind in range(batch_size):
out_unf[ind] = torch.mm(fw_sparse, ui_trans[ind]).detach()
if self.bias is not None:
out_unf = out_unf + self.bias.view(1, -1, 1)
out_unf_sliced = out_unf[:, :flattened_weight_size[1], :unfolded_input_size[2]].detach()
del fw_sparse
del ui_trans
out_unf_sliced = out_unf[:, :flattened_weight_size[1], :unfolded_input_size[2]]
del out_unf
output_size = torch.sqrt(torch.tensor([out_unf_sliced.size(2)])).to(torch.int)
Fold = nn.Fold(output_size=(output_size, output_size), kernel_size=(1, 1))
output_act = Fold(out_unf_sliced)
del out_unf_sliced
del Fold
del Unfold
torch.cuda.empty_cache()
return output_act.to(torch.float).detach()
def forward_old(self,
input_act: torch.tensor) -> torch.tensor:
Unfold = nn.Unfold(kernel_size=self.kernel_size, dilation=self.dilation, padding=self.padding, stride=self.stride)
unfolded_input = Unfold(input_act).cuda()
flattened_weight = self.weight.view(self.weight.size(0), -1).t()
### To apply matrix multiplication
# out_unf = unfolded_input.transpose(1, 2).matmul(flattened_weight).transpose(1, 2)
# out_unf = flattened_weight.T.matmul(unfolded_input)
fw_complete = self._complete(flattened_weight.T).half().cuda().contiguous()
#fw_sparse = self._prune(fw_sparse)
fw_sparse = sparse.to_sparse_semi_structured(fw_complete)
ui_trans = self._complete3D(unfolded_input).half().cuda()
out_unf = torch.zeros((unfolded_input.size(0), fw_sparse.size(0), ui_trans.size(2))).half().cuda()
for ind in range(unfolded_input.size(0)):
out_unf[ind] = torch.mm(fw_sparse, ui_trans[ind])
if self.bias is not None:
out_unf = out_unf + self.bias.view(1, -1, 1)
out_unf_sliced = out_unf[:, :flattened_weight.size(1), :unfolded_input.size(2)]
output_size = torch.sqrt(torch.tensor([out_unf_sliced.size(2)])).to(torch.int)
Fold = nn.Fold(output_size=(output_size, output_size), kernel_size=(1, 1))
#output_act = torch.nan_to_num(Fold(out_unf_sliced))
output_act = Fold(out_unf_sliced)
# model_utils.memory_stats(8)
del unfolded_input
del flattened_weight
del fw_complete
del fw_sparse
del out_unf
del out_unf_sliced
del ui_trans
del Fold
del Unfold
torch.cuda.empty_cache()
# model_utils.memory_stats(9)
return output_act.float().detach()
def _complete(self, tensor, MS=64, NS=64, KS=64):
shape = tensor.size()
if (shape[0] % MS) > 0:
new_shape_a = MS * (shape[0] // MS) + MS
else:
new_shape_a = shape[0]
if (shape[1] % NS) > 0:
new_shape_b = NS * (shape[1] // NS) + NS
else:
new_shape_b = shape[1]
new_shape = (new_shape_a, new_shape_b)
new_tensor = torch.zeros(new_shape, dtype=torch.half, device='cuda')
new_tensor[: shape[0], : shape[1]] = tensor.half()
del tensor
torch.cuda.empty_cache()
return new_tensor
def _complete3D(self, tensor, MS=64, NS=64, KS=64):
shape = tensor.size()
if (shape[1] % MS) > 0:
new_shape_a = MS * (shape[1] // MS) + MS
else:
new_shape_a = shape[1]
if (shape[2] % NS) > 0:
new_shape_b = NS * (shape[2] // NS) + NS
else:
new_shape_b = shape[2]
new_shape = (shape[0], new_shape_a, new_shape_b)
new_tensor = torch.zeros(new_shape, dtype=torch.half, device='cuda')
new_tensor[:, : shape[1], : shape[2]] = tensor.half()
del tensor
torch.cuda.empty_cache()
return new_tensor
def _prune(self, tensor, group_size=4, k=2):
"""
For each group of `group_size` elements in the last dimension of `tensor`,
keep the top-k values and zero out the rest.
"""
*leading_dims, last_dim = tensor.shape
#assert last_dim % group_size == 0, "Last dimension must be divisible by group_size"
num_groups = last_dim // group_size
# Reshape the tensor to group the last dimension
tensor_grouped = tensor.view(*leading_dims, num_groups, group_size)
# Get top-k indices in each group
_, topk_idx = torch.topk(torch.abs(tensor_grouped), k=k, dim=-1)
# Build a mask for the top-k values
mask = torch.zeros_like(tensor_grouped, dtype=torch.bool)
mask = mask.scatter_(-1, topk_idx, True)
# Apply the mask
tensor_filtered = tensor_grouped * mask
# Reshape back to original shape
return tensor_filtered.view(*tensor.shape)