Skip to content

Commit 2466e0c

Browse files
authoredMay 4, 2023
Update LSPara.multi.ls14.py
1 parent 91379ec commit 2466e0c

File tree

1 file changed

+39
-181
lines changed

1 file changed

+39
-181
lines changed
 

‎LSPara.multi.ls14.py

+39-181
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from operator import index
55
import os
66
from pyexpat import model
7-
os.environ["CUDA_VISIBLE_DEVICES"] = '1'
7+
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
88
import os
99
os.environ["TOKENIZERS_PARALLELISM"] = "false"
1010
import argparse
@@ -47,29 +47,16 @@
4747
from bleurt import score
4848
from transformers import AutoModelForSequenceClassification, AutoTokenizer
4949

50-
# scorer = BERTScorer(lang="en", rescale_with_baseline=True)
51-
52-
5350

5451
bart_scorer=BARTScorer(device="cuda",checkpoint="/home/yz/liukang/liukang/huggingface/facebook/bart-large-cnn")
5552
bart_scorer.load(path="/home/yz/liukang/liukang/huggingface/facebook/bart-large-cnn/bart.pth")
56-
# bart_scorer=None
57-
# bleurt_scorer=score.BleurtScorer("/home/yz/liukang/liukang/huggingface/bleurt/BLEURT-20")
53+
5854

5955
bleurt_tokenizer = AutoTokenizer.from_pretrained("bleurt-large-512")
6056

6157
bleurt_scorer = AutoModelForSequenceClassification.from_pretrained("bleurt-large-512").cuda()
6258
bleurt_scorer.eval()
63-
# bleurt_scorer=None
64-
# import gensim
65-
# from gensim.test.utils import datapath,get_tmpfile
66-
# from gensim.scripts.glove2word2vec import glove2word2vec
67-
# from gensim.models import KeyedVectors
68-
# wordVecPath = "/home/yz/liukang/liukang/fairseq-main_prefix/fairseq-main_prefix/checkpoints/glove/glove.6B.300d.txt"
69-
# glove_file = datapath(wordVecPath)
70-
# tmp_file = get_tmpfile('glove_word2vec.txt')
71-
# glove2word2vec(glove_file,tmp_file)
72-
# glove_model = KeyedVectors.load_word2vec_format(tmp_file)
59+
7360

7461

