Error when iterating HybridSequential: control dep not found in graph

Dear all,

I have a weird error that I cannot find the source of it. Perhaps someone can help? In particular I am trying to implement the following simple operation of iterating over the elements of a HybridSequential “container”:

class SomeModel(HybridBlock):
     def __init__(self, **kwards):
        super().__init__(**kwards)

        with self.name_scope():
             self.layers = gluon.nn.HybridSequential()
            for idx in range(4):
                 self.layers.add(SomeOtherHybridBlock(idx) ) # different idx results different layer

    def hybrid_forward(self, F, input):
        x = input
        for conv in self.layers:
            x = x + self.layers(input)
         return x

it seems simple, but I am getting the following error:

---------------------------------------------------------------------------
MXNetError                                Traceback (most recent call last)
<ipython-input-2-e5c3f04ecacb> in <module>
      8 net.hybridize()
      9 xx = nd.random.uniform(shape=[7,nfilters,F,F])
---> 10 out = net(xx)

~/.local/lib/python3.6/site-packages/mxnet/gluon/block.py in __call__(self, *args)
    691             hook(self, args)
    692 
--> 693         out = self.forward(*args)
    694 
    695         for hook in self._forward_hooks.values():

~/.local/lib/python3.6/site-packages/mxnet/gluon/block.py in forward(self, x, *args)
   1146                                      'Find all contexts = {}'.format(ctx_set))
   1147                 with ctx:
-> 1148                     return self._call_cached_op(x, *args)
   1149             with ctx:
   1150                 try:

~/.local/lib/python3.6/site-packages/mxnet/gluon/block.py in _call_cached_op(self, *args)
    979     def _call_cached_op(self, *args):
    980         if self._cached_op is None:
--> 981             self._build_cache(*args)
    982         assert self._cached_op, "cached op is not None"
    983         if self._callback:

~/.local/lib/python3.6/site-packages/mxnet/gluon/block.py in _build_cache(self, *args)
    967         flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \
    968                 self._flags
--> 969         self._cached_op = ndarray.CachedOp(out, flags)
    970 
    971     def _deferred_infer_shape(self, *args):

~/.local/lib/python3.6/site-packages/mxnet/_ctypes/ndarray.py in __init__(self, sym, flags)
    134             c_str_array([key for key, _ in flags]),
    135             c_str_array([str(val) for _, val in flags]),
--> 136             ctypes.byref(self.handle)))
    137 
    138     def __del__(self):

