# Create custom op with auxiliary_states

I am writing a custom op to reproduce CurricularFace paper, belove is my implementation for computing moving average t, just like BatchNorm moving_mean and moving_var.

``````import numpy as np
import mxnet as mx

class CurricularFace(mx.operator.CustomOp):
def __init__(self, momentum=0.99):
super(CurricularFace, self).__init__()
self.momentum = momentum

def forward(self, is_train, req, in_data, out_data, aux):
target_logits = in_data[0]
t = aux[0]
batch = target_logits.shape[0]
aux[0][:] = self.momentum * t + (1 - self.momentum) * target_logits.mean()
self.assign(out_data[0], req[0], aux[0].repeat(repeats=batch).expand_dims(axis=-1))

@mx.operator.register("CurricularFaceT")
class CurricularFaceProp(mx.operator.CustomOpProp):
def __init__(self, momentum=0.99):
super(CurricularFaceProp, self).__init__(need_top_grad=False)
self.momentum = float(momentum)

def list_arguments(self):
return ['data']

def list_outputs(self):
return ['CurricularT']

def list_auxiliary_states(self):
return ['coef_t_bias']

def infer_shape(self, in_shapes):
data_shape = in_shapes[0]
batch = data_shape[0]
return [data_shape], [(batch, 1)], [(1,)]

def infer_type(self, in_type):
return [np.float32], [np.float32], [np.float32]

def create_operator(self, ctx, in_shapes, in_dtypes):
return CurricularFace(momentum=self.momentum)

def declare_backward_dependency(self, out_grad, in_data, out_data):
return []
``````

The above code will cause the following error:

``````Stack trace:
[bt] (0) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x4b04cb) [0x7f3a4764b4cb]
[bt] (1) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x7fddcc) [0x7f3a47998dcc]
[bt] (2) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x7e16bd) [0x7f3a4797c6bd]
[bt] (3) /usr/lib/x86_64-linux-gnu/libstdc++.so.6(+0xb8c80) [0x7f3ab224bc80]
[bt] (4) /lib/x86_64-linux-gnu/libpthread.so.0(+0x76ba) [0x7f3b91c166ba]
[bt] (5) /lib/x86_64-linux-gnu/libc.so.6(clone+0x6d) [0x7f3b90df941d]
Traceback (most recent call last):
.......................
File "/usr/local/lib/python3.6/dist-packages/mxnet/ndarray/ndarray.py", line 1819, in wait_to_read
check_call(_LIB.MXNDArrayWaitToRead(self.handle))
File "/usr/local/lib/python3.6/dist-packages/mxnet/base.py", line 253, in check_call
raise MXNetError(py_str(_LIB.MXGetLastError()))
mxnet.base.MXNetError: [22:49:32] src/operator/custom/custom.cc:417: Check failed: reinterpret_cast<CustomOpFBFunc>(params.info->call
backs[kCustomOpBackward])( ptrs.size(), const_cast<void**>(ptrs.data()), const_cast<int*>(tags.data()), reinterpret_cast<const int*>(req.data()), static_cast<int>(ctx.is_train), params
.info->contexts[kCustomOpBackward]):
Stack trace:
[bt] (0) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x4b04cb) [0x7f3a4764b4cb]
[bt] (1) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x7fddcc) [0x7f3a47998dcc]
[bt] (2) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x7e16bd) [0x7f3a4797c6bd]
[bt] (3) /usr/lib/x86_64-linux-gnu/libstdc++.so.6(+0xb8c80) [0x7f3ab224bc80]
[bt] (4) /lib/x86_64-linux-gnu/libpthread.so.0(+0x76ba) [0x7f3b91c166ba]
[bt] (5) /lib/x86_64-linux-gnu/libc.so.6(clone+0x6d) [0x7f3b90df941d]
Error in CustomOp.backward: Traceback (most recent call last):
File "/usr/local/lib/python3.6/dist-packages/mxnet/operator.py", line 1020, in backward_entry
stype=stype))
File "/usr/local/lib/python3.6/dist-packages/mxnet/ndarray/sparse.py", line 1187, in _ndarray_cls
raise Exception("unknown storage type: %s"%stype)
Exception: unknown storage type: -1
Error in CustomOp.backward: Traceback (most recent call last):
File "/usr/local/lib/python3.6/dist-packages/mxnet/operator.py", line 1020, in backward_entry
stype=stype))
File "/usr/local/lib/python3.6/dist-packages/mxnet/ndarray/sparse.py", line 1187, in _ndarray_cls
raise Exception("unknown storage type: %s"%stype)
Exception: unknown storage type: -1
``````

if we write code like following:

