Day 15: CVAE with Gumbel-softmax (cont)

My Day 14, I added Gumbel-softmax to sample a digit category but the result did not look good. I have found that switching from a straight-through Gumbel-softmax to a general Gumbel-softmax greatly helps the CVAE to learn a digit distribution.

Here are the results:

gumbel_softmax_CVAE.png

The result also looks much better than the generated digit by the method used in Day 13 (expanding the expectation term). It is partly because we do not average out the loss anymore.

I am surprised that a straight-through Gumbel-softmax does not quite work. Although I did use temperature annealing technique ( slowly convert a Gumbel-softmax to a straight-through Gumbel on the backward pass ), the result was not good.

It seems at this point, a general Gumbel-softmax performs much better, at least, for MNIST digit generation task.

 

Day 14: CVAE with Gumbel-Softmax

From my previous post, expanding the expectation term is not scalable when the dataset has many classes. For MNIST dataset with only 10 classes, the training is slow because we have to compute loss for all classes before applying a weight average.

For a gaussian distribution, we use reparameterization trick to convert a sampling step to a deterministic operation:

z \sim N(\mu, \sigma)

By using shift and scaling operations, we now have:

z = \mu + \sigma * \epsilon

Where \epsilon is an extra input to the network and we generate it from a Gaussian noise \epsilon \sim N(0, 1).

In our case, we also need another reparameterization trick forĀ a categorical distribution. Gumbel-Softmax [1] is a surrogate function to approximate this sampling process. I will leave some technical details for my later post such as why choosing the Gumbel distribution is suitable and why softmax is used to approximate onehot vector representation.

For now, I simply added Gumbel-softmax into my CVAE. PyTorch provides a nice API for Gumbel-Softmax, so I don’t have to implement myself. Note that, there are two version of Gumbel-softmax: (1) Straight-through and (2) Non-Straight-through.

The Non-Straight-through Gumbel outputs a soft-version of a onehot encoder. It is possible that more than one entry has a non-zero value. The straight-through Gumbel (ST-Gumbel) simply outputs a one-hot encoder. However, during the backprop, ST-Gumbel uses a gradient from Non-ST gumbel.

Here are some initial results (without hyper-parameters tuning):

gumbel.png

The output does not look correct. This could be the way I train the model because the original Gumbel-Softmax paper can train Semi-supervised VAE. This is something I will address next time.

References:

[1] https://arxiv.org/pdf/1611.01144.pdf

Day 13: Implementation of Semi-supervised VAE

This post is a continuation of my Day 12. To train semi-supervised VAE, we need to expand the expectation as follows:

E_{q(y|x)}[E_{q(z|x)}[\log P(x|z,y) - KL(q(z|x)||p(z))] - \log q(y|x)] =

