Skip to content

Commit 96bef6b

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Replace split_with_sizes_copy with slice_copy (#10318)
Summary: This is useful for a few reasons, which are all related to the fact that we have a lot of things in place for slice ops, so the compiler can optimize much better if we replace `split_with_sizes`. That includes: - swapping with q/dq nodes to do it in lower bitwidth - becoming slice_nop when possible (through memory planning constraints) - a few simplification passes on slices - possibly finding patterns with slice/cat being a nop - we have an optimized slice op on HiFi already. We will get the split op soon to be fair, we can do some profiling then, but the above points will still stand Reviewed By: zonglinpeng Differential Revision: D73312379
1 parent 21adbe2 commit 96bef6b

File tree

2 files changed

+107
-2
lines changed

2 files changed

+107
-2
lines changed

backends/cadence/aot/replace_ops.py

+80-2
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717
# pyre-unsafe
1818

1919
import math
20+
import operator
2021
from operator import neg
21-
from typing import cast, Dict, Iterable, Sequence, Set, Tuple
22+
from typing import cast, Dict, Iterable, Optional, Sequence, Set, Tuple
2223

2324
import torch
2425
import torch.fx
@@ -2182,6 +2183,82 @@ def call_operator(
21822183
)
21832184

21842185

2186+
# Adapted from fbcode/pyspeech/opt_passes/replace_ops.py
2187+
@register_cadence_pass(CadencePassAttribute(opt_level=2))
2188+
class ReplaceSplitWithSlicePass(ExportPass):
2189+
"""
2190+
split_with_sizes() delegates to slice() op, so perform this replacement here.
2191+
This avoids the expense of delegation from ATen.
2192+
"""
2193+
2194+
# For split_with_sizes, return the slice dim and extent for each split.
2195+
def get_split_sizes(
2196+
self, graph_module: torch.fx.GraphModule, node: torch.fx.Node
2197+
) -> Optional[list[tuple[int, ...]]]:
2198+
# Parse the args of the split_with_sizes op
2199+
tensor_arg, split_sizes = node.args[0:2]
2200+
assert isinstance(tensor_arg, torch.fx.Node)
2201+
in_shape = get_shape(graph_module, tensor_arg)
2202+
split_dim = 0 if len(node.args) < 3 else node.args[2]
2203+
if in_shape is None:
2204+
return None
2205+
2206+
# Canonicalize the split dimension
2207+
assert isinstance(split_dim, int)
2208+
split_dim = split_dim if split_dim >= 0 else len(in_shape) + split_dim
2209+
2210+
# Create the slice op args corresponding to each split
2211+
slice_ops = []
2212+
split_start = 0
2213+
assert isinstance(split_sizes, list)
2214+
for split_size in split_sizes:
2215+
split_end = split_start + split_size
2216+
slice_args = (split_dim, split_start, split_end)
2217+
slice_ops.append(slice_args)
2218+
split_start = split_end
2219+
2220+
return slice_ops
2221+
2222+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
2223+
graph = graph_module.graph
2224+
for node in graph.nodes:
2225+
if not isinstance(node.target, EdgeOpOverload):
2226+
continue
2227+
if (
2228+
get_edge_overload_packet(node.target)
2229+
!= exir_ops.edge.aten.split_with_sizes_copy
2230+
):
2231+
continue
2232+
# All the users of this split_with_sizes op must be getitem ops
2233+
if any(user.target != operator.getitem for user in node.users):
2234+
continue
2235+
2236+
# Get the slice dim and extent for each split
2237+
slice_ops = self.get_split_sizes(graph_module, node)
2238+
if slice_ops is None:
2239+
continue
2240+
2241+
# Go over each getitem user, and replace it with slice op
2242+
for user in list(node.users.keys()):
2243+
assert user.target == operator.getitem
2244+
item_idx = user.args[1]
2245+
assert item_idx < len(slice_ops)
2246+
cur_slice = slice_ops[item_idx]
2247+
with graph.inserting_before(user):
2248+
cur_slice_node = graph.call_function(
2249+
exir_ops.edge.aten.slice_copy.Tensor,
2250+
(node.args[0], cur_slice[0], cur_slice[1], cur_slice[2], 1),
2251+
)
2252+
user.replace_all_uses_with(cur_slice_node)
2253+
graph.erase_node(user)
2254+
2255+
graph.erase_node(node)
2256+
2257+
graph_module.recompile()
2258+
result = super().call(graph_module)
2259+
return result
2260+
2261+
21852262
# This class encapsulates all the functions that replace/switch one op in the
21862263
# graph with another.
21872264
class CadenceReplaceOpsInGraph:
@@ -2220,5 +2297,6 @@ class CadenceReplaceOpsInGraph:
22202297
ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass,
22212298
ReplaceAtenAvgPoolWithJarvisAvgPoolPass,
22222299
ReplaceWhereWithFullArgsWithWhereScalar,
2223-
# ReplaceGeluWithApproximateGeluPass,
2300+
ReplaceGeluWithApproximateGeluPass,
2301+
ReplaceSplitWithSlicePass,
22242302
]

backends/cadence/aot/tests/test_replace_ops_passes.py

+27
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# pyre-unsafe
88

9+
import operator
910
import unittest
1011
from typing import Any, Callable, cast, List, Optional, Sequence, Tuple, Union
1112

@@ -40,6 +41,7 @@
4041
ReplaceScalarWithTensorArgPass,
4142
ReplaceSelectWithViewOpPass,
4243
ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass,
44+
ReplaceSplitWithSlicePass,
4345
ReplaceSqueezeAndUnsqueezeWithViewPass,
4446
ReplaceTCopyWithTransposePass,
4547
ReplaceTransposedConvWithLinearPass,
@@ -1306,6 +1308,31 @@ def forward(self, input):
13061308
6,
13071309
)
13081310

1311+
def test_replace_split_with_sizes_with_slice(self):
1312+
builder = GraphBuilder()
1313+
x = builder.placeholder("x", torch.randn(1, 16, 8, 4))
1314+
split = builder.call_operator(
1315+
exir_ops.edge.aten.split_with_sizes_copy.default, (x, [8, 8], 1)
1316+
)
1317+
# We need the outputs to be gathered by getitem ops
1318+
out0 = builder.call_operator(operator.getitem, (split, 0))
1319+
out1 = builder.call_operator(operator.getitem, (split, 1))
1320+
builder.output([out0, out1])
1321+
graph_module = builder.get_graph_module()
1322+
1323+
p = ReplaceSplitWithSlicePass()
1324+
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
1325+
1326+
self.assertEqual(
1327+
count_node(
1328+
graph_after_passes, exir_ops.edge.aten.split_with_sizes_copy.default
1329+
),
1330+
0,
1331+
)
1332+
self.assertEqual(
1333+
count_node(graph_after_passes, exir_ops.edge.aten.slice_copy.Tensor),
1334+
2,
1335+
)
13091336

13101337
class TestReplaceIm2rowWithViewPass(unittest.TestCase):
13111338
def test_no_replacement_for_conv(self):

0 commit comments

Comments
 (0)