Object detection, finetune F-RCNN models

Hi there,
I was trying to finetune a Faster-RCNN on my custom dataset, and I was following the corresponding tutorial.
As mentioned at the end, the tutorial is meant to be used for SSD models, and I was trying to modify it by including F-RCNN blocks from the train_faster_rcnn.py files.

The main difference with the train_faster_rcnn.py file is that I need to finetune on my dataset, so I changed the function to get the dataset in order to read my own .rec files, instead of downloading COCO, voc or similar.
I hardcoded the variables that are expected to be passed as initial arguments, and I passed them to the training function. For the rest, I used the other code blocks from the original file.
This is what I have now:

import time
import os
import logging
import mxnet as mx
from mxnet import autograd, gluon
import gluoncv as gcv
from mxboard import SummaryWriter
from gluoncv.data.batchify import FasterRCNNTrainBatchify, Tuple, Append
from gluoncv.data.transforms.presets.rcnn import FasterRCNNDefaultTrainTransform, \
	FasterRCNNDefaultValTransform
from gluoncv.utils.metrics.voc_detection import VOC07MApMetric
from gluoncv.utils.parallel import Parallelizable, Parallel
from gluoncv.utils.metrics.rcnn import RPNAccMetric, RPNL1LossMetric, RCNNAccMetric, \
	RCNNL1LossMetric

def main():
	ctx = [mx.gpu(0)]

	# network
	kwargs = {}
	module_list = []

	## whether to use feature pyramid network
	use_fpn = False
	if use_fpn:
		module_list.append('fpn')

	# module_list.append('fpn')

	for param in net.collect_params().values():
		if param._data is not None:
			continue
		param.initialize()
	net.collect_params().reset_ctx(ctx)

	# output log file
	log_file = open(f'{saved_weights_path}{project_name}_{model_name}_log_file.txt', 'w')
	log_file.write("Epoch".rjust(8))
	for class_name in classes:
		log_file.write(f"{class_name:>15}")
	log_file.write("Total".rjust(15))
	log_file.write("\n")
	# summary file for tensorboard
	sw = SummaryWriter(logdir=saved_weights_path+'logs/', flush_secs=30)

	# prepare data
	data_shape = 512
	train_dataset = gcv.data.RecordFileDetection(f'custom_dataset/train_{project_name}.rec', coord_normalized=True)
	val_dataset  = gcv.data.RecordFileDetection(f'custom_dataset/test_{project_name}.rec', coord_normalized=True)
	eval_metric = VOC07MApMetric(iou_thresh=0.5, class_names=classes)
	# COCO metrics seem to work only on COCO dataset, while custom dataset is a RecordFileDetection file!
	# eval_metric = COCODetectionMetric(val_dataset, '_eval', data_shape=(data_shape, data_shape))

	# create data batches from dataset (net, train_dataset, data_shape, batch_size, num_workers):
	train_data, val_data = get_dataloader(net, train_dataset, val_dataset, FasterRCNNDefaultTrainTransform,
	FasterRCNNDefaultValTransform, batch_size, len(ctx), use_fpn, num_workers=0)
	print(f"train dataloader -> {len(train_data)}")
	print(f"test dataloader -> {len(val_data)}")

	# training
	train(net, model_name, train_data, val_data, eval_metric, batch_size, ctx, lr=0.001, wd=0.0005, momentum=0.9, lr_decay=0.1, lr_decay_epoch='', lr_warmup=1000, lr_warmup_factor=1. / 3., start_epoch=0, epochs=100, log_interval=100, val_interval=1)



def get_dataloader(net, train_dataset, val_dataset, train_transform, val_transform, batch_size,
				   num_shards, use_fpn, num_workers):
	"""Get dataloader."""
	train_bfn = FasterRCNNTrainBatchify(net, num_shards)
	if hasattr(train_dataset, 'get_im_aspect_ratio'):
		im_aspect_ratio = train_dataset.get_im_aspect_ratio()
	else:
		im_aspect_ratio = [1.] * len(train_dataset)
	train_sampler = \
		gcv.nn.sampler.SplitSortedBucketSampler(im_aspect_ratio, batch_size,
												num_parts = 1,
												part_index = 0,
												shuffle=True)
	train_loader = mx.gluon.data.DataLoader(train_dataset.transform(
		train_transform(net.short, net.max_size, net, ashape=net.ashape, multi_stage=use_fpn)),
		batch_sampler=train_sampler, batchify_fn=train_bfn, num_workers=num_workers)
	val_bfn = Tuple(*[Append() for _ in range(3)])
	short = net.short[-1] if isinstance(net.short, (tuple, list)) else net.short
	# validation use 1 sample per device
	val_loader = mx.gluon.data.DataLoader(
		val_dataset.transform(val_transform(short, net.max_size)), num_shards, False,
		batchify_fn=val_bfn, last_batch='keep', num_workers=num_workers)
	return train_loader, val_loader


