Skip to content

epistemically humble MLP able to abstain from making predictions if uncertain

Notifications You must be signed in to change notification settings

luciensc/HumbleNet

Repository files navigation

HumbleNet

WORK IN PROGRESS

Intro

The goal of this project was to create an epistemically humble MLP. This is a model which abstains from making predictions on data points it is uncertain on. This is useful in situations where wrong predictions are much more problematic than no predictions - i.e., a large proportion of real-world applications where ML is used to assist humans in sensitive activities.

More specifically, we consider a scenario where our model should maximise the following objective function:

obj fxn: $f(y_{true}, y_{pred}) = w_{correct} * N_{correct} + w_{abstention} * N_{abstention} + w_{wrong} * N_{wrong}$

where we assume $w_{correct} = 1$, and $w_{wrong} = 0$. In this setting, we refer to $w_{abstention}$ as $\gamma$ where $0 \leq \gamma \leq 1$, which expresses how much an abstention is worth relative to a correct prediction. E.g. for $\gamma=0.5$, an abstention is worth half as much as a correct prediction. (Note: This can be restated as loss minimisation, where $w_{wrong} > w_{abstention} > w_{correct} $, and more generally for other values of $w_{correct}$ and $w_{wrong}$ - but the formulation above was easiest for our purposes.)

The intuition surrounding this objective function is the following:

  • If the model can predict everything correctly, it should always do so. This gains it a reward of 1 per correct prediction.
  • If the model is unsure on some data points, it might want to abstain. That way, it can still get the reward of $\gamma$ per abstention.
  • In consequence, abstention has an opportunity cost of 1-$\gamma$. This reflects the fact that abstention is less desirable model output than correct predictions.

Methods

Data

Data setup: we build on the digits dataset, which - unsurprisingly - comprises images of digits and their corresponding labels. 70% of the data were left uncorrupted and can easily be classified correctly by many models. 30% have been subjected to strong noise and are virtually impossible to predict correctly. There is no label info that indicates which data points are corrupted.

Clean digits: digits clean Partially corrupted digits: digits corrupted

Model architecture

The naive approach is to train a regular MLP to classify, and then try to quantify uncertainty using e.g. the entropy across the softmax probabilities, or the confidence of the model in the predicted class. To identify data points where abstention is recommended, just find a good threshold on the uncertainty metric. However, this approach seems suboptimal. Ideally, for a given $\gamma$, the model would explicitly decide on which data points to abstain jointly with the prediction, in order to maximise the objective function. (cf. also cost-aware classification) It seems plausible that this would work better than aforementioned heuristics, as the model should be capable of learning features that indicate uncertainty, as opposed to uncertainty just being a by-product.

For this purpose, we explore a neural network where the last layer with K neurons (K: number of classes)has been augmented by one additional pseudo-class, "abstention". We treat this neuron essentially as just another class by applying softmax over all K+1 neurons. This is enforces an intuitive constraint, namely that the different classes compete for probability mass incl. with abstention. (Note: This is a key difference compared to some other implementations of similar systems, where abstention probability has been considered as an independent component relative to the actual classes.)

TODO: add reference to other implementations

In order to connect the model training to the objective function framework, the loss is calculated as follows:
loss = -log( p_true_class + gamma * p_abstention )
This corresponds to the typical crossentropy loss, with the addition that abstention is technically always also a "correct" class that is worth $\gamma$ relative to a regular class.

Model training

We observe that training can get stuck in full abstention if $\gamma$ is high enough. However, if we have "delayed onset" of $\gamma$, where $\gamma = 0$ for the first few epochs, training behaviour recovers and the model learns to abstain on some proportion of the data.

TODO: have a few figures demonstrating training behaviour
TODO: should check how path-dependent this is

Incidentally, it remains unclear why this behaviour occurs. Without the delayed $\gamma$ onset, the model learns full abstention quite easily with the coverage collapsing to 0. The model then continues to assign ever-increasing probability mass to the abstention class. However, it is not clear why the model is unable at some point to progressively start shifting some probability mass to the other classes. Perhaps the model's internal representation of the data has become so strongly optimised for abstention, that it is a local minimum of loss where any small deviation from abstention is punished. To anyone reading this, I would be interested to hear your thoughts on the matter and how it could be investigated further.

Results

Coverage

We have a dataset with 70% uncorrupted, easily classified images, and 30% completely non-informative samples. If we consider an ideal model across a range of coverage (proportion of data without abstention) values, we would expect it to abstain from predicting on the corrupted samples for as long as possible. After the coverage goes above 70%, the model has to also start predicting on the corrupted samples. We plot this ideal behaviour in the dashed line. This ideal behaviour is then compared to the behaviour of our model for different values of $\gamma$. corruption rate

We find that for a range of $\gamma$ values, the model successfully behaves very similarly to the ideal model (dashed line) w.r.t. avoiding corrupted images when possible.

Overall, the coverage distribution across $\gamma$ values qualitatively behaves as expected, where for high values, the model abstains a lot from predictions, and for low values we get full coverage. coverage distribution

Model performance

Core research question: is the abstention model able to maximise the objective function better than a regular model using the best possible threshold on max confidence? model performance Yes. For intermediate values of $\gamma$ - which is the range of interest, where some limited amount of abstention takes place - the abstention model clearly outperforms the regular model on the objective function.

Discussion

In conclusion, the model is able to learn to abstain from predictions on corrupted data, and does so in a way that maximises the objective function better than a regular model. Overall the behaviour is very promising and corresponds to expectations. In a situation where there is some a priori intuition on a plausible value of $\gamma$, this framework seems like a good choice.

This should of course be tested on other datasets, in particular cases where the corruption is less absolute and ideally gradual - something we have not evaluated so far.

The nature of the model training collapse if $\gamma$ is too high in the beginning is interesting and deserves further investigation.

Lastly, our model demonstrates epistemic humility by abstaining on data points with high aleatoric uncertainty (inherent noise). This is distinct from being able to deal with epistemic uncertainty (novel data distributions), for which this approach is not particularly suited.

About

epistemically humble MLP able to abstain from making predictions if uncertain

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published