Remove padding from scales for hipBLASlt calls#442
Conversation
| if (params.m % 16 || params.n % 16) { | ||
| GTEST_SKIP() << "MXFP8 requires M & N to be multiples of 16"; | ||
| } | ||
| if (params.k % 128) { |
There was a problem hiding this comment.
Is it hipblasLt limitation?
There was a problem hiding this comment.
Yes, these are the values that hipblastlt team provided to us. I tested just in case, but nothing smaller that 128 works for k.
| NVTE_DIM_CHECK(chunk_height > 0 && chunk_width > 0, "Attempted to get empty tensor chunk"); | ||
| NVTE_DIM_CHECK(chunk_height <= height && chunk_width <= width, | ||
| "Attempted to get out-of-bounds tensor chunk"); | ||
| #ifndef __HIP_PLATFORM_AMD__ |
There was a problem hiding this comment.
I think this file is not currently compiled for ROCm - it is for UB
There was a problem hiding this comment.
Yes, I can move it to the UB PR if you prefer?
| unpad_mxfp8_scales_kernel<<<blocks, threads, 0, stream>>> | ||
| (scale_dptr, unpadded_rows, unpadded_cols, padded_cols); | ||
|
|
||
| NVTE_CHECK_CUDA(hipStreamSynchronize(stream)); |
There was a problem hiding this comment.
If we're using the same stream for GEMM, synchronize creates unnecessary bubble
| // dimensions (with matrix interpreted in row-major order). | ||
|
|
||
| unpad_mxfp8_checkpoint(A, is_A_transposed, m, n, k, stream); | ||
|
|
There was a problem hiding this comment.
It is probably not the best place for this call - it does not update input parameters.
Also, what will be the code behaviour if we call GEMM twice with the same input - the first call will updad scale tensor and the second will do the same, right?
There was a problem hiding this comment.
We talked about this during standup and the issues it would create in Python, I am currently working on moving the logic over there.
There was a problem hiding this comment.
I have moved the check and unpadding to generic_gemm in pytorch. This provides a more robust way to check before calling the gemms, and directly modifies the pytorch tensors.
Removes padding for scale vectors that are used mainly for MXFP8.