diff --git a/common_level3.h b/common_level3.h index ba368bf7d5..39abe3016c 100644 --- a/common_level3.h +++ b/common_level3.h @@ -89,6 +89,27 @@ void strmm_direct_LTLN(BLASLONG M, BLASLONG N, float * A, BLASLONG strideA, float * B, BLASLONG strideB); +void ssyrk_direct_alpha_betaUN(BLASLONG N, BLASLONG K, + float alpha, + float * A, BLASLONG strideA, + float beta, + float * C, BLASLONG strideC); +void ssyrk_direct_alpha_betaUT(BLASLONG N, BLASLONG K, + float alpha, + float * A, BLASLONG strideA, + float beta, + float * C, BLASLONG strideC); +void ssyrk_direct_alpha_betaLN(BLASLONG N, BLASLONG K, + float alpha, + float * A, BLASLONG strideA, + float beta, + float * C, BLASLONG strideC); +void ssyrk_direct_alpha_betaLT(BLASLONG N, BLASLONG K, + float alpha, + float * A, BLASLONG strideA, + float beta, + float * C, BLASLONG strideC); + int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K); int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float, diff --git a/common_param.h b/common_param.h index 8f75d0f890..06a04137a5 100644 --- a/common_param.h +++ b/common_param.h @@ -263,6 +263,10 @@ int (*shgemm_otcopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *); void (*strmm_direct_LNLN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG); void (*strmm_direct_LTUN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG); void (*strmm_direct_LTLN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG); + void (*ssyrk_direct_alpha_betaUN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float, float *, BLASLONG); + void (*ssyrk_direct_alpha_betaUT) (BLASLONG, BLASLONG, float, float *, BLASLONG, float, float *, BLASLONG); + void (*ssyrk_direct_alpha_betaLN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float, float *, BLASLONG); + void (*ssyrk_direct_alpha_betaLT) (BLASLONG, BLASLONG, float, float *, BLASLONG, float, float *, BLASLONG); #endif diff --git a/common_s.h b/common_s.h index 8f55180559..df61125f6e 100644 --- a/common_s.h +++ b/common_s.h @@ -56,6 +56,10 @@ #define STRMM_DIRECT_LNLN strmm_direct_LNLN #define STRMM_DIRECT_LTUN strmm_direct_LTUN #define STRMM_DIRECT_LTLN strmm_direct_LTLN +#define SSYRK_DIRECT_ALPHA_BETA_UN ssyrk_direct_alpha_betaUN +#define SSYRK_DIRECT_ALPHA_BETA_UT ssyrk_direct_alpha_betaUT +#define SSYRK_DIRECT_ALPHA_BETA_LN ssyrk_direct_alpha_betaLN +#define SSYRK_DIRECT_ALPHA_BETA_LT ssyrk_direct_alpha_betaLT #define SGEMM_ONCOPY sgemm_oncopy #define SGEMM_OTCOPY sgemm_otcopy @@ -232,6 +236,10 @@ #define STRMM_DIRECT_LNLN gotoblas -> strmm_direct_LNLN #define STRMM_DIRECT_LTUN gotoblas -> strmm_direct_LTUN #define STRMM_DIRECT_LTLN gotoblas -> strmm_direct_LTLN +#define SSYRK_DIRECT_ALPHA_BETA_UN gotoblas -> ssyrk_direct_alpha_betaUN +#define SSYRK_DIRECT_ALPHA_BETA_UT gotoblas -> ssyrk_direct_alpha_betaUT +#define SSYRK_DIRECT_ALPHA_BETA_LN gotoblas -> ssyrk_direct_alpha_betaLN +#define SSYRK_DIRECT_ALPHA_BETA_LT gotoblas -> ssyrk_direct_alpha_betaLT #endif #define SGEMM_ONCOPY gotoblas -> sgemm_oncopy diff --git a/interface/syrk.c b/interface/syrk.c index 69f2328a44..9e493b00fe 100644 --- a/interface/syrk.c +++ b/interface/syrk.c @@ -338,6 +338,23 @@ double NNK; BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME)); return; } +#if !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) && !defined(HFLOAT16) +#if defined(ARCH_ARM64) && (defined(USE_SSYRK_KERNEL_DIRECT)||defined(DYNAMIC_ARCH)) +#if defined(DYNAMIC_ARCH) + if (support_sme1()) +#endif + if (args.n == 0) return; + if (order == CblasRowMajor && n == ldc) { + if (Trans == CblasNoTrans && k == lda) { + (Uplo == CblasUpper ? SSYRK_DIRECT_ALPHA_BETA_UN : SSYRK_DIRECT_ALPHA_BETA_LN)(n, k, alpha, a, lda, beta, c, ldc); + return; + } else if (Trans == CblasTrans && n == lda){ + (Uplo == CblasUpper ? SSYRK_DIRECT_ALPHA_BETA_UT : SSYRK_DIRECT_ALPHA_BETA_LT)(n, k, alpha, a, lda, beta, c, ldc); + return; + } + } +#endif +#endif #endif diff --git a/kernel/CMakeLists.txt b/kernel/CMakeLists.txt index 30f2af867a..985e4d9fa0 100644 --- a/kernel/CMakeLists.txt +++ b/kernel/CMakeLists.txt @@ -241,6 +241,10 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS) if (ARM64) set(USE_DIRECT_STRMM true) endif() + set(USE_DIRECT_SSYRK false) + if (ARM64) + set(USE_DIRECT_SSYRK true) + endif() set(USE_DIRECT_SGEMM false) if (X86_64 OR ARM64) set(USE_DIRECT_SGEMM true) @@ -293,6 +297,16 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS) endif () endif () + if (USE_DIRECT_SSYRK) + if (ARM64) + set (SSYRKDIRECTKERNEL_ALPHA_BETA ssyrk_direct_alpha_beta_arm64_sme1.c) + GenerateNamedObjects("${KERNELDIR}/${SSYRKDIRECTKERNEL_ALPHA_BETA}" "" "syrk_direct_alpha_betaUN" false "" "" false SINGLE) + GenerateNamedObjects("${KERNELDIR}/${SSYRKDIRECTKERNEL_ALPHA_BETA}" "" "syrk_direct_alpha_betaUT" false "" "" false SINGLE) + GenerateNamedObjects("${KERNELDIR}/${SSYRKDIRECTKERNEL_ALPHA_BETA}" "" "syrk_direct_alpha_betaLN" false "" "" false SINGLE) + GenerateNamedObjects("${KERNELDIR}/${SSYRKDIRECTKERNEL_ALPHA_BETA}" "" "syrk_direct_alpha_betaLT" false "" "" false SINGLE) + endif () + endif() + foreach (float_type SINGLE DOUBLE) string(SUBSTRING ${float_type} 0 1 float_char) GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMMKERNEL}" "" "gemm_kernel" false "" "" false ${float_type}) diff --git a/kernel/Makefile.L3 b/kernel/Makefile.L3 index 8fe5051839..6df9d78b1a 100644 --- a/kernel/Makefile.L3 +++ b/kernel/Makefile.L3 @@ -54,6 +54,7 @@ USE_TRMM = 1 USE_DIRECT_SGEMM = 1 USE_DIRECT_SSYMM = 1 USE_DIRECT_STRMM = 1 +USE_DIRECT_SSYRK = 1 endif ifeq ($(ARCH), riscv64) @@ -161,6 +162,16 @@ endif endif endif +ifdef USE_DIRECT_SSYRK +ifndef SSYRKDIRECTKERNEL_ALPHA_BETA +ifeq ($(ARCH), arm64) +ifeq ($(TARGET_CORE), ARMV9SME) +HAVE_SME = 1 +endif +SSYRKDIRECTKERNEL_ALPHA_BETA = ssyrk_direct_alpha_beta_arm64_sme1.c +endif +endif +endif ifeq ($(BUILD_BFLOAT16), 1) ifndef BGEMMKERNEL @@ -261,6 +272,14 @@ SKERNELOBJS += \ endif endif +ifdef USE_DIRECT_SSYRK +ifeq ($(ARCH), arm64) +SKERNELOBJS += \ + ssyrk_direct_alpha_betaUN$(TSUFFIX).$(SUFFIX) ssyrk_direct_alpha_betaUT$(TSUFFIX).$(SUFFIX) \ + ssyrk_direct_alpha_betaLN$(TSUFFIX).$(SUFFIX) ssyrk_direct_alpha_betaLT$(TSUFFIX).$(SUFFIX) +endif +endif + ifneq "$(or $(BUILD_DOUBLE),$(BUILD_COMPLEX16))" "" DKERNELOBJS += \ dgemm_beta$(TSUFFIX).$(SUFFIX) \ @@ -1158,6 +1177,21 @@ $(KDIR)xgemm_kernel_r$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(XGEMMKERNEL) $(XGEMMD $(KDIR)xgemm_kernel_b$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(XGEMMKERNEL) $(XGEMMDEPEND) $(CC) $(CFLAGS) -c -DXDOUBLE -DCOMPLEX -DCC $< -o $@ +ifdef USE_DIRECT_SSYRK +ifeq ($(ARCH), arm64) +$(KDIR)ssyrk_direct_alpha_betaUN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYRKDIRECTKERNEL_ALPHA_BETA) + $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DUPPER -UTRANSA $< -o $@ + +$(KDIR)ssyrk_direct_alpha_betaUT$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYRKDIRECTKERNEL_ALPHA_BETA) + $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DUPPER -DTRANSA $< -o $@ + +$(KDIR)ssyrk_direct_alpha_betaLN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYRKDIRECTKERNEL_ALPHA_BETA) + $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -UUPPER -UTRANSA $< -o $@ + +$(KDIR)ssyrk_direct_alpha_betaLT$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYRKDIRECTKERNEL_ALPHA_BETA) + $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -UUPPER -DTRANSA $< -o $@ +endif +endif ifdef USE_TRMM $(KDIR)strmm_kernel_LN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMKERNEL) diff --git a/kernel/arm64/ssyrk_direct_alpha_beta_arm64_sme1.c b/kernel/arm64/ssyrk_direct_alpha_beta_arm64_sme1.c new file mode 100644 index 0000000000..f137767aaa --- /dev/null +++ b/kernel/arm64/ssyrk_direct_alpha_beta_arm64_sme1.c @@ -0,0 +1,250 @@ +/* + Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. + SPDX-License-Identifier: BSD-3-Clause-Clear +*/ + +#include "common.h" +#include +#include +#include +#if defined(HAVE_SME) + +#if defined(__ARM_FEATURE_SME) && defined(__clang__) && __clang_major__ >= 16 +#include +#endif + +/* Function prototypes */ +extern void sgemm_direct_sme1_preprocess(uint64_t nbr, uint64_t nbc,\ + const float * restrict a, float * a_mod) __asm__("sgemm_direct_sme1_preprocess"); + +/* Function Definitions */ +static uint64_t sve_cntw() { + uint64_t cnt; + asm volatile( + "rdsvl %[res], #1\n" + "lsr %[res], %[res], #2\n" + : [res] "=r" (cnt) :: + ); + return cnt; +} + +#if defined(__ARM_FEATURE_SME) && defined(__ARM_FEATURE_LOCALLY_STREAMING) && defined(__clang__) && __clang_major__ >= 16 +// Outer product kernel. +// Computes a 2SVL x 2SVL block of C, utilizing all four FP32 tiles of ZA. +__attribute__((always_inline)) inline void +kernel_2x2(const float *A, float *B, float *C, size_t shared_dim, + size_t ldc, size_t block_rows, size_t block_cols, float alpha, + float beta, uint64_t row_idx, uint64_t col_idx) + __arm_out("za") __arm_streaming { + + const uint64_t svl = svcntw(); + size_t ldb = ldc; + // Predicate set-up + svbool_t pg = svptrue_b32(); + svbool_t pg_a_0 = svwhilelt_b32_u64(0, block_rows); + svbool_t pg_a_1 = svwhilelt_b32_u64(svl, block_rows); + + svbool_t pg_b_0 = svwhilelt_b32_u64(0, block_cols); + svbool_t pg_b_1 = svwhilelt_b32_u64(svl, block_cols); + +#define pg_c_0 pg_b_0 +#define pg_c_1 pg_b_1 + + svzero_za(); + svfloat32_t beta_vec = svdup_f32(beta); + + // Load C to ZA + for (size_t i = 0; i < MIN(svl, block_rows); i++) { + svfloat32_t row_c_0 = svld1(pg_c_0, &C[i * ldc]); + row_c_0 = svmul_x(pg, beta_vec, row_c_0); + svwrite_hor_za32_f32_m(/*tile*/0, /*slice*/i, pg_c_0, row_c_0); + + svfloat32_t row_c_1 = svld1(pg_c_1, &C[i * ldc + svl]); + row_c_1 = svmul_x(pg, beta_vec, row_c_1); + svwrite_hor_za32_f32_m(/*tile*/1, /*slice*/i, pg_c_1, row_c_1); + } + for (size_t i = svl; i < block_rows; i++) { + svfloat32_t row_c_0 = svld1(pg_c_0, &C[i * ldc]); + row_c_0 = svmul_x(pg, beta_vec, row_c_0); + svwrite_hor_za32_f32_m(/*tile*/2, /*slice*/i, pg_c_0, row_c_0); + + svfloat32_t row_c_1 = svld1(pg_c_1, &C[i * ldc + svl]); + row_c_1 = svmul_x(pg, beta_vec, row_c_1); + svwrite_hor_za32_f32_m(/*tile*/3, /*slice*/i, pg_c_1, row_c_1); + } + + svfloat32_t alpha_vec = svdup_f32(alpha); + // Iterate through shared dimension (K) + for (size_t k = 0; k < shared_dim; k++) { +#if !defined(TRANSA) + // Load column of A + svfloat32_t col_a_0 = svld1(pg_a_0, &A[k * svl]); + col_a_0 = svmul_x(pg, alpha_vec, col_a_0); + svfloat32_t col_a_1 = svld1(pg_a_1, &A[(k + shared_dim) * svl]); + col_a_1 = svmul_x(pg, alpha_vec, col_a_1); + + // Load row of A**T + svfloat32_t row_b_0 = svld1(pg_b_0, &B[k * svl]); + svfloat32_t row_b_1 = svld1(pg_b_1, &B[(k + shared_dim) * svl]); +#else + // Load column of A**T + svfloat32_t col_a_0 = svld1(pg_a_0, &A[k * ldb]); + col_a_0 = svmul_x(pg, alpha_vec, col_a_0); + + svfloat32_t col_a_1 = svld1(pg_a_1, &A[k * ldb + svl]); + col_a_1 = svmul_x(pg, alpha_vec, col_a_1); + + // Load row of A + svfloat32_t row_b_0 = svld1(pg_b_0, &B[k * ldb]); + svfloat32_t row_b_1 = svld1(pg_b_1, &B[k * ldb + svl]); +#endif + // Perform outer product + svmopa_za32_m(/*tile*/0, pg, pg, col_a_0, row_b_0); + svmopa_za32_m(/*tile*/1, pg, pg, col_a_0, row_b_1); + svmopa_za32_m(/*tile*/2, pg, pg, col_a_1, row_b_0); + svmopa_za32_m(/*tile*/3, pg, pg, col_a_1, row_b_1); + } + +#if defined(UPPER) +#define pg_c_0_full pg_c_0 +#define pg_c_1_full pg_c_1 + + bool need_update_pg_b = true; + size_t last_invalid_index = col_idx - row_idx; + // For Upper, If col_idx - row_idx >= 2*svl, we don't need to update the predicate due to all elements above the digonal + if (col_idx - row_idx >= 2*svl) { + need_update_pg_b = false; + } + // Store to C from ZA + for (size_t i = 0; i < MIN(svl, block_rows); i++, last_invalid_index++) { + if (need_update_pg_b) { + pg_c_0 = svnot_b_z(pg_c_0_full, svwhilelt_b32_u64(0, last_invalid_index)); + pg_c_1 = svnot_b_z(pg_c_1_full, svwhilelt_b32_u64(svl, last_invalid_index)); + } + + svst1_hor_za32(/*tile*/0, /*slice*/i, pg_c_0, &C[i * ldc]); + svst1_hor_za32(/*tile*/1, /*slice*/i, pg_c_1, &C[i * ldc + svl]); + } + for (size_t i = svl; i < block_rows; i++,last_invalid_index++) { + if (need_update_pg_b) { + pg_c_0 = svnot_b_z(pg_c_0_full, svwhilelt_b32_u64(0, last_invalid_index)); + pg_c_1 = svnot_b_z(pg_c_1_full, svwhilelt_b32_u64(svl, last_invalid_index)); + } + svst1_hor_za32(/*tile*/2, /*slice*/i, pg_c_0, &C[i * ldc]); + svst1_hor_za32(/*tile*/3, /*slice*/i, pg_c_1, &C[i * ldc + svl]); + } +#else + // Store to C from ZA + size_t valid_index = row_idx - col_idx + 1; + for (size_t i = 0; i < MIN(svl, block_rows); i++, valid_index++) { + pg_c_0 = svwhilelt_b32_u64(0, MIN(valid_index, block_cols)); + pg_c_1 = svwhilelt_b32_u64(svl, MIN(valid_index, block_cols)); + svst1_hor_za32(/*tile*/0, /*slice*/i, pg_c_0, &C[i * ldc]); + svst1_hor_za32(/*tile*/1, /*slice*/i, pg_c_1, &C[i * ldc + svl]); + } + for (size_t i = svl; i < block_rows; i++, valid_index++) { + pg_c_0 = svwhilelt_b32_u64(0, MIN(valid_index, block_cols)); + pg_c_1 = svwhilelt_b32_u64(svl, MIN(valid_index, block_cols)); + svst1_hor_za32(/*tile*/2, /*slice*/i, pg_c_0, &C[i * ldc]); + svst1_hor_za32(/*tile*/3, /*slice*/i, pg_c_1, &C[i * ldc + svl]); + } +#endif +} + +__arm_new("za") __arm_locally_streaming +static void ssyrk_direct_sme1_2VLx2VL(uint64_t n, uint64_t k, const float* alpha,\ + const float *ba, const float* beta, float *restrict bc) { + const uint64_t num_rows = n; + const uint64_t num_cols = n; + + const float *restrict a_ptr = ba; + const float *restrict b_ptr = ba; + float *restrict c_ptr = bc; + + const uint64_t svl = svcntw(); + const uint64_t ldc = n; + + // Block over rows of C (panels of A) + uint64_t row_idx = 0; + + // 2x2 loop + uint64_t row_batch = 2*svl; + + // Block over row dimension of C + for (; row_idx < num_rows; row_idx += row_batch) { + row_batch = MIN(row_batch, num_rows - row_idx); + uint64_t col_batch = 2*svl; +#if defined(UPPER) + // for UPLO is upper, Start from column col_idx = rows_index to ensure we only process the upper triangle (col_idx >= rows_index) + for (uint64_t col_idx = row_idx; col_idx < num_cols; col_idx += col_batch) { + col_batch = MIN(col_batch, num_cols - col_idx); +#else + // for UPLO is lower, we only process the lower triangle part (col_idx <= row_idxx) + for (uint64_t col_idx = 0; col_idx < num_cols && col_idx <= row_idx; col_idx += col_batch) { +#endif + col_batch = MIN(col_batch, num_cols - col_idx); +#if !defined(TRANSA) + kernel_2x2(&a_ptr[row_idx * k], &b_ptr[col_idx * k], + &c_ptr[row_idx * ldc + col_idx], k, + ldc, row_batch, col_batch, *alpha, *beta, row_idx, col_idx); +#else + kernel_2x2(&a_ptr[row_idx], &b_ptr[col_idx], + &c_ptr[row_idx * ldc + col_idx], k, + ldc, row_batch, col_batch, *alpha, *beta, row_idx, col_idx); +#endif + + } + } + return; +} + +#else +static void ssyrk_direct_sme1_2VLx2VL(uint64_t n, uint64_t k, const float* alpha,\ + const float *ba, const float* beta, float *restrict bc){} +#endif + +void CNAME (BLASLONG N, BLASLONG K, float alpha, float * __restrict A,\ + BLASLONG strideA, float beta, float * __restrict C, BLASLONG strideC){ +#if !defined(TRANSA) + uint64_t n_mod, vl_elms; + + vl_elms = sve_cntw(); + + n_mod = ceil((double)N/(double)vl_elms) * vl_elms; + + float *A_mod = (float *) malloc(n_mod*K*sizeof(float)); + + /* Prevent compiler optimization by reading from memory instead + * of reading directly from vector (z) registers. + * */ + asm volatile("" : : :"p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", + "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", + "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", + "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", + "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", + "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"); + + /* Pre-process the left matrix to make it suitable for + matrix sum of outer-product calculation + */ + sgemm_direct_sme1_preprocess(N, K, A, A_mod); + asm volatile("" : : :"p0", "p1", "p2", "p3", "p4", "p5", "p6", "p7", + "p8", "p9", "p10", "p11", "p12", "p13", "p14", "p15", + "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", + "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", + "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", + "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"); + ssyrk_direct_sme1_2VLx2VL(N, K, &alpha, A_mod, &beta, C); + free(A_mod); +#else + ssyrk_direct_sme1_2VLx2VL(N, K, &alpha, A, &beta, C); +#endif + +} + +#else + +void CNAME (BLASLONG N, BLASLONG K, float alpha, float * __restrict A,\ + BLASLONG strideA, float beta, float * __restrict C, BLASLONG strideC){} + +#endif diff --git a/kernel/setparam-ref.c b/kernel/setparam-ref.c index 1164523ff5..32a1384104 100644 --- a/kernel/setparam-ref.c +++ b/kernel/setparam-ref.c @@ -222,6 +222,10 @@ gotoblas_t TABLE_NAME = { strmm_direct_LNLNTS, strmm_direct_LTUNTS, strmm_direct_LTLNTS, + ssyrk_direct_alpha_betaUNTS, + ssyrk_direct_alpha_betaUTTS, + ssyrk_direct_alpha_betaLNTS, + ssyrk_direct_alpha_betaLTTS, #endif sgemm_kernelTS, sgemm_betaTS, diff --git a/param.h b/param.h index c912c6eda1..e030647d1a 100644 --- a/param.h +++ b/param.h @@ -3846,6 +3846,7 @@ Until then, just keep it different than DGEMM_DEFAULT_UNROLL_N to keep copy rout #define USE_SGEMM_KERNEL_DIRECT 1 #define USE_SSYMM_KERNEL_DIRECT 1 #define USE_STRMM_KERNEL_DIRECT 1 +#define USE_SSYRK_KERNEL_DIRECT 1 #endif /* ARMv9 SME */ #if defined(ARMV5)