Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces several new sparse attention mechanisms and operators, including FlashAttention v4 and SageAttention v2/v3, along with corresponding configuration files and utility functions. Key changes include the addition of SparseFlashAttn4Weight, SparseSageAttn2Weight, and SparseSageAttn3Weight, as well as a new SpargeMaskGenerator and a comprehensive sparge_util.py containing Triton kernels for block map generation. Feedback focuses on improving error handling by raising exceptions for unsupported sparse modes instead of just logging them, and refining assertion messages for better clarity and professionalism.
| smooth_k = kt - kt.mean(dim=-2, keepdim=True) | ||
| sparse_map = get_block_map_meansim(qt, smooth_k, cdfthreshd=None, topk=self.topk, return_lut=False, BLKQ=self.BLKQ, BLKK=self.BLKK) | ||
| else: | ||
| logger.info(f"spas_flash_attn4 sparse_mode only support sla_mode and sparge_mode now.") |
There was a problem hiding this comment.
Using logger.info for an unsupported sparse_mode might lead to silent failures or misconfigurations being overlooked. It would be more robust to either raise a ValueError or use logger.error to clearly indicate an invalid state.
| logger.info(f"spas_flash_attn4 sparse_mode only support sla_mode and sparge_mode now.") | |
| raise ValueError(f"Unsupported sparse_mode: {self.sparse_mode}. spas_flash_attn4 sparse_mode only supports 'sla_mode' and 'sparge_mode'.") |
| smooth_k = k - k.mean(dim=-2, keepdim=True) | ||
| sparse_map = get_block_map_meansim(q, smooth_k, cdfthreshd=None, topk=self.topk, return_lut=False, BLKQ=self.BLKQ, BLKK=self.BLKK) | ||
| else: | ||
| logger.info(f"spas_sage_attn2 sparse_mode only support sla_mode and sparge_mode now.") |
There was a problem hiding this comment.
Using logger.info for an unsupported sparse_mode might lead to silent failures or misconfigurations being overlooked. It would be more robust to either raise a ValueError or use logger.error to clearly indicate an invalid state.
| logger.info(f"spas_sage_attn2 sparse_mode only support sla_mode and sparge_mode now.") | |
| raise ValueError(f"Unsupported sparse_mode: {self.sparse_mode}. spas_sage_attn2 sparse_mode only supports 'sla_mode' and 'sparge_mode'.") |
| smooth_k = k - k.mean(dim=-2, keepdim=True) | ||
| sparse_map = get_block_map_meansim(q, smooth_k, cdfthreshd=None, topk=self.topk, return_lut=False, BLKQ=self.BLKQ, BLKK=self.BLKK) | ||
| else: | ||
| logger.info(f"spas_sage_attn3 sparse_mode only support sla_mode and sparge_mode now.") |
There was a problem hiding this comment.
Using logger.info for an unsupported sparse_mode might lead to silent failures or misconfigurations being overlooked. It would be more robust to either raise a ValueError or use logger.error to clearly indicate an invalid state.
| logger.info(f"spas_sage_attn3 sparse_mode only support sla_mode and sparge_mode now.") | |
| raise ValueError(f"Unsupported sparse_mode: {self.sparse_mode}. spas_sage_attn3 sparse_mode only supports 'sla_mode' and 'sparge_mode'.") |
| q, k, v = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0) | ||
| elif len(q.shape) == 4: | ||
| bs = q.shape[0] | ||
| assert bs == 1, "flash_attn4 doesn't support flash_attn_varlen_func now. Just use it for batchsize = 1 for sure." |
There was a problem hiding this comment.
The assertion message here is a bit informal. Consider making it more professional to clearly communicate the limitation to users or developers.
| assert bs == 1, "flash_attn4 doesn't support flash_attn_varlen_func now. Just use it for batchsize = 1 for sure." | |
| assert bs == 1, "FlashAttention v4 currently only supports batch size of 1 for this function." |
| q, k, v = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0) | ||
| elif len(q.shape) == 4: | ||
| bs = q.shape[0] | ||
| assert bs == 1, "flash_attn4 doesn't support flash_attn_varlen_func now. Just use it for batchsize = 1 for sure." |
There was a problem hiding this comment.
The assertion message here is a bit informal. Consider making it more professional to clearly communicate the limitation to users or developers.
| assert bs == 1, "flash_attn4 doesn't support flash_attn_varlen_func now. Just use it for batchsize = 1 for sure." | |
| assert bs == 1, "FlashAttention v4 currently only supports batch size of 1 for this function." |
No description provided.