Flow Contrastive Estimation of Energy-Based Models



Ruiqi Gao 1, Erik Nijkamp 1, Diederik P. Kingma 2, Zhen Xu 2, Andrew M. Dai 2, Ying Nian Wu 1

1 University of California, Los Angeles (UCLA)
2 Google Brain


Abstract

This paper studies a training method to jointly estimate an energy-based model and a flow-based model, in which the two models are iteratively updated based on a shared adversarial value function. This joint training method has the following traits. (1) The update of the energy-based model is based on noise contrastive estimation, with the flow model serving as a strong noise distribution. (2) The update of the flow model approximately minimizes the Jensen-Shannon divergence between the flow model and the data distribution. (3) Unlike generative adversarial networks (GAN) which estimates an implicit probability distribution defined by a generator model, our method estimates two explicit probabilistic distributions on the data. Using the proposed method we demonstrate a significant improvement on the synthesis quality of the flow model, and show the effectiveness of unsupervised feature learning by the learned energy-based model. Furthermore, the proposed training method can be easily adapted to semi-supervised learning. We achieve competitive results to the state-of-the-art semi-supervised learning methods.

Paper

The paper can be downloaded here.

Code

The TensorFlow2 code is coming soon!

If you wish to use the code or results, please cite the following paper: 

Flow Contrastive Estimation of Energy-Based Models
@article{gao2019flow,
title={Flow Contrastive Estimation of Energy-Based Models},
author={Gao, Ruiqi and Nijkamp, Erik and Kingma, Diederik P and Xu, Zhen and Dai, Andrew M and Wu, Ying Nian},
journal={arXiv preprint arXiv:1912.00589},
year={2019}} 

Experiments

Exp 1 : Density estimation on 2D synthetic data
Exp 2 : Learning on real image datasets
Exp 3 : Unsupervised feature learning
Exp 4 : Semi-supervised learning

Experiment 1: Density estimation on 2D synthetic data

Figure 1. Training process of EBM and Glow models using our method (FCE) on 2-dimensional data distributions.

Figure 2. Density estimation accuracy in 2D examples of a mixture of 8 Gaussian distributions.

Experiment 2: Learning on real image datasets

    

Figure 3. Synthesized examples from the Glow model on SVHN. Left panel: estimated by MLE; right panel: esitmated by our method (FCE).

    

Figure 4. Synthesized examples from the Glow model on CIFAR-10. Left panel: estimated by MLE; right panel: esitmated by our method (FCE).

    

Figure 5. Synthesized examples from the Glow model on CelebA. Left panel: estimated by MLE; right panel: esitmated by our method (FCE).


Table 1. FID scores for generated samples.

Table 2. Bits per dimension on testing data

Experiment 3: Unsupervised feature learning

Figure 6. SVHN test-set classfication accuracy as a function of number of labeled examples. The features from top layer feature maps are extracted and a linear classifier is learned on the extracted features.


Table 3. Test set classification error of L2-SVM classifier trained on the concatenated features learned from SVHN. DDGM stands for Deep Directed Generative Models. For fair comparison, all the energy-based models or discriminative models are trained with the same model structure.

Experiment 4: Semi-supervised learning

Figure 7. Illustration of FCE for semi-supervised learning on a 2D example, where the data distribution is two spirals belonging to two categories. Within each panel, the top left is the learned unconditional EBM. The top right is the learned Glow model. The bottom are two class-conditional EBMs. For observed data, seven labeled points are provided for each category.


Table 4. Semi-supervised classification error (%) on the SVHN test set. indicates that we derive the results by running the released code. * indicates that the method uses data augmentation. The other cited results are provided by the original papers. Our results are averaged over three runs.

Acknowledgment

The work is partially supported by DARPA XAI project N66001-17-2-4029 and ARO project W911NF1810296. We thank Pavel Sountsov, Alex Alemi, Matthew D. Hoffman and Srinivas Vasudevan for their helpful discussions.

Top