Gluon pretrained model layer access and usage

Hello, are there sample notebooks or other code showing usage in Gluon for the following:

  1. Simply load a pretrained model, e.g. ResNet, and load an image and get a prediction about it (I know about the Gluon Model Zoo, but am looking for a complete working example);
  2. Load a pretrained model, get a reference to one of its layers (e.g. last fully connected layer), then send data through the net and get the output.

No training is being done. For the second case, I’d like to convert the output to zeros and ones after I get the output.

For 1. Reading the model zoo documentation should get you most of the way. Here is a complete working example:

import mxnet as mx
from mxnet import gluon, nd
from mxnet.gluon.model_zoo import vision
import numpy as np
import wget
import json

# Get the image net labels
wget.download('https://gist.githubusercontent.com/ThomasDelteil/3bc3a3a7e9601b2a67646b4813981a40/raw/6fe3860887a3ac6ea1d8301531b57603909b6ff3/image_net_labels.json')
categories = json.load(open('image_net_labels.json', 'r'))

# Get the model
ctx = mx.cpu() #set the context
resnet18 = vision.resnet18_v1(pretrained=True, ctx=ctx) #download the pre-trained model

# load and pre-process the image
image_path = 'dog.jpg'
image = mx.image.imdecode(open(image_path, 'rb').read()).astype(np.float32)
resized = mx.image.resize_short(image, 224) #minimum 224x224 images
cropped, crop_info = mx.image.center_crop(resized, (224, 224))
normalized = mx.image.color_normalize(cropped/255,
                                      mean=mx.nd.array([0.485, 0.456, 0.406]),
                                      std=mx.nd.array([0.229, 0.224, 0.225])) 
# the network expect batches of the form (N,3,224,224)
flipped_axis = normalized.transpose((2,0,1))  # Flipping from (224, 224, 3) to (3, 224, 224)
batchified = flipped_axis.expand_dims(axis=0) # change the shape from (3, 224, 224) to (1, 3, 224, 224)

# Run the predictions
predictions = resnet18(batchified)
class_predicted = int(nd.argmax(predictions, axis=1).asscalar())

# Get the label of the class
print(categories[class_predicted])

For 2. As far as I know there are no easy way to get access to a given layer in a pre-trained model from the model zoo. You would need to export the network using net.export. Load it in the symbolic API. Find the layers that you are interested in. Load this symbol and the corresponding weights in a symbolic block.

Here is some code that picks a layer in a pre-trained gluon model, save it to symbolic world, cherry pick a given layer, and load the new symbol and params in gluon. I am not sure what you mean by convert the output to zeros and ones, but I am sure you can use numpy operations to do what you want. The previous example gives you by default the last fully connected layer.

# We hybridize the model to get it as a symbol we can then export
resnet18.hybridize()
# We need to run at least one batch to have the cached computation graph
output = resnet18(batchified)
# We export the model. This creates a 0000.params and -symbol.json files
resnet18.export('resnet18')
# We load the symbol and params in the symbolic world
sym = mx.sym.load('resnet18-symbol.json')
# Check the layers, you can also visualize it using mx.visualization to make it easier
# to find the layer you are interested in
print(sym.get_internals())
# We pick a given layer
new_sym = sym.get_internals()['resnetv14_stage3_relu1_fwd_output']
# We load the the symbols and parameters in a gluon symbol block
net = gluon.nn.SymbolBlock(outputs=new_sym, inputs=mx.sym.var('data'))
# Set the params
net.collect_params().load('resnet18-0000.params', ctx=ctx, allow_extra=True)
# We test it
print(net(batchified).shape)

If you just want to use the pre-trained model as a featurizer it is even simpler, just use:

resnet18.features(batchified)

5 Likes

Thanks, this is very helpful. What I am trying to do is basically the feature extraction part of this Module API tutorial but in Gluon, https://mxnet.incubator.apache.org/tutorials/python/predict_image.html. Then take the features that are output from the last FC layer and convert them to “neural codes”, i.e. a binary representation that can be used for hashing in a hash table so similar images can be looked up efficiently. (Trying to do cosine similarity over many images would likely be slow.)

1 Like

according the recent version of gluon, the parameters of function load() have been changed, make sure it’s alright

An example code is also here.

# To extract the feature from fc1 and fc2 layers of AlexNet:
alexnet = gluon.model_zoo.vision.alexnet(pretrained=True, ctx=mx.cpu(), prefix='model_')
inputs = mx.sym.var('data')
out = alexnet(inputs)
internals = out.get_internals()
print(internals.list_outputs())
outputs = [internals['model_dense0_relu_fwd_output'],
# Create SymbolBlock that shares parameters with alexnet
feat_model = gluon.SymbolBlock(outputs, inputs, params=alexnet.collect_params())
x = mx.nd.random.normal(shape=(16, 3, 224, 224))
print(feat_model(x))

Can you use this model for training and not just inference ?

 sym, arg_params, aux_params = mx.model.load_checkpoint(
    prefix, 0)
# Dropping the loss from model
new_sym = sym.get_children()[0]
net = gluon.nn.SymbolBlock(outputs=new_sym, inputs=mx.sym.var("data"))
net.initialize(mx.init.Normal(0.002), ctx=ctx)

I have this model but cannot train it … The model’s weights won’t change

Hello,
It is great! Quite new to Mxnet, I just want to extract the features for a set of images. How can we do it in bulk mode?
thanks a lot