Skip to content

Commit 4f06ab7

Browse files
committed
Do not use initval in test model
PRs #7508 and #7492 introduced incompatible changes but were not tested simultaneously. Deepcopying the steps in the tests leads to deepcopying the model which uses `clone_model`, which in turn does not support initvals.
1 parent 465d8ac commit 4f06ab7

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

tests/models.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import pytensor
1919
import pytensor.tensor as pt
2020

21-
from pytensor import config
2221
from pytensor.compile.ops import as_op
2322

2423
import pymc as pm
@@ -30,9 +29,9 @@ def simple_model():
3029
mu = -2.1
3130
tau = 1.3
3231
with Model() as model:
33-
Normal("x", mu, tau=tau, size=2, initval=np.array([0.1, 0.1]).astype(config.floatX))
32+
x = Normal("x", mu, tau=tau, size=2)
3433

35-
return model.initial_point(), model, (mu, tau**-0.5)
34+
return {"x": np.array([0.1, 0.1], dtype=x.type.dtype)}, model, (mu, tau**-0.5)
3635

3736

3837
def another_simple_model():
@@ -46,11 +45,11 @@ def simple_categorical():
4645
p = np.array([0.1, 0.2, 0.3, 0.4])
4746
v = np.array([0.0, 1.0, 2.0, 3.0])
4847
with Model() as model:
49-
Categorical("x", p, size=3, initval=[1, 2, 3])
48+
x = Categorical("x", p, size=3)
5049

5150
mu = np.dot(p, v)
5251
var = np.dot(p, (v - mu) ** 2)
53-
return model.initial_point(), model, (mu, var)
52+
return {"x": np.array([1, 2, 3], dtype=x.type.dtype)}, model, (mu, var)
5453

5554

5655
def multidimensional_model():
@@ -98,15 +97,14 @@ def mv_simple():
9897
p = np.array([[2.0, 0, 0], [0.05, 0.1, 0], [1.0, -0.05, 5.5]])
9998
tau = np.dot(p, p.T)
10099
with pm.Model() as model:
101-
pm.MvNormal(
100+
x = pm.MvNormal(
102101
"x",
103102
pt.constant(mu),
104103
tau=pt.constant(tau),
105-
initval=np.array([0.1, 1.0, 0.8]),
106104
)
107105
H = tau
108106
C = np.linalg.inv(H)
109-
return model.initial_point(), model, (mu, C)
107+
return {"x": np.array([0.1, 1.0, 0.8], dtype=x.type.dtype)}, model, (mu, C)
110108

111109

112110
def mv_simple_coarse():

0 commit comments

Comments
 (0)