Hi, examples for mxnet version 2.XX You just need to do layer.reshape, which is similar to numpy operations.
version1:
In [9]: import mxnet as mx
...: mx.npx.set_np()
...:
...: class ReshapeExample(mx.gluon.nn.HybridBlock):
...: def __init__(self, **kwards):
...: super().__init__(**kwards)
...:
...: self.dense = mx.gluon.nn.Dense(units=64*64*16)
...: self.conv = mx.gluon.nn.Conv2D(channels=32,kernel_size=3,padding=1)
...: def forward(self, input):
...: out1 = self.dense(input)
...: print ("shape after dense::{}".format(out1.shape))
...: out1 = out1.reshape(-1,16,64,64) # <=== this does the job
...: print ("shape after reshape::{}".format(out1.shape))
...: out2 = self.conv(out1)
...: print ("shape after conv::{}".format(out2.shape))
...: return out2
...:
In [10]: net = ReshapeExample()
...: net.initialize()
In [11]: xx = mx.np.random.rand(5,32)
In [12]: out = net(xx)
shape after dense::(5, 65536)
shape after reshape::(5, 16, 64, 64)
shape after conv::(5, 32, 64, 64)
In [13]:
version 2:
In [8]: net = mx.gluon.nn.HybridLambda(lambda F, x: x.reshape(-1,16,64,64))
In [9]: xx = mx.np.random.rand(5,16*64*64)
In [10]: net(xx).shape
Out[10]: (5, 16, 64, 64)
Regards,