Proper usage of BatchNorm during inference?

It’s not clear to me what I should be doing during inference so that the BatchNorm layers behave correctly. The documentation makes reference to “use_global_stats” and mentions that this is “often used during inference.” But if the model was built and saved during training, how and when would we change that parameter? Currently:

–> define module
–> fit module
–> score module

The final rmse numbers I see during the fit process are around 0.04. But when I run score on the same train iter that was used during the fit, the rmse is around 0.14. I finally isolated the problem to something with the BatchNorm layers: if I remove those, the rmse from the fit process is indeed about the same as the rmse from the score (as I would expect).

I’m probably not handling the infer-time BatchNorm params properly, but I don’t really understand from the documentation what I should be doing.

Did you set use_global_stats=True or =False? It it if set to true, then the running average estimates (global stats) will be used which have been computed during training. The difference you are seeing could be caused by using local stats in training and using global stats in testing.

use_global_stats=True is only relevant during training. But in general it is better to not use it, because if you don’t use pre-trained weights, the running average estimates will be set to 0 and 1. You can find a very detailed description here: Question about batch normalization

Yes, I had already seen that other post, but it didn’t help. My problem (I think) is that I don’t really understand what the usage pattern is supposed to be (and I can’t find any fully-fleshed out examples of training and then using a model with BatchNorm). Let me be clear about my questions (and I apologize if these seem silly or obvious; I’m still learning MXNet).

Question 1: How do I set a parameter such as “use_global_stats” differently when training vs inferring? In my (limited) understanding, the model is defined then trained then used for inference. The model - including the “use_global_stats” param - is defined in the first stage. How can I change the “use_global_stats” param AFTER the model has been defined and trained? Do I have to create a new model and then transfer the params from the trained model into the new variant? It would really, really help to see a COMPLETE example of using BatchNorm – from definition to training to saving to reloading to inferring. If that example exists, I certainly haven’t been able to find it.

Question 2: In my example, I did nothing at all special. I defined the model (without specifying “use_global_stats”). I trained the model. I then did scoring with the trained model using the same data as was used for training, just to try and verify/debug the results. I did not change anything between defining, training and scoring with the model. My question is: What is the expected behavior of the score function (in terms of BatchNorm) in this situation? What mean and variance was being used? Should the mean and variance learned by the training process have survived and been utilized automatically in this case? If not, how do I capture that data and pass it on to the model during inference and/or scoring?

Again, I think all of these questions would be answered easily by seeing a fully worked example that goes through all the stages of defining, training, saving, loading and then using a model with BatchNorm.



I tried to come with an example and it turns out setting use_global_stats=True crashed my testcase. I have to investigate the problem.

Regarding question1: use_global_stats is only being checked during training, so setting it during inference, won’t have any effect. I would recommend do not set it: this means Batchnorm will use local statistics during training ( =batch estimates of mean and averages).

Regarding question2: I tested training/inference of BatchNorm in a small example and I did not see a significant worse accuracy during inference. Could you maybe share a small reproducible example?

Thanks again for your help.

  1. Yes, I will try to come up with a sample example that recreates the BatchNorm problem I ran into. I’ll try to do that in the coming week.

  2. Meanwhile… I don’t understand your response to my Question #1. In your response, you said that use_global_stats is only referenced during training. But doesn’t seem to agree with the documentation:

If ``use_global_stats`` is set to be true, then ``moving_mean`` and
``moving_var`` are used instead of ``data_mean`` and ``data_var`` to compute
the output. It is often used during inference.

Again - if you or ANYBODY has a complete, end-to-end example of training and then predicting with a BatchNorm model, it would really, really help. Especially if the example was in Scala.