From b6eeec8edca2a3f3db2b9bd90a280f66c5d3fe6a Mon Sep 17 00:00:00 2001 From: mohammed052 <142733772+mohammed052@users.noreply.github.com> Date: Thu, 11 Jan 2024 23:14:57 +0530 Subject: [PATCH] Remove unused comp_shape from NormalMixture --- pymc/distributions/mixture.py | 12 ++++-------- tests/distributions/test_mixture.py | 9 ++------- 2 files changed, 6 insertions(+), 15 deletions(-) diff --git a/pymc/distributions/mixture.py b/pymc/distributions/mixture.py index a599bebea1..10c2bb14ad 100644 --- a/pymc/distributions/mixture.py +++ b/pymc/distributions/mixture.py @@ -524,10 +524,6 @@ class NormalMixture: the component standard deviations tau : tensor_like of float the component precisions - comp_shape : shape of the Normal component - notice that it should be different than the shape - of the mixture distribution, with the last axis representing - the number of components. Notes ----- @@ -554,16 +550,16 @@ class NormalMixture: y = pm.NormalMixture("y", w=weights, mu=μ, sigma=σ, observed=data) """ - def __new__(cls, name, w, mu, sigma=None, tau=None, comp_shape=(), **kwargs): + def __new__(cls, name, w, mu, sigma=None, tau=None, **kwargs): _, sigma = get_tau_sigma(tau=tau, sigma=sigma) - return Mixture(name, w, Normal.dist(mu, sigma=sigma, size=comp_shape), **kwargs) + return Mixture(name, w, Normal.dist(mu, sigma=sigma), **kwargs) @classmethod - def dist(cls, w, mu, sigma=None, tau=None, comp_shape=(), **kwargs): + def dist(cls, w, mu, sigma=None, tau=None, **kwargs): _, sigma = get_tau_sigma(tau=tau, sigma=sigma) - return Mixture.dist(w, Normal.dist(mu, sigma=sigma, size=comp_shape), **kwargs) + return Mixture.dist(w, Normal.dist(mu, sigma=sigma), **kwargs) def _zero_inflated_mixture(*, name, nonzero_p, nonzero_dist, **kwargs): diff --git a/tests/distributions/test_mixture.py b/tests/distributions/test_mixture.py index d07a7be927..7ce6084d8d 100644 --- a/tests/distributions/test_mixture.py +++ b/tests/distributions/test_mixture.py @@ -820,10 +820,8 @@ def test_normal_mixture_nd(self, seeded_test, nd, ncomp): mus = Normal("mus", shape=comp_shape) taus = Gamma("taus", alpha=1, beta=1, shape=comp_shape) ws = Dirichlet("ws", np.ones(ncomp), shape=(ncomp,)) - mixture0 = NormalMixture("m", w=ws, mu=mus, tau=taus, shape=nd, comp_shape=comp_shape) - obs0 = NormalMixture( - "obs", w=ws, mu=mus, tau=taus, comp_shape=comp_shape, observed=observed - ) + mixture0 = NormalMixture("m", w=ws, mu=mus, tau=taus, shape=nd) + obs0 = NormalMixture("obs", w=ws, mu=mus, tau=taus, observed=observed) with Model() as model1: mus = Normal("mus", shape=comp_shape) @@ -867,7 +865,6 @@ def ref_rand(size, w, mu, sigma): "mu": Domain([[0.05, 2.5], [-5.0, 1.0]], edges=(None, None)), "sigma": Domain([[1, 1], [1.5, 2.0]], edges=(None, None)), }, - extra_args={"comp_shape": 2}, size=1000, ref_rand=ref_rand, ) @@ -878,7 +875,6 @@ def ref_rand(size, w, mu, sigma): "mu": Domain([[-5.0, 1.0, 2.5]], edges=(None, None)), "sigma": Domain([[1.5, 2.0, 3.0]], edges=(None, None)), }, - extra_args={"comp_shape": 3}, size=1000, ref_rand=ref_rand, ) @@ -902,7 +898,6 @@ def test_scalar_components(self): w=np.ones(npop) / npop, mu=mus, sigma=1e-5, - comp_shape=(nd, npop), shape=nd, ) z = Categorical("z", p=np.ones(npop) / npop, shape=nd)