summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--include/stc/cspan.h14
1 files changed, 9 insertions, 5 deletions
diff --git a/include/stc/cspan.h b/include/stc/cspan.h
index 1bc57e2b..0875ed92 100644
--- a/include/stc/cspan.h
+++ b/include/stc/cspan.h
@@ -82,7 +82,7 @@ int demo2() {
typedef struct { Self##_value *ref; int32_t pos[RANK]; const Self *_s; } Self##_iter; \
\
STC_INLINE Self Self##_slice_(Self##_value* d, const int32_t shape[], const intptr_t stri[], \
- const int rank, const int32_t a[][2]) { \
+ const int rank, const int32_t a[][3]) { \
Self s; int outrank; \
s.data = d + _cspan_slice(s.shape, s.stride.d, &outrank, shape, stri, rank, a); \
c_assert(outrank == RANK); \
@@ -180,8 +180,8 @@ typedef enum {c_ROWMAJOR, c_COLMAJOR} cspan_layout;
// General slicing function;
#define cspan_slice(OutSpan, parent, ...) \
OutSpan##_slice_((parent)->data, (parent)->shape, (parent)->stride.d, cspan_rank(parent) + \
- c_static_assert(cspan_rank(parent) == sizeof((int32_t[][2]){__VA_ARGS__})/sizeof(int32_t[2])), \
- (const int32_t[][2]){__VA_ARGS__})
+ c_static_assert(cspan_rank(parent) == sizeof((int32_t[][3]){__VA_ARGS__})/sizeof(int32_t[3])), \
+ (const int32_t[][3]){__VA_ARGS__})
/* ------------------- PRIVAT DEFINITIONS ------------------- */
@@ -219,7 +219,7 @@ STC_API intptr_t
STC_API intptr_t _cspan_slice(int32_t oshape[], intptr_t ostride[], int* orank,
const int32_t shape[], const intptr_t stride[],
- int rank, const int32_t a[][2]);
+ int rank, const int32_t a[][3]);
STC_API intptr_t* _cspan_shape2stride(cspan_layout layout, intptr_t shape[], int rank);
#endif // STC_CSPAN_H_INCLUDED
@@ -257,7 +257,7 @@ STC_DEF intptr_t* _cspan_shape2stride(cspan_layout layout, intptr_t stride[], in
STC_DEF intptr_t _cspan_slice(int32_t oshape[], intptr_t ostride[], int* orank,
const int32_t shape[], const intptr_t stride[],
- int rank, const int32_t a[][2]) {
+ int rank, const int32_t a[][3]) {
intptr_t off = 0;
int i = 0, oi = 0;
int32_t end;
@@ -272,6 +272,10 @@ STC_DEF intptr_t _cspan_slice(int32_t oshape[], intptr_t ostride[], int* orank,
oshape[oi] = end - a[i][0];
ostride[oi] = stride[i];
c_assert((oshape[oi] > 0) & !c_less_unsigned(shape[i], end));
+ if (a[i][2]) {
+ ostride[oi] *= a[i][2];
+ oshape[oi] = (oshape[oi] - 1)/a[i][2] + 1;
+ }
++oi;
}
*orank = oi;