Custom layer - infer shape after first forward pass

Dear all,

I am designing a custom convolution layer (HybridBlock), and it is not easy to understand how to initialize the weight parameter (specifically the number of channels) after the first forward pass. I am looking at the source code of _Conv private class, but it’s a bit tricky. Any ideas?

My custom convolution operator (currently deriving from gluon.Block) is something like this:

class Conv2DS(Block):
    
    def __init__(self, nchannels, nfilters, kernel_size = 3, kernel_effective_size = 5, degree = 2, pad = None, dilation_rate=[1,1],**kwards):
        Block.__init__(self,**kwards)
        
        self.nchannels = nchannels
        self.nfilters = nfilters
        
        self.kernel_eff = kernel_effective_size
        self.dilation_rate = dilation_rate
        self.Bijkl = # This is some custom 4D matrix I use in convolution. 
         
        # Ensures padding = 'SAME' for ODD kernel selection 
        if (pad ==None):
            p0 = self.dilation_rate[0] * (self.kernel_eff - 1)/2
            p1 = self.dilation_rate[1] * (self.kernel_eff - 1)/2
            pad = (p0,p1)

    
        self.pad = pad
        with self.name_scope():
            
            # This is where I define the custom weight variable
            self.weight = self.params.get(
                'weight',
                shape=[nfilters,self.nchannels,kernel_size,kernel_size])
            
    def forward(self,_x):
         # I would like here the shape self.nchannels to be inferred from the input _x
         # Any pointers / ideas / easy small example?  
        weight = nd.sum(nd.dot(self.weight.data() , self.Bijkl),axis=[2,3])
        conv = nd.Convolution(data=_x,
                             weight=weight,
                             no_bias=True,
                             num_filter=self.nfilters,
                             kernel=[self.kernel_eff,self.kernel_eff],
                             pad=self.pad)
        
        return conv

thank you very much.

Your sample code is using Block, not HybridBlock. The necessary steps are a bit different between the two and I’d recommend that you stick with HybridBlock. In __init__() of both cases, you’d want to call self.params.get() with any unknown dimension of shape argument to 0. For HybridBlock, everything is done for you and when hybrid_forward is called with data, the necessary shapes have been inferred. For Block, you then need to set the shape of the weight (self.weight.shape=(...)) using the shape of the passed-in data (in your case _x) and then call self.weight._finish_deferred_init() to initialize it.

Under the hood, HybridBlock will construct a symbolic graph of the block in order to infer the shape of the unspecified dimensions the first time data is passed into the block.

1 Like

Hi @safrooze, thank you very much for your reply, extremely appreciated!

Based on your suggestions, I’ve tried the following things (I’ve simplified the example):

Block version (works)

import mxnet as mx
from mxnet import nd, gluon
from mxnet.gluon import HybridBlock, Block

class Conv2DS(Block):
    # Now the nchannels variable has initial value zero, this is the variable I need to be inferred
    def __init__(self,  nfilters, nchannels=0, kernel_size = 3, kernel_effective_size = 5,**kwards):
        Block.__init__(self,**kwards)
        
        self.nchannels = nchannels
        self.nfilters = nfilters
        self.kernel_size = kernel_size
        self.kernel_eff = kernel_effective_size
        # Some custom operation that creates a "deprojection" matrix, for now a simple random NDArray
        self.Bijkl = nd.random_uniform(shape=[kernel_size,kernel_size,kernel_effective_size,kernel_effective_size])
        
        with self.name_scope():
            

            self.weight = self.params.get(
                'weight',allow_deferred_init=True,#  init=mx.init.Xavier(magnitude=2.24),
                shape=(nfilters,nchannels,kernel_size,kernel_size))
            
    
    def forward(self,_x):
        self.weight.shape = (self.nfilters,_x.shape[1],self.kernel_size,self.kernel_size)
        self.weight._finish_deferred_init()
        
        weight = nd.sum(nd.dot(self.weight.data() , self.Bijkl),axis=[2,3])
        #print weight.shape
        conv = nd.Convolution(data=_x,
                             weight=weight,
#                             bias=self.bias.data(),
                             no_bias=True,
                             num_filter=self.nfilters,
                             kernel=[self.kernel_eff,self.kernel_eff])
        
        return conv


nbatch=25
nfilters=12
nchannels=7


myConv = Conv2DS(nfilters, kernel_size=3, kernel_effective_size=5)
myConv.initialize(mx.initializer.Xavier())

so far so good, but when I try to do a forward pass:

xx = nd.random_uniform(shape=[nbatch,nchannels,128,128])
temp1= myConv(xx)
print (temp1.shape)
Output
(25L, 12L, 124L, 124L)

HybridBlock version (doesn’t work)

