diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index eb346373b0..9c3b9ffd35 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -3964,11 +3964,7 @@ def aten_ops_linear( def scaled_dot_product_attention_validator( node: Node, settings: Optional[CompilationSettings] = None ) -> bool: - if node.kwargs.get("enable_gqa", False): - _LOGGER.debug( - "enable_gqa is not yet supported by the converter. Please try setting decompose_attention=True in the compilation settings." - ) - return False + enable_gqa = node.kwargs.get("enable_gqa", False) query_shape, key_shape, value_shape = None, None, None if "val" in node.args[0].meta: @@ -3977,15 +3973,51 @@ def scaled_dot_product_attention_validator( key_shape = node.args[1].meta["val"].size() if "val" in node.args[2].meta: value_shape = node.args[2].meta["val"].size() - if ( - query_shape != key_shape - or query_shape != value_shape - or key_shape != value_shape - ): + + if key_shape != value_shape: _LOGGER.debug( - "query, key, and value have different shapes. Please try setting decompose_attention=True in the compilation settings." + "key and value have different shapes, which is not supported. Please try setting decompose_attention=True in the compilation settings." ) return False + + if query_shape is not None and key_shape is not None: + if len(query_shape) != len(key_shape): + _LOGGER.debug("query and key have different ranks, which is not supported.") + return False + ndim = len(query_shape) + if enable_gqa: + # IAttentionLayer natively supports GQA: Q and K/V may differ on the + # head dim (dim 1) as long as Hq is divisible by Hkv. + # Check batch (dim 0) and head_dim (last dim) match; skip seq (dim -2) + # and head (dim 1) dims. + head_dim = ndim - 1 + seq_dim = ndim - 2 + heads_dim = 1 + for i in range(ndim): + if i in (seq_dim, heads_dim): + continue + if query_shape[i] != key_shape[i]: + _LOGGER.debug( + f"query and key mismatch on dim {i} with enable_gqa=True." + ) + return False + num_q_heads = query_shape[1] + num_kv_heads = key_shape[1] + if num_q_heads % num_kv_heads != 0: + _LOGGER.debug( + f"enable_gqa=True but num_q_heads={num_q_heads} is not divisible " + f"by num_kv_heads={num_kv_heads}." + ) + return False + else: + # IAttentionLayer supports decode-phase (seq_q != seq_k). + # Check all dims except the seq dim. + seq_dim = ndim - 2 + if any(query_shape[i] != key_shape[i] for i in range(ndim) if i != seq_dim): + _LOGGER.debug( + "query and key have incompatible shapes (batch, heads, or head_dim mismatch). Please try setting decompose_attention=True in the compilation settings." + ) + return False return True @@ -4032,15 +4064,42 @@ def scaled_dot_product_flash_attention_validator( key_shape = node.args[1].meta["val"].size() if "val" in node.args[2].meta: value_shape = node.args[2].meta["val"].size() - if ( - query_shape != key_shape - or query_shape != value_shape - or key_shape != value_shape - ): + if key_shape != value_shape: _LOGGER.debug( - "query, key, and value have different shapes. Please try setting decompose_attention=True in the compilation settings." + "key and value have different shapes, which is not supported. Please try setting decompose_attention=True in the compilation settings." ) return False + if query_shape is not None and key_shape is not None: + if len(query_shape) != len(key_shape): + _LOGGER.debug("query and key have different ranks, which is not supported.") + return False + ndim = len(query_shape) + seq_dim = ndim - 2 + heads_dim = 1 + num_q_heads = query_shape[heads_dim] + num_kv_heads = key_shape[heads_dim] + is_gqa = num_q_heads != num_kv_heads + if is_gqa: + # GQA: IAttentionLayer natively handles Hq != Hkv. + # Require batch/head_dim to match and Hq divisible by Hkv. + for i in range(ndim): + if i in (seq_dim, heads_dim): + continue + if query_shape[i] != key_shape[i]: + _LOGGER.debug(f"GQA: query and key mismatch on dim {i}.") + return False + if num_q_heads % num_kv_heads != 0: + _LOGGER.debug( + f"GQA: num_q_heads={num_q_heads} not divisible by num_kv_heads={num_kv_heads}." + ) + return False + else: + # MHA / decode-phase: seq may differ, all other dims must match. + if any(query_shape[i] != key_shape[i] for i in range(ndim) if i != seq_dim): + _LOGGER.debug( + "query and key have incompatible shapes (batch, heads, or head_dim mismatch). Please try setting decompose_attention=True in the compilation settings." + ) + return False return True @@ -4086,15 +4145,31 @@ def scaled_dot_product_efficient_attention_validator( key_shape = node.args[1].meta["val"].size() if "val" in node.args[2].meta: value_shape = node.args[2].meta["val"].size() - if ( - query_shape != key_shape - or query_shape != value_shape - or key_shape != value_shape - ): + if key_shape != value_shape: _LOGGER.debug( - "query, key, and value have different shapes. Please try setting decompose_attention=True in the compilation settings." + "key and value have different shapes, which is not supported. Please try setting decompose_attention=True in the compilation settings." ) return False + # GQA (Hq != Hkv) is intentionally not supported here. + # PyTorch's eager _scaled_dot_product_efficient_attention kernel rejects + # non-equal head counts at runtime, so no valid reference output exists for + # comparison. In practice, GQA models on CUDA dispatch to + # _scaled_dot_product_flash_attention (FP16/BF16) or decompose into + # matmul+_safe_softmax (FP32) — this op never appears with GQA shapes in + # a real FX graph. GQA is handled by the flash attention validator instead. + # + # IAttentionLayer does support decode-phase (seq_q != seq_k), so only the + # sequence dimension is skipped in the shape check below. + if query_shape is not None and key_shape is not None: + if len(query_shape) != len(key_shape) or any( + query_shape[i] != key_shape[i] + for i in range(len(query_shape)) + if i != len(query_shape) - 2 # skip the seq dim + ): + _LOGGER.debug( + "query and key have incompatible shapes (batch, heads, or head_dim mismatch). Please try setting decompose_attention=True in the compilation settings." + ) + return False return True diff --git a/py/torch_tensorrt/dynamo/lowering/passes/force_causal_efficient_attention.py b/py/torch_tensorrt/dynamo/lowering/passes/force_causal_efficient_attention.py index 254bf1183a..8829044b8b 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/force_causal_efficient_attention.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/force_causal_efficient_attention.py @@ -12,7 +12,17 @@ def force_causal_efficient_attention( gm: torch.fx.GraphModule, settings: CompilationSettings ) -> torch.fx.GraphModule: - """Force efficient-attention calls to causal mode when enabled in settings.""" + """Force efficient-attention calls to causal mode when enabled in settings. + + For square attention (seq_q == seq_k): replaces attn_bias with is_causal=True + so IAttentionLayer can use its native causal path. + + For decode-phase attention (seq_q != seq_k): skip the transformation. + Applying is_causal=True is semantically wrong here — it creates a lower- + triangular mask aligned to position 0, so the query attends only to k[0] + instead of all past keys. The node is left unchanged and passed to + IAttentionLayer, which supports non-square Q/K natively. + """ if not settings.attn_bias_is_causal: return gm @@ -20,24 +30,39 @@ def force_causal_efficient_attention( for node in gm.graph.nodes: if ( node.target - == torch.ops.aten._scaled_dot_product_efficient_attention.default + != torch.ops.aten._scaled_dot_product_efficient_attention.default + ): + continue + + attn_bias = node.args[3] if len(node.args) > 3 else None + if attn_bias is None: + continue + + query_node, key_node = node.args[0], node.args[1] + query_meta = query_node.meta.get("val") if hasattr(query_node, "meta") else None + key_meta = key_node.meta.get("val") if hasattr(key_node, "meta") else None + if ( + query_meta is not None + and key_meta is not None + and query_meta.size(2) != key_meta.size(2) ): - attn_bias = node.args[3] if len(node.args) > 3 else None - if attn_bias is None: - continue - node.args = ( - node.args[0], - node.args[1], - node.args[2], - None, - False, - 0.0, - True, - ) - changed = True logger.debug( - f"The args of node {node} was changed to causal mode. Now the node's arguments are: {node.args}" + f"Skipping causal force for node {node}: seq_q={query_meta.size(2)} " + f"!= seq_k={key_meta.size(2)} (decode-phase, IAttentionLayer handles it)" ) + continue + + node.args = ( + node.args[0], + node.args[1], + node.args[2], + None, + False, + 0.0, + True, + ) + changed = True + logger.debug(f"Node {node} changed to causal mode: {node.args}") if changed: gm = clean_up_graph_after_modifications(gm) diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index c1a13f0690..29182980fe 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -6,14 +6,12 @@ from typing import Any, Callable, List, Optional, Sequence, Tuple import torch -import torch_tensorrt from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch.fx.passes.shape_prop import ShapeProp from torch.testing._internal.common_utils import TestCase from torch_tensorrt import Input from torch_tensorrt._Device import Device from torch_tensorrt._enums import dtype -from torch_tensorrt.dynamo import _defaults from torch_tensorrt.dynamo._defaults import default_device from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo._tracer import get_dynamic_shapes_args @@ -109,58 +107,6 @@ def get_use_dynamo_tracer(use_dynamo_tracer: Any) -> bool: return False -# this method is only used in our converter test to infer the module output dtypes via dummy inference -# which is due to fx.symbolic_trace does not have the meta['val'] info in the node -# TODO: lan to remove this once our converter test is moved from fx.symbolic_trace to dynamo trace -def infer_module_output_dtypes_for_test( - module: torch.fx.GraphModule, - inputs: Sequence[Input], - device: Device, - kwarg_inputs: Optional[dict[str, Any]] = None, - truncate_double: bool = False, -) -> List[dtype]: - """ - This function performs model inference to determine the output dtypes - and truncates them accordingly. inputs can be either arg_inputs or flattened input list. - If it is flattened list, kwarg_inputs should be None, as it is already included in the flattened input. - """ - # TODO: We can also determine output dtypes from the module.graph based on node metadata. - # However, our converter tests use fx.symbolic_trace which sometimes does not provide metadata, - # so we stick to the model inference approach currently. - with unset_fake_temporarily(): - # Get the device on which the model exists - # For large models, this can be done on CPU to save GPU memory allocation for TRT. - device = get_model_device(module) - torch_inputs = get_torch_inputs(inputs, device) - if kwarg_inputs is None: - kwarg_inputs = {} - torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device) - module_outputs = module(*torch_inputs, **torch_kwarg_inputs) - if not isinstance(module_outputs, (list, tuple)): - module_outputs = [module_outputs] - - # Int64 outputs can sometimes be generated from within other operators - # such as aten.sum - such outputs can be truncated - output_dtypes = [] - for output in module_outputs: - output_ = output - # We don't need to check if output is nested here because the input module will be flattened - if not isinstance(output, torch.Tensor): - if isinstance(output, str): - raise ValueError( - f"Received an output type {type(output)} that's not in the acceptable datatypes (https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)" - ) - else: - output_ = torch.tensor(output) - - if truncate_double and output_.dtype == dtype.float64: - output_dtypes.append(dtype.float32) - else: - output_dtypes.append(dtype._from(output_.dtype)) - - return output_dtypes - - def fetch_attr(mod, target): """ Fetch an attribute from the ``Module`` hierarchy of ``mod.module``. @@ -422,6 +368,7 @@ def run_test( immutable_weights=True, use_explicit_typing=False, decompose_attention=False, + attn_bias_is_causal=True, ): # TODO: lan to remove this and set use_dynamo_traccer to True by default # once all the converter test files are moved to use_dynamo_tracer @@ -434,6 +381,7 @@ def run_test( immutable_weights=immutable_weights, use_explicit_typing=use_explicit_typing, decompose_attention=decompose_attention, + attn_bias_is_causal=attn_bias_is_causal, ) mod = self.generate_graph( diff --git a/tests/py/dynamo/hlo/__init__.py b/tests/py/dynamo/hlo/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/py/dynamo/hlo/test_attention.py b/tests/py/dynamo/hlo/test_attention.py new file mode 100644 index 0000000000..06008ad857 --- /dev/null +++ b/tests/py/dynamo/hlo/test_attention.py @@ -0,0 +1,857 @@ +"""Comprehensive attention subgraph tests for TRT converter bug discovery. + +Covers all SDPA kernel variants, MHA/GQA/MQA attention patterns, causal vs +non-causal masking, bool/float/broadcast mask shapes, decode-phase attention +(seq_q=1), non-power-of-2 head dims, LLM-realistic configs, and multiple dtypes. + +Known limitations (decompose_attention=True required) +----------------------------------------------------- + Large causal sequences (seq >= 512, is_causal=True) + IAttentionLayer produces ~80% element mismatch at long sequences. + + FP32 GQA/MQA + PyTorch's core_aten decomposition expands scaled_dot_product_attention + into matmul + _safe_softmax before the TRT converter runs. No converter + is registered for _safe_softmax, so FP32 GQA requires decompose_attention=True. + +Notes on attn_bias_is_causal +----------------------------- + Default True: the force_causal_efficient_attention lowering pass strips + attn_bias and sets is_causal=True before reaching the converter. + This is an HF-model optimization; most production uses keep the default. + Set False: attn_bias is forwarded to IAttentionLayer.mask. Required for + any test that validates actual bias tensor values. + +Test classes +------------ + TestSDPA - aten.scaled_dot_product_attention — all configurations: + test_no_mask no mask; IAttentionLayer native (decompose=True for large causal) + test_decode decode-phase (seq_q=1, seq_k>1); IAttentionLayer native + test_bool_mask bool attention masks (full, broadcast, 2-D, decode) + test_float_mask additive float attention masks (incl. decode-phase) + test_gqa GQA/MQA (Hq != Hkv); IAttentionLayer native + TestFlashAttention - _scaled_dot_product_flash_attention kernel: + test_no_mask no mask; IAttentionLayer native (decompose=True for large causal) + test_decode decode-phase (seq_q=1, seq_k>1); IAttentionLayer native + test_gqa GQA/MQA (Hq != Hkv); IAttentionLayer native + TestEfficientAttention - _scaled_dot_product_efficient_attention: + test_no_bias attn_bias=None; decompose=True + test_with_bias native IAttentionLayer.mask, incl. h=1/b=1 shapes + test_with_bias_decode decode-phase (seq_q=1, seq_k>1) + 4D bias + test_with_bias_causal is_causal=True + attn_bias combined converter path + test_attn_bias_is_causal_opt force_causal_efficient_attention pass +""" + +import unittest + +import torch +import torch.nn as nn +import torch_tensorrt +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from ..conversion.harness import DispatchTestCase + +_BF16_SKIP = unittest.skipIf( + not torch.cuda.is_available() + or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8, + "BF16 requires Ampere (SM80) or higher", +) + +_FLASH_ATTN_SKIP = unittest.skipIf( + not torch.cuda.is_available() + or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8, + "Flash attention requires Ampere (SM80) or higher", +) + + +def _skip_bf16_on_rtx(test_self, dtype): + """Call at the top of a test to skip BF16 on TensorRT-RTX builds.""" + if dtype == torch.bfloat16 and getattr( + torch_tensorrt.ENABLED_FEATURES, "tensorrt_rtx", False + ): + test_self.skipTest("TensorRT-RTX does not support bfloat16") + + +# --------------------------------------------------------------------------- +# Standard SDPA — all configurations +# --------------------------------------------------------------------------- + + +class TestSDPA(DispatchTestCase): + """aten.scaled_dot_product_attention — all configurations. + + test_no_mask + Standard MHA, no mask. decompose=True for large causal (seq >= 512) + and flash-dispatch configs (large heads+head_dim). + + test_decode + Decode-step (seq_q=1, K/V span full context), no mask. + IAttentionLayer handles non-square Q/K natively. + + test_bool_mask + Bool attention masks (full, broadcast, 2D, decode-phase). + Exercises the bool→float (-inf fill) mask conversion path. + + test_float_mask + Additive float attention masks (added to QK^T before softmax). + Exercises the add-bias path, distinct from the bool mask path. + + test_gqa + GQA/MQA (Hq != Hkv). IAttentionLayer native; no K/V expansion. + FP32 and large causal (seq >= 512) require decompose_attention=True. + """ + + # fmt: off + @parameterized.expand( + [ + # (name, batch, heads, seq_q, seq_k, head_dim, is_causal, scale, dtype, use_decompose, test_atol) + # --- FP16, varying batch --- + ("b1_h8_s32_d64_nc_fp16", 1, 8, 32, 32, 64, False, None, torch.float16, False, 1e-2), + ("b1_h8_s32_d64_ca_fp16", 1, 8, 32, 32, 64, True, None, torch.float16, False, 1e-2), + ("b2_h8_s128_d64_nc_fp16", 2, 8, 128, 128, 64, False, None, torch.float16, False, 1e-2), + ("b2_h8_s128_d64_ca_fp16", 2, 8, 128, 128, 64, True, None, torch.float16, False, 1e-2), + ("b4_h8_s128_d64_fp16", 4, 8, 128, 128, 64, True, None, torch.float16, False, 1e-2), + # --- FP16, varying num_heads --- + ("h1_fp16", 1, 1, 64, 64, 64, False, None, torch.float16, False, 1e-2), + ("h16_fp16", 2, 16, 64, 64, 64, False, None, torch.float16, False, 1e-2), + ("h32_fp16", 2, 32, 128, 128, 64, True, None, torch.float16, False, 1e-2), + # --- FP16, varying head_dim --- + ("d16_fp16", 2, 8, 64, 64, 16, False, None, torch.float16, False, 1e-2), + ("d32_fp16", 2, 8, 64, 64, 32, False, None, torch.float16, False, 1e-2), + ("d128_fp16", 1, 4, 64, 64, 128, False, None, torch.float16, False, 1e-2), + # Non-power-of-2 head dims + ("d48_fp16", 1, 4, 32, 32, 48, False, None, torch.float16, False, 1e-2), + ("d96_fp16", 1, 4, 32, 32, 96, False, None, torch.float16, False, 1e-2), + # Large causal → decompose; loosen atol for fp16 accumulation at long seq + ("s512_ca_fp16", 1, 8, 512, 512, 64, True, None, torch.float16, True, 0.1), + ("s2048_ca_fp16", 1, 8, 2048, 2048, 64, True, None, torch.float16, True, 0.1), + # --- FP16, custom scale --- + ("scale_0125_fp16", 2, 8, 64, 64, 64, False, 0.125, torch.float16, False, 1e-2), + ("scale_05_ca_fp16", 2, 8, 64, 64, 64, True, 0.5, torch.float16, False, 1e-2), + # scale=2.0 in FP16 causes ~0.5% mismatch due to fp16 overflow; loosen atol + ("scale_2_fp16", 2, 8, 64, 64, 64, False, 2.0, torch.float16, False, 0.1), + # --- FP32 --- + ("b1_h8_s32_d64_nc_fp32", 1, 8, 32, 32, 64, False, None, torch.float32, False, 1e-2), + ("b1_h8_s32_d64_ca_fp32", 1, 8, 32, 32, 64, True, None, torch.float32, False, 1e-2), + ("b2_h8_s128_d64_fp32", 2, 8, 128, 128, 64, False, None, torch.float32, False, 1e-2), + ("scale_05_ca_fp32", 2, 8, 64, 64, 64, True, 0.5, torch.float32, False, 1e-2), + # --- BF16 (Ampere+ only, guarded per-test) --- + ("b1_h8_s32_d64_nc_bf16", 1, 8, 32, 32, 64, False, None, torch.bfloat16, False, 1e-2), + ("b2_h8_s128_d64_ca_bf16", 2, 8, 128, 128, 64, True, None, torch.bfloat16, False, 1e-2), + # LLM-realistic configs + ("llama32_1b_prefill_fp16", 1, 32, 2048, 2048, 64, True, None, torch.float16, True, 0.1), # Llama-3.2-1B, large causal → decompose + ("llama32_3b_prefill_fp16", 1, 24, 2048, 2048, 128, True, None, torch.float16, True, 1e-2), # Llama-3.2-3B + ("qwen25_05b_fp16", 1, 14, 128, 128, 64, True, None, torch.float16, False, 1e-2), # Qwen2.5-0.5B + ("mistral_7b_fp16", 1, 32, 512, 512, 128, True, None, torch.float16, True, 1e-2), # Mistral-7B, flash dispatch → decompose + ] + ) + # fmt: on + def test_no_mask( + self, + name, + batch, + num_heads, + seq_q, + seq_k, + head_dim, + is_causal, + scale, + dtype, + use_decompose, + test_atol, + ): + _skip_bf16_on_rtx(self, dtype) + + class SDPA(nn.Module): + def forward(self, q, k, v): + return torch.ops.aten.scaled_dot_product_attention.default( + q, k, v, None, 0.0, is_causal, scale=scale + ) + + q = torch.randn(batch, num_heads, seq_q, head_dim, dtype=dtype) + k = torch.randn(batch, num_heads, seq_k, head_dim, dtype=dtype) + v = torch.randn(batch, num_heads, seq_k, head_dim, dtype=dtype) + self.run_test( + SDPA(), + [q, k, v], + rtol=1e-2, + atol=test_atol, + precision=dtype, + enable_passes=True, + use_explicit_typing=True, + decompose_attention=use_decompose, + ) + + # fmt: off + @parameterized.expand( + [ + # (name, batch, num_heads, context_len, head_dim, dtype) + ("b1_h8_ctx128_d64_fp16", 1, 8, 128, 64, torch.float16), + ("b1_h8_ctx2048_d64_fp16", 1, 8, 2048, 64, torch.float16), + ("b2_h8_ctx128_d64_fp16", 2, 8, 128, 64, torch.float16), + ("b1_h8_ctx128_d64_fp32", 1, 8, 128, 64, torch.float32), + ("b1_h8_ctx128_d64_bf16", 1, 8, 128, 64, torch.bfloat16), + # LLM-realistic decode configs + ("llama32_1b_dec_fp16", 1, 32, 2048, 128, torch.float16), + ("qwen25_dec_fp16", 1, 14, 128, 64, torch.float16), + ("mistral_dec_fp16", 1, 32, 512, 128, torch.float16), + # Non-power-of-2 head dim + ("d48_dec_fp16", 1, 8, 128, 48, torch.float16), + ("d96_dec_fp16", 1, 8, 128, 96, torch.float16), + # Long context + ("b1_h32_ctx4096_d128_fp16", 1, 32, 4096, 128, torch.float16), + ] + ) + # fmt: on + def test_decode(self, name, batch, num_heads, context_len, head_dim, dtype): + """Single-token decode: Q has seq_len=1, K/V hold full context.""" + _skip_bf16_on_rtx(self, dtype) + + class DecodeAttention(nn.Module): + def forward(self, q, k, v): + return torch.ops.aten.scaled_dot_product_attention.default( + q, k, v, None, 0.0, False, scale=None + ) + + q = torch.randn(batch, num_heads, 1, head_dim, dtype=dtype) + k = torch.randn(batch, num_heads, context_len, head_dim, dtype=dtype) + v = torch.randn(batch, num_heads, context_len, head_dim, dtype=dtype) + self.run_test( + DecodeAttention(), + [q, k, v], + rtol=1e-2, + atol=1e-2, + precision=dtype, + enable_passes=True, + use_explicit_typing=True, + decompose_attention=False, + ) + + # fmt: off + @parameterized.expand( + [ + # (name, batch, heads, seq_q, seq_k, head_dim, mask_shape, dtype, use_decompose) + # Full (batch, heads, seq_q, seq_k) masks + ("full_b2_h8_s32_fp16", 2, 8, 32, 32, 64, (2, 8, 32, 32), torch.float16, False), + ("full_b2_h8_s32_fp32", 2, 8, 32, 32, 64, (2, 8, 32, 32), torch.float32, False), + ("full_b4_h8_s64_fp16", 4, 8, 64, 64, 64, (4, 8, 64, 64), torch.float16, False), + # Broadcast: (1, 1, seq_q, seq_k) + ("bcast_1111_fp16", 2, 8, 32, 32, 64, (1, 1, 32, 32), torch.float16, False), + ("bcast_1111_fp32", 2, 8, 32, 32, 64, (1, 1, 32, 32), torch.float32, False), + # Broadcast: (batch, 1, seq_q, seq_k) + ("bcast_b1sk_fp16", 2, 8, 32, 32, 64, (2, 1, 32, 32), torch.float16, False), + # 2D mask (seq_q, seq_k) — broadcastable + ("mask_2d_fp16", 1, 8, 32, 32, 64, (32, 32), torch.float16, False), + ("mask_2d_fp32", 2, 8, 128, 128, 64, (128, 128), torch.float32, False), + # Decode step (seq_q=1): IAttentionLayer handles non-square Q/K natively + ("decode_full_fp16", 2, 8, 1, 32, 64, (2, 8, 1, 32), torch.float16, False), + ("decode_bcast_fp16", 2, 8, 1, 32, 64, (1, 1, 1, 32), torch.float16, False), + # Cross-attention (seq_q != seq_k): non-square → use decompose_attention + ("cross_attn_fp16", 1, 8, 16, 64, 64, (1, 8, 16, 64), torch.float16, True), + ] + ) + # fmt: on + def test_bool_mask( + self, + name, + batch, + num_heads, + seq_q, + seq_k, + head_dim, + mask_shape, + dtype, + use_decompose, + ): + _skip_bf16_on_rtx(self, dtype) + + class SDPABoolMask(nn.Module): + def forward(self, q, k, v, mask): + return torch.ops.aten.scaled_dot_product_attention.default( + q, k, v, mask, 0.0, False, scale=None + ) + + q = torch.randn(batch, num_heads, seq_q, head_dim, dtype=dtype) + k = torch.randn(batch, num_heads, seq_k, head_dim, dtype=dtype) + v = torch.randn(batch, num_heads, seq_k, head_dim, dtype=dtype) + mask = torch.randint(0, 2, mask_shape, dtype=torch.bool) + self.run_test( + SDPABoolMask(), + [q, k, v, mask], + rtol=1e-2, + atol=1e-2, + precision=dtype, + enable_passes=True, + use_explicit_typing=True, + decompose_attention=use_decompose, + ) + + # fmt: off + @parameterized.expand( + [ + # (name, batch, heads, seq_q, seq_k, head_dim, scale, dtype, use_decompose, test_atol) + ("basic_nc_fp16", 2, 8, 32, 32, 64, None, torch.float16, False, 1e-2), + ("basic_nc_fp32", 2, 8, 32, 32, 64, None, torch.float32, False, 1e-2), + ("basic_nc_bf16", 2, 8, 32, 32, 64, None, torch.bfloat16, False, 1e-2), + # scale causes ~0.2% FP16 mismatch at atol=0.01; loosen to 0.05 + ("scale1_fp16", 2, 8, 128, 128, 64, 1.0, torch.float16, False, 5e-2), + ("large_seq_fp16", 1, 8, 512, 512, 64, None, torch.float16, False, 1e-2), + ("b4_h16_fp16", 4, 16, 64, 64, 64, None, torch.float16, False, 1e-2), + # Decode step (seq_q=1) + ("decode_fp16", 2, 8, 1, 32, 64, None, torch.float16, False, 1e-2), + ("decode_fp32", 2, 8, 1, 64, 64, None, torch.float32, False, 1e-2), + # Non-standard head dim + ("d48_fp16", 1, 4, 32, 32, 48, None, torch.float16, False, 1e-2), + ] + ) + # fmt: on + def test_float_mask( + self, + name, + batch, + num_heads, + seq_q, + seq_k, + head_dim, + scale, + dtype, + use_decompose, + test_atol, + ): + _skip_bf16_on_rtx(self, dtype) + + class SDPAFloatMask(nn.Module): + def forward(self, q, k, v, mask): + return torch.ops.aten.scaled_dot_product_attention.default( + q, k, v, mask, 0.0, False, scale=scale + ) + + q = torch.randn(batch, num_heads, seq_q, head_dim, dtype=dtype) + k = torch.randn(batch, num_heads, seq_k, head_dim, dtype=dtype) + v = torch.randn(batch, num_heads, seq_k, head_dim, dtype=dtype) + mask = torch.randn(batch, num_heads, seq_q, seq_k, dtype=dtype) + self.run_test( + SDPAFloatMask(), + [q, k, v, mask], + rtol=1e-2, + atol=test_atol, + precision=dtype, + enable_passes=True, + use_explicit_typing=True, + decompose_attention=use_decompose, + ) + + # fmt: off + @parameterized.expand( + [ + # (name, batch, q_heads, kv_heads, seq_len, head_dim, is_causal, dtype, use_decompose) + ("gqa_32q_8kv_s128_fp16", 1, 32, 8, 128, 128, True, torch.float16, False), + ("gqa_32q_8kv_s2048_fp16", 1, 32, 8, 2048, 128, True, torch.float16, True), # large causal → decompose + ("gqa_16q_4kv_s128_fp16", 2, 16, 4, 128, 64, True, torch.float16, False), + ("gqa_8q_2kv_nc_fp16", 2, 8, 2, 64, 64, False, torch.float16, False), + ("gqa_8q_4kv_fp32", 2, 8, 4, 64, 64, False, torch.float32, True), # FP32: _safe_softmax path → decompose + ("gqa_24q_8kv_fp16", 1, 24, 8, 128, 128, True, torch.float16, False), # Llama-3.2-3B + ("gqa_14q_2kv_fp16", 1, 14, 2, 128, 64, True, torch.float16, False), # Qwen2.5-0.5B + # MQA (kv_heads = 1) + ("mqa_8q_1kv_nc_fp16", 2, 8, 1, 64, 64, False, torch.float16, False), + ("mqa_16q_1kv_ca_fp16", 1, 16, 1, 128, 64, True, torch.float16, False), + # GQA decode (seq_q=1) + ("gqa_decode_32q_8kv_fp16", 2, 32, 8, 1, 128, False, torch.float16, False), + ("mqa_decode_32q_1kv_fp16", 2, 32, 1, 1, 128, False, torch.float16, False), + ] + ) + # fmt: on + def test_gqa( + self, + name, + batch, + num_q_heads, + num_kv_heads, + seq_len, + head_dim, + is_causal, + dtype, + use_decompose, + ): + _skip_bf16_on_rtx(self, dtype) + + class GQA(nn.Module): + def forward(self, q, k, v): + return torch.ops.aten.scaled_dot_product_attention.default( + q, k, v, None, 0.0, is_causal, scale=None, enable_gqa=True + ) + + q = torch.randn(batch, num_q_heads, seq_len, head_dim, dtype=dtype) + k = torch.randn(batch, num_kv_heads, seq_len, head_dim, dtype=dtype) + v = torch.randn(batch, num_kv_heads, seq_len, head_dim, dtype=dtype) + self.run_test( + GQA(), + [q, k, v], + rtol=1e-2, + atol=1e-2, + precision=dtype, + enable_passes=True, + use_explicit_typing=True, + decompose_attention=use_decompose, + ) + + +# --------------------------------------------------------------------------- +# Flash attention kernel +# --------------------------------------------------------------------------- + + +@_FLASH_ATTN_SKIP +class TestFlashAttention(DispatchTestCase): + """_scaled_dot_product_flash_attention kernel (Ampere+ required). + + Mirrors the TestEfficientAttention coverage structure. Flash attention + has no attn_bias parameter so test_with_bias / test_with_bias_causal / + test_attn_bias_is_causal_opt have no equivalent here. + + test_no_mask + Standard MHA, no mask. decompose=True for large causal (seq >= 512). + + test_decode + Decode-phase (seq_q=1, seq_k>1) via IAttentionLayer. + + test_gqa + GQA/MQA (Hq != Hkv). IAttentionLayer native. + Large causal (seq >= 512) uses decompose_attention=True. + """ + + # ------------------------------------------------------------------ + # 1. No mask + # ------------------------------------------------------------------ + + # fmt: off + @parameterized.expand( + [ + # (name, batch, heads, seq_len, head_dim, is_causal, scale, dtype, use_decompose, atol) + ("causal_fp16", 2, 8, 128, 64, True, None, torch.float16, False, 1e-2), + ("nc_fp16", 2, 8, 128, 64, False, None, torch.float16, False, 1e-2), + ("scale_025_ca_fp16", 2, 8, 128, 64, True, 0.25, torch.float16, False, 1e-2), + # scale=0.5 causes ~4-element FP16 mismatch; loosen atol + ("scale_05_nc_fp16", 2, 8, 128, 64, False, 0.5, torch.float16, False, 2e-2), + ("b4_h16_s128_fp16", 4, 16, 128, 64, True, None, torch.float16, False, 1e-2), + ("b1_h8_d128_ca_fp16",1, 8, 128, 128, True, None, torch.float16, False, 1e-2), + ("b1_h32_s256_ca_fp16",1, 32, 256, 64, True, None, torch.float16, False, 1e-2), + # Non-power-of-2 head dim + ("d48_fp16", 1, 4, 64, 48, False, None, torch.float16, False, 1e-2), + ("d96_fp16", 1, 4, 64, 96, False, None, torch.float16, False, 1e-2), + # BF16 + ("causal_bf16", 2, 8, 128, 64, True, None, torch.bfloat16, False, 1e-2), + # Large causal → decompose; loosen atol for fp16 accumulation + ("s512_ca_fp16", 1, 8, 512, 64, True, None, torch.float16, True, 0.1), + ("s2048_ca_fp16", 1, 32, 2048, 64, True, None, torch.float16, True, 0.1), + ] + ) + # fmt: on + def test_no_mask( + self, + name, + batch, + num_heads, + seq_len, + head_dim, + is_causal, + scale, + dtype, + use_decompose, + atol, + ): + _skip_bf16_on_rtx(self, dtype) + + class FlashAttn(nn.Module): + def forward(self, q, k, v): + out = torch.ops.aten._scaled_dot_product_flash_attention.default( + q, k, v, 0.0, is_causal, False, scale=scale + ) + return out[0] + + q = torch.randn(batch, num_heads, seq_len, head_dim, dtype=dtype) + k = torch.randn(batch, num_heads, seq_len, head_dim, dtype=dtype) + v = torch.randn(batch, num_heads, seq_len, head_dim, dtype=dtype) + self.run_test( + FlashAttn(), + [q, k, v], + rtol=1e-2, + atol=atol, + precision=dtype, + enable_passes=True, + use_explicit_typing=not use_decompose, + decompose_attention=use_decompose, + ) + + # ------------------------------------------------------------------ + # 2. Decode-phase (seq_q=1, seq_k>1) — IAttentionLayer native + # ------------------------------------------------------------------ + + # fmt: off + @parameterized.expand( + [ + # (name, batch, heads, context_len, head_dim, dtype, atol) + ("b1_h8_ctx128_fp16", 1, 8, 128, 64, torch.float16, 1e-2), + ("b1_h8_ctx512_fp16", 1, 8, 512, 64, torch.float16, 1e-2), + ("b2_h8_ctx128_fp16", 2, 8, 128, 64, torch.float16, 1e-2), + ("b1_h8_d128_ctx128_fp16", 1, 8, 128, 128, torch.float16, 1e-2), + # LLM-realistic decode configs + ("llama_1b_dec_fp16", 1, 32, 2048, 128, torch.float16, 1e-2), + ("qwen_dec_fp16", 1, 14, 128, 64, torch.float16, 1e-2), + ("mistral_dec_fp16", 1, 32, 512, 128, torch.float16, 1e-2), + ] + ) + # fmt: on + def test_decode(self, name, batch, num_heads, context_len, head_dim, dtype, atol): + """Single-token decode: Q has seq_len=1, K/V hold the full context.""" + _skip_bf16_on_rtx(self, dtype) + + class FlashAttnDecode(nn.Module): + def forward(self, q, k, v): + out = torch.ops.aten._scaled_dot_product_flash_attention.default( + q, k, v, 0.0, False, False, scale=None + ) + return out[0] + + q = torch.randn(batch, num_heads, 1, head_dim, dtype=dtype) + k = torch.randn(batch, num_heads, context_len, head_dim, dtype=dtype) + v = torch.randn(batch, num_heads, context_len, head_dim, dtype=dtype) + self.run_test( + FlashAttnDecode(), + [q, k, v], + rtol=1e-2, + atol=atol, + precision=dtype, + enable_passes=True, + use_explicit_typing=True, + decompose_attention=False, + ) + + # ------------------------------------------------------------------ + # 3. GQA / MQA — IAttentionLayer accepts Hq != Hkv natively + # ------------------------------------------------------------------ + + # fmt: off + @parameterized.expand( + [ + # (name, batch, q_heads, kv_heads, seq_len, head_dim, is_causal, dtype, use_decompose, atol) + ("gqa_32q_8kv_s128_fp16", 1, 32, 8, 128, 128, True, torch.float16, False, 1e-2), + ("gqa_16q_4kv_s128_fp16", 2, 16, 4, 128, 64, True, torch.float16, False, 1e-2), + ("gqa_8q_2kv_nc_fp16", 2, 8, 2, 64, 64, False, torch.float16, False, 1e-2), + ("gqa_24q_8kv_fp16", 1, 24, 8, 128, 128, True, torch.float16, False, 1e-2), + # MQA (kv_heads = 1) + ("mqa_8q_1kv_nc_fp16", 2, 8, 1, 64, 64, False, torch.float16, False, 1e-2), + ("mqa_16q_1kv_ca_fp16", 1, 16, 1, 128, 64, True, torch.float16, False, 1e-2), + # GQA decode (seq_q=1) + ("gqa_decode_32q_8kv_fp16", 2, 32, 8, 1, 128, False, torch.float16, False, 1e-2), + ("mqa_decode_32q_1kv_fp16", 2, 32, 1, 1, 128, False, torch.float16, False, 1e-2), + # Note: large causal GQA (seq >= 512) is untestable via flash attention — + # decompose=False hits IAttentionLayer mismatch, decompose=True fails + # because PyTorch's export-time decomposition does not support Hq != Hkv. + ] + ) + # fmt: on + def test_gqa( + self, + name, + batch, + num_q_heads, + num_kv_heads, + seq_len, + head_dim, + is_causal, + dtype, + use_decompose, + atol, + ): + """GQA/MQA via flash attention: Q has Hq heads, K/V have Hkv heads.""" + _skip_bf16_on_rtx(self, dtype) + + class FlashAttnGQA(nn.Module): + def forward(self, q, k, v): + out = torch.ops.aten._scaled_dot_product_flash_attention.default( + q, k, v, 0.0, is_causal, False, scale=None + ) + return out[0] + + q = torch.randn(batch, num_q_heads, seq_len, head_dim, dtype=dtype) + k = torch.randn(batch, num_kv_heads, seq_len, head_dim, dtype=dtype) + v = torch.randn(batch, num_kv_heads, seq_len, head_dim, dtype=dtype) + self.run_test( + FlashAttnGQA(), + [q, k, v], + rtol=1e-2, + atol=atol, + precision=dtype, + enable_passes=True, + use_explicit_typing=not use_decompose, + decompose_attention=use_decompose, + ) + + +# --------------------------------------------------------------------------- +# Efficient attention kernel +# --------------------------------------------------------------------------- + + +class TestEfficientAttention(DispatchTestCase): + """_scaled_dot_product_efficient_attention kernel — all attn_bias scenarios. + + Five test methods cover the distinct code paths through the converter: + + test_no_bias + attn_bias=None; uses decompose_attention=True to exercise the + matmul+softmax fallback path. + + test_with_bias + attn_bias provided; uses the native IAttentionLayer.mask path + (decompose_attention=False, attn_bias_is_causal=False). + Includes cases with batch=1 or heads=1 to stress-test mask alignment. + + test_with_bias_causal + Both is_causal=True and attn_bias set simultaneously. The converter + materialises a causal tril mask and combines it with the float bias + via additive -inf before passing to IAttentionLayer. + (decompose_attention=False, attn_bias_is_causal=False) + + test_attn_bias_is_causal_opt + Exercises the force_causal_efficient_attention lowering pass + (attn_bias_is_causal=True, default). The pass strips attn_bias and + sets is_causal=True; both TRT and the PyTorch reference see the same + post-lowering graph so the comparison is valid. + (decompose_attention=False, attn_bias_is_causal=True) + + Note: bool attn_bias is not accepted by _scaled_dot_product_efficient_attention + (PyTorch requires bias dtype == query dtype), so the bool+causal combine path + in the converter cannot be exercised through this op. + + Note: GQA/MQA is not testable via _scaled_dot_product_efficient_attention + directly — PyTorch's eager kernel rejects Hq != Hkv at runtime, so no + valid reference exists for output comparison. + """ + + # ------------------------------------------------------------------ + # 1. No bias — decompose fallback + # ------------------------------------------------------------------ + + # fmt: off + @parameterized.expand( + [ + # (name, batch, heads, seq, head_dim, is_causal, scale, dtype, atol) + ("causal_fp16", 2, 8, 128, 64, True, None, torch.float16, 1e-2), + ("nc_fp16", 2, 8, 128, 64, False, None, torch.float16, 1e-2), + ("causal_fp32", 1, 8, 64, 64, True, None, torch.float32, 1e-2), + # scale=0.5 causes ~3-element FP16 mismatch; loosen atol + ("scale05_ca_fp16",2, 8, 128, 64, True, 0.5, torch.float16, 2e-2), + ("b4_h16_fp16", 4, 16, 128, 64, False, None, torch.float16, 1e-2), + ("s512_ca_fp16", 1, 8, 512, 64, True, None, torch.float16, 1e-2), + ("d128_fp16", 1, 8, 64, 128, True, None, torch.float16, 1e-2), + ("d48_fp16", 1, 4, 32, 48, False, None, torch.float16, 1e-2), + ] + ) + # fmt: on + def test_no_bias( + self, name, batch, num_heads, seq, head_dim, is_causal, scale, dtype, atol + ): + class EfficientAttn(nn.Module): + def forward(self, q, k, v): + out = torch.ops.aten._scaled_dot_product_efficient_attention.default( + q, k, v, None, False, 0.0, is_causal, scale=scale + ) + return out[0] + + q = torch.randn(batch, num_heads, seq, head_dim, dtype=dtype) + k = torch.randn(batch, num_heads, seq, head_dim, dtype=dtype) + v = torch.randn(batch, num_heads, seq, head_dim, dtype=dtype) + self.run_test( + EfficientAttn(), + [q, k, v], + rtol=1e-2, + atol=atol, + precision=dtype, + enable_passes=True, + decompose_attention=True, + ) + + # ------------------------------------------------------------------ + # 2. With bias — native IAttentionLayer.mask path + # Includes heads=1 cases to stress mask alignment. + # ------------------------------------------------------------------ + + # fmt: off + @parameterized.expand( + [ + # (name, batch, heads, seq, head_dim, scale, dtype, atol) + # Standard shapes + ("nc_b2_h8_fp16", 2, 8, 32, 64, None, torch.float16, 1e-2), + ("nc_b2_h8_fp32", 2, 8, 32, 64, None, torch.float32, 1e-2), + # scale with bias causes borderline FP16 mismatch; loosen slightly + ("scale05_fp16", 2, 8, 32, 64, 0.5, torch.float16, 2e-2), + ("scale2_fp32", 1, 8, 32, 64, 2.0, torch.float32, 2e-2), + ("large_seq_fp16", 1, 8, 128, 64, None, torch.float16, 1e-2), + ("b4_h16_fp16", 4, 16, 64, 64, None, torch.float16, 1e-2), + # heads=1 — alignment stress test for IAttentionLayer.mask + ("h1_b1_fp16", 1, 1, 32, 64, None, torch.float16, 1e-2), + ("h1_b2_fp16", 2, 1, 32, 64, None, torch.float16, 1e-2), + ("h1_d128_fp16", 1, 1, 32, 128, None, torch.float16, 1e-2), + # batch=1 + ("b1_h8_fp16", 1, 8, 32, 64, None, torch.float16, 1e-2), + ] + ) + # fmt: on + def test_with_bias(self, name, batch, num_heads, seq, head_dim, scale, dtype, atol): + class EfficientAttnBias(nn.Module): + def forward(self, q, k, v, bias): + out = torch.ops.aten._scaled_dot_product_efficient_attention.default( + q, k, v, bias, False, 0.0, False, scale=scale + ) + return out[0] + + q = torch.randn(batch, num_heads, seq, head_dim, dtype=dtype) + k = torch.randn(batch, num_heads, seq, head_dim, dtype=dtype) + v = torch.randn(batch, num_heads, seq, head_dim, dtype=dtype) + bias = torch.randn(batch, num_heads, seq, seq, dtype=dtype) + self.run_test( + EfficientAttnBias(), + [q, k, v, bias], + rtol=1e-2, + atol=atol, + precision=dtype, + enable_passes=True, + use_explicit_typing=True, + decompose_attention=False, + attn_bias_is_causal=False, + ) + + # ------------------------------------------------------------------ + # 3. With bias — decode-phase (seq_q=1, seq_k>1) + # ------------------------------------------------------------------ + + # fmt: off + @parameterized.expand( + [ + # (name, batch, heads, seq_k, head_dim, scale, dtype, atol) + # seq_k >= 8 required: PyTorch's efficient-attention kernel enforces + # attn_bias.stride(1) = seq_q * seq_k >= 8 for its eager reference run. + ("decode_b4_h8_fp16", 4, 8, 8, 64, None, torch.float16, 1e-2), + ] + ) + # fmt: on + def test_with_bias_decode( + self, name, batch, num_heads, seq_k, head_dim, scale, dtype, atol + ): + """Decode-phase with 4D float bias: q=(B,H,1,D), bias=(B,H,1,Sk).""" + seq_q = 1 + + class EfficientAttnBiasDecode(nn.Module): + def forward(self, q, k, v, bias): + out = torch.ops.aten._scaled_dot_product_efficient_attention.default( + q, k, v, bias, False, 0.0, False, scale=scale + ) + return out[0] + + q = torch.randn(batch, num_heads, seq_q, head_dim, dtype=dtype) + k = torch.randn(batch, num_heads, seq_k, head_dim, dtype=dtype) + v = torch.randn(batch, num_heads, seq_k, head_dim, dtype=dtype) + bias = torch.randn(batch, num_heads, seq_q, seq_k, dtype=dtype) + self.run_test( + EfficientAttnBiasDecode(), + [q, k, v, bias], + rtol=1e-2, + atol=atol, + precision=dtype, + enable_passes=True, + use_explicit_typing=True, + decompose_attention=False, + attn_bias_is_causal=False, + ) + + # ------------------------------------------------------------------ + # 4. With bias + is_causal=True — combined path in converter + # ------------------------------------------------------------------ + + # fmt: off + @parameterized.expand( + [ + # (name, batch, heads, seq, head_dim, scale, dtype, atol) + ("ca_b2_h8_fp16", 2, 8, 32, 64, None, torch.float16, 1e-2), + ("ca_b1_h8_fp32", 1, 8, 64, 64, None, torch.float32, 1e-2), + ("ca_scale05_fp16", 2, 8, 32, 64, 0.5, torch.float16, 1e-2), + ("ca_large_fp16", 1, 8, 128, 64, None, torch.float16, 1e-2), + ("ca_d128_fp16", 1, 8, 32, 128, None, torch.float16, 1e-2), + ] + ) + # fmt: on + def test_with_bias_causal( + self, name, batch, num_heads, seq, head_dim, scale, dtype, atol + ): + class EfficientAttnBiasCausal(nn.Module): + def forward(self, q, k, v, bias): + out = torch.ops.aten._scaled_dot_product_efficient_attention.default( + q, k, v, bias, False, 0.0, True, scale=scale + ) + return out[0] + + q = torch.randn(batch, num_heads, seq, head_dim, dtype=dtype) + k = torch.randn(batch, num_heads, seq, head_dim, dtype=dtype) + v = torch.randn(batch, num_heads, seq, head_dim, dtype=dtype) + bias = torch.randn(batch, num_heads, seq, seq, dtype=dtype) + self.run_test( + EfficientAttnBiasCausal(), + [q, k, v, bias], + rtol=1e-2, + atol=atol, + precision=dtype, + enable_passes=True, + use_explicit_typing=True, + decompose_attention=False, + attn_bias_is_causal=False, + ) + + # ------------------------------------------------------------------ + # 5. attn_bias_is_causal=True — force_causal_efficient_attention pass + # ------------------------------------------------------------------ + + # fmt: off + @parameterized.expand( + [ + # (name, batch, heads, seq, head_dim, scale, dtype, atol) + ("opt_b2_h8_fp16", 2, 8, 32, 64, None, torch.float16, 1e-2), + ("opt_large_fp16", 1, 8, 128, 64, None, torch.float16, 1e-2), + ("opt_b4_h16_fp16", 4, 16, 64, 64, None, torch.float16, 1e-2), + ("opt_d128_fp16", 1, 8, 32, 128, None, torch.float16, 1e-2), + ] + ) + # fmt: on + def test_attn_bias_is_causal_opt( + self, name, batch, num_heads, seq, head_dim, scale, dtype, atol + ): + class EfficientAttnBiasOpt(nn.Module): + def forward(self, q, k, v, bias): + out = torch.ops.aten._scaled_dot_product_efficient_attention.default( + q, k, v, bias, False, 0.0, False, scale=scale + ) + return out[0] + + q = torch.randn(batch, num_heads, seq, head_dim, dtype=dtype) + k = torch.randn(batch, num_heads, seq, head_dim, dtype=dtype) + v = torch.randn(batch, num_heads, seq, head_dim, dtype=dtype) + bias = torch.randn( + batch, num_heads, seq, seq, dtype=dtype + ) # values ignored; pass replaces with is_causal=True + self.run_test( + EfficientAttnBiasOpt(), + [q, k, v, bias], + rtol=1e-2, + atol=atol, + precision=dtype, + enable_passes=True, + use_explicit_typing=True, + decompose_attention=False, + attn_bias_is_causal=True, + ) + + +if __name__ == "__main__": + run_tests()