diff --git a/README.md b/README.md index 5d5e680..a9eaee7 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,9 @@ In this repository, we implement FAB and provide the code to reproduce our exper details about our method and the results of our experiments, please read [our paper](https://arxiv.org/abs/2208.01893). + +**FAB in JAX**: See the JAX implementation of the FAB algorithm in the [fab-jax](https://github.com/lollcat/fab-jax) repo. + **Note**: The most important thing to get right when applying FAB to a given problem is to make sure that AIS is returning reasonable samples, where by reasonable we mean that the samples from AIS are closer to the target than the flow. See [About the code](#about-the-code) for further details on how to use the FAB codebase on new problems. @@ -36,6 +39,8 @@ conda install -c conda-forge openmm openmmtools ## Experiments +NB: See README within experiments/{problem-name} for further details on training and evaluation for each problem. + ### Gaussian Mixture Model Open In Colab @@ -110,7 +115,7 @@ The main FAB loss can be found in [core.py](fab/core.py), and we provide a simpl train a flow with this loss (or other flow - loss combinations that meet the spec) in [train.py](fab/train.py) The FAB training algorithm **with** the prioritised buffer can be found in [train_with_prioritised_buffer.py](fab/train_with_prioritised_buffer.py). Additionally, we provide the code for running the SNR/dimensionality analysis with p and q set to independent Gaussians. -in the [fab-jax](https://github.com/lollcat/fab-jax-old) repository. +in the [fab-jax-old](https://github.com/lollcat/fab-jax-old) repository. For training the CRAFT model on the GMM problem we forked the [Annealed Flow Transport repository](https://github.com/deepmind/annealed_flow_transport). This fork may be found [here](https://github.com/lollcat/annealed_flow_transport), and may be used for training the CRAFT model. diff --git a/experiments/many_well/evaluation.py b/experiments/many_well/evaluation.py index 68654f5..0710161 100644 --- a/experiments/many_well/evaluation.py +++ b/experiments/many_well/evaluation.py @@ -27,11 +27,10 @@ def evaluate_many_well(cfg: DictConfig, path_to_model: str, target, num_samples= def main(cfg: DictConfig): """Evaluate each of the models, assume model checkpoints are saved as {model_name}_seed{i}.pt, where the model names for each method are `model_names` and `seeds` below.""" - # model_names = ["target_kld", "flow_nis", "flow_kld", "rbd", "snf_hmc", "fab_no_buffer", - # "fab_buffer"] - model_names = ["rbd", "snf_hmc"] + model_names = ["target_kld", "flow_nis", "flow_kld", "rbd", "snf_hmc", "fab_no_buffer", + "fab_buffer"] seeds = [1, 2, 3] - num_samples = int(5e4) + num_samples = int(5e4) # Divided into 50 runs of 1000 results = pd.DataFrame() for model_name in model_names: @@ -61,7 +60,7 @@ def main(cfg: DictConfig): keys = ["eval_ess_flow", 'test_set_exact_mean_log_prob', 'test_set_modes_mean_log_prob', - 'MSE_log_Z_estimate', "forward_kl"] + 'relative_MSE_Z_estimate', 'abs_MSE_log_Z_estimate', "forward_kl"] print("\n ******* mean ********************** \n") print(results.groupby("model_name").mean()[keys].to_latex()) print("\n ******* std ********************** \n") diff --git a/fab/core.py b/fab/core.py index a29d18a..c5f2cc1 100644 --- a/fab/core.py +++ b/fab/core.py @@ -251,7 +251,7 @@ def load(self, base_distribution=self.flow, target_log_prob=self.target_distribution.log_prob, transition_operator=self.transition_operator, - p_target=self.p_target, + p_target=False, alpha=self.alpha, n_intermediate_distributions=self.n_intermediate_distributions, distribution_spacing_type=self.ais_distribution_spacing)