\sum_{y \in Y} q(y|x)[E_{q(z|x)}[\log P(x|z,y) - KL(q(z|x)||p(z))] - \log q(y|x)

\sum_{y \in Y} q(y|x)[E_{q(z|x)}[\log P(x|z,y) - KL(q(z|x)||p(z))] - \sum_{y \in Y} q(y|x)\log q(y|x)

The first term turns out to be the same loss as the standard CVAE. The second term is the entropy of the q(y|x).

Entropy Loss

The entropy is an average information required to encode the given event. In our case, the event is the outcome of image prediction.

H(y) = -\sum_{y \in Y}q(y|x) \log q(y|x)

When entropy H(y) is high, it means we need to use more bits to encode the event. The highest entropy is when q(y|x) is a uniform distribution. This implies that if the discriminator cannot classify the images, the entropy will be high. When q(y|x) is no longer a uniform distribution, it hints us that the discriminator at least learns something. By minimizing the entropy, we encourage the model to learn some useful distribution q(y|x) instead of a boring uniform distribution.

Implementation (Pseudocode)

The implementation is as follows:

  • for each mini-batch batch_x
    • for each class y
      • loss(x, y) = compute CVAE loss of batch_x and y
    • Avg loss = q(y|x) * loss(x, y)
  • Avg loss += compute entropy loss for the current batch_x

This implementation can be slow during the training because we have to compute the loss for each class. If we have a lot of classes, it will be very slow.

CVAE_mnist_0_01.png

Image generated by CVAE. We use 60 labeled images per classes.

improved_0_01.png

Imaged generated by semi-supervised VAE. We use 60 labeled images and 5400 unlabeled images per class.

To be honest, the results are not impressive. The training could be better. Currently, I simply alternate between training labeled data and unlabeled data.

Next:

Another approach is to use an approximation of discrete output. One of the nice approximation is Gumbel-softmax which I will use and possibly implement it for my Day 14.

 

 

 

 

Day 12: Handling Discrete output in VAE

The limitation of GANs and VAE is that the generator of GANs or encoder of VAE must be differentiable. This prevents the model to generate a discrete output which can be useful for many tasks.

I start my study with Semi-supervised VAE [1].Ā  This model is the same as CVAE but with an extra component for handling the unlabeled training dataset.

 

semi_VAE.png

A semi-supervised model for VAE proposed by [1]

When the input has a label, we use the first architecture (which is CVAE) to train the model. When the input does not have a label, we use a discriminator to predict a label first. Then, we train the same way as labeled data.

To see the impact of unsupervised learning component, we can investigate the improvement of the generated images from the generator after adding unlabeled data.

First, we want to see how bad of the generated images after we train CVAE with the limited number of training samples:

CVAE_mnist_0_95.png

Take 100% of images from each class (6000 images per class)

CVAE_mnist_0_01.png

Take 10% of images from each class (600 images per class)

CVAE_mnist_0_01.png

Take 1% of images from each class (60 images per class)

CVAE_mnist_0_001.png

Take 0.1% of images from each class (6 images per class)

It is obvious that we don’t have enough data for VAE to approximate the image distribution.

Dealing with the Discrete output

It is quite straightforward to predict the label of the given unlabeled image first. But the real issue is that we cannot train CVAE end-to-end anymore because the output from the discriminator is a discrete output. We need to find the way to handle this situation.

Compute the Expectation directly (No Monte Carlo approximation)

The first approach to handle the discrete output from the encoder is to average out the reconstruction error for all possible discrete output. This is exactly what [1] did.

We factor q(y,z|x) = q(y|x)\cdot q(z|x) and derive the ELBO:

E_{q(y|x)}[E_{q(z|x)}[\log P(x|z,y) - KL(q(z|x)||p(z))] - \log q(y|x)] =

\sum_{y \in Y} q(y|x)[E_{q(z|x)}[\log P(x|z,y) - KL(q(z|x)||p(z))] - \log q(y|x)]

Instead of sampling a label from q(y|x), [1] side-steped by expanding the expectation term. This approach is workable for the dataset with the small number of classes but it is not scalable when we have a lot of classes.

Reference:

[1] https://arxiv.org/abs/1406.5298

Day 11: 2D one-hot representation

My last post, my CGAN’s architecture does not work and when it is trained, its generator will learn nothing (complete mode collapsing issue). After a few days of research and read a few tips online, I’ve found the architecture for CGAN that works!

Before I describe the specific architecture for CGAN, here are the results:

cgan_correct.png

generated MNIST digits by CGAN. It is easy to interpret each row as a stroke width, stroke style.

CGAN_fashion.png

generated fashion items by CGAN. It is harder to interpret each row.

The loss of discriminator and generator look much better:

DCGAN_MNIST_loss.png

Loss on MNIST dataset

CGAN_fashion_loss.png

Loss on Fashion-MNIST dataset

The most important component of CGAN is the way we combine class label and actual image/latent vector as one input unit to the discriminator/generator. The choice of architecture either makes or breaks the model.

Generator

It takes two inputs: latent vector and class label.

  • Latent vector is a random vector drawn from a normal distribution, mean at 0 with a unit variance.
  • A class label is a one-hot vector.
  • We concatenate these two vectors and use a few feedforward layers to merge them.
  • Finally, we pass the new vector into deconvolutional layers.

Discriminator

It takes two inputs: generated image/real image and class label.

  • Generated image or real image is a 2d matrix.
  • a class label is converted to a 2d one-hotĀ representation. This component was missing from my previous CGAN implementation.Ā I will describe it a bit later.
  • The generated/real image with 1 color channel + 10 channels from 2d one-hot representation.

2D one-hot representation

The idea is simple. For each class label, we represent it as 10 matrices ( or a matrix of size 10 by image width by image height ). For a class label i, the ith matrix is one matrix, the rest are zero.

This representation is simple and combines nicely with the actual 2d image. I will explore deeper into this choice of representations and how it might affect the discriminator.

References:

[1]Ā https://arxiv.org/abs/1411.1784

 

Day 10: Mode Collapsing on my CGAN

I tried to implement a conditional GAN [1]. At first, it seems to be a straightforward extension of GANs. But I ran into a mode collapsing and it was a mess:

Here are digit 5 generated by the conditional GANs:

Epoch 10:

digit_5_10.png

Start off, it looks okay.

Epoch 50:

digit_5_50.png

wait, the model ignores the class label.

Epoch 100:

digit_5_100.png

So the generator was giving up now?

Epoch 150:

digit_5_150.png

Yup, it is a complete mode collapsing. The generator just gave up.

This is not good. I need to find the right architecture for CGAN or the appropriate hyperparameters.

Ideally, what I want to achieve should look like this:

CGAN_expected.png

This is not my image!

I expect to put more efforts in term of training the CGAN model and the pick the right choice of the architectures and hyper-parameters. Stay tune.

References:

[1] https://arxiv.org/abs/1411.1784