This repository contians an reimplmentation of WAE with Tensorflow.
I made one tweak on top of the paper. I used Wasserstein distance to penalize an encoder
I (personally) believe that this implementation is much clearer and easy to read (, and more importantly, the code almost exactly matches with the algorithm shows on the paper), so I hope it will help someone who wants to digin more! Enjoy 🍺!
- MMD is implemented
- Question 1: What is inverse multiscale kernel? The formula looks a little bit different from other resources..
- Qeustion 2: On its original implementation, why MMD is evaluated on multiple scale (differnt C values) even though true scale is given as a prior? Doesn't it result in wrong MMD values and make Q_z diverge from P_z?
- Python 3.x (tested with Python 3.5)
- TF v1.x (tested with 1.7rc0)
- tqdm
- and etc... (please report if you find other deps.)
Check file to change target dataset or to adjust hyperparmeters such as z_dim, and etc...
See the MNIST Plot.ipynb
and CelebA Plot.ipynb
with Jupyter Notebook.
A pretrained model for both MNIST is included on the repository while a model for CelebA is uploaded on this place.
Please download the zip file and decompress it on assets/pretrained_models/celeba/last*
. Or, you can easily modify a path at the first cell on the notebook.
Trained with GTX-1080 Ti GPU for about 30 minutes
Reconstruction Results
(top): original images from MNIST validation set, (bottom): reconstructed image
It seems not sharp as the authors suggest, but it might due to not enough training and untuned hyperparameters such as lambda, number of layers, or etc.
- Trained with GTX-1080 Ti GPU for about 1 day.
- Encode and Decode images size of 64 by 64.
- Reconstruction Results
(top): original images from CelebA validation set, (bottom): reconstructed image
- Random Sampled Images
With fully trained model, the results seem pretty nice! Can we still say that AE-variants generating blurry images?
- Intepolation on
$z$ space.
- Other datasets (CelebA, or CIFAR10)
- The author's original TF implementation link