Skip to content

Commit 465d8ac

Browse files
committed
Add HMC sampling state
1 parent db32421 commit 465d8ac

File tree

5 files changed

+178
-20
lines changed

5 files changed

+178
-20
lines changed

pymc/step_methods/hmc/base_hmc.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,19 @@
2727
from pymc.model import Point, modelcontext
2828
from pymc.pytensorf import floatX
2929
from pymc.stats.convergence import SamplerWarning, WarningType
30-
from pymc.step_methods import step_sizes
3130
from pymc.step_methods.arraystep import GradientSharedStep
3231
from pymc.step_methods.compound import StepMethodState
3332
from pymc.step_methods.hmc import integration
3433
from pymc.step_methods.hmc.integration import IntegrationError, State
35-
from pymc.step_methods.hmc.quadpotential import QuadPotentialDiagAdapt, quad_potential
34+
from pymc.step_methods.hmc.quadpotential import (
35+
PotentialState,
36+
QuadPotentialDiagAdapt,
37+
quad_potential,
38+
)
39+
from pymc.step_methods.state import dataclass_state
40+
from pymc.step_methods.step_sizes import DualAverageAdaptation, StepSizeState
3641
from pymc.tuning import guess_scaling
37-
from pymc.util import get_value_vars_from_user_vars
42+
from pymc.util import RandomGenerator, get_random_generator, get_value_vars_from_user_vars
3843

3944
logger = logging.getLogger(__name__)
4045

@@ -53,12 +58,27 @@ class HMCStepData(NamedTuple):
5358
stats: dict[str, Any]
5459

5560

61+
@dataclass_state
62+
class BaseHMCState(StepMethodState):
63+
adapt_step_size: bool
64+
Emax: float
65+
iter_count: int
66+
step_size: np.ndarray
67+
step_adapt: StepSizeState
68+
target_accept: float
69+
tune: bool
70+
potential: PotentialState
71+
_num_divs_sample: int
72+
73+
5674
class BaseHMC(GradientSharedStep):
5775
"""Superclass to implement Hamiltonian/hybrid monte carlo."""
5876

5977
integrator: integration.CpuLeapfrogIntegrator
6078
default_blocked = True
6179

