summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorTyge Lovset <[email protected]>2023-08-19 08:57:53 +0200
committerTyge Lovset <[email protected]>2023-08-19 08:57:53 +0200
commite51376c9b72448dad947c3cd3760ab013ca8e4a5 (patch)
tree0651e570f3807affbc7d7f3f6f8270e662ab96ab
parent18a0648d8a9f7dcaf3ce145d8d671679bd46d388 (diff)
downloadSTC-modified-e51376c9b72448dad947c3cd3760ab013ca8e4a5.tar.gz
STC-modified-e51376c9b72448dad947c3cd3760ab013ca8e4a5.zip
Moved cspan_next() to shared implementation (if chosen).
-rw-r--r--include/stc/cspan.h26
-rw-r--r--misc/examples/spans/matmult.c2
2 files changed, 15 insertions, 13 deletions
diff --git a/include/stc/cspan.h b/include/stc/cspan.h
index cca5486a..b8b191f1 100644
--- a/include/stc/cspan.h
+++ b/include/stc/cspan.h
@@ -223,18 +223,7 @@ STC_INLINE intptr_t _cspan_idxN(int rank, const int32_t shape[], const int32_t s
return off;
}
-STC_INLINE intptr_t _cspan_next2(int32_t pos[], const int32_t shape[], const int32_t stride[], int rank, int* done) {
- int i, inc;
- if (stride[0] < stride[rank - 1]) i = rank - 1, inc = -1; else i = 0, inc = 1;
- intptr_t off = stride[i];
- ++pos[i];
- for (; --rank && pos[i] == shape[i]; i += inc) {
- pos[i] = 0; ++pos[i + inc];
- off += stride[i + inc] - stride[i]*shape[i];
- }
- *done = pos[i] == shape[i];
- return off;
-}
+STC_API intptr_t _cspan_next2(int32_t pos[], const int32_t shape[], const int32_t stride[], int rank, int* done);
#define _cspan_next1(pos, shape, stride, rank, done) (*done = ++pos[0]==shape[0], stride[0])
#define _cspan_next3 _cspan_next2
#define _cspan_next4 _cspan_next2
@@ -253,6 +242,19 @@ STC_API int32_t* _cspan_shape2stride(char order, int32_t shape[], int rank);
/* --------------------- IMPLEMENTATION --------------------- */
#if defined(i_implement) || defined(i_static)
+STC_DEF intptr_t _cspan_next2(int32_t pos[], const int32_t shape[], const int32_t stride[], int rank, int* done) {
+ int i, inc;
+ if (stride[0] < stride[rank - 1]) i = rank - 1, inc = -1; else i = 0, inc = 1;
+ intptr_t off = stride[i];
+ ++pos[i];
+ for (; --rank && pos[i] == shape[i]; i += inc) {
+ pos[i] = 0; ++pos[i + inc];
+ off += stride[i + inc] - stride[i]*shape[i];
+ }
+ *done = pos[i] == shape[i];
+ return off;
+}
+
STC_DEF int32_t* _cspan_shape2stride(char order, int32_t shape[], int rank) {
int32_t k = 1, i, j, inc, s1, s2;
if (order == 'F') i = 0, j = rank, inc = 1;
diff --git a/misc/examples/spans/matmult.c b/misc/examples/spans/matmult.c
index b28e6459..35dad7a9 100644
--- a/misc/examples/spans/matmult.c
+++ b/misc/examples/spans/matmult.c
@@ -41,7 +41,7 @@ void recursive_matrix_product(Mat2 A, Mat2 B, OutMat C)
if (C.shape[0] <= recursion_threshold || C.shape[1] <= recursion_threshold) {
base_case_matrix_product(A, B, C);
} else {
- Partition c = partition(C),
+ Partition c = partition(C),
a = partition(A),
b = partition(B);
recursive_matrix_product(a.m00, b.m00, c.m00);