How to reset gradients on module?

Hi,

I am wondering, how to reset gradients to 0 in a module when it is bound with grad_req = 'add' (example below). In the R interface, there is an mx.exec.update.grad.arrays function, but I can’t find an equivalent in the python API documentation.

import mxnet as mx
from collections import namedtuple
Batch = namedtuple('Batch', ['data', 'label'])
data = mx.sym.Variable('data')
label = mx.sym.Variable('label')
net = mx.sym.FullyConnected(data = data, num_hidden=1, no_bias=True)
net = mx.sym.SoftmaxOutput(data = net, label = label)
mod = mx.mod.Module(net, label_names=['label'], data_names=['data'])
mod.bind(data_shapes = [('data', (1,1))], 
         label_shapes= [('label', (1,1))], 
         force_rebind=True, 
         grad_req='add', 
         inputs_need_grad=True)
mod.init_params()
mod.init_optimizer(optimizer = 'rmsprop', optimizer_params=(('learning_rate', 0.001), ('gamma2', 0), ('centered', True)))

inputData = mx.io.DataBatch(data = [mx.nd.ones((1,1))], label = [mx.nd.zeros((1,1))])
mod.forward(data_batch=inputData)
mod.backward()
mod.update() ## does not clear gradient arrays
mod._exec_group.grad_arrays ## are not 0
mod.update() ## updates again with the same gradients
## how can I get the gradients back to 0 so that a further call to mod.update() has no effect?

Thank you for your help.

You can loop through your gradient arrays and set them to zero:

import mxnet as mx
from collections import namedtuple
Batch = namedtuple('Batch', ['data', 'label'])
data = mx.sym.Variable('data')
label = mx.sym.Variable('label')
net = mx.sym.FullyConnected(data = data, num_hidden=1, no_bias=True)
net = mx.sym.SoftmaxOutput(data = net, label = label)
mod = mx.mod.Module(net, label_names=['label'], data_names=['data'])
mod.bind(data_shapes = [('data', (1,1))], 
         label_shapes= [('label', (1,1))], 
         force_rebind=True, 
         grad_req='add', 
         inputs_need_grad=True)
mod.init_params()
mod.init_optimizer(optimizer = 'rmsprop', optimizer_params=(('learning_rate', 0.001), ('gamma2', 0), ('centered', True)))

inputData = mx.io.DataBatch(data = [mx.nd.ones((1,1))], label = [mx.nd.zeros((1,1))])
mod.forward(data_batch=inputData)
mod.backward()
print(mod._exec_group.grad_arrays, mod._exec_group.input_grad_arrays)

mod.forward(data_batch=inputData)
mod.backward()
print(mod._exec_group.grad_arrays, mod._exec_group.input_grad_arrays)
mod.update() 


for grad in mod._exec_group.grad_arrays[0]:
    mx.nd.zeros_like(grad, out=grad)
for grad in mod._exec_group.input_grad_arrays[0]:
    mx.nd.zeros_like(grad, out=grad)

mod.forward(data_batch=inputData)
mod.backward()
print(mod._exec_group.grad_arrays, mod._exec_group.input_grad_arrays)

mod.forward(data_batch=inputData)
mod.backward()
print(mod._exec_group.grad_arrays, mod._exec_group.input_grad_arrays)
mod.update() 
[[
[[1.]]
<NDArray 1x1 @cpu(0)>]] [[
[[0.00838965]]
<NDArray 1x1 @cpu(0)>]]
[[
[[2.]]
<NDArray 1x1 @cpu(0)>]] [[
[[0.0167793]]
<NDArray 1x1 @cpu(0)>]]
[[
[[1.]]
<NDArray 1x1 @cpu(0)>]] [[
[[0.00505632]]
<NDArray 1x1 @cpu(0)>]]
[[
[[2.]]
<NDArray 1x1 @cpu(0)>]] [[
[[0.01011264]]
<NDArray 1x1 @cpu(0)>]]

Great,

thank you very much for your reply. Helped a lot.

1 Like