Let’s suppose I have a HybridBlock consisting of an MLP model. The return value of the hybrid_forward method of my model is the pair (output, final_hidden_layer)
. This model trains by hooking some loss function up to the output
as predictions.
After base training, I then create a new model from this HybridBlock by adding an additional layer on top of final_hidden_layer
, like so:
class CalibratedModel(gluon.nn.HybridBlock):
def __init__(self, model, calibration_layer, model_output_name, prefix=None, params=None):
super().__init__(prefix=prefix, params=params)
self.output_idx = model.ordered_output_names.index(model_output_name)
with self.name_scope():
self.model = model.model
self.calibration_layer = calibration_layer
def hybrid_forward(self, F, *args):
X = self.model(*args)
Y = X[self.output_idx]
return self.calibration_layer(Y)
The problem I’m observing is that when I instantiate a CalibratedModel
and try to call export()
, I get the following error:
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
<ipython-input-11-7daab4c36e3a> in <module>
----> 1 cm.export('/ebs/test')
~/.local/lib/python3.6/site-packages/mypackage/Gluon/blocks.py in export(self, path, epoch, remove_amp_cast)
49
50 def export(self, path, epoch=0, remove_amp_cast=True):
---> 51 return self.model.export(path, epoch, remove_amp_cast)
52
53 def getName(self):
~/.local/lib/python3.6/site-packages/mxnet/gluon/block.py in export(self, path, epoch, remove_amp_cast)
1133 arg_dict['arg:%s'%name] = param._reduce()
1134 else:
-> 1135 assert name in aux_names
1136 arg_dict['aux:%s'%name] = param._reduce()
1137 save_fn = _mx_npx.save if is_np_array() else ndarray.save
AssertionError:
The call to export successfully exports the symbol file, but fails exporting the parameters due to the above assertion. I believe this is caused by my original model having a “dangling” layer, in some sense, because the layer and parameters that produces output
are unused.
Have any of you experienced this error before and have a fix available?