-
Notifications
You must be signed in to change notification settings - Fork 40
/
Copy pathconsensus.py
96 lines (73 loc) · 3.83 KB
/
consensus.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import numpy as np
try:
from .abc_interpreter_m import Interpreter
except:
from .abc_interpreter import Interpreter
class ConsensusInterpreter(object):
"""
ConsensusInterpreter averages the explanations of a given Interpreter over a list of models. The averaged result
is more like an explanation for the data, instead of specific models. For visual object recognition tasks, the
Consensus explanation would be more aligned with the object than individual models.
More details regarding the Consensus method can be found in the original paper:
https://arxiv.org/abs/2109.00707.
For reference, the ``list_of_models`` can be found from :py:mod:`paddle.vision.models` or
`PPClas <https://github.com/PaddlePaddle/PaddleClas/blob/release/2.3/ppcls/arch/backbone/__init__.py>`_.
"""
def __init__(self, InterpreterClass, list_of_models: list, device: str = 'gpu:0', **kwargs):
"""
Args:
InterpreterClass ([type]): The given Interpreter defined in InterpretDL.
list_of_models (list): a list of trained models.
device (str): The device used for running ``model``, options: ``"cpu"``, ``"gpu:0"``, ``"gpu:1"``
etc.
"""
assert issubclass(InterpreterClass, Interpreter)
self.InterpreterClass = InterpreterClass
self.list_of_models = list_of_models
self.device = device
self.other_args = kwargs
def interpret(self, inputs: str or list(str) or np.ndarray, **kwargs) -> np.ndarray:
"""
The technical details are simple to understand for the Consensus method:
Given the ``inputs`` and the interpretation algorithm (one of the Interpreters), each model in
``list_of_models`` will produce an explanation, then Consensus will concatenate all the explanations.
Subsequent normalization and average can be done as users' preference. The suggested operation for input
gradient based algorithms is average of the absolute values.
We leave the visualization to users.
See the `notebook example
<https://github.com/PaddlePaddle/InterpretDL/tree/master/tutorials/example_consensus_cv.ipynb>`_
for an example.
.. code-block:: python
import interpretdl as it
from paddle.vision.models import resnet34, resnet50, resnet101, mobilenet_v2
list_models = {
'resnet34': resnet34(pretrained=True),
'resnet50': resnet50(pretrained=True),
'resnet101': resnet101(pretrained=True),
'mobilenet_v2': mobilenet_v2(pretrained=True)
}
consensus = ConsensusInterpreter(it.SmoothGradInterpreter, list_models.values(), device='gpu:0')
import matplotlib.pyplot as plt
import numpy as np
cols = len(list_models) + 1
psize = 4
fig, ax = plt.subplots(1, cols, figsize=(cols*psize, 1*psize))
for axis in ax:
axis.axis('off')
for i in range(len(list_models)):
ax[i].imshow(np.abs(exp[i]).sum(0))
ax[i].set_title(list(list_models.keys())[i])
ax[-1].imshow(np.abs(exp).sum(1).mean(0))
ax[-1].set_title('Consensus')
Args:
inputs (str or list of strs or numpy.ndarray): The input image filepath or a list of filepaths or numpy
array of read images.
Returns:
np.ndarray: Concatenated raw explanations.
"""
exps = []
for model in self.list_of_models:
interpreter = self.InterpreterClass(model, self.device, **self.other_args)
raw_explanation = interpreter.interpret(inputs, visual=False, save_path=None, **kwargs)
exps.append(raw_explanation)
return np.concatenate(exps)