Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions common_level3.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,27 @@ void strmm_direct_LTLN(BLASLONG M, BLASLONG N,
float * A, BLASLONG strideA,
float * B, BLASLONG strideB);

void ssyrk_direct_alpha_betaUN(BLASLONG N, BLASLONG K,
float alpha,
float * A, BLASLONG strideA,
float beta,
float * C, BLASLONG strideC);
void ssyrk_direct_alpha_betaUT(BLASLONG N, BLASLONG K,
float alpha,
float * A, BLASLONG strideA,
float beta,
float * C, BLASLONG strideC);
void ssyrk_direct_alpha_betaLN(BLASLONG N, BLASLONG K,
float alpha,
float * A, BLASLONG strideA,
float beta,
float * C, BLASLONG strideC);
void ssyrk_direct_alpha_betaLT(BLASLONG N, BLASLONG K,
float alpha,
float * A, BLASLONG strideA,
float beta,
float * C, BLASLONG strideC);

int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K);

int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,
Expand Down
4 changes: 4 additions & 0 deletions common_param.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,10 @@ int (*shgemm_otcopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *);
void (*strmm_direct_LNLN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
void (*strmm_direct_LTUN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
void (*strmm_direct_LTLN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
void (*ssyrk_direct_alpha_betaUN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float, float *, BLASLONG);
void (*ssyrk_direct_alpha_betaUT) (BLASLONG, BLASLONG, float, float *, BLASLONG, float, float *, BLASLONG);
void (*ssyrk_direct_alpha_betaLN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float, float *, BLASLONG);
void (*ssyrk_direct_alpha_betaLT) (BLASLONG, BLASLONG, float, float *, BLASLONG, float, float *, BLASLONG);
#endif


Expand Down
8 changes: 8 additions & 0 deletions common_s.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@
#define STRMM_DIRECT_LNLN strmm_direct_LNLN
#define STRMM_DIRECT_LTUN strmm_direct_LTUN
#define STRMM_DIRECT_LTLN strmm_direct_LTLN
#define SSYRK_DIRECT_ALPHA_BETA_UN ssyrk_direct_alpha_betaUN
#define SSYRK_DIRECT_ALPHA_BETA_UT ssyrk_direct_alpha_betaUT
#define SSYRK_DIRECT_ALPHA_BETA_LN ssyrk_direct_alpha_betaLN
#define SSYRK_DIRECT_ALPHA_BETA_LT ssyrk_direct_alpha_betaLT

#define SGEMM_ONCOPY sgemm_oncopy
#define SGEMM_OTCOPY sgemm_otcopy
Expand Down Expand Up @@ -232,6 +236,10 @@
#define STRMM_DIRECT_LNLN gotoblas -> strmm_direct_LNLN
#define STRMM_DIRECT_LTUN gotoblas -> strmm_direct_LTUN
#define STRMM_DIRECT_LTLN gotoblas -> strmm_direct_LTLN
#define SSYRK_DIRECT_ALPHA_BETA_UN gotoblas -> ssyrk_direct_alpha_betaUN
#define SSYRK_DIRECT_ALPHA_BETA_UT gotoblas -> ssyrk_direct_alpha_betaUT
#define SSYRK_DIRECT_ALPHA_BETA_LN gotoblas -> ssyrk_direct_alpha_betaLN
#define SSYRK_DIRECT_ALPHA_BETA_LT gotoblas -> ssyrk_direct_alpha_betaLT
#endif

#define SGEMM_ONCOPY gotoblas -> sgemm_oncopy
Expand Down
17 changes: 17 additions & 0 deletions interface/syrk.c
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,23 @@ double NNK;
BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME));
return;
}
#if !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) && !defined(HFLOAT16)
#if defined(ARCH_ARM64) && (defined(USE_SSYRK_KERNEL_DIRECT)||defined(DYNAMIC_ARCH))
#if defined(DYNAMIC_ARCH)
if (support_sme1())
#endif
if (args.n == 0) return;
if (order == CblasRowMajor && n == ldc) {
if (Trans == CblasNoTrans && k == lda) {
(Uplo == CblasUpper ? SSYRK_DIRECT_ALPHA_BETA_UN : SSYRK_DIRECT_ALPHA_BETA_LN)(n, k, alpha, a, lda, beta, c, ldc);
return;
} else if (Trans == CblasTrans && n == lda){
(Uplo == CblasUpper ? SSYRK_DIRECT_ALPHA_BETA_UT : SSYRK_DIRECT_ALPHA_BETA_LT)(n, k, alpha, a, lda, beta, c, ldc);
return;
}
}
#endif
#endif

#endif