``````class CurricularFace(mx.operator.CustomOp):
def __init__(self, ctx, in_shapes, in_dtypes, momentum=0.99):
super(CurricularFace, self).__init__()
self.momentum = momentum

def forward(self, is_train, req, in_data, out_data, aux):
target_logits = in_data[0]
t = aux[0]
batch = target_logits.shape[0]
aux[0][:] = self.momentum * t + (1 - self.momentum) * target_logits.mean()
self.assign(out_data[0], req[0], aux[0].repeat(repeats=batch).expand_dims(axis=-1))

@mx.operator.register("CurricularFaceT")
class CurricularFaceProp(mx.operator.CustomOpProp):
def __init__(self, momentum=0.99):
super(CurricularFaceProp, self).__init__(need_top_grad=False)
self.momentum = float(momentum)

def list_arguments(self):
return ['data']

def list_outputs(self):
return ['CurricularT']

def list_auxiliary_states(self):
return ['coef_t_bias']

def infer_shape(self, in_shapes):
data_shape = in_shapes[0]
batch = data_shape[0]
return [data_shape], [(batch, 1)], [(1,)]

def infer_type(self, in_type):
return [np.float32], [np.float32], [np.float32]

def create_operator(self, ctx, in_shapes, in_dtypes):
return CurricularFace(ctx, in_shapes, in_dtypes, momentum=self.momentum)
``````

it will cause another error:

``````   File "/usr/local/lib/python3.6/dist-packages/mxnet/symbol/symbol.py", line 1629, in simple_bind
[bt] (7) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(mxnet::Executor::SimpleBind(nnvm::Symbol, mxnet::Context const&,
std::map<std::string, mxnet::Context, std::less<std::string>, std::allocator<std::pair<std::string const, mxnet::Context> > > const&, std::vector<mxnet::Context, std::allocator<mxnet::
Context> > const&, std::vector<mxnet::Context, std::allocator<mxnet::Context> > const&, std::vector<mxnet::Context, std::allocator<mxnet::Context> > const&, std::unordered_map<std::string, mxnet::TShape, std::hash<std::string>, std::equal_to<std::string>, std::allocator<std::pair<std::string const, mxnet::TShape> > > const&, std::unordered_map<std::string, int, std:
:hash<std::string>, std::equal_to<std::string>, std::allocator<std::pair<std::string const, int> > > const&, std::unordered_map<std::string, int, std::hash<std::string>, std::equal_to<
std::string>, std::allocator<std::pair<std::string const, int> > > const&, std::vector<mxnet::OpReqType, std::allocator<mxnet::OpReqType> > const&, std::un
raise RuntimeError(error_msg)
ordered_set<std::string, std::hash<std::string>, std::equal_to<std::string>, std::allocator<std::string> > const&, std::vector<mxnet::NDArray, std::allocator<mxnet::NDArray> >*, std::vector<mxnet::NDArray, std::allocator<mxnet::NDArray> >*, std::vector<mxnet::NDArray, std::allocator<mxnet::NDArray> >*, std::unordere
d_map<std::string, mxnet::NDArray, std::hash<std::string>, std::equal_to<std::string>, std::allocator<std::pair<std::string const, mxnet::NDArray> > >*, mxnet::Executor*)+0x8a8) [0x7f4
439cc6258]
RuntimeError: simple_bind error. Arguments:
[bt] (8) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(MXExecutorSimpleBindEx+0x221b) [0x7f4439c0795b]
data: (720, 256)
softmax_label: (720,)
[23:37:42] src/operator/custom/custom.cc:282: Check failed: reinterpret_cast<CustomOpCreateFunc>( params.info->callbacks[kCustomOpPro
pCreateOperator])( os.str().c_str(), shapes.size(), shapes.data(), ndims.data(), in_type.data(), op_info, params.info->contexts[kCustomOpPropCreateOperator]):
Stack trace:
[bt] (0) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x4b04cb) [0x7f89b96b14cb]
[bt] (1) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x7eb5a9) [0x7f89b99ec5a9]
[bt] (2) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x7df2e5) [0x7f89b99e02e5]
[bt] (3) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x25d85ad) [0x7f89bb7d95ad]
[bt] (4) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x25dbc08) [0x7f89bb7dcc08]
[bt] (5) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(mxnet::exec::GraphExecutor::FinishInitGraph(nnvm::Symbol, nnvm::G
raph, mxnet::Executor*, std::unordered_map<nnvm::NodeEntry, mxnet::NDArray, nnvm::NodeEntryHash, nnvm::NodeEntryEqual, std::allocator<std::pair<nnvm::NodeEntry const, mxnet::NDArray> >
> const&)+0x793) [0x7f89bb809bf3]``````