I am migrating from tensorflow, so I am not sure if I am doing this properly. Consider the following code snippet:
cross_entropy_cost = mx.symbol.SoftmaxOutput(data=self.logits, label = label, name=name + 'softmax')
#softmax_label
if self.distillation > 0:
"""distill_loss = self.distillation * rmse (
mx.symbol.SoftmaxActivation(self.logits / self.temperature),
mx.symbol.SoftmaxActivation(self.mentor_logits / self.temperature) )"""
distill_loss = mx.symbol.SoftmaxOutput (data = self.logits / self.temperature,\
label = mx.symbol.SoftmaxActivation(self.mentor_logits / self.temperature),
name = name + '_distillation')
loss = mx.symbol.Group( [ cross_entropy_cost , mx.symbol.MakeLoss(distill_loss) ] )
else:
loss = cross_entropy_cost
return loss
if self.distillation = 0
, the code works. Neither the commented out section nor the open section work and they both produce the same error. The error is as follows:
2018-01-30 01:53:14 Traceback printing
Traceback (most recent call last):
File "start.py", line 144, in main
trainer_student.train()
File "----", line 14, in train
return self.fitter.fit(args, sym, args_params=args_params, aux_params=aux_params)
File "/home/---/fit_adapter.py", line 20, in fit
f.fit(args, sym, data.get_rec_iter, arg_params=args_params, aux_params=aux_params)
File "/home/---/fit.py", line 320, in fit
monitor = monitor)
File "/home/---/anaconda3/envs/mxnet_p27/lib/python2.7/site-packages/mxnet/module/base_module.py", line 496, in fit
self.update_metric(eval_metric, data_batch.label)
File "/home/ec2-user/anaconda3/envs/mxnet_p27/lib/python2.7/site-packages/mxnet/module/module.py", line 748, in update_metric
self._exec_group.update_metric(eval_metric, labels)
File "/home/ec2-user/anaconda3/envs/mxnet_p27/lib/python2.7/site-packages/mxnet/module/executor_group.py", line 588, in update_metric
eval_metric.update_dict(labels_, preds)
File "/home/ec2-user/anaconda3/envs/mxnet_p27/lib/python2.7/site-packages/mxnet/metric.py", line 280, in update_dict
metric.update_dict(labels, preds)
File "/home/ec2-user/anaconda3/envs/mxnet_p27/lib/python2.7/site-packages/mxnet/metric.py", line 108, in update_dict
self.update(label, pred)
File "/home/ec2-user/anaconda3/envs/mxnet_p27/lib/python2.7/site-packages/mxnet/metric.py", line 388, in update
check_label_shapes(labels, preds)
File "/home/ec2-user/anaconda3/envs/mxnet_p27/lib/python2.7/site-packages/mxnet/metric.py", line 41, in check_label_shapes
"predictions {}".format(label_shape, pred_shape))
ValueError: Shape of labels 1 does not match shape of predictions 2