Implementing lovasz_loss for keras-mxnet

Hi,

I’ve been trying to port an implementation of the lovasz_loss, but I’ve run into a few issues.

Here is the original:

# code download from: https://github.com/bermanmaxim/LovaszSoftmax
"""
Computes gradient of the Lovasz extension w.r.t sorted errors
See Alg. 1 in paper
"""
gts = tf.reduce_sum(gt_sorted)
intersection = gts - tf.cumsum(gt_sorted)
union = gts + tf.cumsum(1. - gt_sorted)
jaccard = 1. - intersection / union
jaccard = tf.concat((jaccard[0:1], jaccard[1:] - jaccard[:-1]), 0)
return jaccard

# --------------------------- BINARY LOSSES ---------------------------

def lovasz_hinge(logits, labels, per_image=True, ignore=None):
"""
Binary Lovasz hinge loss
logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
per_image: compute the loss per image instead of per batch
ignore: void class id
"""
if per_image:
def treat_image(log_lab):
log, lab = log_lab
log, lab = tf.expand_dims(log, 0), tf.expand_dims(lab, 0)
log, lab = flatten_binary_scores(log, lab, ignore)
return lovasz_hinge_flat(log, lab)
losses = tf.map_fn(treat_image, (logits, labels), dtype=tf.float32)
loss = tf.reduce_mean(losses)
else:
loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore))
return loss

def lovasz_hinge_flat(logits, labels):
"""
Binary Lovasz hinge loss
logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
labels: [P] Tensor, binary ground truth labels (0 or 1)
ignore: label to ignore
"""

def compute_loss():
labelsf = tf.cast(labels, logits.dtype)
signs = 2. * labelsf - 1.
errors = 1. - logits * tf.stop_gradient(signs)
errors_sorted, perm = tf.nn.top_k(errors, k=tf.shape(errors)[0], name="descending_sort")
gt_sorted = tf.gather(labelsf, perm)
return loss

# deal with the void prediction case (only void pixels)
loss = tf.cond(tf.equal(tf.shape(logits)[0], 0),
lambda: tf.reduce_sum(logits) * 0.,
compute_loss,
strict=True,
name="loss"
)
return loss

def flatten_binary_scores(scores, labels, ignore=None):
"""
Flattens predictions in the batch (binary case)
Remove labels equal to 'ignore'
"""
scores = tf.reshape(scores, (-1,))
labels = tf.reshape(labels, (-1,))
if ignore is None:
return scores, labels
valid = tf.not_equal(labels, ignore)
return vscores, vlabels

def lovasz_loss(y_true, y_pred):
y_true, y_pred = K.cast(K.squeeze(y_true, -1), 'int32'), K.cast(K.squeeze(y_pred, -1), 'float32')
#logits = K.log(y_pred / (1. - y_pred))
logits = y_pred #Jiaxin
loss = lovasz_hinge(logits, y_true, per_image = True, ignore = None)
return loss


And here is what I’ve added to keras/backend/mxnet_backend.py:

def lovasz_grad(gt_sorted):
"""
Computes gradient of the Lovasz extension w.r.t sorted errors
See Alg. 1 in paper
"""
gts = mx.sym.sum(gt_sorted) #tf.reduce_sum(gt_sorted)
intersection = gts - np.cumsum(gt_sorted) #tf.cumsum(gt_sorted)
union = gts + np.cumsum(1. - gt_sorted) #tf.cumsum(1. - gt_sorted)
jaccard = 1. - intersection / union
jaccard = mx.sym.concat(jaccard[0:1], jaccard[1:] - jaccard[:-1], axis=0) #tf.concat((jaccard[0:1], jaccard[1:] - jaccard[:-1]), 0)
return jaccard

def lovasz_hinge_flat(logits, labels):
"""
Binary Lovasz hinge loss
logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
labels: [P] Tensor, binary ground truth labels (0 or 1)
ignore: label to ignore
"""

def compute_loss():
labelsf = mx.sym.cast(labels, logits.dtype) #tf.cast(labels, logits.dtype)
signs = 2. * labelsf - 1.
errors_sorted, perm = mx.sym.topk(errors, k=, name="descending_sort")
#tf.nn.top_k(errors, k=tf.shape(errors)[0], name="descending_sort")
gt_sorted = mx.sym.gather_nd(labelsf, perm) #tf.gather(labelsf, perm)
return loss

# deal with the void prediction case (only void pixels)
"""
loss = tf.cond(tf.equal(tf.shape(logits)[0], 0),
lambda: tf.reduce_sum(logits) * 0.,
compute_loss,
strict=True,
name="loss"
)
"""
lambda: mx.sym.sum(logits) * 0.,
compute_loss,
name="loss"
)
return loss

def flatten_binary_scores(scores, labels, ignore=None):
"""
Flattens predictions in the batch (binary case)
Remove labels equal to 'ignore'
"""
scores = mx.sym.reshape(scores, (-1,)) #tf.reshape(scores, (-1,))
labels = mx.sym.reshape(labels, (-1,)) #tf.reshape(labels, (-1,))
if ignore is None:
return scores, labels
valid = my.sym.broadcast_not_equal(labels, ignore) #tf.not_equal(labels, ignore)
return vscores, vlabels

@keras_mxnet_symbol
def lovasz_hinge(logits, labels, per_image=True, ignore=None):
"""
Binary Lovasz hinge loss
logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
per_image: compute the loss per image instead of per batch
ignore: void class id
"""
if per_image:
def _step(log_lab):
log, lab = log_lab
log, lab = mx.sym.expand_dims(log, 0), mx.sym.expand_dims(lab, 0) #tf.expand_dims(log, 0), tf.expand_dims(lab, 0)
log, lab = flatten_binary_scores(log, lab, ignore)
return lovasz_hinge_flat(log, lab)
losses = mx.sym.foreach(_step, (logits, labels), []) #tf.map_fn(treat_image, (logits, labels), dtype=tf.float32)
loss = mx.sym.mean(losses) #tf.reduce_mean(losses)
else:
loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore))
return KerasSymbol(loss)


Here are the issues I’ve encountered:

• cumsum doesn’t seem to be implemented yet and I’m using np.cumsum instead. If this doesn’t work, I’ll possibly implement the operation.
• boolean_mask hopefully will be added soon (https://github.com/apache/incubator-mxnet/pull/12400/commits/0081ecb2f91875e5ae029b99027df96056016311) but it seems a python binding still needs to be added
• in lovasz_hinge_flat, I have to get a dimension size but I’m not sure I can use infer_shape since I don’t always know the input size

If anyone has suggestions on these or spots obvious errors in my code. Please let me know!

As discussed off line, for cumsum the current workaround is to use numpy. However, for dynamic shape, keras-mxnet requires support in mxnet symbol interface, which may come at a later time. We will add this support once it’s out, it will be a major improvement on keras-mxnet especially on RNN use cases.

Again thanks for trying out keras-mxnet!

Thanks again for looking into it!