Lazy update with Adam optimizer is much slower for sparse input

Hi,

I wanted to train a neural network with sparse input. So I build the network in following way:

data = mx.symbol.Variable('data', stype='csr')
# hidden layer
fc1_weight = mx.symbol.Variable("fc1_weight", shape=(100000, 200), stype='row_sparse')
net = mx.symbol.dot(data, fc1_weight)
fc1_bias = mx.symbol.Variable(name='fc1_bias', shape=(200,))
net = mx.symbol.broadcast_add(lhs=net, rhs=fc1_bias)
net = mx.symbol.Activation(data=net, name='ac1', act_type='sigmoid')
# output layer
net = mx.sym.FullyConnected(data=net, name='fcout', num_hidden=120000)
net = mx.sym.Activation(data=net, name='acout', act_type='sigmoid')
label = mx.sym.Variable('softmax_label')
net = mx.sym.Custom(data=net, label=label, name='ce', op_type='CrossEntropyLoss')

The training code is just prepare mini-batch data input and label, then call forward/backward/update. When the Adam optimizer is initialized by setting lazy_update to be True, the time cost for single mini batch is 40% faster than the dense training, i.e., training with same input/label data but building network without crs and row_sparse input.

However if the lazy_update was set to be False and no other code change was made, the training was much slower, about 3~4 times than dense.

I did some profiling. There are three adam_update operations for each mini-batch. The first one should be the update of weight matrix between input layer and hidden layer. There is a sync copy GPU-GPU happened after the beginning of the adam_update. In dense training that sync copy lasted for about 15ms, but for sparse and disabled lazy_update, it took almost 600ms to be completed.

All the training was on single GPU and no data or model parallelism was implemented. The input data batch for sparse was prepared into mx.nd.sparse.crs_matrix, and for dense it’s just in mx.nd.array. The label batch is alwasy in mx.nd.array.

Please kindly help take a look. Appreciate that.

This is caused by inefficient GPU implementation of sparse adam_update. I posted some details in https://github.com/apache/incubator-mxnet/pull/10062