diff --git a/src/twinkle/processor/base.py b/src/twinkle/processor/base.py index ff0fbabf..6f10ff63 100644 --- a/src/twinkle/processor/base.py +++ b/src/twinkle/processor/base.py @@ -246,8 +246,9 @@ def _create_4d_attention_mask(attention_mask): import torch seq_lens = [s.shape[0] for s in attention_mask] max_len = max(seq_lens) - attention_mask = torch.tril(torch.ones((len(seq_lens), max_len, max_len), - dtype=torch.bool)).view(len(seq_lens), 1, max_len, max_len) + device = attention_mask[0].device + attention_mask = torch.tril(torch.ones((len(seq_lens), max_len, max_len), dtype=torch.bool, + device=device)).view(len(seq_lens), 1, max_len, max_len) assert attention_mask.dtype is torch.bool, f'attention_mask.dtype: {attention_mask.dtype}' for i, seq_len in enumerate(seq_lens): attention_mask[i, :, :, seq_len:] = 0