Sparse _backward_dot is slow

Sparse _backward_dot operator is slow and it seems to depend of the dimensions and not number of non zero entries.

Example script=

The number of non-zero entries is kept constant and the feature dimension is varied.

with num_features = 1,000,000 the results are -

with num_features = 100,000,000 the results are -


As can be seen from the results, there is no appreciable difference in the time for dot, but time for backward dot and adam update increases.
The benchmarking is done on CPU on a mac notebook.

Is there a scope of improvement for backward_dot operator? I tried to follow the code in, but could not follow well enough why backward dot is dependent upon the dimension size.

A one-line summary of the dot operator implementation is that, the dot operator first computes the prefix sum of all indices of the row_sparse output, then use the prefix sum to generate the data and sorted indices of the output. Note that for CPU the approach is almost the same, except that GPU leverages some available kernels in NVidia’s cub( ).

Profiling result shows that the computation for prefix sum is taking about 50% of the time. Let’s note the feature_dim of csr as N, number of non-zero elements in csr as K. The time complexity for prefix sum and sorted indices generation is O(N) for our current implementation, which is why the speed slows down for the case you mentioned.

In an alternative implementation, instead of computing the prefix sum, we can sort csr.indices and compute the unique indices based on the sorted result. The time complexity is O(KlogK), which means it only depends on number of non-zeros in the input csr, instead of the feature_dim. Unfortunately this implementation in our benchmark was about 5x slower than the current implementation for feature_dim = 2M and batch_size = 1K. The problem is that sorting is much less cache-friendly comparing to prefix sum. The alternative approach would only be faster than the baseline only if K is very small and N is very large. Faster sorting implementations such as bitonic sort might help reduce the cost of sorting.