Custom loss + custom metric on R

Hi all and thanks in advance!

I would like to use a coustom loss and a custom metric at the same time. I know that he custom loss output is the gradient of the loss function, therefore I developed the following code example:

X = cbind(1, as.matrix(matrix(rnorm(1000)))
Y = X %*% c(-3, 3)

data = mx.symbol.Variable(‘data’)
label = mx.symbol.Variable(‘label’)
mdl = mx.symbol.FullyConnected(data, num_hidden=1)
mdlOut = mx.symbol.BlockGrad(mdlOut, name=‘mdlOut’)
mdlLoss = mx.symbol.MakeLoss(mx.symbol.square(label - mx.symbol.reshape(mdl, shape=0)), name = ‘mdlLoss’)

mdl_symb = mx.symbol.Group(mdlOut, mdlLoss)

metric = mx.metric.custom(feval = function(label, mdlOut_output) (as.array(label)-as.array(mdlOut_output))**2, name = ‘ssq’)

mdl_trained = mx.model.FeedForward.create(symbol=mdl_symb, X=X, y=Y, eval.metric=metric)

It seems to work, but I do not know if the metric take the output from mdlOut or mdlLoss. I tried to specifiy “output.names” in mx.model.FeedForward.create but I got always arrors.
Can you give some insight abuot the goodness of this code?

Nobody on this? It would be really appreciated to get some insights.

TL;DR: From what I understand, you cannot use eval.metric with custom loss function.

There is a tutorial, which you have probably seen, which explains how to create custom loss: https://mxnet.incubator.apache.org/versions/master/tutorials/r/CustomLossFunction.html They never use even predefined metric like MSE or MAE. They explain, that if you use custom loss, then the output of the model is the gradient of loss with respect to the input data. So, to do a real prediction, they have to get last FC2 layer manually.

I tried to play with their code, to understand if custom or predefined metric can work with custom loss. To do so, I used the last part of the tutorial, where they use either predefined loss MAERegressionOutput or custom loss, which has same logic as MAE loss: lro_abs <- mx.symbol.MakeLoss(mx.symbol.abs(mx.symbol.Reshape(fc2, shape = 0) - label)). Then I tried to use predefined MAE eval metric and compare the results.

Test 1: Using predefined loss and metric

My expectation is that if I do eval.metric each epoch and then do manual calculation of MAE based on training data, I should receive similar results. Here is the code I use:

data(BostonHousing, package = "mlbench")
BostonHousing[, sapply(BostonHousing, is.factor)] <-
  as.numeric(as.character(BostonHousing[, sapply(BostonHousing, is.factor)]))
BostonHousing <- data.frame(scale(BostonHousing))

test.ind = seq(1, 506, 5)    # 1 pt in 5 used for testing
train.x = data.matrix(BostonHousing[-test.ind,-14])
train.y = BostonHousing[-test.ind, 14]
test.x = data.matrix(BostonHousing[--test.ind,-14])
test.y = BostonHousing[--test.ind, 14]

require(mxnet)

data <- mx.symbol.Variable("data")
label <- mx.symbol.Variable("label")
fc1 <- mx.symbol.FullyConnected(data, num_hidden = 14, name = "fc1")
tanh1 <- mx.symbol.Activation(fc1, act_type = "tanh", name = "tanh1")
fc2 <- mx.symbol.FullyConnected(tanh1, num_hidden = 1, name = "fc2")
lro_mae <- mx.symbol.MAERegressionOutput(fc2, name = "lro")
mx.set.seed(0)

metric = mx.metric.custom(feval = function(label, out) (sum(as.array(label) - as.array(out))^2 / length(as.array(label))), name = 'ssq')

model2 <- mx.model.FeedForward.create(lro_mae, X = train.x, y = train.y,
                                      ctx = mx.cpu(),
                                      num.round = 5,
                                      array.batch.size = 80,
                                      optimizer = "rmsprop",
                                      verbose = TRUE,
                                      array.layout = "rowmajor",
                                      eval.metric = mx.metric.mae,
                                      batch.end.callback = NULL,
                                      epoch.end.callback = NULL)

internals = internals(model2$symbol)
fc_symbol = internals[[match("fc2_output", outputs(internals))]]

model3 <- list(symbol = fc_symbol,
               arg.params = model2$arg.params,
               aux.params = model2$aux.params)

class(model3) <- "MXFeedForwardModel"

pred2 <- predict(model2, train.x)
pred3 <- predict(model3, train.x)

sum(abs(train.y - pred2[1,])) / length(train.y)
sum(abs(train.y - pred3[1,])) / length(train.y)

If I run this code, I get the following output:

Start training with 1 devices
[1] Train-mae=0.712698568900426
[2] Train-mae=0.600305815537771
[3] Train-mae=0.450728197892507
[4] Train-mae=0.40242209037145
[5] Train-mae=0.395647222797076

...

> sum(abs(train.y - pred2[1,])) / length(train.y)
[1] 0.3761493
> sum(abs(train.y - pred3[1,])) / length(train.y)
[1] 0.3761493

As you can see the last two numbers are same, meaning that the output of the network and the output of FC2 is the same. They also quite well aligned with eval.metric output, though not exactly the same.

Test 2: Using manually created loss

My expectations that if I change loss to a custom one, but which is working the same, I should still have similar result. Here is my code and result (the difference only in using lro_abs):

data(BostonHousing, package = "mlbench")
BostonHousing[, sapply(BostonHousing, is.factor)] <-
  as.numeric(as.character(BostonHousing[, sapply(BostonHousing, is.factor)]))
BostonHousing <- data.frame(scale(BostonHousing))

test.ind = seq(1, 506, 5)    # 1 pt in 5 used for testing
train.x = data.matrix(BostonHousing[-test.ind,-14])
train.y = BostonHousing[-test.ind, 14]
test.x = data.matrix(BostonHousing[--test.ind,-14])
test.y = BostonHousing[--test.ind, 14]

require(mxnet)

data <- mx.symbol.Variable("data")
label <- mx.symbol.Variable("label")
fc1 <- mx.symbol.FullyConnected(data, num_hidden = 14, name = "fc1")
tanh1 <- mx.symbol.Activation(fc1, act_type = "tanh", name = "tanh1")
fc2 <- mx.symbol.FullyConnected(tanh1, num_hidden = 1, name = "fc2")
lro_abs <- mx.symbol.MakeLoss(mx.symbol.abs(mx.symbol.Reshape(fc2, shape = 0) - label))
mx.set.seed(0)

metric = mx.metric.custom(feval = function(label, out) (sum(as.array(label) - as.array(out))^2 / length(as.array(label))), name = 'ssq')

model2 <- mx.model.FeedForward.create(lro_abs, X = train.x, y = train.y,
                                      ctx = mx.cpu(),
                                      num.round = 5,
                                      array.batch.size = 80,
                                      optimizer = "rmsprop",
                                      verbose = TRUE,
                                      array.layout = "rowmajor",
                                      eval.metric = mx.metric.mae,
                                      batch.end.callback = NULL,
                                      epoch.end.callback = NULL)

internals = internals(model2$symbol)
fc_symbol = internals[[match("fc2_output", outputs(internals))]]

model3 <- list(symbol = fc_symbol,
               arg.params = model2$arg.params,
               aux.params = model2$aux.params)

class(model3) <- "MXFeedForwardModel"

pred2 <- predict(model2, train.x)
pred3 <- predict(model3, train.x)

sum(abs(train.y - pred2[1,])) / length(train.y)
sum(abs(train.y - pred3[1,])) / length(train.y)
Start training with 1 devices
[1] Train-mae=0.696901251872381
[2] Train-mae=0.669727434714635
[3] Train-mae=0.780241707960765
[4] Train-mae=0.781373461087545
[5] Train-mae=0.788071354230245

...

> sum(abs(train.y - pred2[1,])) / length(train.y)
Error in pred2[1, ] : incorrect number of dimensions
> sum(abs(train.y - pred3[1,])) / length(train.y)
[1] 0.3761493

As you can see, the results for eval.metric are different. If you try to caclulate the metric manually based on model output, then it actually fails with “incorrect number of dimensions”. The way of finding FC2 manually and calculating the metric still works and gives the same result as the previous test.

From that I had to conclude, that eval.metric (custom or not custom) doesn’t work same with custom loss as with predefined loss. I am not sure if there is a way to make custom metric work, but if you can get to FC2 output, then it should be possible.

Hey @scotty3005 ,
I implemented a custom loss as well as a custom metric for the CrazyAra project.
The code is written in python and gluon but is hopefully adaptable to R using MXNET’s symbol API.
The main difference is that I use mx.metric.create() instead of mx.metric.custom().

Here’s the code snippet for creating the custom loss: https://github.com/QueensGambit/CrazyAra/blob/master/DeepCrazyhouse/src/training/trainer_agent.py#L207

value_loss = self._l2_loss(value_out, value_label)
policy_loss = self._softmax_cross_entropy(policy_out, policy_label)
# weight the components of the combined loss
combined_loss = (
    self._val_loss_factor * value_loss.sum() + self._policy_loss_factor * policy_loss.sum()
)
combined_loss.backward()

I created a custom metric called acc_sign which measures the accuracy how often the sign was correctly detected for a value in the range [-1,+1].

def acc_sign(y_true, y_pred):
    """
    Custom metric which is used to predict the winner of a game
    :param y_true: Ground truth value (np-array with values between -1, 0)
    :param y_pred: Predicted labels as numpy array
    :return:
    """
return (np.sign(y_pred).flatten() == y_true).sum() / len(y_true) 

Then I add my custom metric to the list of metrics.

metrics = {
# ...
'value_acc_sign': mx.metric.create(acc_sign, name='value_acc_sign', output_names=['value_output'],
                                     label_names=['value_label']),
}

In order to update the custom metric I provide the appropriate data.

 metrics["value_acc_sign"].update(preds=value_out, labels=value_label)

Thanks a lot for your feedback.
Starting from the code proposed by Sergey, I developed the following wich seems to work for R with symbol API.

Script:

data(BostonHousing, package = “mlbench”)
BostonHousing[, sapply(BostonHousing, is.factor)] ← as.numeric(as.character(BostonHousing[, sapply(BostonHousing, is.factor)]))
BostonHousing ← data.frame(scale(BostonHousing))
test.ind = seq(1, 506, 5) # 1 pt in 5 used for testing
train.x = data.matrix(BostonHousing[-test.ind,-14])
train.y = BostonHousing[-test.ind, 14]
test.x = data.matrix(BostonHousing[–test.ind,-14])
test.y = BostonHousing[–test.ind, 14]

require(mxnet)

data ← mx.symbol.Variable(“data”)
label ← mx.symbol.Variable(“label”)
fc1 ← mx.symbol.FullyConnected(data, num_hidden = 14, name = “fc1”)
tanh1 ← mx.symbol.Activation(fc1, act_type = “tanh”, name = “tanh1”)
fc2 ← mx.symbol.FullyConnected(tanh1, num_hidden = 1, name = “fc2”)
pred ← mx.symbol.BlockGrad(fc2, name = ‘pred’)
lro_abs ← mx.symbol.MakeLoss(mx.symbol.abs(mx.symbol.Reshape(fc2, shape = 0) - label))

mx.set.seed(0)

mdl = mx.symbol.Group(pred, lro_abs)

metric = mx.metric.custom(feval = function(label, out) (sum(as.array(label) - as.array(out))^2 / length(as.array(label))), name = ‘ssq’)

mdl ← mx.model.FeedForward.create(mdl, X = train.x, y = train.y, ctx = mx.cpu(), num.round = 5, rray.batch.size = 80, optimizer = “rmsprop”, verbose = TRUE, array.layout = “rowmajor”, eval.metric = mx.metric.mae, batch.end.callback = NULL, epoch.end.callback = NULL)

pred1 ← predict(mdl, train.x)

internals = internals(mdl$symbol)
mdl_symbol = internals[[match(“pred_output”, outputs(internals))]]
mdl2 ← list(symbol = mdl_symbol, arg.params = mdl$arg.params, aux.params = mdl$aux.params)
class(mdl2) ← “MXFeedForwardModel”
pred2 ← predict(mdl2, train.x)

internals = internals(mdl$symbol)
mdl_symbol = internals[[match(“fc2_output”, outputs(internals))]]
mdl3 ← list(symbol = mdl_symbol, arg.params = mdl$arg.params, aux.params = mdl$aux.params)
class(mdl3) ← “MXFeedForwardModel”
pred3 ← predict(mdl3, train.x)

sum(abs(train.y - pred1[1,])) / length(train.y)
sum(abs(train.y - pred2[1,])) / length(train.y)
sum(abs(train.y - pred3[1,])) / length(train.y)

Output:

Start training with 1 devices
[1] Train-mae=0.712698568900426
[2] Train-mae=0.600305805603663
[3] Train-mae=0.450728153189023
[4] Train-mae=0.402422085404396
[5] Train-mae=0.395647197961807

[1] 0.3761493
[1] 0.3761493
[1] 0.3761493

AS s you can see the 3 outputs are the same and quite close to the las value returned by the metric during the training.
What do you think?

In particular, I do not understand if this is generally applicable or if it works only in this simple case.

Well, it seems to work for me - I can run your code and get similar result.

It doesn’t seem too specific for me, but I am not entirely understanding how optimizer finds the values of the loss function?

I mean, the symbol you pass is the group of pred and lro_abs. It seems that custom metric receives just the first output (if you change the order from Group(pred, lro_abs) to Group(lro_abs, pred) you will see similar behavior as in my previous comment). But how does optimization happens is something I don’t entirely understand…

Hi Sergey I completely agree and chaging the order of the group does matter. It looks like the optimisation is able to find the loss while the model ouptut is always the first of the group. May be some one more familiar with the source code could give us more indights. For the moment thanks for your feedback!!!