Skip to content

Commit c3918b8

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Replace split_with_sizes_copy with slice_copy
Summary: This is useful for a few reasons, which are all related to the fact that we have a lot of things in place for slice ops, so the compiler can optimize much better if we replace `split_with_sizes`. That includes: - swapping with q/dq nodes to do it in lower bitwidth - becoming slice_nop when possible (through memory planning constraints) - a few simplification passes on slices - possibly finding patterns with slice/cat being a nop - we have an optimized slice op on HiFi already. We will get the split op soon to be fair, we can do some profiling then, but the above points will still stand Differential Revision: D73312379
1 parent 06f912d commit c3918b8

File tree

1 file changed

+77
-2
lines changed

1 file changed

+77
-2
lines changed

Diff for: backends/cadence/aot/replace_ops.py

+77-2
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717
# pyre-unsafe
1818

1919
import math
20+
import operator
2021
from operator import neg
21-
from typing import cast, Dict, Iterable, Sequence, Set, Tuple
22+
from typing import cast, Dict, Iterable, Optional, Sequence, Set, Tuple
2223

2324
import torch
2425
import torch.fx
@@ -2182,6 +2183,79 @@ def call_operator(
21822183
)
21832184

21842185

2186+
# Adapted from fbcode/pyspeech/opt_passes/replace_ops.py
2187+
@register_cadence_pass(CadencePassAttribute(opt_level=2))
2188+
class ReplaceSplitWithSlicePass(ExportPass):
2189+
"""
2190+
split_with_sizes() delegates to slice() op, so perform this replacement here.
2191+
This avoids the expense of delegation from ATen.
2192+
"""
2193+
2194+
# For split_with_sizes, return the slice dim and extent for each split.
2195+
def get_split_sizes(
2196+
self, graph_module: torch.fx.GraphModule, node: torch.fx.Node
2197+
) -> Optional[list[tuple[int, ...]]]:
2198+
# Parse the args of the split_with_sizes op
2199+
tensor_arg, split_sizes = node.args[0:2]
2200+
assert isinstance(tensor_arg, torch.fx.Node)
2201+
in_shape = get_shape(graph_module, tensor_arg)
2202+
split_dim = 0 if len(node.args) < 3 else node.args[2]
2203+
if in_shape is None:
2204+
return None
2205+
2206+
# Canonicalize the split dimension
2207+
assert isinstance(split_dim, int)
2208+
split_dim = split_dim if split_dim >= 0 else len(in_shape) + split_dim
2209+
2210+
# Create the slice op args corresponding to each split
2211+
slice_ops = []
2212+
split_start = 0
2213+
assert isinstance(split_sizes, list)
2214+
for split_size in split_sizes:
2215+
split_end = split_start + split_size
2216+
slice_args = (split_dim, split_start, split_end)
2217+
slice_ops.append(slice_args)
2218+
split_start = split_end
2219+
2220+
return slice_ops
2221+
2222+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
2223+
graph = graph_module.graph
2224+
for node in graph.nodes:
2225+
if not isinstance(node.target, EdgeOpOverload):
2226+
continue
2227+
if get_edge_overload_packet(node.target) != exir_ops.edge.aten.split_with_sizes_copy:
2228+
continue
2229+
# All the users of this split_with_sizes op must be getitem ops
2230+
if any(user.target != operator.getitem for user in node.users):
2231+
continue
2232+
2233+
# Get the slice dim and extent for each split
2234+
slice_ops = self.get_split_sizes(graph_module, node)
2235+
if slice_ops is None:
2236+
continue
2237+
2238+
# Go over each getitem user, and replace it with slice op
2239+
for user in list(node.users.keys()):
2240+
assert user.target == operator.getitem
2241+
item_idx = user.args[1]
2242+
assert item_idx < len(slice_ops)
2243+
cur_slice = slice_ops[item_idx]
2244+
with graph.inserting_before(user):
2245+
cur_slice_node = graph.call_function(
2246+
exir_ops.edge.aten.slice_copy.Tensor,
2247+
(node.args[0], cur_slice[0], cur_slice[1], cur_slice[2], 1),
2248+
)
2249+
user.replace_all_uses_with(cur_slice_node)
2250+
graph.erase_node(user)
2251+
2252+
graph.erase_node(node)
2253+
2254+
graph_module.recompile()
2255+
result = super().call(graph_module)
2256+
return result
2257+
2258+
21852259
# This class encapsulates all the functions that replace/switch one op in the
21862260
# graph with another.
21872261
class CadenceReplaceOpsInGraph:
@@ -2220,5 +2294,6 @@ class CadenceReplaceOpsInGraph:
22202294
ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass,
22212295
ReplaceAtenAvgPoolWithJarvisAvgPoolPass,
22222296
ReplaceWhereWithFullArgsWithWhereScalar,
2223-
# ReplaceGeluWithApproximateGeluPass,
2297+
ReplaceGeluWithApproximateGeluPass,
2298+
ReplaceSplitWithSlicePass,
22242299
]

0 commit comments

Comments
 (0)