Skip to content

Commit f63c67e

Browse files
committedAug 2, 2020
add baseline results
1 parent d7d311d commit f63c67e

18 files changed

+95
-31
lines changed
 

‎config/mlp/sr/fewrel_nc10.json

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
{
2+
"exp_id": "mlp_fewrel_sr_nc10",
3+
"train_path": "data/fewrel/train.txt",
4+
"test_path": "data/fewrel/test.txt",
5+
"num_output_classes": 64,
6+
"train_nc": 10,
7+
"num_epochs": 300,
8+
"minibatch_size": 128,
9+
"n_aug": 4,
10+
"aug_type": "sr"
11+
}

‎config/mlp/sr/huff_nc10.json

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
{
2+
"exp_id": "mlp_huff_sr_nc10",
3+
"train_path": "data/huff/train.txt",
4+
"test_path": "data/huff/test.txt",
5+
"num_output_classes": 41,
6+
"train_nc": 10,
7+
"num_epochs": 300,
8+
"minibatch_size": 128,
9+
"n_aug": 4,
10+
"aug_type": "sr"
11+
}

‎config/mlp/vanilla/covidcat_nc20.json

-9
This file was deleted.

‎config/mlp/vanilla/fewrel_nc10.json

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
{
2+
"exp_id": "mlp_fewrel_vanilla_nc10",
3+
"train_path": "data/fewrel/train.txt",
4+
"test_path": "data/fewrel/test.txt",
5+
"num_output_classes": 64,
6+
"train_nc": 10,
7+
"num_epochs": 300,
8+
"minibatch_size": 128
9+
}

‎config/mlp/vanilla/huff_nc10.json

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
{
2+
"exp_id": "mlp_huff_vanilla_nc10",
3+
"train_path": "data/huff/train.txt",
4+
"test_path": "data/huff/test.txt",
5+
"num_output_classes": 41,
6+
"train_nc": 10,
7+
"num_epochs": 300,
8+
"minibatch_size": 128
9+
}

‎config/mlp/vanilla/sst2_nc100.json

-9
This file was deleted.

‎config/triplet_ap/sr/fewrel_nc10.json

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
{
2+
"exp_id": "triplet_fewrel_sr_nc10",
3+
"train_path": "data/fewrel/train.txt",
4+
"test_path": "data/fewrel/test.txt",
5+
"num_output_classes": 64,
6+
"train_nc": 10,
7+
"total_updates": 15000,
8+
"n_aug": 4,
9+
"aug_type": "sr"
10+
}

‎config/triplet_ap/sr/huff_nc10.json

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
{
2+
"exp_id": "triplet_huff_sr_nc10",
3+
"train_path": "data/huff/train.txt",
4+
"test_path": "data/huff/test.txt",
5+
"num_output_classes": 41,
6+
"train_nc": 10,
7+
"total_updates": 15000,
8+
"n_aug": 4,
9+
"aug_type": "sr"
10+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"exp_id": "triplet_fewrel_vanilla_nc10",
3+
"train_path": "data/fewrel/train.txt",
4+
"test_path": "data/fewrel/test.txt",
5+
"num_output_classes": 64,
6+
"train_nc": 10,
7+
"total_updates": 15000
8+
}
+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"exp_id": "triplet_huff_vanilla_nc10",
3+
"train_path": "data/huff/train.txt",
4+
"test_path": "data/huff/test.txt",
5+
"num_output_classes": 41,
6+
"train_nc": 10,
7+
"total_updates": 15000
8+
}

‎config/triplet_ap/vanilla/imdb_nc10.json

-8
This file was deleted.

‎knn_ap_sr.py

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
if __name__ == "__main__":
44