import mxnet as mx
from mxnet import nd, gluon
from mxnet.gluon import HybridBlock, Block


class Conv2DS(HybridBlock):
    # Now the nchannels variable has initial value zero, this is the variable I need to be inferred
    def __init__(self,  nfilters, nchannels=0, kernel_size = 3, kernel_effective_size = 5,**kwards):
        HybridBlock.__init__(self,**kwards)
        
        self.nchannels = nchannels
        self.nfilters = nfilters
        
        # Some custom operation that creates a "deprojection" kernel, for now a simple random NDArray
        self.Bijkl = nd.random_uniform(shape=[kernel_size,kernel_size,kernel_effective_size,kernel_effective_size])
        
        with self.name_scope():
            

            self.weight = self.params.get(
                'weight',allow_deferred_init=True,
                shape=(nfilters,nchannels,kernel_size,kernel_size))
            
    
    def hybrid_forward(self,F,_x):
                    
        weight = F.sum(F.dot(self.weight.data() , self.Bijkl),axis=[2,3])
        #print weight.shape
        conv = F.Convolution(data=_x,
                             weight=weight,
#                             bias=self.bias.data(),
                             no_bias=True,
                             num_filter=self.nfilters,
                             kernel=[self.kernel_eff,self.kernel_eff])
        
        return conv

then I can initialize the Conv2DS layer:


nbatch=25
nfilters=12
nchannels=7


myConv = Conv2DS(nfilters, kernel_size=3, kernel_effective_size=5)
myConv.initialize(mx.initializer.Xavier())

so far so good, but when I try to do a forward pass:

xx = nd.random_uniform(shape=[nbatch,nchannels,128,128])
temp1= myConv(xx)

I get the following error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-75-6a0caa7e4241> in <module>()
----> 1 temp1= myConv(xx)
      2 #temp2 = myConv_std(xx)

/home/dia021/anaconda2/lib/python2.7/site-packages/mxnet/gluon/block.pyc in __call__(self, *args)
    302     def __call__(self, *args):
    303         """Calls forward. Only accepts positional arguments."""
--> 304         return self.forward(*args)
    305 
    306     def forward(self, *args):

/home/dia021/anaconda2/lib/python2.7/site-packages/mxnet/gluon/block.pyc in forward(self, x, *args)
    507                     params = {i: j.data(ctx) for i, j in self._reg_params.items()}
    508                 except DeferredInitializationError:
--> 509                     self._finish_deferred_init(self._active, x, *args)
    510 
    511                 if self._active:

/home/dia021/anaconda2/lib/python2.7/site-packages/mxnet/gluon/block.pyc in _finish_deferred_init(self, hybrid, *args)
    401 
    402     def _finish_deferred_init(self, hybrid, *args):
--> 403         self.infer_shape(*args)
    404         if hybrid:
    405             for is_arg, i in self._cached_op_args:

/home/dia021/anaconda2/lib/python2.7/site-packages/mxnet/gluon/block.pyc in infer_shape(self, *args)
    460     def infer_shape(self, *args):
    461         """Infers shape of Parameters from inputs."""
--> 462         self._infer_attrs('infer_shape', 'shape', *args)
    463 
    464     def infer_type(self, *args):

/home/dia021/anaconda2/lib/python2.7/site-packages/mxnet/gluon/block.pyc in _infer_attrs(self, infer_fn, attr, *args)
    448     def _infer_attrs(self, infer_fn, attr, *args):
    449         """Generic infer attributes."""
