Because a typical gluon code is a define-by-run computational graph, there is no model
to save! This may sound strange if you have had experience with declarative frameworks like MXNet’s module API, Caffe, or Tensorflow. However, MXNet does allow you to easily generate a symbol that represents your model’s computational graph if the computational graph is indeed a non-dynamic graph.
In order to construct a non-dynamic graph, all blocks used in the network must be a HybridBlock
. You can read more about them in this Gluon tutorial.
Good news is that CNN networks are almost always non-dynamic computational graphs that can be represented by HybridBlocks. In the tutorial that you provided, all you need to do is modify net
from gluon.nn.Sequential()
to gluon.nn.HybridSequential()
. Then instead of passing an NDArray
to net
, you simply pass a Symbol
and the retuned result is going to be a symbol that represents the computational graph of your network, This symbol can be converted to json and saved. Here is an example based on the tutorial you mentioned:
First create the initial network as a HybridBlock, create a symbol, and convert the symbol to json.
batch_size = 64
num_inputs = 784
num_outputs = 10
num_fc = 512
net = gluon.nn.HybridSequential()
with net.name_scope():
net.add(gluon.nn.Conv2D(channels=20, kernel_size=5, activation='relu'))
net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))
net.add(gluon.nn.Conv2D(channels=50, kernel_size=5, activation='relu'))
net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))
# The Flatten layer collapses all axis, except the first one, into one axis.
net.add(gluon.nn.Flatten())
net.add(gluon.nn.Dense(num_fc, activation="relu"))
net.add(gluon.nn.Dense(num_outputs))
sym_json = net(mx.sym.var('data')).tojson()
You can save this json string to a file. Now when you want to load model, you can use the gluon.nn.SymbolBlock to load the symbol:
net = gluon.nn.SymbolBlock(
outputs=mx.sym.load_json(sym_json),
inputs=mx.sym.var('data'))
Now you can use the net
just like the original net. Specifically, you can now call load_params()
on it with a path to the file where the parameters of the trained network are saved and then pass NDArray to it to make prediction:
net.load_params(params_filename)
x = mx.nd.random.uniform(shape=(16, 3, 224, 224))
predictions = net(x)
print(predictions.shape)
The above prints:
(16, 10)
which is the correct output shape for this CNN classification network.