Hdf5 vs .rec, pros cons?

Dear all,

any advice from users that followed both paths? Any pitfalls I should be aware of going the hdf5 path? I currently mainly work with 2D/3D numpy arrays.

Thank you for your time.

Hi @feevos,

Overall I think they’re quite similar, but it’s worth experimenting with. And this is more of a warning to future readers, but if you’re looking to optimise training speed by optimising data loading, first make sure that data loading is actually the bottleneck! Using MXNet’s Profiler.

Main things to watch out for with HDF5 is parallel reading (which happens when num_workers>1). You should take a look at Parallel HDF5 for this or try setting thread_pool=True on the DataLoader.

You should also think about chunking/partitioning for improved speed, but you’d need to change the sampling technique. After shuffling all samples before saving into chunks, you should sample a chunk first, read the whole chunk into memory and then sample from the chunk. You can do the same think with RecordIO by splitting into partition files.

And compression is another area to think about. h5py seems to have general compression techniques (e.g. gzip) but often there are specialised compression techniques depending on what data you’re working with (e.g. jpeg for images). With RecordIO you’re responsible for the compression before packing the sample. With h5py you can use the general compression easily, but it’s a little more work to customise compression like RecordIO.

1 Like

And if you’re interested in trying out HDF5 with MXNet, there already support for it in NDArrayIter (pass in a h5py.Dataset). But because I know you’re a fan of Gluon I’ve just written up a simple Gluon Dataset for testing.

import h5py


class HDF5Dataset(mx.gluon.data.Dataset):
    def __init__(self, filepath, datasets):
        """
        A Gluon Dataset for data in a HDF5 file. Can use multiple datasets of equal length.
        param filepath: path to .hdf5 file
        param datasets: list of strings denoting datasets to use. Order used for __getitem__ return.
        """
        self._file = f = h5py.File(filepath, 'r')
        self._datasets = [self._file[dataset] for dataset in datasets]
        assert len(set([len(dataset) for dataset in self._datasets])) == 1, "Check datasets are equal length."
        
    def __getitem__(self, idx):
        return tuple(dataset[idx] for dataset in self._datasets)
    
    def __len__(self):
        return len(self._datasets[0])
        
    def close(self):
        self._file.close()

And usage would be something like this…

import mxnet as mx


dataset = HDF5Dataset('mytestfile.hdf5', datasets=['data', 'labels'])
dataloader = mx.gluon.data.DataLoader(dataset, batch_size=10, shuffle=True, num_workers=0)
for idx, (data, label) in enumerate(dataloader):
    print("Batch {}".format(idx))

Where I have created the HDF5 file like so…

import numpy as np


samples = 100
data = np.random.randint(low=0, high=256, size=(samples, 3, 512, 512))
labels = np.random.randint(low=0, high=10, size=(samples,))
with h5py.File("mytestfile.hdf5", "w") as f:
    f.create_dataset("data", data=data) #, compression="gzip")
    f.create_dataset("labels", data=labels)
1 Like

Hi @thomelane thank you very much for your detailed answer!!! I’ve learned a lot from this. Apologies for my late reply (too many things happening, and I wanted to finish my final implementation of this before getting back to you).

I am mainly interested in the hdf5 format (an excellent tutorial can be found here, for anyone who is interested) for bookkeeping purposes (plus there are some limitations in the number of data files we can store in our inhouse HPC facility). In my case I have semantic segmentation problems to tackle, so the data are sets of (imgs, masks) and I don’t think I can use NDArrayIter as it is (due to custom data augmentations). My initial data format are larger raster files (GeoTiff) that I need to slice to feed into my neural network. So I extract values and store them into float32 numpy arrays.

To add to your comments, in my experiments if I do not use thread_pool = True in the gluon.data.DataLoader then I get the following error:

