Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
d10fa92
Initial commit
Micky774 Oct 24, 2025
eef7dc0
Updated to build from source by default
Micky774 Oct 24, 2025
cc68ab7
Updated for V3 API
Micky774 Oct 31, 2025
4455361
Fixed build, reverted AOTriton bwd changes (now V2)
Micky774 Nov 3, 2025
2586b18
Removed alterations
Micky774 Nov 3, 2025
aa80f81
Removed lazy tensor wrapper
Micky774 Nov 3, 2025
9a91b9e
Streamlined cmakelist, other PR review feedback adressed
Micky774 Nov 4, 2025
023deb4
Removed `pad_between_seqs`
Micky774 Nov 4, 2025
6b8dbe5
Updated typing to be more explicit
Micky774 Nov 4, 2025
68303d0
Minor streamlining and formatting
Micky774 Nov 4, 2025
8181972
Initial implementation
Micky774 Nov 6, 2025
6788a16
Simplified window size func for current non-SWA support
Micky774 Nov 6, 2025
182101a
Removed accidental include
Micky774 Nov 6, 2025
19a9c0f
Merge branch 'zain/aotriton' into zain/aotriton-bwd
Micky774 Nov 6, 2025
fef6baa
Corrected bwd args
Micky774 Nov 6, 2025
3a4fab8
Updated causal window default
Micky774 Nov 10, 2025
917e3c3
Updated window values for causal
Micky774 Nov 10, 2025
ce32e3b
Merge branch 'zain/aotriton' into zain/aotriton-bwd
Micky774 Nov 10, 2025
36045c8
Corrected DQ_ACC buffer, added env var for GPU kernel building
Micky774 Nov 12, 2025
d6e46c1
Update AOTriton to 0.11.1b
Micky774 Nov 12, 2025
1349a48
Merge branch 'dev' into zain/aotriton
Micky774 Nov 24, 2025
8ed0009
Merge branch 'zain/aotriton' into zain/aotriton-bwd
Micky774 Nov 24, 2025
2bd9006
Added AOTriton commit SHA
Micky774 Nov 25, 2025
a9bef37
Merge branch 'dev' into zain/aotriton-bwd
Micky774 Nov 25, 2025
0fdff86
Moved handling of env variable to makefile
Micky774 Nov 26, 2025
3f6e054
Simplified lazy tensor implementation
Micky774 Dec 1, 2025
2246da4
Merge branch 'dev' into zain/aotriton-bwd
Micky774 Dec 10, 2025
2a17f7b
Merge branch 'dev' into zain/aotriton-bwd
Micky774 Jan 29, 2026
1a267cd
Update AOTriton version
Micky774 Jan 30, 2026
51da203
Improved tests
Micky774 Feb 4, 2026
945c8b2
Fix dq_acc stride. AITER ASM expects BHS.
xinyazhang Feb 5, 2026
e478e1a
Merge branch 'dev' into zain/aotriton-bwd
Micky774 Feb 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 64 additions & 11 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,7 @@ def test_dpa_alibi_slopes(dtype, model_configs, model):
"layout_2_1": ModelConfig(
2, 2048, 24, 256, attn_mask_type="causal", attn_bias_type="post_scale_bias"
),
"layout_3_0": ModelConfig(1, 2048, 12, 64, attn_mask_type="causal"),
}


