Skip to content

Commit c4bac31

Browse files
[TorchFX] SharedQuantizationSpec support
1 parent 6d0aacd commit c4bac31

30 files changed

+16589
-8746
lines changed

nncf/common/quantization/quantizer_setup.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -247,12 +247,13 @@ def __init__(self) -> None:
247247
self._next_unified_scale_gid = 0
248248
self._next_shared_inputs_gid = 0
249249

250-
def add_independent_quantization_point(self, qp: QuantizationPointBase) -> None:
250+
def add_independent_quantization_point(self, qp: QuantizationPointBase) -> int:
251251
if self.quantization_points.keys():
252252
new_id = max(self.quantization_points.keys()) + 1
253253
else:
254254
new_id = 0
255255
self.quantization_points[new_id] = qp
256+
return new_id
256257

257258
def register_unified_scale_group(self, qp_group: list[QuantizationPointId]) -> int:
258259
for qp_id in qp_group:

nncf/experimental/torch/fx/quantization/quantizer/torch_ao_adapter.py

+94-61
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,18 @@
1111

1212

1313
from collections import defaultdict
14-
from typing import Union
14+
from typing import Any, Dict, List, Tuple, Union
1515

1616
import torch
1717
import torch.fx
18+
from torch.ao.quantization.pt2e.prepare import _get_edge_or_node_to_group_id
19+
from torch.ao.quantization.pt2e.prepare import _get_edge_or_node_to_qspec
1820
from torch.ao.quantization.quantizer import Quantizer as TorchAOQuantizer
1921
from torch.ao.quantization.quantizer.quantizer import QuantizationSpec
20-
from torch.ao.quantization.quantizer.quantizer import QuantizationSpecBase
2122
from torch.ao.quantization.quantizer.quantizer import SharedQuantizationSpec
2223

2324
import nncf
2425
from nncf.common.graph.graph import NNCFGraph
25-
from nncf.common.logging import nncf_logger
2626
from nncf.common.quantization.quantizer_setup import ActivationQuantizationInsertionPoint
2727
from nncf.common.quantization.quantizer_setup import QuantizationPointBase
2828
from nncf.common.quantization.quantizer_setup import SingleConfigQuantizationPoint
@@ -72,7 +72,16 @@ def _get_quantization_points(
7272
to_nodes: list[torch.fx.Node],
7373
annotated_model: torch.fx.GraphModule,
7474
qconfig: QuantizerConfig,
75-
) -> list[QuantizationPointBase]:
75+
) -> List[QuantizationPointBase]:
76+
"""
77+
Creates quantization points based on the nodes and edges.
78+
79+
:param from_node: The originating node in the computation graph.
80+
:param to_nodes: The list of destination nodes of the from_node.
81+
:param annotated_model: The torch.fx.GraphModule instance.
82+
:param qconfig: The torch.ao quantization configuration.
83+
:return: A list of NNCF quantization points.
84+
"""
7685
to_n = to_nodes[0]
7786
if from_node.op == "get_attr":
7887
_, metatype = GraphConverter.get_node_type_and_metatype(to_n, annotated_model)
@@ -95,78 +104,102 @@ def _get_quantization_points(
95104
return qps
96105

97106
@staticmethod
98-
def _get_node_args(node: torch.fx.Node):
107+
def _get_node_args(node: torch.fx.Node) -> Tuple[Any, ...]:
108+
"""
109+
Correctly retrieves arguments of the given node.
110+
111+
:param node: The given node.
112+
:return: The arguments of the given node.
113+
"""
99114
if node.target == torch.ops.aten.cat.default:
100115
return node.args[0]
101116
return node.args
102117

103118
@staticmethod
104-
def get_quantizer_config_from_annotated_model(annotated_model: torch.fx.GraphModule) -> SingleConfigQuantizerSetup:
105-
edge_or_node_to_qspec = _get_edge_or_node_to_qspec(annotated_model)
106-
107-
q_map = defaultdict(list)
108-
for edge, qspec in edge_or_node_to_qspec.items():
109-
if not isinstance(edge, tuple):
110-
continue
111-
from_n, to_n = edge
112-
q_map[from_n].append(to_n)
119+
def get_quantizer_config_from_annotated_model(annotated: torch.fx.GraphModule) -> SingleConfigQuantizerSetup:
120+
edge_or_node_to_qspec = _get_edge_or_node_to_qspec(annotated)
121+
# Node means all output edges should be quantized.
122+
# Edge means only one edge should be quantized.
123+
edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec)
124+
125+
group_id_vs_edges = defaultdict(set)
126+
group_id_vs_qspec = {}
127+
for edge_or_node, group_id in edge_or_node_to_group_id.items():
128+
target_edges = [edge_or_node]
129+
if isinstance(edge_or_node, torch.fx.Node):
130+
target_edges = []
131+
for user in edge_or_node.users:
132+
target_edges.append((edge_or_node, user))
133+
group_id_vs_edges[group_id].update(target_edges)
134+
# All qspecs should be aligned after the _get_edge_or_node_to_group_id call
135+
group_id_vs_qspec[group_id] = _unwrap_shared_qspec_safe(
136+
edge_or_node_to_qspec[edge_or_node], edge_or_node_to_qspec
137+
)
113138

