Gluon access intermediate layers in HybridBlocks

Is it possible to access intermediate layers in hybrid blocks? For example, in the following code (source) is it possible to access the batchnorm layer? Indexing doesn’t work and I get the error stating LinearBottleneck does not support indexing.

class LinearBottleneck(nn.HybridBlock):
    """LinearBottleneck used in MobileNetV2 model from the
    `"Inverted Residuals and Linear Bottlenecks:
    Mobile Networks for Classification, Detection and Segmentation"
    <https://arxiv.org/abs/1801.04381>`_ paper.
    Parameters
    ----------
    in_channels : int
        Number of input channels.
    channels : int
        Number of output channels.
    t : int
        Layer expansion ratio.
    stride : int
        stride
    """

    def __init__(self, in_channels, channels, t, stride, **kwargs):
        super(LinearBottleneck, self).__init__(**kwargs)
        self.use_shortcut = stride == 1 and in_channels == channels
        with self.name_scope():
            self.out = nn.HybridSequential()

            _add_conv(self.out, in_channels * t, relu6=True)
            _add_conv(self.out, in_channels * t, kernel=3, stride=stride,
                      pad=1, num_group=in_channels * t, relu6=True)
            _add_conv(self.out, channels, active=False, relu6=True)

    def hybrid_forward(self, F, x):
        out = self.out(x)
        if self.use_shortcut:
            out = F.elemwise_add(out, x)
        return out


# pylint: disable= too-many-arguments
def _add_conv(out, channels=1, kernel=1, stride=1, pad=0,
              num_group=1, active=True, relu6=False):
    out.add(nn.Conv2D(channels, kernel, stride, pad, groups=num_group, use_bias=False))
    out.add(nn.BatchNorm(scale=True))
    if active:
        out.add(RELU6() if relu6 else nn.Activation('relu'))

The reason I want to access this is because I want to set use_global_stats in the batchnorm layers to False for training another task branch. Is it possible to do so in Gluon?

Your model would support indexing if you have defined your model using .add()
For example:

model = nn.HybribSequential()
model.add(nn.Dense(128, 'relu')) # first layer
model.add(nn.Dense(10)) # second layer

print(model)
# will print
'''
HybridSequential(
  (0): Dense(None -> 128, Activation(relu))
  (1): Dense(None -> 10, linear)
)
'''
# As you can see first layer is at index 0
# so you can access it as below
print(model[0])
# will print "Dense(None -> 128, Activation(relu))"

So how to access layers if we are defining our model using class inheritance of nn.Sequential or nn.HybridSequentail?(Which is your case)

Solution:
If you look under the definition of your model, you’ll see that you are creating a class attribute as “nn.HybridSequential”, that is self.out = nn.HybridSequential(), which is the part of hybrid_forward method, and then you are adding layers to it using .add(under _add_conv function).

So your class LinearBottleneck had attribute out that contains all your layers you’ve added.

Now let’s print a model defined using LinearBottleneck

model = LinearBottleneck(3, 16, 1, 1)
print(model)
# will print
'''
LinearBottleneck(
  (out): HybridSequential(
    (0): Conv2D(None -> 3, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None)
    (2): RELU6(
    
    )
    (3): Conv2D(None -> 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=3, bias=False)
    (4): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None)
    (5): RELU6(
    
    )
    (6): Conv2D(None -> 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (7): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None)
  )
)
'''

As you can see the layer you want to access is the part of (out) , with index (1).
So you can access it as below:

print(model.out[1])
# will print
'''
BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=None)
'''

Hope this helps.

Thank you for helping me out! This worked.