How to implement a shared layer?

How to feed two different input to one single layer(e.g. FullyConnected) and get two corresponding output?
For example:


which comes from eccv2018

Hi,

I don’t think it makes much sense to feed two independent inputs, and get two independent outputs (you can just apply the network twice), but If I understand what you need, should be something like:

class YourNet(HybridBlock):
    def __init__(self,some_params,**kwards):
        HybridBlock.__init__(self,**kwards)


        with self.name_scope():
            self.FC = gluon.nn.Dense(some_params)



    def hybrid_forward(self, F, input1, input2):

        out1 = self.FC(input1)
        out2 = self.FC(input2)

        return out1, out2 

Alternatively you could do

class YourNet(HybridBlock):
    def __init__(self,some_params,**kwards):
        HybridBlock.__init__(self,**kwards)

        with self.name_scope():
            self.FC = gluon.nn.Dense(some_params)



    def hybrid_forward(self, F, input):

        out = self.FC(input)
       
        return out


with autograd.record():
     pred1 = mynet(input1)
     pred2 = mynet(input2)
    # and then combine both in some loss function

The situation is different if you want to have a shared layer within the same network. Say for example, you have an image whose channels differ significantly, and it doesn’t make sense to apply a convolution layer to all channels, so you split them in two inputs. However you want to apply the same convolutional layers in both inputs, and then concatenate the result and feed it into a classifier (FC).

Something like:

class YourNet(HybridBlock):
    def __init__(self,some_params,**kwards):
        HybridBlock.__init__(self,**kwards)


        with self.name_scope():
            self.features = gluon.nn.Conv(some_params) # Add here more layers, say can be a resnet18_v2  
            self.FC = gluon.nn.Dense(some_params) #here you can use a more robust classifier etc



    def hybrid_forward(self, F, input1, input2):

         # Here the out1, out2 are produced from the same weights/feature extractor. 
        out1 = self.features(input1)
        out2 = self.features(input2)

        out = F.concat(out1,out2,dim=1) # concat on channel axis
        out = self.FC(out)
        return out

Hope this helps.

Thank you very much.

Is there a way to do that without Gluon.

What I want to do is building a single block of the whole architecture:

avg = mx.sym.Pooling(data=data, pool_type='avg')
max = mx.sym.Pooling(data=data, pool_type='max')

# assume FC is a shared fc, HOWTO do this...
fc_avg = FC(data=avg)
fc_max = FC(data=max)

# do element-wise add or something else
out = fc_avg + fc_max
1 Like

Hi @ZhouJ, I am sorry but I am using only the gluon API, someone more experienced user will definitely reply to you.

You have two options.

  1. concatenate the avg and max on batch axis, pass them through your FC layer, and then split.
  2. create two FC layers with shared weight and bias by creating weight and bias symbol variables and passing them explicitly to FC.
3 Likes

I reproduce in this way.

### channel attention module
pool_c_max = mx.sym.Pooling(data=F, kernel=(3,3), pool_type='max', global_pool=1, name=name + '_pool_c_max')
pool_c_avg = mx.sym.Pooling(data=F, kernel=(3,3), pool_type='avg', global_pool=1, name=name + '_pool_c_avg')
mlp_1_weight = mx.sym.Variable(name=name + '_mlp_1_weight')
mlp_2_weight = mx.sym.Variable(name=name + '_mlp_2_weight')
fc_mlp_max_1 = mx.sym.FullyConnected(data=pool_c_max, weight=mlp_1_weight, num_hidden=num_filter/4, no_bias=True, flatten=True, name=name + '_fc_mlp_max_1')
fc_mlp_max_2 = mx.sym.FullyConnected(data=fc_mlp_max_1, weight=mlp_2_weight, num_hidden=num_filter, no_bias=True, flatten=True, name=name + '_fc_mlp_max_2')
fc_mlp_avg_1 = mx.sym.FullyConnected(data=pool_c_avg, weight=mlp_1_weight, num_hidden=num_filter/4, no_bias=True, flatten=True, name=name + '_fc_mlp_avg_1')
fc_mlp_avg_2 = mx.sym.FullyConnected(data=fc_mlp_avg_1, weight=mlp_2_weight, num_hidden=num_filter, no_bias=True, flatten=True, name=name + '_fc_mlp_avg_2')
fc_mlp = mx.sym.add_n(*[fc_mlp_max_2, fc_mlp_avg_2])
atten_c = mx.sym.sigmoid(fc_mlp)

atten_c = mx.sym.reshape(data=atten_c, shape=(0, 0, 1, 1))
F1 = mx.sym.broadcast_mul(lhs=F, rhs=atten_c)
1 Like

Thank you. :grinning:
All the methods are very good.

Thank you. :grinning:

2 Likes