Calculating gluon Dataset statistics

In mxnet v1.6.0, what’s the recommended way to calculate image dataset statistics, for example for mxnet.gluon.data.vision.datasets.CIFAR10 (or a custom gluon.data.Dataset)?

I’m having a lot more trouble determining the correct way to do this than I expected. Is mx.np and as_np_ndarray the recommended way to do this with latest mxnet?

For example:

from mxnet import gluon, np
from mxnet.gluon.data.vision import datasets, transforms
from multiprocessing import cpu_count

cifar_train = datasets.CIFAR10(train=True)

def get_mean_and_std(dataset):
  '''Compute the mean and std value of dataset.'''
  dataset = dataset.transform_first(transforms.ToTensor())
  dataloader = gluon.data.DataLoader(
    dataset, batch_size=1, shuffle=True, num_workers=cpu_count())
  mean = np.zeros(3)
  std = np.zeros(3)
  for X, y in dataloader:    
    X = X.as_np_ndarray().reshape(3,-1)
    mean += np.mean(X, axis=1)
    std += np.std(X, axis=1)
  mean /= len(dataloader)
  std /= len(dataloader)  
  return mean, std

mean, std = get_mean_and_std(cifar_train)

I’m not sure if it’s recommended to:

  • use mx.np and as_np_ndarray
  • use npx.set_np()
  • use a premade convenience function that I haven’t found yet
  • something else entirely

Any help regarding the recommended way to do this in the current version of mxnet would be greatly appreciated!