summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorTyge Løvset <[email protected]>2023-08-16 17:15:37 +0200
committerTyge Løvset <[email protected]>2023-08-16 17:15:37 +0200
commit541ce2af6bda0bb21393bdee3fed1e70f9ce40f1 (patch)
tree082e8e490fa039405a9690bee8a17398601b4e61
parent5be09e526bc4ee4d1f586aa906e1f9a9c8e3e165 (diff)
downloadSTC-modified-541ce2af6bda0bb21393bdee3fed1e70f9ce40f1.tar.gz
STC-modified-541ce2af6bda0bb21393bdee3fed1e70f9ce40f1.zip
Added recursive matrix multiplication example for cspan.
-rw-r--r--include/c11/fmt.h6
-rw-r--r--misc/examples/spans/matmult.c90
2 files changed, 93 insertions, 3 deletions
diff --git a/include/c11/fmt.h b/include/c11/fmt.h
index d7c10cbe..df96bae3 100644
--- a/include/c11/fmt.h
+++ b/include/c11/fmt.h
@@ -25,7 +25,7 @@ void fmt_close(fmt_stream* ss);
* C11 or higher required.
* MAX 255 chars fmt string by default. MAX 12 arguments after fmt string.
-* Define FMT_IMPLEMENT or i_implement prior to #include in one translation unit.
+* Define FMT_IMPLEMENT, STC_IMPLEMENT or i_implement prior to #include in one translation unit.
* Define FMT_SHORTS to add print(), println() and printd() macros, without fmt_ prefix.
* (c) operamint, 2022, MIT License.
-----------------------------------------------------------------------------------
@@ -84,7 +84,7 @@ int main(void) {
#define _fmt_ARG_N(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, \
_14, _15, _16, N, ...) N
-#if defined FMT_NDEBUG || defined NDEBUG
+#if defined FMT_NDEBUG || defined STC_NDEBUG || defined NDEBUG
# define fmt_OK(exp) (void)(exp)
#else
# define fmt_OK(exp) assert(exp)
@@ -196,7 +196,7 @@ void _fmt_bprint(fmt_stream*, const char* fmt, ...);
const wchar_t*: "ls", \
const void*: "p")
-#if defined FMT_IMPLEMENT || defined i_implement
+#if defined FMT_IMPLEMENT || defined STC_IMPLEMENT || defined i_implement
#include <stdlib.h>
#include <stdarg.h>
diff --git a/misc/examples/spans/matmult.c b/misc/examples/spans/matmult.c
new file mode 100644
index 00000000..62c0c26b
--- /dev/null
+++ b/misc/examples/spans/matmult.c
@@ -0,0 +1,90 @@
+// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2023/p2642r2.html
+// C99:
+#include <stdio.h>
+#include <time.h>
+#include <stc/cspan.h>
+
+using_cspan3(Mat, double);
+typedef Mat2 OutMat;
+typedef struct { Mat2 m00, m01, m10, m11; } Partition;
+
+Partition partition(Mat2 A)
+{
+ int32_t M = A.shape[0];
+ int32_t N = A.shape[1];
+ return (Partition){
+ .m00 = cspan_slice(Mat2, &A, {0, M/2}, {0, N/2}),
+ .m01 = cspan_slice(Mat2, &A, {0, M/2}, {N/2, N}),
+ .m10 = cspan_slice(Mat2, &A, {M/2, M}, {0, N/2}),
+ .m11 = cspan_slice(Mat2, &A, {M/2, M}, {N/2, N}),
+ };
+}
+
+// Slow generic implementation
+void base_case_matrix_product(Mat2 A, Mat2 B, OutMat C)
+{
+ for (int j = 0; j < C.shape[1]; ++j) {
+ for (int i = 0; i < C.shape[0]; ++i) {
+ Mat2_value C_ij = 0;
+ for (int k = 0; k < A.shape[1]; ++k) {
+ C_ij += *cspan_at(&A, i,k) * *cspan_at(&B, k,j);
+ }
+ *cspan_at(&C, i,j) += C_ij;
+ }
+ }
+}
+
+void recursive_matrix_product(Mat2 A, Mat2 B, OutMat C)
+{
+ // Some hardware-dependent constant
+ enum {recursion_threshold = 16};
+ if (C.shape[0] <= recursion_threshold || C.shape[1] <= recursion_threshold) {
+ base_case_matrix_product(A, B, C);
+ } else {
+ Partition c = partition(C),
+ a = partition(A),
+ b = partition(B);
+ recursive_matrix_product(a.m00, b.m00, c.m00);
+ recursive_matrix_product(a.m01, b.m10, c.m00);
+ recursive_matrix_product(a.m10, b.m00, c.m10);
+ recursive_matrix_product(a.m11, b.m10, c.m10);
+ recursive_matrix_product(a.m00, b.m01, c.m01);
+ recursive_matrix_product(a.m01, b.m11, c.m01);
+ recursive_matrix_product(a.m10, b.m01, c.m11);
+ recursive_matrix_product(a.m11, b.m11, c.m11);
+ }
+}
+
+
+#define i_type Values
+#define i_val double
+#include <stc/cstack.h>
+#include <stc/crand.h>
+
+int main(void)
+{
+ enum {N = 10, D1 = 256, D2 = D1};
+
+ Values values = {0};
+ for (int i=0; i < N*D1*D2; ++i)
+ Values_push(&values, (crandf() - 0.5f)*4.0f);
+
+ double out[D1*D2];
+ Mat3 data = cspan_md_order('C', values.data, N, D1, D2);
+ OutMat c = cspan_md_order('C', out, D1, D2);
+ Mat2 a = cspan_submd3(&data, 0);
+ double sum = 0.0;
+ clock_t t = clock();
+
+ for (int i=1; i<N; ++i) {
+ Mat2 b = cspan_submd3(&data, i);
+ memset(out, 0, sizeof out);
+ recursive_matrix_product(a, b, c);
+ //base_case_matrix_product(a, b, c);
+ sum += *cspan_at(&c, 0, 1);
+ }
+
+ t = clock() - t;
+ printf("%.16g: %f\n", sum, (float)t*1000.0f/CLOCKS_PER_SEC);
+ Values_drop(&values);
+}