80+
_state_class = BaseHMCState
81+
6282
def __init__(
6383
self,
6484
vars=None,
@@ -134,9 +154,7 @@ def __init__(
134154
size = sum(v.size for v in nuts_vars)
135155

136156
self.step_size = step_scale / (size**0.25)
137-
self.step_adapt = step_sizes.DualAverageAdaptation(
138-
self.step_size, target_accept, gamma, k, t0
139-
)
157+
self.step_adapt = DualAverageAdaptation(self.step_size, target_accept, gamma, k, t0)
140158
self.target_accept = target_accept
141159
self.tune = True
142160

@@ -268,3 +286,7 @@ def reset_tuning(self, start=None):
268286
def reset(self, start=None):
269287
self.tune = True
270288
self.potential.reset()
289+
290+
def set_rng(self, rng: RandomGenerator):
291+
self.rng = get_random_generator(rng, copy=False)
292+
self.potential.set_rng(self.rng.spawn(1)[0])

pymc/step_methods/hmc/hmc.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,16 @@
1414

1515
from __future__ import annotations
1616

17+
from dataclasses import field
1718
from typing import Any
1819

1920
import numpy as np
2021

2122
from pymc.stats.convergence import SamplerWarning
2223
from pymc.step_methods.compound import Competence
23-
from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData
24+
from pymc.step_methods.hmc.base_hmc import BaseHMC, BaseHMCState, DivergenceInfo, HMCStepData
2425
from pymc.step_methods.hmc.integration import IntegrationError, State
26+
from pymc.step_methods.state import dataclass_state
2527
from pymc.vartypes import discrete_types
2628

2729
__all__ = ["HamiltonianMC"]
@@ -31,6 +33,12 @@ def unif(step_size, elow=0.85, ehigh=1.15, rng: np.random.Generator | None = Non
3133
return (rng or np.random).uniform(elow, ehigh) * step_size
3234

3335

36+
@dataclass_state
37+
class HamiltonianMCState(BaseHMCState):
38+
path_length: float = field(metadata={"frozen": True})
39+
max_steps: int = field(metadata={"frozen": True})
40+
41+
3442
class HamiltonianMC(BaseHMC):
3543
R"""A sampler for continuous variables based on Hamiltonian mechanics.
3644

pymc/step_methods/hmc/nuts.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
from collections import namedtuple
18+
from dataclasses import field
1819

1920
import numpy as np
2021

@@ -23,13 +24,20 @@
2324
from pymc.stats.convergence import SamplerWarning
2425
from pymc.step_methods.compound import Competence
2526
from pymc.step_methods.hmc import integration
26-
from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData
27+
from pymc.step_methods.hmc.base_hmc import BaseHMC, BaseHMCState, DivergenceInfo, HMCStepData
2728
from pymc.step_methods.hmc.integration import IntegrationError, State
29+
from pymc.step_methods.state import dataclass_state
2830
from pymc.vartypes import continuous_types
2931

3032
__all__ = ["NUTS"]
3133

3234

35+
@dataclass_state
36+
class NUTSState(BaseHMCState):
37+
max_treedepth: int = field(metadata={"frozen": True})
38+
early_max_treedepth: int = field(metadata={"frozen": True})
39+
40+
3341
class NUTS(BaseHMC):
3442
r"""A sampler for continuous variables based on Hamiltonian mechanics.
3543

pymc/step_methods/hmc/quadpotential.py

Lines changed: 112 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616

1717
import warnings
1818

19-
from typing import overload
19+
from dataclasses import field
20+
from typing import Any, overload
2021

2122
import numpy as np
2223
import pytensor
@@ -25,6 +26,8 @@
2526
from scipy.sparse import issparse
2627

2728
from pymc.pytensorf import floatX
29+
from pymc.step_methods.state import DataClassState, WithSamplingState, dataclass_state
30+
from pymc.util import RandomGenerator, get_random_generator
2831

2932
__all__ = [
3033
"quad_potential",
@@ -100,11 +103,18 @@ def __str__(self):
100103
return f"Scaling is not positive definite: {self.msg}. Check indexes {self.idx}."
101104

102105

103-
class QuadPotential:
106+
@dataclass_state
107+
class PotentialState(DataClassState):
108+
rng: np.random.Generator
109+
110+
111+
class QuadPotential(WithSamplingState):
104112
dtype: np.dtype
105113

114+
_state_class = PotentialState
115+
106116
def __init__(self, rng=None):
107-
self.rng = np.random.default_rng(rng)
117+
self.rng = get_random_generator(rng)
108118

109119
@overload
110120
def velocity(self, x: np.ndarray, out: None) -> np.ndarray: ...
@@ -157,15 +167,42 @@ def reset(self):
157167
def stats(self):
158168
return {"largest_eigval": np.nan, "smallest_eigval": np.nan}
159169

170+
def set_rng(self, rng: RandomGenerator):
171+
self.rng = get_random_generator(rng, copy=False)
172+
160173

161174
def isquadpotential(value):
162175
"""Check whether an object might be a QuadPotential object."""
163176
return isinstance(value, QuadPotential)
164177

165178

179+
@dataclass_state
180+
class QuadPotentialDiagAdaptState(PotentialState):
181+
_var: np.ndarray
182+
_stds: np.ndarray
183+
_inv_stds: np.ndarray
184+
_foreground_var: WeightedVarianceState
185+
_background_var: WeightedVarianceState
186+
_n_samples: int
187+
adaptation_window: int
188+
_mass_trace: list[np.ndarray] | None
189+
190+
dtype: Any = field(metadata={"frozen": True})
191+
_n: int = field(metadata={"frozen": True})
192+
_discard_window: int = field(metadata={"frozen": True})
193+
_early_update: int = field(metadata={"frozen": True})
194+
_initial_mean: np.ndarray = field(metadata={"frozen": True})
195+
_initial_diag: np.ndarray = field(metadata={"frozen": True})
196+
_initial_weight: np.ndarray = field(metadata={"frozen": True})
197+
adaptation_window_multiplier: float = field(metadata={"frozen": True})
198+
_store_mass_matrix_trace: bool = field(metadata={"frozen": True})
199+
200+
166201
class QuadPotentialDiagAdapt(QuadPotential):
167202
"""Adapt a diagonal mass matrix from the sample variances."""
168203

204+
_state_class = QuadPotentialDiagAdaptState
205+
169206
def __init__(
170207
self,
171208
n,
@@ -346,9 +383,20 @@ def raise_ok(self, map_info):
346383
raise ValueError("\n".join(errmsg))
347384

348385

349-
class _WeightedVariance:
386+
@dataclass_state
387+
class WeightedVarianceState(DataClassState):
388+
n_samples: int
389+
mean: np.ndarray
390+
raw_var: np.ndarray
391+
392+
_dtype: Any = field(metadata={"frozen": True})
393+
394+
395+
class _WeightedVariance(WithSamplingState):
350396
"""Online algorithm for computing mean of variance."""
351397

398+
_state_class = WeightedVarianceState
399+
352400
def __init__(
353401
self, nelem, initial_mean=None, initial_variance=None, initial_weight=0, dtype="d"
354402
):
@@ -390,7 +438,16 @@ def current_mean(self):
390438
return self.mean.copy(dtype=self._dtype)
391439

392440

393-
class _ExpWeightedVariance:
441+
@dataclass_state
442+
class ExpWeightedVarianceState(DataClassState):
443+
_alpha: float
444+
_mean: np.ndarray
445+
_var: np.ndarray
446+
447+
448+
class _ExpWeightedVariance(WithSamplingState):
449+
_state_class = ExpWeightedVarianceState
450+
394451
def __init__(self, n_vars, *, init_mean, init_var, alpha):
395452
self._variance = init_var
396453
self._mean = init_mean
@@ -415,7 +472,18 @@ def current_mean(self, out=None):
415472
return out
416473

417474

475+
@dataclass_state
476+
class QuadPotentialDiagAdaptExpState(QuadPotentialDiagAdaptState):
477+
_alpha: float
478+
_stop_adaptation: float
479+
_variance_estimator: ExpWeightedVarianceState
480+
481+
_variance_estimator_grad: ExpWeightedVarianceState | None = None
482+
483+
418484
class QuadPotentialDiagAdaptExp(QuadPotentialDiagAdapt):
485+
_state_class = QuadPotentialDiagAdaptExpState
486+
419487
def __init__(self, *args, alpha, use_grads=False, stop_adaptation=None, rng=None, **kwargs):
420488
"""Set up a diagonal mass matrix.
421489
@@ -526,7 +594,7 @@ def __init__(self, v, dtype=None, rng=None):
526594
self.s = s
527595
self.inv_s = 1.0 / s
528596
self.v = v
529-
self.rng = np.random.default_rng(rng)
597+
self.rng = get_random_generator(rng)
530598

531599
def velocity(self, x, out=None):
532600
"""Compute the current velocity at a position in parameter space."""
@@ -572,7 +640,7 @@ def __init__(self, A, dtype=None, rng=None):
572640
dtype = pytensor.config.floatX
573641
self.dtype = dtype
574642
self.L = floatX(scipy.linalg.cholesky(A, lower=True))
575-
self.rng = np.random.default_rng(rng)
643+
self.rng = get_random_generator(rng)
576644

577645
def velocity(self, x, out=None):
578646
"""Compute the current velocity at a position in parameter space."""
@@ -621,7 +689,7 @@ def __init__(self, cov, dtype=None, rng=None):
621689
self._cov = np.array(cov, dtype=self.dtype, copy=True)
622690
self._chol = scipy.linalg.cholesky(self._cov, lower=True)
623691
self._n = len(self._cov)
624-
self.rng = np.random.default_rng(rng)
692+
self.rng = get_random_generator(rng)
625693

626694
def velocity(self, x, out=None):
627695
"""Compute the current velocity at a position in parameter space."""
@@ -646,9 +714,31 @@ def velocity_energy(self, x, v_out):
646714
__call__ = random
647715

648716

717+
@dataclass_state
718+
class QuadPotentialFullAdaptState(PotentialState):
719+
_previous_update: int
720+
_cov: np.ndarray
721+
_chol: np.ndarray
722+
_chol_error: scipy.linalg.LinAlgError | ValueError | None = None
723+
_foreground_cov: WeightedCovarianceState
724+
_background_cov: WeightedCovarianceState
725+
_n_samples: int
726+
adaptation_window: int
727+
728+
dtype: Any = field(metadata={"frozen": True})
729+
_n: int = field(metadata={"frozen": True})
730+
_update_window: int = field(metadata={"frozen": True})
731+
_initial_mean: np.ndarray = field(metadata={"frozen": True})
732+
_initial_cov: np.ndarray = field(metadata={"frozen": True})
733+
_initial_weight: np.ndarray = field(metadata={"frozen": True})
734+
adaptation_window_multiplier: float = field(metadata={"frozen": True})
735+
736+
649737
class QuadPotentialFullAdapt(QuadPotentialFull):
650738
"""Adapt a dense mass matrix using the sample covariances."""
651739

740+
_state_class = QuadPotentialFullAdaptState
741+
652742
def __init__(
653743
self,
654744
n,
@@ -689,7 +779,7 @@ def __init__(
689779
self.adaptation_window_multiplier = float(adaptation_window_multiplier)
690780
self._update_window = int(update_window)
691781

692-
self.rng = np.random.default_rng(rng)
782+
self.rng = get_random_generator(rng)
693783

694784
self.reset()
695785

@@ -742,7 +832,16 @@ def raise_ok(self, vmap):
742832
raise ValueError(str(self._chol_error))
743833

744834

745-
class _WeightedCovariance:
835+
@dataclass_state
836+
class WeightedCovarianceState(DataClassState):
837+
n_samples: float
838+
mean: np.ndarray
839+
raw_cov: np.ndarray
840+
841+
_dtype: Any = field(metadata={"frozen": True})
842+
843+
844+
class _WeightedCovariance(WithSamplingState):
746845
"""Online algorithm for computing mean and covariance
747846
748847
This implements the `Welford's algorithm
@@ -752,6 +851,8 @@ class _WeightedCovariance:
752851
753852
"""
754853

854+
_state_class = WeightedCovarianceState
855+
755856
def __init__(
756857
self,
757858
nelem,
@@ -827,7 +928,7 @@ def __init__(self, A, rng=None):
827928
self.size = A.shape[0]
828929
self.factor = factor = cholmod.cholesky(A)
829930
self.d_sqrt = np.sqrt(factor.D())
830-
self.rng = np.random.default_rng(rng)
931+
self.rng = get_random_generator(rng)
831932

832933
def velocity(self, x):
833934
"""Compute the current velocity at a position in parameter space."""

0 commit comments

Comments
 (0)