-
Notifications
You must be signed in to change notification settings - Fork 526
/
Copy pathpatterns.py
342 lines (273 loc) · 10.1 KB
/
patterns.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2025 NXP
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import List, Optional, Tuple, Type, Union
import torch
from executorch.backends.nxp.quantizer.utils import get_bias_qparams
from torch import fx
from torch._ops import OpOverload
from torchao.quantization.pt2e.quantizer import (
DerivedQuantizationSpec,
FixedQParamsQuantizationSpec,
SharedQuantizationSpec,
)
@dataclass
class PartitionAnchors:
"""
All fields except output are lists of (node, args_index) pair, where node is from
the given partition and node.args[args_index] is an input to the partition. Assumes
a single output.
Quantizer uses inputs, weights and biases for quantization annotation. The others
field contains tensor inputs that aren't quantized, and the literals fields contains
is used for other types of input values as well as handling default parameters.
"""
# Inputs can share quantization parameters
inputs: List[
Union[
Tuple[fx.Node, Union[int, Tuple[int, int]]],
Tuple[
fx.Node,
Union[int, Tuple[int, int]],
SharedQuantizationSpec,
],
]
] = field(default_factory=list)
weights: List[Tuple[fx.Node, int]] = field(default_factory=list)
biases: List[
Union[Tuple[fx.Node, int], Tuple[fx.Node, int, DerivedQuantizationSpec]]
] = field(default_factory=list)
others: List[Tuple[fx.Node, int]] = field(default_factory=list)
literals: List[Tuple[fx.Node, int]] = field(default_factory=list)
output: List[Union[Tuple[fx.Node], Tuple[fx.Node, SharedQuantizationSpec]]] = field(
default_factory=list
)
empty: bool = False
class QuantizationPattern(ABC):
@abstractmethod
def partition_types(self) -> list[OpOverload]:
"""
List of types to be passed to find_sequential_partitions_aten.
"""
pass
@abstractmethod
def get_anchors(
self, gm: torch.fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> Optional[PartitionAnchors]:
pass
class SharedSpecPattern(QuantizationPattern):
"""
Quantization pattern for shared quantization.
The quantization is derived from the previous node quantization and the input and output shares the same
quantization parameters (scale and zero-point).
"""
def partition_types(self) -> List[Type[torch.nn.Module]]:
pass
def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors | None:
node = fused_partition[0].nodes[-1]
assert len(fused_partition[0].input_nodes) == 1
prev_node = fused_partition[0].input_nodes[0]
# Previous node was not quantized => we are not able to share q-params
if "quantization_annotation" not in prev_node.meta:
return None
qspec = SharedQuantizationSpec(prev_node)
return PartitionAnchors(
inputs=[(node, 0)],
weights=[],
biases=[],
output=[
(node, qspec),
],
)
class AddmmPattern(QuantizationPattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.addmm.default]
def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
addmm_node = fused_partition[0].nodes[-1]
bias_qspec = DerivedQuantizationSpec(
derived_from=[
(addmm_node.args[1], addmm_node),
(addmm_node.args[2], addmm_node),
],
derive_qparams_fn=get_bias_qparams,
dtype=torch.int32,
quant_min=-(2**31),
quant_max=2**31 - 1,
qscheme=torch.per_tensor_affine,
)
return PartitionAnchors(
inputs=[(addmm_node, 1)],
weights=[(addmm_node, 2)],
biases=[(addmm_node, 0, bias_qspec)],
output=[(addmm_node,)],
)
class AvgPoolPattern(SharedSpecPattern):
"""
Quantizer for AvgPool2D operator.
"""
def partition_types(self):
return [torch.ops.aten.avg_pool2d.default]
class Conv1dPattern(QuantizationPattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.conv1d.default]
def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
conv1d_node = fused_partition[0].nodes[-1]
bias_qspec = DerivedQuantizationSpec(
derived_from=[
(conv1d_node.args[0], conv1d_node),
(conv1d_node.args[1], conv1d_node),
],
derive_qparams_fn=get_bias_qparams,
dtype=torch.int32,
quant_min=-(2**31),
quant_max=2**31 - 1,
qscheme=torch.per_tensor_affine,
)
# Keep bias empty if not supplied
bias = []
if len(conv1d_node.args) > 2 and conv1d_node.args[2] is not None:
bias = [(conv1d_node, 2, bias_qspec)]
return PartitionAnchors(
inputs=[(conv1d_node, 0)],
weights=[(conv1d_node, 1)],
# pyre-fixme[6]: Incompatible parameter type
biases=bias,
output=[(conv1d_node,)],
)
class Conv2dPattern(QuantizationPattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.conv2d.default]
def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
conv2d_node = fused_partition[0].nodes[-1]
bias_qspec = DerivedQuantizationSpec(
derived_from=[
(conv2d_node.args[0], conv2d_node),
(conv2d_node.args[1], conv2d_node),
],
derive_qparams_fn=get_bias_qparams,
dtype=torch.int32,
quant_min=-(2**31),
quant_max=2**31 - 1,
qscheme=torch.per_tensor_affine,
)
# Keep bias empty if not supplied
bias = []
if len(conv2d_node.args) > 2 and conv2d_node.args[2] is not None:
bias = [(conv2d_node, 2, bias_qspec)]
return PartitionAnchors(
inputs=[(conv2d_node, 0)],
weights=[(conv2d_node, 1)],
# pyre-fixme[6]: Incompatible parameter type
biases=bias,
output=[(conv2d_node,)],
)
class LinearPattern(QuantizationPattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.linear.default]
def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C.TensorBase.__ge...
linear_node = fused_partition[0].nodes[-1]
bias_qspec = DerivedQuantizationSpec(
derived_from=[
(linear_node.args[0], linear_node),
(linear_node.args[1], linear_node),
],
derive_qparams_fn=get_bias_qparams,
dtype=torch.int32,
quant_min=-(2**31),
quant_max=2**31 - 1,
qscheme=torch.per_tensor_affine,
)
# Keep bias empty if not supplied
bias = []
if len(linear_node.args) > 2:
bias = [(linear_node, 2, bias_qspec)]
return PartitionAnchors(
inputs=[(linear_node, 0)],
weights=[(linear_node, 1)],
# pyre-fixme[6]: Incompatible parameter type
biases=bias,
output=[(linear_node,)],
)
class MaxPoolPattern(SharedSpecPattern):
"""
Quantizer for MaxPool2D operator.
"""
def partition_types(self):
return [torch.ops.aten.max_pool2d.default]
class PadPattern(SharedSpecPattern):
"""
Quantizer for Pad operator.
"""
def partition_types(self):
return [torch.ops.aten.pad.default]
class PermutePattern(SharedSpecPattern):
"""
Quantizer for Permute operator.
"""
def partition_types(self):
return [torch.ops.aten.permute.default]
class ReluPattern(SharedSpecPattern):
"""
Quantizer for Relu operator. Shared quantization spec is selected, as ReLU usually follows computation layer.
"""
def partition_types(self):
return [torch.ops.aten.relu.default]
class ReluInPlacePattern(SharedSpecPattern):
"""
Quantizer for Relu operator with param inplace=True. Shared quantization spec is selected, as ReLU usually
follows computation layer.
"""
def partition_types(self):
return [torch.ops.aten.relu_.default]
class ReshapePattern(SharedSpecPattern):
"""
Quantizer for Reshape operator.
"""
def partition_types(self):
return [torch.ops.aten.reshape.default]
class SoftMaxPattern(QuantizationPattern):
"""
Quantizer for Softmax operator.
The quantization of Softmax output is fixed to scale 1/256, zero point -128, dtype int8.
"""
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.softmax.int]
def get_anchors(
self, gm: fx.GraphModule, fused_partition: List[fx.GraphModule]
) -> PartitionAnchors:
node = fused_partition[0].nodes[-1]
assert len(fused_partition[0].input_nodes) == 1
qspec = FixedQParamsQuantizationSpec(
dtype=torch.int8,
scale=1.0 / 256.0,
zero_point=-128,
quant_min=-128,
quant_max=127,
qscheme=torch.per_tensor_affine,
)
return PartitionAnchors(
inputs=[(node, 0)],
weights=[],
biases=[],
output=[
(node, qspec),
],
)