---------------------------------------------------------------------------
RemoteTraceback                           Traceback (most recent call last)
RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/data/dia021/Software/anaconda3/lib/python3.6/multiprocessing/pool.py", line 119, in worker
    result = (True, func(*args, **kwds))
  File "/data/dia021/Software/mxnet/gluon/data/dataloader.py", line 400, in _worker_fn
    batch = batchify_fn([_worker_dataset[i] for i in samples])
  File "/data/dia021/Software/mxnet/gluon/data/dataloader.py", line 400, in <listcomp>
    batch = batchify_fn([_worker_dataset[i] for i in samples])
  File "/home/dia021/Projects/WA_l8SAR_2018/src/WALS8SARDataset.py", line 47, in __getitem__
    mask = self.masks[self.idx_start+idx,:,:,:]
  File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
  File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
  File "/data/dia021/Software/anaconda3/lib/python3.6/site-packages/h5py/_hl/dataset.py", line 496, in __getitem__
    self.id.read(mspace, fspace, arr, mtype, dxpl=self._dxpl)
  File "h5py/_objects.pyx", line 54, in h5py._objects.with_phil.wrapper
  File "h5py/_objects.pyx", line 55, in h5py._objects.with_phil.wrapper
  File "h5py/h5d.pyx", line 181, in h5py.h5d.DatasetID.read
  File "h5py/_proxy.pyx", line 130, in h5py._proxy.dset_rw
  File "h5py/_proxy.pyx", line 84, in h5py._proxy.H5PY_H5Dread
