diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index 867e4ec79c6..820e46652ef 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -17,8 +17,9 @@ # pyre-unsafe import math +import operator from operator import neg -from typing import cast, Dict, Iterable, Sequence, Set, Tuple +from typing import cast, Dict, Iterable, Optional, Sequence, Set, Tuple import torch import torch.fx @@ -2182,6 +2183,82 @@ def call_operator( ) +# Adapted from fbcode/pyspeech/opt_passes/replace_ops.py +@register_cadence_pass(CadencePassAttribute(opt_level=2)) +class ReplaceSplitWithSlicePass(ExportPass): + """ + split_with_sizes() delegates to slice() op, so perform this replacement here. + This avoids the expense of delegation from ATen. + """ + + # For split_with_sizes, return the slice dim and extent for each split. + def get_split_sizes( + self, graph_module: torch.fx.GraphModule, node: torch.fx.Node + ) -> Optional[list[tuple[int, ...]]]: + # Parse the args of the split_with_sizes op + tensor_arg, split_sizes = node.args[0:2] + assert isinstance(tensor_arg, torch.fx.Node) + in_shape = get_shape(graph_module, tensor_arg) + split_dim = 0 if len(node.args) < 3 else node.args[2] + if in_shape is None: + return None + + # Canonicalize the split dimension + assert isinstance(split_dim, int) + split_dim = split_dim if split_dim >= 0 else len(in_shape) + split_dim + + # Create the slice op args corresponding to each split + slice_ops = [] + split_start = 0 + assert isinstance(split_sizes, list) + for split_size in split_sizes: + split_end = split_start + split_size + slice_args = (split_dim, split_start, split_end) + slice_ops.append(slice_args) + split_start = split_end + + return slice_ops + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + graph = graph_module.graph + for node in graph.nodes: + if not isinstance(node.target, EdgeOpOverload): + continue + if ( + get_edge_overload_packet(node.target) + != exir_ops.edge.aten.split_with_sizes_copy + ): + continue + # All the users of this split_with_sizes op must be getitem ops + if any(user.target != operator.getitem for user in node.users): + continue + + # Get the slice dim and extent for each split + slice_ops = self.get_split_sizes(graph_module, node) + if slice_ops is None: + continue + + # Go over each getitem user, and replace it with slice op + for user in list(node.users.keys()): + assert user.target == operator.getitem + item_idx = user.args[1] + assert item_idx < len(slice_ops) + cur_slice = slice_ops[item_idx] + with graph.inserting_before(user): + cur_slice_node = graph.call_function( + exir_ops.edge.aten.slice_copy.Tensor, + (node.args[0], cur_slice[0], cur_slice[1], cur_slice[2], 1), + ) + user.replace_all_uses_with(cur_slice_node) + graph.erase_node(user) + + graph.erase_node(node) + + graph_module.recompile() + result = super().call(graph_module) + return result + + # This class encapsulates all the functions that replace/switch one op in the # graph with another. class CadenceReplaceOpsInGraph: @@ -2220,5 +2297,6 @@ class CadenceReplaceOpsInGraph: ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass, ReplaceAtenAvgPoolWithJarvisAvgPoolPass, ReplaceWhereWithFullArgsWithWhereScalar, - # ReplaceGeluWithApproximateGeluPass, + ReplaceGeluWithApproximateGeluPass, + ReplaceSplitWithSlicePass, ] diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index 886550772b5..848eba4d854 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -6,6 +6,7 @@ # pyre-unsafe +import operator import unittest from typing import Any, Callable, cast, List, Optional, Sequence, Tuple, Union @@ -40,6 +41,7 @@ ReplaceScalarWithTensorArgPass, ReplaceSelectWithViewOpPass, ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass, + ReplaceSplitWithSlicePass, ReplaceSqueezeAndUnsqueezeWithViewPass, ReplaceTCopyWithTransposePass, ReplaceTransposedConvWithLinearPass, @@ -1306,6 +1308,32 @@ def forward(self, input): 6, ) + def test_replace_split_with_sizes_with_slice(self): + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(1, 16, 8, 4)) + split = builder.call_operator( + exir_ops.edge.aten.split_with_sizes_copy.default, (x, [8, 8], 1) + ) + # We need the outputs to be gathered by getitem ops + out0 = builder.call_operator(operator.getitem, (split, 0)) + out1 = builder.call_operator(operator.getitem, (split, 1)) + builder.output([out0, out1]) + graph_module = builder.get_graph_module() + + p = ReplaceSplitWithSlicePass() + graph_after_passes = cast(PassResult, p(graph_module)).graph_module + + self.assertEqual( + count_node( + graph_after_passes, exir_ops.edge.aten.split_with_sizes_copy.default + ), + 0, + ) + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.slice_copy.Tensor), + 2, + ) + class TestReplaceIm2rowWithViewPass(unittest.TestCase): def test_no_replacement_for_conv(self):