-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlinear_quantization.py
More file actions
791 lines (684 loc) · 33.1 KB
/
linear_quantization.py
File metadata and controls
791 lines (684 loc) · 33.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
from base_model import *
from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap
import copy
def test_linear_quantize(
test_tensor=torch.tensor([
[0.0523, 0.6364, -0.0968, -0.0020, 0.1940],
[0.7500, 0.5507, 0.6188, -0.1734, 0.4677],
[-0.0669, 0.3836, 0.4297, 0.6267, -0.0695],
[0.1536, -0.0038, 0.6075, 0.6817, 0.0601],
[0.6446, -0.2500, 0.5376, -0.2226, 0.2333]]),
quantized_test_tensor=torch.tensor([
[-1, 1, -1, -1, 0],
[1, 1, 1, -2, 0],
[-1, 0, 0, 1, -1],
[-1, -1, 1, 1, -1],
[1, -2, 1, -2, 0]], dtype=torch.int8),
real_min=-0.25, real_max=0.75, bitwidth=2, scale=1 / 3, zero_point=-1):
def plot_matrix(tensor, ax, title, vmin=0, vmax=1, cmap=ListedColormap(['white'])):
ax.imshow(tensor.cpu().numpy(), vmin=vmin, vmax=vmax, cmap=cmap)
ax.set_title(title)
ax.set_yticklabels([])
ax.set_xticklabels([])
for i in range(tensor.shape[0]):
for j in range(tensor.shape[1]):
datum = tensor[i, j].item()
if isinstance(datum, float):
text = ax.text(j, i, f'{datum:.2f}',
ha="center", va="center", color="k")
else:
text = ax.text(j, i, f'{datum}',
ha="center", va="center", color="k")
quantized_min, quantized_max = get_quantized_range(bitwidth)
fig, axes = plt.subplots(1, 3, figsize=(10, 32))
plot_matrix(test_tensor, axes[0], 'original tensor', vmin=real_min, vmax=real_max)
_quantized_test_tensor = linear_quantize(
test_tensor, bitwidth=bitwidth, scale=scale, zero_point=zero_point)
_reconstructed_test_tensor = scale * (_quantized_test_tensor.float() - zero_point)
print('* Test linear_quantize()')
print(f' target bitwidth: {bitwidth} bits')
print(f' scale: {scale}')
print(f' zero point: {zero_point}')
assert _quantized_test_tensor.equal(quantized_test_tensor)
print('* Test passed.')
plot_matrix(_quantized_test_tensor, axes[1], f'2-bit linear quantized tensor',
vmin=quantized_min, vmax=quantized_max, cmap='tab20c')
plot_matrix(_reconstructed_test_tensor, axes[2], f'reconstructed tensor',
vmin=real_min, vmax=real_max, cmap='tab20c')
fig.tight_layout()
plt.show()
def test_quantized_fc(
input=torch.tensor([
[0.6118, 0.7288, 0.8511, 0.2849, 0.8427, 0.7435, 0.4014, 0.2794],
[0.3676, 0.2426, 0.1612, 0.7684, 0.6038, 0.0400, 0.2240, 0.4237],
[0.6565, 0.6878, 0.4670, 0.3470, 0.2281, 0.8074, 0.0178, 0.3999],
[0.1863, 0.3567, 0.6104, 0.0497, 0.0577, 0.2990, 0.6687, 0.8626]]),
weight=torch.tensor([
[1.2626e-01, -1.4752e-01, 8.1910e-02, 2.4982e-01, -1.0495e-01,
-1.9227e-01, -1.8550e-01, -1.5700e-01],
[2.7624e-01, -4.3835e-01, 5.1010e-02, -1.2020e-01, -2.0344e-01,
1.0202e-01, -2.0799e-01, 2.4112e-01],
[-3.8216e-01, -2.8047e-01, 8.5238e-02, -4.2504e-01, -2.0952e-01,
3.2018e-01, -3.3619e-01, 2.0219e-01],
[8.9233e-02, -1.0124e-01, 1.1467e-01, 2.0091e-01, 1.1438e-01,
-4.2427e-01, 1.0178e-01, -3.0941e-04],
[-1.8837e-02, -2.1256e-01, -4.5285e-01, 2.0949e-01, -3.8684e-01,
-1.7100e-01, -4.5331e-01, -2.0433e-01],
[-2.0038e-01, -5.3757e-02, 1.8997e-01, -3.6866e-01, 5.5484e-02,
1.5643e-01, -2.3538e-01, 2.1103e-01],
[-2.6875e-01, 2.4984e-01, -2.3514e-01, 2.5527e-01, 2.0322e-01,
3.7675e-01, 6.1563e-02, 1.7201e-01],
[3.3541e-01, -3.3555e-01, -4.3349e-01, 4.3043e-01, -2.0498e-01,
-1.8366e-01, -9.1553e-02, -4.1168e-01]]),
bias=torch.tensor([0.1954, -0.2756, 0.3113, 0.1149, 0.4274, 0.2429, -0.1721, -0.2502]),
quantized_bias=torch.tensor([3, -2, 3, 1, 3, 2, -2, -2], dtype=torch.int32),
shifted_quantized_bias=torch.tensor([-1, 0, -3, -1, -3, 0, 2, -4], dtype=torch.int32),
calc_quantized_output=torch.tensor([
[0, -1, 0, -1, -1, 0, 1, -2],
[0, 0, -1, 0, 0, 0, 0, -1],
[0, 0, 0, -1, 0, 0, 0, -1],
[0, 0, 0, 0, 0, 1, -1, -2]], dtype=torch.int8),
bitwidth=2, batch_size=4, in_channels=8, out_channels=8):
def plot_matrix(tensor, ax, title, vmin=0, vmax=1, cmap=ListedColormap(['white'])):
ax.imshow(tensor.cpu().numpy(), vmin=vmin, vmax=vmax, cmap=cmap)
ax.set_title(title)
ax.set_yticklabels([])
ax.set_xticklabels([])
for i in range(tensor.shape[0]):
for j in range(tensor.shape[1]):
datum = tensor[i, j].item()
if isinstance(datum, float):
text = ax.text(j, i, f'{datum:.2f}',
ha="center", va="center", color="k")
else:
text = ax.text(j, i, f'{datum}',
ha="center", va="center", color="k")
output = torch.nn.functional.linear(input, weight, bias)
quantized_weight, weight_scale, weight_zero_point = \
linear_quantize_weight_per_channel(weight, bitwidth)
quantized_input, input_scale, input_zero_point = \
linear_quantize_feature(input, bitwidth)
_quantized_bias, bias_scale, bias_zero_point = \
linear_quantize_bias_per_output_channel(bias, weight_scale, input_scale)
assert _quantized_bias.equal(_quantized_bias)
_shifted_quantized_bias = \
shift_quantized_linear_bias(quantized_bias, quantized_weight, input_zero_point)
assert _shifted_quantized_bias.equal(shifted_quantized_bias)
quantized_output, output_scale, output_zero_point = \
linear_quantize_feature(output, bitwidth)
_calc_quantized_output = quantized_linear(
quantized_input, quantized_weight, shifted_quantized_bias,
bitwidth, bitwidth,
input_zero_point, output_zero_point,
input_scale, weight_scale, output_scale)
assert _calc_quantized_output.equal(calc_quantized_output)
reconstructed_weight = weight_scale * (quantized_weight.float() - weight_zero_point)
reconstructed_input = input_scale * (quantized_input.float() - input_zero_point)
reconstructed_bias = bias_scale * (quantized_bias.float() - bias_zero_point)
reconstructed_calc_output = output_scale * (calc_quantized_output.float() - output_zero_point)
fig, axes = plt.subplots(3, 3, figsize=(15, 12))
quantized_min, quantized_max = get_quantized_range(bitwidth)
plot_matrix(weight, axes[0, 0], 'original weight', vmin=-0.5, vmax=0.5)
plot_matrix(input.t(), axes[1, 0], 'original input', vmin=0, vmax=1)
plot_matrix(output.t(), axes[2, 0], 'original output', vmin=-1.5, vmax=1.5)
plot_matrix(quantized_weight, axes[0, 1], f'{bitwidth}-bit linear quantized weight',
vmin=quantized_min, vmax=quantized_max, cmap='tab20c')
plot_matrix(quantized_input.t(), axes[1, 1], f'{bitwidth}-bit linear quantized input',
vmin=quantized_min, vmax=quantized_max, cmap='tab20c')
plot_matrix(calc_quantized_output.t(), axes[2, 1], f'quantized output from quantized_linear()',
vmin=quantized_min, vmax=quantized_max, cmap='tab20c')
plot_matrix(reconstructed_weight, axes[0, 2], f'reconstructed weight',
vmin=-0.5, vmax=0.5, cmap='tab20c')
plot_matrix(reconstructed_input.t(), axes[1, 2], f'reconstructed input',
vmin=0, vmax=1, cmap='tab20c')
plot_matrix(reconstructed_calc_output.t(), axes[2, 2], f'reconstructed output',
vmin=-1.5, vmax=1.5, cmap='tab20c')
print('* Test quantized_fc()')
print(f' target bitwidth: {bitwidth} bits')
print(f' batch size: {batch_size}')
print(f' input channels: {in_channels}')
print(f' output channels: {out_channels}')
print('* Test passed.')
fig.tight_layout()
plt.show()
def get_quantized_range(bitwidth):
quantized_max = (1 << (bitwidth - 1)) - 1
quantized_min = -(1 << (bitwidth - 1))
return quantized_min, quantized_max
def linear_quantize(fp_tensor, bitwidth, scale, zero_point, dtype=torch.int8) -> torch.Tensor:
"""
linear quantization for single fp_tensor
from
fp_tensor = (quantized_tensor - zero_point) * scale
we have,
quantized_tensor = int(round(fp_tensor / scale)) + zero_point
:param tensor: [torch.(cuda.)FloatTensor] floating tensor to be quantized
:param bitwidth: [int] quantization bit width
:param scale: [torch.(cuda.)FloatTensor] scaling factor
:param zero_point: [torch.(cuda.)IntTensor] the desired centroid of tensor values
:return:
[torch.(cuda.)FloatTensor] quantized tensor whose values are integers
"""
assert (fp_tensor.dtype == torch.float)
assert (isinstance(scale, float) or
(scale.dtype == torch.float and scale.dim() == fp_tensor.dim()))
assert (isinstance(zero_point, int) or
(zero_point.dtype == dtype and zero_point.dim() == fp_tensor.dim()))
############### YOUR CODE STARTS HERE ###############
# Step 1: scale the fp_tensor
scaled_tensor = fp_tensor.div(scale)
# Step 2: round the floating value to integer value
rounded_tensor = scaled_tensor.round_()
############### YOUR CODE ENDS HERE #################
rounded_tensor = rounded_tensor.to(dtype)
############### YOUR CODE STARTS HERE ###############
# Step 3: shift the rounded_tensor to make zero_point 0
shifted_tensor = rounded_tensor.add_(zero_point)
############### YOUR CODE ENDS HERE #################
# Step 4: clamp the shifted_tensor to lie in bitwidth-bit range
quantized_min, quantized_max = get_quantized_range(bitwidth)
quantized_tensor = shifted_tensor.clamp_(quantized_min, quantized_max)
return quantized_tensor
def get_quantization_scale_and_zero_point(fp_tensor, bitwidth):
"""
get quantization scale for single tensor
:param fp_tensor: [torch.(cuda.)Tensor] floating tensor to be quantized
:param bitwidth: [int] quantization bit width
:return:
[float] scale
[int] zero_point
"""
quantized_min, quantized_max = get_quantized_range(bitwidth)
fp_max = fp_tensor.max().item()
fp_min = fp_tensor.min().item()
############### YOUR CODE STARTS HERE ###############
# hint: one line of code for calculating scale
scale = (fp_max - fp_min) / (quantized_max - quantized_min)
# hint: one line of code for calculating zero_point
zero_point = quantized_min - fp_min / scale
############### YOUR CODE ENDS HERE #################
# clip the zero_point to fall in [quantized_min, quantized_max]
if zero_point < quantized_min:
zero_point = quantized_min
elif zero_point > quantized_max:
zero_point = quantized_max
else: # convert from float to int using round()
zero_point = round(zero_point)
return scale, int(zero_point)
def linear_quantize_feature(fp_tensor, bitwidth):
"""
linear quantization for feature tensor
:param fp_tensor: [torch.(cuda.)Tensor] floating feature to be quantized
:param bitwidth: [int] quantization bit width
:return:
[torch.(cuda.)Tensor] quantized tensor
[float] scale
[int] zero_point
"""
scale, zero_point = get_quantization_scale_and_zero_point(fp_tensor, bitwidth)
quantized_tensor = linear_quantize(fp_tensor, bitwidth, scale, zero_point)
return quantized_tensor, scale, zero_point
def plot_weight_distribution(model, bitwidth=32):
# bins = (1 << bitwidth) if bitwidth <= 8 else 256
if bitwidth <= 8:
qmin, qmax = get_quantized_range(bitwidth)
bins = np.arange(qmin, qmax + 2)
align = 'left'
else:
bins = 256
align = 'mid'
fig, axes = plt.subplots(3, 3, figsize=(10, 6))
axes = axes.ravel()
plot_index = 0
for name, param in model.named_parameters():
if param.dim() > 1:
ax = axes[plot_index]
ax.hist(param.detach().view(-1).cpu(), bins=bins, density=True,
align=align, color='blue', alpha=0.5,
edgecolor='black' if bitwidth <= 4 else None)
if bitwidth <= 4:
quantized_min, quantized_max = get_quantized_range(bitwidth)
ax.set_xticks(np.arange(start=quantized_min, stop=quantized_max + 1))
ax.set_xlabel(name)
ax.set_ylabel('density')
plot_index += 1
fig.suptitle(f'Histogram of Weights (bitwidth={bitwidth} bits)')
fig.tight_layout()
fig.subplots_adjust(top=0.925)
plt.show()
def get_quantization_scale_for_weight(weight, bitwidth):
"""
get quantization scale for single tensor of weight
:param weight: [torch.(cuda.)Tensor] floating weight to be quantized
:param bitwidth: [integer] quantization bit width
:return:
[floating scalar] scale
"""
# we just assume values in weight are symmetric
# we also always make zero_point 0 for weight
fp_max = max(weight.abs().max().item(), 5e-7)
_, quantized_max = get_quantized_range(bitwidth)
return fp_max / quantized_max
def linear_quantize_weight_per_channel(tensor, bitwidth):
"""
linear quantization for weight tensor
using different scales and zero_points for different output channels
:param tensor: [torch.(cuda.)Tensor] floating weight to be quantized
:param bitwidth: [int] quantization bit width
:return:
[torch.(cuda.)Tensor] quantized tensor
[torch.(cuda.)Tensor] scale tensor
[int] zero point (which is always 0)
"""
dim_output_channels = 0
num_output_channels = tensor.shape[dim_output_channels]
scale = torch.zeros(num_output_channels, device=tensor.device)
for oc in range(num_output_channels):
_subtensor = tensor.select(dim_output_channels, oc)
_scale = get_quantization_scale_for_weight(_subtensor, bitwidth)
scale[oc] = _scale
scale_shape = [1] * tensor.dim()
scale_shape[dim_output_channels] = -1
scale = scale.view(scale_shape)
quantized_tensor = linear_quantize(tensor, bitwidth, scale, zero_point=0)
return quantized_tensor, scale, 0
@torch.no_grad()
def peek_linear_quantization():
for bitwidth in [4, 2]:
for name, param in model.named_parameters():
if param.dim() > 1:
quantized_param, scale, zero_point = \
linear_quantize_weight_per_channel(param, bitwidth)
param.copy_(quantized_param)
plot_weight_distribution(model, bitwidth)
recover_model()
def linear_quantize_bias_per_output_channel(bias, weight_scale, input_scale):
"""
linear quantization for single bias tensor
quantized_bias = fp_bias / bias_scale
:param bias: [torch.FloatTensor] bias weight to be quantized
:param weight_scale: [float or torch.FloatTensor] weight scale tensor
:param input_scale: [float] input scale
:return:
[torch.IntTensor] quantized bias tensor
"""
assert (bias.dim() == 1)
assert (bias.dtype == torch.float)
assert (isinstance(input_scale, float))
if isinstance(weight_scale, torch.Tensor):
assert (weight_scale.dtype == torch.float)
weight_scale = weight_scale.view(-1)
assert (bias.numel() == weight_scale.numel())
############### YOUR CODE STARTS HERE ###############
# hint: one line of code
bias_scale = weight_scale * input_scale
############### YOUR CODE ENDS HERE #################
quantized_bias = linear_quantize(bias, 32, bias_scale,
zero_point=0, dtype=torch.int32)
return quantized_bias, bias_scale, 0
def shift_quantized_linear_bias(quantized_bias, quantized_weight, input_zero_point):
"""
shift quantized bias to incorporate input_zero_point for nn.Linear
shifted_quantized_bias = quantized_bias - Linear(input_zero_point, quantized_weight)
:param quantized_bias: [torch.IntTensor] quantized bias (torch.int32)
:param quantized_weight: [torch.CharTensor] quantized weight (torch.int8)
:param input_zero_point: [int] input zero point
:return:
[torch.IntTensor] shifted quantized bias tensor
"""
assert (quantized_bias.dtype == torch.int32)
assert (isinstance(input_zero_point, int))
return quantized_bias - quantized_weight.sum(1).to(torch.int32) * input_zero_point
def quantized_linear(input, weight, bias, feature_bitwidth, weight_bitwidth,
input_zero_point, output_zero_point,
input_scale, weight_scale, output_scale):
"""
quantized fully-connected layer
:param input: [torch.CharTensor] quantized input (torch.int8)
:param weight: [torch.CharTensor] quantized weight (torch.int8)
:param bias: [torch.IntTensor] shifted quantized bias or None (torch.int32)
:param feature_bitwidth: [int] quantization bit width of input and output
:param weight_bitwidth: [int] quantization bit width of weight
:param input_zero_point: [int] input zero point
:param output_zero_point: [int] output zero point
:param input_scale: [float] input feature scale
:param weight_scale: [torch.FloatTensor] weight per-channel scale
:param output_scale: [float] output feature scale
:return:
[torch.CharIntTensor] quantized output feature (torch.int8)
"""
assert (input.dtype == torch.int8)
assert (weight.dtype == input.dtype)
assert (bias is None or bias.dtype == torch.int32)
assert (isinstance(input_zero_point, int))
assert (isinstance(output_zero_point, int))
assert (isinstance(input_scale, float))
assert (isinstance(output_scale, float))
assert (weight_scale.dtype == torch.float)
# Step 1: integer-based fully-connected (8-bit multiplication with 32-bit accumulation)
if 'cpu' in input.device.type:
# use 32-b MAC for simplicity
output = torch.nn.functional.linear(input.to(torch.int32), weight.to(torch.int32), bias)
else:
# current version pytorch does not yet support integer-based linear() on GPUs
output = torch.nn.functional.linear(input.float(), weight.float(), bias.float())
############### YOUR CODE STARTS HERE ###############
# Step 2: scale the output
# hint: 1. scales are floating numbers, we need to convert output to float as well
# 2. the shape of weight scale is [oc, 1, 1, 1] while the shape of output is [batch_size, oc]
output = output.float() * (input_scale * weight_scale / output_scale).view(1, -1)
# Step 3: shift output by output_zero_point
# hint: one line of code
output = output + output_zero_point
############### YOUR CODE ENDS HERE #################
# Make sure all value lies in the bitwidth-bit range
output = output.round().clamp(*get_quantized_range(feature_bitwidth)).to(torch.int8)
return output
def shift_quantized_conv2d_bias(quantized_bias, quantized_weight, input_zero_point):
"""
shift quantized bias to incorporate input_zero_point for nn.Conv2d
shifted_quantized_bias = quantized_bias - Conv(input_zero_point, quantized_weight)
:param quantized_bias: [torch.IntTensor] quantized bias (torch.int32)
:param quantized_weight: [torch.CharTensor] quantized weight (torch.int8)
:param input_zero_point: [int] input zero point
:return:
[torch.IntTensor] shifted quantized bias tensor
"""
assert (quantized_bias.dtype == torch.int32)
assert (isinstance(input_zero_point, int))
return quantized_bias - quantized_weight.sum((1, 2, 3)).to(torch.int32) * input_zero_point
def quantized_conv2d(input, weight, bias, feature_bitwidth, weight_bitwidth,
input_zero_point, output_zero_point,
input_scale, weight_scale, output_scale,
stride, padding, dilation, groups):
"""
quantized 2d convolution
:param input: [torch.CharTensor] quantized input (torch.int8)
:param weight: [torch.CharTensor] quantized weight (torch.int8)
:param bias: [torch.IntTensor] shifted quantized bias or None (torch.int32)
:param feature_bitwidth: [int] quantization bit width of input and output
:param weight_bitwidth: [int] quantization bit width of weight
:param input_zero_point: [int] input zero point
:param output_zero_point: [int] output zero point
:param input_scale: [float] input feature scale
:param weight_scale: [torch.FloatTensor] weight per-channel scale
:param output_scale: [float] output feature scale
:return:
[torch.(cuda.)CharTensor] quantized output feature
"""
assert (len(padding) == 4)
assert (input.dtype == torch.int8)
assert (weight.dtype == input.dtype)
assert (bias is None or bias.dtype == torch.int32)
assert (isinstance(input_zero_point, int))
assert (isinstance(output_zero_point, int))
assert (isinstance(input_scale, float))
assert (isinstance(output_scale, float))
assert (weight_scale.dtype == torch.float)
# Step 1: calculate integer-based 2d convolution (8-bit multiplication with 32-bit accumulation)
input = torch.nn.functional.pad(input, padding, 'constant', input_zero_point)
if 'cpu' in input.device.type:
# use 32-b MAC for simplicity
output = torch.nn.functional.conv2d(input.to(torch.int32), weight.to(torch.int32), None, stride, 0, dilation,
groups)
else:
# current version pytorch does not yet support integer-based conv2d() on GPUs
output = torch.nn.functional.conv2d(input.float(), weight.float(), None, stride, 0, dilation, groups)
output = output.round().to(torch.int32)
if bias is not None:
output = output + bias.view(1, -1, 1, 1)
############### YOUR CODE STARTS HERE ###############
# hint: this code block should be the very similar to quantized_linear()
# Step 2: scale the output
# hint: 1. scales are floating numbers, we need to convert output to float as well
# 2. the shape of weight scale is [oc, 1, 1, 1] while the shape of output is [batch_size, oc, height, width]
output = output.float() * (input_scale * weight_scale / output_scale).view(1, -1, 1, 1)
# Step 3: shift output by output_zero_point
# hint: one line of code
output = output + output_zero_point
############### YOUR CODE ENDS HERE #################
# Make sure all value lies in the bitwidth-bit range
output = output.round().clamp(*get_quantized_range(feature_bitwidth)).to(torch.int8)
return output
def fuse_conv_bn(conv, bn):
# modified from https://mmcv.readthedocs.io/en/latest/_modules/mmcv/cnn/utils/fuse_conv_bn.html
assert conv.bias is None
factor = bn.weight.data / torch.sqrt(bn.running_var.data + bn.eps)
conv.weight.data = conv.weight.data * factor.reshape(-1, 1, 1, 1)
conv.bias = nn.Parameter(- bn.running_mean.data * factor + bn.bias.data)
return conv
def add_range_recoder_hook(model):
import functools
def _record_range(self, x, y, module_name):
x = x[0]
input_activation[module_name] = x.detach()
output_activation[module_name] = y.detach()
all_hooks = []
for name, m in model.named_modules():
if isinstance(m, (nn.Conv2d, nn.Linear, nn.ReLU)):
all_hooks.append(m.register_forward_hook(
functools.partial(_record_range, module_name=name)))
return all_hooks
class QuantizedConv2d(nn.Module):
def __init__(self, weight, bias,
input_zero_point, output_zero_point,
input_scale, weight_scale, output_scale,
stride, padding, dilation, groups,
feature_bitwidth=8, weight_bitwidth=8):
super().__init__()
# current version Pytorch does not support IntTensor as nn.Parameter
self.register_buffer('weight', weight)
self.register_buffer('bias', bias)
self.input_zero_point = input_zero_point
self.output_zero_point = output_zero_point
self.input_scale = input_scale
self.register_buffer('weight_scale', weight_scale)
self.output_scale = output_scale
self.stride = stride
self.padding = (padding[1], padding[1], padding[0], padding[0])
self.dilation = dilation
self.groups = groups
self.feature_bitwidth = feature_bitwidth
self.weight_bitwidth = weight_bitwidth
def forward(self, x):
return quantized_conv2d(
x, self.weight, self.bias,
self.feature_bitwidth, self.weight_bitwidth,
self.input_zero_point, self.output_zero_point,
self.input_scale, self.weight_scale, self.output_scale,
self.stride, self.padding, self.dilation, self.groups
)
class QuantizedLinear(nn.Module):
def __init__(self, weight, bias,
input_zero_point, output_zero_point,
input_scale, weight_scale, output_scale,
feature_bitwidth=8, weight_bitwidth=8):
super().__init__()
# current version Pytorch does not support IntTensor as nn.Parameter
self.register_buffer('weight', weight)
self.register_buffer('bias', bias)
self.input_zero_point = input_zero_point
self.output_zero_point = output_zero_point
self.input_scale = input_scale
self.register_buffer('weight_scale', weight_scale)
self.output_scale = output_scale
self.feature_bitwidth = feature_bitwidth
self.weight_bitwidth = weight_bitwidth
def forward(self, x):
return quantized_linear(
x, self.weight, self.bias,
self.feature_bitwidth, self.weight_bitwidth,
self.input_zero_point, self.output_zero_point,
self.input_scale, self.weight_scale, self.output_scale
)
class QuantizedMaxPool2d(nn.MaxPool2d):
def forward(self, x):
# current version PyTorch does not support integer-based MaxPool
return super().forward(x.float()).to(torch.int8)
class QuantizedAvgPool2d(nn.AvgPool2d):
def forward(self, x):
# current version PyTorch does not support integer-based AvgPool
return super().forward(x.float()).to(torch.int8)
def extra_preprocess(x):
# hint: you need to convert the original fp32 input of range (0, 1)
# into int8 format of range (-128, 127)
############### YOUR CODE STARTS HERE ###############
return (x * 255 - 128).clamp(-128, 127).to(torch.int8)
############### YOUR CODE ENDS HERE #################
if __name__ == "__main__":
# test_linear_quantize()
model_name = 'vgg'
model = VGG().cuda()
checkpoint_path = f"./checkpoints/best_{model_name}.pth.tar"
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location="cpu")
print(f"=> loading checkpoint '{checkpoint_path}'")
model.load_state_dict(checkpoint)
else:
train_flag = True
recover_model = lambda: model.load_state_dict(checkpoint)
transforms = {
"train": Compose([
RandomCrop(32, padding=4),
RandomHorizontalFlip(),
ToTensor(),
]),
"test": ToTensor(),
}
dataset = {}
for split in ["train", "test"]:
dataset[split] = CIFAR10(
root="data/cifar10",
train=(split == "train"),
download=True,
transform=transforms[split],
)
dataloader = {}
for split in ['train', 'test']:
dataloader[split] = DataLoader(
dataset[split],
batch_size=512,
shuffle=(split == 'train'),
num_workers=0,
pin_memory=True,
)
fp32_model_accuracy = evaluate(model, dataloader['test'])
fp32_model_size = get_model_size(model)
print(f"fp32 model has accuracy={fp32_model_accuracy:.2f}%")
print(f"fp32 model has size={fp32_model_size / MiB:.2f} MiB")
use_amp = True
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
# recover_model()
# plot_weight_distribution(model)
# peek_linear_quantization()
# test_quantized_fc()
print('Before conv-bn fusion: backbone length', len(model.backbone))
# fuse the batchnorm into conv layers
recover_model()
model_fused = copy.deepcopy(model)
fused_backbone = []
ptr = 0
while ptr < len(model_fused.backbone):
if isinstance(model_fused.backbone[ptr], nn.Conv2d) and \
isinstance(model_fused.backbone[ptr + 1], nn.BatchNorm2d):
fused_backbone.append(fuse_conv_bn(
model_fused.backbone[ptr], model_fused.backbone[ptr + 1]))
ptr += 2
else:
fused_backbone.append(model_fused.backbone[ptr])
ptr += 1
model_fused.backbone = nn.Sequential(*fused_backbone)
print('After conv-bn fusion: backbone length', len(model_fused.backbone))
# sanity check, no BN anymore
for m in model_fused.modules():
assert not isinstance(m, nn.BatchNorm2d)
# the accuracy will remain the same after fusion
fused_acc = evaluate(model_fused, dataloader['test'])
print(f'Accuracy of the fused model={fused_acc:.2f}%')
# add hook to record the min max value of the activation
input_activation = {}
output_activation = {}
hooks = add_range_recoder_hook(model_fused)
sample_data = iter(dataloader['train']).__next__()[0]
model_fused(sample_data.cuda())
# remove hooks
for h in hooks:
h.remove()
# we use int8 quantization, which is quite popular
feature_bitwidth = weight_bitwidth = 8
quantized_model = copy.deepcopy(model_fused)
quantized_backbone = []
ptr = 0
while ptr < len(quantized_model.backbone):
if isinstance(quantized_model.backbone[ptr], nn.Conv2d) and \
isinstance(quantized_model.backbone[ptr + 1], nn.ReLU):
conv = quantized_model.backbone[ptr]
conv_name = f'backbone.{ptr}'
relu = quantized_model.backbone[ptr + 1]
relu_name = f'backbone.{ptr + 1}'
input_scale, input_zero_point = \
get_quantization_scale_and_zero_point(
input_activation[conv_name], feature_bitwidth)
output_scale, output_zero_point = \
get_quantization_scale_and_zero_point(
output_activation[relu_name], feature_bitwidth)
quantized_weight, weight_scale, weight_zero_point = \
linear_quantize_weight_per_channel(conv.weight.data, weight_bitwidth)
quantized_bias, bias_scale, bias_zero_point = \
linear_quantize_bias_per_output_channel(
conv.bias.data, weight_scale, input_scale)
shifted_quantized_bias = \
shift_quantized_conv2d_bias(quantized_bias, quantized_weight,
input_zero_point)
quantized_conv = QuantizedConv2d(
quantized_weight, shifted_quantized_bias,
input_zero_point, output_zero_point,
input_scale, weight_scale, output_scale,
conv.stride, conv.padding, conv.dilation, conv.groups,
feature_bitwidth=feature_bitwidth, weight_bitwidth=weight_bitwidth
)
quantized_backbone.append(quantized_conv)
ptr += 2
elif isinstance(quantized_model.backbone[ptr], nn.MaxPool2d):
quantized_backbone.append(QuantizedMaxPool2d(
kernel_size=quantized_model.backbone[ptr].kernel_size,
stride=quantized_model.backbone[ptr].stride
))
ptr += 1
elif isinstance(quantized_model.backbone[ptr], nn.AvgPool2d):
quantized_backbone.append(QuantizedAvgPool2d(
kernel_size=quantized_model.backbone[ptr].kernel_size,
stride=quantized_model.backbone[ptr].stride
))
ptr += 1
else:
raise NotImplementedError(type(quantized_model.backbone[ptr])) # should not happen
quantized_model.backbone = nn.Sequential(*quantized_backbone)
# finally, quantized the classifier
fc_name = 'classifier'
fc = model.classifier
input_scale, input_zero_point = \
get_quantization_scale_and_zero_point(
input_activation[fc_name], feature_bitwidth)
output_scale, output_zero_point = \
get_quantization_scale_and_zero_point(
output_activation[fc_name], feature_bitwidth)
quantized_weight, weight_scale, weight_zero_point = \
linear_quantize_weight_per_channel(fc.weight.data, weight_bitwidth)
quantized_bias, bias_scale, bias_zero_point = \
linear_quantize_bias_per_output_channel(
fc.bias.data, weight_scale, input_scale)
shifted_quantized_bias = \
shift_quantized_linear_bias(quantized_bias, quantized_weight,
input_zero_point)
quantized_model.classifier = QuantizedLinear(
quantized_weight, shifted_quantized_bias,
input_zero_point, output_zero_point,
input_scale, weight_scale, output_scale,
feature_bitwidth=feature_bitwidth, weight_bitwidth=weight_bitwidth
)
print(quantized_model)
int8_model_accuracy = evaluate(quantized_model, dataloader['test'],
extra_preprocess=[extra_preprocess])
print(f"int8 model has accuracy={int8_model_accuracy:.2f}%")