How to create customized recurrent cells?

Hi! I am new to MXNet, and when I attempt to create a customized recurrent cell rather than use those provided by the MXNet library, I get stuck.
Assume there is a function s' = f(s, x) that takes a state s and an element x as inputs and returns a new state s'. Using only the interfaces provided by the MXNet, is there any way to efficiently (say, without for loops over the sequence) apply this function recurrently to a sequence of elements x1, x2, ..., xn and get the corresponding state sequence s1, s2, ..., sn?

I have just noticed one thing that might possibly be a solution: mxnet.ndarray.contrib.while_loop: As suggested, however, this interface is only an experiment new feature, and according to my observation of the source code, it seems to be implemented with Python loops. I wonder

  1. How stable would the API be?
  2. How efficient is it compared with things like mxnet.gluon.rnn.RNN which I see is implemented with cuDNN (cuDNN RNN implementation)?

I think I have found a related article:

@jason_yu control flow operators are a work in progress in MXNet, they will remain experimental for a bit until we have concluded they are stable enough. (1)

the RNN cell has been optimized to run on GPU through the cudnn library, indeed there would be a performance overhead if you implemented a RNN cell yourself (2)

One option is to use Gluon and python flow operators directly. Benchmark it, it might be fast enough for your use-case.

1 Like

Thanks! I have tried these control flow operators, and their performance is indeed quite satisfactory even without hybridization.

1 Like