Skip to content

Commit 4cdeed7

Browse files
committed
upd Bloom _prepare_attn_mask()
1 parent 4295ee8 commit 4cdeed7

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

src/petals/models/bloom/block.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Optional, Tuple
77

88
import torch
9+
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
910
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel, build_alibi_tensor
1011

1112

@@ -26,7 +27,13 @@ def forward(
2627
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
2728
if alibi is None:
2829
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
29-
attention_mask = BloomModel._prepare_attn_mask(None, attention_mask, (batch_size, seq_length), past_length)
30+
fake_inputs_embeds = torch.tensor([42], dtype=torch.float32)
31+
attention_mask = _prepare_4d_causal_attention_mask(
32+
attention_mask=attention_mask,
33+
input_shape=(batch_size, seq_length),
34+
inputs_embeds=fake_inputs_embeds,
35+
past_key_values_length=past_length,
36+
)
3037
return super().forward(
3138
hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs
3239
)

0 commit comments

Comments
 (0)