Slices / indexing /__getitem__ when hybridizing a model

Hello,

I am doing some experiments with gradient descent and as my toy model I am using a one dimensional mixture of gaussians with only 2 components. I created the model via the gluon/ndarray API, but now want to hybridize the model. I finally managed to get it done, but it took me a lot of experimentation and I wonder if I got it right.

In gluon/ndarray I can simply use the [ ] operator to access parameters at a given index. This operator does not seem to work for symbols. The only way I got it working was to use the slice() method. I tried to use the pick() and take() methods, too, but when I used them, while not getting any concrete errors thrown the gradient descent did not work.

My question is: what is the correct way to do slices / indexing /getitem when hybridizing a model?
I looked at all the methods in the documented API here: http://beta.mxnet.io/r/api/symbol.html#indexing

Here is what I currently have and I let the comments with my failed attempts in the code to show what I tried:

class MixtureLogLikelihood(mx.gluon.block.HybridBlock):
    def __init__(self, dtype='float32', negative=False, **kwargs):
        super().__init__(**kwargs)
        self.negative = negative

        with self.name_scope():
            self.p = self.params.get('p', shape=(1), init=mx.initializer.Constant(0.5), dtype=dtype, allow_deferred_init=True)
            self.m = self.params.get('m', shape=(2), init=mx.initializer.Uniform(10.0), dtype=dtype, allow_deferred_init=True)
            self.v = self.params.get('v', shape=(2), init=mx.initializer.One(), dtype=dtype, allow_deferred_init=True)

    #   hybrid_forward(self, F, x1, x2, *args, **kwargs)
    def hybrid_forward(self, F, x, p, m, v):
        # F is instance of mx.ndarray
        #                  mx.symbol

        # index0 = F.full(1,0, dtype='int32') #,np.int8 # self.params.get_constant(name='i0', value=nd.array([0], dtype=np.int))
        # index1 = F.full(1,1, dtype='int32') # self.params.get_constant(name='i1', value=nd.array([1], dtype=np.int))

        n = mxf.components.distributions.normal.Normal(0.0, 1.0)
        # m0 = m[0] # F.pick(m, index0, 0)
        # v0 = v[0] # F.pick(v, index0, 0)
        # m0 = F.pick(m, index0, 0)
        # v0 = F.pick(v, index0, 0)
        m0 = m.slice(begin=(0,), end=(1,))
        v0 = v.slice(begin=(0,), end=(1,))
        n0 = n.log_pdf_impl(mean=m0, variance=v0, random_variable=x, F=F)
        # m1 = m[1] # F.pick(m, index0, 0)
        # v1 = v[1] # F.pick(v, index1, 0)
        # m1 = F.pick(m, index0, 0)
        # v1 = F.pick(v, index1, 0)
        m1 = m.slice(begin=(1,), end=(2,))
        v1 = v.slice(begin=(1,), end=(2,))
        n1 = n.log_pdf_impl(mean=m1, variance=v1, random_variable=x, F=F)
        # ns = F.stack(n0, n1, axis=1)
        nc = F.concat(n0, n1, dim=1)
        ps = F.stack(F.log(p[0]), F.log(1-p[0]), axis=1)

        nsp = F.broadcast_add(nc, ps) # nd.array([nd.log(p[0]), nd.log(1-p[0])])

        # lse = log_sum_exp(nsp, axis=1)
        lse = F.Custom(nsp, name="log_sum_exp", op_type="log_sum_exp", axis=1, keepdims=False)

        # N = x.shape[0]
        #
        # loss_list = []
        # for i in range(N):
        #     loss_list += [-nd.log(p[0] * nd.exp(-(x[i] - m[0]) ** 2 / (2 * s[0] ** 2)) / nd.sqrt(2 * np.pi * s[0] ** 2) + (1 - p[0]) * nd.exp(-(x[i] - m[1]) ** 2 / (2 * s[1] ** 2)) / nd.sqrt(2 * np.pi * s[1] ** 2))]
        #
        # loss = nd.add_n(*loss_list)
        if self.negative:
            lse = -lse

        return lse

    def __repr__(self):
        s = '{name}'
        return s.format(name=self.__class__.__name__)

Yes the correct way to do indexing is using slice(), take() or pick(). The symbol’s __getitem__ semantics is not consistent with ndarray’s. There is currently a PR open that tries to add support for basic slicing in symbol (see https://github.com/apache/incubator-mxnet/pull/15905)

Thank you very much for the clarification and the pointer to the PR!

When I used the take/pick approach the gradient descent did not work any longer. I am not sure where this is coming from? To create the index number 0 I used “F.full(1,0, dtype=‘int32’)” and to create the index number 1 I used “F.full(1,1, dtype=‘int32’)”. Then the code above does not throw any errors, but the gradient descent simply stays at its initial values.