Kronecker Product functionality

There doesn’t seem to be any Kronecker Product functionality in MXNet. I’ve resorted to this (admittedly terrible) hack for the 2D case - keeping it here for reference (shouldn’t be hard to generalize it to ND):

def kron(a, b):
	s1, s2 = infer_shape(a)
	s3, s4 = infer_shape(b)
	return mx.sym.reshape(mx.sym.broadcast_mul(mx.sym.reshape(a, (s1, 1, s2, 1)), mx.sym.reshape(b, (1, s3, 1, s4))), (s1*s3, s2*s4))

@mseeger if you have an op for this :slight_smile:

I’d do it exactly the same way. Wrapping it into CustomOp gives you the advantage that it works inside mx.sym expression as well (you cannot use infer_shape then).

But always ask yourself whether you really need the Kronecker product as a matrix, or whether you just need to multiply with it. In the end, it can be a pretty big thing.