-
Notifications
You must be signed in to change notification settings - Fork 526
/
Copy pathexample_partitioner.py
88 lines (74 loc) · 3.47 KB
/
example_partitioner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, final
import torch
from executorch.backends.example.example_backend import ExampleBackend
from executorch.backends.example.example_operators.ops import module_to_annotator
from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
generate_partitions_from_list_of_nodes,
)
from executorch.exir.backend.partitioner import (
DelegationSpec,
Partitioner,
PartitionResult,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.graph_module import get_control_flow_submodules
from torch.export import ExportedProgram
from torch.fx.passes.operator_support import OperatorSupportBase
from torchao.quantization.pt2e.pt2e.graph_utils import find_sequential_partitions
@final
class ExamplePartitioner(Partitioner):
"""
Partitions all add/mul nodes regardless of order
"""
def __init__(self) -> None:
self.patterns = module_to_annotator.keys()
self.delegation_spec = DelegationSpec(ExampleBackend.__name__, [])
class DequantQuantOperatorSupport(OperatorSupportBase):
def is_node_supported(self, _submodules, node: torch.fx.Node) -> bool:
return node.op == "call_function" and node.target in [
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
]
self.dequant_quant_support = DequantQuantOperatorSupport()
def _partition_graph_module(
self, edge_graph_module: torch.fx.GraphModule
) -> Dict[str, DelegationSpec]:
partition_tags: Dict[str, DelegationSpec] = {}
partition_nodes = []
for pattern in self.patterns:
fused_partitions = find_sequential_partitions(
edge_graph_module,
pattern,
)
for fused_partition in fused_partitions:
for partition in fused_partition:
partition_nodes.append(partition.nodes)
partitions = generate_partitions_from_list_of_nodes(
edge_graph_module, partition_nodes, self.dequant_quant_support
)
for partition in partitions:
for node in partition.nodes:
delegation_tag = f"tag{partition.id}"
node.meta["delegation_tag"] = delegation_tag
if node.op == "call_function":
for arg_node in node.args:
if (
isinstance(arg_node, torch.fx.Node)
and arg_node.op == "get_attr"
):
arg_node.meta["delegation_tag"] = delegation_tag
partition_tags[delegation_tag] = self.delegation_spec
for _, submodule, _ in get_control_flow_submodules(edge_graph_module):
submodule_partition_tags = self._partition_graph_module(submodule)
partition_tags.update(submodule_partition_tags)
return partition_tags
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
partition_tag = self._partition_graph_module(exported_program.graph_module)
return PartitionResult(
tagged_exported_program=exported_program, partition_tags=partition_tag
)