Skip to content

Commit db32421

Browse files
committed
Add slice sampling state
1 parent c417476 commit db32421

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

pymc/step_methods/slicer.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
from pymc.model import modelcontext
2222
from pymc.pytensorf import compile_pymc, join_nonshared_inputs, make_shared_replacements
2323
from pymc.step_methods.arraystep import ArrayStepShared
24-
from pymc.step_methods.compound import Competence
24+
from pymc.step_methods.compound import Competence, StepMethodState
25+
from pymc.step_methods.state import dataclass_state
2526
from pymc.util import get_value_vars_from_user_vars
2627
from pymc.vartypes import continuous_types
2728

@@ -30,6 +31,17 @@
3031
LOOP_ERR_MSG = "max slicer iters %d exceeded"
3132

3233

34+
dataclass_state
35+
36+
37+
@dataclass_state
38+
class SliceState(StepMethodState):
39+
w: np.ndarray
40+
tune: bool
41+
n_tunes: float
42+
iter_limit: float
43+
44+
3345
class Slice(ArrayStepShared):
3446
"""
3547
Univariate slice sampler step method.
@@ -61,6 +73,8 @@ class Slice(ArrayStepShared):
6173
"nstep_in": (int, []),
6274
}
6375

76+
_state_class = SliceState
77+
6478
def __init__(
6579
self, vars=None, w=1.0, tune=True, model=None, iter_limit=np.inf, rng=None, **kwargs
6680
):

0 commit comments

Comments
 (0)