class ForwardBackwardTask(Parallelizable):
	def __init__(self, net, optimizer, rpn_cls_loss, rpn_box_loss, rcnn_cls_loss, rcnn_box_loss,
				 mix_ratio):
		super(ForwardBackwardTask, self).__init__()
		self.net = net
		self._optimizer = optimizer
		self.rpn_cls_loss = rpn_cls_loss
		self.rpn_box_loss = rpn_box_loss
		self.rcnn_cls_loss = rcnn_cls_loss
		self.rcnn_box_loss = rcnn_box_loss
		self.mix_ratio = mix_ratio

	def forward_backward(self, x):
		data, label, rpn_cls_targets, rpn_box_targets, rpn_box_masks = x
		with autograd.record():
			gt_label = label[:, :, 4:5]
			gt_box = label[:, :, :4]
			cls_pred, box_pred, roi, samples, matches, rpn_score, rpn_box, anchors, cls_targets, \
			box_targets, box_masks, _ = net(data, gt_box, gt_label)
			# losses of rpn
			rpn_score = rpn_score.squeeze(axis=-1)
			num_rpn_pos = (rpn_cls_targets >= 0).sum()
			rpn_loss1 = self.rpn_cls_loss(rpn_score, rpn_cls_targets,
										  rpn_cls_targets >= 0) * rpn_cls_targets.size / num_rpn_pos
			rpn_loss2 = self.rpn_box_loss(rpn_box, rpn_box_targets,
										  rpn_box_masks) * rpn_box.size / num_rpn_pos
			# rpn overall loss, use sum rather than average
			rpn_loss = rpn_loss1 + rpn_loss2
			# losses of rcnn
			num_rcnn_pos = (cls_targets >= 0).sum()
			rcnn_loss1 = self.rcnn_cls_loss(cls_pred, cls_targets,
											cls_targets.expand_dims(-1) >= 0) * cls_targets.size / \
						 num_rcnn_pos
			rcnn_loss2 = self.rcnn_box_loss(box_pred, box_targets, box_masks) * box_pred.size / \
						 num_rcnn_pos
			rcnn_loss = rcnn_loss1 + rcnn_loss2
			# overall losses
			total_loss = rpn_loss.sum() * self.mix_ratio + rcnn_loss.sum() * self.mix_ratio

			rpn_loss1_metric = rpn_loss1.mean() * self.mix_ratio
			rpn_loss2_metric = rpn_loss2.mean() * self.mix_ratio
			rcnn_loss1_metric = rcnn_loss1.mean() * self.mix_ratio
			rcnn_loss2_metric = rcnn_loss2.mean() * self.mix_ratio
			rpn_acc_metric = [[rpn_cls_targets, rpn_cls_targets >= 0], [rpn_score]]
			rpn_l1_loss_metric = [[rpn_box_targets, rpn_box_masks], [rpn_box]]
			rcnn_acc_metric = [[cls_targets], [cls_pred]]
			rcnn_l1_loss_metric = [[box_targets, box_masks], [box_pred]]

			total_loss.backward()

		return rpn_loss1_metric, rpn_loss2_metric, rcnn_loss1_metric, rcnn_loss2_metric, \
			   rpn_acc_metric, rpn_l1_loss_metric, rcnn_acc_metric, rcnn_l1_loss_metric



