SSD - MultiboxTarget returns 0 for everything! What does the function do in detail?

I am trying to train a SSD network to do line segmentation on form data. [Text, Text_field, check_box]. It has has 4 classes (including background).

Size of each image is 1675x1250.

The dataformat that I am passing as bounding box coordinate to the network is in form of
[class, X_min, Y_min, X_max, Y_max]

The network that have designed is attached below.

class SSDNetwork(gluon.Block):

    def __init__(self, num_classes, **kwargs):
        self.anchor_sizes = model_conf.SSD_ANCHOR_SIZE
        self.anchor_ratios = model_conf.SSD_ANCHOR_RATIO

        self.num_anchors = len(self.anchor_sizes)
        self.num_classes = num_classes
        self.ctx = model_conf.CTX

        super(SSDNetwork, self).__init__(**kwargs)
        with self.name_scope():
            self.body, self.downsamples, self.class_pred, self.box_pred = self.get_ssd_network()
            self.downsamples.initialize(mx.init.Normal(), ctx=self.ctx)
            self.class_pred.initialize(mx.init.Normal(), ctx=self.ctx)
            self.box_pred.initialize(mx.init.Normal(), ctx=self.ctx)

    def down_samplers(self, channels):
        output = gluon.nn.HybridSequential()
        for _ in range(2):
            output.add(gluon.nn.Conv2D(channels, 3, strides=1, padding=1))
            output.add(gluon.nn.BatchNorm(in_channels=channels))
            output.add(gluon.nn.Activation('relu'))
        output.add(gluon.nn.MaxPool2D(2))
        output.hybridize()
        return output

    def predicted_boxes(self, num_anchors_predicted):
        pred_box = gluon.nn.HybridSequential()
        with pred_box.name_scope():
            pred_box.add(gluon.nn.Conv2D(
                channels=num_anchors_predicted*4, kernel_size=3, padding=1))
        return pred_box

    def class_prediction(self, num_anchors_predicted):
        return gluon.nn.Conv2D(num_anchors_predicted*(self.num_classes + 1), kernel_size=3, padding=1)

    def get_resnet_34(self):

        pretrained = resnet34_v1(pretrained=True, ctx=self.ctx)
        pretrained_2 = resnet34_v1(pretrained=True, ctx=self.ctx)
        first_weights = pretrained_2.features[0].weight.data().mean(
            axis=1).expand_dims(axis=1)

        body = gluon.nn.HybridSequential()
        with body.name_scope():
            first_layer = gluon.nn.Conv2D(channels=64, kernel_size=(7, 7), padding=(
                3, 3), strides=(2, 2), in_channels=1, use_bias=False)
            first_layer.initialize(mx.init.Normal(), ctx=self.ctx)
            first_layer.weight.set_data(first_weights)
            body.add(first_layer)
            body.add(*pretrained.features[0:-3])
        return body

    def get_ssd_network(self):
        body = self.get_resnet_34()
        downsamples = gluon.nn.HybridSequential()
        class_preds = gluon.nn.HybridSequential()
        box_preds = gluon.nn.HybridSequential()

        downsamples.add(self.down_samplers(128))
        downsamples.add(self.down_samplers(128))
        downsamples.add(self.down_samplers(128))

        for _ in range(self.num_anchors):
            num_anchors_predicted = len(
                self.anchor_sizes[0]) + len(self.anchor_ratios[0]) - 1
            class_preds.add(self.class_prediction(num_anchors_predicted))
            box_preds.add(self.predicted_boxes(num_anchors_predicted))
        return body, downsamples, class_preds, box_preds

    def ssd_forward(self, x):
        x = self.body(x)
        default_anchors = []
        predicted_boxes = []
        predicted_classes = []

        for i, (box_predictor, class_predictor) in enumerate(zip(self.box_pred, self.class_pred)):
            default_anchors.append(MultiBoxPrior(
                x, sizes=self.anchor_sizes[i], ratios=self.anchor_ratios[i]))
            predicted_boxes.append(self._change_channel_rep(box_predictor(x)))
            predicted_classes.append(
                self._change_channel_rep(class_predictor(x)))
            if i < len(self.downsamples):
                x = self.downsamples[i](x)
            elif i == 3:
                x = nd.Pooling(x, global_pool=True,
                               pool_type='max', kernel=(4, 4))

        return default_anchors, predicted_boxes, predicted_classes

    def forward(self, x):
        default_anchors, predicted_classes, predicted_boxes = self.ssd_forward(
            x)
        # we want to concatenate anchors, class predictions, box predictions from different layers
        anchors = nd.concat(*default_anchors, dim=1)
        box_preds = nd.concat(*predicted_boxes, dim=1)
        class_preds = nd.concat(*predicted_classes, dim=1)
        class_preds = nd.reshape(
            class_preds, shape=(0, -1, self.num_classes + 1))
        return anchors, class_preds, box_preds

    def _change_channel_rep(self, x):
        return nd.flatten(nd.transpose(x, axes=(0, 2, 3, 1)))

    def training_targets(self, default_anchors, class_predicts, labels):
        print("Got till the training targets functions")
        class_predicts = nd.transpose(class_predicts, axes=(0, 2, 1))
        box_target, box_mask, cls_target = MultiBoxTarget(
            default_anchors, labels, class_predicts)
        return box_target, box_mask, cls_target

The anchor size and anchor ratios are basically, I have 7 of them covering whole page.

