16
16
17
17
import warnings
18
18
19
- from typing import overload
19
+ from dataclasses import field
20
+ from typing import Any , overload
20
21
21
22
import numpy as np
22
23
import pytensor
25
26
from scipy .sparse import issparse
26
27
27
28
from pymc .pytensorf import floatX
29
+ from pymc .step_methods .state import DataClassState , WithSamplingState , dataclass_state
30
+ from pymc .util import RandomGenerator , get_random_generator
28
31
29
32
__all__ = [
30
33
"quad_potential" ,
@@ -100,11 +103,18 @@ def __str__(self):
100
103
return f"Scaling is not positive definite: { self .msg } . Check indexes { self .idx } ."
101
104
102
105
103
- class QuadPotential :
106
+ @dataclass_state
107
+ class PotentialState (DataClassState ):
108
+ rng : np .random .Generator
109
+
110
+
111
+ class QuadPotential (WithSamplingState ):
104
112
dtype : np .dtype
105
113
114
+ _state_class = PotentialState
115
+
106
116
def __init__ (self , rng = None ):
107
- self .rng = np . random . default_rng (rng )
117
+ self .rng = get_random_generator (rng )
108
118
109
119
@overload
110
120
def velocity (self , x : np .ndarray , out : None ) -> np .ndarray : ...
@@ -157,15 +167,42 @@ def reset(self):
157
167
def stats (self ):
158
168
return {"largest_eigval" : np .nan , "smallest_eigval" : np .nan }
159
169
170
+ def set_rng (self , rng : RandomGenerator ):
171
+ self .rng = get_random_generator (rng , copy = False )
172
+
160
173
161
174
def isquadpotential (value ):
162
175
"""Check whether an object might be a QuadPotential object."""
163
176
return isinstance (value , QuadPotential )
164
177
165
178
179
+ @dataclass_state
180
+ class QuadPotentialDiagAdaptState (PotentialState ):
181
+ _var : np .ndarray
182
+ _stds : np .ndarray
183
+ _inv_stds : np .ndarray
184
+ _foreground_var : WeightedVarianceState
185
+ _background_var : WeightedVarianceState
186
+ _n_samples : int
187
+ adaptation_window : int
188
+ _mass_trace : list [np .ndarray ] | None
189
+
190
+ dtype : Any = field (metadata = {"frozen" : True })
191
+ _n : int = field (metadata = {"frozen" : True })
192
+ _discard_window : int = field (metadata = {"frozen" : True })
193
+ _early_update : int = field (metadata = {"frozen" : True })
194
+ _initial_mean : np .ndarray = field (metadata = {"frozen" : True })
195
+ _initial_diag : np .ndarray = field (metadata = {"frozen" : True })
196
+ _initial_weight : np .ndarray = field (metadata = {"frozen" : True })
197
+ adaptation_window_multiplier : float = field (metadata = {"frozen" : True })
198
+ _store_mass_matrix_trace : bool = field (metadata = {"frozen" : True })
199
+
200
+
166
201
class QuadPotentialDiagAdapt (QuadPotential ):
167
202
"""Adapt a diagonal mass matrix from the sample variances."""
168
203
204
+ _state_class = QuadPotentialDiagAdaptState
205
+
169
206
def __init__ (
170
207
self ,
171
208
n ,
@@ -346,9 +383,20 @@ def raise_ok(self, map_info):
346
383
raise ValueError ("\n " .join (errmsg ))
347
384
348
385
349
- class _WeightedVariance :
386
+ @dataclass_state
387
+ class WeightedVarianceState (DataClassState ):
388
+ n_samples : int
389
+ mean : np .ndarray
390
+ raw_var : np .ndarray
391
+
392
+ _dtype : Any = field (metadata = {"frozen" : True })
393
+
394
+
395
+ class _WeightedVariance (WithSamplingState ):
350
396
"""Online algorithm for computing mean of variance."""
351
397
398
+ _state_class = WeightedVarianceState
399
+
352
400
def __init__ (
353
401
self , nelem , initial_mean = None , initial_variance = None , initial_weight = 0 , dtype = "d"
354
402
):
@@ -390,7 +438,16 @@ def current_mean(self):
390
438
return self .mean .copy (dtype = self ._dtype )
391
439
392
440
393
- class _ExpWeightedVariance :
441
+ @dataclass_state
442
+ class ExpWeightedVarianceState (DataClassState ):
443
+ _alpha : float
444
+ _mean : np .ndarray
445
+ _var : np .ndarray
446
+
447
+
448
+ class _ExpWeightedVariance (WithSamplingState ):
449
+ _state_class = ExpWeightedVarianceState
450
+
394
451
def __init__ (self , n_vars , * , init_mean , init_var , alpha ):
395
452
self ._variance = init_var
396
453
self ._mean = init_mean
@@ -415,7 +472,18 @@ def current_mean(self, out=None):
415
472
return out
416
473
417
474
475
+ @dataclass_state
476
+ class QuadPotentialDiagAdaptExpState (QuadPotentialDiagAdaptState ):
477
+ _alpha : float
478
+ _stop_adaptation : float
479
+ _variance_estimator : ExpWeightedVarianceState
480
+
481
+ _variance_estimator_grad : ExpWeightedVarianceState | None = None
482
+
483
+
418
484
class QuadPotentialDiagAdaptExp (QuadPotentialDiagAdapt ):
485
+ _state_class = QuadPotentialDiagAdaptExpState
486
+
419
487
def __init__ (self , * args , alpha , use_grads = False , stop_adaptation = None , rng = None , ** kwargs ):
420
488
"""Set up a diagonal mass matrix.
421
489
@@ -526,7 +594,7 @@ def __init__(self, v, dtype=None, rng=None):
526
594
self .s = s
527
595
self .inv_s = 1.0 / s
528
596
self .v = v
529
- self .rng = np . random . default_rng (rng )
597
+ self .rng = get_random_generator (rng )
530
598
531
599
def velocity (self , x , out = None ):
532
600
"""Compute the current velocity at a position in parameter space."""
@@ -572,7 +640,7 @@ def __init__(self, A, dtype=None, rng=None):
572
640
dtype = pytensor .config .floatX
573
641
self .dtype = dtype
574
642
self .L = floatX (scipy .linalg .cholesky (A , lower = True ))
575
- self .rng = np . random . default_rng (rng )
643
+ self .rng = get_random_generator (rng )
576
644
577
645
def velocity (self , x , out = None ):
578
646
"""Compute the current velocity at a position in parameter space."""
@@ -621,7 +689,7 @@ def __init__(self, cov, dtype=None, rng=None):
621
689
self ._cov = np .array (cov , dtype = self .dtype , copy = True )
622
690
self ._chol = scipy .linalg .cholesky (self ._cov , lower = True )
623
691
self ._n = len (self ._cov )
624
- self .rng = np . random . default_rng (rng )
692
+ self .rng = get_random_generator (rng )
625
693
626
694
def velocity (self , x , out = None ):
627
695
"""Compute the current velocity at a position in parameter space."""
@@ -646,9 +714,31 @@ def velocity_energy(self, x, v_out):
646
714
__call__ = random
647
715
648
716
717
+ @dataclass_state
718
+ class QuadPotentialFullAdaptState (PotentialState ):
719
+ _previous_update : int
720
+ _cov : np .ndarray
721
+ _chol : np .ndarray
722
+ _chol_error : scipy .linalg .LinAlgError | ValueError | None = None
723
+ _foreground_cov : WeightedCovarianceState
724
+ _background_cov : WeightedCovarianceState
725
+ _n_samples : int
726
+ adaptation_window : int
727
+
728
+ dtype : Any = field (metadata = {"frozen" : True })
729
+ _n : int = field (metadata = {"frozen" : True })
730
+ _update_window : int = field (metadata = {"frozen" : True })
731
+ _initial_mean : np .ndarray = field (metadata = {"frozen" : True })
732
+ _initial_cov : np .ndarray = field (metadata = {"frozen" : True })
733
+ _initial_weight : np .ndarray = field (metadata = {"frozen" : True })
734
+ adaptation_window_multiplier : float = field (metadata = {"frozen" : True })
735
+
736
+
649
737
class QuadPotentialFullAdapt (QuadPotentialFull ):
650
738
"""Adapt a dense mass matrix using the sample covariances."""
651
739
740
+ _state_class = QuadPotentialFullAdaptState
741
+
652
742
def __init__ (
653
743
self ,
654
744
n ,
@@ -689,7 +779,7 @@ def __init__(
689
779
self .adaptation_window_multiplier = float (adaptation_window_multiplier )
690
780
self ._update_window = int (update_window )
691
781
692
- self .rng = np . random . default_rng (rng )
782
+ self .rng = get_random_generator (rng )
693
783
694
784
self .reset ()
695
785
@@ -742,7 +832,16 @@ def raise_ok(self, vmap):
742
832
raise ValueError (str (self ._chol_error ))
743
833
744
834
745
- class _WeightedCovariance :
835
+ @dataclass_state
836
+ class WeightedCovarianceState (DataClassState ):
837
+ n_samples : float
838
+ mean : np .ndarray
839
+ raw_cov : np .ndarray
840
+
841
+ _dtype : Any = field (metadata = {"frozen" : True })
842
+
843
+
844
+ class _WeightedCovariance (WithSamplingState ):
746
845
"""Online algorithm for computing mean and covariance
747
846
748
847
This implements the `Welford's algorithm
@@ -752,6 +851,8 @@ class _WeightedCovariance:
752
851
753
852
"""
754
853
854
+ _state_class = WeightedCovarianceState
855
+
755
856
def __init__ (
756
857
self ,
757
858
nelem ,
@@ -827,7 +928,7 @@ def __init__(self, A, rng=None):
827
928
self .size = A .shape [0 ]
828
929
self .factor = factor = cholmod .cholesky (A )
829
930
self .d_sqrt = np .sqrt (factor .D ())
830
- self .rng = np . random . default_rng (rng )
931
+ self .rng = get_random_generator (rng )
831
932
832
933
def velocity (self , x ):
833
934
"""Compute the current velocity at a position in parameter space."""
0 commit comments