Hi, I’m using conda_mxnet_p36 on an Amazon SageMaker notebook instance. In the code below I’d like to instantiate a model and display its summary parameters table. I have the following error. Any idea what’s wrong? how to make it work? Cheers
import time
import mxnet as mx
from mxnet import nd
import numpy as np
from mxnet import nd, autograd, gluon
ctx = mx.gpu(0)
# define CNN
num_inputs = 784
num_outputs = 10
num_fc = 256
def BuildNet():
net = gluon.nn.HybridSequential()
with net.name_scope():
net.add(gluon.nn.Conv2D(channels=20, kernel_size=3, activation='relu'))
net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))
net.add(gluon.nn.Conv2D(channels=50, kernel_size=3, activation='relu'))
net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))
# The Flatten layer collapses all axis, except the first one, into one axis.
net.add(gluon.nn.Flatten())
net.add(gluon.nn.Dense(num_fc, activation="relu"))
net.add(gluon.nn.Dropout(.3))
net.add(gluon.nn.Dense(num_outputs))
return net
net = BuildNet()
net.collect_params().initialize()
x = mx.sym.var('data')
sym = net(x)
mx.viz.print_summary(sym)
returns:
________________________________________________________________________________________________________________________
Layer (type) Output Shape Param # Previous Layer
========================================================================================================================
data(null) 0
________________________________________________________________________________________________________________________
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-5-eaf5ad29bf78> in <module>()
4 x = mx.sym.var('data')
5 sym = net(x)
----> 6 mx.viz.print_summary(sym)
~/anaconda3/envs/mxnet_p36/lib/python3.6/site-packages/mxnet/visualization.py in print_summary(symbol, shape, line_length, positions)
182 if key in shape_dict:
183 out_shape = shape_dict[key][1:]
--> 184 total_params += print_layer_summary(nodes[i], out_shape)
185 if i == len(nodes) - 1:
186 print('=' * line_length)
~/anaconda3/envs/mxnet_p36/lib/python3.6/site-packages/mxnet/visualization.py in print_layer_summary(node, out_shape)
135 cur_param = 0
136 if op == 'Convolution':
--> 137 if ("no_bias" in node["attrs"]) and int(node["attrs"]["no_bias"]):
138 cur_param = pre_filter * int(node["attrs"]["num_filter"])
139 for k in _str2tuple(node["attrs"]["kernel"]):
ValueError: invalid literal for int() with base 10: 'False'