-
Notifications
You must be signed in to change notification settings - Fork 526
/
Copy pathtest_remove_ops_passes.py
795 lines (670 loc) · 27.9 KB
/
test_remove_ops_passes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
import unittest
from typing import cast, Tuple
import executorch.backends.cadence.aot.ops_registrations # noqa
import torch
import torch.nn as nn
import torch.nn.functional as F
from executorch.backends.cadence.aot import compiler
from executorch.backends.cadence.aot.compiler import export_to_edge
from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceDefaultQuantizer
from executorch.backends.cadence.aot.remove_ops import (
RemoveAliasCopyOpPass,
RemoveBranchedQuantDequant,
RemoveCatFromSliceCopyPass,
RemoveCloneOpPass,
RemoveContiguousOpPass,
RemoveDetachCopyPass,
RemoveNopAddOpPass,
RemoveNopExpandOpPass,
RemoveNopLinalgVectorNormOpPass,
RemoveNopMulOpPass,
RemoveNopSelectOpPass,
RemoveNopSliceOrViewOpPass,
RemovePermutesAroundElementwiseOps,
RemoveToOpsPass,
RemoveZeroSizedCatArgsPass,
RemoveZeroSizedConstantPadNd,
)
from executorch.exir.dialects._ops import ops as exir_ops
from parameterized.parameterized import parameterized
from pyre_extensions import none_throws
from torch.export import export_for_training
from torch.fx.passes.infra.pass_base import PassResult
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
class TestRemoveOpsPasses(unittest.TestCase):
@parameterized.expand(
[
[(1, 2, 3)],
]
)
@torch.no_grad()
def test_remove_to_ops(self, shape: Tuple[int]):
class M(torch.nn.Module):
def forward(self, x: torch.Tensor):
return exir_ops.edge.aten.to(x, dtype=torch.float32)
model = M()
x = torch.randn(shape)
graph_module = export_to_edge(model, (x,)).exported_program().graph_module
p = RemoveToOpsPass()
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.to.dtype),
0,
)
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.to.dtype_layout),
0,
)
@parameterized.expand(
[
[(7, 6, 5)],
[(7, 6)],
[(7,)],
]
)
@torch.no_grad()
def test_remove_nop_add_op_pass(self, shape: Tuple[int]):
class FullX(torch.nn.Module):
def forward(self, t: torch.Tensor):
return torch.add(torch.full(shape, 0), t)
class FullY(torch.nn.Module):
def forward(self, t: torch.Tensor):
return torch.add(t, torch.full(shape, 0))
model = FullX()
t = torch.full(shape, 3)
graph_module = export_to_edge(model, (t,)).exported_program().graph_module
p = RemoveNopAddOpPass()
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
graph_module.print_readable()
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.add.Tensor),
0,
)
model = FullY()
graph_module = export_to_edge(model, (t,)).exported_program().graph_module
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.add.Tensor),
0,
)
@parameterized.expand(
[
[(7, 6, 5)],
[(7, 6)],
[(7,)],
]
)
@torch.no_grad()
def test_remove_nop_mul_op_pass(self, shape: Tuple[int]):
class FullX(torch.nn.Module):
def forward(self, t: torch.Tensor):
return torch.mul(torch.full(shape, 0), t)
class FullY(torch.nn.Module):
def forward(self, t: torch.Tensor):
return torch.mul(t, torch.full(shape, 0))
model = FullX()
t = torch.full(shape, 3)
graph_module = export_to_edge(model, (t,)).exported_program().graph_module
p = RemoveNopMulOpPass()
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
graph_module.print_readable()
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.mul.Tensor),
0,
)
model = FullY()
graph_module = export_to_edge(model, (t,)).exported_program().graph_module
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.mul.Tensor),
0,
)
@parameterized.expand(
[
[(1, 2, 3)],
]
)
@torch.no_grad()
def test_remove_alias_copy(self, shape: Tuple[int]):
class M(torch.nn.Module):
def forward(self, x: torch.Tensor):
return exir_ops.edge.aten.alias_copy(x)
model = M()
x = torch.randn(shape)
graph_module = export_to_edge(model, (x,)).exported_program().graph_module
p = RemoveAliasCopyOpPass()
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.alias_copy.default),
0,
)
@parameterized.expand(
[
[(1, 2, 3)],
]
)
@torch.no_grad()
def test_remove_detach_copy(self, shape: Tuple[int]):
# aten::detach is converted to aten::alias_copy after functionalization & decomposition.
class M(torch.nn.Module):
def forward(self, x: torch.Tensor):
return exir_ops.edge.aten.detach_copy(x)
model = M()
x = torch.randn(shape)
graph_module = export_to_edge(model, (x,)).exported_program().graph_module
p = RemoveDetachCopyPass()
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.detach_copy.default),
0,
)
@parameterized.expand(
[
[(1, 2, 3), (0, 0)],
]
)
@torch.no_grad()
def test_remove_zero_sized_constant_pad_nd(
self, shape: Tuple[int], padding: Tuple[int]
):
# F.pad is converted to aten::constant_pad_nd after functionalization & decomposition.
class Padding(torch.nn.Module):
def __init__(self):
super().__init__()
self.padding = padding
def forward(self, x: torch.Tensor):
return F.pad(x, self.padding)
model = Padding()
x = torch.randn(shape)
graph_module = export_to_edge(model, (x,)).exported_program().graph_module
p = RemoveZeroSizedConstantPadNd()
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.constant_pad_nd.default),
0,
)
def test_remove_expand(self):
class Expand(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.expand_copy(x, [2, 3, 5])
x = torch.ones(2, 3, 5)
p = RemoveNopExpandOpPass()
graph_module = export_to_edge(Expand(), (x,)).exported_program().graph_module
graph_module = p(graph_module).graph_module
# Assert that expand op is optimized away, since it is a nop
self.assertEqual(
count_node(graph_module, exir_ops.edge.aten.expand_copy.default), 0
)
def test_remove_zero_arg_cat(self):
class Cat(torch.nn.Module):
def forward(self, x, y):
return torch.ops.aten.cat((x, y), 0)
x = torch.ones(1, 0, 3, 5)
y = torch.ones(2, 0, 3, 5)
graph_module = (
compiler.export_to_cadence(Cat(), (x, y)).exported_program().graph_module
)
# Assert that cat op is optimized away, since it concatenates
# two zero-sized tensors
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 0)
def test_remove_single_arg_cat(self):
class Cat(torch.nn.Module):
def forward(self, x, y):
z = torch.ones(0, 5)
# z is an empty tensor, and concatenation of x with z will
# be x. So we can safely eliminate the following cat op.
x1 = torch.ops.aten.cat((x, z))
x2 = torch.add(x1, 2.4, 3.1)
y1 = torch.add(y, 1, 2)
return torch.add(x2, y1)
x = torch.ones(3, 5)
y = torch.ones(3, 5)
graph_module = export_to_edge(Cat(), (x, y)).exported_program().graph_module
new_graph_module = RemoveZeroSizedCatArgsPass()(graph_module).graph_module
new_graph_module.graph.eliminate_dead_code()
# Assert that x1 is optimized away
self.assertEqual(count_node(new_graph_module, torch.ops.aten.cat.out), 0)
def test_remove_zero_sized_cat(self):
class Cat(torch.nn.Module):
def __init__(self, dim: int):
super().__init__()
self.dim = dim
def forward(self, tensors):
return torch.cat(tensors, self.dim)
shapes, dim, dtype, _max = [(1, 0, 3), (2, 0, 3)], 0, torch.float32, 127
in_tensors = [(torch.rand(shape) * _max).to(dtype=dtype) for shape in shapes]
model = Cat(dim)
graph_module = (
export_to_edge(model, (in_tensors,)).exported_program().graph_module
)
new_graph_module = RemoveZeroSizedCatArgsPass()(graph_module).graph_module
new_graph_module.graph.eliminate_dead_code()
self.assertEqual(count_node(new_graph_module, torch.ops.aten.cat.out), 0)
def test_remove_clone(self):
class Clone(torch.nn.Module):
def forward(self, x, y):
t1 = x.clone()
t2 = y.clone()
return t1 + t2
x = torch.ones(3, 5)
y = torch.ones(3, 5)
graph_module = export_to_edge(Clone(), (x, y)).exported_program().graph_module
new_graph_module = RemoveCloneOpPass()(graph_module).graph_module
new_graph_module.graph.eliminate_dead_code()
# Assert that t1 and t2 are optimized away
self.assertEqual(count_node(new_graph_module, torch.ops.aten.clone.out), 0)
def test_remove_contiguous(self):
class Contiguous(torch.nn.Module):
def forward(self, x, y):
t1 = x.contiguous()
t2 = y.contiguous()
return t1 + t2
x = torch.ones(3, 5)
y = torch.ones(3, 5)
graph_module = (
export_to_edge(Contiguous(), (x, y)).exported_program().graph_module
)
new_graph_module = RemoveContiguousOpPass()(graph_module).graph_module
new_graph_module.graph.eliminate_dead_code()
# Assert that t1 and t2 are optimized away
self.assertEqual(count_node(new_graph_module, torch.ops.aten.contiguous.out), 0)
@parameterized.expand(
[
[(3, 5), [3, 5]],
[(1,), [-1]],
]
)
@torch.no_grad()
def test_remove_nop_view(self, shape, new_shape):
class View(torch.nn.Module):
def __init__(self, new_shape):
super().__init__()
self.new_shape = new_shape
def forward(self, x: torch.Tensor):
return x.view(self.new_shape)
model = View(new_shape)
x = torch.randn(shape)
graph_module = export_to_edge(model, (x,)).exported_program().graph_module
p = RemoveNopSliceOrViewOpPass()
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
graph_after_passes.graph.eliminate_dead_code()
# Assert that view op was removed
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), 0
)
def test_remove_nop_slice(self):
class Slice(torch.nn.Module):
def forward(self, x):
return torch.slice_copy(x, dim=0, start=0, step=1)
x = torch.ones(3, 5)
model = Slice()
graph_module = export_to_edge(model, (x,)).exported_program().graph_module
p = RemoveNopSliceOrViewOpPass()
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
graph_after_passes.graph.eliminate_dead_code()
# Assert that slice op was removed
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.slice_copy.Tensor), 0
)
def test_remove_nop_select(self):
class SelectFeasible1(torch.nn.Module):
def forward(self, x):
y = x.select(0, 0)
z = y.view([1, 5, 6])
return z
x = torch.ones(1, 5, 6)
graph_module = (
export_to_edge(SelectFeasible1(), (x,)).exported_program().graph_module
)
self.assertEqual(
count_node(graph_module, exir_ops.edge.aten.select_copy.int), 1
)
graph_module = RemoveNopSelectOpPass()(graph_module).graph_module
# Assert that select op was removed
self.assertEqual(
count_node(graph_module, exir_ops.edge.aten.select_copy.int), 0
)
class SelectFeasible2(torch.nn.Module):
def forward(self, x, y):
x = x.select(0, 0)
z = x + y
return z
x = torch.ones(1, 5, 6)
y = torch.ones(1, 5, 6)
graph_module = (
export_to_edge(SelectFeasible2(), (x, y)).exported_program().graph_module
)
self.assertEqual(
count_node(graph_module, exir_ops.edge.aten.select_copy.int), 1
)
graph_module = RemoveNopSelectOpPass()(graph_module).graph_module
# Assert that select op was removed
self.assertEqual(
count_node(graph_module, exir_ops.edge.aten.select_copy.int), 0
)
class SelectFeasible3(torch.nn.Module):
def forward(self, x, y):
x = x.select(0, 0)
z = x * y
return z
x = torch.ones(1, 5, 6)
y = torch.ones(1, 5, 6)
graph_module = (
export_to_edge(SelectFeasible3(), (x, y)).exported_program().graph_module
)
self.assertEqual(
count_node(graph_module, exir_ops.edge.aten.select_copy.int), 1
)
graph_module = RemoveNopSelectOpPass()(graph_module).graph_module
# Assert that select op was removed
self.assertEqual(
count_node(graph_module, exir_ops.edge.aten.select_copy.int), 0
)
class SelectFeasible4(torch.nn.Module):
def forward(self, x, y):
x = x.select(0, 0)
z = x / y
return z
x = torch.ones(1, 5, 6)
y = torch.ones(1, 5, 6)
graph_module = (
export_to_edge(SelectFeasible4(), (x, y)).exported_program().graph_module
)
self.assertEqual(
count_node(graph_module, exir_ops.edge.aten.select_copy.int), 1
)
graph_module = RemoveNopSelectOpPass()(graph_module).graph_module
# Assert that select op was removed
self.assertEqual(
count_node(graph_module, exir_ops.edge.aten.select_copy.int), 0
)
def test_remove_nop_quant_dequant(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.linear = torch.nn.Linear(6, 12, bias=False)
def forward(self, x):
x = self.linear(x)
return x
inp = torch.randn(2, 8, 1, 6)
# Run the standard quant/convert steps, but without fusing
# this leaves two redundant quant/dequant pairs to test with
quantizer = CadenceDefaultQuantizer()
model_exp = export_for_training(M(), (inp,), strict=True).module()
prepared_model = prepare_pt2e(model_exp, quantizer)
prepared_model(inp)
converted_model = convert_pt2e(prepared_model)
graph_module = (
compiler.export_to_cadence(
converted_model,
(inp,),
)
.exported_program()
.graph_module
)
# Expect all quantize ops to be removed by the pass
self.assertEqual(
count_node(graph_module, exir_ops.edge.cadence.quantize_per_tensor.default),
0,
)
# Expect 1 dequantize op for the weights
self.assertEqual(
count_node(
graph_module, exir_ops.edge.cadence.dequantize_per_tensor.default
),
1,
)
def test_remove_nop_aten_linalg_vector_norm(self):
class LinalgVectorNorm(torch.nn.Module):
def forward(self, x: torch.Tensor):
return torch.linalg.vector_norm(x, 2, [0, 1], True)
model = LinalgVectorNorm()
x = torch.randn([1, 1, 128])
inputs = (x,)
graph_module = (
compiler.export_to_edge(
model,
inputs,
)
.exported_program()
.graph_module
)
graph_module = none_throws(
RemoveNopLinalgVectorNormOpPass()(graph_module)
).graph_module
# Expect the linalg_vector_norm op to be removed by the pass
self.assertEqual(
count_node(graph_module, exir_ops.edge.aten.linalg_vector_norm.default)
+ count_node(
graph_module, exir_ops.edge.cadence.linalg_vector_norm.default
),
0,
)
def test_remove_permutes_around_elemwise_ops_add(self) -> None:
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(8, 8, 1, bias=False)
def forward(self, x):
x = self.conv(x)
x = torch.permute(x, [0, 3, 1, 2])
x = torch.add(x, x)
x = torch.permute(x, [0, 2, 3, 1])
x = self.conv(x)
return x
inputs = (torch.randn(1, 8, 4, 4),)
graph_module = export_to_edge(M(), inputs).exported_program().graph_module
p = RemovePermutesAroundElementwiseOps()
graph_module = cast(PassResult, p(graph_module)).graph_module
self.assertEqual(
count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 0
)
def test_remove_permutes_around_elemwise_ops_add_mean(self) -> None:
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv2d = nn.Conv2d(8, 8, 1)
def forward(self, x, y):
x = self.conv2d(x)
y = self.conv2d(y)
x = torch.permute(x, [0, 3, 1, 2])
y = torch.permute(y, [0, 3, 1, 2])
z = torch.add(x, y)
z = torch.mean(z, dim=[-1, -3], keepdim=True)
z = torch.permute(z, [0, 2, 3, 1])
z = self.conv2d(z)
return z
inputs = (torch.randn(1, 8, 4, 4), torch.randn(1, 8, 4, 4))
graph_module = export_to_edge(M(), inputs).exported_program().graph_module
p = RemovePermutesAroundElementwiseOps()
graph_module = cast(PassResult, p(graph_module)).graph_module
self.assertEqual(
count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 0
)
# verify that mean was updated correctly
mean = [
n
for n in graph_module.graph.nodes
if n.target == exir_ops.edge.aten.mean.dim
][0]
self.assertEqual(mean.args[1], [2, 3])
def test_remove_permutes_around_elemwise_ops_mul(self) -> None:
class M(torch.nn.Module):
def forward(self, x, y):
x = torch.slice_copy(x, 0, 0, 1)
x = torch.permute(x, [0, 3, 1, 2])
y = torch.permute(y, [0, 3, 1, 2])
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
x, 1.5, 0, 0, 255, torch.uint8
)
z = x * y
z = torch.ops.quantized_decomposed.quantize_per_tensor(
z, 2.5, 0, 0, 255, torch.uint8
)
z = torch.permute(z, [0, 2, 3, 1])
z = torch.unsqueeze_copy(z, 0)
return z
inputs = (torch.randn(2, 4, 4, 8), torch.randn(2, 4, 4, 8))
graph_module = export_to_edge(M(), inputs).exported_program().graph_module
p = RemovePermutesAroundElementwiseOps()
graph_module = cast(PassResult, p(graph_module)).graph_module
self.assertEqual(
count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 0
)
def test_remove_permutes_around_elemwise_ops_double_permutes(self) -> None:
class M(torch.nn.Module):
def forward(self, x, y):
x = torch.slice_copy(x, 0, 0, 1)
x = torch.permute(x, [0, 3, 1, 2])
x = torch.permute(x, [0, 3, 1, 2])
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
x, 1.5, 0, 0, 255, torch.uint8
)
y = torch.permute(y, [0, 3, 1, 2])
y = torch.ops.quantized_decomposed.dequantize_per_tensor(
y, 1.5, 0, 0, 255, torch.uint8
)
z = torch.cat((x, y), 1)
z = torch.ops.quantized_decomposed.quantize_per_tensor(
z, 2.5, 0, 0, 255, torch.uint8
)
z = torch.permute(z, [0, 2, 3, 1])
z = torch.permute(z, [0, 2, 3, 1])
z = torch.unsqueeze_copy(z, 0)
return z
inputs = (torch.randn(2, 4, 4, 8), torch.randn(1, 8, 4, 4))
graph_module = export_to_edge(M(), inputs).exported_program().graph_module
p = RemovePermutesAroundElementwiseOps()
graph_module = cast(PassResult, p(graph_module)).graph_module
# Expect 2 permutes to remain, one on input x and one on output z
self.assertEqual(
count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 2
)
# verify that cat was updated correctly
cat = [
n
for n in graph_module.graph.nodes
if n.target == exir_ops.edge.aten.cat.default
][0]
self.assertEqual(cat.args[1], 3)
def test_remove_permutes_around_concat_with_views(self) -> None:
class M(torch.nn.Module):
def forward(self, x, y):
# Mix and match views that are permutes and actual permutes. Both
# should be removed.
x = x.view(1, 1, 4, 4)
y = torch.permute(y, [0, 3, 1, 2])
z = torch.cat((x, y), 1)
return z.view(1, 4, 4, 8)
inputs = (torch.randn(1, 4, 4, 1), torch.randn(1, 4, 4, 7))
graph_module = export_to_edge(M(), inputs).exported_program().graph_module
p = RemovePermutesAroundElementwiseOps()
graph_module = cast(PassResult, p(graph_module)).graph_module
# Expect 0 permutes and views to remain.
self.assertEqual(
count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 0
)
self.assertEqual(
count_node(graph_module, exir_ops.edge.aten.view_copy.default), 0
)
# verify that cat was updated correctly
cat = [
n
for n in graph_module.graph.nodes
if n.target == exir_ops.edge.aten.cat.default
][0]
self.assertEqual(cat.args[1], 3)
def test_remove_permutes_around_elemwise_ops_noop(self) -> None:
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(8, 8, 1, bias=False)
def forward(self, x):
x = self.conv(x)
x = torch.permute(x, [0, 2, 3, 1])
x = torch.add(x, x)
x = torch.permute(x, [0, 3, 1, 2])
x = self.conv(x)
return x
inputs = (torch.randn(1, 8, 4, 4),)
graph_module = export_to_edge(M(), inputs).exported_program().graph_module
p = RemovePermutesAroundElementwiseOps()
graph_module = cast(PassResult, p(graph_module)).graph_module
# Ensure no permutes were removed, since the dimensions don't fit the expected pattern
self.assertEqual(
count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 2
)
def test_remove_dequant_on_branch(self):
class M(torch.nn.Module):
def forward(self, x):
x = torch.abs(x)
x0 = torch.ops.quantized_decomposed.quantize_per_tensor(
x, 1.2, 3, 0, 127, torch.int8
)
x1 = torch.abs(x0)
y0 = torch.ops.quantized_decomposed.dequantize_per_tensor(
x0, 1.2, 3, 0, 127, torch.int8
)
y1 = y0.view(-1)
return x1, y1
inputs = torch.rand(1, 8, 4, 6)
model = M()
graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module
graph_module = RemoveBranchedQuantDequant()(graph_module).graph_module
self.assertTrue(
op_counts_match(
graph_module,
expected_op_counts={
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1,
# we expect the pass to remove the dequantize node
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0,
exir_ops.edge.aten.abs.default: 2,
},
)
)
def test_remove_cat_from_slice_copy_all_removal(self) -> None:
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
x1 = torch.cat((x, y), 0) # (2, 4)
return torch.slice_copy(x1, dim=0, start=0, end=1)
inputs = tuple(torch.randn(2, 4) for _ in range(2))
graph_module = export_to_edge(M(), inputs).exported_program().graph_module
p = RemoveCatFromSliceCopyPass()
graph_module = cast(PassResult, p(graph_module)).graph_module
# Ensure both cat nodes were removed
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 0)
def test_remove_cat_from_slice_copy_no_removal(self) -> None:
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
x1 = torch.cat((x, y), 0) # (2, 4)
return torch.slice_copy(x1, dim=0, start=0, end=3)
inputs = tuple(torch.randn(2, 4) for _ in range(2))
graph_module = export_to_edge(M(), inputs).exported_program().graph_module
p = RemoveCatFromSliceCopyPass()
graph_module = cast(PassResult, p(graph_module)).graph_module
# Ensure both cat nodes were removed
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 1)
def test_remove_cat_from_slice_copy_zero_range(self) -> None:
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
x1 = torch.cat((x, y), 0) # (2, 4)
return torch.slice_copy(x1, dim=0, start=0, end=0)
inputs = tuple(torch.randn(2, 4) for _ in range(2))
graph_module = export_to_edge(M(), inputs).exported_program().graph_module
p = RemoveCatFromSliceCopyPass()
graph_module = cast(PassResult, p(graph_module)).graph_module
# Ensure both cat nodes were removed
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 0)