Skip to content

feat: ability to use only precomputed point predictions #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 30 additions & 8 deletions src/macest/classification/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ class ModelWithConfidence:

def __init__(
self,
point_pred_model: _ClassificationPointPredictionModel,
x_train: np.ndarray,
y_train: Iterable[int],
point_pred_model: Optional[_ClassificationPointPredictionModel] = None,
macest_model_params: MacestConfModelParams = MacestConfModelParams(),
precomputed_neighbour_info: Optional[PrecomputedNeighbourInfo] = None,
graph: Optional[Dict[int, nmslib.dist.FloatIndex]] = None,
Expand Down Expand Up @@ -110,6 +110,9 @@ def __init__(
:param empirical_conflict_constant: Constant to set confidence conflicting predictions
calculated during calibration
"""
if point_pred_model is None and training_preds_by_class is None:
raise ValueError("One of 'point_pred_model' or 'training_preds_by_class'"
"must be specified")
self.point_pred_model = point_pred_model
self.x_train = x_train
self.y_train = y_train
Expand Down Expand Up @@ -161,6 +164,8 @@ def predict(self, x_star: np.ndarray) -> np.ndarray:
:param x_star: The point(s) at which we want to predict
:return: A point prediction for the given x_star
"""
if self.point_pred_model is None:
raise ValueError("Cannot predict as no 'point_pred_model' has been initialized")
return self.point_pred_model.predict(x_star)

def build_class_graphs(self) -> Dict[int, nmslib.dist.FloatIndex]:
Expand Down Expand Up @@ -286,19 +291,26 @@ def predict_proba(
return relative_conf

def predict_confidence_of_point_prediction(
self, x_star: np.ndarray, change_conflicts: bool = False,
self,
x_star: np.ndarray,
prec_point_preds: Optional[np.ndarray] = None,
change_conflicts: bool = False,
) -> np.ndarray:
"""
Estimate a single confidence score, this represents the confidence of the point prediction
being correct rather than a confidence score for each class.

:param x_star: The point to predict confidently
:param prec_point_preds: The pre-computed model predictions
:param change_conflicts: Boolean, true means conflicting predictions between macest and
point prediction are set to an empirical constant

:return: The confidence in the point prediction being correct

"""
if prec_point_preds is not None:
self.point_preds = prec_point_preds

if self.point_preds is not None:
point_prediction = self.point_preds
else:
Expand Down Expand Up @@ -377,6 +389,7 @@ def fit(
self,
x_cal: np.ndarray,
y_cal: np.ndarray,
prec_point_preds: Optional[np.ndarray] = None,
param_range: SearchBounds = SearchBounds(),
optimiser_args: Optional[Dict[Any, Any]] = None,
) -> None:
Expand All @@ -385,11 +398,15 @@ def fit(

:param x_cal: Calibration data
:param y_cal: Target values
:param prec_point_preds: The pre-computed model predictions
:param param_range: The bounds within which to search for MACEst parameters
:param optimiser_args: Any arguments for the optimiser (see scipy.optimize)

:return: None
"""
if prec_point_preds is not None:
self.point_preds = prec_point_preds

if optimiser_args is None:
optimiser_args = {}

Expand Down Expand Up @@ -471,7 +488,8 @@ def __init__(
self.precomputed_index = self.precomputed_neighbours[1]
self.precomputed_error = self.precomputed_neighbours[2]
self._n_classes = len(np.unique(self.model.y_train))
self.model.point_preds = self.model.predict(self.x_cal)
if self.model.point_preds is None:
self.model.point_preds = self.model.predict(self.x_cal)
self.model.distance_to_neighbours = self.precomputed_distance
self.model.index_of_neighbours = self.precomputed_index
self.model.error_on_neighbours = self.precomputed_error
Expand Down Expand Up @@ -576,12 +594,15 @@ def fit(
self,
optimiser: Literal["de"] = "de",
optimiser_args: Optional[Dict[Any, Any]] = None,
update_empirical_conflict_constant: bool = True,
) -> ModelWithConfidence:
"""
Fit MACEst model using the calibration data.

:param optimiser: The optimisation method
:param optimiser_args: Any arguments for the optimisation strategy
:param update_empirical_conflict_constant: Boolean, true means the constant to set
confidence conflicting predictions will be updated at the end of fit

:return: A ModelWithConfidence object with the parameters that minimises the loss function
"""
Expand Down Expand Up @@ -619,11 +640,12 @@ def fit(

self.model.macest_model_params = self.set_macest_model_params()

point_preds = self.model.predict(self.x_cal)
conflicts = self.model.find_conflicting_predictions(self.x_cal)
self.model.empirical_conflict_constant = np.array(
point_preds[conflicts] == self.y_cal[conflicts]
).mean()
if update_empirical_conflict_constant:
point_preds = self.model.predict(self.x_cal)
conflicts = self.model.find_conflicting_predictions(self.x_cal)
self.model.empirical_conflict_constant = np.array(
point_preds[conflicts] == self.y_cal[conflicts]
).mean()

self.model.distance_to_neighbours = None
self.model.index_of_neighbours = None
Expand Down
23 changes: 20 additions & 3 deletions src/macest/regression/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ class ModelWithPredictionInterval:

def __init__(
self,
model: _RegressionPointPredictionModel,
x_train: np.ndarray,
train_err: np.ndarray,
model: Optional[_RegressionPointPredictionModel] = None,
macest_model_params: MacestPredIntervalModelParams = MacestPredIntervalModelParams(),
error_dist: Literal["normal", "laplace"] = "normal",
dist_func: Literal["linear", "error_weighted_poly"] = "linear",
Expand All @@ -96,6 +96,8 @@ def __init__(
:param prec_ind_of_nn: The pre-computed nearest neighbour indices for the calibration and test data
:param prec_graph: The pre-computed graph to use for online hnsw search
"""
if model is None and prec_point_preds is None:
raise ValueError("One of 'model' or 'prec_point_preds' must be specified")
self.model = model
self.x_train = x_train
self.train_err = train_err
Expand Down Expand Up @@ -126,6 +128,8 @@ def predict(self, x_star: np.ndarray) -> np.ndarray:

:return: pred_star : The point prediction for x_star
"""
if self.model is None:
raise ValueError("Cannot predict as no 'model' has been initialized")
pred_star = self.model.predict(x_star)
return pred_star

Expand Down Expand Up @@ -280,16 +284,23 @@ def _distribution(self, x_star: np.ndarray) -> laplace_gen:
return dist

def predict_interval(
self, x_star: np.ndarray, conf_level: Union[np.ndarray, int, float] = 90,
self,
x_star: np.ndarray,
prec_point_preds: Optional[np.ndarray] = None,
conf_level: Union[np.ndarray, int, float] = 90,
) -> np.ndarray:
"""
Predict the upper and lower prediction interval bounds for a given confidence level.

:param x_star: The position for which we would like to predict
:param prec_point_preds: The pre-computed model predictions
:param conf_level:

:return: The confidence bounds for each x_star for each confidence level
"""
if prec_point_preds is not None:
self.point_preds = prec_point_preds

dist = self._distribution(x_star)
lower_perc = (100 - conf_level) / 2
upper_perc = 100 - lower_perc
Expand Down Expand Up @@ -330,6 +341,7 @@ def fit(
self,
x_cal: np.ndarray,
y_cal: np.ndarray,
prec_point_preds: Optional[np.ndarray] = None,
param_range: SearchBounds = SearchBounds(),
optimiser_args: Optional[Dict[Any, Any]] = None,
) -> None:
Expand All @@ -338,11 +350,15 @@ def fit(

:param x_cal: Calibration data
:param y_cal: Target values
:param prec_point_preds: The pre-computed model predictions
:param param_range: The bounds within which to search for MACEst parameters
:param optimiser_args: Any arguments for the optimiser (see scipy.optimize)

:return: None
"""
if prec_point_preds is not None:
self.point_preds = prec_point_preds

if optimiser_args is None:
optimiser_args = {}

Expand Down Expand Up @@ -418,7 +434,8 @@ def __init__(
self.prec_graph = self.model.build_graph()
self.model.prec_graph = self.prec_graph
self.prec_dist, self.prec_ind = self._prec_neighbours()
self.model.point_preds = self.model.predict(self.x_cal)
if self.model.point_preds is None:
self.model.point_preds = self.model.predict(self.x_cal)

def _prec_neighbours(self) -> Tuple[Dict[int, np.ndarray], Dict[int, np.ndarray]]:
"""
Expand Down