Create custom op with auxiliary_states

I am writing a custom op to reproduce CurricularFace paper, belove is my implementation for computing moving average t, just like BatchNorm moving_mean and moving_var.

import numpy as np
import mxnet as mx


class CurricularFace(mx.operator.CustomOp):
    def __init__(self, momentum=0.99):
        super(CurricularFace, self).__init__()
        self.momentum = momentum

    def forward(self, is_train, req, in_data, out_data, aux):
        target_logits = in_data[0]
        t = aux[0]
        batch = target_logits.shape[0]
        aux[0][:] = self.momentum * t + (1 - self.momentum) * target_logits.mean()
        self.assign(out_data[0], req[0], aux[0].repeat(repeats=batch).expand_dims(axis=-1))


@mx.operator.register("CurricularFaceT")
class CurricularFaceProp(mx.operator.CustomOpProp):
    def __init__(self, momentum=0.99):
        super(CurricularFaceProp, self).__init__(need_top_grad=False)
        self.momentum = float(momentum)

    def list_arguments(self):
        return ['data']

    def list_outputs(self):
        return ['CurricularT']

    def list_auxiliary_states(self):
        return ['coef_t_bias']

    def infer_shape(self, in_shapes):
        data_shape = in_shapes[0]
        batch = data_shape[0]
        return [data_shape], [(batch, 1)], [(1,)]

    def infer_type(self, in_type):
        return [np.float32], [np.float32], [np.float32]

    def create_operator(self, ctx, in_shapes, in_dtypes):
        return CurricularFace(momentum=self.momentum)
    
    def declare_backward_dependency(self, out_grad, in_data, out_data):
        return []

The above code will cause the following error:

Stack trace:
[bt] (0) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x4b04cb) [0x7f3a4764b4cb]
[bt] (1) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x7fddcc) [0x7f3a47998dcc]
[bt] (2) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x7e16bd) [0x7f3a4797c6bd]
[bt] (3) /usr/lib/x86_64-linux-gnu/libstdc++.so.6(+0xb8c80) [0x7f3ab224bc80]
[bt] (4) /lib/x86_64-linux-gnu/libpthread.so.0(+0x76ba) [0x7f3b91c166ba]
[bt] (5) /lib/x86_64-linux-gnu/libc.so.6(clone+0x6d) [0x7f3b90df941d]
Traceback (most recent call last):
.......................
 File "/usr/local/lib/python3.6/dist-packages/mxnet/ndarray/ndarray.py", line 1819, in wait_to_read
   check_call(_LIB.MXNDArrayWaitToRead(self.handle))
 File "/usr/local/lib/python3.6/dist-packages/mxnet/base.py", line 253, in check_call
   raise MXNetError(py_str(_LIB.MXGetLastError()))
 mxnet.base.MXNetError: [22:49:32] src/operator/custom/custom.cc:417: Check failed: reinterpret_cast<CustomOpFBFunc>(params.info->call
backs[kCustomOpBackward])( ptrs.size(), const_cast<void**>(ptrs.data()), const_cast<int*>(tags.data()), reinterpret_cast<const int*>(req.data()), static_cast<int>(ctx.is_train), params
.info->contexts[kCustomOpBackward]): 
 Stack trace:
   [bt] (0) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x4b04cb) [0x7f3a4764b4cb]
   [bt] (1) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x7fddcc) [0x7f3a47998dcc]
   [bt] (2) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x7e16bd) [0x7f3a4797c6bd]
   [bt] (3) /usr/lib/x86_64-linux-gnu/libstdc++.so.6(+0xb8c80) [0x7f3ab224bc80]
   [bt] (4) /lib/x86_64-linux-gnu/libpthread.so.0(+0x76ba) [0x7f3b91c166ba]
   [bt] (5) /lib/x86_64-linux-gnu/libc.so.6(clone+0x6d) [0x7f3b90df941d]
 Error in CustomOp.backward: Traceback (most recent call last):
   File "/usr/local/lib/python3.6/dist-packages/mxnet/operator.py", line 1020, in backward_entry
     stype=stype))
   File "/usr/local/lib/python3.6/dist-packages/mxnet/ndarray/sparse.py", line 1187, in _ndarray_cls
     raise Exception("unknown storage type: %s"%stype)
 Exception: unknown storage type: -1
 Error in CustomOp.backward: Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/mxnet/operator.py", line 1020, in backward_entry
     stype=stype))
   File "/usr/local/lib/python3.6/dist-packages/mxnet/ndarray/sparse.py", line 1187, in _ndarray_cls
    raise Exception("unknown storage type: %s"%stype)
     Exception: unknown storage type: -1

if we write code like following:

