Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate n_splits with cv #362

Merged
merged 4 commits into from
Jan 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions doc/spec/estimation/dml.rst
Original file line number Diff line number Diff line change
Expand Up @@ -551,10 +551,10 @@ Usage FAQs
If one uses cross-validated estimators as first stages, then model selection for the first stage models
is performed automatically.

- **How should I set the parameter `n_splits`?**
- **How should I set the parameter `cv`?**

This parameter defines the number of data partitions to create in order to fit the first stages in a
crossfittin manner (see :class:`._OrthoLearner`). The default is 2, which
crossfitting manner (see :class:`._OrthoLearner`). The default is 2, which
is the minimal. However, larger values like 5 or 6 can lead to greater statistical stability of the method,
especially if the number of samples is small. So we advise that for small datasets, one should raise this
value. This can increase the computational cost as more first stage models are being fitted.
Expand Down
8 changes: 4 additions & 4 deletions doc/spec/estimation/dr.rst
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ Below we give a brief description of each of these classes:
}, cv=10, n_jobs=-1, scoring='neg_mean_squared_error'
)
est = DRLearner(model_regression=model_reg(), model_propensity=model_clf(),
model_final=model_reg(), n_splits=5)
model_final=model_reg(), cv=5)
est.fit(y, T, X=X, W=W)
point = est.effect(X, T0=T0, T1=T1)

Expand Down Expand Up @@ -427,7 +427,7 @@ Usage FAQs
}, cv=5, n_jobs=-1, scoring='neg_mean_squared_error'
)
est = DRLearner(model_regression=model_reg(), model_propensity=model_clf(),
model_final=model_reg(), n_splits=5)
model_final=model_reg(), cv=5)
est.fit(y, T, X=X, W=W)
point = est.effect(X, T0=T0, T1=T1)

Expand Down Expand Up @@ -467,10 +467,10 @@ Usage FAQs
If one uses cross-validated estimators as first stages, then model selection for the first stage models
is performed automatically.

- **How should I set the parameter `n_splits`?**
- **How should I set the parameter `cv`?**

This parameter defines the number of data partitions to create in order to fit the first stages in a
crossfittin manner (see :class:`._OrthoLearner`). The default is 2, which
crossfitting manner (see :class:`._OrthoLearner`). The default is 2, which
is the minimal. However, larger values like 5 or 6 can lead to greater statistical stability of the method,
especially if the number of samples is small. So we advise that for small datasets, one should raise this
value. This can increase the computational cost as more first stage models are being fitted.
Expand Down
34 changes: 25 additions & 9 deletions econml/_ortho_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ class _OrthoLearner(TreatmentExpansionMixin, LinearCateEstimator):
The categories to use when encoding discrete treatments (or 'auto' to use the unique sorted values).
The first category will be treated as the control treatment.

n_splits: int, cross-validation generator or an iterable
cv: int, cross-validation generator or an iterable
Determines the cross-validation splitting strategy.
Possible inputs for cv are:

