How to weight terms in softmax cross entropy loss based on value of class label

I am trying to do image classification with an unbalanced data set, and I want to rescale each term of the cross entropy loss function to correct for this imbalance. For example, if I have 2 classes with 100 images in class 0 and 200 images in class 1, then I would want to weight the loss function terms involving examples from class 0 with a factor 2/3 and those terms involving class 1 with a factor 1/3. In other words, I want to compute the weighted cross entropy loss as follows given the softmax outputs and label for a given example, which I will denote by (softmax_output, label):

f(softmax_output, label) = -label[0]*log(softmax_output[0])*(2/3) - label[1]*log(softmax_output[1])*(1/3)

For the sake of definiteness, suppose I want to use a pretrained model on imagenet1k to do this. My idea of how to approach this so far is based on first stripping off the last layer of the network, adding back a softmax activation layer, and then using MakeLoss. Unfortunately, I have some holes in my understanding of the API, and it is difficult to find parts of the documentation that address this use case. Any help would be appreciated.

Incorrect Code Below to Illustrate Intended Approach

import mxnet as mx
from common import data, fit, modelzoo

(prefix, epoch) = modelzoo.download_model('imagenet1k-resnext-101-64x4d', '/path/to/model/location')
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
all_layers = symbol.get_internals()
net = all_layers['flatten0_output']

num_classes = 2
net = mx.symbol.FullyConnected(data=net, num_hidden=num_classes, name='fc')
net = mx.symbol.softmax(data=net, name='softmax_activation')

From here I’m not exactly sure how to proceed, but hopefully this can be corrected.

label = mx.symbol.Variable("label")
ce = -labe[0]*(2.0/3)*mx.sym.log(net[0]) - label[1]*(1.0/3)*mx.sym.log(net[1])
loss = mx.sym.MakeLoss(ce, normalization='batch', name='weighted_cross_entropy')
new_args = dict({k:arg_params[k] for k in arg_params if 'fc' not in k}) = args, network = new_sym, data_loader = data.get_rec_iter, arg_params = new_args, aux_params = aux_params)

In particular, I don’t think I’m using the label Variable correctly, among other things. Any help would be much appreciated.

Does this custom operator work for you?

The link is broken. Could you please provide a working one?

Files have been moved around in master branch. Here is the link in 1.0.0 version:

1 Like