Skip to content

Commit 68d4f3e

Browse files
Fixed left_pad_sequence - correctly flip dims based on batch_first (#1523)
1 parent 5d5caca commit 68d4f3e

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

tests/torchtune/data/test_collate.py

+6
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@ def test_left_pad_sequence(self):
5757
expected = torch.tensor([[0, 0, 1, 2, 3], [0, 4, 5, 6, 7], [8, 9, 10, 11, 12]])
5858
assert torch.equal(result, expected)
5959

60+
result = left_pad_sequence([a, b, c], batch_first=False, padding_value=0)
61+
expected = torch.tensor(
62+
[[0, 0, 8], [0, 4, 9], [1, 5, 10], [2, 6, 11], [3, 7, 12]]
63+
)
64+
assert torch.equal(result, expected)
65+
6066

6167
class TestPaddedCollate:
6268
def test_padded_collate_classifier_labels(self):

torchtune/data/_collate.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def left_pad_sequence(
4949
map(lambda x: torch.flip(x, dims=[0]), sequences),
5050
batch_first=batch_first,
5151
padding_value=padding_value,
52-
).flip(dims=[1])
52+
).flip(dims=[int(batch_first)])
5353

5454

5555
def padded_collate(

0 commit comments

Comments
 (0)