Object detection using a model from model ZOO, detect only one class

Hi, I was wondering how I can use one of the pretrained models from model ZOO to detect only one specific class of objects.

For example, if I’m using ‘ssd_512_resnet50_v1_coco’, I want to detect only one of the 80 COCO classes in images (say, “person”).
I’ve been trying to slice the array resulting from inference:

class_IDs, scores, bounding_boxes = net(x)

but the problem is that boolean masks are not supported, so if I try to select:

class_IDs[0][class_IDs[0]==0]

i t doesn’t work (the internal condition returns an array of 0s and 1s instead of True and False).

I managed to do it by using the following code:

    class_IDs, scores, bounding_boxes = net(x)

    selected_class_ID = []
    selected_scores = []
    selected_bbox = []

    for ID, score, box in zip(class_IDs[0].asnumpy(), scores[0].asnumpy(), bounding_boxes[0].asnumpy()):
        # using ID== 0 because is the classID for "person" in COCO
        if (ID==0) & (score>0.45):
            selected_class_ID.append(ID)
            selected_scores.append(score)
            selected_bbox.append(box)
    selected_class_ID = nd.array(selected_class_ID)
    selected_scores = nd.array(selected_scores)
    selected_bbox = nd.array(selected_bbox)

    ax = utils.viz.plot_bbox(img, selected_bbox, selected_scores,
                                     selected_class_ID, thresh=.4, class_names=net.classes)

But it is kinda involved…is there anything else I could do?

try this code.

from gluoncv import model_zoo, data, utils

from matplotlib import pyplot as plt

net = model_zoo.get_model('ssd_512_resnet50_v1_coco', pretrained=True)

net.reset_class(["person"], reuse_weights=["person"])

im_fname = utils.download('https://raw.githubusercontent.com/zhreshold/mxnet-ssd/master/data/demo/person.jpg',

                      path='person.jpg')

x, img = data.transforms.presets.ssd.load_test(im_fname, short=512)

class_IDs, scores, bounding_boxs = net(x)

ax = utils.viz.plot_bbox(img, bounding_boxs[0], scores[0],

                     class_IDs[0], thresh=0.8, class_names=net.classes)

plt.axis('off')

plt.show()