Change layer output without custom layer

Hi, examples for mxnet version 2.XX You just need to do layer.reshape, which is similar to numpy operations.

version1:

In [9]: import mxnet as mx 
   ...: mx.npx.set_np() 
   ...:  
   ...: class ReshapeExample(mx.gluon.nn.HybridBlock): 
   ...:     def __init__(self, **kwards): 
   ...:         super().__init__(**kwards) 
   ...:          
   ...:         self.dense = mx.gluon.nn.Dense(units=64*64*16) 
   ...:         self.conv = mx.gluon.nn.Conv2D(channels=32,kernel_size=3,padding=1) 
   ...:     def forward(self, input): 
   ...:         out1 = self.dense(input) 
   ...:         print ("shape after dense::{}".format(out1.shape)) 
   ...:         out1 = out1.reshape(-1,16,64,64) # <=== this does the job  
   ...:         print ("shape after reshape::{}".format(out1.shape)) 
   ...:         out2  = self.conv(out1) 
   ...:         print ("shape after conv::{}".format(out2.shape)) 
   ...:         return out2  
   ...:                                                                                                                                                    

In [10]: net = ReshapeExample() 
    ...: net.initialize()                                                                                                                                  

In [11]: xx = mx.np.random.rand(5,32)                                                                                                                      

In [12]: out = net(xx)                                                                                                                                     
shape after dense::(5, 65536)
shape after reshape::(5, 16, 64, 64)
shape after conv::(5, 32, 64, 64)

In [13]:  

version 2:

In [8]: net = mx.gluon.nn.HybridLambda(lambda F, x: x.reshape(-1,16,64,64))                                                                                

In [9]: xx = mx.np.random.rand(5,16*64*64)                                                                                                                 

In [10]: net(xx).shape                                                                                                                                     
Out[10]: (5, 16, 64, 64)

Regards,

1 Like