Problem of Sequential.add(). Unpacking a Sequential makes something different

For some reason i want to put a custom layer right after a ResidualBlock but before relu, so I made a little tweak to the resnet18_v1 in gluon.model_zoo, which remove relu in the ResidualBlock(BasicBlockV1 here). The modified resnet18_v1 is a HybridSequential which looks like:

HybridSequential(
  (0): Conv2D(3 -> 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64)
  (2): Activation(relu)
  (3): MaxPool2D(size=(3, 3), stride=(2, 2), padding=(1, 1), ceil_mode=False)
  (4): BasicBlockV1(
    (body): HybridSequential(
      (0): Conv2D(64 -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64)
      (2): Activation(relu)
      (3): Conv2D(64 -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64)
    )
  )
  (5): Activation(relu)
  (6): BasicBlockV1(
    (body): HybridSequential(
      (0): Conv2D(64 -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64)
      (2): Activation(relu)
      (3): Conv2D(64 -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64)
    )
  )
  (7): Activation(relu)
  (8): BasicBlockV1(
    (body): HybridSequential(
      (0): Conv2D(64 -> 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128)
      (2): Activation(relu)
      (3): Conv2D(128 -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128)
    )
    (downsample): HybridSequential(
      (0): Conv2D(64 -> 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
      (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128)
    )
  )
  (9): Activation(relu)
  (10): BasicBlockV1(
    (body): HybridSequential(
      (0): Conv2D(128 -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128)
      (2): Activation(relu)
      (3): Conv2D(128 -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128)
    )
  )
  (11): Activation(relu)
  (12): BasicBlockV1(
    (body): HybridSequential(
      (0): Conv2D(128 -> 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256)
      (2): Activation(relu)
      (3): Conv2D(256 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256)
    )
    (downsample): HybridSequential(
      (0): Conv2D(128 -> 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
      (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256)
    )
  )
  (13): Activation(relu)
  (14): BasicBlockV1(
    (body): HybridSequential(
      (0): Conv2D(256 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256)
      (2): Activation(relu)
      (3): Conv2D(256 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256)
    )
  )
  (15): Activation(relu)
  (16): BasicBlockV1(
    (body): HybridSequential(
      (0): Conv2D(256 -> 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512)
      (2): Activation(relu)
      (3): Conv2D(512 -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512)
    )
    (downsample): HybridSequential(
      (0): Conv2D(256 -> 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
      (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512)
    )
  )
  (17): Activation(relu)
  (18): BasicBlockV1(
    (body): HybridSequential(
      (0): Conv2D(512 -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512)
      (2): Activation(relu)
      (3): Conv2D(512 -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512)
    )
  )
  (19): Activation(relu)
  (20): GlobalAvgPool2D(size=(1, 1), stride=(1, 1), padding=(0, 0), ceil_mode=True)
)

I want to insert a custom layer after each of 12th, 14th, 16th and 18th BasicBlockV1, and add dropout after activation:

 cnn1 = resnet18_v1().features
 self.net = nn.Sequential()
 last_pos = 0
 for pos in (13, 15, 17, 19):
        self.net.add(cnn1[last_pos:pos])
        self.net.add(myLayer())
        self.net.add(nn.Activation('relu'))
        self.net.add(nn.Dropout(0.5))
        last_pos = pos + 1 # plus 1 to skip the activation layer in cnn1

Everything looked fine but i got a warning during training:

UserWarning: Gradient of Parameter `mynet_resnet180_resnetv10_conv0_weight` on context gpu(0) has not been updated by backward since last `step`. This could mean a bug in your model that made it only use a subset of the Parameters (Blocks) for this iteration. If you are intention

Then I made a little modification to the code which unpacked the hybridSequential:

 cnn1 = resnet18_v1().features
 self.net = nn.Sequential()
 last_pos = 0
 for pos in (13, 15, 17, 19):
        self.net.add(*cnn1[last_pos:pos])  # unpacking
        self.net.add(myLayer())
        self.net.add(nn.Activation('relu'))
        self.net.add(nn.Dropout(0.5))
        last_pos = pos + 1 # plus 1 to skip the activation layer in cnn1

At this time, no warning showed during training.

But this is still confusing me because the behavior of the networks with/without unpacking should be the same.
Did I miss anything or it is just a bug of gluon?

Which version of GluonCV and MXNet do you use?

I tried to reproduce the warning using your code with a dummy training loop, but on the latest version of MXNet and GluonCV I cannot reproduce the warning (MXNet version 1.4.1, GluonCV 0.4.0.post0 - both installed with pip: pip install mxnet and pip install gluoncv).

Here is the code:

import numpy as np
import mxnet as mx
from gluoncv.model_zoo import resnet18_v1
from mxnet import autograd
from mxnet.gluon import nn, Trainer
from mxnet.gluon.loss import L2Loss

cnn1 = resnet18_v1().features
net = nn.Sequential()
last_pos = 0

np.random.seed(42)
mx.random.seed(42)

for pos in (13, 15, 17, 19):
    net.add(cnn1[last_pos:pos])
#   net.add(myLayer())
    net.add(nn.Activation('relu'))
    net.add(nn.Dropout(0.5))
    last_pos = pos + 1  # plus 1 to skip the activation layer in cnn1

x = mx.random.uniform(shape=(1, 214, 214, 3))
label = mx.random.uniform(shape=(1, 512, 1, 1))

net.initialize()
net.summary(x)

l2_loss = L2Loss()
trainer = Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.1, 'wd': 1})

with autograd.record():
    out = net(x)
    loss = l2_loss(out, label)

loss.backward()
trainer.step(1)

The output has no warning:

--------------------------------------------------------------------------------
        Layer (type)                                Output Shape         Param #
================================================================================
               Input                            (1, 214, 214, 3)               0
            Conv2D-1                             (1, 64, 107, 2)          671104
         BatchNorm-2                             (1, 64, 107, 2)             256
        Activation-3                             (1, 64, 107, 2)               0
         MaxPool2D-4                              (1, 64, 54, 1)               0
            Conv2D-5                              (1, 64, 54, 1)           36864
         BatchNorm-6                              (1, 64, 54, 1)             256
        Activation-7                              (1, 64, 54, 1)               0
            Conv2D-8                              (1, 64, 54, 1)           36864
         BatchNorm-9                              (1, 64, 54, 1)             256
     BasicBlockV1-10                              (1, 64, 54, 1)               0
           Conv2D-11                              (1, 64, 54, 1)           36864
        BatchNorm-12                              (1, 64, 54, 1)             256
       Activation-13                              (1, 64, 54, 1)               0
           Conv2D-14                              (1, 64, 54, 1)           36864
        BatchNorm-15                              (1, 64, 54, 1)             256
     BasicBlockV1-16                              (1, 64, 54, 1)               0
           Conv2D-17                             (1, 128, 27, 1)           73728
        BatchNorm-18                             (1, 128, 27, 1)             512
       Activation-19                             (1, 128, 27, 1)               0
           Conv2D-20                             (1, 128, 27, 1)          147456
        BatchNorm-21                             (1, 128, 27, 1)             512
           Conv2D-22                             (1, 128, 27, 1)            8192
        BatchNorm-23                             (1, 128, 27, 1)             512
     BasicBlockV1-24                             (1, 128, 27, 1)               0
           Conv2D-25                             (1, 128, 27, 1)          147456
        BatchNorm-26                             (1, 128, 27, 1)             512
       Activation-27                             (1, 128, 27, 1)               0
           Conv2D-28                             (1, 128, 27, 1)          147456
        BatchNorm-29                             (1, 128, 27, 1)             512
     BasicBlockV1-30                             (1, 128, 27, 1)               0
           Conv2D-31                             (1, 256, 14, 1)          294912
        BatchNorm-32                             (1, 256, 14, 1)            1024
       Activation-33                             (1, 256, 14, 1)               0
           Conv2D-34                             (1, 256, 14, 1)          589824
        BatchNorm-35                             (1, 256, 14, 1)            1024
           Conv2D-36                             (1, 256, 14, 1)           32768
        BatchNorm-37                             (1, 256, 14, 1)            1024
     BasicBlockV1-38                             (1, 256, 14, 1)               0
           Conv2D-39                             (1, 256, 14, 1)          589824
        BatchNorm-40                             (1, 256, 14, 1)            1024
       Activation-41                             (1, 256, 14, 1)               0
           Conv2D-42                             (1, 256, 14, 1)          589824
        BatchNorm-43                             (1, 256, 14, 1)            1024
     BasicBlockV1-44                             (1, 256, 14, 1)               0
           Conv2D-45                              (1, 512, 7, 1)         1179648
        BatchNorm-46                              (1, 512, 7, 1)            2048
       Activation-47                              (1, 512, 7, 1)               0
           Conv2D-48                              (1, 512, 7, 1)         2359296
        BatchNorm-49                              (1, 512, 7, 1)            2048
           Conv2D-50                              (1, 512, 7, 1)          131072
        BatchNorm-51                              (1, 512, 7, 1)            2048
     BasicBlockV1-52                              (1, 512, 7, 1)               0
           Conv2D-53                              (1, 512, 7, 1)         2359296
        BatchNorm-54                              (1, 512, 7, 1)            2048
       Activation-55                              (1, 512, 7, 1)               0
           Conv2D-56                              (1, 512, 7, 1)         2359296
        BatchNorm-57                              (1, 512, 7, 1)            2048
     BasicBlockV1-58                              (1, 512, 7, 1)               0
  GlobalAvgPool2D-59                              (1, 512, 1, 1)               0
       Activation-60                              (1, 512, 1, 1)               0
          Dropout-61                              (1, 512, 1, 1)               0
       Activation-62                              (1, 512, 1, 1)               0
          Dropout-63                              (1, 512, 1, 1)               0
       Activation-64                              (1, 512, 1, 1)               0
          Dropout-65                              (1, 512, 1, 1)               0
       Activation-66                              (1, 512, 1, 1)               0
          Dropout-67                              (1, 512, 1, 1)               0
================================================================================
Parameters in forward computation graph, duplicate included
   Total params: 11847808
   Trainable params: 11838208
   Non-trainable params: 9600
Shared params in forward computation graph: 0
Unique parameters in model: 11847808
--------------------------------------------------------------------------------

Thanks for you reply.
I was using mxnet-cu92 1.4.1, and GluonCV wasn’t used in my code. I modified the resnet18_v1 in mxnet.gluon.model_zoo.vision.
And here is the modification code:

class BasicBlockV1(HybridBlock):

    def __init__(self, channels, stride, downsample=False, in_channels=0, **kwargs):
        super(BasicBlockV1, self).__init__(**kwargs)
        ........

    def hybrid_forward(self, F, x):
        residual = x

        x = self.body(x)

        if self.downsample:
            residual = self.downsample(residual)
        # remove relu
        # x = F.Activation(residual + x, act_type='relu') 
        return residual + x
def resnet18_v1(**kwargs):
    block_pos = [4, 5, 6, 7]
    resnet = get_resnet(1, 18, **kwargs)
    with resnet.name_scope():
        features = nn.HybridSequential()
        last_pos = 0
        for pos in block_pos:
            features.add(*resnet.features[last_pos:pos])
            for i in range(2):
                features.add(resnet.features[pos][i])
                features.add(nn.Activation('relu'))
            last_pos = pos + 1
        features.add(*resnet.features[last_pos:])
        resnet.features = features
    return resnet

What the modification did is moving relu out of BasicBlockV1 without changing the parameters’ name scope.

You cannot reproduce the warning because you used the original resnet18_v1 whose features is a HybridSequential that has only 9 layers and inserting layers at position 13, 15, 17, 19 actually did nothing the same as my code.