-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathemap.py
57 lines (43 loc) · 1.7 KB
/
emap.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
'''
Example implementation of EMAP
'''
import numpy as np
import collections
def emap(idx2logits):
'''Example implementation of EMAP (more efficient ones exist)
inputs:
idx2logits: This nested dictionary maps from image/text indices
function evals, i.e., idx2logits[i][j] = f(t_i, v_j)
returns:
projected_preds: a numpy array where projected_preds[i]
corresponds to \hat f(t_i, v_i).
'''
all_logits = []
for k, v in idx2logits.items():
all_logits.extend(v.values())
all_logits = np.vstack(all_logits)
logits_mean = np.mean(all_logits, axis=0)
reversed_idx2logits = collections.defaultdict(dict)
for i in range(len(idx2logits)):
for j in range(len(idx2logits[i])):
reversed_idx2logits[j][i] = idx2logits[i][j]
projected_preds = []
for idx in range(len(idx2logits)):
pred = np.mean(np.vstack(list(idx2logits[idx].values())), axis=0)
pred += np.mean(np.vstack(list(reversed_idx2logits[idx].values())), axis=0)
pred -= logits_mean
projected_preds.append(pred)
projected_preds = np.vstack(projected_preds)
return projected_preds
def test_from_paper():
'''tests the emap code using the worked example in the appendix'''
idx2logits = collections.defaultdict(dict)
idx2logits[0][0] = -1.3 ; idx2logits[0][1] = .3; idx2logits[0][2] = -.2;
idx2logits[1][0] = .8; idx2logits[1][1] = 3.0; idx2logits[1][2] = 1.1;
idx2logits[2][0] = 1.1; idx2logits[2][1] = -.1; idx2logits[2][2] = .7;
print('original model predictions:')
print([idx2logits[idx][idx] for idx in range(3)])
print('emap:')
print(emap(idx2logits))
if __name__ == '__main__':
test_from_paper()