From f30202b705d4071a0ff0aba12e154293c00d162b Mon Sep 17 00:00:00 2001 From: Fadi Arafeh Date: Thu, 5 Mar 2026 13:50:07 +0000 Subject: [PATCH] Accelerate SVE128 SBGEMM/BGEMM This accelerates SBGEMM/BGEMM by extending the existing 8x4 kernel to 8x8 (unrolling N by 8) Not sure if it's a good idea to delete the previous 8x4 kernel? Here are the speedups on single core Neoverse-V2 (SVE128) compared to prev state: Per-shape speedup M=N=K=64: SBGEMM 1.164x (16.42%), BGEMM 1.133x (13.30%) M=N=K=128: SBGEMM 1.220x (22.02%), BGEMM 1.186x (18.56%) M=N=K=256: SBGEMM 1.241x (24.08%), BGEMM 1.235x (23.54%) M=N=K=512: SBGEMM 1.240x (23.95%), BGEMM 1.227x (22.75%) M=N=K=1024: SBGEMM 1.251x (25.11%), BGEMM 1.232x (23.23%) M=N=K=2048: SBGEMM 1.235x (23.47%), BGEMM 1.246x (24.64%) Signed-off-by: Fadi Arafeh --- CONTRIBUTORS.md | 3 + kernel/arm64/KERNEL.NEOVERSEN2 | 12 +- kernel/arm64/sbgemm_kernel_8x8_neoversen2.c | 56 ++ .../arm64/sbgemm_kernel_8x8_neoversen2_impl.c | 763 ++++++++++++++++++ param.h | 6 +- 5 files changed, 833 insertions(+), 7 deletions(-) create mode 100644 kernel/arm64/sbgemm_kernel_8x8_neoversen2.c create mode 100644 kernel/arm64/sbgemm_kernel_8x8_neoversen2_impl.c diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 3dc3d73bbb..d1424b2777 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -272,3 +272,6 @@ In chronological order: * Anna Mayne * [2025-11-19] Update thread throttling profile for SGEMV on NEOVERSEV1 and NEOVERSEV2 + +* Fadi Arafeh + * [2026-03-05] Accelerate SVE128 SBGEMM/BGEMM diff --git a/kernel/arm64/KERNEL.NEOVERSEN2 b/kernel/arm64/KERNEL.NEOVERSEN2 index 6431422faa..8269812347 100644 --- a/kernel/arm64/KERNEL.NEOVERSEN2 +++ b/kernel/arm64/KERNEL.NEOVERSEN2 @@ -191,12 +191,14 @@ ZGEMMOTCOPYOBJ = zgemm_otcopy$(TSUFFIX).$(SUFFIX) ifeq ($(BUILD_BFLOAT16), 1) BGEMM_BETA = bgemm_beta_neon.c BGEMMKERNEL = sbgemm_kernel_$(BGEMM_UNROLL_M)x$(BGEMM_UNROLL_N)_neoversen2.c +ifneq ($(BGEMM_UNROLL_M), $(BGEMM_UNROLL_N)) BGEMMINCOPY = sbgemm_ncopy_$(BGEMM_UNROLL_M)_neoversen2.c BGEMMITCOPY = sbgemm_tcopy_$(BGEMM_UNROLL_M)_neoversen2.c -BGEMMONCOPY = sbgemm_ncopy_$(BGEMM_UNROLL_N)_neoversen2.c -BGEMMOTCOPY = sbgemm_tcopy_$(BGEMM_UNROLL_N)_neoversen2.c BGEMMINCOPYOBJ = bgemm_incopy$(TSUFFIX).$(SUFFIX) BGEMMITCOPYOBJ = bgemm_itcopy$(TSUFFIX).$(SUFFIX) +endif +BGEMMONCOPY = sbgemm_ncopy_$(BGEMM_UNROLL_N)_neoversen2.c +BGEMMOTCOPY = sbgemm_tcopy_$(BGEMM_UNROLL_N)_neoversen2.c BGEMMONCOPYOBJ = bgemm_oncopy$(TSUFFIX).$(SUFFIX) BGEMMOTCOPYOBJ = bgemm_otcopy$(TSUFFIX).$(SUFFIX) BGEMVTKERNEL = sbgemv_t_bfdot.c @@ -204,12 +206,14 @@ BGEMVNKERNEL = bgemv_n_sve_v3x4.c SBGEMM_BETA = sbgemm_beta_neoversen2.c SBGEMMKERNEL = sbgemm_kernel_$(SBGEMM_UNROLL_M)x$(SBGEMM_UNROLL_N)_neoversen2.c +ifneq ($(SBGEMM_UNROLL_M), $(SBGEMM_UNROLL_N)) SBGEMMINCOPY = sbgemm_ncopy_$(SBGEMM_UNROLL_M)_neoversen2.c SBGEMMITCOPY = sbgemm_tcopy_$(SBGEMM_UNROLL_M)_neoversen2.c -SBGEMMONCOPY = sbgemm_ncopy_$(SBGEMM_UNROLL_N)_neoversen2.c -SBGEMMOTCOPY = sbgemm_tcopy_$(SBGEMM_UNROLL_N)_neoversen2.c SBGEMMINCOPYOBJ = sbgemm_incopy$(TSUFFIX).$(SUFFIX) SBGEMMITCOPYOBJ = sbgemm_itcopy$(TSUFFIX).$(SUFFIX) +endif +SBGEMMONCOPY = sbgemm_ncopy_$(SBGEMM_UNROLL_N)_neoversen2.c +SBGEMMOTCOPY = sbgemm_tcopy_$(SBGEMM_UNROLL_N)_neoversen2.c SBGEMMONCOPYOBJ = sbgemm_oncopy$(TSUFFIX).$(SUFFIX) SBGEMMOTCOPYOBJ = sbgemm_otcopy$(TSUFFIX).$(SUFFIX) SBGEMVTKERNEL = sbgemv_t_bfdot.c diff --git a/kernel/arm64/sbgemm_kernel_8x8_neoversen2.c b/kernel/arm64/sbgemm_kernel_8x8_neoversen2.c new file mode 100644 index 0000000000..1af679a5a1 --- /dev/null +++ b/kernel/arm64/sbgemm_kernel_8x8_neoversen2.c @@ -0,0 +1,56 @@ +/*************************************************************************** + * Copyright (c) 2026 The OpenBLAS Project + * All rights reserved. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the + * distribution. + * 3. Neither the name of the OpenBLAS project nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * *****************************************************************************/ + +#include +#include + +#include "common.h" + +#define ALPHA_ONE +#include "sbgemm_kernel_8x8_neoversen2_impl.c" +#undef ALPHA_ONE +#undef UPDATE_C +#include "sbgemm_kernel_8x8_neoversen2_impl.c" + +int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT *A, IFLOAT *B, + FLOAT *C, BLASLONG ldc) { +#ifdef BGEMM + bfloat16_t alpha_bf16; + memcpy(&alpha_bf16, &alpha, sizeof(bfloat16_t)); + float alpha_f32 = vcvtah_f32_bf16(alpha_bf16); +#else + float alpha_f32 = alpha; +#endif + + if (alpha_f32 == 1.0f) + return gemm_kernel_neoversen2_alpha_one(m, n, k, alpha, A, B, C, ldc); + else + return gemm_kernel_neoversen2_alpha(m, n, k, alpha, A, B, C, ldc); + + return 0; +} diff --git a/kernel/arm64/sbgemm_kernel_8x8_neoversen2_impl.c b/kernel/arm64/sbgemm_kernel_8x8_neoversen2_impl.c new file mode 100644 index 0000000000..ed230a5ac7 --- /dev/null +++ b/kernel/arm64/sbgemm_kernel_8x8_neoversen2_impl.c @@ -0,0 +1,763 @@ +/*************************************************************************** + * Copyright (c) 2022,2026 The OpenBLAS Project + * All rights reserved. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the + * distribution. + * 3. Neither the name of the OpenBLAS project nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * *****************************************************************************/ + +#include +#include + +#include "common.h" + +#define INIT_C(M, N) mc##M##N = svdup_f32(0); + +#define MATMUL(M, N) mc##M##N = svbfmmla(mc##M##N, ma##M, mb##N); + +#define INIT_C_8x4 \ + do { \ + INIT_C(0, 0); \ + INIT_C(0, 1); \ + INIT_C(1, 0); \ + INIT_C(1, 1); \ + INIT_C(2, 0); \ + INIT_C(2, 1); \ + INIT_C(3, 0); \ + INIT_C(3, 1); \ + } while (0); + +#define INIT_C_8x8 \ + do { \ + INIT_C(0, 0); \ + INIT_C(0, 1); \ + INIT_C(0, 2); \ + INIT_C(0, 3); \ + INIT_C(1, 0); \ + INIT_C(1, 1); \ + INIT_C(1, 2); \ + INIT_C(1, 3); \ + INIT_C(2, 0); \ + INIT_C(2, 1); \ + INIT_C(2, 2); \ + INIT_C(2, 3); \ + INIT_C(3, 0); \ + INIT_C(3, 1); \ + INIT_C(3, 2); \ + INIT_C(3, 3); \ + } while (0); + +#ifdef BGEMM +#ifdef ALPHA_ONE +#define UPDATE_C(PG16, PG32, PTR, SRC) \ + do { \ + tmp16 = svld1_bf16((PG16), (PTR)); \ + tmp32 = svreinterpret_f32(svzip1_bf16(zeros, tmp16)); \ + tmp32 = svadd_z((PG32), SRC, tmp32); \ + tmp16 = svcvt_bf16_f32_z((PG32), tmp32); \ + tmp16 = svuzp1_bf16(tmp16, tmp16); \ + svst1_bf16((PG16), (PTR), tmp16); \ + } while (0) +#else +#define UPDATE_C(PG16, PG32, PTR, SRC) \ + do { \ + tmp16 = svld1_bf16((PG16), (PTR)); \ + tmp32 = svreinterpret_f32(svzip1_bf16(zeros, tmp16)); \ + tmp32 = svmad_z((PG32), svalpha, SRC, tmp32); \ + tmp16 = svcvt_bf16_f32_z((PG32), tmp32); \ + tmp16 = svuzp1_bf16(tmp16, tmp16); \ + svst1_bf16((PG16), (PTR), tmp16); \ + } while (0) +#endif +#else +#ifdef ALPHA_ONE +#define UPDATE_C(PG16, PG32, PTR, SRC) \ + do { \ + tmp32 = svld1_f32((PG32), (PTR)); \ + tmp32 = svadd_z((PG32), SRC, tmp32); \ + svst1_f32((PG32), (PTR), tmp32); \ + } while (0); +#else +#define UPDATE_C(PG16, PG32, PTR, SRC) \ + do { \ + tmp32 = svld1_f32((PG32), (PTR)); \ + tmp32 = svmad_z((PG32), svalpha, SRC, tmp32); \ + svst1_f32((PG32), (PTR), tmp32); \ + } while (0); + #endif + #endif + +#ifdef BGEMM +#define OUTPUT_FLOAT bfloat16_t +#else +#define OUTPUT_FLOAT float +#endif + +#ifdef ALPHA_ONE +static int gemm_kernel_neoversen2_alpha_one(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT * A, IFLOAT * B, FLOAT * C, BLASLONG ldc) +#else +static int gemm_kernel_neoversen2_alpha(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT alpha, IFLOAT * A, IFLOAT * B, FLOAT * C, BLASLONG ldc) +#endif +{ + BLASLONG pad_k = (k + 3) & ~3; + + svbfloat16_t ma0, ma1, ma2, ma3, mb0, mb1, mb2, mb3; + svfloat32_t mc00, mc01, mc02, mc03; + svfloat32_t mc10, mc11, mc12, mc13; + svfloat32_t mc20, mc21, mc22, mc23; + svfloat32_t mc30, mc31, mc32, mc33; + svfloat32_t vc0, vc1, vc2, vc3, vc4, vc5, vc6, vc7; + svfloat32_t vc8, vc9, vc10, vc11, vc12, vc13, vc14, vc15; + +#ifndef ALPHA_ONE +#ifdef BGEMM + bfloat16_t alpha_bf16; + memcpy(&alpha_bf16, &alpha, sizeof(bfloat16_t)); + svfloat32_t svalpha = svdup_f32(vcvtah_f32_bf16(alpha_bf16)); +#else + svfloat32_t svalpha = svdup_f32(alpha); +#endif +#endif + + svbool_t pg32_first_4 = svdupq_b32(1, 1, 1, 1); + svbool_t pg32_first_2 = svdupq_b32(1, 1, 0, 0); + svbool_t pg32_first_1 = svdupq_b32(1, 0, 0, 0); + svbool_t pg16_first_8 = svdupq_b16(1, 1, 1, 1, 1, 1, 1, 1); + svbool_t pg16_first_4 = svdupq_b16(1, 1, 1, 1, 0, 0, 0, 0); +#ifdef BGEMM + svbool_t pg16_first_2 = svdupq_b16(1, 1, 0, 0, 0, 0, 0, 0); + svbool_t pg16_first_1 = svdupq_b16(1, 0, 0, 0, 0, 0, 0, 0); + svbfloat16_t zeros = svdup_n_bf16(vcvth_bf16_f32(0.0)); +#endif + + bfloat16_t *ptr_a = (bfloat16_t *)A; + bfloat16_t *ptr_b = (bfloat16_t *)B; + OUTPUT_FLOAT *ptr_c = (OUTPUT_FLOAT*)C; + + bfloat16_t *ptr_a0; + bfloat16_t *ptr_b0; + OUTPUT_FLOAT *ptr_c0, *ptr_c1, *ptr_c2, *ptr_c3; + OUTPUT_FLOAT *ptr_c4, *ptr_c5, *ptr_c6, *ptr_c7; + + svfloat32_t tmp32; +#ifdef BGEMM + svbfloat16_t tmp16; +#endif + + for (BLASLONG j = 0; j < n / 8; j++) { + ptr_c0 = ptr_c; + ptr_c1 = ptr_c0 + ldc; + ptr_c2 = ptr_c1 + ldc; + ptr_c3 = ptr_c2 + ldc; + ptr_c4 = ptr_c3 + ldc; + ptr_c5 = ptr_c4 + ldc; + ptr_c6 = ptr_c5 + ldc; + ptr_c7 = ptr_c6 + ldc; + ptr_c += 8 * ldc; + ptr_a = (bfloat16_t *)A; + + for (BLASLONG i = 0; i < m / 8; i++) { + ptr_a0 = ptr_a; + ptr_a += 8 * pad_k; + + ptr_b0 = ptr_b; + + INIT_C_8x8; + + for (BLASLONG p = 0; p < pad_k; p += 4) { + ma0 = svld1_bf16(pg16_first_8, ptr_a0); + ma1 = svld1_bf16(pg16_first_8, ptr_a0 + 8); + ma2 = svld1_bf16(pg16_first_8, ptr_a0 + 16); + ma3 = svld1_bf16(pg16_first_8, ptr_a0 + 24); + + mb0 = svld1_bf16(pg16_first_8, ptr_b0); + mb1 = svld1_bf16(pg16_first_8, ptr_b0 + 8); + mb2 = svld1_bf16(pg16_first_8, ptr_b0 + 16); + mb3 = svld1_bf16(pg16_first_8, ptr_b0 + 24); + + MATMUL(0, 0); MATMUL(0, 1); MATMUL(0, 2); MATMUL(0, 3); + MATMUL(1, 0); MATMUL(1, 1); MATMUL(1, 2); MATMUL(1, 3); + MATMUL(2, 0); MATMUL(2, 1); MATMUL(2, 2); MATMUL(2, 3); + MATMUL(3, 0); MATMUL(3, 1); MATMUL(3, 2); MATMUL(3, 3); + + ptr_a0 += 32; + ptr_b0 += 32; + } + + vc0 = svuzp1(mc00, mc10); + vc1 = svuzp1(mc20, mc30); + vc2 = svuzp2(mc00, mc10); + vc3 = svuzp2(mc20, mc30); + vc4 = svuzp1(mc01, mc11); + vc5 = svuzp1(mc21, mc31); + vc6 = svuzp2(mc01, mc11); + vc7 = svuzp2(mc21, mc31); + vc8 = svuzp1(mc02, mc12); + vc9 = svuzp1(mc22, mc32); + vc10 = svuzp2(mc02, mc12); + vc11 = svuzp2(mc22, mc32); + vc12 = svuzp1(mc03, mc13); + vc13 = svuzp1(mc23, mc33); + vc14 = svuzp2(mc03, mc13); + vc15 = svuzp2(mc23, mc33); + + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c0, vc0); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c0 + 4, vc1); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c1, vc2); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c1 + 4, vc3); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c2, vc4); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c2 + 4, vc5); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c3, vc6); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c3 + 4, vc7); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c4, vc8); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c4 + 4, vc9); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c5, vc10); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c5 + 4, vc11); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c6, vc12); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c6 + 4, vc13); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c7, vc14); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c7 + 4, vc15); + + ptr_c0 += 8; + ptr_c1 += 8; + ptr_c2 += 8; + ptr_c3 += 8; + ptr_c4 += 8; + ptr_c5 += 8; + ptr_c6 += 8; + ptr_c7 += 8; + } + + if (m & 4) { + ptr_a0 = ptr_a; + ptr_a += 4 * pad_k; + ptr_b0 = ptr_b; + + INIT_C(0, 0); INIT_C(0, 1); INIT_C(0, 2); INIT_C(0, 3); + INIT_C(1, 0); INIT_C(1, 1); INIT_C(1, 2); INIT_C(1, 3); + + for (BLASLONG p = 0; p < pad_k; p += 4) { + ma0 = svld1_bf16(pg16_first_8, ptr_a0); + ma1 = svld1_bf16(pg16_first_8, ptr_a0 + 8); + mb0 = svld1_bf16(pg16_first_8, ptr_b0); + mb1 = svld1_bf16(pg16_first_8, ptr_b0 + 8); + mb2 = svld1_bf16(pg16_first_8, ptr_b0 + 16); + mb3 = svld1_bf16(pg16_first_8, ptr_b0 + 24); + + MATMUL(0, 0); MATMUL(0, 1); MATMUL(0, 2); MATMUL(0, 3); + MATMUL(1, 0); MATMUL(1, 1); MATMUL(1, 2); MATMUL(1, 3); + + ptr_a0 += 16; + ptr_b0 += 32; + } + + vc0 = svuzp1(mc00, mc10); + vc1 = svuzp2(mc00, mc10); + vc2 = svuzp1(mc01, mc11); + vc3 = svuzp2(mc01, mc11); + vc4 = svuzp1(mc02, mc12); + vc5 = svuzp2(mc02, mc12); + vc6 = svuzp1(mc03, mc13); + vc7 = svuzp2(mc03, mc13); + + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c0, vc0); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c1, vc1); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c2, vc2); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c3, vc3); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c4, vc4); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c5, vc5); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c6, vc6); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c7, vc7); + + ptr_c0 += 4; + ptr_c1 += 4; + ptr_c2 += 4; + ptr_c3 += 4; + ptr_c4 += 4; + ptr_c5 += 4; + ptr_c6 += 4; + ptr_c7 += 4; + } + + if (m & 2) { + ptr_a0 = ptr_a; + ptr_a += 2 * pad_k; + ptr_b0 = ptr_b; + + INIT_C(0, 0); INIT_C(0, 1); INIT_C(0, 2); INIT_C(0, 3); + for (BLASLONG p = 0; p < pad_k; p += 4) { + ma0 = svld1_bf16(pg16_first_8, ptr_a0); + mb0 = svld1_bf16(pg16_first_8, ptr_b0); + mb1 = svld1_bf16(pg16_first_8, ptr_b0 + 8); + mb2 = svld1_bf16(pg16_first_8, ptr_b0 + 16); + mb3 = svld1_bf16(pg16_first_8, ptr_b0 + 24); + + MATMUL(0, 0); MATMUL(0, 1); MATMUL(0, 2); MATMUL(0, 3); + + ptr_a0 += 8; + ptr_b0 += 32; + } + + vc0 = svuzp1(mc00, mc00); + vc1 = svuzp2(mc00, mc00); + vc2 = svuzp1(mc01, mc01); + vc3 = svuzp2(mc01, mc01); + vc4 = svuzp1(mc02, mc02); + vc5 = svuzp2(mc02, mc02); + vc6 = svuzp1(mc03, mc03); + vc7 = svuzp2(mc03, mc03); + + UPDATE_C(pg16_first_2, pg32_first_2, ptr_c0, vc0); + UPDATE_C(pg16_first_2, pg32_first_2, ptr_c1, vc1); + UPDATE_C(pg16_first_2, pg32_first_2, ptr_c2, vc2); + UPDATE_C(pg16_first_2, pg32_first_2, ptr_c3, vc3); + UPDATE_C(pg16_first_2, pg32_first_2, ptr_c4, vc4); + UPDATE_C(pg16_first_2, pg32_first_2, ptr_c5, vc5); + UPDATE_C(pg16_first_2, pg32_first_2, ptr_c6, vc6); + UPDATE_C(pg16_first_2, pg32_first_2, ptr_c7, vc7); + + ptr_c0 += 2; + ptr_c1 += 2; + ptr_c2 += 2; + ptr_c3 += 2; + ptr_c4 += 2; + ptr_c5 += 2; + ptr_c6 += 2; + ptr_c7 += 2; + } + + if (m & 1) { + ptr_a0 = ptr_a; + ptr_b0 = ptr_b; + + INIT_C(0, 0); INIT_C(0, 1); INIT_C(0, 2); INIT_C(0, 3); + for (BLASLONG p = 0; p < pad_k; p += 4) { + ma0 = svld1_bf16(pg16_first_4, ptr_a0); + mb0 = svld1_bf16(pg16_first_8, ptr_b0); + mb1 = svld1_bf16(pg16_first_8, ptr_b0 + 8); + mb2 = svld1_bf16(pg16_first_8, ptr_b0 + 16); + mb3 = svld1_bf16(pg16_first_8, ptr_b0 + 24); + + MATMUL(0, 0); MATMUL(0, 1); MATMUL(0, 2); MATMUL(0, 3); + + ptr_a0 += 4; + ptr_b0 += 32; + } + + vc1 = svuzp2(mc00, mc00); + vc3 = svuzp2(mc01, mc01); + vc5 = svuzp2(mc02, mc02); + vc7 = svuzp2(mc03, mc03); + + UPDATE_C(pg16_first_1, pg32_first_1, ptr_c0, mc00); + UPDATE_C(pg16_first_1, pg32_first_1, ptr_c1, vc1); + UPDATE_C(pg16_first_1, pg32_first_1, ptr_c2, mc01); + UPDATE_C(pg16_first_1, pg32_first_1, ptr_c3, vc3); + UPDATE_C(pg16_first_1, pg32_first_1, ptr_c4, mc02); + UPDATE_C(pg16_first_1, pg32_first_1, ptr_c5, vc5); + UPDATE_C(pg16_first_1, pg32_first_1, ptr_c6, mc03); + UPDATE_C(pg16_first_1, pg32_first_1, ptr_c7, vc7); + } + + ptr_b += 8 * pad_k; + } + + if (n & 4) { + ptr_c0 = ptr_c; + ptr_c1 = ptr_c0 + ldc; + ptr_c2 = ptr_c1 + ldc; + ptr_c3 = ptr_c2 + ldc; + ptr_c += 4 * ldc; + ptr_a = (bfloat16_t *)A; + + for (BLASLONG i = 0; i < m / 8; i++) { + ptr_a0 = ptr_a; + ptr_a += 8 * pad_k; + + ptr_b0 = ptr_b; + + INIT_C_8x4; + + for (BLASLONG p = 0; p < pad_k; p += 4) { + ma0 = svld1_bf16(pg16_first_8, ptr_a0); + ma1 = svld1_bf16(pg16_first_8, ptr_a0 + 8); + ma2 = svld1_bf16(pg16_first_8, ptr_a0 + 16); + ma3 = svld1_bf16(pg16_first_8, ptr_a0 + 24); + + mb0 = svld1_bf16(pg16_first_8, ptr_b0); + mb1 = svld1_bf16(pg16_first_8, ptr_b0 + 8); + + MATMUL(0, 0); MATMUL(0, 1); + MATMUL(1, 0); MATMUL(1, 1); + MATMUL(2, 0); MATMUL(2, 1); + MATMUL(3, 0); MATMUL(3, 1); + + ptr_a0 += 32; + ptr_b0 += 16; + } + + vc0 = svuzp1(mc00, mc10); + vc1 = svuzp1(mc20, mc30); + vc2 = svuzp2(mc00, mc10); + vc3 = svuzp2(mc20, mc30); + vc4 = svuzp1(mc01, mc11); + vc5 = svuzp1(mc21, mc31); + vc6 = svuzp2(mc01, mc11); + vc7 = svuzp2(mc21, mc31); + + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c0, vc0); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c0+4, vc1); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c1, vc2); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c1+4, vc3); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c2, vc4); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c2+4, vc5); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c3, vc6); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c3+4, vc7); + + ptr_c0 += 8; + ptr_c1 += 8; + ptr_c2 += 8; + ptr_c3 += 8; + } + + if (m & 4) { + ptr_a0 = ptr_a; + ptr_a += 4 * pad_k; + ptr_b0 = ptr_b; + + INIT_C(0, 0); INIT_C(0, 1); + INIT_C(1, 0); INIT_C(1, 1); + + for (BLASLONG p = 0; p < pad_k; p += 4) { + ma0 = svld1_bf16(pg16_first_8, ptr_a0); + ma1 = svld1_bf16(pg16_first_8, ptr_a0 + 8); + mb0 = svld1_bf16(pg16_first_8, ptr_b0); + mb1 = svld1_bf16(pg16_first_8, ptr_b0 + 8); + + MATMUL(0, 0); MATMUL(0, 1); + MATMUL(1, 0); MATMUL(1, 1); + + ptr_a0 += 16; + ptr_b0 += 16; + } + + vc0 = svuzp1(mc00, mc10); + vc1 = svuzp2(mc00, mc10); + vc2 = svuzp1(mc01, mc11); + vc3 = svuzp2(mc01, mc11); + + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c0, vc0); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c1, vc1); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c2, vc2); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c3, vc3); + + ptr_c0 += 4; + ptr_c1 += 4; + ptr_c2 += 4; + ptr_c3 += 4; + } + + if (m & 2) { + ptr_a0 = ptr_a; + ptr_a += 2 * pad_k; + ptr_b0 = ptr_b; + + INIT_C(0, 0); INIT_C(0, 1); + for (BLASLONG p = 0; p < pad_k; p += 4) { + ma0 = svld1_bf16(pg16_first_8, ptr_a0); + mb0 = svld1_bf16(pg16_first_8, ptr_b0); + mb1 = svld1_bf16(pg16_first_8, ptr_b0 + 8); + + MATMUL(0, 0); MATMUL(0, 1); + + ptr_a0 += 8; + ptr_b0 += 16; + } + + vc0 = svuzp1(mc00, mc00); + vc1 = svuzp2(mc00, mc00); + vc2 = svuzp1(mc01, mc01); + vc3 = svuzp2(mc01, mc01); + + UPDATE_C(pg16_first_2, pg32_first_2, ptr_c0, vc0); + UPDATE_C(pg16_first_2, pg32_first_2, ptr_c1, vc1); + UPDATE_C(pg16_first_2, pg32_first_2, ptr_c2, vc2); + UPDATE_C(pg16_first_2, pg32_first_2, ptr_c3, vc3); + + ptr_c0 += 2; + ptr_c1 += 2; + ptr_c2 += 2; + ptr_c3 += 2; + } + + if (m & 1) { + ptr_a0 = ptr_a; + ptr_b0 = ptr_b; + + INIT_C(0, 0); INIT_C(0, 1); + for (BLASLONG p = 0; p < pad_k; p += 4) { + ma0 = svld1_bf16(pg16_first_4, ptr_a0); + mb0 = svld1_bf16(pg16_first_8, ptr_b0); + mb1 = svld1_bf16(pg16_first_8, ptr_b0 + 8); + + MATMUL(0, 0); MATMUL(0, 1); + + ptr_a0 += 4; + ptr_b0 += 16; + } + + vc1 = svuzp2(mc00, mc00); + vc3 = svuzp2(mc01, mc01); + + UPDATE_C(pg16_first_1, pg32_first_1, ptr_c0, mc00); + UPDATE_C(pg16_first_1, pg32_first_1, ptr_c1, vc1); + UPDATE_C(pg16_first_1, pg32_first_1, ptr_c2, mc01); + UPDATE_C(pg16_first_1, pg32_first_1, ptr_c3, vc3); + + } + + ptr_b += 4 * pad_k; + } + + if (n & 2) { + ptr_c0 = ptr_c; + ptr_c1 = ptr_c0 + ldc; + ptr_c += 2 * ldc; + ptr_a = (bfloat16_t *)A; + + for (BLASLONG i = 0; i < m / 8; i++) { + ptr_a0 = ptr_a; + ptr_a += 8 * pad_k; + + ptr_b0 = ptr_b; + + INIT_C(0, 0); + INIT_C(1, 0); + INIT_C(2, 0); + INIT_C(3, 0); + + for (BLASLONG p = 0; p < pad_k; p += 4) { + ma0 = svld1_bf16(pg16_first_8, ptr_a0); + ma1 = svld1_bf16(pg16_first_8, ptr_a0 + 8); + ma2 = svld1_bf16(pg16_first_8, ptr_a0 + 16); + ma3 = svld1_bf16(pg16_first_8, ptr_a0 + 24); + + mb0 = svld1_bf16(pg16_first_8, ptr_b0); + + MATMUL(0, 0); + MATMUL(1, 0); + MATMUL(2, 0); + MATMUL(3, 0); + + ptr_a0 += 32; + ptr_b0 += 8; + } + + vc0 = svuzp1(mc00, mc10); + vc1 = svuzp1(mc20, mc30); + vc2 = svuzp2(mc00, mc10); + vc3 = svuzp2(mc20, mc30); + + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c0, vc0); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c0 + 4, vc1); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c1, vc2); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c1 + 4, vc3); + + ptr_c0 += 8; + ptr_c1 += 8; + } + + if (m & 4) { + ptr_a0 = ptr_a; + ptr_a += 4 * pad_k; + ptr_b0 = ptr_b; + + INIT_C(0, 0); + INIT_C(1, 0); + + for (BLASLONG p = 0; p < pad_k; p += 4) { + ma0 = svld1_bf16(pg16_first_8, ptr_a0); + ma1 = svld1_bf16(pg16_first_8, ptr_a0 + 8); + mb0 = svld1_bf16(pg16_first_8, ptr_b0); + MATMUL(0, 0); + MATMUL(1, 0); + ptr_a0 += 16; + ptr_b0 += 8; + } + + vc0 = svuzp1(mc00, mc10); + vc1 = svuzp2(mc00, mc10); + + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c0, vc0); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c1, vc1); + + ptr_c0 += 4; + ptr_c1 += 4; + } + + if (m & 2) { + ptr_a0 = ptr_a; + ptr_a += 2 * pad_k; + ptr_b0 = ptr_b; + + INIT_C(0, 0); + + for (BLASLONG p = 0; p < pad_k; p += 4) { + ma0 = svld1_bf16(pg16_first_8, ptr_a0); + mb0 = svld1_bf16(pg16_first_8, ptr_b0); + + MATMUL(0, 0); + + ptr_a0 += 8; + ptr_b0 += 8; + } + + vc0 = svuzp1(mc00, mc00); + vc1 = svuzp2(mc00, mc00); + UPDATE_C(pg16_first_2, pg32_first_2, ptr_c0, vc0); + UPDATE_C(pg16_first_2, pg32_first_2, ptr_c1, vc1); + + ptr_c0 += 2; + ptr_c1 += 2; + + } + + if (m & 1) { + ptr_a0 = ptr_a; + ptr_b0 = ptr_b; + INIT_C(0, 0); + for (BLASLONG p = 0; p < pad_k; p += 4) { + ma0 = svld1_bf16(pg16_first_4, ptr_a0); + mb0 = svld1_bf16(pg16_first_8, ptr_b0); + MATMUL(0, 0); + ptr_a0 += 4; + ptr_b0 += 8; + } + vc1 = svuzp2(mc00, mc00); + + UPDATE_C(pg16_first_1, pg32_first_1, ptr_c0, mc00); + UPDATE_C(pg16_first_1, pg32_first_1, ptr_c1, vc1); + } + + ptr_b += 2 * pad_k; + } + + if (n & 1) { + ptr_c0 = ptr_c; + ptr_a = (bfloat16_t *)A; + + for (BLASLONG i = 0; i < m / 8; i++) { + ptr_a0 = ptr_a; + ptr_a += 8 * pad_k; + + ptr_b0 = ptr_b; + + INIT_C(0, 0); + INIT_C(1, 0); + INIT_C(2, 0); + INIT_C(3, 0); + + for (BLASLONG p = 0; p < pad_k; p += 4) { + ma0 = svld1_bf16(pg16_first_8, ptr_a0); + ma1 = svld1_bf16(pg16_first_8, ptr_a0 + 8); + ma2 = svld1_bf16(pg16_first_8, ptr_a0 + 16); + ma3 = svld1_bf16(pg16_first_8, ptr_a0 + 24); + + mb0 = svld1_bf16(pg16_first_4, ptr_b0); + + MATMUL(0, 0); + MATMUL(1, 0); + MATMUL(2, 0); + MATMUL(3, 0); + + ptr_a0 += 32; + ptr_b0 += 4; + } + + vc0 = svuzp1(mc00, mc10); + vc1 = svuzp1(mc20, mc30); + + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c0, vc0); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c0 + 4, vc1); + + ptr_c0 += 8; + } + + if (m & 4) { + ptr_a0 = ptr_a; + ptr_a += 4 * pad_k; + ptr_b0 = ptr_b; + INIT_C(0, 0); + INIT_C(1, 0); + for (BLASLONG p = 0; p < pad_k; p += 4) { + ma0 = svld1_bf16(pg16_first_8, ptr_a0); + ma1 = svld1_bf16(pg16_first_8, ptr_a0 + 8); + mb0 = svld1_bf16(pg16_first_4, ptr_b0); + MATMUL(0, 0); + MATMUL(1, 0); + ptr_a0 += 16; + ptr_b0 += 4; + } + vc0 = svuzp1(mc00, mc10); + UPDATE_C(pg16_first_4, pg32_first_4, ptr_c0, vc0); + ptr_c0 += 4; + } + + if (m & 2) { + ptr_a0 = ptr_a; + ptr_a += 2 * pad_k; + ptr_b0 = ptr_b; + + INIT_C(0, 0); + + for (BLASLONG p = 0; p < pad_k; p += 4) { + ma0 = svld1_bf16(pg16_first_8, ptr_a0); + mb0 = svld1_bf16(pg16_first_4, ptr_b0); + + MATMUL(0, 0); + + ptr_a0 += 8; + ptr_b0 += 4; + } + vc0 = svuzp1(mc00, mc00); + UPDATE_C(pg16_first_2, pg32_first_2, ptr_c0, vc0); + ptr_c0 += 2; + } + + if (m & 1) { + ptr_a0 = ptr_a; + ptr_b0 = ptr_b; + INIT_C(0, 0); + for (BLASLONG p = 0; p < pad_k; p += 4) { + ma0 = svld1_bf16(pg16_first_4, ptr_a0); + mb0 = svld1_bf16(pg16_first_4, ptr_b0); + MATMUL(0, 0); + ptr_a0 += 4; + ptr_b0 += 4; + } + UPDATE_C(pg16_first_1, pg32_first_1, ptr_c0, mc00); + } + } + + return 0; +} diff --git a/param.h b/param.h index 8112e8915a..b9cacb5920 100644 --- a/param.h +++ b/param.h @@ -1,5 +1,5 @@ /***************************************************************************** -Copyright (c) 2011-2023, 2025 The OpenBLAS Project +Copyright (c) 2011-2023, 2025-2026 The OpenBLAS Project All rights reserved. Redistribution and use in source and binary forms, with or without @@ -3673,14 +3673,14 @@ is a big desktop or server with abundant cache rather than a phone or embedded d #undef BGEMM_DEFAULT_UNROLL_N #define BGEMM_ALIGN_K 4 #define BGEMM_DEFAULT_UNROLL_M 8 -#define BGEMM_DEFAULT_UNROLL_N 4 +#define BGEMM_DEFAULT_UNROLL_N 8 #undef SBGEMM_ALIGN_K #undef SBGEMM_DEFAULT_UNROLL_M #undef SBGEMM_DEFAULT_UNROLL_N #define SBGEMM_ALIGN_K 4 #define SBGEMM_DEFAULT_UNROLL_M 8 -#define SBGEMM_DEFAULT_UNROLL_N 4 +#define SBGEMM_DEFAULT_UNROLL_N 8 #define SGEMM_DEFAULT_UNROLL_M 16 #define SGEMM_DEFAULT_UNROLL_N 4