From 541ce2af6bda0bb21393bdee3fed1e70f9ce40f1 Mon Sep 17 00:00:00 2001 From: Tyge Løvset Date: Wed, 16 Aug 2023 17:15:37 +0200 Subject: Added recursive matrix multiplication example for cspan. --- include/c11/fmt.h | 6 +-- misc/examples/spans/matmult.c | 90 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+), 3 deletions(-) create mode 100644 misc/examples/spans/matmult.c diff --git a/include/c11/fmt.h b/include/c11/fmt.h index d7c10cbe..df96bae3 100644 --- a/include/c11/fmt.h +++ b/include/c11/fmt.h @@ -25,7 +25,7 @@ void fmt_close(fmt_stream* ss); * C11 or higher required. * MAX 255 chars fmt string by default. MAX 12 arguments after fmt string. -* Define FMT_IMPLEMENT or i_implement prior to #include in one translation unit. +* Define FMT_IMPLEMENT, STC_IMPLEMENT or i_implement prior to #include in one translation unit. * Define FMT_SHORTS to add print(), println() and printd() macros, without fmt_ prefix. * (c) operamint, 2022, MIT License. ----------------------------------------------------------------------------------- @@ -84,7 +84,7 @@ int main(void) { #define _fmt_ARG_N(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, \ _14, _15, _16, N, ...) N -#if defined FMT_NDEBUG || defined NDEBUG +#if defined FMT_NDEBUG || defined STC_NDEBUG || defined NDEBUG # define fmt_OK(exp) (void)(exp) #else # define fmt_OK(exp) assert(exp) @@ -196,7 +196,7 @@ void _fmt_bprint(fmt_stream*, const char* fmt, ...); const wchar_t*: "ls", \ const void*: "p") -#if defined FMT_IMPLEMENT || defined i_implement +#if defined FMT_IMPLEMENT || defined STC_IMPLEMENT || defined i_implement #include #include diff --git a/misc/examples/spans/matmult.c b/misc/examples/spans/matmult.c new file mode 100644 index 00000000..62c0c26b --- /dev/null +++ b/misc/examples/spans/matmult.c @@ -0,0 +1,90 @@ +// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2023/p2642r2.html +// C99: +#include +#include +#include + +using_cspan3(Mat, double); +typedef Mat2 OutMat; +typedef struct { Mat2 m00, m01, m10, m11; } Partition; + +Partition partition(Mat2 A) +{ + int32_t M = A.shape[0]; + int32_t N = A.shape[1]; + return (Partition){ + .m00 = cspan_slice(Mat2, &A, {0, M/2}, {0, N/2}), + .m01 = cspan_slice(Mat2, &A, {0, M/2}, {N/2, N}), + .m10 = cspan_slice(Mat2, &A, {M/2, M}, {0, N/2}), + .m11 = cspan_slice(Mat2, &A, {M/2, M}, {N/2, N}), + }; +} + +// Slow generic implementation +void base_case_matrix_product(Mat2 A, Mat2 B, OutMat C) +{ + for (int j = 0; j < C.shape[1]; ++j) { + for (int i = 0; i < C.shape[0]; ++i) { + Mat2_value C_ij = 0; + for (int k = 0; k < A.shape[1]; ++k) { + C_ij += *cspan_at(&A, i,k) * *cspan_at(&B, k,j); + } + *cspan_at(&C, i,j) += C_ij; + } + } +} + +void recursive_matrix_product(Mat2 A, Mat2 B, OutMat C) +{ + // Some hardware-dependent constant + enum {recursion_threshold = 16}; + if (C.shape[0] <= recursion_threshold || C.shape[1] <= recursion_threshold) { + base_case_matrix_product(A, B, C); + } else { + Partition c = partition(C), + a = partition(A), + b = partition(B); + recursive_matrix_product(a.m00, b.m00, c.m00); + recursive_matrix_product(a.m01, b.m10, c.m00); + recursive_matrix_product(a.m10, b.m00, c.m10); + recursive_matrix_product(a.m11, b.m10, c.m10); + recursive_matrix_product(a.m00, b.m01, c.m01); + recursive_matrix_product(a.m01, b.m11, c.m01); + recursive_matrix_product(a.m10, b.m01, c.m11); + recursive_matrix_product(a.m11, b.m11, c.m11); + } +} + + +#define i_type Values +#define i_val double +#include +#include + +int main(void) +{ + enum {N = 10, D1 = 256, D2 = D1}; + + Values values = {0}; + for (int i=0; i < N*D1*D2; ++i) + Values_push(&values, (crandf() - 0.5f)*4.0f); + + double out[D1*D2]; + Mat3 data = cspan_md_order('C', values.data, N, D1, D2); + OutMat c = cspan_md_order('C', out, D1, D2); + Mat2 a = cspan_submd3(&data, 0); + double sum = 0.0; + clock_t t = clock(); + + for (int i=1; i