def train(net, model_name, train_data, val_data, eval_metric, batch_size, ctx, lr, wd, momentum, lr_decay, lr_decay_epoch, lr_warmup, lr_warmup_factor, start_epoch, epochs, log_interval, val_interval):
	"""Training pipeline"""
	kv_store = 'local'
	net.collect_params().setattr('grad_req', 'null')
	net.collect_train_params().setattr('grad_req', 'write')
	optimizer_params = {'learning_rate': lr, 'wd': wd, 'momentum': momentum}
	trainer = gluon.Trainer(
		net.collect_train_params(),  # fix batchnorm, fix first stage, etc...
		'sgd',
		optimizer_params,
		update_on_kvstore=None, kvstore=kv_store)


	# lr decay policy
	lr_decay = float(lr_decay)
	lr_steps = sorted([float(ls) for ls in lr_decay_epoch.split(',') if ls.strip()])
	lr_warmup = float(lr_warmup)  # avoid int division

	# TODO(zhreshold) losses?
	rpn_cls_loss = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False)
	rpn_box_loss = mx.gluon.loss.HuberLoss(rho=1 / 9.)  # == smoothl1
	rcnn_cls_loss = mx.gluon.loss.SoftmaxCrossEntropyLoss()
	rcnn_box_loss = mx.gluon.loss.HuberLoss()  # == smoothl1
	metrics = [mx.metric.Loss('RPN_Conf'),
			   mx.metric.Loss('RPN_SmoothL1'),
			   mx.metric.Loss('RCNN_CrossEntropy'),
			   mx.metric.Loss('RCNN_SmoothL1'), ]

	rpn_acc_metric = RPNAccMetric()
	rpn_bbox_metric = RPNL1LossMetric()
	rcnn_acc_metric = RCNNAccMetric()
	rcnn_bbox_metric = RCNNL1LossMetric()
	metrics2 = [rpn_acc_metric, rpn_bbox_metric, rcnn_acc_metric, rcnn_bbox_metric]

	# set up logger
	logging.basicConfig()
	logger = logging.getLogger()
	logger.setLevel(logging.INFO)
	log_file_path = model_name + '_train.log'
	log_dir = os.path.dirname(log_file_path)
	if log_dir and not os.path.exists(log_dir):
		os.makedirs(log_dir)
	fh = logging.FileHandler(log_file_path)
	logger.addHandler(fh)
	logger.info('Start training from [Epoch {}]'.format(start_epoch))
	best_map = [0]
	for epoch in range(start_epoch, epochs):
		mix_ratio = 1.0
		rcnn_task = ForwardBackwardTask(net, trainer, rpn_cls_loss, rpn_box_loss, rcnn_cls_loss,
										rcnn_box_loss, mix_ratio=1.0)
		executor = Parallel(1, rcnn_task)
		while lr_steps and epoch >= lr_steps[0]:
			new_lr = trainer.learning_rate * lr_decay
			lr_steps.pop(0)
			trainer.set_learning_rate(new_lr)
			logger.info("[Epoch {}] Set learning rate to {}".format(epoch, new_lr))
		for metric in metrics:
			metric.reset()
		tic = time.time()
		btic = time.time()
		base_lr = trainer.learning_rate
		rcnn_task.mix_ratio = mix_ratio
		for i, batch in enumerate(train_data):
			if epoch == 0 and i <= lr_warmup:
				# adjust based on real percentage
				new_lr = base_lr * get_lr_at_iter(i / lr_warmup, lr_warmup_factor)
				if new_lr != trainer.learning_rate:
					if i % log_interval == 0:
						logger.info(
							'[Epoch 0 Iteration {}] Set learning rate to {}'.format(i, new_lr))
					trainer.set_learning_rate(new_lr)
			batch = split_and_load(batch, ctx_list=ctx)
			metric_losses = [[] for _ in metrics]
			add_losses = [[] for _ in metrics2]
			if executor is not None:
				for data in zip(*batch):
					executor.put(data)
			for j in range(len(ctx)):
				if executor is not None:
					result = executor.get()
				else:
					result = rcnn_task.forward_backward(list(zip(*batch))[0])
				for k in range(len(metric_losses)):
					metric_losses[k].append(result[k])
				for k in range(len(add_losses)):
					add_losses[k].append(result[len(metric_losses) + k])
			for metric, record in zip(metrics, metric_losses):
				metric.update(0, record)
			for metric, records in zip(metrics2, add_losses):
				for pred in records:
					metric.update(pred[0], pred[1])
			trainer.step(batch_size)

			# update metrics
			if log_interval and not (i + 1) % log_interval:
				msg = ','.join(
					['{}={:.3f}'.format(*metric.get()) for metric in metrics + metrics2])
				logger.info('[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}'.format(
					epoch, i, log_interval * batch_size / (time.time() - btic), msg))
				btic = time.time()

		msg = ','.join(['{}={:.3f}'.format(*metric.get()) for metric in metrics])
		logger.info('[Epoch {}] Training cost: {:.3f}, {}'.format(
			epoch, (time.time() - tic), msg))
		if not (epoch + 1) % val_interval:
			# consider reduce the frequency of validation to save time
			map_name, mean_ap = validate(net, val_data, ctx, eval_metric)
			val_msg = '\n'.join(['{}={}'.format(k, v) for k, v in zip(map_name, mean_ap)])
			logger.info('[Epoch {}] Validation: \n{}'.format(epoch, val_msg))
			current_map = float(mean_ap[-1])
		else:
			current_map = 0.
		save_params(net, logger, best_map, current_map, epoch, 1,
					model_name)


