Skip to content

Commit 2600cc8

Browse files
authored
Fix list assumption about permute arguments
Differential Revision: D67765363 Pull Request resolved: #7472
1 parent a861294 commit 2600cc8

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

Diff for: backends/cadence/aot/compiler_utils.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -129,16 +129,16 @@ def get_transposed_dims(node: torch.fx.Node, dims: List[int]) -> List[int]:
129129

130130

131131
# Capture the effect of permute op on incoming dimension order
132-
def get_permuted_dims(node: torch.fx.Node, dims: Optional[List[int]]) -> List[int]:
132+
def get_permuted_dims(node: torch.fx.Node, dims: Optional[Sequence[int]]) -> List[int]:
133133
"""
134134
Given a permute node, and the incoming dimension ordering of the input
135135
tensor to the permute node, return the net effect of permute op on the
136136
dimension order.
137137
"""
138138
assert node.target == exir_ops.edge.aten.permute_copy.default
139139
# Permute each index of the dimension ordering (dims)
140-
permute_dims = node.args[1]
141-
assert isinstance(permute_dims, List)
140+
# pyre-fixme[6]: This combined typecheck isn't supported yet.
141+
permute_dims: List[int] = list(node.args[1])
142142
assert all(isinstance(x, int) for x in permute_dims)
143143
# If the dims is empty, we can simply return the permute order
144144
if not dims:

Diff for: backends/cadence/aot/reorder_ops.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -438,9 +438,9 @@ def postpone_dequantize_op(self, graph_module: torch.fx.GraphModule) -> bool:
438438
args=(user, *node.args[1:]),
439439
)
440440
dequant_node.meta = user.meta.copy()
441-
# Remove meta["debug_handle"] on new node. Reassign it at the
442-
# caller level by calling generate_missing_debug_handles
443-
dequant_node.meta.pop("debug_handle")
441+
# Remove meta["debug_handle"] on new node if it exists.
442+
# Reassign it at the caller level by calling generate_missing_debug_handles
443+
dequant_node.meta.pop("debug_handle", None)
444444
user.replace_all_uses_with(dequant_node)
445445
dequant_node.args = (user, *node.args[1:])
446446

0 commit comments

Comments
 (0)