Transpose x during forward

I created a dummy Block, which takes a 2D array, performs a 2D convolution, and feeds the convoluted output to a fully connected layer :

class DummyBlock(gluon.Block):
    def __init__(self, **kwargs):
        super(DummyBlock, self).__init__(**kwargs)
        with self.name_scope():
            self.conv = gluon.nn.Conv2D(channels=3, kernel_size=(1, 5), strides=(1, -1), activation='relu')
            self.fc = gluon.nn.Dense(5)

    def forward(self, x):
        # 2D convolution: <NDArray 2x3x4x1 @cpu(0)>
        x = self.conv(x)
        x = self.fc(x)
        return x

I tested DummyBlock using the following code:

import numpy as np
import mxnet as mx
from mxnet import gluon, nd, autograd

X = nd.array([
    [[1,0,0,0,0],[2,0,0,0,0],[3,0,0,0,0],[4,0,0,0,0]],
    [[0,1,0,0,0],[0,2,0,0,0],[0,3,0,0,0],[0,4,0,0,0]],
    [[0,0,1,0,0],[0,0,2,0,0],[0,0,3,0,0],[0,0,4,0,0]],
    [[0,0,0,1,0],[0,0,0,2,0],[0,0,0,3,0],[0,0,0,4,0]],
    [[0,0,0,0,1],[0,0,0,0,2],[0,0,0,0,3],[0,0,0,0,4]]
])

Y = nd.array([0,1,2,3,4])

ctx = mx.cpu()
net = DummyBlock()
net.collect_params().initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.1})

batch_size = 2
loss_func = gluon.loss.SoftmaxCrossEntropyLoss()
data = gluon.data.DataLoader(gluon.data.ArrayDataset(X, Y), batch_size=batch_size)

for i, (data, label) in enumerate(data):
    data = data.as_in_context(ctx)
    data = data.reshape((0, 1, data.shape[1], data.shape[2]))
    label = label.as_in_context(ctx)
    with autograd.record():
        output = net(data)
        loss = loss_func(output, label)
        loss.backward()
    trainer.step(data.shape[0])

Besides the fact that it doesn’t do anything useful, this runs fine without any error. When I transpose x and feed it into the fully connected layer:

    def forward(self, x):
        # 2D convolution: <NDArray 2x3x4x1 @cpu(0)>
        x = self.conv(x)

        # transpose: <NDArray 2x1x4x3 @cpu(0)>
        x = nd.array([nd.transpose(a).asnumpy() for a in x])
        
        x = self.fc(x)
        return x

it fails after the first batch and gives the following error message:

Traceback (most recent call last):
  File "/Users/jdchoi/workspace/elit/elit/component/postag.py", line 519, in <module>
    trainer.step(data.shape[0])
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/mxnet/gluon/trainer.py", line 147, in step
    %(param.name, str(data.context)))
UserWarning: Gradient of Parameter `dummyblock0_conv0_weight` on context cpu(0) has not been updated by backward since last `step`. This could mean a bug in your model that maked it only use a subset of the Parameters (Blocks) for this iteration. If you are intentionally only using a subset, call step with ignore_stale_grad=True to suppress this warning and skip updating of Parameters with stale gradient

In fact, it gives the same error message if I make a copy of x and pass it to the fully connected layer:

    def forward(self, x):
        # 2D convolution: <NDArray 2x3x4x1 @cpu(0)>
        x = self.conv(x)
        x = x.copy()
        x = self.fc(x)
        return x

When I reshape x and copy transposed values to x, it runs fine:

    def forward(self, x):
        # 2D convolution: <NDArray 2x3x4x1 @cpu(0)>
        x = self.conv(x)

        # reshape and copy: <NDArray 2x1x4x3 @cpu(0)>
        y = [nd.transpose(a).asnumpy() for a in x]
        x = x.reshape((-1, 1, x.shape[2], x.shape[1]))
        for i in range(len(x)): x[i] = y[i]

        x = self.fc(x)
        return x

This is very hacky and not efficient. Could someone explain to me why the first two approaches fail? I often need to transpose the output of the convolution (or even concatenate another vector with the output), and feed into the next layer, so it will be great to know if I could do with with Gluon. Thank you.

x = nd.array([nd.transpose(a).asnumpy() for a in x]) This breaks the autograd chain. Because you converted to numpy, mxnet doesn’t know the relationship between a and x anymore. So gradient calculation stops at x

You should do x = nd.stack(*[nd.transpose(a) for a in x])
Or better
x = nd.transpose(x, (0, 2, 1))

1 Like

