Given a generic pretrained image classification model with softmax output and N classes, I want to compute the gradient of softmax output j (0 <= j < N) with respect to the input image pixel values. My approach so far has been the following:
import numpy as np
import mxnet as mx
sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, epoch) #load pretrained model
all_layers = sym.get_internals()
net = all_layers['fc_output'] #Include all layers up to but not including the SoftmaxOutput layer. I don't want the cross entropy loss function in the symbol, just the softmax outputs
net = mx.symbol.softmax(data = net, name='softmax_label') #add back softmax activation function to last fully connected layer output without adding cross entropy loss
my_model = mx.mod.Module(symbol=net, context = mx.cpu())
my_model.bind(data_shapes = [('data', (1,3,250,250))], inputs_need_grad = True, for_training = True) #input is 250x250 RGB image
my_model.set_params(arg_params, aux_params, allow_missing = False) #set weights to pretrained values
img = get_image('/home/ubuntu/data-2/test_image_normalized.jpg', 250, 'resize') #numpy array with shape (1,3,250,250)
x = nx.nd.array(img)
d = mx.io.DataBatch([x])
label_index = 1 #specifies which softmax output we want to compute the gradient of
my_model.forward(d)
my_model.backward(out_grad=mx.ndarray.one_hot(indices = mx.nd.array([label_index]), depth = N)
) #this doesn't execute as written
w = my_model.get_input_grads()[0].asnumpy()
Unfortunately, the backward() call fails and yields the following error message:
[18:40:59] /home/ubuntu/src/mxnet/dmlc-core/include/dmlc/./logging.h:308: [18:40:59] src/ndarray/ndarray.cc:348: Check failed: from.shape() == to->shape() operands shape mismatchfrom.shape = (1,) to.shape=(1,2)
Stack trace returned 10 entries:
[bt] (0) /usr/local/lib/python3.5/dist-packages/mxnet-0.11.0-py3.5.egg/mxnet/libmxnet.so(_ZN4dmlc15LogMessageFatalD1Ev+0x3c) [0x7fbf65fdef0c]
[bt] (1) /usr/local/lib/python3.5/dist-packages/mxnet-0.11.0-py3.5.egg/mxnet/libmxnet.so(_ZN5mxnet10CopyFromToERKNS_7NDArrayEPS0_i+0x546) [0x7fbf66cc12f6]
[bt] (2) /usr/local/lib/python3.5/dist-packages/mxnet-0.11.0-py3.5.egg/mxnet/libmxnet.so(_ZN5mxnet4exec13GraphExecutor8BackwardERKSt6vectorINS_7NDArrayESaIS3_EEb+0xb3) [0x7fbf67082173]
[bt] (3) /usr/local/lib/python3.5/dist-packages/mxnet-0.11.0-py3.5.egg/mxnet/libmxnet.so(MXExecutorBackwardEx+0x314) [0x7fbf6704b4f4]
[bt] (4) /usr/lib/python3.5/lib-dynload/_ctypes.cpython-35m-x86_64-linux-gnu.so(ffi_call_unix64+0x4c) [0x7fbf99f90e20]
[bt] (5) /usr/lib/python3.5/lib-dynload/_ctypes.cpython-35m-x86_64-linux-gnu.so(ffi_call+0x2eb) [0x7fbf99f9088b]
[bt] (6) /usr/lib/python3.5/lib-dynload/_ctypes.cpython-35m-x86_64-linux-gnu.so(_ctypes_callproc+0x49a) [0x7fbf99f8b01a]
[bt] (7) /usr/lib/python3.5/lib-dynload/_ctypes.cpython-35m-x86_64-linux-gnu.so(+0x9fcb) [0x7fbf99f7efcb]
[bt] (8) /usr/bin/python3(PyObject_Call+0x47) [0x5b7167]
[bt] (9) /usr/bin/python3(PyEval_EvalFrameEx+0x4f06) [0x528d06]
---------------------------------------------------------------------------
MXNetError Traceback (most recent call last)
<ipython-input-167-3a4eabe9e4b8> in <module>()
----> 1 my_model.backward(h)
/usr/local/lib/python3.5/dist-packages/mxnet-0.11.0-py3.5.egg/mxnet/module/module.py in backward(self, out_grads)
611 """
612 assert self.binded and self.params_initialized
--> 613 self._exec_group.backward(out_grads=out_grads)
614
615 def update(self):
/usr/local/lib/python3.5/dist-packages/mxnet-0.11.0-py3.5.egg/mxnet/module/executor_group.py in backward(self, out_grads)
545 else:
546 out_grads_slice.append(grad.copyto(self.contexts[i]))
--> 547 exec_.backward(out_grads=out_grads_slice)
548
549 def update_metric(self, eval_metric, labels):
/usr/local/lib/python3.5/dist-packages/mxnet-0.11.0-py3.5.egg/mxnet/executor.py in backward(self, out_grads, is_train)
229 mx_uint(len(out_grads)),
230 ndarray,
--> 231 ctypes.c_int(is_train)))
232
233 def set_monitor_callback(self, callback):
/usr/local/lib/python3.5/dist-packages/mxnet-0.11.0-py3.5.egg/mxnet/base.py in check_call(ret)
127 """
128 if ret != 0:
--> 129 raise MXNetError(py_str(_LIB.MXGetLastError()))
130
131 if sys.version_info[0] < 3:
MXNetError: [18:40:59] src/ndarray/ndarray.cc:348: Check failed: from.shape() == to->shape() operands shape mismatchfrom.shape = (1,) to.shape=(1,2)
Stack trace returned 10 entries:
[bt] (0) /usr/local/lib/python3.5/dist-packages/mxnet-0.11.0-py3.5.egg/mxnet/libmxnet.so(_ZN4dmlc15LogMessageFatalD1Ev+0x3c) [0x7fbf65fdef0c]
[bt] (1) /usr/local/lib/python3.5/dist-packages/mxnet-0.11.0-py3.5.egg/mxnet/libmxnet.so(_ZN5mxnet10CopyFromToERKNS_7NDArrayEPS0_i+0x546) [0x7fbf66cc12f6]
[bt] (2) /usr/local/lib/python3.5/dist-packages/mxnet-0.11.0-py3.5.egg/mxnet/libmxnet.so(_ZN5mxnet4exec13GraphExecutor8BackwardERKSt6vectorINS_7NDArrayESaIS3_EEb+0xb3) [0x7fbf67082173]
[bt] (3) /usr/local/lib/python3.5/dist-packages/mxnet-0.11.0-py3.5.egg/mxnet/libmxnet.so(MXExecutorBackwardEx+0x314) [0x7fbf6704b4f4]
[bt] (4) /usr/lib/python3.5/lib-dynload/_ctypes.cpython-35m-x86_64-linux-gnu.so(ffi_call_unix64+0x4c) [0x7fbf99f90e20]
[bt] (5) /usr/lib/python3.5/lib-dynload/_ctypes.cpython-35m-x86_64-linux-gnu.so(ffi_call+0x2eb) [0x7fbf99f9088b]
[bt] (6) /usr/lib/python3.5/lib-dynload/_ctypes.cpython-35m-x86_64-linux-gnu.so(_ctypes_callproc+0x49a) [0x7fbf99f8b01a]
[bt] (7) /usr/lib/python3.5/lib-dynload/_ctypes.cpython-35m-x86_64-linux-gnu.so(+0x9fcb) [0x7fbf99f7efcb]
[bt] (8) /usr/bin/python3(PyObject_Call+0x47) [0x5b7167]
[bt] (9) /usr/bin/python3(PyEval_EvalFrameEx+0x4f06) [0x528d06]
Any help to correct my approach here would be much appreciated.