Skip to content

Commit 510d7b8

Browse files
committed
Do not compare arrays with strings in initial_point::make_initial_point_expression
1 parent 8950f21 commit 510d7b8

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

pymc/initial_point.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -269,10 +269,15 @@ def make_initial_point_expression(
269269
if strategy is None:
270270
strategy = default_strategy
271271

272-
if strategy == "moment":
273-
value = get_moment(variable)
274-
elif strategy == "prior":
275-
value = variable
272+
if isinstance(strategy, str):
273+
if strategy == "moment":
274+
value = get_moment(variable)
275+
elif strategy == "prior":
276+
value = variable
277+
else:
278+
raise ValueError(
279+
f'Invalid string strategy: {strategy}. It must be one of ["moment", "prior"]'
280+
)
276281
else:
277282
value = at.as_tensor(strategy, dtype=variable.dtype).astype(variable.dtype)
278283

pymc/tests/test_initial_point.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ def test_new_warnings(self):
5050
assert not hasattr(rv.tag, "test_value")
5151
pass
5252

53+
def test_valid_string_strategy(self):
54+
with pm.Model() as pmodel:
55+
pm.Uniform("x", 0, 1, size=2, initval="unknown")
56+
with pytest.raises(ValueError, match="Invalid string strategy: unknown"):
57+
pmodel.recompute_initial_point(seed=0)
58+
5359

5460
class TestInitvalEvaluation:
5561
def test_make_initial_point_fns_per_chain_checks_kwargs(self):

0 commit comments

Comments
 (0)