Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
bf5b217
Changed VERSION to 2.6.0
KshitijLakhani Jul 20, 2025
c7d0271
[PyTorch] Remove GH pinned deps (#1961)
ksivaman Jul 21, 2025
787acff
[PyTorch] Reset FP8 weight workspace if usages are invalid (#1972)
timmoon10 Jul 21, 2025
9926245
Fix the condition error when checking fp8 attn in `get_attention_back…
yuzhongw-nvidia Jul 21, 2025
4b537aa
[Common] Skip cuDNN 9.10.0/9.10.1 due to bugs (#1937)
cyanguwa Jul 21, 2025
7ba6cd5
[PyTorch] Debug linear layer when saving original input and using deb…
timmoon10 Jul 22, 2025
b97c2bf
[Common] Improved performance of mxfp8 cast kernels (#1628)
Oleg-Goncharov Jul 22, 2025
a593092
Fix the device for cuDNN/cuBLAS handles (#1974)
cyanguwa Jul 23, 2025
928dfa8
[JAX] Fix current scaling test_helper.py and enable test_helper.py in…
jberchtold-nvidia Jul 23, 2025
13f5796
[JAX] Helper to disable TE custom calls + disable GemmPrimitive for n…
phu0ngng Jul 24, 2025
e02e289
Fix runtime lib loading for cuDNN (#1989)
ksivaman Jul 24, 2025
21d7410
Fix cudnn versioning support in PyTorch DPA and Fused attn (#1991)
KshitijLakhani Jul 24, 2025
0f585e8
[JAX] Fixing GemmPrimitive partitioning rules to handle tensor-parall…
denera Jul 24, 2025
5f1142e
[PyTorch] Optimize cudagraph static_grad_outputs reuse (#1992)
buptzyb Jul 25, 2025
c90a720
Fix the use-after-free bug in unfused normalization (#2002)
ptrendx Jul 29, 2025
966a4ac
Merge remote-tracking branch 'upstream/release_v2.6' into yewang12/if…
wangye805 Jan 3, 2026
97556c6
[ROCm] Resolve conflicts
wangye805 Dec 17, 2025
3bf9150
[ROCm] address reviewer comments
wangye805 Jan 31, 2026
fdac02a
CI: Serialize core sgpu test (#426)
leo-amd Jan 26, 2026
bf75169
[ROCm] add EnvVarCleaner definition and update copyright years
wangye805 Feb 2, 2026
9b12e69
Enable gfx950 CI on dev branch (#401)
VeeraRajasekhar Jan 9, 2026
a512352
[ROCm] cherrypick the timeout setting for jax distributed pytests hang
wangye805 Feb 3, 2026
4998d8f
Do not fail CI on known failed JAX test (#421)
ipanfilo Jan 19, 2026
e4a079b
Upcoming ROCm and JAX 0.8 support (#403)
ipanfilo Feb 7, 2026
9d233ad
Revert "CI: Serialize core sgpu test (#426)", move gemm_sm_count to m…
ipanfilo Feb 7, 2026
65cf94a
Cerry pick ROCm 7.2 w/a (#404)
ipanfilo Feb 7, 2026
9adaffa
Re-enable supported GEMM configs (#378)
ipanfilo Feb 7, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions .github/workflows/rocm-ci.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
#
# See LICENSE for license information.

Expand Down Expand Up @@ -40,9 +40,13 @@ concurrency:

jobs:
build_and_test:
name: Build and Test on GPU
name: Build and Test on GPU (${{ matrix.runner }})
timeout-minutes: 720
runs-on: linux-mi325-8
runs-on: ${{ matrix.runner }}
strategy:
fail-fast: false
matrix:
runner: [linux-mi325-8, linux-mi355-8]
steps:
- name: Checkout repository
uses: actions/checkout@v4
Expand Down Expand Up @@ -422,7 +426,7 @@ jobs:
if: always()
uses: actions/upload-artifact@v4
with:
name: logs-and-reports
name: logs-and-reports-${{ matrix.runner }}
path: |
*.log
if-no-files-found: ignore
Expand Down
8 changes: 4 additions & 4 deletions benchmarks/attention/benchmark_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
import torch
import nvtx
import transformer_engine
from tests.pytorch.fused_attn.test_fused_attn import (
from tests.pytorch.utils import (
ModelConfig,
_get_attention_backends,
_run_dot_product_attention,
get_available_attention_backends,
)
from tests.pytorch.attention.test_attention import _run_dot_product_attention

pd.set_option("display.precision", 4)

Expand Down Expand Up @@ -197,7 +197,7 @@ def main():
)
for model in model_configs.keys():
config = model_configs[model]
available_backends, fused_attn_backends = _get_attention_backends(
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
Expand Down
32 changes: 18 additions & 14 deletions benchmarks/attention/benchmark_attention_rocm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This file was modified for portability to AMDGPU
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand All @@ -13,15 +13,19 @@
import transformer_engine
from transformer_engine_torch import NVTE_Fused_Attn_Backend

# Add test_fused_attn to the sys path
# Add paths tests/pytorch/ and tests/pytorch/attention to the sys path
tests_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../../tests/pytorch/fused_attn")
os.path.join(os.path.dirname(__file__), "../../tests")
)
sys.path.append(tests_path)
sys.path.append(tests_path + "/pytorch")
sys.path.append(tests_path + "/pytorch/attention")

from test_fused_attn import (
# Add tests/pytorch/utils.py path into sys path
from utils import (
ModelConfig,
_get_attention_backends,
get_available_attention_backends,
)
from test_attention import (
_run_dot_product_attention,
)

Expand All @@ -46,12 +50,12 @@
is_training = True

model_configs = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), # short seq
"test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask
"test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias
"test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA
"test_4": ModelConfig(2, 128, 8, 128, 8192, 8192, 0.0, "causal_bottom_right", "no_bias")
# test: b, sq, h, d
"test_0": ModelConfig(2, 512, 16, 64), # short seq
"test_1": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), # longer seq, mask
"test_2": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"), # bias
"test_3": ModelConfig(2, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"), # GQA
"test_4": ModelConfig(2, 8192, 128, 128, num_gqa_groups=8, attn_mask_type="causal_bottom_right")
}

# DataFrame indices and columns for results
Expand Down Expand Up @@ -303,7 +307,7 @@ def sanity_checks(
}

for model, cfg in model_configs.items():
avail, _, fused_bes = _get_attention_backends(
avail, _, fused_bes = get_available_attention_backends(
cfg,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
Expand Down Expand Up @@ -364,7 +368,7 @@ def main(args):
# Benchmarking starts..
for model in model_configs.keys():
config = model_configs[model]
available_backends, _, fused_attn_backends = _get_attention_backends(
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
Expand Down
2 changes: 1 addition & 1 deletion build_tools/VERSION.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.6.0.dev0
2.6.0
15 changes: 1 addition & 14 deletions build_tools/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,7 @@

def install_requirements() -> List[str]:
"""Install dependencies for TE/PyTorch extensions."""
reqs = ["einops"]
if not rocm_build():
reqs.append(
"nvdlfw-inspect @"
" git+https://github.com/NVIDIA/[email protected]#egg=nvdlfw-inspect"
)
reqs.extend(
[
"torch>=2.1",
"onnx",
"onnxscript@git+https://github.com/microsoft/onnxscript.git@51ecf47523ef079c53b0e620c62d56d70cfd3871",
]
)
return reqs
return ["torch>=2.1", "einops", "onnxscript==0.3.1", "onnx"]


def test_requirements() -> List[str]:
Expand Down
10 changes: 9 additions & 1 deletion ci/_utils.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
#
# See LICENSE for license information.

Expand All @@ -25,20 +25,28 @@ export CI=1

_script_error_count=0
_run_error_count=0
_ignored_error_count=0
TEST_ERROR_IGNORE=""

script_error() {
_script_error_count=$((_script_error_count+1))
test "$@" && echo $@ >&2
}

test_run_error() {
if [ -n "$TEST_ERROR_IGNORE" ]; then
_ignored_error_count=$((_ignored_error_count+1))
test -n "$@" && echo "Ignore error in test $@" >&2
return
fi
_run_error_count=$((_run_error_count+1))
test -n "$@" && echo "Error in test $@" >&2
}

return_run_results() {
test $_script_error_count -ne 0 && echo Detected $_script_error_count script errors during tests run at level $TEST_LEVEL >&2
test $_run_error_count -ne 0 && echo Got $_run_error_count test errors during run at level $TEST_LEVEL >&2
test $_ignored_error_count -ne 0 && echo Ignored $_ignored_error_count test errors during run at level $TEST_LEVEL >&2
test $_run_error_count -eq 0 -a $_script_error_count -eq 0
}

Expand Down
2 changes: 1 addition & 1 deletion ci/ci_config.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"docker_images": {
"default": "registry-sc-harbor.amd.com/framework/te-ci:rocm-7.0.2_ubuntu22.04_py3.10_pytorch_release-2.7_9015dfdf_jax_v0.6.0_fa-v2.8.0",
"default": "registry-sc-harbor.amd.com/framework/te-ci:rocm-7.1.1_ubuntu22.04_py3.11_pytorch_release_2.8_63e525b2_jax_0.7.1_fa-2.8.0",
"release_v1.13": "compute-artifactory.amd.com:5000/rocm-plus-docker/framework/private/te-ci:rocm-6.4_0_ubuntu22_py310_torch25_jax0435qa_fa273",
"release_v1.14": "compute-artifactory.amd.com:5000/rocm-plus-docker/framework/private/te-ci:rocm-6.4_0_ubuntu22_py310_torch25_jax0435qa_fa273"
}
Expand Down
2 changes: 1 addition & 1 deletion ci/core.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/sh
# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
#
# See LICENSE for license information.

Expand Down
25 changes: 21 additions & 4 deletions ci/jax.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/sh
# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
#
# See LICENSE for license information.

Expand All @@ -21,6 +21,12 @@ install_prerequisites() {
script_error "Failed to install Flax and dependencies"
return $rc
fi
pip install pytest-timeout
rc=$?
if [ $rc -ne 0 ]; then
script_error "Failed to install test prerequisites"
exit $rc
fi
}

TEST_DIR=${TE_PATH}tests/jax
Expand Down Expand Up @@ -62,15 +68,26 @@ run_test_config() {
run_test_config_mgpu() {
echo ==== Run mGPU with Fused attention backend: $_fus_attn ====
configure_omp_threads 8

# Mitigate distributed tests hang by adding 5min timeout
_timeout_args="--timeout 300 --timeout-method thread"
# Workaround for some distributed tests hang/abotrion
export XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false"

if [ $_fus_attn = $_DEFAULT_FUSED_ATTN ]; then
_dfa_level=2
else
_dfa_level=3
fi
# Workaround for distributed tests hang with xla_flag
XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" run $_dfa_level test_distributed_fused_attn.py
# Do not fail automated CI if test_distributed_fused_attn is hung
# If the sctipt run w/o TEST_LEVEL the test error will be honored
if [ "$TEST_LEVEL" -le 3 ]; then
TEST_ERROR_IGNORE="1"
fi
run $_dfa_level test_distributed_fused_attn.py $_timeout_args
TEST_ERROR_IGNORE=""
run_default_fa 3 test_distributed_layernorm.py
XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" run_default_fa 2 test_distributed_layernorm_mlp.py
run_default_fa 2 test_distributed_layernorm_mlp.py $_timeout_args
run_default_fa 3 test_distributed_softmax.py

run_default_fa 3 test_sanity_import.py
Expand Down
13 changes: 8 additions & 5 deletions ci/pytorch.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/sh
# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
#
# See LICENSE for license information.

Expand Down Expand Up @@ -56,7 +56,6 @@ run_test_config(){
run_default_fa 1 test_fused_router.py
run_default_fa 1 test_fusible_ops.py
run_default_fa 1 test_gemm_autotune.py
run_default_fa 1 test_gemm_sm_count.py
run 1 test_gqa.py
run 1 test_jit.py
run_default_fa 1 test_multi_tensor.py
Expand All @@ -65,7 +64,7 @@ run_test_config(){
run_default_fa 1 test_recipe.py
run 1 test_sanity.py
run_default_fa 1 test_sanity_import.py
run_default_fa 1 fused_attn/test_fused_attn.py # Backend selection is controlled by the test
run_default_fa 1 attention/test_attention.py # Backend selection is controlled by the test
run_default_fa 1 triton_kernels/test_cast.py
run_default_fa 1 triton_kernels/test_cast_mxfp8.py
run_default_fa 1 triton_kernels/test_norm_common.py
Expand All @@ -83,13 +82,17 @@ run_test_config_mgpu(){
echo ==== Run mGPU with Fused attention backend: $_fus_attn ====
configure_omp_threads 8
run_default_fa 1 test_fused_optimizer.py
#this test is not really mGPU but time sensitive so run it here because sGPU tests
#run in parallel on CI and it affects timing
run_default_fa 1 test_gemm_sm_count.py
run_default_fa 3 test_sanity_import.py
run_default_fa 1 distributed/test_sanity.py
run_default_fa 2 distributed/test_fusible_ops.py
run_default_fa 2 distributed/test_numerics.py
run_default_fa 1 distributed/test_torch_fsdp2.py
run_default_fa 2 distributed/test_torch_fsdp2_fp8.py
run_default_fa_lbl "flash" 3 fused_attn/test_fused_attn_with_cp.py -k "with_flash"
run_default_fa_lbl "fused" 2 fused_attn/test_fused_attn_with_cp.py -k "with_fused"
run_default_fa_lbl "flash" 3 attention/test_attention_with_cp.py -k "with_flash"
run_default_fa_lbl "fused" 2 attention/test_attention_with_cp.py -k "with_fused"
}

run_benchmark() {
Expand Down
4 changes: 2 additions & 2 deletions docs/debug/1_getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Transformer Engine provides a set of precision debug tools which allow you to ea
There are 4 things one needs to do to use Transformer Engine debug features:

1. Create a configuration YAML file to configure the desired features.
2. Import, and initialize the `Nvidia-DL-Framework-Inspect <https://github.com/NVIDIA/nvidia-dlfw-inspect>`_ tool, which is installed as the dependency of the Transformer Engine.
2. Import, initialize, and install the `Nvidia-DL-Framework-Inspect <https://github.com/NVIDIA/nvidia-dlfw-inspect>`_ tool.
3. One can pass ``name="..."`` when creating TE layers to easier identify layer names. If this is not provided, names will be inferred automatically.
4. Invoke ``debug_api.step()`` at the end of one forward-backward pass.

Expand Down Expand Up @@ -238,4 +238,4 @@ Let's run training and open TensorBoard by ``tensorboard --logdir=./tensorboard_
.. figure:: ./img/tensorboard.png
:align: center

Fig 2: TensorBoard with plotted stats.
Fig 2: TensorBoard with plotted stats.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import torch
from typing import Tuple
from tests.pytorch.fused_attn.test_fused_attn import ModelConfig
from tests.pytorch.utils import ModelConfig
from transformer_engine.pytorch.attention import DotProductAttention

# Initialize RNG state
Expand Down
18 changes: 9 additions & 9 deletions docs/examples/attention/attention.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@
"\n",
"Our [unit tests](https://github.com/NVIDIA/TransformerEngine/tree/main/tests) demonstrate the use of Transformer Engine dot product attention APIs. Users are encouraged to use them as a template when integrating Transformer Engine to their ML workflows.\n",
"\n",
"For example, in PyTorch, [test_dot_product_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) offers a variety of use cases of `pytorch.DotProductAttention`, from data types, model configs, checkpointing, to QKV layouts."
"For example, in PyTorch, [test_dot_product_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py) offers a variety of use cases of `pytorch.DotProductAttention`, from data types, model configs, checkpointing, to QKV layouts."
]
},
{
Expand All @@ -394,10 +394,10 @@
"| Framework-native attention | BF16, FP16, FP32 | Any | No, unless used as a mask | Yes | Yes (PyTorch only) | No | Yes |\n",
"\n",
"Some unit tests are provided to serve as a starting point for integrating such features into users' models. For example,\n",
"- sliding window attention: [test_dpa_swa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n",
"- MQA/GQA: [test_te_layer_mqa_gqa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n",
"- Multi-Latent Attention: [test_dpa_mla](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n",
"- context parallelism: [test_cp_with_fused_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn_with_cp.py), [test_cp_with_flash_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn_with_cp.py)"
"- sliding window attention: [test_dpa_swa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py)\n",
"- MQA/GQA: [test_te_layer_mqa_gqa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py)\n",
"- Multi-Latent Attention: [test_dpa_mla](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py)\n",
"- context parallelism: [test_cp_with_fused_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention_with_cp.py), [test_cp_with_flash_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention_with_cp.py)"
]
},
{
Expand Down Expand Up @@ -458,7 +458,7 @@
" </tr>\n",
"</table>\n",
"\n",
"Some example usage of the different layouts can be found at [test_dpa_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_dpa_qkv_layout_thd](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.attention.dot_product_attention.utils.get_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n",
"Some example usage of the different layouts can be found at [test_dpa_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py) and [test_dpa_qkv_layout_thd](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.attention.dot_product_attention.utils.get_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n",
"\n",
"<div class=\"alert alert-info\">\n",
"<b>Note</b>\n",
Expand Down Expand Up @@ -548,7 +548,7 @@
"id": "dda4a589",
"metadata": {},
"source": [
"Some more examples of running Transformer Engine with different attention masks can be found at [test_dpa_mask](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py).\n",
"Some more examples of running Transformer Engine with different attention masks can be found at [test_dpa_mask](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py).\n",
"\n",
"### 3.3 Attention Bias\n",
"\n",
Expand Down Expand Up @@ -594,7 +594,7 @@
"\n",
"The framework-native backends do not explicitly support `ALiBi`, but users can convert `ALiBi` to a regular `post_scale_bias` bias to achieve the same effect. In PyTorch, this utility function, `transformer_engine.pytorch.attention.get_alibi`, can be used to help with the conversion.\n",
"\n",
"More examples of how to use the various attention biases are at [test_dpa_bias](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)."
"More examples of how to use the various attention biases are at [test_dpa_bias](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py)."
]
},
{
Expand All @@ -612,7 +612,7 @@
"\n",
"- `DelayedScaling.fp8_mha=True (default=False)`: This option, on top of `fp8_dpa=True`, removes the casting operations at the beginning and end of the `FusedAttention` module. This feature is experimental. \n",
"\n",
"Examples of using the two features are available at [test_dpa_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_mha_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). To disable FP8 attention for backward and only use it for forward, users can also set `NVTE_FP8_DPA_BWD=0 (default=1)`."
"Examples of using the two features are available at [test_dpa_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py) and [test_mha_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py). To disable FP8 attention for backward and only use it for forward, users can also set `NVTE_FP8_DPA_BWD=0 (default=1)`."
]
}
],
Expand Down
Loading