Change Conv2D layer on pretrained network for Image Segmentation

Hi,

I am wanting to fine tune a pretrained image segmentation network on a new dataset. Specifically, I want to finetune the DeepLabV3 with a resnet101 backbone pretrained on the ADE 20k data set. Following the demo_deeplab.ipynb on gluon-cv, I download the model:

model = gluoncv.model_zoo.get_model(‘deeplab_resnet101_ade’, pretrained=True)

I see that the model outputs 150 classes in its semantic segmentation. We can see that in the last layer:

model.auxlayer

output

_FCNHead(
  (block): HybridSequential(
    (0): Conv2D(1024 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm(axis=1, eps=1e-05, momentum=0.9, fix_gamma=False, use_global_stats=False, in_channels=256)
    (2): Activation(relu)
    (3): Dropout(p = 0.1, axes=())
    (4): Conv2D(256 -> 150, kernel_size=(1, 1), stride=(1, 1))
  )
)

I would like to retain the pretrained weights, but would like to change

(4): Conv2D(256 → 150, kernel_size=(1, 1), stride=(1, 1))

to

(4): Conv2D(256 → 1, kernel_size=(1, 1), stride=(1, 1))

Is there a simple way of doing this (similar to the image classification examples)?

For some reason my version of gluoncv does not have this model, however in general you can change the number of classes with the keyword classes, e.g.

model = gluoncv.model_zoo.get_model('resnet101_v1b',pretrained=False,classes=1)

I don’t know if there is anything more specific in the model you are trying to use.

Thanks for your input. Unforunately, this flag doesn’t seem to work with segmentation networks… (I have tried).

FYI: To get the model, follow the tutorial here

https://gluon-cv.mxnet.io/build/examples_segmentation/demo_deeplab.html

Looking at the source code, here it seems that when you are using the function get_model it preloads the set of classes because I am getting this error:

In [23]: model = gluoncv.model_zoo.get_model('deeplab_resnet101_ade', nclass=2)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-23-8a1eb9de3344> in <module>()
----> 1 model = gluoncv.model_zoo.get_model('deeplab_resnet101_ade', nclass=2)

/usr/local/lib/python3.5/dist-packages/gluoncv/model_zoo/model_zoo.py in get_model(name, **kwargs)
    184         err_str += '%s' % ('\n\t'.join(sorted(_models.keys())))
    185         raise ValueError(err_str)
--> 186     net = _models[name](**kwargs)
    187     return net
    188 

/usr/local/lib/python3.5/dist-packages/gluoncv/model_zoo/deeplabv3.py in get_deeplab_resnet101_ade(**kwargs)
    294     >>> print(model)
    295     """
--> 296     return get_deeplab('ade20k', 'resnet101', **kwargs)

/usr/local/lib/python3.5/dist-packages/gluoncv/model_zoo/deeplabv3.py in get_deeplab(dataset, backbone, pretrained, root, ctx, **kwargs)
    175     from ..data import datasets
    176     # infer number of classes
--> 177     model = DeepLabV3(datasets[dataset].NUM_CLASS, backbone=backbone, ctx=ctx, **kwargs)
    178     if pretrained:
    179         from .model_store import get_model_file

TypeError: __init__() got multiple values for argument 'nclass'

The parameter that defines the expected classes for this family of models is nclass, as described here and here. So what you can do, is use an alternative definition for the model you want (use nclass = 2 for a binary classification scheme, default functions in mxnet/gluon behave better):

model  = gluoncv.model_zoo.DeepLabV3(nclass=2,backbone='resnet101',pretrained_base=True)

I tested it with pretrained_base = False, then

In [20]: model.auxlayer
Out[20]: 
_FCNHead(
  (block): HybridSequential(
    (0): Conv2D(1024 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm(fix_gamma=False, eps=1e-05, axis=1, momentum=0.9, use_global_stats=False, in_channels=256)
    (2): Activation(relu)
    (3): Dropout(p = 0.1, axes=())
    (4): Conv2D(256 -> 2, kernel_size=(1, 1), stride=(1, 1))
  )
)

I don’t know, using this definition, on which dataset this model is trained.

edit alternatively, hack the definition of getting the model:

In [28]: def get_deeplab(nclass, dataset='pascal_voc', backbone='resnet50', pretrained=False,
    ...:             root='~/.mxnet/models', ctx=mx.cpu(0), **kwargs):
    ...:     r"""DeepLabV3
    ...:     Parameters
    ...:     ----------
    ...:     dataset : str, default pascal_voc
    ...:         The dataset that model pretrained on. (pascal_voc, ade20k)
    ...:     pretrained : bool or str
    ...:         Boolean value controls whether to load the default pretrained weights for model.
    ...:         String value represents the hashtag for a certain version of pretrained weights.
    ...:     ctx : Context, default CPU
    ...:         The context in which to load the pretrained weights.
    ...:     root : str, default '~/.mxnet/models'
    ...:         Location for keeping the model parameters.
    ...:     Examples
    ...:     --------
    ...:     >>> model = get_fcn(dataset='pascal_voc', backbone='resnet50', pretrained=False)
    ...:     >>> print(model)
    ...:     """
    ...:     acronyms = {
    ...:         'pascal_voc': 'voc',
    ...:         'pascal_aug': 'voc',
    ...:         'ade20k': 'ade',
    ...:         'coco': 'coco',
    ...:     }
    ...:     #from ..data import datasets
    ...:     # infer number of classes
    ...:     model = gluoncv.model_zoo.DeepLabV3(nclass, backbone=backbone, ctx=ctx, **kwargs)
    ...:     if pretrained:
    ...:         from .model_store import get_model_file
    ...:         model.load_parameters(get_model_file('deeplab_%s_%s'%(backbone, acronyms[dataset]),
    ...:                                              tag=pretrained, root=root), ctx=ctx)
    ...:     return model
    ...: 

In [29]: model = get_deeplab(2,'ade20k', 'resnet101')

In [30]: model.auxlayer
Out[30]: 
_FCNHead(
  (block): HybridSequential(
    (0): Conv2D(1024 -> 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm(fix_gamma=False, eps=1e-05, axis=1, momentum=0.9, use_global_stats=False, in_channels=256)
    (2): Activation(relu)
    (3): Dropout(p = 0.1, axes=())
    (4): Conv2D(256 -> 2, kernel_size=(1, 1), stride=(1, 1))
  )
)


1 Like

@feevos Thanks, this worked perfectly. I ran into the same bug as well trying to define nclass in the more general function.

Let’s see if it trains well.

Cheers!

1 Like

@feevos, just while I have your attention: To initialize that new layer, would you do the following…

model.auxlayer.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx, force_reinit=True)

Hi, I don’t know for sure, you’ll have to experiment (take a look at the source code as well, on how they use the pretrained weights). I hope someone else with more expertise can address this.

Yeah, I ran it with and without the initialization, and got the following error

AttributeError: ‘Symbol’ object has no attribute ‘shape’

stemming (via the stack trace) from

 18 
 19             with mx.autograd.record():

> 20 output = net(x)

 21                 loss = lossfunc(output, y)
 22             loss.backward()

So performing a forward pass breaks it.

I can do a forward pass, like this:

In [20]: model = get_deeplab(2,'ade20k', 'resnet101')

In [21]: out1, out2 = model(nd.random.uniform(shape=[2,3,256,256]))

In [22]: out1.shape
Out[22]: (2, 2, 480, 480)

In [23]: out2.shape
Out[23]: (2, 2, 480, 480)

does this help?

Hasn’t solved it but it narrows down the possiblities. Thanks again for the help! :blush:

1 Like

@feevos

I realized the problem now. It happens after I call model.hybridize(), then I try to train it on my data, and I get the following error…

AttributeError: ‘Symbol’ object has no attribute ‘shape’

This is similar to your discussion on Github, ironically

Any idea on how to fix this?

edit: Fixed spelling mistake

Yeah, but the solution I have in my mind ain’t pretty. You’ll need to hack the code.

Your definitions are in line 107 of file deeplabv3.py. The problem is that after you call hybridize, inside the hybrid_forward call of _AsppPooling the layer x is no longer an nd.array, so you cannot call x.shape any more (I have suffered from this in the past, a lot!). The layer _ASPP declares _AsppPooling in the definition line 126. So, you basically need to know the height and width (h,w) of the parameters. One way to learn that is to modify the definition of _AsppPooling:

class _AsppPooling(nn.HybridBlock):
    def __init__(self, in_channels, out_channels, norm_layer, norm_kwargs):
        super(_AsppPooling, self).__init__()
        self.gap = nn.HybridSequential()
        with self.gap.name_scope():
            self.gap.add(nn.GlobalAvgPool2D())
            self.gap.add(nn.Conv2D(in_channels=in_channels, channels=out_channels,
                                   kernel_size=1, use_bias=False))
            self.gap.add(norm_layer(in_channels=out_channels, **norm_kwargs))
            self.gap.add(nn.Activation("relu"))

    def hybrid_forward(self, F, x):
        _, _, h, w = x.shape
        # @@@@@@@@@@@ MODIFICATION HERE @@@@@@@@@@@@4
        print (h,w)
        # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
        pool = self.gap(x)
        return F.contrib.BilinearResize2D(pool, height=h, width=w)

then declare your model as usual, and call once the model with a random input image of the correct size. Then you’ll have the dimensions printed on screen. So now, you can go and modify again this layer, according to

class _AsppPooling(nn.HybridBlock):
    def __init__(self, in_channels, out_channels, norm_layer, norm_kwargs, h, w):
        super(_AsppPooling, self).__init__()
        # @@@@@@@@@@@ MOD HERE and above, see h,w arguments @@@@@@@@@@@@@
        self.h = h
        self.w = w
        # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@

        self.gap = nn.HybridSequential()
        with self.gap.name_scope():
            self.gap.add(nn.GlobalAvgPool2D())
            self.gap.add(nn.Conv2D(in_channels=in_channels, channels=out_channels,
                                   kernel_size=1, use_bias=False))
            self.gap.add(norm_layer(in_channels=out_channels, **norm_kwargs))
            self.gap.add(nn.Activation("relu"))

    def hybrid_forward(self, F, x):
        # @@@@@@@ MOD HERE @@@@@@@@@@@@@
        #_, _, h, w = x.shape
        pool = self.gap(x)
        return F.contrib.BilinearResize2D(pool, height=self.h, width=self.w)

And now you have to modify accordingly all DeeplabV3 definitions where this thing is getting declared (if you can’t make it, let me know). Basically you need to create a copy of the file
deeplabv3.py with new definitions and call the function get_deeplab(…) from the new file. Then it will work. It’s not tragically difficult, but not the best of things one would like to be doing. On the plus side, you’ll learn more about the models :).

Let me know how it goes (if I’ll have tiime, I’ll post a complete hack too),
Cheers.

1 Like

@feevos, thanks for taking the time to answer this. I will see what I can do.

1 Like

To wrap my head around this, the definition of _AsppPooling essentially makes this structure un-hybridiz-able, in its present state. Is this understanding correct?

1 Like

Yes, because it calls the shape of input,and Symbols do not have shape property.

Basically whenever a call to x.shape is being made, then this is no longer compatible with Symbol. I think the easiest way to go is to pass (not hybridized) a single forward pass to your model, and print the h,w parameters in this layer. Then go and define them in there directly, as constants, and remove the calls to x.shape. If you’re into Semantic segmentation, you must have come around also to the PSP pooling layer. See this issue for a way I found to make this hybridizable. It also has a way to solve your problem (final answer).

1 Like

Okay great. Will check out that github issue.

Thanks for the help!

1 Like