7562
import json
@@ -350,13 +337,7 @@ def give_real_scores_ahead(tgt_dict,outputs,scores_with_suffix,scores_with_suffi
350337
scores_with_suffix[:,i]-=scores_with_suffix[:,first_index-1]
351338
else:
352339
pass
353-
# print(outputs)
354-
# print(scores_with_suffix[:,0:5])
355-
# for i in range(first_index,last_index):
356-
# pass
357-
#scores_with_suffix[:,i]/=(len(suffix_tokens)+i-prefix_len+1)
358-
#scores_with_suffix[:,i]/=(len(suffix_tokens)+i-prefix_len+1)
359-
# print(scores_with_suffix[:,0:5])
340+
360341
scores_with_suffix[scores_with_suffix_masks]=-math.inf
361342
for j in range(0,first_index):
362343
scores_with_suffix[:,j]=torch.tensor(-math.inf)
@@ -437,15 +418,7 @@ def extract_substitute(output_sentences, original_sentence, complex_word, thresh
437418

438419
index_of_complex_word = -1
439420

440-
# if complex_word not in original_words:
441-
# i = 0
442-
# for word in original_words:
443-
# if complex_word == word.lower():
444-
# index_of_complex_word = i
445-
# break
446-
# i += 1
447-
# else:
448-
# index_of_complex_word = original_words.index(complex_word)
421+
449422
index_of_complex_word=word_index
450423
if index_of_complex_word == -1:
451424
print("******************no found the complex word*****************")
@@ -543,19 +516,7 @@ def extract_substitute(output_sentences, original_sentence, complex_word, thresh
543516
real_prev_scores.append(prev_scores[s1_count])
544517

545518
if len(substitutes)>0:
546-
# bert_scores = substitutes_BertScore(context, complex_word, substitutes)
547519

548-
# #print(substitutes)
549-
# bert_scores = bert_scores.tolist()
550-
551-
# #pdb.set_trace()
552-
553-
554-
# filter_substitutes, bert_scores = filterSubstitute(substitutes, bert_scores, threshold)
555-
556-
# rank_bert = sorted(bert_scores,reverse = True)
557-
558-
# rank_bert_substitutes = [filter_substitutes[bert_scores.index(v)] for v in rank_bert]
559520
filter_substitutes=substitutes
560521
rank_bert_substitutes=substitutes
561522

@@ -632,14 +593,7 @@ def lexicalSubstitute(model, sentence, sentence_words, prefix,word_index,complex
632593
suffix1=suffix1.strip()
633594

634595
suffix1=" ".join(suffix1.split(" ")[:2])
635-
# if "," in suffix1:
636-
# if suffix1.index(",")!=0:
637-
# suffix1=suffix1[:suffix1.index(",")]
638-
#suffix1 = sentence[index_complex+:index_complex+1].strip()
639-
# suffix1 = " ".join(ori_words[ori_words.index(complex_word)+1:ori_words.index(complex_word)+7])
640-
# suffix1=process_string(suffix1)
641-
# medium_qutos=[",",".","!","?","\"","``",""]
642-
# for char1 in suffix1:
596+
643597

644598
else:
645599
pass
@@ -655,9 +609,6 @@ def lexicalSubstitute(model, sentence, sentence_words, prefix,word_index,complex
655609
prefix_tokens = prefix_tokens[:-1].view(1,-1)
656610

657611
complex_tokens = model.encode(complex_word)
658-
#1.make some change to the original sentence
659-
#=prefix.strip()+" "+process_string(complex_word.strip()+" "+stored_suffix1.strip())
660-
#sentence=new_sentence
661612

662613

663614
sentence_tokens = model.encode(sentence)
@@ -668,11 +619,7 @@ def lexicalSubstitute(model, sentence, sentence_words, prefix,word_index,complex
668619
attn_len = len(prefix_tokens[0])+len(complex_tokens)-1
669620
if len((model.tgt_dict.string(prefix_tokens).strip().replace("@@ ","")).strip().split())!=len(prefix.strip().split()):
670621
print("finding prefix not good before replace mask token!!!")
671-
# if len((model.tgt_dict.string(prefix_tokens).strip().replace("@@ ","")).strip().replace("<unk>",""))!=len(prefix.strip().split()):
672-
# print("finding prefix not good!!!")
673-
#outputs = model.generate2(sentence_tokens, beam=20, prefix_tokens=prefix_tokens)
674-
# outputs,pre_scores = model.generate2(sentence_tokens.cuda(), beam=beam, prefix_tokens=prefix_tokens.cuda(), attn_len=attn_len)
675-
#outputs,pre_scores = model.generate2(sentence_tokens.cuda(), beam=beam, prefix_tokens=prefix_tokens.cuda(), attn_len=attn_len,suffix_ids=suffix_tokens)
622+
676623
outputs,combined_sss,prev_masks,prev_masks2,scores_with_suffix,scores_with_suffix_masks,scores_with_dynamic = model.generate2(sentence_tokens.cuda(),
677624
beam=beam,
678625
prefix_tokens=prefix_tokens.cuda(),
@@ -683,47 +630,19 @@ def lexicalSubstitute(model, sentence, sentence_words, prefix,word_index,complex
683630
max_aheads=5)
684631
outputs=outputs.cpu()
685632

686-
# for i in range(len(combined_sss)):
687-
# if combined_sss[i]!=[]:
688-
# if type(combined_sss[i])==list:
689-
# combined_sss[i][0]=combined_sss[i][0].to("cpu")
690-
# combined_sss[i][1]=combined_sss[i][1].to("cpu")
691-
# else:
692-
# combined_sss[i]=combined_sss[i].to("cpu")
693-
# prev_masks=prev_masks.cpu()
694-
# prev_masks2=prev_masks2.cpu()
633+
695634
scores_with_suffix=scores_with_suffix.cpu()
696635
scores_with_suffix_masks=scores_with_suffix_masks.cpu()
697636

698-
# output_final_scores=give_real_scores(combined_sss,prev_masks,prev_masks2,suffix_tokens)
699-
# # import pdb
700-
# # pdb.set_trace()
701637

702-
# if combined_sss[1]!=[]:
703-
# # print("123")
704-
# outputs=outputs[torch.squeeze(torch.topk(output_final_scores,k=combined_sss[0][0].shape[1],dim=1)[1].view(1,-1),1)][0]
705-
# else:
706-
# outputs=outputs[torch.squeeze(torch.topk(combined_sss[0][0],k=combined_sss[0][0].shape[1],dim=1)[1].view(1,-1),1)][0]
707638
embed_scores=give_embedding_scores(outputs,model.models[0].state_dict()["decoder.embed_tokens.weight"].cpu(),complex_tokens=complex_tokens,temperature=0.2)
708639
#embed_scores=give_embedding_scores_v2(outputs,model.models[0].state_dict()["decoder.embed_tokens.weight"].cpu(),complex_tokens=complex_tokens,temperature=0.2)
709640
assert embed_scores.size()==scores_with_suffix[:,:(outputs.size()[-1]-1)].size()
710-
# alkl make change the embedding scores
711-
#embed_scores=change_embedding_scores(outputs,embed_scores,prefix_len=len(prefix_tokens[0]),max_ahead=5)
712-
#scores_with_suffix[:,:(outputs.size()[-1]-1)]=scores_with_suffix[:,:(outputs.size()[-1]-1)]+embed_scores
641+
713642

714643
outputs,outputs_scores,candis=give_real_scores_ahead(model.tgt_dict,outputs,scores_with_suffix,scores_with_suffix_masks,suffix_tokens,prefix_len=len(prefix_tokens[0]),prefix_str=prefix,max_ahead=5,flag=1)
715644

716-
# glove_scores_static=give_embedding_scores_v4(complex_tokens[:-1],candis,model,temperature=0.2,tokens_embedding=model.models[0].state_dict()["decoder.embed_tokens.weight"].cpu())
717-
# outputs_scores=torch.tensor(outputs_scores)+glove_scores_static
718-
# outputs_scores=outputs_scores.tolist()
719-
720-
# outputs=outputs[:20]
721-
# outputs_scores=outputs_scores[:20]
722-
# candis=candis[:20]
723645

724-
#glove_scores=get_glove_embedding(complex_word,candis,glove_model,temperature=1)
725-
#glove_scores=torch.tensor(outputs_scores)-torch.tensor(outputs_scores)
726-
#glove_scores=cal_bart_score(sentence,complex_word,word_index,candis)+cal_bleurt_score(sentence,complex_word,word_index,candis)
727646

728647

729648
new_outputs_scores=torch.tensor(outputs_scores)
@@ -735,30 +654,14 @@ def lexicalSubstitute(model, sentence, sentence_words, prefix,word_index,complex
735654
outputs_scores=[outputs_scores[index1].tolist() for index1 in new_indices]
736655
candis=[candis[index1] for index1 in new_indices]
737656

738-
#glove_scores=[glove_scores[index1].tolist() for index1 in new_indices]
739-
740-
#outputs_scores=outputs_scores.tolist()
741-
742-
#print(outputs)
743-
744-
#outputs=outputs[torch.squeeze(torch.topk(output_final_scores,k=beam,dim=1)[-1].view(1,-1),0)][:50]
745-
746-
#output_sentences = [model.decode(x['tokens']) for x in outputs]
747657
output_sentences=[model.decode(x) for x in outputs]
748658
if output_sentences==[]:
749659
print("find a missing prefix sentence!!!")
750660
return [],[],[],[]
751-
# for s1 in output_sentences:
752-
# print(s1[:200])
753-
# for s1 in outputs:
754-
# print(model.tgt_dict.string(s1)[:150])
755-
#bertscore_substitutes, ranking_bertscore_substitutes = extract_substitute(output_sentences, sentence, complex_word, threshold)
661+
756662
bertscore_substitutes, ranking_bertscore_substitutes,real_prev_scores = extract_substitute(output_sentences, sentence, complex_word,
757663
threshold,outputs_scores,word_index,sentence_words,target_pos,target_lemma)
758-
#print(pre_scores)
759664

760-
#for sen in output_sentences:
761-
# print(sen)
762665

763666
bertscore_substitutes=bertscore_substitutes[:50]
764667
ranking_bertscore_substitutes=ranking_bertscore_substitutes[:50]
@@ -767,20 +670,10 @@ def lexicalSubstitute(model, sentence, sentence_words, prefix,word_index,complex
767670

768671

769672
#glove_scores_static=give_embedding_scores_v4(complex_tokens[:-1],bertscore_substitutes,model,temperature=0.2,tokens_embedding=model.models[0].state_dict()["decoder.embed_tokens.weight"].cpu())
770-
#glove_scores=cal_bart_score(sentence,complex_word,word_index,bertscore_substitutes)+cal_bleurt_score(sentence,complex_word,word_index,bertscore_substitutes)
771-
#glove_scores=cal_bart_score(sentence,complex_word,word_index,bertscore_substitutes)+cal_bert_score(sentence,complex_word,word_index,bertscore_substitutes)
772-
#glove_scores=cal_bleurt_score(sentence,complex_word,word_index,bertscore_substitutes)
773-
#real_prev_scores=0.03*torch.tensor(real_prev_scores)+glove_scores
774-
775-
776-
#real_prev_scores=real_prev_scores.tolist()
673+
glove_scores=cal_bart_score(sentence,complex_word,word_index,bertscore_substitutes)+cal_bleurt_score(sentence,complex_word,word_index,bertscore_substitutes)
777674

778-
#bertscore_substitutes, ranking_bertscore_substitutes = extractSubstitute_bertscore(output_sentences, sentence, complex_word, threshold)
779-
#suffix_substitutes = extractSubstitute_suffix(output_sentences, sentence, complex_word)
675+
return bertscore_substitutes, ranking_bertscore_substitutes,real_prev_scores,glove_scores.tolist()
780676

781-
#return bertscore_substitutes, ranking_bertscore_substitutes,real_prev_scores,glove_scores.tolist()
782-
#return bertscore_substitutes, ranking_bertscore_substitutes,real_prev_scores,glove_scores_static.tolist()
783-
return bertscore_substitutes, ranking_bertscore_substitutes,real_prev_scores,1
784677

785678
def pos_filter(pos_vocab,target_pos,candi,candi_lemma):
786679
PosMap={"v":"VERB", "n":"NOUN", "a":"ADJ", "r":"ADV"}
@@ -917,14 +810,7 @@ def main():
917810
en2en = TransformerModel.from_pretrained(args.paraphraser_path, checkpoint_file=args.paraphraser_model,bpe=args.bpe,
918811
bpe_codes=args.bpe_codes).cuda().eval()
919812

920-
#CS = []
921813

922-
#CS2 = []
923-
924-
#CS3 = []
925-
926-
#output_sr_file.write("beam:", args.beam, " bertscore:", args.bertscore)
927-
#output_sr_file.write('\n')
928814
bert_substitutes_all=[]
929815
real_prev_scores_all=[]
930816
real_embed_scores_all=[]
@@ -936,10 +822,7 @@ def main():
936822
# continue
937823
for instance in reader.words_candidate[main_word]:
938824
for context in reader.words_candidate[main_word][instance]:
939-
# import pdb
940-
# pdb.set_trace()
941-
# if main_word!="forget.V":
942-
# continue
825+
943826

944827
text = context[1]
945828
original_text = text
@@ -961,24 +844,14 @@ def main():
961844
real_prev_scores_all.append(real_prev_scores)
962845
real_embed_scores_all.append(real_embed_scores)
963846

964-
# write_all_results(main_word, instance, target_pos, args.output_SR_file,
965-
# bert_substitutes, real_prev_scores, evaluation_metric)
966-
967-
#CS2.append(bert_substitutes[:10])
968847

969-
#CS3.append(bert_rank_substitutes[:10])
970-
#final_str=" ".join(complex_labels[i])+"|||"+" ".join(bert_rank_substitutes[:10])+"|||"+" ".join(list(set(complex_labels[i])&set(bert_rank_substitutes[:10])))+"\n"
971-
#final_str="&".join(complex_labels[i])+"|||"+" ".join(bert_substitutes[:10])+"|||"+" ".join(list(set(complex_labels[i])&set(bert_substitutes[:10])))+"\n"
972848

973849
import numpy as np
974850
import copy
975851
import os
976852

977853
#range1=np.arange(1,2,1)
978-
range1=np.arange(0.2,0.4,0.2)
979-
#range1=np.arange(0.005,0.1,0.005)
980-
#range2_log_softmax=np.arange(0.2,0.4,0.2)
981-
#range2_log_softmax=np.arange(0.,0.4,0.2)
854+
range1=np.arange(0.02,0.04,0.02)
982855
range2_log_softmax=np.arange(1,2,1)
983856

984857
for log_quto in range2_log_softmax:
@@ -1016,57 +889,42 @@ def main():
1016889
target_pos=target_pos
1017890
).lower().strip()
1018891

1019-
# target_lemma = lemma_word_spacy(
1020-
# target_word,
1021-
# target_pos=target_pos
1022-
# ).lower().strip()
1023-
1024-
#print("ori_score",real_prev_scores_all[count_1][:10])
1025892
tmp_log_embed_scores=torch.tensor(tmp_real_embed_scores_all[count_1])
1026-
#tmp_log_embed_scores=torch.tensor(tmp_real_embed_scores_all[count_1])/1
1027-
#tmp_log_embed_scores=F.log_softmax(tmp_log_embed_scores,dim=0)
1028-
tmp_log_embed_scores=tmp_log_embed_scores.tolist()
1029893

894+
tmp_log_embed_scores=tmp_log_embed_scores.tolist()
1030895

1031-
# for k1 in range(len(tmp_real_prev_scores_all[count_1])):
1032-
# tmp_real_prev_scores_all[count_1][k1]=tmp_real_prev_scores_all[count_1][k1]-tmp_real_embed_scores_all[count_1][k1]
1033-
# tmp_real_prev_scores_all[count_1][k1]+=embed_quto*tmp_real_embed_scores_all[count_1][k1]
1034-
1035-
# for k1 in range(len(tmp_real_prev_scores_all[count_1])):
1036-
# #tmp_real_prev_scores_all[count_1][k1]=tmp_real_prev_scores_all[count_1][k1]-tmp_real_embed_scores_all[count_1][k1]
1037-
# #tmp_real_prev_scores_all[count_1][k1]+=embed_quto*tmp_log_embed_scores[k1]
1038-
# tmp_real_prev_scores_all[count_1][k1]=embed_quto*tmp_real_prev_scores_all[count_1][k1]+tmp_log_embed_scores[k1]
1039-
# #tmp_real_prev_scores_all[count_1][k1]=tmp_log_embed_scores[k1]
1040-
# pass
896+
for k1 in range(len(tmp_real_prev_scores_all[count_1])):
897+
tmp_real_prev_scores_all[count_1][k1]=embed_quto*tmp_real_prev_scores_all[count_1][k1]+tmp_log_embed_scores[k1]
898+
pass
1041899

1042900
write_all_results(main_word, instance, target_pos, work_dir+args.output_SR_file+".embed."+str(embed_quto),
1043901
tmp_bert_substitutes_all[count_1], tmp_real_prev_scores_all[count_1], evaluation_metric)
1044902

1045903
#print("after_score",real_prev_scores_all[count_1][:10])
1046904

1047905
count_1+=1
1048-
# print("*"*100)
1049-
# test_golden_file="data/LS14/test/coinco_test.gold"
1050-
# output_results=work_dir+args.output_SR_file+".embed."+str(embed_quto)
1051-
# results_file=work_dir+args.output_score_file+".embed."+str(embed_quto)
1052-
# evaluation_metric.calculation_perl(
1053-
# test_golden_file,
1054-
# output_results + ".best",
1055-
# output_results + ".oot",
1056-
# results_file + ".best",
1057-
# results_file + ".oot"
1058-
# )
1059-
# evaluation_metric.calculation_p1(
1060-
# test_golden_file,
1061-
# output_results + "_p1.txt",
1062-
# results_file + "_p1.txt"
1063-
# )
906+
print("*"*100)
907+
test_golden_file="data/LS14/test/coinco_test.gold"
908+
output_results=work_dir+args.output_SR_file+".embed."+str(embed_quto)
909+
results_file=work_dir+args.output_score_file+".embed."+str(embed_quto)
910+
evaluation_metric.calculation_perl(
911+
test_golden_file,
912+
output_results + ".best",
913+
output_results + ".oot",
914+
results_file + ".best",
915+
results_file + ".oot"
916+
)
917+
evaluation_metric.calculation_p1(
918+
test_golden_file,
919+
output_results + "_p1.txt",
920+
results_file + "_p1.txt"
921+
)
1064922

1065-
# evaluation_metric.calculation_p3(
1066-
# test_golden_file,
1067-
# output_results + "_p3.txt",
1068-
# results_file + "_p3.txt"
1069-
# )
923+
evaluation_metric.calculation_p3(
924+
test_golden_file,
925+
output_results + "_p3.txt",
926+
results_file + "_p3.txt"
927+
)
1070928
if __name__ == "__main__":
1071929
main()
1072930

0 commit comments

Comments
 (0)
Please sign in to comment.