forked from alexysxeightn/MADE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
141 lines (92 loc) · 4.08 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import spacy
import torch
from torchtext.data.metrics import bleu_score
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
def flatten(l):
return [item for sublist in l for item in sublist]
def remove_tech_tokens(mystr, tokens_to_remove=['<eos>', '<sos>', '<unk>', '<pad>']):
return [x for x in mystr if x not in tokens_to_remove]
def get_text(x, TRG_vocab):
text = [TRG_vocab.itos[token] for token in x]
try:
end_idx = text.index('<eos>')
text = text[:end_idx]
except ValueError:
pass
text = remove_tech_tokens(text)
if len(text) < 1:
text = []
return text
def generate_translation(src, trg, model, TRG_vocab):
model.eval()
output = model(src, trg, 0) #turn off teacher forcing
output = output.argmax(dim=-1).cpu().numpy()
original = get_text(list(trg[:,0].cpu().numpy()), TRG_vocab)
generated = get_text(list(output[1:, 0]), TRG_vocab)
print('Original: {}'.format(' '.join(original)))
print('Generated: {}'.format(' '.join(generated)))
print()
def translate_sentence(sentence, src_field, trg_field, model, device, max_len = 50, is_cnn=False):
model.eval()
if isinstance(sentence, str):
nlp = spacy.load('de_core_news_sm')
tokens = [token.text.lower() for token in nlp(sentence)]
else:
tokens = [token.lower() for token in sentence]
tokens = [src_field.init_token] + tokens + [src_field.eos_token]
src_indexes = [src_field.vocab.stoi[token] for token in tokens]
src_tensor = torch.LongTensor(src_indexes).unsqueeze(0).to(device)
if not is_cnn:
src_mask = model.make_src_mask(src_tensor)
with torch.no_grad():
if is_cnn:
encoder_conved, encoder_combined = model.encoder(src_tensor)
else:
enc_src = model.encoder(src_tensor, src_mask)
trg_indexes = [trg_field.vocab.stoi[trg_field.init_token]]
for i in range(max_len):
trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0).to(device)
if is_cnn:
with torch.no_grad():
output, attention = model.decoder(trg_tensor, encoder_conved, encoder_combined)
else:
trg_mask = model.make_trg_mask(trg_tensor)
with torch.no_grad():
output, attention = model.decoder(trg_tensor, enc_src, trg_mask, src_mask)
pred_token = output.argmax(2)[:,-1].item()
trg_indexes.append(pred_token)
if pred_token == trg_field.vocab.stoi[trg_field.eos_token]:
break
trg_tokens = [trg_field.vocab.itos[i] for i in trg_indexes]
return trg_tokens[1:], attention
def display_attention(sentence, translation, attention, n_heads = 8, n_rows = 4, n_cols = 2, is_cnn=False):
assert n_rows * n_cols == n_heads
fig = plt.figure(figsize=(15,25))
for i in range(n_heads):
ax = fig.add_subplot(n_rows, n_cols, i+1)
if is_cnn:
_attention = attention.squeeze(0).cpu().detach().numpy()
else:
_attention = attention.squeeze(0)[i].cpu().detach().numpy()
cax = ax.matshow(_attention, cmap='bone')
ax.tick_params(labelsize=12)
ax.set_xticklabels(['']+['<sos>']+[t.lower() for t in sentence]+['<eos>'],
rotation=45)
ax.set_yticklabels(['']+translation)
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
plt.show()
plt.close()
def calculate_bleu(data, src_field, trg_field, model, device, max_len = 50, is_cnn=False):
trgs = []
pred_trgs = []
for datum in data:
src = vars(datum)['src']
trg = vars(datum)['trg']
pred_trg, _ = translate_sentence(src, src_field, trg_field, model, device, max_len, is_cnn)
#cut off <eos> token
pred_trg = pred_trg[:-1]
pred_trgs.append(pred_trg)
trgs.append([trg])
return bleu_score(pred_trgs, trgs)