Finetuning in MXNet for ConvNet blocks

Im new to MXNet and I was wondering if any one knows how to fine tune more layers in CNN other than only the FC layers. All the examples that Im looking at, have fine tuning only on the FC layers. In Keras this can be easily done and more blocks of ConvNets other than FC block can be fine tuned: https://github.com/Hvass-Labs/TensorFlow-Tutorials/blob/master/10_Fine-Tuning.ipynb

If we want to fine-tune only the FC block, we make all the layers trainability to false: layer.trainable = False

If we want to fine-tune more blocks of ConnNet other than FC layers, we make t: he layer.trainable=True for those layers

My question is how to do similarly in MXNet

Hi,

I guess you should wait for input from mxnet experts, but this will get you started (with gluon).

Every variable in a network is stored in a gluon.Parameter object. This has the property grad_req with the following three options: write, for updating the gradients (thus modifying them), null for non trainable variables (parameters), and add for adding in place (not sure for the inplace), the new gradients when trainer.step is called. So in your case, you can view explicitly which layer in a pre-trained network has grad_req == null that you want to train, and change that. I am under the impression - but I may be wrong - that pretrained networks are not by default non-trainable (even when you load with pretrained=True). Let’s see an example I just run:

import essentials

import mxnet as mx
from mxnet import gluon

# This loads a pre-trained network
net = gluon.model_zoo.vision.resnet18_v2(pretrained=True)

if you run net in the shell, you’ll see the layers of the network:

net

# prints
ResNetV2(
  (features): HybridSequential(
    (0): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=True, use_global_stats=False, in_channels=3)
    (1): Conv2D(3 -> 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64)
    (3): Activation(relu)
    (4): MaxPool2D(size=(3, 3), stride=(2, 2), padding=(1, 1), ceil_mode=False)
    (5): HybridSequential(
      (0): BasicBlockV2(
        (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64)
        (conv1): Conv2D(64 -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64)
        (conv2): Conv2D(64 -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (1): BasicBlockV2(
        (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64)
        (conv1): Conv2D(64 -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64)
        (conv2): Conv2D(64 -> 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
    (6): HybridSequential(
      (0): BasicBlockV2(
        (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=64)
        (conv1): Conv2D(64 -> 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128)
        (conv2): Conv2D(128 -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (downsample): Conv2D(64 -> 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
      )
      (1): BasicBlockV2(
        (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128)
        (conv1): Conv2D(128 -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128)
        (conv2): Conv2D(128 -> 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
    (7): HybridSequential(
      (0): BasicBlockV2(
        (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=128)
        (conv1): Conv2D(128 -> 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256)
        (conv2): Conv2D(256 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (downsample): Conv2D(128 -> 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
      )
      (1): BasicBlockV2(
        (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256)
        (conv1): Conv2D(256 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256)
        (conv2): Conv2D(256 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
    (8): HybridSequential(
      (0): BasicBlockV2(
        (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256)
        (conv1): Conv2D(256 -> 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512)
        (conv2): Conv2D(512 -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (downsample): Conv2D(256 -> 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
      )
      (1): BasicBlockV2(
        (bn1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512)
        (conv1): Conv2D(512 -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512)
        (conv2): Conv2D(512 -> 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
    )
    (9): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=512)
    (10): Activation(relu)
    (11): GlobalAvgPool2D(size=(1, 1), stride=(1, 1), padding=(0, 0), ceil_mode=True)
    (12): Flatten
  )
  (output): Dense(512 -> 1000, linear)
)

Now, in this pre-trained network you can see there exist net.features that is a HybridSequential container of which you can access the layers with simple indexing (just like a python list). For example, say I want to see the grad_req value (string) of the 2nd layer of features (1st: BatchNorm, 2nd: Conv2D)

net.features[1] # Note that index 0 refers to the BatchNorm layer

# prints
Conv2D(3 -> 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

Let’s see the grad_req:

for param in net.features[1].collect_params().values():
    print (param.name,", ", param.grad_req)

# prints, observe no bias since it follows a BatchNorm (```bias=False```)
resnetv20_conv0_weight ,  write

so convolutional weight is writable (i.e. trainable). If you want to freeze some of the variables/layers of the pre-trained network (so they are not updated), you need to do the above loop, with assignment, where grad_req == 'write':

for param in net.features[1].collect_params().values(): # Or some other layers that you want. 
    param.grad_req='null'

The above is my summary of understanding on pre-trained networks, I maybe wrong, but you can experiment. Hope this helps.

3 Likes

Very helpful to me! Thanks a lot.

1 Like

If you also want to set differential learning rates, you can check the answer here: