Training is faster when get_params() is called every mini-batch

I’m training a resent-50 model for image classification. I wanted to track the norm of my parameters so I wrote a custom call back to compute them and a custom mx.mod.Module with the fit() function overridden. The only thing I changed in the fit function was to add a call to get_params() to pull the parameters from the gpus after every update. My training loop in that function looks like this:

        while not end_of_batch:
            data_batch = next_data_batch
            if monitor is not None:
                # pre fetch next batch
                next_data_batch = next(data_iter)
            except StopIteration:
                end_of_batch = True

            self.update_metric(eval_metric, data_batch.label)

            if monitor is not None:

            arg_params, aux_params = self.get_params() # the only line I added

            if batch_end_callback is not None:
                batch_end_params = mx.model.BatchEndParam(epoch=epoch, nbatch=nbatch,
                for callback in _as_list(batch_end_callback):
            nbatch += 1

This works but I notice something odd: when I added the call to get_params() the trainin speed increased. Using the standard mx.mod.Module I got an average training speed of 750 samples/sec but with that line I get an average speed of 1050 samples/sec.

Any idea why this would speed things up?

1 Like

@piiswrong that’s wired. because get_params() adds some synchronization?