Get list of indices of max value occurences

I need the list of indices of the maximum value occurrences of a Symbol/nd.array.

So suppose i have x= [1,0,0,1], I want to get [0,3]. I want to be able to hybridize my network, so I cannot use any numpy functionality. Unfortunatly mxnet argmax function only returns the index of the first occurrence.

What is the best and fastest way to achieve this? I am a bit stumbled, but I havent found a satisfying solution for this so far. I want to use this index list for F.take()

Thanks for any help!

@ifeherva is the expert in residence for these kinds of question.
@ifeherva any suggestions?

@adrian I don’t think it is possible with hybridization.

I could be wrong, but I don’t think you can right now (using symbols alone).

However, can I ask how crucial it is that you need ints? Things you might needs ints for can often be done with binary masks, and you can do:

mask = x == x.max(axis=-1, keepdims=True)

to get a mask of the indices where the maxes are.

After looking around, I also think its not possible right now. I opened a feature request issue on Github.

I need this list of indices for F.take() , which takes a list of indices (ints) and no masks.
I work with an architecture that splits the mini-batch (representing images) and sends them to different branches. I have a binary mask array indicating to which branch a image of the mini-batch should be send. So I want to use take() to split the minibatch given that binary mask, but it only takes indices as argument and apparently I cant retrieve them if I want to be able to hybridize.

It’s not such a big problem for me, I can just send the whole mini-batch to both branches and then use the mask to merge the results together to get the same result as if I would have split my mini-batch. But would have been nice to not do the unnecessary operations…

I feel like the Mxnet API is missing some functions to retrive index lists. I don’t really see any function (in both ndarray and symbols) that returns something that can be used as an index list argument, which is needed for F.take(). Only topk() I guess, but as I do not know the amount of 1 in my indicator array I cannot use it.

Anyway, thanks for the replies!