|
17 | 17 | # pyre-unsafe
|
18 | 18 |
|
19 | 19 | import math
|
| 20 | +import operator |
20 | 21 | from operator import neg
|
21 |
| -from typing import cast, Dict, Iterable, Sequence, Set, Tuple |
| 22 | +from typing import cast, Dict, Iterable, Optional, Sequence, Set, Tuple |
22 | 23 |
|
23 | 24 | import torch
|
24 | 25 | import torch.fx
|
@@ -2182,6 +2183,79 @@ def call_operator(
|
2182 | 2183 | )
|
2183 | 2184 |
|
2184 | 2185 |
|
| 2186 | +# Adapted from fbcode/pyspeech/opt_passes/replace_ops.py |
| 2187 | +@register_cadence_pass(CadencePassAttribute(opt_level=2)) |
| 2188 | +class ReplaceSplitWithSlicePass(ExportPass): |
| 2189 | + """ |
| 2190 | + split_with_sizes() delegates to slice() op, so perform this replacement here. |
| 2191 | + This avoids the expense of delegation from ATen. |
| 2192 | + """ |
| 2193 | + |
| 2194 | + # For split_with_sizes, return the slice dim and extent for each split. |
| 2195 | + def get_split_sizes( |
| 2196 | + self, graph_module: torch.fx.GraphModule, node: torch.fx.Node |
| 2197 | + ) -> Optional[list[tuple[int, ...]]]: |
| 2198 | + # Parse the args of the split_with_sizes op |
| 2199 | + tensor_arg, split_sizes = node.args[0:2] |
| 2200 | + assert isinstance(tensor_arg, torch.fx.Node) |
| 2201 | + in_shape = get_shape(graph_module, tensor_arg) |
| 2202 | + split_dim = 0 if len(node.args) < 3 else node.args[2] |
| 2203 | + if in_shape is None: |
| 2204 | + return None |
| 2205 | + |
| 2206 | + # Canonicalize the split dimension |
| 2207 | + assert isinstance(split_dim, int) |
| 2208 | + split_dim = split_dim if split_dim >= 0 else len(in_shape) + split_dim |
| 2209 | + |
| 2210 | + # Create the slice op args corresponding to each split |
| 2211 | + slice_ops = [] |
| 2212 | + split_start = 0 |
| 2213 | + assert isinstance(split_sizes, list) |
| 2214 | + for split_size in split_sizes: |
| 2215 | + split_end = split_start + split_size |
| 2216 | + slice_args = (split_dim, split_start, split_end) |
| 2217 | + slice_ops.append(slice_args) |
| 2218 | + split_start = split_end |
| 2219 | + |
| 2220 | + return slice_ops |
| 2221 | + |
| 2222 | + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: |
| 2223 | + graph = graph_module.graph |
| 2224 | + for node in graph.nodes: |
| 2225 | + if not isinstance(node.target, EdgeOpOverload): |
| 2226 | + continue |
| 2227 | + if get_edge_overload_packet(node.target) != exir_ops.edge.aten.split_with_sizes_copy: |
| 2228 | + continue |
| 2229 | + # All the users of this split_with_sizes op must be getitem ops |
| 2230 | + if any(user.target != operator.getitem for user in node.users): |
| 2231 | + continue |
| 2232 | + |
| 2233 | + # Get the slice dim and extent for each split |
| 2234 | + slice_ops = self.get_split_sizes(graph_module, node) |
| 2235 | + if slice_ops is None: |
| 2236 | + continue |
| 2237 | + |
| 2238 | + # Go over each getitem user, and replace it with slice op |
| 2239 | + for user in list(node.users.keys()): |
| 2240 | + assert user.target == operator.getitem |
| 2241 | + item_idx = user.args[1] |
| 2242 | + assert item_idx < len(slice_ops) |
| 2243 | + cur_slice = slice_ops[item_idx] |
| 2244 | + with graph.inserting_before(user): |
| 2245 | + cur_slice_node = graph.call_function( |
| 2246 | + exir_ops.edge.aten.slice_copy.Tensor, |
| 2247 | + (node.args[0], cur_slice[0], cur_slice[1], cur_slice[2], 1), |
| 2248 | + ) |
| 2249 | + user.replace_all_uses_with(cur_slice_node) |
| 2250 | + graph.erase_node(user) |
| 2251 | + |
| 2252 | + graph.erase_node(node) |
| 2253 | + |
| 2254 | + graph_module.recompile() |
| 2255 | + result = super().call(graph_module) |
| 2256 | + return result |
| 2257 | + |
| 2258 | + |
2185 | 2259 | # This class encapsulates all the functions that replace/switch one op in the
|
2186 | 2260 | # graph with another.
|
2187 | 2261 | class CadenceReplaceOpsInGraph:
|
@@ -2220,5 +2294,6 @@ class CadenceReplaceOpsInGraph:
|
2220 | 2294 | ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass,
|
2221 | 2295 | ReplaceAtenAvgPoolWithJarvisAvgPoolPass,
|
2222 | 2296 | ReplaceWhereWithFullArgsWithWhereScalar,
|
2223 |
| - # ReplaceGeluWithApproximateGeluPass, |
| 2297 | + ReplaceGeluWithApproximateGeluPass, |
| 2298 | + ReplaceSplitWithSlicePass, |
2224 | 2299 | ]
|
0 commit comments