mse = mx.metric.create('mse')
x = mx.nd.array([1,2,3,4]) # 1D array
mse.update(x, x + 0.1)
print(mse.get()) # 0.01. OK
mse = mx.metric.create('mse')
x = mx.nd.array([[1,2,3,4]]) # 2D array
mse.update(x, x + 0.1)
print(mse.get()) # 2.509999 ... ??
At first I thought that when 2D arraies are given each row represents an output and the set of rows is treated as a batch. However, that seems not the case. What’s the second result 2.509999?