def save_params(net, logger, best_map, current_map, epoch, save_interval, prefix):
	current_map = float(current_map)
	if current_map > best_map[0]:
		logger.info('[Epoch {}] mAP {} higher than current best {} saving to {}'.format(
			epoch, current_map, best_map, '{:s}_best.params'.format(prefix)))
		best_map[0] = current_map
		net.save_parameters('{:s}_best.params'.format(prefix))
		with open(prefix + '_best_map.log', 'a') as f:
			f.write('{:04d}:\t{:.4f}\n'.format(epoch, current_map))
	if save_interval and (epoch + 1) % save_interval == 0:
		logger.info('[Epoch {}] Saving parameters to {}'.format(
			epoch, '{:s}_{:04d}_{:.4f}.params'.format(prefix, epoch, current_map)))
		net.save_parameters('{:s}_{:04d}_{:.4f}.params'.format(prefix, epoch, current_map))


def split_and_load(batch, ctx_list):
	"""Split data to 1 batch each device."""
	new_batch = []
	for i, data in enumerate(batch):
		if isinstance(data, (list, tuple)):
			new_data = [x.as_in_context(ctx) for x, ctx in zip(data, ctx_list)]
		else:
			new_data = [data.as_in_context(ctx_list[0])]
		new_batch.append(new_data)
	return new_batch


def validate(net, val_data, ctx, eval_metric):
	"""Test on validation dataset."""
	clipper = gcv.nn.bbox.BBoxClipToImage()
	eval_metric.reset()
	net.hybridize(static_alloc=False)
	for batch in val_data:
		batch = split_and_load(batch, ctx_list=ctx)
		det_bboxes = []
		det_ids = []
		det_scores = []
		gt_bboxes = []
		gt_ids = []
		gt_difficults = []
		for x, y, im_scale in zip(*batch):
			# get prediction results
			ids, scores, bboxes = net(x)
			det_ids.append(ids)
			det_scores.append(scores)
			# clip to image size
			det_bboxes.append(clipper(bboxes, x))
			# rescale to original resolution
			im_scale = im_scale.reshape((-1)).asscalar()
			det_bboxes[-1] *= im_scale
			# split ground truths
			gt_ids.append(y.slice_axis(axis=-1, begin=4, end=5))
			gt_bboxes.append(y.slice_axis(axis=-1, begin=0, end=4))
			gt_bboxes[-1] *= im_scale
			gt_difficults.append(y.slice_axis(axis=-1, begin=5, end=6) if y.shape[-1] > 5 else None)

		# update metric
		for det_bbox, det_id, det_score, gt_bbox, gt_id, gt_diff in zip(det_bboxes, det_ids,
																		det_scores, gt_bboxes,
																		gt_ids, gt_difficults):
			eval_metric.update(det_bbox, det_id, det_score, gt_bbox, gt_id, gt_diff)
	return eval_metric.get()


def get_lr_at_iter(alpha, lr_warmup_factor=1. / 3.):
	return lr_warmup_factor * (1 - alpha) + alpha





if __name__ == "__main__":
	# prepare model
	model_name = "faster_rcnn_resnet50_v1b_coco"
	## this will be used to automatically determine input and output file names
	project_name = "train_RCNN"
	classes = ['ball', 'bb_ball', 'drum', 'guitar', 'koshi_bell', 'massager', 'ring', 'snake', 'tinsel']
	batch_size = 8
	# pre-trained model, reset network to predict new class
	net = gcv.model_zoo.get_model(model_name, pretrained=True)
	# net = gcv.model_zoo.get_model(model_name, classes=classes, pretrained=False, transfer='coco')
	net.reset_class(classes)
	# folder where trained model will be saved
	saved_weights_path = f"saved_weights/{project_name}_{model_name}/"
	if not os.path.exists(saved_weights_path):
		os.makedirs(saved_weights_path)

	main()

The problem is that, when I try to run it, I get the error:

mxnet.base.MXNetError: MXNetError: Shape inconsistent, Provided = [1,128], inferred shape=[8,128]

so, it seems that I’m loading the data incorrectly. Any suggestions?