Gluon Multi GPU Out of Memory Issues

Hi, the problem is not going MultiGPU. Even single gpu has large memory footprint. Your network (removed bias where unnecessary to reduce params):

import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn

class VGG19(gluon.HybridBlock):
    def __init__(self, **kwargs):
        super(VGG19, self).__init__(**kwargs)
        with self.name_scope():
            self.features = nn.HybridSequential(prefix='')

            # Block 1
            self.features.add(nn.Conv2D(64, kernel_size=3, padding=1,
                                        weight_initializer=mx.init.Xavier(rnd_type='gaussian',
                                                                          factor_type='out',
                                                                          magnitude=2),use_bias=False
                                        ))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(64, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.MaxPool2D(strides=2))
            self.features.add(nn.Dropout(rate=0.25))

            # Block 2
            self.features.add(nn.Conv2D(128, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(128, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.MaxPool2D(strides=2))
            self.features.add(nn.Dropout(rate=0.25))

            # Block 3
            self.features.add(nn.Conv2D(256, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(256, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))
            self.features.add(nn.Conv2D(256, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(256, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.MaxPool2D(strides=2))
            self.features.add(nn.Dropout(rate=0.25))

            # Block 4
            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))
            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.MaxPool2D(strides=2))
            self.features.add(nn.Dropout(rate=0.25))

            # Block 5
            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))
            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.MaxPool2D(strides=2))
            self.features.add(nn.Dropout(rate=0.25))

            self.features.add(nn.Flatten())
            
            # Block 6
            self.features.add(nn.Dense(4096, activation="relu", weight_initializer="normal",use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))
            self.features.add(nn.Dropout(rate=0.5))

            # Block 7
            self.features.add(nn.Dense(4096, activation="relu", weight_initializer="normal",use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))
            self.features.add(nn.Dropout(rate=0.5))

            self.output = nn.Dense(2, weight_initializer="normal", bias_initializer="zeros")

    def hybrid_forward(self, F, x):
        x = self.features(x)
        x = self.output(x)
        return x

has ~600M parameters. The majority of which comes from the Dense_1 layer (the input to this is your problem):

net = VGG19()
ctx=mx.cpu()
net.initialize(ctx=ctx)
#net.hybridize(static_shape=True, static_alloc=True)
net.summary(mx.nd.ones((1, 3, 480, 640)))

Even if you use MobileNet convolutions, the problem will remain. What you can try and do, is reduce further the size of the input to the first dense layer. You can do so by adding more convolution layers (a deeper network - the deeper the better usually) that summarizes more the last conv feature, or use more aggressive MaxPooling in the last conv layer (e.g. pool = 4, stride = 4 or higher). For example, this modified architecture:

import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn

class VGG19(gluon.HybridBlock):
    def __init__(self, **kwargs):
        super(VGG19, self).__init__(**kwargs)
        with self.name_scope():
            self.features = nn.HybridSequential(prefix='')

            # Block 1
            self.features.add(nn.Conv2D(64, kernel_size=3, padding=1,
                                        weight_initializer=mx.init.Xavier(rnd_type='gaussian',
                                                                          factor_type='out',
                                                                          magnitude=2),use_bias=False
                                        ))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(64, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.MaxPool2D(strides=2))
            self.features.add(nn.Dropout(rate=0.25))

            # Block 2
            self.features.add(nn.Conv2D(128, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(128, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.MaxPool2D(strides=2))
            self.features.add(nn.Dropout(rate=0.25))

            # Block 3
            self.features.add(nn.Conv2D(256, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(256, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))
            self.features.add(nn.Conv2D(256, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(256, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.MaxPool2D(strides=2))
            self.features.add(nn.Dropout(rate=0.25))

            # Block 4
            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))
            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.MaxPool2D(strides=2))
            self.features.add(nn.Dropout(rate=0.25))

            # Block 5
            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))
            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.Conv2D(512, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))
                        

            self.features.add(nn.MaxPool2D(strides=2))
            self.features.add(nn.Dropout(rate=0.25))

            # @@@@@@@@@@@@@@@@@@@@@@@@ MOD here @@@@@@@@@@@@@@@@@@@@@@@@@@@@
            self.features.add(nn.Conv2D(1024, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.MaxPool2D(strides=2))
            self.features.add(nn.Dropout(rate=0.25))

            self.features.add(nn.Conv2D(1024, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.MaxPool2D(strides=2))
            self.features.add(nn.Dropout(rate=0.25))
            
            self.features.add(nn.Conv2D(2048, kernel_size=3, padding=1,use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))

            self.features.add(nn.MaxPool2D(strides=2))
            self.features.add(nn.Dropout(rate=0.25))
            # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
        
        
            self.features.add(nn.Flatten())
            
            # Block 6
            self.features.add(nn.Dense(4096, activation="relu", weight_initializer="normal",use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))
            self.features.add(nn.Dropout(rate=0.5))

            # Block 7
            self.features.add(nn.Dense(4096, activation="relu", weight_initializer="normal",use_bias=False))
            self.features.add(nn.BatchNorm())
            self.features.add(nn.Activation('relu'))
            self.features.add(nn.Dropout(rate=0.5))

            self.output = nn.Dense(2, weight_initializer="normal", bias_initializer="zeros")

    def hybrid_forward(self, F, x):
        x = self.features(x)
        x = self.output(x)
        return x

has “only” ~87M params. Play with this idea, and you can reduce the memory footprint even further.

All the best,
Foivos

edit: I strongly recommend, as @ThomasDelteil suggested, to move to a better architecture for usage (ResNet, DenseNet etc). Even implementing them on your own, if you don’t want the pre-defined models - from scratch (based on the papers), is relatively easy.

2 Likes