diff options
| -rw-r--r-- | include/stc/cspan.h | 26 | ||||
| -rw-r--r-- | misc/examples/spans/matmult.c | 2 |
2 files changed, 15 insertions, 13 deletions
diff --git a/include/stc/cspan.h b/include/stc/cspan.h index cca5486a..b8b191f1 100644 --- a/include/stc/cspan.h +++ b/include/stc/cspan.h @@ -223,18 +223,7 @@ STC_INLINE intptr_t _cspan_idxN(int rank, const int32_t shape[], const int32_t s return off; } -STC_INLINE intptr_t _cspan_next2(int32_t pos[], const int32_t shape[], const int32_t stride[], int rank, int* done) { - int i, inc; - if (stride[0] < stride[rank - 1]) i = rank - 1, inc = -1; else i = 0, inc = 1; - intptr_t off = stride[i]; - ++pos[i]; - for (; --rank && pos[i] == shape[i]; i += inc) { - pos[i] = 0; ++pos[i + inc]; - off += stride[i + inc] - stride[i]*shape[i]; - } - *done = pos[i] == shape[i]; - return off; -} +STC_API intptr_t _cspan_next2(int32_t pos[], const int32_t shape[], const int32_t stride[], int rank, int* done); #define _cspan_next1(pos, shape, stride, rank, done) (*done = ++pos[0]==shape[0], stride[0]) #define _cspan_next3 _cspan_next2 #define _cspan_next4 _cspan_next2 @@ -253,6 +242,19 @@ STC_API int32_t* _cspan_shape2stride(char order, int32_t shape[], int rank); /* --------------------- IMPLEMENTATION --------------------- */ #if defined(i_implement) || defined(i_static) +STC_DEF intptr_t _cspan_next2(int32_t pos[], const int32_t shape[], const int32_t stride[], int rank, int* done) { + int i, inc; + if (stride[0] < stride[rank - 1]) i = rank - 1, inc = -1; else i = 0, inc = 1; + intptr_t off = stride[i]; + ++pos[i]; + for (; --rank && pos[i] == shape[i]; i += inc) { + pos[i] = 0; ++pos[i + inc]; + off += stride[i + inc] - stride[i]*shape[i]; + } + *done = pos[i] == shape[i]; + return off; +} + STC_DEF int32_t* _cspan_shape2stride(char order, int32_t shape[], int rank) { int32_t k = 1, i, j, inc, s1, s2; if (order == 'F') i = 0, j = rank, inc = 1; diff --git a/misc/examples/spans/matmult.c b/misc/examples/spans/matmult.c index b28e6459..35dad7a9 100644 --- a/misc/examples/spans/matmult.c +++ b/misc/examples/spans/matmult.c @@ -41,7 +41,7 @@ void recursive_matrix_product(Mat2 A, Mat2 B, OutMat C) if (C.shape[0] <= recursion_threshold || C.shape[1] <= recursion_threshold) { base_case_matrix_product(A, B, C); } else { - Partition c = partition(C), + Partition c = partition(C), a = partition(A), b = partition(B); recursive_matrix_product(a.m00, b.m00, c.m00); |
