Skip to content

Commit 102522e

Browse files
Warn about resizing MutableData dims that are not symbolically linked
Closes #5812
1 parent aa857f1 commit 102522e

File tree

2 files changed

+37
-2
lines changed

2 files changed

+37
-2
lines changed

pymc/model.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
from pymc.distributions import joint_logpt
6161
from pymc.distributions.logprob import _get_scaling
6262
from pymc.distributions.transforms import _default_transform
63-
from pymc.exceptions import ImputationWarning, SamplingError, ShapeError
63+
from pymc.exceptions import ImputationWarning, SamplingError, ShapeError, ShapeWarning
6464
from pymc.initial_point import make_initial_point_fn
6565
from pymc.math import flatten_list
6666
from pymc.util import (
@@ -1194,14 +1194,31 @@ def set_data(
11941194
f"{new_length}, so new coord values for the {dname} dimension are required."
11951195
)
11961196
if isinstance(length_tensor, TensorConstant):
1197+
# The dimension was fixed in length.
1198+
# Resizing a data variable in this dimension would
1199+
# definitely lead to shape problems.
11971200
raise ShapeError(
11981201
f"Resizing dimension '{dname}' is impossible, because "
11991202
"a 'TensorConstant' stores its length. To be able "
12001203
"to change the dimension length, pass `mutable=True` when "
12011204
"registering the dimension via `model.add_coord`, "
12021205
"or define it via a `pm.MutableData` variable."
12031206
)
1207+
elif isinstance(length_tensor, ScalarSharedVariable):
1208+
# The dimension is mutable, but was defined without being linked
1209+
# to a shared variable. This is allowed, but slightly dangerous.
1210+
warnings.warn(
1211+
f"You are resizing a variable with dimension '{dname}' which was initialized"
1212+
" as a mutable dimension and is not linked to the `MutableData` variable."
1213+
" Remember to update the dimension length by calling "
1214+
f"`Model.set_dim({dname}, new_length={new_length})` manually,"
1215+
" preferably _before_ updating `MutableData` variables that use this dimension.",
1216+
ShapeWarning,
1217+
stacklevel=2,
1218+
)
12041219
else:
1220+
# The dimension was created from another model variable.
1221+
# If that was a non-mutable variable, there will definitely be shape problems.
12051222
length_belongs_to = length_tensor.owner.inputs[0].owner.inputs[0]
12061223
if not isinstance(length_belongs_to, SharedVariable):
12071224
raise ShapeError(

pymc/tests/test_model.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from pymc import Deterministic, Potential
3838
from pymc.blocking import DictToArrayBijection, RaveledVars
3939
from pymc.distributions import Normal, transforms
40-
from pymc.exceptions import ShapeError
40+
from pymc.exceptions import ShapeError, ShapeWarning
4141
from pymc.model import Point, ValueGradFunction
4242
from pymc.tests.helpers import SeededTest
4343

@@ -798,6 +798,24 @@ def test_set_dim_with_coords():
798798
assert pmodel.coords["mdim"] == ("A", "B", "C")
799799

800800

801+
def test_set_data_warns_resize_mutable_dim():
802+
with pm.Model() as pmodel:
803+
pmodel.add_coord("mdim", mutable=True, length=2)
804+
pm.MutableData("mdata", [1, 2], dims="mdim")
805+
806+
# First resize the dimension.
807+
pmodel.dim_lengths["mdim"].set_value(3)
808+
# Then change the data.
809+
pmodel.set_data("mdata", [1, 2, 3])
810+
811+
# Now the other way around.
812+
# Because the dimension doesn't depend on the data variable,
813+
# a warning shoudl be emitted.
814+
with pytest.warns(ShapeWarning, match="update the dimension length"):
815+
pmodel.set_data("mdata", [1, 2, 3, 4])
816+
pass
817+
818+
801819
@pytest.mark.parametrize("jacobian", [True, False])
802820
def test_model_logp(jacobian):
803821
with pm.Model() as m:

0 commit comments

Comments
 (0)