Skip to content
5 changes: 3 additions & 2 deletions src/twinkle/processor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,9 @@ def _create_4d_attention_mask(attention_mask):
import torch
seq_lens = [s.shape[0] for s in attention_mask]
max_len = max(seq_lens)
attention_mask = torch.tril(torch.ones((len(seq_lens), max_len, max_len),
dtype=torch.bool)).view(len(seq_lens), 1, max_len, max_len)
device = attention_mask[0].device
attention_mask = torch.tril(torch.ones((len(seq_lens), max_len, max_len), dtype=torch.bool,
device=device)).view(len(seq_lens), 1, max_len, max_len)
assert attention_mask.dtype is torch.bool, f'attention_mask.dtype: {attention_mask.dtype}'
for i, seq_len in enumerate(seq_lens):
attention_mask[i, :, :, seq_len:] = 0
Expand Down
Loading