BatchNorm parameters are not properly copied under multiple GPU setting

I have a gluon.HybridBlock model trained on multiple GPUs, and my goal is to copy its parameters to another model on multiple GPUs. The two models share identical architectures, and we use the same number of GPUs as contexts for both models. Furthermore, each model has a batch normalization layer.

I tried two ways to initialize the second model with the parameters of the first one. The first was by using the save_parameters and load_parameters methods of gluon.Block class as follows:

#model_1 has been initialized on [gpu(i) for i in range(4)]
model_2.load_parameters("", ctx=[gpu(i) for i in range(4)])

Unfortunately, this trick worked for all parameters except for the mean and variance parameters of the batch normalization layer. Concretely, after executing the above code, model_1 and model_2 (i) share the same parameters in all layers other than batch normalization; and (ii) in the batch normalization layer, they share the same beta and gamma parameters but have different mean and variance parameters.

The second way was to use the method set_data of gluon.Parameter class as follows:

params1 = model_1.collect_params()
params2 = model_2.collect_params()
for p1, p2 in zip(params1.values(), params2.values()):
    p1_data =

Unfortunately, the problem with the above code is that gets the data from one context (i.e., one specific GPU), and subsequently p2.set_data(p1_data) sets the value of parameter p2 to the same value (i.e., p1_data) across all GPUs. However, when training an mxnet model with a batch normalization layer with multiple GPUs, each GPU (context) has its own mean and variance batch normalization parameters, while the gamma and beta batch normalization parameters are shared among all GPUs. (For layers other than batch normalization, all GPUs share the same parameters.) As a result, the second approach does not work since it will set the mean and variance for the batch normalization layer of model_2 to the same value across all GPUs, even though they should have different values on each GPU in model_1. Interestingly, the mean and variance parameters of the batch normalization layer is also where the problem occurred for the first approach outlined above.

What is the reason for this? Is there some other way to address this problem?