diff options
| author | Tyge Lovset <[email protected]> | 2023-08-19 08:57:53 +0200 |
|---|---|---|
| committer | Tyge Lovset <[email protected]> | 2023-08-19 08:57:53 +0200 |
| commit | e51376c9b72448dad947c3cd3760ab013ca8e4a5 (patch) | |
| tree | 0651e570f3807affbc7d7f3f6f8270e662ab96ab | |
| parent | 18a0648d8a9f7dcaf3ce145d8d671679bd46d388 (diff) | |
| download | STC-modified-e51376c9b72448dad947c3cd3760ab013ca8e4a5.tar.gz STC-modified-e51376c9b72448dad947c3cd3760ab013ca8e4a5.zip | |
Moved cspan_next() to shared implementation (if chosen).
| -rw-r--r-- | include/stc/cspan.h | 26 | ||||
| -rw-r--r-- | misc/examples/spans/matmult.c | 2 |
2 files changed, 15 insertions, 13 deletions
diff --git a/include/stc/cspan.h b/include/stc/cspan.h index cca5486a..b8b191f1 100644 --- a/include/stc/cspan.h +++ b/include/stc/cspan.h @@ -223,18 +223,7 @@ STC_INLINE intptr_t _cspan_idxN(int rank, const int32_t shape[], const int32_t s return off; } -STC_INLINE intptr_t _cspan_next2(int32_t pos[], const int32_t shape[], const int32_t stride[], int rank, int* done) { - int i, inc; - if (stride[0] < stride[rank - 1]) i = rank - 1, inc = -1; else i = 0, inc = 1; - intptr_t off = stride[i]; - ++pos[i]; - for (; --rank && pos[i] == shape[i]; i += inc) { - pos[i] = 0; ++pos[i + inc]; - off += stride[i + inc] - stride[i]*shape[i]; - } - *done = pos[i] == shape[i]; - return off; -} +STC_API intptr_t _cspan_next2(int32_t pos[], const int32_t shape[], const int32_t stride[], int rank, int* done); #define _cspan_next1(pos, shape, stride, rank, done) (*done = ++pos[0]==shape[0], stride[0]) #define _cspan_next3 _cspan_next2 #define _cspan_next4 _cspan_next2 @@ -253,6 +242,19 @@ STC_API int32_t* _cspan_shape2stride(char order, int32_t shape[], int rank); /* --------------------- IMPLEMENTATION --------------------- */ #if defined(i_implement) || defined(i_static) +STC_DEF intptr_t _cspan_next2(int32_t pos[], const int32_t shape[], const int32_t stride[], int rank, int* done) { + int i, inc; + if (stride[0] < stride[rank - 1]) i = rank - 1, inc = -1; else i = 0, inc = 1; + intptr_t off = stride[i]; + ++pos[i]; + for (; --rank && pos[i] == shape[i]; i += inc) { + pos[i] = 0; ++pos[i + inc]; + off += stride[i + inc] - stride[i]*shape[i]; + } + *done = pos[i] == shape[i]; + return off; +} + STC_DEF int32_t* _cspan_shape2stride(char order, int32_t shape[], int rank) { int32_t k = 1, i, j, inc, s1, s2; if (order == 'F') i = 0, j = rank, inc = 1; diff --git a/misc/examples/spans/matmult.c b/misc/examples/spans/matmult.c index b28e6459..35dad7a9 100644 --- a/misc/examples/spans/matmult.c +++ b/misc/examples/spans/matmult.c @@ -41,7 +41,7 @@ void recursive_matrix_product(Mat2 A, Mat2 B, OutMat C) if (C.shape[0] <= recursion_threshold || C.shape[1] <= recursion_threshold) { base_case_matrix_product(A, B, C); } else { - Partition c = partition(C), + Partition c = partition(C), a = partition(A), b = partition(B); recursive_matrix_product(a.m00, b.m00, c.m00); |
