diff --git a/3rdparty/aotriton b/3rdparty/aotriton index 98371989e..dd1b68b60 160000 --- a/3rdparty/aotriton +++ b/3rdparty/aotriton @@ -1 +1 @@ -Subproject commit 98371989e8a23267e284c94e95156a139e4b33c4 +Subproject commit dd1b68b604b5258ee7a9f7b66ad95e7a82c18065 diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 337bc1646..4cd57b945 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -723,6 +723,7 @@ def test_dpa_alibi_slopes(dtype, model_configs, model): "layout_2_1": ModelConfig( 2, 2048, 24, 256, attn_mask_type="causal", attn_bias_type="post_scale_bias" ), + "layout_3_0": ModelConfig(1, 2048, 12, 64, attn_mask_type="causal"), } @@ -1281,17 +1282,60 @@ def test_transformer_layer( # FusedAttention backend if fused_attn_supported: - fused_attn_fwd, fused_attn_bwd = _run_transformer_layer( - dtype, - config, - "FusedAttention", - ckpt_attn, - qkv_format, - workspace_opt, - fused_qkv_params, - RoPE, - is_training, - ) + if len(fused_attn_backends) == 1: + fused_attn_fwd, fused_attn_bwd = _run_transformer_layer( + dtype, + config, + "FusedAttention", + ckpt_attn, + qkv_format, + workspace_opt, + fused_qkv_params, + RoPE, + is_training, + ) + elif len(fused_attn_backends) == 2: + os.environ["NVTE_FUSED_ATTN_CK"] = "0" + os.environ["NVTE_FUSED_ATTN_AOTRITON"] = "1" + fused_attn_fwd, fused_attn_bwd = _run_transformer_layer( + dtype, + config, + "FusedAttention", + ckpt_attn, + qkv_format, + workspace_opt, + fused_qkv_params, + RoPE, + is_training, + ) + os.environ["NVTE_FUSED_ATTN_CK"] = "1" + os.environ["NVTE_FUSED_ATTN_AOTRITON"] = "0" + fused_attn_fwd_1, fused_attn_bwd_1 = _run_transformer_layer( + dtype, + config, + "FusedAttention", + ckpt_attn, + qkv_format, + workspace_opt, + fused_qkv_params, + RoPE, + is_training, + ) + + os.environ["NVTE_CK_USES_FWD_V3"] = "1" + os.environ["NVTE_CK_USES_BWD_V3"] = "1" + fused_attn_fwd_2, fused_attn_bwd_2 = _run_transformer_layer( + dtype, + config, + "FusedAttention", + ckpt_attn, + qkv_format, + workspace_opt, + fused_qkv_params, + RoPE, + is_training, + ) + # FlashAttention backend if flash_attn_supported: @@ -1320,6 +1364,15 @@ def test_transformer_layer( logging.info("[test_transformer_layer]: fused attn vs flash attn") torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols) torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, **tols) + if IS_HIP_EXTENSION and fused_attn_supported and len(fused_attn_backends) == 2: + logging.info("[test_transformer_layer]: fused attn backend 0 vs 1") + torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols) + for i, _ in enumerate(fused_attn_bwd): + torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols) + logging.info("[test_transformer_layer]: fused attn backend 0 vs 2") + torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_2, **tols) + for i, _ in enumerate(fused_attn_bwd): + torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_2[i], **tols) @pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index bc29d29e3..e0537a818 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -1212,7 +1212,6 @@ def _test_dpa_accuracy(block, bs, dtype, config): query.retain_grad() key.retain_grad() value.retain_grad() - out = block(query, key, value, attention_mask=mask) loss = out.sum() loss.backward() @@ -1256,7 +1255,8 @@ def test_dpa_accuracy(dtype, bs, model): else: assert_allclose(te_outputs[0], torch_outputs[0], 5e-2) - for te_output, torch_output in zip(te_outputs[1:], torch_outputs[1:]): + for idx, outs in enumerate(zip(te_outputs[1:], torch_outputs[1:])): + te_output, torch_output = outs assert_allclose(te_output, torch_output, atol=5e-2, rtol=1e-2) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index cefec6d06..50dcf90a0 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -8,7 +8,6 @@ cmake_minimum_required(VERSION 3.21) option(USE_ROCM "Use ROCm" ON) option(USE_FUSED_ATTN_AOTRITON "Use aotriton backend" ON) -option(USE_FUSED_ATTN_AOTRITON_BUILD_GPU_KERNELS "Build AOTriton GPU kernels" OFF) option(USE_FUSED_ATTN_CK "Use ck backend" ON) set(USE_CUDA OFF) diff --git a/transformer_engine/common/aotriton/CMakeLists.txt b/transformer_engine/common/aotriton/CMakeLists.txt index 7656f9a66..4f823dc2e 100644 --- a/transformer_engine/common/aotriton/CMakeLists.txt +++ b/transformer_engine/common/aotriton/CMakeLists.txt @@ -20,7 +20,7 @@ if(NOT DEFINED AOTRITON_PATH) set(AOTRITON_NOIMAGE_MODE ON) endif() - set(__AOTRITON_VER "0.11.1b") + set(__AOTRITON_VER "0.11.2b") set(__AOTRITON_IMAGE_LIST "amd-gfx942" "amd-gfx950" diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp index bb5e22887..d8969b2ab 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp @@ -492,6 +492,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con fused_attn_aotriton_bwd_qkvpacked( b, h, max_seqlen, d, attn_scale, dropout, + window_size_left, window_size_right, qkv_layout, bias_type, attn_mask_type, input_QKV, input_O, input_dO, output_S, output_dQKV, @@ -678,6 +679,7 @@ void nvte_fused_attn_bwd_kvpacked( fused_attn_aotriton_bwd_kvpacked( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, + window_size_left, window_size_right, qkv_layout, bias_type, attn_mask_type, input_Q, input_KV, input_O, input_dO, output_S, @@ -858,6 +860,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso fused_attn_aotriton_bwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout, + window_size_left, window_size_right, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, input_O, input_dO, output_S, diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp index a8a151b40..f3badae2c 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp @@ -208,7 +208,6 @@ void fused_attn_aotriton_fwd_impl( nvte_log_aotriton_config = true; } aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, dtype); - using aotriton::v3::flash::attn_fwd; auto seed = mk_aoscalartensor(devPtrDropoutSeed); auto offset1 = mk_aoscalartensor(devPtrDropoutOffset); auto seed_output = mk_aoscalartensor(nullptr); @@ -293,18 +292,30 @@ void fused_attn_aotriton_fwd_impl( fwd_params.window_right = window_right; NVTE_CHECK_CUDA(hipMemsetAsync(workspace, 0, sizeof(int32_t), stream)); + using aotriton::v3::flash::attn_fwd; NVTE_CHECK_CUDA(attn_fwd(fwd_params, fwd_params.kVersion, stream)); } +// A thin conversion wrapper around eager tensor-views to lazy tensors +template +struct LazyTensorFunctions { + static aotriton::TensorView acquire(void* cookie) { + return *static_cast*>(cookie); + } + static void dispose(void* cookie) { + } +}; + void fused_attn_aotriton_bwd_impl( uint64_t b, uint64_t h, uint64_t hg, uint64_t s_q, uint64_t s_kv, uint64_t d, float scaling_factor, float dropout_probability, - NVTE_QKV_Layout layout, + int32_t window_size_left, int32_t window_size_right, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrO, void* devPtrSoftmaxAux, void* devPtrdQ, void* devPtrdK, void* devPtrdV, void* devPtrdO, + void* devPtrCuSeqlensQ, void* devPtrCuSeqlensKV, const uint64_t* devPtrDropoutSeed, const uint64_t* devPtrDropoutOffset, aotriton::DType dtype, @@ -312,12 +323,20 @@ void fused_attn_aotriton_bwd_impl( size_t *workspace_size, cudaStream_t stream) { + const uint64_t dq_acc_size = b*s_q*h*d*sizeof(float); + // Exit to request upper level API to allocate memory if needed if(workspace==nullptr){ - // CK only requires workspace for lse softmax + // AOTriton requires workspace for lse softmax *workspace_size = b*h*s_q*sizeof(float); + // AOTriton requires workspace for DQ_ACC + *workspace_size += dq_acc_size; return; } + void * delta = workspace; + workspace = static_cast(static_cast(workspace) + b*h*s_q*sizeof(float)); + void * dq_acc_ptr = workspace; + std::array q_stride; std::array k_stride; std::array v_stride; @@ -330,6 +349,8 @@ void fused_attn_aotriton_bwd_impl( layout, NVTE_QKV_Matrix::NVTE_V_Matrix); generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout, NVTE_QKV_Matrix::NVTE_O_Matrix); + // AOTriton expects a **BHSD** layout DQ_ACC matrix + std::array dq_acc_stride {h * s_q * d, s_q * d, d, 1}; //q and o are having the same shape //k and v are having the same shape @@ -337,7 +358,7 @@ void fused_attn_aotriton_bwd_impl( std::array q_shape{b, h, s_q, d}; std::array kv_shape{b, hg, s_kv, d}; - // m and workspace are of the same shape and stride + // m and softmax_lse are of the same shape and stride std::array m_shape{b * h, s_q}; std::array m_stride{s_q, 1}; @@ -355,13 +376,44 @@ void fused_attn_aotriton_bwd_impl( // auxilary tensors auto M_tensor = aotriton::TensorView<2>(reinterpret_cast(devPtrSoftmaxAux), m_shape, m_stride, aotriton::DType::kFloat32); - auto wkspace_tensor = aotriton::TensorView<2>(reinterpret_cast(workspace), m_shape, m_stride, aotriton::DType::kFloat32); + auto delta_tensor = aotriton::TensorView<2>(reinterpret_cast(delta), m_shape, m_stride, aotriton::DType::kFloat32); + auto dq_acc_tensor = aotriton::TensorView<4>(reinterpret_cast(dq_acc_ptr), q_shape, dq_acc_stride, aotriton::DType::kFloat32); + NVTE_CHECK_CUDA(hipMemsetAsync(dq_acc_ptr, 0, dq_acc_size, stream)); + + auto dq_acc_lazy = aotriton::LazyTensor<4> { + .cookie = &dq_acc_tensor, + .acquire = &LazyTensorFunctions<4>::acquire, + .dispose = &LazyTensorFunctions<4>::dispose + }; + auto delta_lazy = aotriton::LazyTensor<2> { + .cookie = &delta_tensor, + .acquire = &LazyTensorFunctions<2>::acquire, + .dispose = &LazyTensorFunctions<2>::dispose + }; + + // Cumulative seqlen tensors + std::array cu_seqlens_shape{b+1}; + std::array cu_seqlens_stride{1}; + auto cu_seqlens_q = aotriton::TensorView<1>(reinterpret_cast(devPtrCuSeqlensQ), cu_seqlens_shape, cu_seqlens_stride, aotriton::DType::kInt32); + auto cu_seqlens_k = aotriton::TensorView<1>(reinterpret_cast(devPtrCuSeqlensKV), cu_seqlens_shape, cu_seqlens_stride, aotriton::DType::kInt32); bool nvte_log_aotriton_config = false; if (const char* env_p = std::getenv("NVTE_LOG_AOTRITON_CONFIG") ) { if (env_p != nullptr && std::string(env_p) == "1") nvte_log_aotriton_config = true; } + aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, dtype); + auto seed = mk_aoscalartensor(devPtrDropoutSeed); + auto offset = mk_aoscalartensor(devPtrDropoutOffset); + const auto is_causal = mask_type == NVTE_CAUSAL_MASK; + + using aotriton::v3::flash::VarlenType; + int8_t varlen_type = VarlenType::None; + + auto [window_left, window_right] = get_window_sizes(window_size_left, window_size_right, is_causal); + using aotriton::v3::flash::CausalType; + int8_t causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None; + if (nvte_log_aotriton_config) { std::cout< empty_bias(0, {0,0,0,0}, {0,0,0,0}, dtype); - using aotriton::v2::flash::attn_bwd; - auto seed = mk_aoscalartensor(devPtrDropoutSeed); - auto offset = mk_aoscalartensor(devPtrDropoutOffset); - const auto is_causal = mask_type == NVTE_CAUSAL_MASK; - NVTE_CHECK_CUDA(attn_bwd(q_tensor, - k_tensor, - v_tensor, - empty_bias, - scaling_factor, - o_tensor, - do_tensor, - dq_tensor, - dk_tensor, - dv_tensor, - empty_bias, - M_tensor, - wkspace_tensor, - dropout_probability, - seed, - offset, - 0, - is_causal, - stream)); + aotriton::v3::flash::attn_bwd_params bwd_params{}; + bwd_params.Q = q_tensor; + bwd_params.K = k_tensor; + bwd_params.V = v_tensor; + bwd_params.B = empty_bias; + bwd_params.Sm_scale = scaling_factor; + bwd_params.Out = o_tensor; + if(varlen_type){ + bwd_params.cu_seqlens_q = cu_seqlens_q; + bwd_params.cu_seqlens_k = cu_seqlens_k; + bwd_params.Max_seqlen_q = s_q; + bwd_params.Max_seqlen_k = s_kv; + } + bwd_params.DO = do_tensor; + bwd_params.DK = dk_tensor; + bwd_params.DV = dv_tensor; + bwd_params.DQ = dq_tensor; + bwd_params.DB = empty_bias; + bwd_params.L = M_tensor; + bwd_params.D = delta_lazy; + bwd_params.dropout_p = dropout_probability; + bwd_params.philox_seed_ptr = seed; + bwd_params.philox_offset1 = offset; + bwd_params.philox_offset2 = 0; + bwd_params.causal_type = causal_type; + bwd_params.varlen_type = varlen_type; + bwd_params.window_left = window_left; + bwd_params.window_right = window_right; + bwd_params.DQ_ACC = dq_acc_lazy; + + using aotriton::v3::flash::attn_bwd; + NVTE_CHECK_CUDA(attn_bwd(bwd_params, bwd_params.kVersion, stream)); } #endif // USE_FUSED_ATTN_AOTRITON } // namespace fused_attn_rocm @@ -495,6 +580,7 @@ void fused_attn_aotriton_fwd_qkvpacked( void fused_attn_aotriton_bwd_qkvpacked( size_t b, size_t h, size_t max_seqlen, size_t d, float attn_scale, float dropout, + int32_t window_size_left, int32_t window_size_right, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_QKV, const Tensor* input_O, const Tensor* input_dO, const Tensor* output_S, @@ -532,12 +618,14 @@ void fused_attn_aotriton_bwd_qkvpacked( fused_attn_aotriton_bwd_impl( b, h, h, max_seqlen, max_seqlen, d, attn_scale, dropout, + window_size_left, window_size_right, qkv_layout, bias_type, attn_mask_type, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, + input_cu_seqlens->data.dptr, input_cu_seqlens->data.dptr, reinterpret_cast(rng_state->data.dptr), reinterpret_cast(rng_state->data.dptr) + 1, nvte_to_aotriton_dtype(QKV_type), @@ -652,6 +740,7 @@ void fused_attn_aotriton_fwd_kvpacked( void fused_attn_aotriton_bwd_kvpacked( size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, float attn_scale, float dropout, + int32_t window_size_left, int32_t window_size_right, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_Q, const Tensor* input_KV, const Tensor* input_O, const Tensor* input_dO, const Tensor* output_S, @@ -692,12 +781,14 @@ void fused_attn_aotriton_bwd_kvpacked( fused_attn_aotriton_bwd_impl( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, + window_size_left, window_size_right, qkv_layout, bias_type, attn_mask_type, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, + input_cu_seqlens_q->data.dptr, input_cu_seqlens_kv->data.dptr, reinterpret_cast(rng_state->data.dptr), reinterpret_cast(rng_state->data.dptr) + 1, nvte_to_aotriton_dtype(QKV_type), @@ -803,6 +894,7 @@ void fused_attn_aotriton_fwd( void fused_attn_aotriton_bwd( size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, float attn_scale, float dropout, + int32_t window_size_left, int32_t window_size_right, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, const Tensor* input_O, const Tensor* input_dO, const Tensor* output_S, @@ -831,12 +923,14 @@ void fused_attn_aotriton_bwd( fused_attn_aotriton_bwd_impl( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, + window_size_left, window_size_right, qkv_layout, bias_type, attn_mask_type, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, + input_cu_seqlens_q->data.dptr, input_cu_seqlens_kv->data.dptr, reinterpret_cast(rng_state->data.dptr), reinterpret_cast(rng_state->data.dptr) + 1, nvte_to_aotriton_dtype(QKV_type), diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h index b016acc67..3fdb359d1 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h @@ -47,6 +47,7 @@ void fused_attn_aotriton_fwd_qkvpacked( void fused_attn_aotriton_bwd_qkvpacked( size_t b, size_t h, size_t max_seqlen, size_t d, float attn_scale, float dropout, + int32_t window_size_left, int32_t window_size_right, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_QKV, const Tensor* input_O, const Tensor* input_dO, const Tensor* output_S, @@ -72,6 +73,7 @@ void fused_attn_aotriton_fwd_kvpacked( void fused_attn_aotriton_bwd_kvpacked( size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, float attn_scale, float dropout, + int32_t window_size_left, int32_t window_size_right, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_Q, const Tensor* input_KV, const Tensor* input_O, const Tensor* input_dO, const Tensor* output_S, @@ -98,6 +100,7 @@ void fused_attn_aotriton_fwd( void fused_attn_aotriton_bwd( size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, float attn_scale, float dropout, + int32_t window_size_left, int32_t window_size_right, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, const Tensor* input_O, const Tensor* input_dO, const Tensor* output_S,