Try this:
import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn
from mxnet.ndarray import tanh, relu
# define a layer that will raise an exception on forward
class BrokenConv(nn.Conv2D):
def forward(*x, **y):
raise Exception()
def hybrid_forward(*x, **y):
raise Exception()
# define standard LeNet
class LeNet(gluon.HybridBlock):
def __init__(self, kernel_size=(5,5), num_filters=(20, 50), pool_size=(2,2), strides=(2,2), ff_hidden=500, **kwargs):
super(LeNet, self).__init__(**kwargs)
with self.name_scope():
self.conv1 = nn.Conv2D(num_filters[0], kernel_size=kernel_size)
self.pool1 = nn.MaxPool2D(pool_size=pool_size, strides=strides)
self.conv2 = nn.Conv2D(num_filters[1], kernel_size=kernel_size)
self.pool2 = nn.MaxPool2D(pool_size=pool_size, strides=strides)
self.fc1 = nn.Dense(ff_hidden)
self.fc2 = nn.Dense(10)
def forward(self, x):
x = self.pool1(tanh(self.conv1(x)))
x = self.pool2(tanh(self.conv2(x)))
x = x.reshape((0, -1))
x = tanh(self.fc1(x))
x = tanh(self.fc2(x))
return x
# get a net
net = LeNet()
ctx = [mx.cpu()]
net.initialize(ctx=ctx)
# run a forward pass
data = mx.nd.ones((1024, 3, 32, 32))
def replace_conv2D(net):
for key, layer in net._children.items():
if isinstance(layer, gluon.nn.Conv2D):
new_conv = BrokenConv(
channels=layer._channels,
kernel_size=layer._kwargs['kernel'],
strides=layer._kwargs['stride'],
padding=layer._kwargs['pad'],
use_bias=True)
with net.name_scope():
if hasattr(net, key):
setattr(net, key, new_conv)
net.register_child(new_conv, key)
new_conv.initialize(mx.init.Xavier())
print('Replacing layer '+key)
# Recursively replace layers
else:
replace_conv2D(layer)
replace_conv2D(net)
out = net(data)
<ipython-input-210-55f25459abc5> in forward(*x, **y)
7 class BrokenConv(nn.Conv2D):
8 def forward(*x, **y):
----> 9 raise Exception()
10 def hybrid_forward(*x, **y):
11 raise Exception()
Exception:
Updated to set the attribute as well. This is due to the duality between registering the children for example for a HybridSequential where they are stored only in the _children ordered dict, and custom blocks that have blocks as properties, and forward passes that reference them directly by attritube names.
Obviously this is a bit hacky and might not resist the test of time. The ideal way is to reconstruct a new network and cherry pick what you need on the other one by iterating through it, which is easy to do if you know the structure of your network, but quite hacky if you were to build a completely generic method to do that.