Unable to run the 4 multi operators for Neural Network Optimizers viz multi_mp_sgd_mom_update
, mx.nd.multi_mp_sgd_update
, mx.nd.multi_sgd_mom_update
and mx.nd.multi_sgd_update
There’s no NDArray API doc for the same (although it does exist) (https://github.com/apache/incubator-mxnet/issues/15643)
>>> help(mx.nd.multi_mp_sgd_mom_update)
However, unable to find a way to pass NDArray[]
Getting errors like
AssertionError: Positional arguments must have NDArray type, but got [[1. 2. 3.]]
<NDArray 3 @cpu(0)>]
AssertionError: Positional arguments must have NDArray type, but got (1.0, 2.0)
mxnet.base.MXNetError: Required parameter lrs of tuple of <float> is not presented, in operator multi_mp_sgd_mom_update(name="")
Definition of multi_mp_sgd_mom_update
multi_mp_sgd_mom_update(*data, **kwargs)
Momentum update function for multi-precision Stochastic Gradient Descent (SGD) optimizer.
Momentum update has better convergence rates on neural networks. Mathematically it looks
like below:
.. math::
v_1 = \alpha * \nabla J(W_0)\\
v_t = \gamma v_{t-1} - \alpha * \nabla J(W_{t-1})\\
W_t = W_{t-1} + v_t
It updates the weights using::
v = momentum * v - learning_rate * gradient
weight += v
Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch.
Defined in src/operator/optimizer_op.cc:L470
Parameters
----------
data : NDArray[]
Weights
lrs : tuple of <float>, required
Learning rates.
wds : tuple of <float>, required
Weight decay augments the objective function with a regularization term that penalizes large weights. The penalty scales with the square of the magnitude of each weight.
momentum : float, optional, default=0
The decay rate of momentum estimates at each epoch.
rescale_grad : float, optional, default=1
Rescale gradient to grad = rescale_grad*grad.
clip_gradient : float, optional, default=-1
Clip gradient to the range of [-clip_gradient, clip_gradient] If clip_gradient <= 0, gradient clipping is turned off. grad = max(min(grad, clip_gradient), -clip_gradient).
num_weights : int, optional, default='1'
Number of updated weights.
out : NDArray, optional
The output NDArray to hold the result.
Returns
-------
out : NDArray or list of NDArrays
The output of this function.