# How to choose some specific rows according to 0/1 mask?

As shown in following code, I want to choose the first row and thrid row because their corresponding index are “1” and second row’s corresponding index is “0”. However, if I use a[b], I cannot get what I want.
How to achieve the effect I want？

``````a = mx.nd.array([[1,2,3,],[4,5,6],[7,8,9]])
b = mx.nd.array([1,0,1])
a[b]
[[4. 5. 6.]
[1. 2. 3.]
[4. 5. 6.]]
<NDArray 3x3 @cpu(0)>
# I want to get:
# [[1. 2. 3.]
#  [7. 8. 9.]]
``````

Hi @zhoulukuan,

I believe this is what you’re trying to achieve here:

``````a = mx.nd.array([[1,2,3,],[4,5,6],[7,8,9]])

# b should be the indices you want to take.
b = mx.nd.array([0,2])

a.take(b, axis=0)

[[1. 2. 3.]
[7. 8. 9.]]
<NDArray 2x3 @cpu(0)>
``````

I don’t think this solves the problem the OP was asking about. This suggestion, while it achieves the same end, is limited in a couple of key situations. For example, suppose I want to do this type of masked selection in batches:

``````x = mx.nd.array([ [[1,2,3], [4,5,6], [7,8,9]], [[-1,-2,-3], [-4,-5,-6], [-7,-8,-9]] ], ctx=mx.cpu()) # shape (2, 3, 3)
y = mx.nd.array( [[1,0,1], [0,1,0]], ctx=mx.cpu() )
# Error due to incompatible shapes
# z = mx.nd.array( [[0, 2], [1]])
# mx.nd.take(x, z)
# desired output:
# [ [1, 2, 3], [7, 8, 9], [-4, -5, -6] ]
``````
1 Like

@ThomasDelteil Oh, it’s good. Thanks for your help.