Inference network: RNNs vs NNets

The standard variational autoencoder [1] uses neural networks to approximate the true posterior distribution by mapping an input to mean and variance of a standard Gaussian distribution. A simple modification is to replace the inference network from neural nets to RNN. That what exactly this paper present [2].

Intuitively, the RNN will work on the dataset that each consecutive features are highly correlated. It means that for the public dataset such as MNIST, RNN should have no problem approximate posterior distribution of any MNIST digit.

I started with a classical VAE. First, I trained VAE on MNIST dataset, with the hidden units of 500 for both encoders and decoders. I set the latent dimension to 2 so that I can quickly visualize on 2D plot.



2D embedding using Neural Nets (2-layers) as inference network

Some digits are clustered together but some are mixed together because VAE does not know the label of the digits. Thus, it will still put similar digits nearby, aka digit 7’s are right next to digit 9’s. Many digit 3 and 2 are mixed together. To have a better separation between each digit classes, the label information shall be utilized. In fact, our recent publication to SIGIR’2017 utilizes the label information in order to cluster similar documents together.

But come back to our original research question. Is RNN really going to improve the quality of the embedding vectors?


2D embedding using LSTM as inference network


The above 2D plot shows that using LSTM as an inference network has a slightly different embedding space.


2D embedding vectors of randomly chosen MNIST digits using GRU as inference network

LSTM and GRU also generate slightly different embedding vectors. The recurrent model tends to spread out each digit class. For example, digit 6’s (orange) are spread out. All models mixed digit 4 and 9 together. We should know that mixing digits together might not be a bad thing because some writing digit 4 are very similar to 9. This probably indicates that the recurrent model can capture more subtle similarity between digits.

Now, we will see if RNN model might generate better-looking digits than a standard model.






neural nets

It is difficult to tell which models are better. In term of training time, neural nets are the fastest, and LSTM is the slowest. It could be that we have not utilize the strength of RNN yet. Since we are working on MNIST dataset, it might be easy for a traditional model (Neural nets) to perform well. What if we train the model on text datasets such as Newsgroup20? Intuitively, RNN should be able to capture the sequential information. We might get a better embedding space, maybe? Next time we will investigate further on text dataset.


[1] Kingma, Diederik P., and Max Welling. “Auto-encoding variational bayes.” arXiv preprint arXiv:1312.6114 (2013).

[2] Fabius, Otto, and Joost R. van Amersfoort. “Variational recurrent auto-encoders.” arXiv preprint arXiv:1412.6581 (2014).

Blackbox Variational Inference

When we want to inference the hidden variables from our generative model, we usually employ either Gibbs sampler or Mean-field variational inference. Both methods have pro sand cons. Gibbs sampler is easily derived because we only need a transitional probability distribution and MCMC will eventually converge to the true posterior. But it is super slow to converge and we don’t know when it will converge. Mean-field variational inference turns inference problem to an optimization problem. The inference procedure is no longer un-deterministic, no more sampling. It is faster than Gibbs sampling and typically harder to derive the closed form update equations.

David Blei’s quote on these two methods:

“Variational inference is that thing you implement while waiting for your Gibbs sampler to converge.”

His message is that deriving the closed form of VI might take as long as Gibbs sampler to converge!

So here is an alternative approach from David Blei’s student: “Blackbox Variational Inference” [1]. They proposed a generic framework for variational inference and help us avoid derivation of variational parameters. The paper has a lot of mathematical details; however, the key ideas are following:

First, it uses stochastic gradient descent to optimize ELBO.

Second, it computes the expectation term w.r.t to the variational distribution with Monte Carlo samples. Basically, we will draw many latent variables from an approximate posterior distribution, and compute an average of the gradient of ELBO.

Third, this is an important contribution. The Monte Carlo samples have a high variance, thus, optimizing the model will be difficult because the learning rate needs to be very small. The author uses Rao-Blackwellization to reduce the variance by replacing random variables with its conditional expectation w.r.t. the conditioning set.

The author also adds control variates and show that a good control variate has high covariance with the function whose expectation is being computed.

The Blackbox VI seems to be a useful tool for quickly exploring the model assumption.





Entropy in Variational Inference

In Variational Inference, the evidence lowerbound (ELBO) is defined as:

\log p(x) = \log \int_z p(x, z) dz = \log E_{q(z)}[\frac{p(x,z)}{q(z)}]
>= E_{q(z)}[\log p(x, z) ]- E_{q(z)}[ \log q(z) ]

At this point, I usually pay attention to the first term, the expected log-likelihood and treat the second term, entropy, as extra baggage. It makes sense that the entropy term is less relevant especially in mean-field approximation because when the complicated approximate distribution is now just a factored of much simpler approximate distributions, q(z) = \prod_i q(z_i) , the entropy term is just a sum of entropy for each hidden variables.

The entropy term shows up again when I worked on the Variational Autoencoder formulation. The entropy term is merge with a prior and result in a relative entropy:

\log p(x) >= E_{q(z)}[\log p(x| z) ] + E_{q(z)}[\log p(z)]- E_{q(z)}[ \log q(z) ]
= E_{q(z)}[\log p(x| z) ] - D_{kl}(q(z)||p(z))

When I pay closer attention, the relative entropy represents a gap (slack) between the lowerbound and the log-likelihood. The high relative entropy means that I badly approximate the true posterior.

When I studied information theory more thoroughly, the entropy can be viewed as the uncertainty of the event of interest. The more uncertainty, the higher entropy because we need to add extra bits to encode these unknown. Then, the entropy in the ELBO could imply how useful the variational distribution. If q(z) is a uniform distribution, then it is useless because we cannot make any good prediction. Thus, we want q(z) to be less flat which automatically means more predictable and lower entropy.

I think I need to be able to look from information theory perspective to get a better sense of how things work in Variational Inference. This knowledge will definitely be helpful for my future research.