4
4
from operator import index
5
5
import os
6
6
from pyexpat import model
7
- os .environ ["CUDA_VISIBLE_DEVICES" ] = '1 '
7
+ os .environ ["CUDA_VISIBLE_DEVICES" ] = '0 '
8
8
import os
9
9
os .environ ["TOKENIZERS_PARALLELISM" ] = "false"
10
10
import argparse
47
47
from bleurt import score
48
48
from transformers import AutoModelForSequenceClassification , AutoTokenizer
49
49
50
- # scorer = BERTScorer(lang="en", rescale_with_baseline=True)
51
-
52
-
53
50
54
51
bart_scorer = BARTScorer (device = "cuda" ,checkpoint = "/home/yz/liukang/liukang/huggingface/facebook/bart-large-cnn" )
55
52
bart_scorer .load (path = "/home/yz/liukang/liukang/huggingface/facebook/bart-large-cnn/bart.pth" )
56
- # bart_scorer=None
57
- # bleurt_scorer=score.BleurtScorer("/home/yz/liukang/liukang/huggingface/bleurt/BLEURT-20")
53
+
58
54
59
55
bleurt_tokenizer = AutoTokenizer .from_pretrained ("bleurt-large-512" )
60
56
61
57
bleurt_scorer = AutoModelForSequenceClassification .from_pretrained ("bleurt-large-512" ).cuda ()
62
58
bleurt_scorer .eval ()
63
- # bleurt_scorer=None
64
- # import gensim
65
- # from gensim.test.utils import datapath,get_tmpfile
66
- # from gensim.scripts.glove2word2vec import glove2word2vec
67
- # from gensim.models import KeyedVectors
68
- # wordVecPath = "/home/yz/liukang/liukang/fairseq-main_prefix/fairseq-main_prefix/checkpoints/glove/glove.6B.300d.txt"
69
- # glove_file = datapath(wordVecPath)
70
- # tmp_file = get_tmpfile('glove_word2vec.txt')
71
- # glove2word2vec(glove_file,tmp_file)
72
- # glove_model = KeyedVectors.load_word2vec_format(tmp_file)
59
+
73
60
74
61
75
62
import json
@@ -350,13 +337,7 @@ def give_real_scores_ahead(tgt_dict,outputs,scores_with_suffix,scores_with_suffi
350
337
scores_with_suffix [:,i ]-= scores_with_suffix [:,first_index - 1 ]
351
338
else :
352
339
pass
353
- # print(outputs)
354
- # print(scores_with_suffix[:,0:5])
355
- # for i in range(first_index,last_index):
356
- # pass
357
- #scores_with_suffix[:,i]/=(len(suffix_tokens)+i-prefix_len+1)
358
- #scores_with_suffix[:,i]/=(len(suffix_tokens)+i-prefix_len+1)
359
- # print(scores_with_suffix[:,0:5])
340
+
360
341
scores_with_suffix [scores_with_suffix_masks ]= - math .inf
361
342
for j in range (0 ,first_index ):
362
343
scores_with_suffix [:,j ]= torch .tensor (- math .inf )
@@ -437,15 +418,7 @@ def extract_substitute(output_sentences, original_sentence, complex_word, thresh
437
418
438
419
index_of_complex_word = - 1
439
420
440
- # if complex_word not in original_words:
441
- # i = 0
442
- # for word in original_words:
443
- # if complex_word == word.lower():
444
- # index_of_complex_word = i
445
- # break
446
- # i += 1
447
- # else:
448
- # index_of_complex_word = original_words.index(complex_word)
421
+
449
422
index_of_complex_word = word_index
450
423
if index_of_complex_word == - 1 :
451
424
print ("******************no found the complex word*****************" )
@@ -543,19 +516,7 @@ def extract_substitute(output_sentences, original_sentence, complex_word, thresh
543
516
real_prev_scores .append (prev_scores [s1_count ])
544
517
545
518
if len (substitutes )> 0 :
546
- # bert_scores = substitutes_BertScore(context, complex_word, substitutes)
547
519
548
- # #print(substitutes)
549
- # bert_scores = bert_scores.tolist()
550
-
551
- # #pdb.set_trace()
552
-
553
-
554
- # filter_substitutes, bert_scores = filterSubstitute(substitutes, bert_scores, threshold)
555
-
556
- # rank_bert = sorted(bert_scores,reverse = True)
557
-
558
- # rank_bert_substitutes = [filter_substitutes[bert_scores.index(v)] for v in rank_bert]
559
520
filter_substitutes = substitutes
560
521
rank_bert_substitutes = substitutes
561
522
@@ -632,14 +593,7 @@ def lexicalSubstitute(model, sentence, sentence_words, prefix,word_index,complex
632
593
suffix1 = suffix1 .strip ()
633
594
634
595
suffix1 = " " .join (suffix1 .split (" " )[:2 ])
635
- # if "," in suffix1:
636
- # if suffix1.index(",")!=0:
637
- # suffix1=suffix1[:suffix1.index(",")]
638
- #suffix1 = sentence[index_complex+:index_complex+1].strip()
639
- # suffix1 = " ".join(ori_words[ori_words.index(complex_word)+1:ori_words.index(complex_word)+7])
640
- # suffix1=process_string(suffix1)
641
- # medium_qutos=[",",".","!","?","\"","``",""]
642
- # for char1 in suffix1:
596
+
643
597
644
598
else :
645
599
pass
@@ -655,9 +609,6 @@ def lexicalSubstitute(model, sentence, sentence_words, prefix,word_index,complex
655
609
prefix_tokens = prefix_tokens [:- 1 ].view (1 ,- 1 )
656
610
657
611
complex_tokens = model .encode (complex_word )
658
- #1.make some change to the original sentence
659
- #=prefix.strip()+" "+process_string(complex_word.strip()+" "+stored_suffix1.strip())
660
- #sentence=new_sentence
661
612
662
613
663
614
sentence_tokens = model .encode (sentence )
@@ -668,11 +619,7 @@ def lexicalSubstitute(model, sentence, sentence_words, prefix,word_index,complex
668
619
attn_len = len (prefix_tokens [0 ])+ len (complex_tokens )- 1
669
620
if len ((model .tgt_dict .string (prefix_tokens ).strip ().replace ("@@ " ,"" )).strip ().split ())!= len (prefix .strip ().split ()):
670
621
print ("finding prefix not good before replace mask token!!!" )
671
- # if len((model.tgt_dict.string(prefix_tokens).strip().replace("@@ ","")).strip().replace("<unk>",""))!=len(prefix.strip().split()):
672
- # print("finding prefix not good!!!")
673
- #outputs = model.generate2(sentence_tokens, beam=20, prefix_tokens=prefix_tokens)
674
- # outputs,pre_scores = model.generate2(sentence_tokens.cuda(), beam=beam, prefix_tokens=prefix_tokens.cuda(), attn_len=attn_len)
675
- #outputs,pre_scores = model.generate2(sentence_tokens.cuda(), beam=beam, prefix_tokens=prefix_tokens.cuda(), attn_len=attn_len,suffix_ids=suffix_tokens)
622
+
676
623
outputs ,combined_sss ,prev_masks ,prev_masks2 ,scores_with_suffix ,scores_with_suffix_masks ,scores_with_dynamic = model .generate2 (sentence_tokens .cuda (),
677
624
beam = beam ,
678
625
prefix_tokens = prefix_tokens .cuda (),
@@ -683,47 +630,19 @@ def lexicalSubstitute(model, sentence, sentence_words, prefix,word_index,complex
683
630
max_aheads = 5 )
684
631
outputs = outputs .cpu ()
685
632
686
- # for i in range(len(combined_sss)):
687
- # if combined_sss[i]!=[]:
688
- # if type(combined_sss[i])==list:
689
- # combined_sss[i][0]=combined_sss[i][0].to("cpu")
690
- # combined_sss[i][1]=combined_sss[i][1].to("cpu")
691
- # else:
692
- # combined_sss[i]=combined_sss[i].to("cpu")
693
- # prev_masks=prev_masks.cpu()
694
- # prev_masks2=prev_masks2.cpu()
633
+
695
634
scores_with_suffix = scores_with_suffix .cpu ()
696
635
scores_with_suffix_masks = scores_with_suffix_masks .cpu ()
697
636
698
- # output_final_scores=give_real_scores(combined_sss,prev_masks,prev_masks2,suffix_tokens)
699
- # # import pdb
700
- # # pdb.set_trace()
701
637
702
- # if combined_sss[1]!=[]:
703
- # # print("123")
704
- # outputs=outputs[torch.squeeze(torch.topk(output_final_scores,k=combined_sss[0][0].shape[1],dim=1)[1].view(1,-1),1)][0]
705
- # else:
706
- # outputs=outputs[torch.squeeze(torch.topk(combined_sss[0][0],k=combined_sss[0][0].shape[1],dim=1)[1].view(1,-1),1)][0]
707
638
embed_scores = give_embedding_scores (outputs ,model .models [0 ].state_dict ()["decoder.embed_tokens.weight" ].cpu (),complex_tokens = complex_tokens ,temperature = 0.2 )
708
639
#embed_scores=give_embedding_scores_v2(outputs,model.models[0].state_dict()["decoder.embed_tokens.weight"].cpu(),complex_tokens=complex_tokens,temperature=0.2)
709
640
assert embed_scores .size ()== scores_with_suffix [:,:(outputs .size ()[- 1 ]- 1 )].size ()
710
- # alkl make change the embedding scores
711
- #embed_scores=change_embedding_scores(outputs,embed_scores,prefix_len=len(prefix_tokens[0]),max_ahead=5)
712
- #scores_with_suffix[:,:(outputs.size()[-1]-1)]=scores_with_suffix[:,:(outputs.size()[-1]-1)]+embed_scores
641
+
713
642
714
643
outputs ,outputs_scores ,candis = give_real_scores_ahead (model .tgt_dict ,outputs ,scores_with_suffix ,scores_with_suffix_masks ,suffix_tokens ,prefix_len = len (prefix_tokens [0 ]),prefix_str = prefix ,max_ahead = 5 ,flag = 1 )
715
644
716
- # glove_scores_static=give_embedding_scores_v4(complex_tokens[:-1],candis,model,temperature=0.2,tokens_embedding=model.models[0].state_dict()["decoder.embed_tokens.weight"].cpu())
717
- # outputs_scores=torch.tensor(outputs_scores)+glove_scores_static
718
- # outputs_scores=outputs_scores.tolist()
719
-
720
- # outputs=outputs[:20]
721
- # outputs_scores=outputs_scores[:20]
722
- # candis=candis[:20]
723
645
724
- #glove_scores=get_glove_embedding(complex_word,candis,glove_model,temperature=1)
725
- #glove_scores=torch.tensor(outputs_scores)-torch.tensor(outputs_scores)
726
- #glove_scores=cal_bart_score(sentence,complex_word,word_index,candis)+cal_bleurt_score(sentence,complex_word,word_index,candis)
727
646
728
647
729
648
new_outputs_scores = torch .tensor (outputs_scores )
@@ -735,30 +654,14 @@ def lexicalSubstitute(model, sentence, sentence_words, prefix,word_index,complex
735
654
outputs_scores = [outputs_scores [index1 ].tolist () for index1 in new_indices ]
736
655
candis = [candis [index1 ] for index1 in new_indices ]
737
656
738
- #glove_scores=[glove_scores[index1].tolist() for index1 in new_indices]
739
-
740
- #outputs_scores=outputs_scores.tolist()
741
-
742
- #print(outputs)
743
-
744
- #outputs=outputs[torch.squeeze(torch.topk(output_final_scores,k=beam,dim=1)[-1].view(1,-1),0)][:50]
745
-
746
- #output_sentences = [model.decode(x['tokens']) for x in outputs]
747
657
output_sentences = [model .decode (x ) for x in outputs ]
748
658
if output_sentences == []:
749
659
print ("find a missing prefix sentence!!!" )
750
660
return [],[],[],[]
751
- # for s1 in output_sentences:
752
- # print(s1[:200])
753
- # for s1 in outputs:
754
- # print(model.tgt_dict.string(s1)[:150])
755
- #bertscore_substitutes, ranking_bertscore_substitutes = extract_substitute(output_sentences, sentence, complex_word, threshold)
661
+
756
662
bertscore_substitutes , ranking_bertscore_substitutes ,real_prev_scores = extract_substitute (output_sentences , sentence , complex_word ,
757
663
threshold ,outputs_scores ,word_index ,sentence_words ,target_pos ,target_lemma )
758
- #print(pre_scores)
759
664
760
- #for sen in output_sentences:
761
- # print(sen)
762
665
763
666
bertscore_substitutes = bertscore_substitutes [:50 ]
764
667
ranking_bertscore_substitutes = ranking_bertscore_substitutes [:50 ]
@@ -767,20 +670,10 @@ def lexicalSubstitute(model, sentence, sentence_words, prefix,word_index,complex
767
670
768
671
769
672
#glove_scores_static=give_embedding_scores_v4(complex_tokens[:-1],bertscore_substitutes,model,temperature=0.2,tokens_embedding=model.models[0].state_dict()["decoder.embed_tokens.weight"].cpu())
770
- #glove_scores=cal_bart_score(sentence,complex_word,word_index,bertscore_substitutes)+cal_bleurt_score(sentence,complex_word,word_index,bertscore_substitutes)
771
- #glove_scores=cal_bart_score(sentence,complex_word,word_index,bertscore_substitutes)+cal_bert_score(sentence,complex_word,word_index,bertscore_substitutes)
772
- #glove_scores=cal_bleurt_score(sentence,complex_word,word_index,bertscore_substitutes)
773
- #real_prev_scores=0.03*torch.tensor(real_prev_scores)+glove_scores
774
-
775
-
776
- #real_prev_scores=real_prev_scores.tolist()
673
+ glove_scores = cal_bart_score (sentence ,complex_word ,word_index ,bertscore_substitutes )+ cal_bleurt_score (sentence ,complex_word ,word_index ,bertscore_substitutes )
777
674
778
- #bertscore_substitutes, ranking_bertscore_substitutes = extractSubstitute_bertscore(output_sentences, sentence, complex_word, threshold)
779
- #suffix_substitutes = extractSubstitute_suffix(output_sentences, sentence, complex_word)
675
+ return bertscore_substitutes , ranking_bertscore_substitutes ,real_prev_scores ,glove_scores .tolist ()
780
676
781
- #return bertscore_substitutes, ranking_bertscore_substitutes,real_prev_scores,glove_scores.tolist()
782
- #return bertscore_substitutes, ranking_bertscore_substitutes,real_prev_scores,glove_scores_static.tolist()
783
- return bertscore_substitutes , ranking_bertscore_substitutes ,real_prev_scores ,1
784
677
785
678
def pos_filter (pos_vocab ,target_pos ,candi ,candi_lemma ):
786
679
PosMap = {"v" :"VERB" , "n" :"NOUN" , "a" :"ADJ" , "r" :"ADV" }
@@ -917,14 +810,7 @@ def main():
917
810
en2en = TransformerModel .from_pretrained (args .paraphraser_path , checkpoint_file = args .paraphraser_model ,bpe = args .bpe ,
918
811
bpe_codes = args .bpe_codes ).cuda ().eval ()
919
812
920
- #CS = []
921
813
922
- #CS2 = []
923
-
924
- #CS3 = []
925
-
926
- #output_sr_file.write("beam:", args.beam, " bertscore:", args.bertscore)
927
- #output_sr_file.write('\n')
928
814
bert_substitutes_all = []
929
815
real_prev_scores_all = []
930
816
real_embed_scores_all = []
@@ -936,10 +822,7 @@ def main():
936
822
# continue
937
823
for instance in reader .words_candidate [main_word ]:
938
824
for context in reader .words_candidate [main_word ][instance ]:
939
- # import pdb
940
- # pdb.set_trace()
941
- # if main_word!="forget.V":
942
- # continue
825
+
943
826
944
827
text = context [1 ]
945
828
original_text = text
@@ -961,24 +844,14 @@ def main():
961
844
real_prev_scores_all .append (real_prev_scores )
962
845
real_embed_scores_all .append (real_embed_scores )
963
846
964
- # write_all_results(main_word, instance, target_pos, args.output_SR_file,
965
- # bert_substitutes, real_prev_scores, evaluation_metric)
966
-
967
- #CS2.append(bert_substitutes[:10])
968
847
969
- #CS3.append(bert_rank_substitutes[:10])
970
- #final_str=" ".join(complex_labels[i])+"|||"+" ".join(bert_rank_substitutes[:10])+"|||"+" ".join(list(set(complex_labels[i])&set(bert_rank_substitutes[:10])))+"\n"
971
- #final_str="&".join(complex_labels[i])+"|||"+" ".join(bert_substitutes[:10])+"|||"+" ".join(list(set(complex_labels[i])&set(bert_substitutes[:10])))+"\n"
972
848
973
849
import numpy as np
974
850
import copy
975
851
import os
976
852
977
853
#range1=np.arange(1,2,1)
978
- range1 = np .arange (0.2 ,0.4 ,0.2 )
979
- #range1=np.arange(0.005,0.1,0.005)
980
- #range2_log_softmax=np.arange(0.2,0.4,0.2)
981
- #range2_log_softmax=np.arange(0.,0.4,0.2)
854
+ range1 = np .arange (0.02 ,0.04 ,0.02 )
982
855
range2_log_softmax = np .arange (1 ,2 ,1 )
983
856
984
857
for log_quto in range2_log_softmax :
@@ -1016,57 +889,42 @@ def main():
1016
889
target_pos = target_pos
1017
890
).lower ().strip ()
1018
891
1019
- # target_lemma = lemma_word_spacy(
1020
- # target_word,
1021
- # target_pos=target_pos
1022
- # ).lower().strip()
1023
-
1024
- #print("ori_score",real_prev_scores_all[count_1][:10])
1025
892
tmp_log_embed_scores = torch .tensor (tmp_real_embed_scores_all [count_1 ])
1026
- #tmp_log_embed_scores=torch.tensor(tmp_real_embed_scores_all[count_1])/1
1027
- #tmp_log_embed_scores=F.log_softmax(tmp_log_embed_scores,dim=0)
1028
- tmp_log_embed_scores = tmp_log_embed_scores .tolist ()
1029
893
894
+ tmp_log_embed_scores = tmp_log_embed_scores .tolist ()
1030
895
1031
- # for k1 in range(len(tmp_real_prev_scores_all[count_1])):
1032
- # tmp_real_prev_scores_all[count_1][k1]=tmp_real_prev_scores_all[count_1][k1]-tmp_real_embed_scores_all[count_1][k1]
1033
- # tmp_real_prev_scores_all[count_1][k1]+=embed_quto*tmp_real_embed_scores_all[count_1][k1]
1034
-
1035
- # for k1 in range(len(tmp_real_prev_scores_all[count_1])):
1036
- # #tmp_real_prev_scores_all[count_1][k1]=tmp_real_prev_scores_all[count_1][k1]-tmp_real_embed_scores_all[count_1][k1]
1037
- # #tmp_real_prev_scores_all[count_1][k1]+=embed_quto*tmp_log_embed_scores[k1]
1038
- # tmp_real_prev_scores_all[count_1][k1]=embed_quto*tmp_real_prev_scores_all[count_1][k1]+tmp_log_embed_scores[k1]
1039
- # #tmp_real_prev_scores_all[count_1][k1]=tmp_log_embed_scores[k1]
1040
- # pass
896
+ for k1 in range (len (tmp_real_prev_scores_all [count_1 ])):
897
+ tmp_real_prev_scores_all [count_1 ][k1 ]= embed_quto * tmp_real_prev_scores_all [count_1 ][k1 ]+ tmp_log_embed_scores [k1 ]
898
+ pass
1041
899
1042
900
write_all_results (main_word , instance , target_pos , work_dir + args .output_SR_file + ".embed." + str (embed_quto ),
1043
901
tmp_bert_substitutes_all [count_1 ], tmp_real_prev_scores_all [count_1 ], evaluation_metric )
1044
902
1045
903
#print("after_score",real_prev_scores_all[count_1][:10])
1046
904
1047
905
count_1 += 1
1048
- # print("*"*100)
1049
- # test_golden_file="data/LS14/test/coinco_test.gold"
1050
- # output_results=work_dir+args.output_SR_file+".embed."+str(embed_quto)
1051
- # results_file=work_dir+args.output_score_file+".embed."+str(embed_quto)
1052
- # evaluation_metric.calculation_perl(
1053
- # test_golden_file,
1054
- # output_results + ".best",
1055
- # output_results + ".oot",
1056
- # results_file + ".best",
1057
- # results_file + ".oot"
1058
- # )
1059
- # evaluation_metric.calculation_p1(
1060
- # test_golden_file,
1061
- # output_results + "_p1.txt",
1062
- # results_file + "_p1.txt"
1063
- # )
906
+ print ("*" * 100 )
907
+ test_golden_file = "data/LS14/test/coinco_test.gold"
908
+ output_results = work_dir + args .output_SR_file + ".embed." + str (embed_quto )
909
+ results_file = work_dir + args .output_score_file + ".embed." + str (embed_quto )
910
+ evaluation_metric .calculation_perl (
911
+ test_golden_file ,
912
+ output_results + ".best" ,
913
+ output_results + ".oot" ,
914
+ results_file + ".best" ,
915
+ results_file + ".oot"
916
+ )
917
+ evaluation_metric .calculation_p1 (
918
+ test_golden_file ,
919
+ output_results + "_p1.txt" ,
920
+ results_file + "_p1.txt"
921
+ )
1064
922
1065
- # evaluation_metric.calculation_p3(
1066
- # test_golden_file,
1067
- # output_results + "_p3.txt",
1068
- # results_file + "_p3.txt"
1069
- # )
923
+ evaluation_metric .calculation_p3 (
924
+ test_golden_file ,
925
+ output_results + "_p3.txt" ,
926
+ results_file + "_p3.txt"
927
+ )
1070
928
if __name__ == "__main__" :
1071
929
main ()
1072
930
0 commit comments