~/.local/lib/python3.6/site-packages/mxnet/base.py in check_call(ret)
    253     """
    254     if ret != 0:
--> 255         raise MXNetError(py_str(_LIB.MXGetLastError()))
    256 
    257 

MXNetError: [18:27:16] src/core/graph.cc:110: Check failed: it != node2index_.end(): control dep not found in graph
Stack trace:
  [bt] (0) /home/dia021/.local/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x65928b) [0x7fa68af2828b]
  [bt] (1) /home/dia021/.local/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x8385e68) [0x7fa692c54e68]
  [bt] (2) /home/dia021/.local/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x8386fe8) [0x7fa692c55fe8]
  [bt] (3) /home/dia021/.local/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x8387b70) [0x7fa692c56b70]
  [bt] (4) /home/dia021/.local/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x389a709) [0x7fa68e169709]
  [bt] (5) /home/dia021/.local/lib/python3.6/site-packages/mxnet/libmxnet.so(MXCreateCachedOpEx+0x2cf) [0x7fa68e07672f]
  [bt] (6) /usr/lib/x86_64-linux-gnu/libffi.so.6(ffi_call_unix64+0x4c) [0x7fa6f55f3dae]
  [bt] (7) /usr/lib/x86_64-linux-gnu/libffi.so.6(ffi_call+0x22f) [0x7fa6f55f371f]
  [bt] (8) /usr/lib/python3.6/lib-dynload/_ctypes.cpython-36m-x86_64-linux-gnu.so(_ctypes_callproc+0x2b4) [0x7fa6f58075c4]

Complete code example (am trying to implement an attention layer - in 2D) and add it in parallel to the input, in the philosophy of the residual units/blocks):

# Definitions

import mxnet as mx
from mxnet.gluon import HybridBlock
from mxnet import nd, gluon

class Conv2DNormed(HybridBlock):
    """
        Convenience wrapper layer for 2D convolution followed by a normalization layer 
        (either BatchNorm or InstanceNorm). 
        norm_type: Either BatchNorm (default) or InstanceNorm strings. 
        axis : axis in normalization (exists only in BatchNorm). 
        All other keywords are the same as gluon.nn.Conv2D 
    """

    def __init__(self,  channels, kernel_size, strides=(1, 1),
                 padding=(0, 0), dilation=(1, 1),   activation=None,
                 weight_initializer=None,  in_channels=0, _norm_type = 'BatchNorm', axis =1 , groups=1,**kwards):
        super().__init__(**kwards)

        if (_norm_type == 'BatchNorm'):
            self.norm = gluon.nn.BatchNorm
        elif (_norm_type == 'SyncBatchNorm'):
            self.norm = gluon.contrib.nn.SyncBatchNorm
            _prefix = "_SyncBN"
        elif (_norm_type == 'InstanceNorm'):
            self.norm = gluon.nn.InstanceNorm

        elif (_norm_type == 'LayerNorm'):
            self.norm = gluon.nn.LayerNorm
        else:
            raise NotImplementedError


        with self.name_scope():
            self.conv2d = gluon.nn.Conv2D(channels, kernel_size = kernel_size,
                                          strides= strides,
                                          padding=padding,
                                          dilation= dilation,
                                          activation=activation,
                                          use_bias=False,
                                          weight_initializer = weight_initializer,
                                          groups=groups,
                                          in_channels=0)

            self.norm_layer = self.norm(axis=axis)


    def hybrid_forward(self,F,_x):

        x = self.conv2d(_x)
        x = self.norm_layer(x)

        return x


    
def get_norm(name, prefix, ngroups=None):
    if (name == 'BatchNorm'):
        return gluon.nn.BatchNorm(axis=1,prefix = prefix)
    elif (name == 'InstanceNorm'):
        return gluon.nn.InstanceNorm(axis=1,prefix = prefix)
    elif (name == 'LayerNorm'):
        return gluon.nn.LayerNorm(axis=1,prefix = prefix)
    elif (name == 'GroupNorm'):
        return gluon.nn.GroupNorm(num_groups=ngroups, prefix = prefix)
    else:
        raise NotImplementedError

class ResNet_v2_block(HybridBlock):
    """
    ResNet v2 building block. It is built upon the assumption of ODD kernel 
    """
    def __init__(self, _nfilters,_kernel_size=(3,3),_dilation_rate=(1,1),
                 _norm_type='BatchNorm', ngroups=1, **kwards):
        super().__init__(**kwards)

        self.nfilters = _nfilters
        self.kernel_size = _kernel_size
        self.dilation_rate = _dilation_rate


        with self.name_scope():

            # Ensures padding = 'SAME' for ODD kernel selection 
            p0 = self.dilation_rate[0] * (self.kernel_size[0] - 1)/2
            p1 = self.dilation_rate[1] * (self.kernel_size[1] - 1)/2
            p = (int(p0),int(p1))


            self.BN1 = get_norm(_norm_type, prefix = "1_",ngroups=ngroups )
            self.conv1 = gluon.nn.Conv2D(self.nfilters,kernel_size = self.kernel_size,padding=p,dilation=self.dilation_rate,use_bias=False,prefix="_conv1_",groups=ngroups)
            self.BN2 = get_norm(_norm_type, prefix = "2_" ,ngroups=ngroups)
            self.conv2 = gluon.nn.Conv2D(self.nfilters,kernel_size = self.kernel_size,padding=p,dilation=self.dilation_rate,use_bias=True,prefix="_conv2_",groups=ngroups)


    def hybrid_forward(self,F,_input_layer):


        x = self.BN1(_input_layer)
        x = F.relu(x)
        x = self.conv1(x)

        x = self.BN2(x)
        x = F.relu(x)
        x = self.conv2(x)

        return x

class ResNet_atrous_unit(HybridBlock):
    def __init__(self, _nfilters, _kernel_size=(3,3), _dilation_rates=[3,15,31], _norm_type = 'BatchNorm', **kwards):
        super().__init__(**kwards)
        with self.name_scope():
            self.ResBlocks = gluon.nn.HybridSequential()
            self.ResBlocks.add(ResNet_v2_block(_nfilters,_kernel_size,_dilation_rate=(1,1), _norm_type = _norm_type))
            for idx,d in enumerate(_dilation_rates):
                self.ResBlocks.add(ResNet_v2_block(_nfilters,_kernel_size,_dilation_rate=(d,d), _norm_type = _norm_type))


    def hybrid_forward(self,F,_xl):
        x = _xl
        for conv in self.ResBlocks:
            x = x + conv(_xl)

        return x



    
# @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ SOME WEIRD PROBLEM @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
class Attention2D_v3(HybridBlock):
    """
    Self spatial Attention layer.  
    """

    def __init__(self, nkeys, nheads=1, norm = 'BatchNorm',**kwards):
        super().__init__(**kwards)

        with self.name_scope():
            
            self.query  = Conv2DNormed(nkeys,kernel_size=1, _norm_type= norm, groups=nheads)
            self.key    = Conv2DNormed(nkeys,kernel_size=1, _norm_type= norm, groups=nheads)
            self.value  = Conv2DNormed(nkeys,kernel_size=1, _norm_type= norm, groups=nheads)
                    
            self.act_q = gluon.nn.Swish()
            self.act_k = gluon.nn.SELU()

    def hybrid_forward(self, F, input):

        
        
        q = self.act_q(self.query(input))
        q = q.reshape([0,0,-1])
        
        k = self.act_k(self.key(input))
        k = k.reshape([0,0,-1])
        
        v = F.relu(self.value(input))
        v = v.reshape([0,0,-1])

        # Spatial attention
        att = F.batch_dot(q,k,transpose_b=True)
        # Is this scaling good enough? 
        scale = F.reciprocal(F.norm(att,ord=1,axis=(-1,-2),keepdims=True))
        att = F.softmax(F.broadcast_mul(att,scale),axis=-1)

        att = F.batch_dot(att,v)
        att = att.reshape_like(input)

        return att
# @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@



class AResNet_v2_block(HybridBlock):
    def __init__(self, _nfilters, _kernel_size=(3,3), _dilation_rate=(1,1), _norm_type='BatchNorm', _nkeys=None, nheads=4, ngroups=1, **kwards):
        super().__init__(**kwards)
        
        if _nkeys==None:
            _nkeys=_nfilters        
        with self.name_scope():
            self.block = ResNet_v2_block(_nfilters,_kernel_size,_dilation_rate= _dilation_rate, _norm_type = _norm_type)
            self.att = Attention2D_v3( _nkeys, nheads=nheads, norm= _norm_type)
            self.tgamma = self.params.get('tgamma',shape=(1,),init=mx.init.One())

                        
    def hybrid_forward(self, F,input, tgamma):
        block = self.block(input) 
        att = self.att(input)
        out = F.broadcast_add(block , F.broadcast_mul(att,tgamma))
        return out
    

class AResNet_atrous_unit(HybridBlock):
    # PROBLEM HERE 
    """
    This layer is as Diakogiannis et al 2019 (ResUNet-a) with the addition of an attention layer applied on each dilated conv block. 
    """
    def __init__(self, _nfilters, _kernel_size=(3,3), _dilation_rates=[1,3,15,31], _nkeys=None, nheads=4, _norm_type = 'BatchNorm', **kwards):
        super().__init__(**kwards)

        # mxnet doesn't like wrapping things inside a list: it shadows the HybridBlock, remove list                                                                                   
        with self.name_scope():
            self.AResBlocks = gluon.nn.HybridSequential()
            for d in _dilation_rates:
                self.AResBlocks.add( AResNet_v2_block(_nfilters,_kernel_size,_dilation_rate=(d,d), nheads=nheads, _nkeys=_nkeys, _norm_type = _norm_type) )
            
    def hybrid_forward(self, F, input):
        x = input
        for block in self.AResBlocks:
            x = x + block(input)

        return x

This reproduces the error:

from mxnet import nd 
nfilters = 1024
F=8
# @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
net = AResNet_atrous_unit(nfilters,_dilation_rates=[1,3],nheads=nfilters//8) # Doesn't work
# @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
#net = ResNet_atrous_unit(nfilters) # Works fine
#net = AResNet_v2_block(nfilters) # Works fine
net.initialize()
net.hybridize()
xx = nd.random.uniform(shape=[7,nfilters,F,F])
out = net(xx)

Any ideas what I am doing wrong? Thank you in advance.

mxnet version == 1.6.0

So, I’ve redesigned my initial AResNet_atrous_unit, removing completely the AResNet_v2_block (which was redundant in the form presented here as it was evaluating the same layer many times) and problem solved. Thanks!

Actually, this looks like a bug in mxnet as it seems it is related (somehow) to the length of the HybridSequential as a container.

Demo:

class Demo(HybridBlock):
    def __init__(self, kernel_sizes = [3,3,3,3],**kwards):
        super().__init__(**kwards)
        
        with self.name_scope():
            self.net = gluon.nn.HybridSequential()
            for k in kernel_sizes:
                self.net.add(gluon.nn.Conv2D(32,kernel_size=k,padding=1))
                
    def hybrid_forward(self, F, input):
        x = input
        for conv in self.net:
            x = x + conv(input)
            
            
        return x

as long as the length of kernel_sizes is smaller than 7 - for this particular example - this is working:

nfilters=32
F = 256

net = Demo(kernel_sizes=[3]*6)
net.initialize()
net.hybridize()
xx = nd.random.uniform(shape=[7,nfilters,F,F])
out = net(xx)

This producess error:

nfilters=32
F = 256

net = Demo(kernel_sizes=[3]*7) # <=== CHANGE HERE
net.initialize()
net.hybridize()
xx = nd.random.uniform(shape=[7,nfilters,F,F])
out = net(xx)

Error output:

---------------------------------------------------------------------------
MXNetError                                Traceback (most recent call last)
<ipython-input-39-0824e4e4d6c4> in <module>
      6 net.hybridize()
      7 xx = nd.random.uniform(shape=[7,nfilters,F,F])
----> 8 out = net(xx)

~/.local/lib/python3.6/site-packages/mxnet/gluon/block.py in __call__(self, *args)
    691             hook(self, args)
    692 
--> 693         out = self.forward(*args)
    694 
    695         for hook in self._forward_hooks.values():

~/.local/lib/python3.6/site-packages/mxnet/gluon/block.py in forward(self, x, *args)
   1146                                      'Find all contexts = {}'.format(ctx_set))
   1147                 with ctx:
-> 1148                     return self._call_cached_op(x, *args)
   1149             with ctx:
   1150                 try:

~/.local/lib/python3.6/site-packages/mxnet/gluon/block.py in _call_cached_op(self, *args)
    979     def _call_cached_op(self, *args):
    980         if self._cached_op is None:
--> 981             self._build_cache(*args)
    982         assert self._cached_op, "cached op is not None"
    983         if self._callback:

~/.local/lib/python3.6/site-packages/mxnet/gluon/block.py in _build_cache(self, *args)
    967         flags = [('data_indices', data_indices), ('param_indices', param_indices)] + \
    968                 self._flags
--> 969         self._cached_op = ndarray.CachedOp(out, flags)
    970 
    971     def _deferred_infer_shape(self, *args):

~/.local/lib/python3.6/site-packages/mxnet/_ctypes/ndarray.py in __init__(self, sym, flags)
    134             c_str_array([key for key, _ in flags]),
    135             c_str_array([str(val) for _, val in flags]),
--> 136             ctypes.byref(self.handle)))
    137 
    138     def __del__(self):

~/.local/lib/python3.6/site-packages/mxnet/base.py in check_call(ret)
    253     """
    254     if ret != 0:
