-
Notifications
You must be signed in to change notification settings - Fork 526
/
Copy pathqconfig.py
171 lines (137 loc) · 5.33 KB
/
qconfig.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
# Copyright (c) 2024 MediaTek Inc.
#
# Licensed under the BSD License (the "License"); you may not use this file
# except in compliance with the License. See the license file in the root
# directory of this source tree for more details.
import copy
from enum import IntEnum, unique
import torch
from torchao.quantization.pt2e.fake_quantize import FakeQuantize
from torchao.quantization.pt2e.observer import MinMaxObserver, PerChannelMinMaxObserver
from torchao.quantization.pt2e.quantizer import QuantizationSpec
@unique
class Precision(IntEnum):
A16W16 = 0
A16W8 = 1
A16W4 = 2
A8W8 = 3
A8W4 = 4
class QuantizationConfig:
def __init__(
self, activation_spec: QuantizationSpec, weight_spec: QuantizationSpec
):
self._activation_spec = activation_spec
self._weight_spec = weight_spec
@property
def activation(self):
return copy.deepcopy(self._activation_spec)
@property
def weight(self):
return copy.deepcopy(self._weight_spec)
def get_quant_config(
precision: Precision,
is_per_channel: bool = False,
is_qat: bool = False,
) -> QuantizationConfig:
precision_mappings = {
Precision.A16W16: get_a16w16_quant_config,
Precision.A16W8: get_a16w8_quant_config,
Precision.A16W4: get_a16w4_quant_config,
Precision.A8W8: get_a8w8_quant_config,
Precision.A8W4: get_a8w4_quant_config,
}
if precision not in precision_mappings:
raise RuntimeError("Unrecognized precision setting.")
qconfig_fn = precision_mappings[precision]
return qconfig_fn(is_per_channel, is_qat)
def _get_activation_qspec(
dtype,
is_symmetric,
is_qat,
observer_cls=MinMaxObserver,
quant_min=None,
quant_max=None,
):
if quant_max is None:
quant_max = torch.iinfo(dtype).max
if quant_min is None:
# quant_min = torch.iinfo(dtype).min + 1 if is_symmetric else torch.iinfo(dtype).min
quant_min = torch.iinfo(dtype).min
qscheme = torch.per_tensor_symmetric if is_symmetric else torch.per_tensor_affine
if is_qat:
observer_or_fake_quant = FakeQuantize.with_args(observer=observer_cls, eps=1e-6)
else:
observer_or_fake_quant = observer_cls.with_args(eps=1e-6)
return QuantizationSpec(
dtype=dtype,
quant_min=quant_min,
quant_max=quant_max,
qscheme=qscheme,
observer_or_fake_quant_ctr=observer_or_fake_quant,
)
def _get_weight_qspec(
dtype, is_symmetric, is_per_channel, is_qat, quant_min=None, quant_max=None
):
if not is_per_channel:
return _get_activation_qspec(
dtype, is_symmetric, is_qat, observer_cls=MinMaxObserver
)
if quant_max is None:
quant_max = torch.iinfo(dtype).max
if quant_min is None:
# quant_min = torch.iinfo(dtype).min + 1 if is_symmetric else torch.iinfo(dtype).min
quant_min = torch.iinfo(dtype).min
qscheme = torch.per_channel_symmetric if is_symmetric else torch.per_channel_affine
if is_qat:
observer_or_fake_quant = FakeQuantize.with_args(
observer=PerChannelMinMaxObserver, eps=1e-6
)
else:
observer_or_fake_quant = PerChannelMinMaxObserver.with_args(eps=1e-6)
return QuantizationSpec(
dtype=dtype,
quant_min=quant_min,
quant_max=quant_max,
qscheme=qscheme,
ch_axis=0,
observer_or_fake_quant_ctr=observer_or_fake_quant,
)
def get_a16w16_quant_config(is_per_channel, is_qat) -> QuantizationConfig:
act_quantization_spec = _get_activation_qspec(torch.int16, True, is_qat)
wgt_quantization_spec = _get_weight_qspec(torch.int16, True, is_per_channel, is_qat)
quantization_config = QuantizationConfig(
act_quantization_spec, wgt_quantization_spec
)
return quantization_config
def get_a16w8_quant_config(is_per_channel, is_qat) -> QuantizationConfig:
act_quantization_spec = _get_activation_qspec(torch.int16, True, is_qat)
wgt_quantization_spec = _get_weight_qspec(torch.int8, True, is_per_channel, is_qat)
quantization_config = QuantizationConfig(
act_quantization_spec, wgt_quantization_spec
)
return quantization_config
def get_a16w4_quant_config(is_per_channel, is_qat) -> QuantizationConfig:
act_quantization_spec = _get_activation_qspec(torch.int16, True, is_qat)
wgt_quantization_spec = _get_weight_qspec(
torch.int8, False, is_per_channel, is_qat, quant_min=-8, quant_max=7
)
quantization_config = QuantizationConfig(
act_quantization_spec, wgt_quantization_spec
)
return quantization_config
def get_a8w8_quant_config(is_per_channel, is_qat) -> QuantizationConfig:
act_quantization_spec = _get_activation_qspec(torch.int8, False, is_qat)
wgt_quantization_spec = _get_weight_qspec(torch.int8, False, is_per_channel, is_qat)
quantization_config = QuantizationConfig(
act_quantization_spec, wgt_quantization_spec
)
return quantization_config
def get_a8w4_quant_config(is_per_channel, is_qat) -> QuantizationConfig:
act_quantization_spec = _get_activation_qspec(torch.int8, False, is_qat)
wgt_quantization_spec = _get_weight_qspec(
torch.int8, False, is_per_channel, is_qat, quant_min=-8, quant_max=7
)
quantization_config = QuantizationConfig(
act_quantization_spec, wgt_quantization_spec
)
return quantization_config