--> 450         inputs, out = self._get_graph(*args)
    451         args, _ = _flatten(args)
    452         arg_attrs, _, aux_attrs = getattr(out, infer_fn)(

/home/dia021/anaconda2/lib/python2.7/site-packages/mxnet/gluon/block.pyc in _get_graph(self, *args)
    369             params = {i: j.var() for i, j in self._reg_params.items()}
    370             with self.name_scope():
--> 371                 out = self.hybrid_forward(symbol, *grouped_inputs, **params)  # pylint: disable=no-value-for-parameter
    372             out, self._out_format = _flatten(out)
    373 

TypeError: hybrid_forward() got an unexpected keyword argument 'weight'

Any ideas what I am doing wrong and how to fix it? The input to the Conv2DS will be another convolution - image - operator of size (nfilters, nchannels, height, width). the dimension I need to infer on run-time is nchannels.

Thanks!

This is because hybrid_forward passes the data as well as all the parameters to your block. The reason for this is that hybrid_forward may be called with F as NDArray or as Symbol and if F is a Symbol, you must have access to a symbol that represents your parameter (which you cannot do by using self.weight).

So the solution to your problem is to either add **kwargs to your hybrid_forward function signature, or simply add a weight argument to the function signature.

Another problem with your HybridBlock code is self.Bijkl. This is an NDArray instance, which cannot be used in hybrid_forward(). Remember that hybrid_forward may be called with Symbol or NDArray, so you cannot have dependency on one of the other (and that’s why F is passed in). The operation that is required to create self.Bijkl must be moved to hybrid_forward() and must be changed to utilize F instead of nd.

It seems like in your case bijkl might be a constant parameter. In that case, you can add a parameter, set differentiable to False, and use a Constant initializer for it.

1 Like

Hi @safrooze, again many thanks for your answer. I’ve been playing around with your suggestions - not much luck in managing to hybridize my custom layer.

Would it please be possible to provide a simple example of a HybridBlock wrapper around an nd.array object? I need to create a matrix object in numpy (it would be very time consuming to create it from scratch in nd.array - it’s basically BSplines definitions which already exist in python), and then transfer it to nd.array. Something like

class ndarray_wrap(HybridBlock):
    def __init__(self, const_ndarray, *kwards):
        HybridBlock.__init__(self,**kwards)

       # Some operations that take  constant const_ndarray 
       # transforms const_ndarray to a layer with no differentiation 
       self.constant_layer = ... 


    def hybrid_forward(self,F,x):
        return self.constant_layer

such as it can be used in combination with other HybridBlocks and eventually hybridize the whole network?
Again, many thanks for your time.

I appologize @feevos for the delayed response. I hope the following code can help with understanding how to use parameters (constant or not) with a custom HybridBlock:

class CustomConv(HybridBlock):
    def __init__(self, const_ndarray, **kwargs):
        super(CustomConv, self).__init__(**kwargs)
        self.weight = self.params.get('weight', shape=(100, 100, 3, 3))
        self.bijkl = self.params.get(
            'bijkl', shape=const_ndarray.shape,
            init=mx.init.Constant(const_ndarray.asnumpy().tolist()), differentiable=False)

    def hybrid_forward(self, F, x, weight, bijkl):
        proj_weight = F.sum(F.dot(weight, bijkl), axis=[2, 3])
        return F.Convolution(data=x, weight=proj_weight, no_bias=True, num_filter=100, kernel=(5, 5))


bijkl_const = nd.ones(shape=(3, 3, 5, 5)) * 5
net = CustomConv(bijkl_const)

net.collect_params().initialize()

x = nd.random.uniform(shape=(16, 100, 128, 128))
net.hybridize()  # Remove for normal execution, keep for hybridized execution
out = net(x)

print(out.shape)

I recommend trying both normal and hybridized versions and putting a breakpoint in hybrid_forward() call to fully understand what gets passed into the function in both modes.

3 Likes

Hi @safrooze, thank you very much, your solution is very informative and works wonders (apologies for late reply as well - was on leave). I’ve modified your example to include an optional bias term:

class CustomConv(HybridBlock):
    def __init__(self, const_ndarray, use_bias = True, **kwargs):
        super(CustomConv, self).__init__(**kwargs)
        self.use_bias = use_bias 
        
        with self.name_scope():
            self.weight = self.params.get('weight', 
                                          shape=(100, 100, 3, 3),
                                          allow_deferred_init=True)
            self.bijkl = self.params.get(
                'bijkl', 
                shape=const_ndarray.shape,
                init=mx.init.Constant(const_ndarray.asnumpy().tolist()), 
                differentiable=False)

        if self.use_bias:
                self.bias = self.params.get(
                    'bias',
                    allow_deferred_init=True,
                    init = mx.init.Zero(),
                    shape=(100,))
        
    def hybrid_forward(self, F, x, weight, bijkl, bias=None):
        proj_weight = F.sum(F.dot(weight, bijkl), axis=[2, 3])
        
        if self.use_bias:
            return F.Convolution(data=x, weight=proj_weight, bias=bias, num_filter=100, kernel=(5, 5))
            
        else:
        
            return F.Convolution(data=x, weight=proj_weight, no_bias=True, num_filter=100, kernel=(5, 5))

The only minor issue I still have is how to infer the shape during first run. If I add a dimension (in shape) as zero, I get an error. Is it possible to infer the shape (of some dimension, not all) when passing arguments in the hybrid_forward function?

Again many thanks, extremely appreciate your help.

edit: it seems the problem is in the line:

proj_weight = F.sum(F.dot(weight, bijkl), axis=[2, 3])

if I remove this it can identify automatically the shape of weight.

1 Like

The problem with shape inference in your code is that it is not possible to infer the missing dimension in shape of weight in F.dot operation, even though it can later on be inferred from convolution. The way shape inference is done in MXNet, each operator tries to infer the shape of its input parameters as well as its outputs. However the mapping between missing input dimensions and missing output dimensions isn’t maintained anywhere and I don’t see any mechanism to “back-propagate” inferred shapes in down-stream operators back.

One option for you would be to do this manually by setting the shape attribute of Parameter based on the shape of your data before calling the network.

1 Like

Hi @safrooze thank you very much. I’ve learned so many things in this topic from your contribution. I’ll use what you suggested.

Thanks!

Glad I could help a MXNet fan!

1 Like

Hi to all,

@safrooze gave the solution in this topic in the discussion forum. The trick is to overwrite the forward function, and getting the layer shape in there. Example

from mxnet import gluon

class GetShape(gluon.HybridBlock):
    def __init__(self,nchannels=0, kernel_size=(3,3), **kwards):
        gluon.HybridBlock.__init__(self,**kwards)
        
        self.layer_shape = None
        
        with self.name_scope():
            self.conv = gluon.nn.Conv2D(nchannels,kernel_size=kernel_size)
            
            
            
    def forward(self,x):
        self.layer_shape = x.shape
        
        return gluon.HybridBlock.forward(self,x)
    
    def hybrid_forward(self,F,x):
        print (self.layer_shape)
        out = self.conv(x)
        return out

mynet = GetShape(nchannels=12)
mynet.hybridize()

mynet.initialize(mx.init.Xavier(),ctx=ctx)
xx = nd.random.uniform(shape=[32,8,128,128])
out = mynet(xx)
# prints (32, 8, 128, 128)

Thank you @safrooze !!

Just to clarify, this trick doesn’t work if the block is child of another HybridBlock. In that case, forward() is also called with x of type Symbol.

1 Like

and I was sooooooo excited about this! That’s OK, getting there … :slight_smile:

Hello @safrooze, I’d appreciate your help, sorry if I missed something.

I am trying to write my own Dense as HybridBlock, I want to learn how to leave in_units unspecified. I tried your advice to (a) leave the dim to 0, and (b) pass weight and bias to hybrid_forward, but it does not work. My code:

class MyDense(gluon.HybridBlock):
def __init__(self, units, activation=None, use_bias=True, in_units=None, **kwargs):
    super(MyDense, self).__init__(*kwargs)
    if in_units is None:
        in_units = 0
    w_shape = (in_units, units)
    b_shape = (1, units)
    with self.name_scope():
        self.weight = self.params.get(
            'weight', init=mx.init.Normal(0.01), shape=w_shape)
        self.bias = self.params.get(
            'bias', init=mx.init.Zero(), shape=b_shape)
    if activation is not None:
        self.activation = gluon.nn.Activation(activation)
    else:
        self.activation = None

def hybrid_forward(self, F, x, weight, bias):
    a = F.broadcast_add(F.dot(x, weight), bias)
    if self.activation is not None:
        y = self.activation(a)
    else:
        y = a
    return y

Then, when I run:

net = MyDense(64, activation='relu'); net.initialize()

I am getting an error:

ValueError: Cannot initialize Parameter 'mydense2_weight' because it has invalid shape: (0, 64)

Hi @mseeger,

you need to use F.FullyConnected instead of F.dot to infer shape (I don’t recall why, but it has come up in the past, search it also in mxnet github issues). This example works:

class MyDense(gluon.HybridBlock):
    def __init__(self, units, activation=None, use_bias=True, in_units=None, **kwargs):
        super().__init__(**kwargs)
        if in_units is None:
            in_units = 0
            
        self.units = units
        #w_shape = (in_units, units)
        w_shape = ( units, in_units)
        #b_shape = (1,units)
        b_shape = (units,1)
        
        with self.name_scope():
            self.weight = self.params.get(
                'weight', init=mx.init.Normal(0.01), 
                allow_deferred_init=True, 
                shape=w_shape)
            self.bias = self.params.get(
                'bias', init=mx.init.Zero(), shape=b_shape)
        if activation is not None:
            self.activation = gluon.nn.Activation(activation)
        else:
            self.activation = None

    def hybrid_forward(self, F, x, weight, bias):
        #a = F.broadcast_add(F.dot(x, weight), bias) # You need to use F.FullyConnected,
        a = F.FullyConnected(data=x,weight=weight, bias=bias, num_hidden=self.units)
        if self.activation is not None:
            y = self.activation(a)
        else:
            y = a
        return y
net = MyDense(64, activation='relu'); net.initialize()
xx = nd.random.uniform(shape=[16,32])
net.summary(xx)

output:
image

You need to add an if statement inside the hybrid_forward to take care of the case where you have no bias.

1 Like

Thanks! Wow that is super-odd.

Somehow I expected that MXNet is dealing with the missing shape when passing weight in as argument to hybrid_forward. But it appears this is not the case. I really do not understand in the end how this works internally, I have to say.