--> 255         raise MXNetError(py_str(_LIB.MXGetLastError()))
    256 
    257 

MXNetError: [14:27:32] src/core/graph.cc:110: Check failed: it != node2index_.end(): control dep not found in graph
Stack trace:
  [bt] (0) /home/dia021/.local/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x65928b) [0x7f6e5b82f28b]
  [bt] (1) /home/dia021/.local/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x8385e68) [0x7f6e6355be68]
  [bt] (2) /home/dia021/.local/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x8386fe8) [0x7f6e6355cfe8]
  [bt] (3) /home/dia021/.local/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x8387b70) [0x7f6e6355db70]
  [bt] (4) /home/dia021/.local/lib/python3.6/site-packages/mxnet/libmxnet.so(+0x389a709) [0x7f6e5ea70709]
  [bt] (5) /home/dia021/.local/lib/python3.6/site-packages/mxnet/libmxnet.so(MXCreateCachedOpEx+0x2cf) [0x7f6e5e97d72f]
  [bt] (6) /usr/lib/x86_64-linux-gnu/libffi.so.6(ffi_call_unix64+0x4c) [0x7f6edc06fdae]
  [bt] (7) /usr/lib/x86_64-linux-gnu/libffi.so.6(ffi_call+0x22f) [0x7f6edc06f71f]
  [bt] (8) /usr/lib/python3.6/lib-dynload/_ctypes.cpython-36m-x86_64-linux-gnu.so(_ctypes_callproc+0x2b4) [0x7f6edc2835c4]

I’ll submit an issue.
Edit: issue #16736

Edit2:
Some additional information: it seems the error relates to how many times the initial input is passed from the conv layers. It is not directly related to the iteration over the HybridSequential container.

This works irrespective to what is the length of the kernel_sizes:

class Demo(HybridBlock):
    def __init__(self, kernel_sizes = [3,3,3,3],**kwards):
        super().__init__(**kwards)
        
        with self.name_scope():
            self.net = gluon.nn.HybridSequential()
            for k in kernel_sizes:
                tnet = gluon.nn.HybridSequential()
                for _ in range(3):
                    tnet.add(gluon.nn.Conv2D(32,kernel_size=k,padding=1))
                self.net.add(tnet)
    def hybrid_forward(self, F, input):
        x = input
        for conv in self.net:
            #x = x + conv(input)
            x = x + conv(x) ## <===== CHANGE HERE 
            
        return x

Runs fine:

nfilters=32
F = 256

net = Demo(kernel_sizes=[3]*100)
net.initialize()
net.hybridize()
xx = nd.random.uniform(shape=[7,nfilters,F,F])
out = net(xx)