@piiswrong; I don’t fully understand how autograd chains the variables but will dig into the code. Thank you very much.

So, I tried the above two approaches. nd.stack(*[nd.transpose(a) for a in x]) worked well whereas nd.transpose(x, (0, 2, 1)) gave the following errors:

[16:04:01] /Users/travis/build/dmlc/mxnet-distro/mxnet-build/dmlc-core/include/dmlc/logging.h:308: [16:04:01] src/operator/tensor/./matrix_op-inl.h:306: Check failed: shp.ndim() == param.axes.ndim() (4 vs. 3) 

Stack trace returned 7 entries:
[bt] (0) 0   libmxnet.so                         0x0000000104dfa685 _ZN4dmlc15LogMessageFatalD2Ev + 37
[bt] (1) 1   libmxnet.so                         0x00000001054e212f _ZN5mxnet2op14TransposeShapeERKN4nnvm9NodeAttrsEPNSt3__16vectorINS1_6TShapeENS5_9allocatorIS7_EEEESB_ + 1183
[bt] (2) 2   libmxnet.so                         0x0000000105591f7d _Z12SetShapeTypePKN4nnvm2OpERKNS_9NodeAttrsERKN5mxnet7ContextERKNSt3__16vectorINS6_7NDArrayENSA_9allocatorISC_EEEEPSF_ + 1309
[bt] (3) 3   libmxnet.so                         0x000000010559886f _Z20ImperativeInvokeImplRKN5mxnet7ContextERKN4nnvm9NodeAttrsEPNSt3__16vectorINS_7NDArrayENS7_9allocatorIS9_EEEESD_ + 815
[bt] (4) 4   libmxnet.so                         0x0000000105599a41 MXImperativeInvoke + 433
[bt] (5) 5   _ctypes.cpython-36m-darwin.so       0x0000000101fe1247 ffi_call_unix64 + 79
[bt] (6) 6   Python                              0x00007fff5bffdbe0 Python + 140730441915360

Traceback (most recent call last):
  File "/Users/jdchoi/workspace/elit/elit/component/postag.py", line 333, in <module>
    output = net(data)
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/mxnet/gluon/block.py", line 268, in __call__
    return self.forward(*args)
  File "/Users/jdchoi/workspace/elit/elit/component/postag.py", line 305, in forward
    x = nd.transpose(x, (0, 2, 1))
  File "<string>", line 13, in transpose
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/mxnet/_ctypes/ndarray.py", line 89, in _imperative_invoke
    c_array(ctypes.c_char_p, [c_str(str(val)) for val in vals])))
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/mxnet/base.py", line 129, in check_call
    raise MXNetError(py_str(_LIB.MXGetLastError()))
mxnet.base.MXNetError: [16:04:01] src/operator/tensor/./matrix_op-inl.h:306: Check failed: shp.ndim() == param.axes.ndim() (4 vs. 3) 

Stack trace returned 7 entries:
[bt] (0) 0   libmxnet.so                         0x0000000104dfa685 _ZN4dmlc15LogMessageFatalD2Ev + 37
[bt] (1) 1   libmxnet.so                         0x00000001054e212f _ZN5mxnet2op14TransposeShapeERKN4nnvm9NodeAttrsEPNSt3__16vectorINS1_6TShapeENS5_9allocatorIS7_EEEESB_ + 1183
[bt] (2) 2   libmxnet.so                         0x0000000105591f7d _Z12SetShapeTypePKN4nnvm2OpERKNS_9NodeAttrsERKN5mxnet7ContextERKNSt3__16vectorINS6_7NDArrayENSA_9allocatorISC_EEEEPSF_ + 1309
[bt] (3) 3   libmxnet.so                         0x000000010559886f _Z20ImperativeInvokeImplRKN5mxnet7ContextERKN4nnvm9NodeAttrsEPNSt3__16vectorINS_7NDArrayENS7_9allocatorIS9_EEEESD_ + 815
[bt] (4) 4   libmxnet.so                         0x0000000105599a41 MXImperativeInvoke + 433
[bt] (5) 5   _ctypes.cpython-36m-darwin.so       0x0000000101fe1247 ffi_call_unix64 + 79
[bt] (6) 6   Python                              0x00007fff5bffdbe0 Python + 140730441915360

It’s pretty hard to trace the errors from libmxnet.so especially when the logic seems to be correct. Any idea? Thanks ahead.

How may dimensions does your data have? If its 4d then you need something like mx.nd.transpose(x, (0, 2, 1, 3)) See doc on transpose for more info

This works; I thought the second parameter was a shape but it is actually an order of axes! Thank you very much!