Expand Down
14 changes: 14 additions & 0 deletions kernel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,10 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
if (ARM64)
set(USE_DIRECT_STRMM true)
endif()
set(USE_DIRECT_SSYRK false)
if (ARM64)
set(USE_DIRECT_SSYRK true)
endif()
set(USE_DIRECT_SGEMM false)
if (X86_64 OR ARM64)
set(USE_DIRECT_SGEMM true)
Expand Down Expand Up @@ -293,6 +297,16 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
endif ()
endif ()

if (USE_DIRECT_SSYRK)
if (ARM64)
set (SSYRKDIRECTKERNEL_ALPHA_BETA ssyrk_direct_alpha_beta_arm64_sme1.c)
GenerateNamedObjects("${KERNELDIR}/${SSYRKDIRECTKERNEL_ALPHA_BETA}" "" "syrk_direct_alpha_betaUN" false "" "" false SINGLE)
GenerateNamedObjects("${KERNELDIR}/${SSYRKDIRECTKERNEL_ALPHA_BETA}" "" "syrk_direct_alpha_betaUT" false "" "" false SINGLE)
GenerateNamedObjects("${KERNELDIR}/${SSYRKDIRECTKERNEL_ALPHA_BETA}" "" "syrk_direct_alpha_betaLN" false "" "" false SINGLE)
GenerateNamedObjects("${KERNELDIR}/${SSYRKDIRECTKERNEL_ALPHA_BETA}" "" "syrk_direct_alpha_betaLT" false "" "" false SINGLE)
endif ()
endif()

foreach (float_type SINGLE DOUBLE)
string(SUBSTRING ${float_type} 0 1 float_char)
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMMKERNEL}" "" "gemm_kernel" false "" "" false ${float_type})
Expand Down
34 changes: 34 additions & 0 deletions kernel/Makefile.L3
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ USE_TRMM = 1
USE_DIRECT_SGEMM = 1
USE_DIRECT_SSYMM = 1
USE_DIRECT_STRMM = 1
USE_DIRECT_SSYRK = 1
endif

ifeq ($(ARCH), riscv64)
Expand Down Expand Up @@ -161,6 +162,16 @@ endif
endif
endif

ifdef USE_DIRECT_SSYRK
ifndef SSYRKDIRECTKERNEL_ALPHA_BETA
ifeq ($(ARCH), arm64)
ifeq ($(TARGET_CORE), ARMV9SME)
HAVE_SME = 1
endif
SSYRKDIRECTKERNEL_ALPHA_BETA = ssyrk_direct_alpha_beta_arm64_sme1.c
endif
endif
endif

ifeq ($(BUILD_BFLOAT16), 1)
ifndef BGEMMKERNEL
Expand Down Expand Up @@ -261,6 +272,14 @@ SKERNELOBJS += \
endif
endif

ifdef USE_DIRECT_SSYRK
ifeq ($(ARCH), arm64)
SKERNELOBJS += \
ssyrk_direct_alpha_betaUN$(TSUFFIX).$(SUFFIX) ssyrk_direct_alpha_betaUT$(TSUFFIX).$(SUFFIX) \
ssyrk_direct_alpha_betaLN$(TSUFFIX).$(SUFFIX) ssyrk_direct_alpha_betaLT$(TSUFFIX).$(SUFFIX)
endif
endif

ifneq "$(or $(BUILD_DOUBLE),$(BUILD_COMPLEX16))" ""
DKERNELOBJS += \
dgemm_beta$(TSUFFIX).$(SUFFIX) \
Expand Down Expand Up @@ -1158,6 +1177,21 @@ $(KDIR)xgemm_kernel_r$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(XGEMMKERNEL) $(XGEMMD
$(KDIR)xgemm_kernel_b$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(XGEMMKERNEL) $(XGEMMDEPEND)
$(CC) $(CFLAGS) -c -DXDOUBLE -DCOMPLEX -DCC $< -o $@

ifdef USE_DIRECT_SSYRK
ifeq ($(ARCH), arm64)
$(KDIR)ssyrk_direct_alpha_betaUN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYRKDIRECTKERNEL_ALPHA_BETA)
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DUPPER -UTRANSA $< -o $@

$(KDIR)ssyrk_direct_alpha_betaUT$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYRKDIRECTKERNEL_ALPHA_BETA)
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DUPPER -DTRANSA $< -o $@

$(KDIR)ssyrk_direct_alpha_betaLN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYRKDIRECTKERNEL_ALPHA_BETA)
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -UUPPER -UTRANSA $< -o $@

$(KDIR)ssyrk_direct_alpha_betaLT$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYRKDIRECTKERNEL_ALPHA_BETA)
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -UUPPER -DTRANSA $< -o $@
endif
endif

ifdef USE_TRMM
$(KDIR)strmm_kernel_LN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMKERNEL)
Expand Down
Loading
Loading