Workaround for take operation on other axes

Since the symbol api has limited indexing operations, it gets a bit tricky to do some of the tensor operations when the data has batches and is 4D.

For my problem, the take() operation does exactly what i want to do. But when the data comes in batches, i can’t use it anymore as the take operation only works for axis=0.

Following snippet shows an example using the ndarray api:

img =mx.nd.random.uniform(0, 1, shape=(16,32,32)) #shape=(16,32,32)
indices = mx.nd.array(np.random.choice(16, 128)) #shape = (128,)
result = mx.nd.take(img,indices) # shape (128L, 32L, 32L)

Since the following snippet won’t work, is it possible to have a workaround (in the symbol api) with a similar performance?

img =mx.nd.random.uniform(0, 1, shape=(1, 16,32,32)) #shape=(1, 16,32,32)
indices = mx.nd.array(np.random.choice(16, 128)) #shape = (128,)
result = mx.nd.take(img,indices, axis=1).shape # shape (1, 128L, 32L, 32L)

I have of course checked out batch_take, it is deprecated and seems equivalent to the sym.pick() instead of take().

I can use slice_axis, loop, take and then stack the results together but wanted to make sure if there is a more efficient/elegant way that i am overlooking.

transpose the batch to last axis, do take operation and transpose back?

1 Like