From 09c18779bb865e8253341fdd0cd09f188b453e33 Mon Sep 17 00:00:00 2001 From: Martin Kroeker Date: Fri, 10 Oct 2025 01:40:28 -0700 Subject: [PATCH 1/7] Add test for SHGEMM --- test/Makefile | 12 ++- test/compare_sgemm_shgemm.c | 148 ++++++++++++++++++++++++++++++++++++ 2 files changed, 157 insertions(+), 3 deletions(-) create mode 100644 test/compare_sgemm_shgemm.c diff --git a/test/Makefile b/test/Makefile index f29bd35471..62585b29c2 100644 --- a/test/Makefile +++ b/test/Makefile @@ -234,6 +234,9 @@ ifeq ($(BUILD_BFLOAT16),1) BF3= test_bgemm B3 = test_sbgemm endif +ifeq ($(BUILD_HFLOAT16),1) +H3 = test_shgemm +endif ifeq ($(BUILD_SINGLE),1) S3=sblat3 endif @@ -257,9 +260,9 @@ endif ifeq ($(SUPPORT_GEMM3M),1) -level3: $(BF3) $(B3) $(S3) $(D3) $(C3) $(Z3) level3_3m +level3: $(BF3) $(B3) $(H3) $(S3) $(D3) $(C3) $(Z3) level3_3m else -level3: $(BF3) $(B3) $(S3) $(D3) $(C3) $(Z3) +level3: $(BF3) $(B3) $(H3) $(S3) $(D3) $(C3) $(Z3) endif ifneq ($(CROSS), 1) @@ -454,6 +457,9 @@ test_sbgemv : compare_sgemv_sbgemv.c ../$(LIBNAME) endif ifeq ($(BUILD_HFLOAT16),1) +test_shgemm : compare_sgemm_shgemm.c test_helpers.h ../$(LIBNAME) + $(CC) $(CLDFLAGS) -DIHFLOAT16 -o test_shgemm compare_sgemm_shgemm.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB) + test_shgemv : compare_sgemv_shgemv.c ../$(LIBNAME) $(CC) $(CLDFLAGS) -o test_shgemv compare_sgemv_shgemv.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB) endif @@ -475,7 +481,7 @@ clean: @rm -f *.$(SUFFIX) *.$(PSUFFIX) gmon.$(SUFFIX)ut *.SUMM *.cxml *.exe *.pdb *.dwf \ sblat1 dblat1 cblat1 zblat1 \ sblat2 dblat2 cblat2 zblat2 \ - test_bgemm test_bgemv test_sbgemm test_sbgemv test_shgemv sblat3 dblat3 cblat3 zblat3 \ + test_bgemm test_bgemv test_sbgemm test_sbgemv test_shgemm test_shgemv sblat3 dblat3 cblat3 zblat3 \ sblat1p dblat1p cblat1p zblat1p \ sblat2p dblat2p cblat2p zblat2p \ sblat3p dblat3p cblat3p zblat3p \ diff --git a/test/compare_sgemm_shgemm.c b/test/compare_sgemm_shgemm.c new file mode 100644 index 0000000000..c03e4e5b66 --- /dev/null +++ b/test/compare_sgemm_shgemm.c @@ -0,0 +1,148 @@ +/*************************************************************************** +Copyright (c) 2020,2025 The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ +#include +#include +#define __USE_POSIX199309 +#include "../common.h" + +#include "test_helpers.h" + +#define SGEMM BLASFUNC(sgemm) +#define SHGEMM BLASFUNC(shgemm) +#define SGEMV BLASFUNC(sgemv) +#define SHGEMV BLASFUNC(shgemv) +#define SHGEMM_LARGEST 256 + +int +main (int argc, char *argv[]) +{ + blasint m, n, k; + int i, j, l; + blasint x, y; + int ret = 0; + int loop = SHGEMM_LARGEST; + char transA = 'N', transB = 'N'; + float alpha = 1.0, beta = 0.0; + + for (x = 0; x <= loop; x++) + { + if ((x > 100) && (x != SHGEMM_LARGEST)) continue; + m = k = n = x; + float *A = (float *)malloc_safe(m * k * sizeof(FLOAT)); + float *B = (float *)malloc_safe(k * n * sizeof(FLOAT)); + float *C = (float *)malloc_safe(m * n * sizeof(FLOAT)); + hfloat16 *AA = (hfloat16 *)malloc_safe(m * k * sizeof(bfloat16)); + hfloat16 *BB = (hfloat16 *)malloc_safe(k * n * sizeof(bfloat16)); + float *DD = (float *)malloc_safe(m * n * sizeof(FLOAT)); + float *CC = (float *)malloc_safe(m * n * sizeof(FLOAT)); + if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) || + (DD == NULL) || (CC == NULL)) + return 1; + + for (j = 0; j < m; j++) + { + for (i = 0; i < k; i++) + { + A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; + AA[j * k + i] = (hfloat16) A[j * k + i]; + } + } + for (j = 0; j < n; j++) + { + for (i = 0; i < k; i++) + { + B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; + BB[j * k + i] = (hfloat16) A[j * k + i]; + } + } + for (y = 0; y < 4; y++) + { + if ((y == 0) || (y == 2)) { + transA = 'N'; + } else { + transA = 'T'; + } + if ((y == 0) || (y == 1)) { + transB = 'N'; + } else { + transB = 'T'; + } + + memset(CC, 0, m * n * sizeof(FLOAT)); + memset(DD, 0, m * n * sizeof(FLOAT)); + memset(C, 0, m * n * sizeof(FLOAT)); + + SGEMM (&transA, &transB, &m, &n, &k, &alpha, A, + &m, B, &k, &beta, C, &m); + SHGEMM (&transA, &transB, &m, &n, &k, &alpha, AA, + &m, BB, &k, &beta, CC, &m); + + for (i = 0; i < n; i++) + for (j = 0; j < m; j++) + { + for (l = 0; l < k; l++) + if (transA == 'N' && transB == 'N') + { + DD[i * m + j] += + (float) AA[l * m + j] * (float)BB[l + k * i]; + } else if (transA == 'T' && transB == 'N') + { + DD[i * m + j] += + (float)AA[k * j + l] * (float)BB[l + k * i]; + } else if (transA == 'N' && transB == 'T') + { + DD[i * m + j] += + (float)AA[l * m + j] * (float)BB[i + l * n]; + } else if (transA == 'T' && transB == 'T') + { + DD[i * m + j] += + (float)AA[k * j + l] * (float)BB[i + l * n]; + } + if (!is_close(CC[i * m + j], C[i * m + j], 0.01, 0.001)) { + ret++; + } + if (!is_close(CC[i * m + j], DD[i * m + j], 0.001, 0.0001)) { + ret++; + } + } + } + free(A); + free(B); + free(C); + free(AA); + free(BB); + free(DD); + free(CC); + } + + if (ret != 0) { + fprintf(stderr, "SHGEMM FAILURES: %d\n", ret); + return 1; + } + + return ret; +} From fba2014239b745022be1590fe316d5ae63847b80 Mon Sep 17 00:00:00 2001 From: Martin Kroeker Date: Fri, 10 Oct 2025 14:36:33 +0200 Subject: [PATCH 2/7] remove spurious POSIX define --- test/compare_sgemm_shgemm.c | 1 - 1 file changed, 1 deletion(-) diff --git a/test/compare_sgemm_shgemm.c b/test/compare_sgemm_shgemm.c index c03e4e5b66..163d95234c 100644 --- a/test/compare_sgemm_shgemm.c +++ b/test/compare_sgemm_shgemm.c @@ -26,7 +26,6 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *****************************************************************************/ #include #include -#define __USE_POSIX199309 #include "../common.h" #include "test_helpers.h" From a5fda2e2c300b52a9a6e27d190af5f6def0295b9 Mon Sep 17 00:00:00 2001 From: Martin Kroeker Date: Fri, 10 Oct 2025 23:38:43 +0200 Subject: [PATCH 3/7] fix missed bfloat/hfloat edit Co-authored-by: Christopher Sidebottom --- test/compare_sgemm_shgemm.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/compare_sgemm_shgemm.c b/test/compare_sgemm_shgemm.c index 163d95234c..c85c856946 100644 --- a/test/compare_sgemm_shgemm.c +++ b/test/compare_sgemm_shgemm.c @@ -54,8 +54,8 @@ main (int argc, char *argv[]) float *A = (float *)malloc_safe(m * k * sizeof(FLOAT)); float *B = (float *)malloc_safe(k * n * sizeof(FLOAT)); float *C = (float *)malloc_safe(m * n * sizeof(FLOAT)); - hfloat16 *AA = (hfloat16 *)malloc_safe(m * k * sizeof(bfloat16)); - hfloat16 *BB = (hfloat16 *)malloc_safe(k * n * sizeof(bfloat16)); + hfloat16 *AA = (hfloat16 *)malloc_safe(m * k * sizeof(hfloat16)); + hfloat16 *BB = (hfloat16 *)malloc_safe(k * n * sizeof(hfloat16)); float *DD = (float *)malloc_safe(m * n * sizeof(FLOAT)); float *CC = (float *)malloc_safe(m * n * sizeof(FLOAT)); if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) || From 87470a3b184dd1fbc1682677d9a9b4a9f3a7d511 Mon Sep 17 00:00:00 2001 From: Martin Kroeker Date: Sun, 12 Oct 2025 14:21:33 -0700 Subject: [PATCH 4/7] remove unused definitions --- test/compare_sgemm_shgemm.c | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/compare_sgemm_shgemm.c b/test/compare_sgemm_shgemm.c index c85c856946..11b6f39f59 100644 --- a/test/compare_sgemm_shgemm.c +++ b/test/compare_sgemm_shgemm.c @@ -32,8 +32,6 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define SGEMM BLASFUNC(sgemm) #define SHGEMM BLASFUNC(shgemm) -#define SGEMV BLASFUNC(sgemv) -#define SHGEMV BLASFUNC(shgemv) #define SHGEMM_LARGEST 256 int From 05adb52353e15d2b28addb6b4332322063aff059 Mon Sep 17 00:00:00 2001 From: Martin Kroeker Date: Wed, 15 Oct 2025 00:29:26 +0200 Subject: [PATCH 5/7] copypaste fix --- test/compare_sgemm_shgemm.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/compare_sgemm_shgemm.c b/test/compare_sgemm_shgemm.c index 11b6f39f59..d235ea44e0 100644 --- a/test/compare_sgemm_shgemm.c +++ b/test/compare_sgemm_shgemm.c @@ -73,7 +73,7 @@ main (int argc, char *argv[]) for (i = 0; i < k; i++) { B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; - BB[j * k + i] = (hfloat16) A[j * k + i]; + BB[j * k + i] = (hfloat16) B[j * k + i]; } } for (y = 0; y < 4; y++) From 19be504cd04ccd9f76ac9b7eeb94d6a0442509d1 Mon Sep 17 00:00:00 2001 From: Martin Kroeker Date: Wed, 15 Oct 2025 14:00:58 -0700 Subject: [PATCH 6/7] Add tests varying alpha and beta --- test/compare_sgemm_shgemm.c | 107 +++++++++++++++++++++++++++++++++--- 1 file changed, 98 insertions(+), 9 deletions(-) diff --git a/test/compare_sgemm_shgemm.c b/test/compare_sgemm_shgemm.c index d235ea44e0..7a97a06697 100644 --- a/test/compare_sgemm_shgemm.c +++ b/test/compare_sgemm_shgemm.c @@ -41,9 +41,11 @@ main (int argc, char *argv[]) int i, j, l; blasint x, y; int ret = 0; + int rret = 0; int loop = SHGEMM_LARGEST; char transA = 'N', transB = 'N'; float alpha = 1.0, beta = 0.0; + int xvals[6]={3,24,55,71,SHGEMM_LARGEST/2,SHGEMM_LARGEST}; for (x = 0; x <= loop; x++) { @@ -52,8 +54,8 @@ main (int argc, char *argv[]) float *A = (float *)malloc_safe(m * k * sizeof(FLOAT)); float *B = (float *)malloc_safe(k * n * sizeof(FLOAT)); float *C = (float *)malloc_safe(m * n * sizeof(FLOAT)); - hfloat16 *AA = (hfloat16 *)malloc_safe(m * k * sizeof(hfloat16)); - hfloat16 *BB = (hfloat16 *)malloc_safe(k * n * sizeof(hfloat16)); + _Float16 *AA = (_Float16 *)malloc_safe(m * k * sizeof(_Float16)); + _Float16 *BB = (_Float16 *)malloc_safe(k * n * sizeof(_Float16)); float *DD = (float *)malloc_safe(m * n * sizeof(FLOAT)); float *CC = (float *)malloc_safe(m * n * sizeof(FLOAT)); if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) || @@ -65,7 +67,7 @@ main (int argc, char *argv[]) for (i = 0; i < k; i++) { A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; - AA[j * k + i] = (hfloat16) A[j * k + i]; + AA[j * k + i] = (_Float16) A[j * k + i]; } } for (j = 0; j < n; j++) @@ -73,7 +75,7 @@ main (int argc, char *argv[]) for (i = 0; i < k; i++) { B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; - BB[j * k + i] = (hfloat16) B[j * k + i]; + BB[j * k + i] = (_Float16) B[j * k + i]; } } for (y = 0; y < 4; y++) @@ -95,8 +97,8 @@ main (int argc, char *argv[]) SGEMM (&transA, &transB, &m, &n, &k, &alpha, A, &m, B, &k, &beta, C, &m); - SHGEMM (&transA, &transB, &m, &n, &k, &alpha, AA, - &m, BB, &k, &beta, CC, &m); + SHGEMM (&transA, &transB, &m, &n, &k, &alpha, (_Float16*) AA, + &m, (_Float16*)BB, &k, &beta, CC, &m); for (i = 0; i < n; i++) for (j = 0; j < m; j++) @@ -120,9 +122,11 @@ main (int argc, char *argv[]) (float)AA[k * j + l] * (float)BB[i + l * n]; } if (!is_close(CC[i * m + j], C[i * m + j], 0.01, 0.001)) { + fprintf(stderr,"CC %f C %f \n",(float)CC[i*m+j],C[i*m+j]); ret++; } if (!is_close(CC[i * m + j], DD[i * m + j], 0.001, 0.0001)) { + fprintf(stderr,"CC %f DD %f \n",(float)CC[i*m+j],(float)DD[i*m+j]); ret++; } } @@ -135,11 +139,96 @@ main (int argc, char *argv[]) free(DD); free(CC); } - if (ret != 0) { - fprintf(stderr, "SHGEMM FAILURES: %d\n", ret); + fprintf(stderr, "SHGEMM FAILURES: %d!!!\n", ret); return 1; } - return ret; + + for (loop = 0; loop<6; loop++) { + x=xvals[loop]; + for (alpha=0.;alpha<=1.;alpha+=0.5) + { + for (beta = 0.0; beta <=1.; beta+=0.5) { + + m = k = n = x; + float *A = (float *)malloc_safe(m * k * sizeof(FLOAT)); + float *B = (float *)malloc_safe(k * n * sizeof(FLOAT)); + float *C = (float *)malloc_safe(m * n * sizeof(FLOAT)); + _Float16 *AA = (_Float16 *)malloc_safe(m * k * sizeof(_Float16)); + _Float16 *BB = (_Float16 *)malloc_safe(k * n * sizeof(_Float16)); + float *CC = (float *)malloc_safe(m * n * sizeof(FLOAT)); + if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) || + (CC == NULL)) + return 1; + + for (j = 0; j < m; j++) + { + for (i = 0; i < k; i++) + { + A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; + AA[j * k + i] = (_Float16) A[j * k + i]; + } + } + for (j = 0; j < n; j++) + { + for (i = 0; i < k; i++) + { + B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; + BB[j * k + i] = (_Float16) B[j * k + i]; + } + } + + for (y = 0; y < 4; y++) + { + if ((y == 0) || (y == 2)) { + transA = 'N'; + } else { + transA = 'T'; + } + if ((y == 0) || (y == 1)) { + transB = 'N'; + } else { + transB = 'T'; + } + + memset(CC, 0, m * n * sizeof(FLOAT)); + memset(C, 0, m * n * sizeof(FLOAT)); + + SGEMM (&transA, &transB, &m, &n, &k, &alpha, A, + &m, B, &k, &beta, C, &m); + SHGEMM (&transA, &transB, &m, &n, &k, &alpha, (_Float16*) AA, + &m, (_Float16*)BB, &k, &beta, CC, &m); + + for (i = 0; i < n; i++) + for (j = 0; j < m; j++) + { + if (!is_close(CC[i * m + j], C[i * m + j], 0.01, 0.001)) { + ret++; + } + } + } + free(A); + free(B); + free(C); + free(AA); + free(BB); + free(CC); + + if (ret != 0) { +/* + * fprintf(stderr, "SHGEMM FAILURES FOR n=%d, alpha=%f beta=%f : %d\n", x, alpha, beta, ret); + */ + rret++; + ret=0; +/* } else { + fprintf(stderr, "SHGEMM SUCCEEDED FOR n=%d, alpha=%f beta=%f : %d\n", x, alpha, beta, ret); +*/ + } + } + + } + } + if (rret > 0) return(1); + return(0); } From ee6aa89fb0d512e1148fe40f625433d169a81f99 Mon Sep 17 00:00:00 2001 From: Martin Kroeker Date: Thu, 16 Oct 2025 03:56:43 -0700 Subject: [PATCH 7/7] Add BFLOAT16 and HFLOAT16 tests --- test/CMakeLists.txt | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index f874fa5eaa..e3491d7f11 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -36,6 +36,28 @@ foreach(test_bin ${OpenBLAS_Tests}) target_link_libraries(${test_bin} ${OpenBLAS_LIBNAME}) endforeach() +if (BUILD_BFLOAT16) + add_executable(test_bgemm compare_sgemm_bgemm.c) + target_compile_definitions(test_bgemm PUBLIC -DIBFLOAT16 -DOBFLOAT16) + target_link_libraries(test_bgemm ${OpenBLAS_LIBNAME}) + add_executable(test_bgemv compare_sgemv_bgemv.c) + target_compile_definitions(test_bgemv PUBLIC -DIBFLOAT16 -DOBFLOAT16) + target_link_libraries(test_bgemv ${OpenBLAS_LIBNAME}) + add_executable(test_sbgemm compare_sgemm_sbgemm.c) + target_compile_definitions(test_sbgemm PUBLIC -DIBFLOAT16) + target_link_libraries(test_sbgemm ${OpenBLAS_LIBNAME}) + add_executable(test_sbgemv compare_sgemv_sbgemv.c) + target_compile_definitions(test_sbgemv PUBLIC -DIBFLOAT16) + target_link_libraries(test_sbgemv ${OpenBLAS_LIBNAME}) +endif() + +if (BUILD_HFLOAT16) + add_executable(test_shgemm compare_sgemm_shgemm.c) + target_link_libraries(test_shgemm ${OpenBLAS_LIBNAME}) + add_executable(test_shgemv compare_sgemv_shgemv.c) + target_link_libraries(test_shgemv ${OpenBLAS_LIBNAME}) +endif() + # $1 exec, $2 input, $3 output_result if(WIN32) FILE(WRITE ${CMAKE_CURRENT_BINARY_DIR}/test_helper.ps1 @@ -94,3 +116,21 @@ add_test(NAME "${float_type}blas3_3m" endif() endif() endforeach() + +if (BUILD_BFLOAT16) + add_test(NAME "bgemm" + COMMAND $) + add_test(NAME "bgemv" + COMMAND $) + add_test(NAME "sbgemm" + COMMAND $) + add_test(NAME "sbgemv" + COMMAND $) +endif() + +if (BUILD_HFLOAT16) + add_test(NAME "shgemm" + COMMAND $) + add_test(NAME "shgemv" + COMMAND $) +endif()