Alternated training parts of a model with mx.sym

Suppose you have a model that has two parts (partA and partB, partA+partB = whole model), you want to train partA for a while leaving partB fixed, and then train partB while leaving partA fixed.

In Gluon, it seems like I can do this by using something the following

partA_trainer = gluon.Trainer(net.partA.collect_params(), 'adam', {'learning_rate': lr})
partB_trainer = gluon.Trainer(net.partB.collect_params(), 'adam', {'learning_rate': lr})

Then

with autograd.record():
    loss = net(data)

if (epoch // alternate_epochs) % 2 == 0:
    training_info = '<<< partA training only >>>'
    partA_trainer.step(data.shape[0])
else:
    training_info = '<<< partB training only >>>'
    partB_trainer.step(data.shape[0])

I wonder if there’s an easy way to do similar things with mx.sym API

This github issue seems to be what you are looking for:


Freeze the gradients of the parts that you don’t want to update