Multivariate Gaussian Log Density Operator?

Currently, Mxnet has no differentiable operators for the log probability density of standard distributions. For discrete distributions, this is trivial. For a vector a single variate Gaussian distributions, implementing such an operator is fairly straightforward. However, implementing an operator for a multivariate Gaussian log density is not so straightforward.

In particular, unless I’m missing something, to implement one we would need access to more linear algebra operations for determinants and inverses. While Mxnet seems to have some matrix operations, I think they’re actually too narrow in scope to support what we would need for a multivariate Gaussian log density, but my linear algebra could use some brushing up, so it’s possible I’m just not seeing how to use what is there to do it.

Has anyone here implemented one themselves, or see how it could be done?

Hi @jmacglashan,

Would Gluon’s autograd package help with this? See here for an example of usage.

Otherwise, what matrix operations would you need to be able to do this? You can find a more comprehensive list of linear algebra functions in the nd.linalg module here.

Hi, @thomelane

Unfortunately, I don’t think autograd will help here (though In general I do work in Gluon, I think I need a custom operator for this).

The multivariate Gaussian forward operation requires computing a matrix inverse and determinant (see here). It’s not clear to me that the existing nd.linalg operators have what I need to do even that. Specifically, I need to be able to compute the determinant and a matrix inverse. The nd.linalg has no determinant operator that I see, and its inverse operator might be for too narrow a space of matrices. That is, the matrix inverse operator in mxnet says it uses Cholesky factorization, but my understanding is the Cholesky factorization requires a positive definite matrix, while a covariance matrix guarantees only semi-positive definite. I could be wrong about that though.

Ultimately, the tools I need to just compute the forward, are the same tools I need to write a custom backward. So if I could write the forward, I might as well make the operator with the custom backward which would be more likely to be numerically stable and precise. (I’m also not sure the linear algebra operators support backward passes, which means I’d have to write my own operator anyway)

Could you describe what you want to do more in details?

Some thoughts:

  1. Alternative, if you need to approximate mean and convariance, then usual MLE should be enough.

  2. If you’d need more, then I think potrf/potri and trsm (alike) should be enough given some useful facts like derivative of log det (A) is A inv, transpose. I suggest, see what’s needed and how you can simply as much as you can given those ops.

AFAIK, among all the frameworks, MXNet has more efficient linalg lapack ops. See this paper and it’s examples, they might be enough for you!

Just on this point, the linear algebra operators are differentiable, so you wouldn’t need to compute the backward pass manually. I’ll try and reach out to someone who worked on this component to confirm.

Thanks guys. I’ll check out that paper and it’s good to know that the linear algebra operators are in fact differentiable. I’d still be concerned about numerical suitability without a custom operator, but it’s worth playing with.

The task I’m aiming this for is related to reinforcement learning with a policy network, so I can’t use MLE approaches and need derivatives.

It’s possible I can constrain things enough to force the covariance matrix to be positive definite and not just semi-positive definite, which would help with the inverse. I’ll have to dig in a bit more for how to compute the determinant from the available operators. Scanning the paper, it looks like they do that in some places, so perhaps I can reconstruct the steps. Probably would be good to brush up on my linear algebra too so I understand why it’s doing what it’s doing :stuck_out_tongue:

You may check the writeup about how to compute multivariate Gaussian log density and its derivative here:

As you see, you need the inverse of the covariance matrix which means that this matrix must be positive definite (otherwise the inverse won’t exist).

All the formulas that you see there should be expressible with the operators in linalg-namespace. The trick here is that you don’t go through direct inversion but do a Cholesky-decomposition of the covariance matrix before (sigma = L * transpose(L)). This is done by operator “potrf”. Given the Cholesky-decomposition, you can do the rest:

linalg.potri computes the inverse of sigma based on a Cholesky-factor L

Further log(det(sigma)) = log(det(L * transpose(L))) = 2log(det(L)) = 2sumlogdiag(L)
This because L is a triangular matrix so the determinant is just the product of the elements in the main diagonal. MXNet’s linalg.sumlogdiag() does exactly what you want.

Note that depending on the application, explicit inversion may be not the most stable method and an implicit method may be preferable. For example if you need the solution of
sigma * x = y
then rather than computing the inverse of sigma (based on the Cholesky factor), solving the system
L*transpose(L) * x = y
by means of the linalg.trsm() operator is likely numerically more stable.

And just to confirm: All MXNet.linalg() operators provide backward computations (i.e. their gradient).


Don’t hesitate if you have further questions. We (the authors of the linalg-package) are happy to help. Though the paper mentioned before in the thread should give a lot more context.

Thanks very much! I think I see where to go now. If I run into trouble, I’ll come back :slight_smile:

@jmacglashan, I’d be interested to hear how you got on. Just needing the same thing actually! Many thanks in advance.

I ended up not using the operation a lot because other priorities took over, but I did seem to get it working. I’ll see if I can grab the core code to share tomorrow.