class CurricularFace(mx.operator.CustomOp):
    def __init__(self, ctx, in_shapes, in_dtypes, momentum=0.99):
        super(CurricularFace, self).__init__()
        self.momentum = momentum

    def forward(self, is_train, req, in_data, out_data, aux):
        target_logits = in_data[0]
        t = aux[0]
        batch = target_logits.shape[0]
        aux[0][:] = self.momentum * t + (1 - self.momentum) * target_logits.mean()
        self.assign(out_data[0], req[0], aux[0].repeat(repeats=batch).expand_dims(axis=-1))


@mx.operator.register("CurricularFaceT")
class CurricularFaceProp(mx.operator.CustomOpProp):
    def __init__(self, momentum=0.99):
        super(CurricularFaceProp, self).__init__(need_top_grad=False)
        self.momentum = float(momentum)

    def list_arguments(self):
        return ['data']

    def list_outputs(self):
        return ['CurricularT']

    def list_auxiliary_states(self):
        return ['coef_t_bias']

    def infer_shape(self, in_shapes):
        data_shape = in_shapes[0]
        batch = data_shape[0]
        return [data_shape], [(batch, 1)], [(1,)]

    def infer_type(self, in_type):
        return [np.float32], [np.float32], [np.float32]

    def create_operator(self, ctx, in_shapes, in_dtypes):
        return CurricularFace(ctx, in_shapes, in_dtypes, momentum=self.momentum)

it will cause another error:

   File "/usr/local/lib/python3.6/dist-packages/mxnet/symbol/symbol.py", line 1629, in simple_bind
   [bt] (7) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(mxnet::Executor::SimpleBind(nnvm::Symbol, mxnet::Context const&, 
std::map<std::string, mxnet::Context, std::less<std::string>, std::allocator<std::pair<std::string const, mxnet::Context> > > const&, std::vector<mxnet::Context, std::allocator<mxnet::
Context> > const&, std::vector<mxnet::Context, std::allocator<mxnet::Context> > const&, std::vector<mxnet::Context, std::allocator<mxnet::Context> > const&, std::unordered_map<std::string, mxnet::TShape, std::hash<std::string>, std::equal_to<std::string>, std::allocator<std::pair<std::string const, mxnet::TShape> > > const&, std::unordered_map<std::string, int, std:
:hash<std::string>, std::equal_to<std::string>, std::allocator<std::pair<std::string const, int> > > const&, std::unordered_map<std::string, int, std::hash<std::string>, std::equal_to<
std::string>, std::allocator<std::pair<std::string const, int> > > const&, std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType> > const&, std::un
     raise RuntimeError(error_msg)
 ordered_set<std::string, std::hash<std::string>, std::equal_to<std::string>, std::allocator<std::string> > const&, std::vector<mxnet::NDArray, std::allocator<mxnet::NDArray> >*, std::vector<mxnet::NDArray, std::allocator<mxnet::NDArray> >*, std::vector<mxnet::NDArray, std::allocator<mxnet::NDArray> >*, std::unordere
d_map<std::string, mxnet::NDArray, std::hash<std::string>, std::equal_to<std::string>, std::allocator<std::pair<std::string const, mxnet::NDArray> > >*, mxnet::Executor*)+0x8a8) [0x7f4
439cc6258]
 RuntimeError: simple_bind error. Arguments:
   [bt] (8) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(MXExecutorSimpleBindEx+0x221b) [0x7f4439c0795b]
 data: (720, 256)
 softmax_label: (720,)
 [23:37:42] src/operator/custom/custom.cc:282: Check failed: reinterpret_cast<CustomOpCreateFunc>( params.info->callbacks[kCustomOpPro
pCreateOperator])( os.str().c_str(), shapes.size(), shapes.data(), ndims.data(), in_type.data(), op_info, params.info->contexts[kCustomOpPropCreateOperator]): 
 Stack trace:
   [bt] (0) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x4b04cb) [0x7f89b96b14cb]
   [bt] (1) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x7eb5a9) [0x7f89b99ec5a9]
   [bt] (2) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x7df2e5) [0x7f89b99e02e5]
   [bt] (3) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x25d85ad) [0x7f89bb7d95ad]
   [bt] (4) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x25dbc08) [0x7f89bb7dcc08]
   [bt] (5) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(mxnet::exec::GraphExecutor::FinishInitGraph(nnvm::Symbol, nnvm::G
raph, mxnet::Executor*, std::unordered_map<nnvm::NodeEntry, mxnet::NDArray, nnvm::NodeEntryHash, nnvm::NodeEntryEqual, std::allocator<std::pair<nnvm::NodeEntry const, mxnet::NDArray> >
 > const&)+0x793) [0x7f89bb809bf3]