How to resume training with optimizer status

I want to realize the “resume training” function for my training program. But I don’t know how to correctly resume the optimizer status.

My program is like this:

opt = mx.optimizer(learning_rate=lr, ....)

ctx = [...]
sym = get_symbol() # The function define network
model = mx.mod.Module(sym=sym, ctx=ctx)

model.fit(...)

Now I want to save the model after training 1k steps and then resume it from the checkpoint. Since the optimizer status are also required to be resumed (i.e. The momentum of parameters for a momentum optimizer), I use the mxnet.Module API, and the codes to perform saving and loading are:

##### save #####
def batch_callback(params):
    if global_step == 1000:
        model.save_checkpoint(prefix, 0, save_optimizer_states=True)
        sys.exit(0)

The batch_callback is registered to the model.fit() function.

##### load #####
model = mx.mod.Module.load(prefix, 0, load_optimizer_states=True)
model.bind(...)
arg_params, aux_params=model.get_params()

model.fit(optimizer = opt, optimizer_params=('learning_rate', args.lr),
          arg_params=arg_params, aux_params=aux_params,
          batch_end_callback = batch_callback)

However, I find that the model is not correctly resumed. The results are quite bad. I am not sure but it seems that the parameters of model are random initialized rather than load from checkpoint.

So, what is the correct way to resume training with resuming optimizer status?

The code snippets look ok. Can you provide a small reproducible example so that I can debug the issue?

In general saving and loading a model with optimizer states can be done the following way:

Save:

model.save_checkpoint("test", 0, save_optimizer_states=True)

Load:

model = mx.mod.Module.load("test", 0, load_optimizer_states=True)
model.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)
model.init_optimizer(optimizer='sgd', optimizer_params=(('learning_rate', 0.1), ))