Expand Down Expand Up @@ -1281,17 +1282,60 @@ def test_transformer_layer(

# FusedAttention backend
if fused_attn_supported:
fused_attn_fwd, fused_attn_bwd = _run_transformer_layer(
dtype,
config,
"FusedAttention",
ckpt_attn,
qkv_format,
workspace_opt,
fused_qkv_params,
RoPE,
is_training,
)
if len(fused_attn_backends) == 1:
fused_attn_fwd, fused_attn_bwd = _run_transformer_layer(
dtype,
config,
"FusedAttention",
ckpt_attn,
qkv_format,
workspace_opt,
fused_qkv_params,
RoPE,
is_training,
)
elif len(fused_attn_backends) == 2:
os.environ["NVTE_FUSED_ATTN_CK"] = "0"
os.environ["NVTE_FUSED_ATTN_AOTRITON"] = "1"
fused_attn_fwd, fused_attn_bwd = _run_transformer_layer(
dtype,
config,
"FusedAttention",
ckpt_attn,
qkv_format,
workspace_opt,
fused_qkv_params,
RoPE,
is_training,
)
os.environ["NVTE_FUSED_ATTN_CK"] = "1"
os.environ["NVTE_FUSED_ATTN_AOTRITON"] = "0"
fused_attn_fwd_1, fused_attn_bwd_1 = _run_transformer_layer(
dtype,
config,
"FusedAttention",
ckpt_attn,
qkv_format,
workspace_opt,
fused_qkv_params,
RoPE,
is_training,
)

os.environ["NVTE_CK_USES_FWD_V3"] = "1"
os.environ["NVTE_CK_USES_BWD_V3"] = "1"
fused_attn_fwd_2, fused_attn_bwd_2 = _run_transformer_layer(
dtype,
config,
"FusedAttention",
ckpt_attn,
qkv_format,
workspace_opt,
fused_qkv_params,
RoPE,
is_training,
)


# FlashAttention backend
if flash_attn_supported:
Expand Down Expand Up @@ -1320,6 +1364,15 @@ def test_transformer_layer(
logging.info("[test_transformer_layer]: fused attn vs flash attn")
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, **tols)
if IS_HIP_EXTENSION and fused_attn_supported and len(fused_attn_backends) == 2:
logging.info("[test_transformer_layer]: fused attn backend 0 vs 1")
torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols)
for i, _ in enumerate(fused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols)
logging.info("[test_transformer_layer]: fused attn backend 0 vs 2")
torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_2, **tols)
for i, _ in enumerate(fused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_2[i], **tols)


@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
Expand Down
4 changes: 2 additions & 2 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1212,7 +1212,6 @@ def _test_dpa_accuracy(block, bs, dtype, config):
query.retain_grad()
key.retain_grad()
value.retain_grad()

out = block(query, key, value, attention_mask=mask)
loss = out.sum()
loss.backward()
Expand Down Expand Up @@ -1256,7 +1255,8 @@ def test_dpa_accuracy(dtype, bs, model):
else:
assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)

for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]):
for idx, outs in enumerate(zip(te_outputs[1:], torch_outputs[1:])):
te_output, torch_output = outs
assert_allclose(te_output, torch_output, atol=5e-2, rtol=1e-2)


Expand Down
1 change: 0 additions & 1 deletion transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ cmake_minimum_required(VERSION 3.21)

option(USE_ROCM "Use ROCm" ON)
option(USE_FUSED_ATTN_AOTRITON "Use aotriton backend" ON)
option(USE_FUSED_ATTN_AOTRITON_BUILD_GPU_KERNELS "Build AOTriton GPU kernels" OFF)
option(USE_FUSED_ATTN_CK "Use ck backend" ON)
set(USE_CUDA OFF)

Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/common/aotriton/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ if(NOT DEFINED AOTRITON_PATH)
set(AOTRITON_NOIMAGE_MODE ON)
endif()

set(__AOTRITON_VER "0.11.1b")
set(__AOTRITON_VER "0.11.2b")
set(__AOTRITON_IMAGE_LIST
"amd-gfx942"
"amd-gfx950"
Expand Down
3 changes: 3 additions & 0 deletions transformer_engine/common/fused_attn_rocm/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
fused_attn_aotriton_bwd_qkvpacked(
b, h, max_seqlen, d,
attn_scale, dropout,
window_size_left, window_size_right,
qkv_layout, bias_type, attn_mask_type,
input_QKV, input_O, input_dO, output_S,
output_dQKV,
Expand Down Expand Up @@ -678,6 +679,7 @@ void nvte_fused_attn_bwd_kvpacked(
fused_attn_aotriton_bwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d,
attn_scale, dropout,
window_size_left, window_size_right,
qkv_layout, bias_type, attn_mask_type,
input_Q, input_KV, input_O, input_dO,
output_S,
Expand Down Expand Up @@ -858,6 +860,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
fused_attn_aotriton_bwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk,
attn_scale, dropout,
window_size_left, window_size_right,
qkv_layout, bias_type, attn_mask_type,
input_Q, input_K, input_V, input_O, input_dO,
output_S,
Expand Down
Loading