Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added GridSearchCV list that can help auto select among multiple models #328

Merged
merged 6 commits into from
Dec 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 60 additions & 2 deletions econml/sklearn_extensions/model_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@

import numbers
import warnings

from sklearn.base import BaseEstimator
from sklearn.utils.multiclass import type_of_target
import numpy as np
import scipy.sparse as sp
from joblib import Parallel, delayed
from sklearn.base import clone, is_classifier
from sklearn.model_selection import KFold, StratifiedKFold, check_cv
from sklearn.model_selection import KFold, StratifiedKFold, check_cv, GridSearchCV
# TODO: conisder working around relying on sklearn implementation details
from sklearn.model_selection._validation import (_check_is_permutation,
_fit_and_predict)
Expand Down Expand Up @@ -201,6 +202,63 @@ def split(self, X, y, sample_weight=None):
return _split_weighted_sample(self, X, y, sample_weight, is_stratified=True)


class GridSearchCVList(BaseEstimator):
""" An extension of GridSearchCV that allows for passing a list of estimators each with their own
parameter grid and returns the best among all estimators in the list and hyperparameter in their
corresponding grid. We are only changing the estimator parameter to estimator_list and the param_grid
parameter to be a list of parameter grids. The rest of the parameters are the same as in
:meth:`~sklearn.model_selection.GridSearchCV`. See the documentation of that class
for explanation of the remaining parameters.

Parameters
----------
estimator_list : list of estimator object.
Each estimator in th list is assumed to implement the scikit-learn estimator interface.
Either estimator needs to provide a ``score`` function,
or ``scoring`` must be passed.

param_grid : list of dict or list of list of dictionaries
For each estimator, the dictionary with parameters names (`str`) as keys and lists of
parameter settings to try as values, or a list of such
dictionaries, in which case the grids spanned by each dictionary
in the list are explored. This enables searching over any sequence
of parameter settings.
"""

def __init__(self, estimator_list, param_grid_list, scoring=None,
n_jobs=None, refit=True, cv=None, verbose=0, pre_dispatch='2*n_jobs',
error_score='raise-deprecating', return_train_score=False):
self.estimator_list = estimator_list
self.param_grid_list = param_grid_list
self.scoring = scoring
self.n_jobs = n_jobs
self.refit = refit
self.cv = cv
self.verbose = verbose
self.pre_dispatch = pre_dispatch
self.error_score = error_score
self.return_train_score = return_train_score
return

def fit(self, X, y, **fit_params):
self._gcv_list = [GridSearchCV(estimator, param_grid, scoring=self.scoring,
n_jobs=self.n_jobs, refit=self.refit, cv=self.cv, verbose=self.verbose,
pre_dispatch=self.pre_dispatch, error_score=self.error_score,
return_train_score=self.return_train_score)
for estimator, param_grid in zip(self.estimator_list, self.param_grid_list)]
self.best_ind_ = np.argmax([gcv.fit(X, y, **fit_params).best_score_ for gcv in self._gcv_list])
self.best_estimator_ = self._gcv_list[self.best_ind_].best_estimator_
self.best_score_ = self._gcv_list[self.best_ind_].best_score_
self.best_params_ = self._gcv_list[self.best_ind_].best_params_
return self

def predict(self, X):
return self.best_estimator_.predict(X)

def predict_proba(self, X):
return self.best_estimator_.predict_proba(X)


def _cross_val_predict(estimator, X, y=None, *, groups=None, cv=None,
n_jobs=None, verbose=0, fit_params=None,
pre_dispatch='2*n_jobs', method='predict', safe=True):
Expand Down
Loading