Thank you @ThomasDelteil, indeed, the greatest difficulty was to set up everything properly on the HPC environment. Critically, there was a bug due to different gcc compilers that was resolved few weeks ago, and made installation easy. I finally have a fully working version and itâs super fast. I think even for a single node is faster, but I need to do proper benchmarks for this. The official mxnet_mnist.py example on horovod is super awesome. I also found this tutorial on ring-allreduce extremely beneficial.
Leaving here for reference some tips for running Horovod + mxnet under SLURM manager (please correct me if you see anything weird/wrong). The corresponding modules loaded, may be different in different HPC environments, but this is my take (assuming you have everything installed and running).
slurm job submit file (if your HPC environment supports it):
#!/bin/bash
#SBATCH --job-name="HVDRC"
#SBATCH --nodes=12
#SBATCH -t 23:30:30
#SBATCH --cpus-per-task=4
#SBATCH --gres=gpu:4
#SBATCH --ntasks-per-node=4 ##### This should be EQUAL to the number of GPUs for the MPI, specifiying the gres=gpu:4 only doesn't work
#SBATCH --mem=32gb
##### Number of total processes
echo " "
echo " Nodelist:= " $SLURM_JOB_NODELIST
echo " Number of nodes:= " $SLURM_JOB_NUM_NODES
echo " NGPUs per node:= " $SLURM_GPUS_PER_NODE
echo " Ntasks per node:= " $SLURM_NTASKS_PER_NODE
echo " "
#### Load modules that you used when installed horovod.
module load cuda/9.2.88
module load nccl/2.3.7-cuda92 #### I am not sure I need it due to the allreduce mpi
module load cudnn/v7.5.0-cuda92
module load gcc/8.3.0
module load openmpi/4.0.0-simple-gcc ### Working
module load hpc-x
# print on screen what you used
module list
#### Use MPI for communication with Horovod - this can be hard-coded during installation as well.
export HOROVOD_GPU_ALLREDUCE=MPI
export HOROVOD_GPU_ALLGATHER=MPI
export HOROVOD_GPU_BROADCAST=MPI
#### Produce a timeline for debugging purposes
####export HOROVOD_TIMELINE=./timeline.json ### Do not use for production runs, it produces very large files
export NCCL_DEBUG=DEBUG
ulimit -s 20480 ####### Horovod recommends 10240 for this
echo "Running on multiple nodes and GPU devices"
echo ""
echo "Run started at:- "
date
##### Actual executable
mpirun -np $SLURM_NTASKS -bind-to none -map-by slot -x HOROVOD_TIMELINE -x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH -x PATH -mca pml ob1 -mca btl ^openib python ./main.py ######## SUCCESSS
echo "Run completed at:- "
date
mxnet/gluon specific things I tweaked.
- As you said @ThomasDelteil, there is no
split_and_load
anymore. The whole code is like running on a single context.
- DataLoader for train and validation data (SplitSampler is the one from the parameter-server distributed tutorial - see above) with pinned memory (#14136):
data_generator = gluon.data.DataLoader(dataset,
batch_size = self[C.C_BATCH_SIZE],
sampler = SplitSampler(length=len(dataset),
num_parts=hvd.size(),
part_index=hvd.rank(),
shuffle=True),
# ******* fixes a bug causing segm fault ******
pin_memory=True,
pin_device_id = hvd.local_rank(), # See issue 14136
# *********************************************
last_batch = 'discard',
num_workers = 0)# in all my tests, num_workers=0 is fastest.
- All the horovod mxnet-related useful functions one can use are in
horovod.mxnet.__init__.py
and horovod.mxnet.mpi_tools.py
. I found super useful the function allreduce
(from mpi_tools) in order to calculate global/average statistics across all workers. So, assuming one calculates a validation loss/per worker (after the split in the validation data), then one can calculate the average of these losses with:
from horovod.mxnet.mpi_tools import allreduce
# This is per worker, returns nd.array of shape 1.
valLoss = some_function_eval_val_loss()
# the argument of allreduce must be an nd.array.
# returns a scalar value, average of all losses.
valLoss = allreduce(valLoss,average=True).asscalar()
you can then print only this average from the first node:
if hvd.rank() == 0:
print ("avg validation loss:{}".format(valLoss))
And if youâve defined a metric that uses _BinaryClassificationMetrics
for binary classification, this function will come handy for calculating global statistics. In my case I use MCC as metric.
# This is specific for metrics derived from the class _BinaryClassificationMetrics
# Construct tensor (i.e. nd.array) objects, on the same local context to feed into allreduce
# Note the square brackets [ ], these give tp.shape = [1]. If you ommit them - I did - you'll get an error.
tp = nd.array([mcc._metrics.global_true_positives], ctx = mx.cpu(hvd.local_rank()) )
tn = nd.array([mcc._metrics.global_true_negatives], ctx = mx.cpu(hvd.local_rank()) )
fp = nd.array([mcc._metrics.global_false_positives], ctx = mx.cpu(hvd.local_rank()) )
fn = nd.array([mcc._metrics.global_false_negatives], ctx = mx.cpu(hvd.local_rank()) )
# Some over all true positives/negatives etc to get global stats.
tp = allreduce(tp,average=False).asscalar()
tn = allreduce(tn,average=False).asscalar()
fp = allreduce(fp,average=False).asscalar()
fn = allreduce(fn,average=False).asscalar()
# The definition of accuracy, mcc, precision, recall can be found on wiki (among other places)
# these are functions you need to define, e.g.
def accuracy(tp,tn,fp,fn):
num = tp+tn
denum = num+fp+fn
return num/denum
def mcc(tp,tn,fp,fn):
terms = [(tp+fp),(tp+fn),(tn+fp),(tn+fn)]
denom = np.prod(terms)
return (tp*tn - fp*fn)/np.sqrt(denom)
def precision(tp,tn,fp,fn):
return tp/(tp+fp)
def recall(tp,tn,fp,fn):
return tp/(tp+fn)
# put everything in a dict for saving history (at least, that's how I use it).
kwards = dict()
metric_name, metric_value = mcc.get() # This provides per worker statistics
kwards[metric_name ] = metric_value
kwards['global_acc'] = accuracy(tp,tn,fp,fn) # global accuracy
kwards['global_mcc'] = mcc(tp,tn,fp,fn) # global mcc
kwards['global_precision'] = precision(tp,tn,fp,fn) # global precision
kwards['global_recall'] = recall(tp,tn,fp,fn) # global recall
# Now kwards holds all statistics.
-
When loading parameters from a file (for a neural network), I currently load them in all workers. I think the recommended approach is to load once (on first node - hvd.rank() == 0) and then broadcast but it hasnât worked for me (probably some silly bug). I donât think itâs a huge issue (computationally wise).
-
The definition of context for heavy computation is (assuming we want to take advantage of gpus)
local_context = mx.gpu(hvd.local_rank()) if len( list( mx.test_utils.list_gpus()) ) > 0 else mx.cpu(hvd.local_rank())
For metric evaluations, or other local operations (e.g. true positive above), it is unnecessary to copy labels (or tensors) into gpu, use mx.cpu(hvd.local_rank())
. So, in the definition of accuracy evaluation in the official mxnet_mnist.py example:
def evaluate(model, data_iter, context):
data_iter.reset()
metric = mx.metric.Accuracy()
for _, batch in enumerate(data_iter):
data = batch.data[0].as_in_context(context)
#label = batch.label[0].as_in_context(context) # this is an unnecessary copy
label = batch.label[0]
output = model(data.astype(args.dtype, copy=False))
metric.update([label], [output]) # is copied back to cpu() internally
return metric.get()
You can see this in the official mxnet accuracy function, the lines:
def update(self, labels, preds):
"""Updates the internal evaluation result.
Parameters
----------
labels : list of `NDArray`
The labels of the data with class indices as values, one per sample.
preds : list of `NDArray`
Prediction values for samples. Each prediction value can either be the class index,
or a vector of likelihoods for all classes.
"""
labels, preds = check_label_shapes(labels, preds, True)
for label, pred_label in zip(labels, preds):
if pred_label.shape != label.shape:
pred_label = ndarray.argmax(pred_label, axis=self.axis)
pred_label = pred_label.asnumpy().astype('int32')
label = label.asnumpy().astype('int32') # <============ here ===
hope itâll prove useful to someone out there :).
Many thanks to the community for all the help.