Hey,
I am working on a VAE using gluon and I think I implemented it correctly. However, it converges to a really weird latent space…
Can someone take a look and explain what I did wrong?
Thanks
import mxnet as mx
from mxnet import nd, gluon, autograd
from mxnet.gluon import nn, utils
def sample_gaussian(mu, lv, batch_size, latent_z, ctx=mx.cpu()):
epsilon = nd.random_normal(0, 1, shape=(batch_size, latent_z), ctx=ctx)
sigma = nd.sqrt(nd.exp(lv))
z = mu + nd.multiply(sigma, epsilon)
return z
class VAEEncoder(gluon.Block):
def __init__(self, latent_z=100, **kwargs):
super(VAEEncoder, self).__init__(**kwargs)
with self.name_scope():
self.enc = nn.Sequential()
with self.enc.name_scope():
self.enc.add(nn.Dense(28*28, activation='relu'))
self.enc.add(nn.Dense(300, activation='relu'))
self.enc.add(nn.Dense(128, activation='relu'))
self.enc.add(nn.Activation(activation='tanh'))
self.mu = nn.Dense(latent_z) # mu = mean
self.lv = nn.Dense(latent_z) # lv = log variance
def forward(self, x):
x = self.enc(x)
mu = self.mu(x)
lv = self.lv(x)
return mu, lv
def net_init(self, ctx):
self.enc.initialize(ctx=ctx)
class VAEDecoder(gluon.Block):
def __init__(self, latent_z=100, **kwargs):
super(VAEDecoder, self).__init__(**kwargs)
with self.name_scope():
self.dec = nn.Sequential()
with self.dec.name_scope():
self.dec.add(nn.Dense(128, activation='relu'))
self.dec.add(nn.Dense(300, activation='relu'))
self.dec.add(nn.Dense(28*28))
self.dec.add(nn.Activation(activation='tanh'))
def forward(self, x):
x = self.dec(x)
return x
def net_init(self, ctx):
self.dec.initialize(ctx=ctx)
class VAE(gluon.Block):
def __init__(self, latent_z=100, batch_size=1, ctx=mx.cpu(), **kwargs):
super(VAE, self).__init__(**kwargs)
with self.name_scope():
self.enc = VAEEncoder(latent_z=latent_z)
self.dec = VAEDecoder(latent_z=latent_z)
self.latent_z = latent_z
self.batch_size = batch_size
self.ctx = ctx
def forward(self, x):
mu, lv = self.enc(x)
z = sample_gaussian(mu, lv, self.batch_size, self.latent_z, self.ctx)
y = self.dec(z)
return y, mu, lv
def vae_loss(x, y, mu, lv):
l2 = gluon.loss.L2Loss()
bce = l2(x, y) # MSE loss
bce = nd.sum(bce)
# loss = 0.5 sum(1-log(sigma^2)+mu^2+sigma^2)
kld_el = (nd.power(mu, 2) + nd.exp(lv)) * -1 + 1 + lv
kld = nd.sum(kld_el) * (-0.5)
return bce + kld
use_gpu = True
latent_z = 100
batch_size = 128
learning_rate = 0.001
epochs = 10
ctx = mx.gpu() if use_gpu else mx.cpu()
mnist = mx.test_utils.get_mnist()
flattened_training_data = mnist["train_data"].reshape(60000, 28*28)
data_iter = mx.io.NDArrayIter(flattened_training_data, mnist['train_label'], batch_size=batch_size, shuffle=True)
vae = VAE(latent_z=latent_z, ctx=ctx)
vae.collect_params().initialize(ctx=ctx)
vae_trainer = gluon.Trainer(vae.collect_params(), 'adam', {'learning_rate': learning_rate})
for epoch in range(epochs):
train_loss = 0
data_iter.reset()
for batch_index, batch in enumerate(data_iter):
data = batch.data[0].as_in_context(ctx)
with mx.autograd.record():
y, mu, lv = vae(data)
loss = vae_loss(data, y, mu, lv)
loss.backward()
train_loss += loss.asscalar()
vae_trainer.step(batch_size)
if batch_index % 100 == 0:
print('Epoch {}\tBatch {}\tLoss {}'.format(epoch, batch_index, train_loss))
data_iter.reset()
batch = data_iter.next().data[0].as_in_context(ctx)
y_batch = vae(batch)[0]
import matplotlib.pyplot as plt
import matplotlib.cm as cm
plt.imshow(y_batch[10].reshape((28,28)).asnumpy(), cmap=cm.Greys)
plt.show()
plt.imshow(batch[10].reshape((28,28)).asnumpy(), cmap=cm.Greys)
plt.show()