MXNet and sparse attention

Hello everyone,
I am trying to implement a LogSparse Transformer.
Like this one, or this one

I managed to reduce memory consumption through Module API and the use of MXNet Memonger, however, I have some issues about sparse attention.

I feel like we miss something in the nd.array sparse api. The CSR format can only handle 2D data and the transformers produces more than 2D tensors.

Am I wrong ? Does MXNet has all the tools to implement sparse attention ? Maybe there are some tricks to overcome the limitations.