HybridBlock,How does he add more input to this hybrid_forward function?

Hi @DarkWings,

I see you are trying to use UNet, and this is why you need more input in the hybrid_forward, is this correct?
In general it is straightforward to add additional arguments and use them, see this link and this for a UNet. However, in particular for unet I don’t see a benefit as it is just a concat operation you are after. So this UNet should be a good starting point for you (I haven’t debugged the code, but looks OK, at least for a starting point. You may want to add BatchNorm after Conv2DTranspose layers too. )?

import mxnet as mx
from mxnet import gluon 
from mxnet.gluon import HybridBlock # This is for imperative programming. Change to HybridBlock 


class UNetBlock (HybridBlock):
    def __init__(self, Nfilters, **kwargs):
        super(UNetBlock,self).__init__(**kwargs)


        with self.name_scope():
            self.act = gluon.nn.Activation('relu')


            self.conv1 = gluon.nn.Conv2D(Nfilters,kernel_size=3,padding=1,use_bias=False)
            self.bn1 = gluon.nn.BatchNorm(axis=1)
            self.conv2 = gluon.nn.Conv2D(Nfilters,kernel_size=3,padding=1,use_bias=False)
            self.bn2 = gluon.nn.BatchNorm(axis=1)


    def hybrid_forward(self,F,x):

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.act(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.act(x)

        return x



class UNet(HybridBlock):
    def __init__(self,  NClasses, Nfilters_init=32, **kwargs):
        super(UNet, self).__init__(**kwargs)
        
        
        with self.name_scope():

            # A single pool is enough, since it doesn't have parameters. 
            self.pool = gluon.nn.MaxPool2D(pool_size=2,strides=2)

            # 32
            self.block1 = UNetBlock(Nfilters_init)
            
            # 64
            self.block2 = UNetBlock(Nfilters_init*2)
            
            # 128 
            self.block3 = UNetBlock(Nfilters_init*2**2)
            
            # 256
            self.block4 = UNetBlock(Nfilters_init*2**3)
            
            # 512
            self.block5 = UNetBlock(Nfilters_init*2**4)

            
            
            # 256
            self.up6 = gluon.nn.Conv2DTranspose(Nfilters_init*2**3,kernel_size=(2,2), strides=(2,2),activation='relu') 
            self.block6 = UNetBlock(Nfilters_init*2**3)
            
            
            # 128 
            self.up7 = gluon.nn.Conv2DTranspose(Nfilters_init*2**2,kernel_size=(2,2), strides=(2,2),activation='relu')
            self.block7 = UNetBlock(Nfilters_init*2**2)
            
            # 64 
            self.up8 = gluon.nn.Conv2DTranspose(Nfilters_init*2**1,kernel_size=(2,2), strides=(2,2),activation='relu')
            self.block8 = UNetBlock(Nfilters_init*2)
            
            

            # 32 
            self.up9 = gluon.nn.Conv2DTranspose(Nfilters_init*2**0,kernel_size=(2,2), strides=(2,2),activation='relu')
            self.block9 = UNetBlock(Nfilters_init)
            
            self.convLast = gluon.nn.Conv2D(NClasses,kernel_size=(1,1),padding=0)
            
            
            
    def hybrid_forward(self,F,x):
        
        conv1 = self.block1(x)
        pool1 = self.pool(conv1)
            
        conv2 = self.block2(pool1)
        pool2 = self.pool(conv2)
 

        conv3 = self.block3(pool2)
        pool3 = self.pool(conv3)


        conv4 = self.block4(pool3)
        pool4 = self.pool(conv4)

        conv5 = self.block5(pool4)

       
        # UpSampling with transposed Convolution
        conv6 = self.up6(conv5)
        conv6 = F.concat(conv6,conv4)
        conv6 = self.block6(conv6)


        # UpSampling with transposed Convolution
        conv7 = self.up7(conv6)
        conv7 = F.concat(conv7,conv3)
        conv7 = self.block7(conv7)


        # UpSampling with transposed Convolution
        conv8 = self.up8(conv7)
        conv8 = F.concat(conv8,conv2)
        conv8 = self.block8(conv8)

        # UpSampling with transposed Convolution
        conv9 = self.up9(conv8)
        conv9 = F.concat(conv9,conv1)
        conv9 = self.block9(conv9)

        
        final_layer = self.convLast(conv9)
        final_layer = F.softmax(final_layer,axis=1)
        
        return final_layer


1 Like