55
cfg_json_list = [ #uses same configs as triplet_ap
6+
"config/triplet_ap/sr/fewrel_nc10.json",
7+
"config/triplet_ap/sr/huff_nc10.json",
68
"config/triplet_ap/sr/covidclu_nc3.json",
79
"config/triplet_ap/sr/covidcat_nc10.json",
810
"config/triplet_ap/sr/sst2_nc10.json",

‎knn_ap_vanilla.py

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
if __name__ == "__main__":
44

55
cfg_json_list = [ #uses same configs as triplet_ap
6+
"config/triplet_ap/vanilla/fewrel_nc10.json",
7+
"config/triplet_ap/vanilla/huff_nc10.json",
68
"config/triplet_ap/vanilla/covidclu_nc3.json",
79
"config/triplet_ap/vanilla/covidcat_nc10.json",
810
"config/triplet_ap/vanilla/sst2_nc10.json",

‎mlp_ap_sr.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
if __name__ == "__main__":
44

55
cfg_json_list = [
6-
"config/lr/sr/fewrel_nc10.json",
6+
# "config/lr/sr/fewrel_nc10.json",
77
# "config/lr/sr/huff_nc10.json",
88
# "config/lr/sr/covidclu_nc3.json",
99
# "config/lr/sr/covidcat_nc10.json",
1010
# "config/lr/sr/trec_nc10.json",
1111
# "config/lr/sr/sst2_nc10.json",
12+
"config/mlp/sr/fewrel_nc10.json",
13+
"config/mlp/sr/huff_nc10.json",
1214
# "config/mlp/sr/covidclu_nc3.json",
1315
# "config/mlp/sr/covidcat_nc10.json",
1416
# "config/mlp/sr/trec_nc10.json",

‎mlp_ap_vanilla.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
if __name__ == "__main__":
44

55
cfg_json_list = [
6-
"config/lr/vanilla/fewrel_nc10.json",
6+
# "config/lr/vanilla/fewrel_nc10.json",
77
# "config/lr/vanilla/huff_nc10.json",
88
# "config/lr/vanilla/covidclu_nc3.json",
99
# "config/lr/vanilla/covidcat_nc10.json",
1010
# "config/lr/vanilla/trec_nc10.json",
1111
# "config/lr/vanilla/sst2_nc10.json",
12+
"config/mlp/vanilla/fewrel_nc10.json",
13+
"config/mlp/vanilla/huff_nc10.json",
1214
# "config/mlp/vanilla/covidclu_nc3.json",
1315
# "config/mlp/vanilla/covidcat_nc10.json",
1416
# "config/mlp/vanilla/trec_nc10.json",

‎triplet_ap_sr.py

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
if __name__ == "__main__":
44

55
cfg_json_list = [
6+
"config/triplet_ap/sr/fewrel_nc10.json",
7+
"config/triplet_ap/sr/huff_nc10.json",
68
"config/triplet_ap/sr/covidcat_nc10.json",
79
"config/triplet_ap/sr/covidclu_nc3.json",
810
"config/triplet_ap/sr/trec_nc10.json",

‎triplet_ap_vanilla.py

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
if __name__ == "__main__":
44

55
cfg_json_list = [
6+
"config/triplet_ap/vanilla/fewrel_nc10.json",
7+
"config/triplet_ap/vanilla/huff_nc10.json",
68
"config/triplet_ap/vanilla/covidclu_nc3.json",
79
"config/triplet_ap/vanilla/covidcat_nc10.json",
810
"config/triplet_ap/vanilla/sst2_nc10.json",

‎utils/knn_methods.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from tqdm import tqdm
55
from scipy.spatial import distance
66
from pathlib import Path
7+
import itertools
78

89
def eval_model(
910
train_sentence_to_label,
@@ -14,14 +15,15 @@ def eval_model(
1415
):
1516

1617
def get_closest_train_sentence(test_sentence_encoding, train_sentence_to_encoding):
17-
train_sentence_to_dist_list = [ (train_sentence, distance.cosine(test_sentence_encoding, train_sentence_encoding)) for train_sentence, train_sentence_encoding in train_sentence_to_encoding.items()]
18+
train_sentences = list(itertools.chain.from_iterable(train_label_to_sentences.values()))
19+
train_sentence_to_dist_list = [ (train_sentence, distance.cosine(test_sentence_encoding, train_sentence_to_encoding[train_sentence])) for train_sentence in train_sentences]
1820
sorted_train_sentence_dist_list = list(sorted(train_sentence_to_dist_list, key=lambda tup: tup[1]))
19-
return sorted_train_sentence_dist_list[0][0]
21+
return sorted_train_sentence_dist_list[0][0], sorted_train_sentence_dist_list[0][1]
2022

2123
num_correct = 0 #probably should be refactored
2224
for test_sentence, label in tqdm(test_sentence_to_label.items()):
2325
test_sentence_encoding = test_sentence_to_encoding[test_sentence]
24-
closest_train_sentence = get_closest_train_sentence(test_sentence_encoding, train_sentence_to_encoding)
26+
closest_train_sentence, closest_dist = get_closest_train_sentence(test_sentence_encoding, train_sentence_to_encoding)
2527
predicted_label = train_sentence_to_label[closest_train_sentence]
2628
if predicted_label == label:
2729
num_correct += 1

0 commit comments

Comments
 (0)
Please sign in to comment.