@thomelane I ended up using just the forward and letting autograd through the linear ops do its thing. This follow pretty close what was in the links provided by @asmushetzel.

    class LgMultivariateNormal(gluon.HybridBlock):
        def __init__(self, dim: int, prefix, params):
            :param dim: dimensionality of the multivariate normal
            :param prefix: usual HybridBlock optional prefix passed to super
            :param params: usual HybridBlock optional params passed to super
            super().__init__(prefix, params)
            self._constant_part = (-dim / 2) * np.log(2*np.pi)

        def hybrid_forward(self, F, x, *args, **kwargs):
            mu = args[0]
            sigma = args[1]

            x = x.expand_dims(axis=2)  # make batch of vectors a batch of kx1 matrices
            mu = mu.expand_dims(axis=2)  # make batch vectors a batch kx1 matrices

            diff = x - mu  # kx1
            diff_t = diff.transpose(axes=[0, 2, 1])  # batch of 1xk matrices

            cholesky = F.linalg.potrf(sigma)
            sigma_inv = F.linalg.potri(cholesky)

            # log(det) is    2 * sumlogdiag(cholesky)
            # so half det is sumlogdiag(cholesky)
            half_sigma_det = F.linalg.sumlogdiag(cholesky)
            half_sigma_det = half_sigma_det.expand_dims(axis=1).expand_dims(axis=2)  # make batch of 1x1

            matrix_mult = F.linalg.gemm2(diff_t, sigma_inv)
            matrix_mult = F.linalg.gemm2(matrix_mult, diff)

            log_density = self._constant_part - half_sigma_det - 0.5 * matrix_mult
            log_density = log_density.flatten()

            return log_density

    x = mx.nd.array([[0., 0.], [1., 1.]])
    mu = mx.nd.zeros((2, 2))
    sigma = mx.nd.array([[[2., 0.], [0., 3.]], [[2., 0.], [0., 3.]]])

    lgnormal = LgMultivariateNormal(2)
    lgnormal(x, mu, sigma)

Great, many thanks for sharing @jmacglashan. I also had an attempt and ended up with the following:

class MultivariateNormalDist:
    def __init__(self, num_var, mean, sigma):
        Distribution object for Multivariate Normal. Works with batches.
        Optionally works with batches and time steps, but be consistent in usage: i.e. if using time_step,
        mean, sigma and data for log_prob must all include a time_step dimension.

        :param num_var: number of variables in distribution
        :type num_var: int
        :param mean: mean for each variable,
            of shape (num_var) or
            of shape (batch_size, num_var) or
            of shape (batch_size, time_step, num_var).
        :type mean: mxnet.nd.NDArray
        :param sigma: covariance matrix,
            of shape (num_var, num_var) or
            of shape (batch_size, num_var, num_var) or
            of shape (batch_size, time_step, num_var, num_var).
        :type sigma: mxnet.nd.NDArray
        self.num_var = num_var
        self.mean = mean
        self.sigma = sigma

    def inverse_using_cholesky(matrix):
        Calculate inverses for a batch of matrices using Cholesky decomposition method.

        :param matrix: matrix (or matrices) to invert,
            of shape (num_var, num_var) or
            of shape (batch_size, num_var, num_var) or
            of shape (batch_size, time_step, num_var, num_var).
        :type sigma: mxnet.nd.NDArray
        :return: inverted matrix (or matrices),
            of shape (num_var, num_var) or
            of shape (batch_size, num_var, num_var) or
            of shape (batch_size, time_step, num_var, num_var).
        :rtype sigma: mxnet.nd.NDArray
        cholesky_factor = potrf(matrix)
        return potri(cholesky_factor)

    def log_prob(self, x):
        Calculate the log probability of data given the current distribution.


        :param x: input data,
            of shape (num_var) or
            of shape (batch_size, num_var) or
            of shape (batch_size, time_step, num_var).
        :type x: mxnet.nd.NDArray
        :return: log_probability,
            of shape (1) or
            of shape (batch_size) or
            of shape (batch_size, time_step).
        :rtype: mxnet.nd.NDArray
        a = (self.num_var / 2) * math.log(2 * math.pi)
        log_det_sigma = 2 * sumlogdiag(self.sigma)
        b = (1 / 2) * log_det_sigma
        sigma_inv = self.inverse_using_cholesky(self.sigma)
        # deviation from mean, and dev_t is equivalent to transpose on last two dims.
        dev = (x - self.mean).expand_dims(-1)
        dev_t = (x - self.mean).expand_dims(-2)

        # since batch_dot only works with ndarrays with ndim of 3,
        # and we could have ndarrays with ndim of 4,
        # we flatten batch_size and time_step into single dim.
        dev_flat = dev.reshape(shape=(-1, 0, 0), reverse=1)
        sigma_inv_flat = sigma_inv.reshape(shape=(-1, 0, 0), reverse=1)
        dev_t_flat = dev_t.reshape(shape=(-1, 0, 0), reverse=1)
        c = (1 / 2) * batch_dot(batch_dot(dev_t_flat, sigma_inv_flat), dev_flat)
        # and now reshape back to (batch_size, time_step) if required.
        c = c.reshape_like(b)

        log_likelihood = -a - b - c
        return log_likelihood

hope to contribute to official Mxnet repo