Hybridized RNN State Initialization Error

I’m using recurrent neural networks from mxnet.gluon.rnn to build a simple language model:

import mxnet as mx

class MyRNN(mx.gluon.HybridBlock):
    def __init__(self, vocab_dim, emb_dim, hidden_size, n_layers, **kwargs):
        super().__init__(**kwargs)
        with self.name_scope():
            self.embedding = mx.gluon.nn.Embedding(vocab_dim, emb_dim)
            self.rnn = mx.gluon.rnn.LSTM(hidden_size=hidden_size,
                                         num_layers=n_layers)
            self.output = mx.gluon.nn.Dense(vocab_dim, flatten=False)
    def hybrid_forward(self, f, x, states=None):
        if states is None:
            if f == mx.nd:
                states = self.rnn.begin_state(batch_size=x.shape[0],
                                              func=mx.nd.zeros)
            else:
                states = self.rnn.begin_state(func=mx.sym.zeros)
        x = f.transpose(x)
        x = self.embedding(x)
        x, states = self.rnn(x, states)
        return self.output(x).swapaxes(0, 1), states

For Symbol API, begin_state does not require batch_size (defaults to 0) and in fact, we could not
infer the shape of input x or have an integer batch_size as a formal parameter of hybrid_forward. Still when hybridized, forward propagation initializes exactly zero-shaped list, and subsequent operations fail:

>>> x = mx.nd.random.randint(0, 10, shape=(3, 5))
>>> rnn = MyRNN(10, 4, 8, 2)
>>> rnn.hybridize()
>>> rnn.initialize()
>>> rnn(x)
Traceback (most recent call last):
  ...
mxnet.gluon.parameter.DeferredInitializationError: Parameter 'myrnn0_lstm0_l0_i2h_weight' has not been initialized yet because initialization was deferred. Actual initialization happens during the first forward pass. Please pass one batch of data through the network before accessing Parameters. You can also avoid deferred initialization by specifying in_units, num_features, etc., for network layers.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  ...
mxnet.base.MXNetError: MXNetError: Error in operator myrnn0_lstm0_rnn0: Shape inconsistent, Provided = [2,0,8], inferred shape=(2,3,8)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  ...
ValueError: Deferred initialization failed because shape cannot be inferred. MXNetError: Error in operator myrnn0_lstm0_rnn0: Shape inconsistent, Provided = [2,0,8], inferred shape=(2,3,8)

I guess the only way to avoid this is to initialize states out of hybrid_forward scope. Anyway, the error looks like a bug in mxnet.gluon.rnn._RNNLayer.

@sanjaradylov Please help to file an issue on the mxnet github repository.