How to feed 6-channel images to the network

I have a multi-task network that has two softmax classifiers at the end. I want to feed it 6-channel images (basically the image and the segmentation mask). I am able to train my network with RGB image using the following data-iterator and mode.fit

What do I need to change to be able to feed two images to the network

class MultitaskIterator(mx.io.DataIter):
“”" multi task iterator “”"

def __init__(self, data_iter):
    super(MultitaskIterator, self).__init__()
    self.data_iter = data_iter
    self.batch_size = self.data_iter.batch_size

@property
def provide_data(self):
    return self.data_iter.provide_data

@property
def provide_label(self):
    #provide_label = self.data_iter.provide_label[0]
    # Different labels should be used here for actual application
    return [('softmax1_label', (self.batch_size,)), \
            ('softmax2_label', (self.batch_size,))]

def hard_reset(self):
    self.data_iter.hard_reset()

def reset(self):
    self.data_iter.reset()

def next(self):
    batch = self.data_iter.next()
    label = batch.label[0]
    label1, label2 = label.T.asnumpy()
    label1 = mx.nd.array(label1)
    label2 = mx.nd.array(label2)
    return mx.io.DataBatch(data=batch.data, label=[label1, label2], \
            pad=batch.pad, index=batch.index)

train = mx.io.ImageRecordIter(
path_imglist= train_list, # you have to specify path_imglist when label_width larger than 2.
path_imgrec = train_rec,
#mean_img = train_mean,
data_shape = data_shape,
batch_size = batch_size,
rand_crop = True,
rand_mirror = True,
shuffle = True,
label_width = 2 # specify label_width = 2 here
)

model.fit(train,
begin_epoch = epoch,
num_epoch = num_epochs,
eval_data = val,
eval_metric = MultitaskAccuracy(num=2),
optimizer = ‘sgd’,
optimizer_params = optimizer_params,
arg_params = new_args,
initializer = initializer,
allow_missing = True,
batch_end_callback = mx.callback.Speedometer(batch_size, 50),
epoch_end_callback = checkpoint)

What you’re trying to achieve is much simpler using the Gluon API and it’s easy to train in Gluon and deploy using module, as long as your network is hybridizable. Personally I can’t think of a reason why one would need to remain in ModuleAPI world. If you believe you have a valid reason to not switch, please let me know (maybe there exists a solution that’s not known to you).

1 Like