Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refitting model_final and nuisance averaging #360

Merged
merged 27 commits into from
Jan 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
a3adef1
Support refitting in DML
kbattocchi Dec 16, 2020
d3520ee
Add support for monte carlo nuisance estimation
kbattocchi Oct 28, 2020
b4a6bdf
Address PR feedback
kbattocchi Dec 23, 2020
7b98d46
Address monte carlo feedback
kbattocchi Dec 23, 2020
f636dab
Refit test fixes
kbattocchi Dec 23, 2020
deda008
added RScorer
vasilismsr Dec 28, 2020
b7470f2
Revert "added RScorer"
vasilismsr Dec 28, 2020
cf7122a
abstract class approach to refitting
vasilismsr Jan 2, 2021
72a297a
Merge branch 'master' into vasilis/refit
vsyrgkanis Jan 3, 2021
ba0dbf0
fixed some bugs related to merging with shap and new attribute names
vasilismsr Jan 3, 2021
3b37a4c
changed refit to refit_final to make sure that it's obvious that only…
vasilismsr Jan 4, 2021
d29f8a3
added refit example in the dml notebook
vasilismsr Jan 4, 2021
ad27125
Update econml/dml.py
vsyrgkanis Jan 9, 2021
031ffb0
addressed review comments. Deprecated positional arguments at init.
vasilismsr Jan 9, 2021
12c5518
Merge branch 'vasilis/refit' of github.com:microsoft/EconML into vasi…
vasilismsr Jan 9, 2021
8a49619
linting
vasilismsr Jan 9, 2021
bad9db9
docstring fixes
vasilismsr Jan 9, 2021
1951a31
docstring fixes
vasilismsr Jan 9, 2021
b523ef0
fixed failing tests due to deprecation of positional
vasilismsr Jan 9, 2021
d913aae
fixed failing notebook due to deprecation of positional
vasilismsr Jan 9, 2021
dafa116
fixed failing test due to positional deprecation
vasilismsr Jan 9, 2021
b52408f
merged with master
vasilismsr Jan 9, 2021
c553795
fixed relative import. fixed drlearner test featurizer access
vasilismsr Jan 9, 2021
e35305f
fixed merge bugs
vasilismsr Jan 10, 2021
e9f3f34
fixed bugs from merge
vasilismsr Jan 10, 2021
222d297
fixed docstring
vasilismsr Jan 10, 2021
7f437c1
added refit re-implementation in causalforest dml
vasilismsr Jan 10, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 24 additions & 17 deletions econml/_cate_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,8 +529,11 @@ class LinearModelFinalCateEstimatorMixin(BaseCateEstimator):
"""
Base class for models where the final stage is a linear model.

Subclasses must expose a ``model_final`` attribute containing the model's
final stage model.
Such an estimator must implement a :attr:`model_final_` attribute that points
to the fitted final :class:`.StatsModelsLinearRegression` object that
represents the fitted CATE model. Also must implement :attr:`featurizer_` that points
to the fitted featurizer and :attr:`bias_part_of_coef` that designates
if the intercept is the first element of the :attr:`model_final_` coefficient.

Attributes
----------
Expand All @@ -544,7 +547,9 @@ def _get_inference_options(self):
options.update(auto=LinearModelFinalInference)
return options

bias_part_of_coef = False
@property
def bias_part_of_coef(self):
return False

@property
def coef_(self):
Expand All @@ -561,9 +566,9 @@ def coef_(self):
a vector and not a 2D array. For binary treatment the n_t dimension is
also omitted.
"""
return parse_final_model_params(self.model_final.coef_, self.model_final.intercept_,
return parse_final_model_params(self.model_final_.coef_, self.model_final_.intercept_,
self._d_y, self._d_t, self._d_t_in, self.bias_part_of_coef,
self.fit_cate_intercept)[0]
self.fit_cate_intercept_)[0]

@property
def intercept_(self):
Expand All @@ -578,11 +583,11 @@ def intercept_(self):
a vector and not a 2D array. For binary treatment the n_t dimension is
also omitted.
"""
if not self.fit_cate_intercept:
if not self.fit_cate_intercept_:
raise AttributeError("No intercept was fitted!")
return parse_final_model_params(self.model_final.coef_, self.model_final.intercept_,
return parse_final_model_params(self.model_final_.coef_, self.model_final_.intercept_,
self._d_y, self._d_t, self._d_t_in, self.bias_part_of_coef,
self.fit_cate_intercept)[1]
self.fit_cate_intercept_)[1]

@BaseCateEstimator._defer_to_inference
def coef__interval(self, *, alpha=0.1):
Expand Down Expand Up @@ -718,11 +723,11 @@ def summary(self, alpha=0.1, value=0, decimals=3, feature_names=None, treatment_
return smry

def shap_values(self, X, *, feature_names=None, treatment_names=None, output_names=None, background_samples=100):
if hasattr(self, "featurizer") and self.featurizer is not None:
X = self.featurizer.transform(X)
if hasattr(self, "featurizer_") and self.featurizer_ is not None:
X = self.featurizer_.transform(X)
feature_names = self.cate_feature_names(feature_names)
return _shap_explain_joint_linear_model_cate(self.model_final, X, self._d_t, self._d_y,
self.fit_cate_intercept,
return _shap_explain_joint_linear_model_cate(self.model_final_, X, self._d_t, self._d_y,
self.bias_part_of_coef,
feature_names=feature_names, treatment_names=treatment_names,
output_names=output_names,
input_names=self._input_names,
Expand All @@ -736,9 +741,11 @@ class StatsModelsCateEstimatorMixin(LinearModelFinalCateEstimatorMixin):
Mixin class that offers `inference='statsmodels'` options to the CATE estimator
that inherits it.

Such an estimator must implement a :attr:`model_final` attribute that points
Such an estimator must implement a :attr:`model_final_` attribute that points
to the fitted final :class:`.StatsModelsLinearRegression` object that
represents the fitted CATE model.
represents the fitted CATE model. Also must implement :attr:`featurizer_` that points
to the fitted featurizer and :attr:`bias_part_of_coef` that designates
if the intercept is the first element of the :attr:`model_final_` coefficient.
"""

def _get_inference_options(self):
Expand Down Expand Up @@ -771,7 +778,7 @@ def _get_inference_options(self):

@property
def feature_importances_(self):
return self.model_final.feature_importances_
return self.model_final_.feature_importances_


class LinearModelFinalCateEstimatorDiscreteMixin(BaseCateEstimator):
Expand Down Expand Up @@ -822,7 +829,7 @@ def intercept_(self, T):
-------
intercept: float or (n_y,) array like
"""
if not self.fit_cate_intercept:
if not self.fit_cate_intercept_:
raise AttributeError("No intercept was fitted!")
_, T = self._expand_treatments(None, T)
ind = inverse_onehot(T).item() - 1
Expand Down Expand Up @@ -980,7 +987,7 @@ class StatsModelsCateEstimatorDiscreteMixin(LinearModelFinalCateEstimatorDiscret
Mixin class that offers `inference='statsmodels'` options to the CATE estimator
that inherits it.

Such an estimator must implement a :attr:`model_final` attribute that points
Such an estimator must implement a :attr:`model_final_` attribute that points
to a :class:`.StatsModelsLinearRegression` object that is cloned to fit
each discrete treatment target CATE model and a :attr:`fitted_models_final` attribute
that returns the list of fitted final models that represent the CATE for each categorical treatment.
Expand Down
Loading