Skip to content

Commit b37651e

Browse files
committed
Return memo dictionary in fgraph_from_model
1 parent ac0455f commit b37651e

File tree

2 files changed

+36
-17
lines changed

2 files changed

+36
-17
lines changed

pymc_experimental/tests/utils/test_model_fgraph.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,17 @@ def test_basic():
2323
y = pm.Deterministic("y", x + 1)
2424
w = pm.HalfNormal("w", pm.math.exp(y))
2525
z = pm.Normal("z", y, w, observed=[0, 1, 2], dims=("test_dim",))
26-
pm.Potential("pot", x * 2)
26+
pot = pm.Potential("pot", x * 2)
2727

28-
m_fgraph = fgraph_from_model(m_old)
28+
m_fgraph, memo = fgraph_from_model(m_old)
2929
assert isinstance(m_fgraph, FunctionGraph)
3030

31+
assert memo[x] in m_fgraph.variables
32+
assert memo[y] in m_fgraph.variables
33+
assert memo[w] in m_fgraph.variables
34+
assert memo[z] in m_fgraph.variables
35+
assert memo[pot] in m_fgraph.variables
36+
3137
m_new = model_from_fgraph(m_fgraph)
3238
assert isinstance(m_new, pm.Model)
3339

@@ -79,7 +85,7 @@ def test_data():
7985
mu = pm.Deterministic("mu", b0 + b1 * x, dims=("test_dim",))
8086
obs = pm.Normal("obs", mu, sigma=1e-5, observed=y, dims=("test_dim",))
8187

82-
m_new = model_from_fgraph(fgraph_from_model(m_old))
88+
m_new = model_from_fgraph(fgraph_from_model(m_old)[0])
8389

8490
# ConstantData is preserved
8591
assert m_new["b0"].data == m_old["b0"].data
@@ -125,7 +131,7 @@ def test_deterministics():
125131
assert m["y"].owner.inputs[3] is m["mu"]
126132
assert m["y"].owner.inputs[4] is not m["sigma"]
127133

128-
fg = fgraph_from_model(m)
134+
fg, _ = fgraph_from_model(m)
129135

130136
# Check that no Deterministics are in graph of x to y and y to z
131137
x, y, z, det_mu, det_sigma, det_y_, det_y__ = fg.outputs
@@ -173,7 +179,7 @@ def test_sub_model_error():
173179
with pm.Model() as sub_m:
174180
y = pm.Normal("y", x)
175181

176-
nodes = [v for v in fgraph_from_model(m).toposort() if not isinstance(v.op, ModelVar)]
182+
nodes = [v for v in fgraph_from_model(m)[0].toposort() if not isinstance(v.op, ModelVar)]
177183
assert len(nodes) == 2
178184
assert isinstance(nodes[0].op, pm.Beta)
179185
assert isinstance(nodes[1].op, pm.Normal)
@@ -234,7 +240,7 @@ def test_fgraph_rewrite(non_centered_rewrite):
234240
subject_mean = pm.Normal("subject_mean", group_mean, group_std, dims=("subject",))
235241
obs = pm.Normal("obs", subject_mean, 1, observed=np.zeros(10), dims=("subject",))
236242

237-
fg = fgraph_from_model(m_old)
243+
fg, _ = fgraph_from_model(m_old)
238244
non_centered_rewrite.apply(fg)
239245

240246
m_new = model_from_fgraph(fg)

pymc_experimental/utils/model_fgraph.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional
1+
from typing import Dict, Optional, Tuple
22

33
import pytensor
44
from pymc.logprob.transforms import RVTransform
@@ -107,16 +107,19 @@ def local_remove_identity(fgraph, node):
107107
remove_identity_rewrite = out2in(local_remove_identity)
108108

109109

110-
def fgraph_from_model(model: Model) -> FunctionGraph:
110+
def fgraph_from_model(model: Model) -> Tuple[FunctionGraph, Dict[Variable, Variable]]:
111111
"""Convert Model to FunctionGraph.
112112
113-
Create a FunctionGraph that includes a copy of model variables, wrapped in dummy `ModelVar` Ops.
114-
115-
PyTensor rewrites can be used to transform the FunctionGraph.
113+
See: model_from_fgraph
116114
117-
It should be possible to reconstruct a valid PyMC model using `model_from_fgraph`.
115+
Returns
116+
-------
117+
fgraph: FunctionGraph
118+
FunctionGraph that includes a copy of model variables, wrapped in dummy `ModelVar` Ops.
119+
It should be possible to reconstruct a valid PyMC model using `model_from_fgraph`.
118120
119-
See: model_from_fgraph
121+
memo: Dict
122+
A dictionary mapping original model variables to the equivalent nodes in the fgraph.
120123
"""
121124

122125
if any(v is not None for v in model.rvs_to_initial_values.values()):
@@ -200,7 +203,17 @@ def fgraph_from_model(model: Model) -> FunctionGraph:
200203
new_var = var
201204
new_vars.append(new_var)
202205

203-
toposort_replace(fgraph, tuple(zip(vars, new_vars)))
206+
replacements = tuple(zip(vars, new_vars))
207+
toposort_replace(fgraph, replacements)
208+
209+
# Reference model vars in memo
210+
inverse_memo = {v: k for k, v in memo.items()}
211+
for var, model_var in replacements:
212+
if isinstance(model_var.owner is not None and model_var.owner.op, ModelDeterministic):
213+
# Ignore extra identity that will be removed at the end
214+
var = var.owner.inputs[0]
215+
original_var = inverse_memo[var]
216+
memo[original_var] = model_var
204217

205218
# Remove value variable as outputs, now that they are graph inputs
206219
first_value_idx = len(fgraph.outputs) - len(value_vars)
@@ -210,7 +223,7 @@ def fgraph_from_model(model: Model) -> FunctionGraph:
210223
# Now that we have Deterministic dummy Ops, we remove the noisy `Identity`s from the graph
211224
remove_identity_rewrite.apply(fgraph)
212225

213-
return fgraph
226+
return fgraph, memo
214227

215228

216229
def model_from_fgraph(fgraph: FunctionGraph) -> Model:
@@ -280,7 +293,7 @@ def model_from_fgraph(fgraph: FunctionGraph) -> Model:
280293
return model
281294

282295

283-
def clone_model(model: Model) -> Model:
296+
def clone_model(model: Model) -> Tuple[Model]:
284297
"""Clone a PyMC model.
285298
286299
Recreates a PyMC model with clones of the original variables.
@@ -308,4 +321,4 @@ def clone_model(model: Model) -> Model:
308321
z = pm.Deterministic("z", clone_x + 1)
309322
310323
"""
311-
return model_from_fgraph(fgraph_from_model(model))
324+
return model_from_fgraph(fgraph_from_model(model)[0])

0 commit comments

Comments
 (0)