114139
q_setup = SingleConfigQuantizerSetup()
115-
for from_n, to_nodes in q_map.items():
116-
to_n = to_nodes[0]
117-
qspec = edge_or_node_to_qspec[(from_n, to_n)]
140+
for group_id, edges in group_id_vs_edges.items():
141+
qspec = group_id_vs_qspec[group_id]
118142
if qspec is None:
119143
continue
120-
if isinstance(qspec, QuantizationSpec):
121-
if qspec.qscheme in [torch.per_channel_affine, torch.per_channel_symmetric]:
122-
per_channel = True
123-
elif qspec.qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
124-
per_channel = False
125-
else:
126-
msg = f"Unknown qscheme: {qspec.qscheme}"
127-
raise nncf.InternalError(msg)
128-
signed = qspec.dtype is torch.int8
129-
mode = (
130-
QuantizationMode.SYMMETRIC
131-
if qspec.qscheme in [torch.per_channel_symmetric, torch.per_tensor_symmetric]
132-
else QuantizationMode.ASYMMETRIC
133-
)
134-
qconfig = QuantizerConfig(mode=mode, signedness_to_force=signed, per_channel=per_channel)
135-
136-
qps = TorchAOQuantizerAdapter._get_quantization_points(from_n, to_nodes, annotated_model, qconfig)
137-
for qp in qps:
138-
q_setup.add_independent_quantization_point(qp)
139-
140-
elif isinstance(qspec, SharedQuantizationSpec):
141-
# TODO(dlyakhov): Support SharedQuantizationSpec
142-
nncf_logger.warning(
143-
f"SharedQuantizationSpec is not supported yet; edges {from_n} -> {to_nodes} won't be quantized."
144-
)
145-
else:
144+
if not isinstance(qspec, QuantizationSpec):
146145
msg = f"Unknown torch.ao quantization spec: {qspec}"
147146
raise nncf.InternalError(msg)
148147

148+
if qspec.qscheme in [torch.per_channel_affine, torch.per_channel_symmetric]:
149+
per_channel = True
150+
elif qspec.qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
151+
per_channel = False
152+
else:
153+
msg = f"Unknown qscheme: {qspec.qscheme}"
154+
raise nncf.InternalError(msg)
155+
156+
signed = qspec.dtype is torch.int8
157+
mode = (
158+
QuantizationMode.SYMMETRIC
159+
if qspec.qscheme in [torch.per_channel_symmetric, torch.per_tensor_symmetric]
160+
else QuantizationMode.ASYMMETRIC
161+
)
162+
narrow_range = qspec.quant_min % 2 != 0
163+
qconfig = QuantizerConfig(
164+
mode=mode, signedness_to_force=signed, per_channel=per_channel, narrow_range=narrow_range
165+
)
166+
167+
joined_edges = defaultdict(list)
168+
for edge in edges:
169+
joined_edges[edge[0]].append(edge[1])
170+
171+
qps = []
172+
for from_node, to_nodes in joined_edges.items():
173+
qps.extend(TorchAOQuantizerAdapter._get_quantization_points(from_node, to_nodes, annotated, qconfig))
174+
qp_ids = []
175+
for qp in qps:
176+
qp_ids.append(q_setup.add_independent_quantization_point(qp))
177+
if len(qp_ids) > 1:
178+
q_setup.register_unified_scale_group(qp_ids)
179+
149180
return q_setup
150181

151182

152-
def _get_edge_or_node_to_qspec(
153-
model: torch.fx.GraphModule,
154-
) -> dict[EdgeOrNode, QuantizationSpecBase]:
183+
def _unwrap_shared_qspec_safe(qspec: QuantizationSpec, edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpec]):
155184
"""
156-
Get a map from EdgeOrNode to quantization spec based on annotations on the nodes.
185+
Iteratively unwraps a given SharedQuantizationSpec to retrieve its actual QuantizationSpec.
186+
It detects cyclic dependencies and enforces a maximum depth limit to prevent infinite recursion.
157187
158-
:param model: torch.fx.GraphModule instance.
159-
:return: A map from EdgeOrNode to quantization spec based on annotations on the nodes.
188+
:param qspec: The quantization specification to unwrap.
189+
:param edge_or_node_to_qspec: A dictionary mapping EdgeOrNode instances to their respective QuantizationSpec.
190+
:return: The resolved QuantizationSpec.
160191
"""
161-
edge_or_node_to_qspec: dict[EdgeOrNode, QuantizationSpecBase] = {}
162-
for n in model.graph.nodes:
163-
if hasattr(n, "meta") and "quantization_annotation" in n.meta:
164-
qa = n.meta["quantization_annotation"]
165-
for input_to_n, qspec in qa.input_qspec_map.items():
166-
input_edge = (input_to_n, n)
167-
edge_or_node_to_qspec[input_edge] = qspec
168-
if qa.output_qspec is not None:
169-
output_node = n
170-
qspec = qa.output_qspec
171-
edge_or_node_to_qspec[output_node] = qspec
172-
return edge_or_node_to_qspec
192+
MAX_DEPTH = 1000
193+
i = 0
194+
visited = []
195+
while i < MAX_DEPTH and isinstance(qspec, SharedQuantizationSpec):
196+
if qspec.edge_or_node in visited:
197+
msg = f"A cycled dependency of the quantization spec is detected {visited + [qspec.edge_or_node]}"
198+
raise RuntimeError(msg)
199+
visited.append(qspec.edge_or_node)
200+
qspec = edge_or_node_to_qspec[qspec.edge_or_node]
201+
i += 1
202+
if i == MAX_DEPTH:
203+
msg = f"Shared qspecs referenced to each other more than the limit: {MAX_DEPTH}"
204+
raise RuntimeError(msg)
205+
return qspec

0 commit comments

Comments
 (0)