Skip to content

Commit 2473bfe

Browse files
committed
Implement utility to convert Model to and from FunctionGraph
1 parent b730449 commit 2473bfe

File tree

5 files changed

+567
-0
lines changed

5 files changed

+567
-0
lines changed

docs/api_reference.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,5 @@ Utils
5151

5252
spline.bspline_interpolation
5353
prior.prior_from_idata
54+
model_fgraph.fgraph_from_model
55+
model_fgraph.model_from_fgraph

pymc_experimental/tests/utils/__init__.py

Whitespace-only changes.
Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
import numpy as np
2+
import pymc as pm
3+
import pytensor.tensor as pt
4+
import pytest
5+
from pytensor.graph import FunctionGraph, node_rewriter
6+
from pytensor.graph.rewriting.basic import in2out
7+
from pytensor.tensor.exceptions import NotScalarConstantError
8+
9+
from pymc_experimental.utils.model_fgraph import (
10+
ModelFreeRV,
11+
ModelVar,
12+
fgraph_from_model,
13+
model_deterministic,
14+
model_free_rv,
15+
model_from_fgraph,
16+
)
17+
18+
19+
def test_basic():
20+
"""Test we can convert from a PyMC Model to a FunctionGraph and back"""
21+
with pm.Model(coords={"test_dim": range(3)}) as m_old:
22+
x = pm.Normal("x")
23+
y = pm.Deterministic("y", x + 1)
24+
w = pm.HalfNormal("w", pm.math.exp(y))
25+
z = pm.Normal("z", y, w, observed=[0, 1, 2], dims=("test_dim",))
26+
pm.Potential("pot", x * 2)
27+
28+
m_fgraph = fgraph_from_model(m_old)
29+
assert isinstance(m_fgraph, FunctionGraph)
30+
31+
m_new = model_from_fgraph(m_fgraph)
32+
assert isinstance(m_new, pm.Model)
33+
34+
assert m_new.coords == {"test_dim": tuple(range(3))}
35+
assert m_new._dim_lengths["test_dim"].eval() == 3
36+
assert m_new.named_vars_to_dims == {"z": ["test_dim"]}
37+
38+
named_vars = {"x", "y", "w", "z", "pot"}
39+
assert set(m_new.named_vars) == named_vars
40+
for named_var in named_vars:
41+
assert m_new[named_var] is not m_old[named_var]
42+
assert m_new["x"] in m_new.free_RVs
43+
assert m_new["w"] in m_new.free_RVs
44+
assert m_new["y"] in m_new.deterministics
45+
assert m_new["z"] in m_new.observed_RVs
46+
assert m_new["pot"] in m_new.potentials
47+
assert m_new.rvs_to_transforms[m_new["x"]] is None
48+
assert m_new.rvs_to_transforms[m_new["w"]] is pm.distributions.transforms.log
49+
assert m_new.rvs_to_transforms[m_new["z"]] is None
50+
51+
# Test random
52+
new_y_draw, new_z_draw = pm.draw([m_new["y"], m_new["z"]], draws=5, random_seed=1)
53+
old_y_draw, old_z_draw = pm.draw([m_old["y"], m_old["z"]], draws=5, random_seed=1)
54+
np.testing.assert_array_equal(new_y_draw, old_y_draw)
55+
np.testing.assert_array_equal(new_z_draw, old_z_draw)
56+
57+
# Test logp
58+
ip = m_new.initial_point()
59+
np.testing.assert_equal(
60+
m_new.compile_logp()(ip),
61+
m_old.compile_logp()(ip),
62+
)
63+
64+
65+
def test_data():
66+
"""Test shared RNGs, MutableData, ConstantData and Dim lengths are handled correctly.
67+
68+
Everything should be preserved across new and old models, except for shared RNGs
69+
"""
70+
with pm.Model(coords_mutable={"test_dim": range(3)}) as m_old:
71+
x = pm.MutableData("x", [0.0, 1.0, 2.0], dims=("test_dim",))
72+
y = pm.MutableData("y", [10.0, 11.0, 12.0], dims=("test_dim",))
73+
b0 = pm.ConstantData("b0", 0.0)
74+
b1 = pm.Normal("b1")
75+
mu = pm.Deterministic("mu", b0 + b1 * x, dims=("test_dim",))
76+
obs = pm.Normal("obs", mu, sigma=1e-5, observed=y, dims=("test_dim",))
77+
78+
m_new = model_from_fgraph(fgraph_from_model(m_old))
79+
80+
# ConstantData is preserved
81+
assert m_new["b0"].data == m_old["b0"].data
82+
83+
# Shared non-rng shared variables are preserved
84+
assert m_new["x"].container is x.container
85+
assert m_new["y"].container is y.container
86+
assert m_new.rvs_to_values[m_new["obs"]] is m_new["y"]
87+
88+
# Shared rng shared variables are not preserved
89+
m_new["b1"].owner.inputs[0].container is not m_old["b1"].owner.inputs[0].container
90+
91+
with m_old:
92+
pm.set_data({"x": [100.0, 200.0]}, coords={"test_dim": range(2)})
93+
94+
assert m_new.dim_lengths["test_dim"].eval() == 2
95+
np.testing.assert_array_almost_equal(pm.draw(m_new["x"]), [100.0, 200.0])
96+
97+
98+
def test_deterministics():
99+
"""Test handling of deterministics.
100+
101+
We don't want Deterministics in the middle of the FunctionGraph, as they would make rewrites cumbersome
102+
However we want them in the middle of Model.basic_RVs, so they display nicely in graphviz
103+
104+
There is one edge case that has to be considered, when a Deterministic is just a copy of a RV.
105+
In that case we don't bother to reintroduce it in between other Model.basic_RVs
106+
"""
107+
with pm.Model() as m:
108+
x = pm.Normal("x")
109+
mu = pm.Deterministic("mu", pm.math.abs(x))
110+
sigma = pm.math.exp(x)
111+
pm.Deterministic("sigma", sigma)
112+
y = pm.Normal("y", mu, sigma)
113+
# Special case where the Deterministic
114+
# is a direct view on another model variable
115+
y_ = pm.Deterministic("y_", y)
116+
# Just for kicks, make it a double one!
117+
y__ = pm.Deterministic("y__", y_)
118+
z = pm.Normal("z", y__)
119+
120+
# Deterministic mu is in the graph of x to y but not sigma
121+
assert m["y"].owner.inputs[3] is m["mu"]
122+
assert m["y"].owner.inputs[4] is not m["sigma"]
123+
124+
fg = fgraph_from_model(m)
125+
126+
# Check that no Deterministics are in graph of x to y and y to z
127+
x, y, z, det_mu, det_sigma, det_y_, det_y__ = fg.outputs
128+
# [Det(mu), Det(sigma)]
129+
mu = det_mu.owner.inputs[0]
130+
sigma = det_sigma.owner.inputs[0]
131+
# [FreeRV(y(mu, sigma))] not [FreeRV(y(Det(mu), Det(sigma)))]
132+
assert y.owner.inputs[0].owner.inputs[3] is mu
133+
assert y.owner.inputs[0].owner.inputs[4] is sigma
134+
# [FreeRV(z(y))] not [FreeRV(z(Det(Det(y))))]
135+
assert z.owner.inputs[0].owner.inputs[3] is y
136+
# [Det(y), Det(y)], not [Det(y), Det(Det(y))]
137+
assert det_y_.owner.inputs[0] is y
138+
assert det_y__.owner.inputs[0] is y
139+
assert det_y_ is not det_y__
140+
141+
# Both mu and sigma deterministics are now in the graph of x to y
142+
m = model_from_fgraph(fg)
143+
assert m["y"].owner.inputs[3] is m["mu"]
144+
assert m["y"].owner.inputs[4] is m["sigma"]
145+
# But not y_* in y to z, since there was no real Op in between
146+
assert m["z"].owner.inputs[3] is m["y"]
147+
assert m["y_"].owner.inputs[0] is m["y"]
148+
assert m["y__"].owner.inputs[0] is m["y"]
149+
150+
151+
def test_context_error():
152+
"""Test that model_from_fgraph fails when called inside a Model context.
153+
154+
We can't allow it, because the new Model that's returned would be a child of whatever Model context is active.
155+
"""
156+
with pm.Model() as m:
157+
x = pm.Normal("x")
158+
159+
fg = fgraph_from_model(m)
160+
161+
with pytest.raises(RuntimeError, match="cannot be called inside a PyMC model context"):
162+
model_from_fgraph(fg)
163+
164+
165+
def test_sub_model_error():
166+
"""Test Error is raised when trying to convert a sub-model to fgraph."""
167+
with pm.Model() as m:
168+
x = pm.Beta("x", 1, 1)
169+
with pm.Model() as sub_m:
170+
y = pm.Normal("y", x)
171+
172+
nodes = [v for v in fgraph_from_model(m).toposort() if not isinstance(v.op, ModelVar)]
173+
assert len(nodes) == 2
174+
assert isinstance(nodes[0].op, pm.Beta)
175+
assert isinstance(nodes[1].op, pm.Normal)
176+
177+
with pytest.raises(ValueError, match="Nested sub-models cannot be converted"):
178+
fgraph_from_model(sub_m)
179+
180+
181+
@pytest.fixture()
182+
def non_centered_rewrite():
183+
@node_rewriter(tracks=[ModelFreeRV])
184+
def non_centered_param(fgraph: FunctionGraph, node):
185+
"""Rewrite that replaces centered normal by non-centered parametrization."""
186+
187+
rv, value, *dims = node.inputs
188+
if not isinstance(rv.owner.op, pm.Normal):
189+
return
190+
rng, size, dtype, loc, scale = rv.owner.inputs
191+
192+
# Only apply rewrite if size information is explicit
193+
if size.ndim == 0:
194+
return None
195+
196+
try:
197+
is_unit = (
198+
pt.get_underlying_scalar_constant_value(loc) == 0
199+
and pt.get_underlying_scalar_constant_value(scale) == 1
200+
)
201+
except NotScalarConstantError:
202+
is_unit = False
203+
204+
# Nothing to do here
205+
if is_unit:
206+
return
207+
208+
raw_norm = pm.Normal.dist(0, 1, size=size, rng=rng)
209+
raw_norm.name = f"{rv.name}_raw_"
210+
raw_norm_value = raw_norm.clone()
211+
fgraph.add_input(raw_norm_value)
212+
raw_norm = model_free_rv(raw_norm, raw_norm_value, node.op.transform, *dims)
213+
214+
new_norm = loc + raw_norm * scale
215+
new_norm.name = rv.name
216+
new_norm_det = model_deterministic(new_norm, *dims)
217+
fgraph.add_output(new_norm_det)
218+
219+
return [new_norm]
220+
221+
return in2out(non_centered_param)
222+
223+
224+
def test_fgraph_rewrite(non_centered_rewrite):
225+
"""Test we can apply a simple rewrite to a PyMC Model."""
226+
227+
with pm.Model(coords={"subject": range(10)}) as m_old:
228+
group_mean = pm.Normal("group_mean")
229+
group_std = pm.HalfNormal("group_std")
230+
subject_mean = pm.Normal("subject_mean", group_mean, group_std, dims=("subject",))
231+
obs = pm.Normal("obs", subject_mean, 1, observed=np.zeros(10), dims=("subject",))
232+
233+
fg = fgraph_from_model(m_old)
234+
non_centered_rewrite.apply(fg)
235+
236+
m_new = model_from_fgraph(fg)
237+
assert m_new.named_vars_to_dims == {
238+
"subject_mean": ["subject"],
239+
"subject_mean_raw_": ["subject"],
240+
"obs": ["subject"],
241+
}
242+
assert set(m_new.named_vars) == {
243+
"group_mean",
244+
"group_std",
245+
"subject_mean_raw_",
246+
"subject_mean",
247+
"obs",
248+
}
249+
assert {rv.name for rv in m_new.free_RVs} == {"group_mean", "group_std", "subject_mean_raw_"}
250+
assert {rv.name for rv in m_new.observed_RVs} == {"obs"}
251+
assert {rv.name for rv in m_new.deterministics} == {"subject_mean"}
252+
253+
with pm.Model() as m_ref:
254+
group_mean = pm.Normal("group_mean")
255+
group_std = pm.HalfNormal("group_std")
256+
subject_mean_raw = pm.Normal("subject_mean_raw_", 0, 1, shape=(10,))
257+
subject_mean = pm.Deterministic("subject_mean", group_mean + subject_mean_raw * group_std)
258+
obs = pm.Normal("obs", subject_mean, 1, observed=np.zeros(10))
259+
260+
np.testing.assert_array_equal(
261+
pm.draw(m_new["subject_mean_raw_"], draws=7, random_seed=1),
262+
pm.draw(m_ref["subject_mean_raw_"], draws=7, random_seed=1),
263+
)
264+
265+
ip = m_new.initial_point()
266+
np.testing.assert_equal(
267+
m_new.compile_logp()(ip),
268+
m_ref.compile_logp()(ip),
269+
)

0 commit comments

Comments
 (0)