diff options
Diffstat (limited to 'ndarray.c')
-rw-r--r-- | ndarray.c | 32 |
1 files changed, 32 insertions, 0 deletions
@@ -185,3 +185,35 @@ void mg2di_ndarray_free(NDArray **pa) free(a); *pa = NULL; } + +static int copy_axis(NDArray *dst, const NDArray *src, unsigned int axis, + ptrdiff_t offset_dst, ptrdiff_t offset_src) +{ + if (dst->shape[axis] != src->shape[axis]) + return -EINVAL; + + if (axis == dst->dims - 1) { + if (dst->stride[axis] == 1 && src->stride[axis] == 1) + memcpy(dst->data + offset_dst, src->data + offset_src, sizeof(*dst->data) * dst->shape[axis]); + else { + for (size_t idx = 0; idx < dst->shape[axis]; idx++) + dst->data[offset_dst + idx * dst->stride[axis]] = src->data[offset_src + idx * src->stride[axis]]; + } + return 0; + } + + for (size_t idx = 0; idx < dst->shape[axis]; idx++) + copy_axis(dst, src, axis + 1, offset_dst + idx * dst->stride[axis], offset_src + idx * src->stride[axis]); + + return 0; +} + +int mg2di_ndarray_copy(NDArray *dst, const NDArray *src) +{ + const unsigned int dims = src->dims; + + if (dims != dst->dims) + return -EINVAL; + + return copy_axis(dst, src, 0, 0, 0); +} |