Timedistributed style layer in mx.sym?

I’m looking for the functionality of Keras’ timedistributed layer in mxnet’s symbol workflow. For example run data of (batch_size, seq_length, channels, height, width) through a bunch of 2D convolutional layers (that are the same layers with the same states), then pool the results on seq_length at the end.

Anyone aware of a way to do this?

Thanks :slight_smile:

MXNet does not have a time-distributed layer like Keras, but you can use a Dense layer and set flatten=false. The Dense layer works then the same like Keras’ time-distributed layer.

1 Like

Thanks for the response, how does this work in the sense of a wrapper? could you give a brief example of its usage in the case where you have a symbol that is wrapped by this dense layer.

Thanks

You could do something like the following:

data = mx.sym.Variable('data')
fc1 = mx.sym.FullyConnected(data=data, flatten=False, num_hidden=20,  name ='fc1')

assuming that the place holder of your input data is directly followed by the Dense layer. In the Dense layer you have to set flatten=False

1 Like

But what about for non-dense layers such as wrapping a convolutional layer or an entire model itself like as can be done with TimeDistributed?

model = Sequential()
model.add(TimeDistributed(Conv2D(64, (3, 3)),
                          input_shape=(10, 299, 299, 3)))

Edit: Doesn’t have to be sym, can be gluon.

I think I have hacked something together for reference if anyone is interested:

class TimeDistributed(HybridBlock):
    def __init__(self, model, **kwargs):
        super(TimeDistributed, self).__init__(**kwargs)
        with self.name_scope():
            self.model = model

    def apply_model(self, x, _):
        return self.model(x), []

    def hybrid_forward(self, F, x):
        x = F.swapaxes(x, 0, 1)  # swap batch and seqlen channels
        x, _ = F.contrib.foreach(self.apply_model, x, [])  # runs on first channel, which is now seqlen
        x = F.swapaxes(x, 0, 1)  # swap seqlen and batch channels
        return x

Seems to fail on backward when hybridize is used for some reason though.