OSError: Can't read data (wrong B-tree signature)
"""

The above exception was the direct cause of the following exception:

OSError                                   Traceback (most recent call last)
<ipython-input-9-67c2805630ec> in <module>
----> 1 for img,label in datagen:
      2     break
      3 

/data/dia021/Software/mxnet/gluon/data/dataloader.py in __next__(self)
    448         assert self._rcvd_idx in self._data_buffer, "fatal error with _push_next, rcvd_idx missing"
    449         ret = self._data_buffer.pop(self._rcvd_idx)
--> 450         batch = pickle.loads(ret.get()) if self._dataset is None else ret.get()
    451         if self._pin_memory:
    452             batch = _as_in_context(batch, context.cpu_pinned(self._pin_device_id))

/data/dia021/Software/anaconda3/lib/python3.6/multiprocessing/pool.py in get(self, timeout)
    668             return self._value
    669         else:
--> 670             raise self._value
    671 
    672     def _set(self, i, obj):

OSError: Can't read data (wrong B-tree signature)

In addition I’ve found that thread_pool is slower than the default (i think, using multiprocessing). I have followed the workaround from pytorch, here and it seemed to give the best performance for data loading (i.e. best performance using he hdf5 format, I haven’t tested against reading independent *.npy files).

But first, this is how I create my hdf5 file, storing in chunks of 1 so as to allow the dataloader to randomize (am pretty sure this is innefficient, so I’ll test your approach in reading in chunks later on):

# Window size F
F = 256
teye = np.eye(len(dictType),dtype=np.uint8)

# Open here h5 file to store data 
f = h5py.File(r'WA_L8_imgs_masks.hdf5',"w")
f.create_dataset(name="imgs",shape=(1,41,256,256),dtype='float32',maxshape=(None,41,256,256),chunks=(1,41,256,256))
f.create_dataset(name="masks",shape=(1,43,256,256),dtype='float32',maxshape=(None,43,256,256),chunks=(1,43,256,256))
# Do stuff and close file ....

Also I do not know the total number of data from the beginning so I just increase the datasets size on the fly:

counter = 0
for name  in raster_filenames:
      # read raster file, slice etc 
      # out_image_patch is my image, tlabels_all my masks 

     f["imgs"][counter] = out_image_patch
     f["masks"][counter] = tlabels_all
     f["imgs"].resize(f["imgs"].shape[0]+1,axis=0) # increase dimension on the fly
     f["masks"].resize(f["masks"].shape[0]+1,axis=0)
     counter += 1

# fix final dimension
f["imgs"].resize(f["imgs"].shape[0]-1,axis=0)
f["masks"].resize(f["masks"].shape[0]-1,axis=0)
f.close() # close

and this is a snipet of my dataset class for a bit complicated example, where I have a large channel image (includes optical and SAR data):

import numpy as np
from mxnet.gluon.data import dataset
import h5py

class WALS8SARDataset(dataset.Dataset):
    """
    SAR images contain 5 channels: [vh, vv, entropy, anisotropy, alpha] and have 7 observations (in time).  
    """


    def __init__(self, h5file , mode='train', twocomp = False, color=False, transform=None, norm=None):

        # Transformation of augmented data
        self._h5file = h5file
        self._mode = mode
        self._twocomp = twocomp # If true, split in LS8+SAR in a list 
        self._transform = transform
        self._norm = norm # Normalization of img
        self.color = color
        if (color):
            self.colornorm = np.array([1./179, 1./255, 1./255])

        with  h5py.File(h5file,"r") as dataset:
            self.ntrain = int(0.8*dataset['imgs'].shape[0])
            self.nval =  int(0.1*dataset['imgs'].shape[0])


        self.f = None # h5py.File(h5file,"r")


        if self._mode == 'train':
            self.idx_start = 0
        elif self._mode == 'val':
            self.idx_start = self.ntrain
        elif self._mode == 'test':
            self.idx_start = self.ntrain + self.nval


    def __getitem__(self, idx):
        # So according to what I've been reading, this opens up on each thread. 
        if self.f == None:
            self.f = h5py.File(self._h5file,"r")
            self.imgs = self.f['imgs']
            self.masks = self.f['masks']

        # load in float32 - they are stored in float32 by default. 
        base = self.imgs[self.idx_start+idx,:,:,:]
        mask = self.masks[self.idx_start+idx,:,:,:]


        # TODO Add color reconstruction 
        if self.color:
            raise NotImplementedError ("Currently not implemented")

        if self.color:
            mask = np.concatenate([mask,base_hsv],axis=0)

        if self._transform is not None:
            base, mask = self._transform(base, mask)
            if self._norm is not None:
                base = self._norm(base.astype(np.float32))

            if self._twocomp:
                base_ls8 = base[:6,:,:].astype(np.float32)
                base_sar = base[6:,:,:].astype(np.float32)
                base_sar = base_sar.reshape(5,7,self.imgs.shape[1],self.imgs.shape[2])

                
                return base_ls8, base_sar, mask.astype(np.float32)

            return base.astype(np.float32), mask.astype(np.float32)

        else:
            if self._norm is not None:
                base = self._norm(base.astype(np.float32))

            if self._twocomp:
                base_ls8 = base[:6,:,:].astype(np.float32)
                base_sar = base[6:,:,:].astype(np.float32)
                base_sar = base_sar.reshape(5,7,self.imgs.shape[1],self.imgs.shape[2])

                

                return base_ls8, base_sar, mask.astype(np.float32)

            return base.astype(np.float32), mask.astype(np.float32)

    def __len__(self):

        if self._mode == 'train':
            return self.ntrain
        elif self._mode == 'val':
            return self.nval
        elif self._mode == 'test':
            return self.imgs.shape[0]-self.ntrain-self.nval

    def __del__(self):
        if self.f is not None:
            self.f.close()

in my tests, if I open the hdf5 file inside the init function, and use thread_pool the timing is

In [8]: datagen = gluon.data.DataLoader(dataset_train, 
   ...:                                                         batch_size = 4, 
   ...:                                                         shuffle = True, 
   ...:                                                         thread_pool=True, 
   ...:                                                         last_batch = 'discard', 
   ...:                                                         num_workers = 4)                                                              

In [9]: %timeit for img,label in datagen:break                                                                                                
4.44 s ± 631 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

with multiprocessing and the trick as above I get:

In [11]: datagen = gluon.data.DataLoader(dataset_train, 
    ...:                                                         batch_size = 4, 
    ...:                                                         shuffle = True, 
    ...:                                                         thread_pool=False, 
    ...:                                                         last_batch = 'discard', 
    ...:                                                         num_workers = 4)              

In [15]: %timeit for img,label in datagen:break 
    ...:  
    ...:                                                                                                                                      
2.64 s ± 357 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

The times may be better in some other hardware (I’ve performed the tests in an interactive hpc node, where I am competing for resources, but they seem pretty consistent).

Again thank you very much for your reply. I hope my comments may also be of benefit to someone.

I also need to mention here Yannik Rist (a mxnet fan + collaborrator, not in the forum yet) who introduced me in the hdf5 files.

All the best,
Foivos