Expand Down Expand Up @@ -333,7 +333,7 @@ def _gen_ortho_learner_model_final(self):
np.random.seed(123)
X = np.random.normal(size=(100, 3))
y = X[:, 0] + X[:, 1] + np.random.normal(0, 0.1, size=(100,))
est = OrthoLearner(n_splits=2, discrete_treatment=False, discrete_instrument=False,
est = OrthoLearner(cv=2, discrete_treatment=False, discrete_instrument=False,
categories='auto', random_state=None)
est.fit(y, X[:, 0], W=X[:, 1:])

Expand Down Expand Up @@ -391,7 +391,7 @@ def _gen_ortho_learner_model_final(self):
import scipy.special
T = np.random.binomial(1, scipy.special.expit(W[:, 0]))
y = T + W[:, 0] + np.random.normal(0, 0.01, size=(100,))
est = OrthoLearner(n_splits=2, discrete_treatment=True, discrete_instrument=False,
est = OrthoLearner(cv=2, discrete_treatment=True, discrete_instrument=False,
categories='auto', random_state=None)
est.fit(y, T, W=W)

Expand Down Expand Up @@ -424,8 +424,9 @@ def _gen_ortho_learner_model_final(self):
"""

def __init__(self, *,
discrete_treatment, discrete_instrument, categories, n_splits, random_state,
mc_iters=None, mc_agg='mean'):
discrete_treatment, discrete_instrument, categories, cv, random_state,
n_splits='raise', mc_iters=None, mc_agg='mean'):
self.cv = cv
self.n_splits = n_splits
self.discrete_treatment = discrete_treatment
self.discrete_instrument = discrete_instrument
Expand Down Expand Up @@ -566,7 +567,7 @@ def fit(self, Y, T, X=None, W=None, Z=None, *, sample_weight=None, sample_var=No
Sample variance for each sample
groups: (n,) vector, optional
All rows corresponding to the same group will be kept together during splitting.
If groups is not None, the n_splits argument passed to this class's initializer
If groups is not None, the cv argument passed to this class's initializer
must support a 'groups' argument to its split method.
cache_values: bool, default False
Whether to cache the inputs and computed nuisances, which will allow refitting a different final model
Expand Down Expand Up @@ -712,16 +713,16 @@ def _fit_nuisances(self, Y, T, X=None, W=None, Z=None, sample_weight=None, group
if self.discrete_instrument:
Z = self.z_transformer.transform(reshape(Z, (-1, 1)))

if self.n_splits == 1: # special case, no cross validation
if self.cv == 1: # special case, no cross validation
folds = None
else:
splitter = check_cv(self.n_splits, [0], classifier=stratify)
splitter = check_cv(self.cv, [0], classifier=stratify)
# if check_cv produced a new KFold or StratifiedKFold object, we need to set shuffle and random_state
# TODO: ideally, we'd also infer whether we need a GroupKFold (if groups are passed)
# however, sklearn doesn't support both stratifying and grouping (see
# https://github.com/scikit-learn/scikit-learn/issues/13621), so for now the user needs to supply
# their own object that supports grouping if they want to use groups.
if splitter != self.n_splits and isinstance(splitter, (KFold, StratifiedKFold)):
if splitter != self.cv and isinstance(splitter, (KFold, StratifiedKFold)):
splitter.shuffle = True
splitter.random_state = self._random_state

Expand Down Expand Up @@ -856,3 +857,18 @@ def models_nuisance_(self):
if not hasattr(self, '_models_nuisance'):
raise AttributeError("Model is not fitted!")
return self._models_nuisance

#######################################################
# These should be removed once `n_splits` is deprecated
#######################################################

@property
def n_splits(self):
return self.cv

@n_splits.setter
def n_splits(self, value):
if value != 'raise':
warn("Parameter `n_splits` has been deprecated and will be removed in the next version. "
"Use parameter `cv` instead.")
self.cv = value
10 changes: 6 additions & 4 deletions econml/dml/_rlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class _RLearner(_OrthoLearner):
The categories to use when encoding discrete treatments (or 'auto' to use the unique sorted values).
The first category will be treated as the control treatment.

n_splits: int, cross-validation generator or an iterable
cv: int, cross-validation generator or an iterable
Determines the cross-validation splitting strategy.
Possible inputs for cv are:

Expand Down Expand Up @@ -216,7 +216,7 @@ def _gen_rlearner_model_final(self):
np.random.seed(123)
X = np.random.normal(size=(1000, 3))
y = X[:, 0] + X[:, 1] + np.random.normal(0, 0.01, size=(1000,))
est = RLearner(n_splits=2, discrete_treatment=False, categories='auto', random_state=None)
est = RLearner(cv=2, discrete_treatment=False, categories='auto', random_state=None)
est.fit(y, X[:, 0], X=np.ones((X.shape[0], 1)), W=X[:, 1:])

>>> est.const_marginal_effect(np.ones((1,1)))
Expand Down Expand Up @@ -261,10 +261,12 @@ def _gen_rlearner_model_final(self):
is multidimensional, then the average of the MSEs for each dimension of Y is returned.
"""

def __init__(self, *, discrete_treatment, categories, n_splits, random_state, mc_iters=None, mc_agg='mean'):
def __init__(self, *, discrete_treatment, categories, cv, random_state,
n_splits='raise', mc_iters=None, mc_agg='mean'):
super().__init__(discrete_treatment=discrete_treatment,
discrete_instrument=False, # no instrument, so doesn't matter
categories=categories,
cv=cv,
n_splits=n_splits,
random_state=random_state,
mc_iters=mc_iters,
Expand Down Expand Up @@ -345,7 +347,7 @@ def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, grou
Sample variance for each sample
groups: (n,) vector, optional
All rows corresponding to the same group will be kept together during splitting.
If groups is not None, the n_splits argument passed to this class's initializer
If groups is not None, the `cv` argument passed to this class's initializer
must support a 'groups' argument to its split method.
cache_values: bool, default False
Whether to cache inputs and first stage results, which will allow refitting a different final model
Expand Down
4 changes: 2 additions & 2 deletions econml/dml/causal_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,8 +439,8 @@ def __init__(self, *,
cv = self.n_crossfit_splits
super().__init__(discrete_treatment=discrete_treatment,
categories=categories,
# TODO. change to `cv=cv, n_splits='raise` when merged with the `n_splits` deprecation PR
n_splits=cv,
cv=cv,
n_splits=n_crossfit_splits,
mc_iters=mc_iters,
mc_agg=mc_agg,
random_state=random_state)
Expand Down
Loading