Help with VAE converging

Hey,
I am working on a VAE using gluon and I think I implemented it correctly. However, it converges to a really weird latent space…

Can someone take a look and explain what I did wrong?

Thanks

import mxnet as mx
from mxnet import nd, gluon, autograd
from mxnet.gluon import nn, utils


def sample_gaussian(mu, lv, batch_size, latent_z, ctx=mx.cpu()):
    epsilon = nd.random_normal(0, 1, shape=(batch_size, latent_z), ctx=ctx)
    sigma = nd.sqrt(nd.exp(lv))
    z = mu + nd.multiply(sigma, epsilon)
    return z

class VAEEncoder(gluon.Block):
    def __init__(self, latent_z=100, **kwargs):
        super(VAEEncoder, self).__init__(**kwargs)
        with self.name_scope():
            self.enc = nn.Sequential()
            with self.enc.name_scope():
                self.enc.add(nn.Dense(28*28, activation='relu'))
                self.enc.add(nn.Dense(300, activation='relu'))
                self.enc.add(nn.Dense(128, activation='relu'))
                self.enc.add(nn.Activation(activation='tanh'))

            self.mu = nn.Dense(latent_z) # mu = mean
            self.lv = nn.Dense(latent_z) # lv = log variance

    def forward(self, x):
        x = self.enc(x)
        mu = self.mu(x)
        lv = self.lv(x)
        return mu, lv

    def net_init(self, ctx):
        self.enc.initialize(ctx=ctx)


class VAEDecoder(gluon.Block):
    def __init__(self, latent_z=100, **kwargs):
        super(VAEDecoder, self).__init__(**kwargs)
        with self.name_scope():
            self.dec = nn.Sequential()
            with self.dec.name_scope():
                self.dec.add(nn.Dense(128, activation='relu'))
                self.dec.add(nn.Dense(300, activation='relu'))
                self.dec.add(nn.Dense(28*28))
                self.dec.add(nn.Activation(activation='tanh'))


    def forward(self, x):
        x = self.dec(x)
        return x

    def net_init(self, ctx):
        self.dec.initialize(ctx=ctx)

class VAE(gluon.Block):
    def __init__(self, latent_z=100, batch_size=1, ctx=mx.cpu(),  **kwargs):
        super(VAE, self).__init__(**kwargs)
        with self.name_scope():
            self.enc = VAEEncoder(latent_z=latent_z)
            self.dec = VAEDecoder(latent_z=latent_z)
            self.latent_z = latent_z
            self.batch_size = batch_size
            self.ctx = ctx

    def forward(self, x):
        mu, lv = self.enc(x)
        z = sample_gaussian(mu, lv, self.batch_size, self.latent_z, self.ctx)
        y = self.dec(z)
        return y, mu, lv


def vae_loss(x, y, mu, lv):
    l2 = gluon.loss.L2Loss()
    bce = l2(x, y) # MSE loss
    bce = nd.sum(bce)
    # loss = 0.5 sum(1-log(sigma^2)+mu^2+sigma^2)
    kld_el = (nd.power(mu, 2) + nd.exp(lv)) * -1 + 1 + lv
    kld = nd.sum(kld_el) * (-0.5)
    return bce + kld



use_gpu = True

latent_z = 100
batch_size = 128
learning_rate = 0.001
epochs = 10

ctx = mx.gpu() if use_gpu else mx.cpu()

mnist = mx.test_utils.get_mnist()
flattened_training_data = mnist["train_data"].reshape(60000, 28*28)
data_iter = mx.io.NDArrayIter(flattened_training_data, mnist['train_label'], batch_size=batch_size, shuffle=True)

vae = VAE(latent_z=latent_z, ctx=ctx)
vae.collect_params().initialize(ctx=ctx)

vae_trainer = gluon.Trainer(vae.collect_params(), 'adam', {'learning_rate': learning_rate})


for epoch in range(epochs):
    train_loss = 0
    data_iter.reset()
    for batch_index, batch in enumerate(data_iter):
        data = batch.data[0].as_in_context(ctx)
        with mx.autograd.record():
            y, mu, lv = vae(data)
            loss = vae_loss(data, y, mu, lv)
        loss.backward()
        train_loss += loss.asscalar()
        vae_trainer.step(batch_size)
        
        if batch_index % 100 == 0:
            print('Epoch {}\tBatch {}\tLoss {}'.format(epoch, batch_index, train_loss))


data_iter.reset()
batch = data_iter.next().data[0].as_in_context(ctx)
y_batch = vae(batch)[0]

import matplotlib.pyplot as plt
import matplotlib.cm as cm


plt.imshow(y_batch[10].reshape((28,28)).asnumpy(), cmap=cm.Greys)
plt.show()
plt.imshow(batch[10].reshape((28,28)).asnumpy(), cmap=cm.Greys)
plt.show()