|
60 | 60 | from pymc.distributions import joint_logpt
|
61 | 61 | from pymc.distributions.logprob import _get_scaling
|
62 | 62 | from pymc.distributions.transforms import _default_transform
|
63 |
| -from pymc.exceptions import ImputationWarning, SamplingError, ShapeError |
| 63 | +from pymc.exceptions import ImputationWarning, SamplingError, ShapeError, ShapeWarning |
64 | 64 | from pymc.initial_point import make_initial_point_fn
|
65 | 65 | from pymc.math import flatten_list
|
66 | 66 | from pymc.util import (
|
@@ -1194,14 +1194,31 @@ def set_data(
|
1194 | 1194 | f"{new_length}, so new coord values for the {dname} dimension are required."
|
1195 | 1195 | )
|
1196 | 1196 | if isinstance(length_tensor, TensorConstant):
|
| 1197 | + # The dimension was fixed in length. |
| 1198 | + # Resizing a data variable in this dimension would |
| 1199 | + # definitely lead to shape problems. |
1197 | 1200 | raise ShapeError(
|
1198 | 1201 | f"Resizing dimension '{dname}' is impossible, because "
|
1199 | 1202 | "a 'TensorConstant' stores its length. To be able "
|
1200 | 1203 | "to change the dimension length, pass `mutable=True` when "
|
1201 | 1204 | "registering the dimension via `model.add_coord`, "
|
1202 | 1205 | "or define it via a `pm.MutableData` variable."
|
1203 | 1206 | )
|
| 1207 | + elif isinstance(length_tensor, ScalarSharedVariable): |
| 1208 | + # The dimension is mutable, but was defined without being linked |
| 1209 | + # to a shared variable. This is allowed, but slightly dangerous. |
| 1210 | + warnings.warn( |
| 1211 | + f"You are resizing a variable with dimension '{dname}' which was initialized" |
| 1212 | + " as a mutable dimension and is not linked to the `MutableData` variable." |
| 1213 | + " Remember to update the dimension length by calling " |
| 1214 | + f"`Model.set_dim({dname}, new_length={new_length})` manually," |
| 1215 | + " preferably _before_ updating `MutableData` variables that use this dimension.", |
| 1216 | + ShapeWarning, |
| 1217 | + stacklevel=2, |
| 1218 | + ) |
1204 | 1219 | else:
|
| 1220 | + # The dimension was created from another model variable. |
| 1221 | + # If that was a non-mutable variable, there will definitely be shape problems. |
1205 | 1222 | length_belongs_to = length_tensor.owner.inputs[0].owner.inputs[0]
|
1206 | 1223 | if not isinstance(length_belongs_to, SharedVariable):
|
1207 | 1224 | raise ShapeError(
|
|
0 commit comments