Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabling summary() even when inference not available #363

Merged
merged 6 commits into from
Jan 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions econml/drlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,14 @@ def __init__(self, *,
categories=categories,
random_state=random_state)

def _get_inference_options(self):
options = super()._get_inference_options()
if not self.multitask_model_final:
options.update(auto=GenericModelFinalInferenceDiscrete)
else:
options.update(auto=lambda: None)
return options

def _gen_ortho_learner_model_nuisance(self):
if self.model_propensity == 'auto':
model_propensity = LogisticRegressionCV(cv=3, solver='lbfgs', multi_class='auto',
Expand All @@ -426,7 +434,7 @@ def _gen_ortho_learner_model_final(self):
@_deprecate_positional("X and W should be passed by keyword only. In a future release "
"we will disallow passing X and W by position.", ['X', 'W'])
def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, groups=None,
cache_values=False, inference=None):
cache_values=False, inference='auto'):
"""
Estimate the counterfactual model from data, i.e. estimates function :math:`\\theta(\\cdot)`.

Expand Down Expand Up @@ -463,6 +471,10 @@ def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, grou
sample_weight=sample_weight, sample_var=sample_var, groups=groups,
cache_values=cache_values, inference=inference)

def refit_final(self, *, inference='auto'):
return super().refit_final(inference=inference)
refit_final.__doc__ = _OrthoLearner.refit_final.__doc__

def score(self, Y, T, X=None, W=None):
"""
Score the fitted CATE model on a new data set. Generates nuisance parameters
Expand Down Expand Up @@ -851,10 +863,6 @@ def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, grou
sample_weight=sample_weight, sample_var=sample_var, groups=groups,
cache_values=cache_values, inference=inference)

def refit_final(self, *, inference='auto'):
return super().refit_final(inference=inference)
refit_final.__doc__ = _OrthoLearner.refit_final.__doc__

@property
def fit_cate_intercept_(self):
return self.model_final_.fit_intercept
Expand Down Expand Up @@ -1151,10 +1159,6 @@ def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, grou
sample_weight=sample_weight, sample_var=None, groups=groups,
cache_values=cache_values, inference=inference)

def refit_final(self, *, inference='auto'):
return super().refit_final(inference=inference)
refit_final.__doc__ = _OrthoLearner.refit_final.__doc__

@property
def fit_cate_intercept_(self):
return self.model_final_.fit_intercept
Expand Down Expand Up @@ -1465,10 +1469,6 @@ def fit(self, Y, T, X=None, W=None, *, sample_weight=None, sample_var=None, grou
sample_weight=sample_weight, sample_var=None, groups=groups,
cache_values=cache_values, inference=inference)

def refit_final(self, *, inference='auto'):
return super().refit_final(inference=inference)
refit_final.__doc__ = _OrthoLearner.refit_final.__doc__

def multitask_model_cate(self):
# Replacing to remove docstring
super().multitask_model_cate()
Expand Down
Loading