From 3d370ff7085dadd891fb3be2661c1e523f8bffaf Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Wed, 6 Dec 2023 19:58:30 +0100 Subject: [PATCH 1/2] call convergence_check after sampling with numpyro --- pymc/sampling/mcmc.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index f546837290..858d2c4a4d 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -359,6 +359,10 @@ def _sample_external_nuts( idata_kwargs=idata_kwargs, **nuts_sampler_kwargs, ) + + warns = run_convergence_checks(idata, model) + log_warnings(warns) + return idata elif sampler == "blackjax": From e28c71ea6cdff1e13fc6b8f754e401ef22223d18 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Wed, 6 Dec 2023 20:11:48 +0100 Subject: [PATCH 2/2] Call `run_convergence_checks` after sampling with blackjax --- pymc/sampling/mcmc.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 858d2c4a4d..07c99e3328 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -380,6 +380,9 @@ def _sample_external_nuts( idata_kwargs=idata_kwargs, **nuts_sampler_kwargs, ) + warns = run_convergence_checks(idata, model) + log_warnings(warns) + return idata else: