-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlinear_quantization.py
More file actions
103 lines (72 loc) · 3.2 KB
/
linear_quantization.py
File metadata and controls
103 lines (72 loc) · 3.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
### Quantizing and Dequantizing a Tensor with Torch ###
import torch
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import torch
## Linear Quantization uses a Scale and Zero Point ##
# r = s(q-z)
def linear_q_with_scale_and_zero_point(tensor, scale, zero_point, dtype = torch.int8):
scaled_and_shifted_tensor = tensor / scale + zero_point
rounded_tensor = torch.round(scaled_and_shifted_tensor)
q_min = torch.iinfo(dtype).min
q_max = torch.iinfo(dtype).max
q_tensor = rounded_tensor.clamp(q_min,q_max).to(dtype)
return q_tensor
### a dummy tensor to test the implementation
test_tensor=torch.tensor(
[[191.6, -13.5, 728.6],
[92.14, 295.5, -184],
[0, 684.6, 245.5]]
)
### these are random values for "scale" and "zero_point"
### to test the implementation
scale = 3.5
zero_point = -70
quantized_tensor = linear_q_with_scale_and_zero_point(test_tensor, scale, zero_point)
print(quantized_tensor)
## Dequantization is the Inverse Function ##
dequantized_tensor = scale * (quantized_tensor.float() - zero_point)
# this was the original tensor
# [[191.6, -13.5, 728.6],
# [92.14, 295.5, -184],
# [0, 684.6, 245.5]]
print(dequantized_tensor)
### without casting to float
scale * (quantized_tensor - zero_point)
def linear_dequantization(quantized_tensor, scale, zero_point):
return scale * (quantized_tensor.float() - zero_point)
dequantized_tensor = linear_dequantization(quantized_tensor, scale, zero_point)
print(dequantized_tensor)
## Quantization Error Calculation ##
def plot_matrix(tensor, ax, title, vmin=0, vmax=1, cmap=None):
"""
Plot a heatmap of tensors using seaborn
"""
sns.heatmap(tensor.cpu().numpy(), ax=ax, vmin=vmin, vmax=vmax, cmap=cmap, annot=True, fmt=".2f", cbar=False)
ax.set_title(title)
ax.set_yticklabels([])
ax.set_xticklabels([])
def plot_quantization_errors(original_tensor, quantized_tensor, dequantized_tensor, dtype = torch.int8, n_bits = 8):
"""
A method that plots 4 matrices, the original tensor, the quantized tensor
the de-quantized tensor and the error tensor.
"""
# Get a figure of 4 plots
fig, axes = plt.subplots(1, 4, figsize=(15, 4))
# Plot the first matrix
plot_matrix(original_tensor, axes[0], 'Original Tensor', cmap=ListedColormap(['white']))
# Get the quantization range and plot the quantized tensor
q_min, q_max = torch.iinfo(dtype).min, torch.iinfo(dtype).max
plot_matrix(quantized_tensor, axes[1], f'{n_bits}-bit Linear Quantized Tensor', vmin=q_min, vmax=q_max, cmap='coolwarm')
# Plot the de-quantized tensors
plot_matrix(dequantized_tensor, axes[2], 'Dequantized Tensor', cmap='coolwarm')
# Get the quantization errors
q_error_tensor = abs(original_tensor - dequantized_tensor)
plot_matrix(q_error_tensor, axes[3], 'Quantization Error Tensor', cmap=ListedColormap(['white']))
fig.tight_layout()
plt.show()
plot_quantization_errors(test_tensor, quantized_tensor, dequantized_tensor)
print("Error Tensor:", dequantized_tensor - test_tensor)
print("Squared Error Tensor:", (dequantized_tensor - test_tensor).square())
print("MSE:", (dequantized_tensor - test_tensor).square().mean())