Hi!
I try to select elements from NDArray structure by indexes like that:
ix_valid = np.ix_(valid.asnumpy().astype(np.uint8) != 0)
vlabels = labels[mx.nd.array(ix_valid)]
where valid
is MXNET NDArray type.
But this way is slower because I spend time to convert asnumpy
.
Is there any way to implement this by MXNET NDArray only?
I am also interested function cumsum
in MXNET context:
numpy.
cumsum
( a , axis=None , dtype=None , out=None )[[source]]
You could use mx.nd.where
https://mxnet.apache.org/api/python/ndarray/ndarray.html#mxnet.ndarray.where
E.g. mx.nd.where( (valid != 0), labels, mx.nd.zeros_like(labels))
would return you an NDArray, where elements are zero when the corresponding element in valid
is equal zero, otherwise the element takes the value from labels
. If you need the indices itself, then one would need to find a workaround for that.
cumsum
function is currently not supported, but there is already a feature request for it: https://github.com/apache/incubator-mxnet/issues/13001