Great, many thanks for sharing @jmacglashan. I also had an attempt and ended up with the following:
class MultivariateNormalDist:
def __init__(self, num_var, mean, sigma):
"""
Distribution object for Multivariate Normal. Works with batches.
Optionally works with batches and time steps, but be consistent in usage: i.e. if using time_step,
mean, sigma and data for log_prob must all include a time_step dimension.
:param num_var: number of variables in distribution
:type num_var: int
:param mean: mean for each variable,
of shape (num_var) or
of shape (batch_size, num_var) or
of shape (batch_size, time_step, num_var).
:type mean: mxnet.nd.NDArray
:param sigma: covariance matrix,
of shape (num_var, num_var) or
of shape (batch_size, num_var, num_var) or
of shape (batch_size, time_step, num_var, num_var).
:type sigma: mxnet.nd.NDArray
"""
self.num_var = num_var
self.mean = mean
self.sigma = sigma
@staticmethod
def inverse_using_cholesky(matrix):
"""
Calculate inverses for a batch of matrices using Cholesky decomposition method.
:param matrix: matrix (or matrices) to invert,
of shape (num_var, num_var) or
of shape (batch_size, num_var, num_var) or
of shape (batch_size, time_step, num_var, num_var).
:type sigma: mxnet.nd.NDArray
:return: inverted matrix (or matrices),
of shape (num_var, num_var) or
of shape (batch_size, num_var, num_var) or
of shape (batch_size, time_step, num_var, num_var).
:rtype sigma: mxnet.nd.NDArray
"""
cholesky_factor = potrf(matrix)
return potri(cholesky_factor)
def log_prob(self, x):
"""
Calculate the log probability of data given the current distribution.
See http://www.notenoughthoughts.net/posts/normal-log-likelihood-gradient.html
and https://discuss.mxnet.apache.org/t/multivariate-gaussian-log-density-operator/1169/7
:param x: input data,
of shape (num_var) or
of shape (batch_size, num_var) or
of shape (batch_size, time_step, num_var).
:type x: mxnet.nd.NDArray
:return: log_probability,
of shape (1) or
of shape (batch_size) or
of shape (batch_size, time_step).
:rtype: mxnet.nd.NDArray
"""
a = (self.num_var / 2) * math.log(2 * math.pi)
log_det_sigma = 2 * sumlogdiag(self.sigma)
b = (1 / 2) * log_det_sigma
sigma_inv = self.inverse_using_cholesky(self.sigma)
# deviation from mean, and dev_t is equivalent to transpose on last two dims.
dev = (x - self.mean).expand_dims(-1)
dev_t = (x - self.mean).expand_dims(-2)
# since batch_dot only works with ndarrays with ndim of 3,
# and we could have ndarrays with ndim of 4,
# we flatten batch_size and time_step into single dim.
dev_flat = dev.reshape(shape=(-1, 0, 0), reverse=1)
sigma_inv_flat = sigma_inv.reshape(shape=(-1, 0, 0), reverse=1)
dev_t_flat = dev_t.reshape(shape=(-1, 0, 0), reverse=1)
c = (1 / 2) * batch_dot(batch_dot(dev_t_flat, sigma_inv_flat), dev_flat)
# and now reshape back to (batch_size, time_step) if required.
c = c.reshape_like(b)
log_likelihood = -a - b - c
return log_likelihood