SSD_ANCHOR_SIZE = [[.1, .2], [.2, .3], [.2, .4],
               [.4, .6], [.5, .7], [.6, .8], [.7, .9]]
SSD_ANCHOR_RATIO = [[1, 3, 5], [1, 3, 5], [1, 6, 8],
                    [1, 5, 7], [1, 6, 8], [1, 7, 9], [1, 7, 10]]

So entire network is basically Resnet34-> [Downsampler (128) , class_prob, box_pred] -> [Downsampler (128) , class_prob, box_pred] -> [Downsampler (128) , class_prob, box_pred].

As I am using scanned documents, I have converted the restnet’s first layer to accept back and white datapoints with (1 channel instead of 3).

I have created my own FormDataGenerator which parsers the data and puts it in Dataloader. I can post code if needed. (But it returns batch size of 5 with image, 117 padded bounding box).

Here is the example of the dataset so better understanding can be obtained of what we are dealing with,

21%20PM

Here is my training code run epoch,

def run_epoch(e, dataloader, network, data_type, trainer, update_network, update_metric, save_cnn):

    total_loss = []
    for i, (X, Y) in enumerate(dataloader):

        if not isinstance(model_conf.CTX, list):
            X = X.as_in_context(model_conf.CTX)
            Y = Y.as_in_context(model_conf.CTX)
        else:
            total_losses = [nd.zeros(1, ctx_i) for ctx_i in model_conf.CTX]
            X = gluon.utils.split_and_load(X, model_conf.CTX)
            Y = gluon.utils.split_and_load(Y, model_conf.CTX) 

        with autograd.record():
            default_anchors, class_predictions, box_predictions = network(X)
            box_target, box_mask, cls_target = network.training_targets(default_anchors, class_predictions, Y)

            loss_class = cls_loss(class_predictions, cls_target)
            loss_box = box_loss(box_predictions, box_target, box_mask)
            # sum all losses
            loss = loss_class + loss_box

        if update_network:
            loss.backward()
            if isinstance(model_conf.CTX, list):
                step_size = 0
                step_size = map(sum, [x.shape[0] for x in X])
                # for x in X:
                #     step_size += x.shape[0]
                trainer.step(step_size) 
            else:    
                step_size = X.shape[1]
                trainer.step(step_size)

        if isinstance(model_conf.CTX,list):
                # for index, l in enumerate(loss):
                #     total_losses[index] += l.mean()/len(model_conf.CTX)
                total_losses = map(lambda loss: loss.mean()/len(model_conf.CTX), loss)
        else:
            mean_loss = loss.mean().asnumpy()[0]
            total_loss.append(mean_loss)

        if update_metric:
        	cls_metric.update([cls_target], [nd.transpose(class_predictions, (0, 2, 1))])
        	box_metric.update([box_target], [box_predictions * box_mask])

    	# if i == 0 and e % model_conf.IMAGE_TEST == 0 and e > 0:
     #        cls_probs = nd.SoftmaxActivation(nd.transpose(class_predictions, (0, 2, 1)), mode='channel')
     #        output_image, number_of_bbs = generate_output_image(box_predictions, default_anchors,
     #                                                            cls_probs, box_target, box_mask,
     #                                                            cls_target, x, y)
     #        print("Number of predicted {} BBs = {}".format(data_type, number_of_bbs))
     #        with SummaryWriter(logdir=log_dir, verbose=False, flush_secs=5) as sw:
     #            sw.add_image('bb_{}_image'.format(data_type), output_image, global_step=e)

    if isinstance(model_conf.CTX,list):
        total_loss = 0
        for loss in total_losses:
            total_loss = loss.asscalar()
        epoch_loss = float(total_loss)/len(dataloader)
    else:  
        epoch_loss = float(sum(total_loss)/len(total_loss))

    with SummaryWriter(logdir=model_conf.LOG_DIR, verbose=False, flush_secs=5) as sw:
        if update_metric:
            name1, val1 = cls_metric.get()
            name2, val2 = box_metric.get()
            sw.add_scalar(name1, {"test": val1}, global_step=e)
            sw.add_scalar(name2, {"test": val2}, global_step=e)
        sw.add_scalar('loss', {data_type: epoch_loss}, global_step=e)

    if save_cnn and e % model_conf.CHECKPOINT_EPOCH == 0 and e > 0:
        file_name = model_conf.CHECKPOINT_NAME.split(".")
        date_today = datetime.datetime.today().strftime('%Y-%m-%d')
        file_name = file_name[0]+"_"+str(e)+"."+file_name[1]
        file_path = os.path.join(model_conf.CHECKPOINT_DIR, date_today)
        if not os.path.exists(file_path):
            os.makedirs(file_path)
        network.save_parameters(os.path.join(file_path, file_name))

    return epoch_loss

And the problem that I get is here,

 with autograd.record():
            default_anchors, class_predictions, box_predictions = network(X)
            box_target, box_mask, cls_target = network.training_targets(default_anchors, class_predictions, Y)

Here all the box_target, box_mask and cls_target is 0. But network does return default_anchors, class_predictions and box_predictions

Am I doing something wrong with the data? Am I passing in the data correctly? I have small dataset as of now with 30 odd labeled documents, but that doesn’t mean it would return 0 for the training targets. Is my dataset format correct of [center, X_min, Y_min, X_max, Y_max] as labels?

And lastly what exactly does this function do? I have been trying to find resources on this a lot. But I have not been able to come up with one such resource that has provided me with good answer.

Thanks