One method of doing this is to get the weights of each parameter from the pre-trained model (using param.data()
) and setting the weights of each parameter from the custom model (using param.set_data()
).
You need to be careful with:
- Mapping the parameters from the pre-trained model to the custom model. You might not have the same parameter names for example. In your case you have
conv0_weight
but the pre-trained model has something called alexnet0_conv0_weight
.
- Shape alignment. You need to have exactly the same shaped parameters in the custom model as in the pre-trained model. You don’t seem to have the same, the last layer has 10 hidden units instead of 1000 for example. And initialising the weights first.
A code example of this case would look something like:
import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn
from mxnet.gluon.model_zoo import vision
# define custom model
alex_net = gluon.nn.Sequential()
alex_net.add(nn.Conv2D(64, kernel_size=11, strides=4,
padding=2, activation='relu'))
alex_net.add(nn.MaxPool2D(pool_size=3, strides=2))
alex_net.add(nn.Conv2D(192, kernel_size=5, padding=2,
activation='relu'))
alex_net.add(nn.MaxPool2D(pool_size=3, strides=2))
alex_net.add(nn.Conv2D(384, kernel_size=3, padding=1,
activation='relu'))
alex_net.add(nn.Conv2D(256, kernel_size=3, padding=1,
activation='relu'))
alex_net.add(nn.Conv2D(256, kernel_size=3, padding=1,
activation='relu'))
alex_net.add(nn.MaxPool2D(pool_size=3, strides=2))
alex_net.add(nn.Flatten())
alex_net.add(nn.Dense(4096, activation='relu'))
alex_net.add(nn.Dropout(0.5))
alex_net.add(nn.Dense(4096, activation='relu'))
alex_net.add(nn.Dropout(0.5))
alex_net.add(nn.Dense(10))
# must initialize parameters before changing,
# so pass through example batch since lazy initialization
alex_net.initialize()
alex_net(mx.nd.random.uniform(shape=(10, 3, 224, 224)))
# load pretrained model
pretrained_alex_net = vision.alexnet(pretrained=True)
# create parameter dictionaries
model_params = {name: param for name, param in alex_net.collect_params().items()}
pretrained_model_params = {name: param for name, param in pretrained_alex_net.collect_params().items()}
Before Replacing Parameters
# sample of randomly initialised weights
print(model_params['conv0_weight'].data()[0, 0, :3, :3])
[[-0.04774426 -0.02267893 -0.05454748]
[ 0.02275376 0.04493906 -0.06809997]
[-0.00438883 0.00134741 0.06674656]]
<NDArray 3x3 @cpu(0)>
# sample of pre-trained weights
print(pretrained_model_params['alexnet0_conv0_weight'].data()[0, 0, :3, :3])
[[0.11863963 0.09406868 0.09543519]
[0.07488242 0.03894044 0.05297883]
[0.07542486 0.03877855 0.05493048]]
<NDArray 3x3 @cpu(0)>
Replacing Parameters
for name, param in model_params.items():
lookup_name = 'alexnet0_' + name
if lookup_name in pretrained_model_params:
lookup_param = pretrained_model_params[lookup_name]
if lookup_param.shape == param.shape:
param.set_data(lookup_param.data())
print("Sucessful match for {}.".format(name))
else:
print("Error: Shape mismatch for {}. {}!={}".format(name, lookup_param.shape, param.shape))
else:
print("Error: Couldn't find match for {}.".format(name))
Sucessful match for conv0_weight.
Sucessful match for conv0_bias.
Sucessful match for conv1_weight.
Sucessful match for conv1_bias.
Sucessful match for conv2_weight.
Sucessful match for conv2_bias.
Sucessful match for conv3_weight.
Sucessful match for conv3_bias.
Sucessful match for conv4_weight.
Sucessful match for conv4_bias.
Sucessful match for dense0_weight.
Sucessful match for dense0_bias.
Sucessful match for dense1_weight.
Sucessful match for dense1_bias.
Error: Shape mismatch for dense2_weight. (1000, 4096)!=(10, 4096)
Error: Shape mismatch for dense2_bias. (1000,)!=(10,)
So it looks like you’re training a model with a different number of classes here, 10 instead of 1000, which is why the last few layers are not matching the pre-trained model.
After Replacing Parameters
# sample of weights
print(model_params['conv0_weight'].data()[0, 0, :3, :3])
[[0.11863963 0.09406868 0.09543519]
[0.07488242 0.03894044 0.05297883]
[0.07542486 0.03877855 0.05493048]]
<NDArray 3x3 @cpu(0)>
# sample of pre-trained weights
print(pretrained_model_params['alexnet0_conv0_weight'].data()[0, 0, :3, :3])
[[0.11863963 0.09406868 0.09543519]
[0.07488242 0.03894044 0.05297883]
[0.07542486 0.03877855 0.05493048]]
<NDArray 3x3 @cpu(0)>