1
1
import abc
2
+ import logging
2
3
import re
3
4
from typing import Any , ClassVar , Dict , List , Optional , Sequence , Union
4
5
10
11
11
12
from ..tokenizer import Tokenizer
12
13
14
+ log = logging .getLogger (__name__ )
15
+
13
16
14
17
class ICLMetric (Metric ):
15
18
# update method does not require access to global metric state
@@ -152,13 +155,17 @@ def __init__(
152
155
dataset_name : Union [str , Sequence [str ], None ] = None ,
153
156
model_ctx_len : int = 2048 ,
154
157
split = "validation" ,
158
+ prompts = [None ], # List of prompt variants to use
155
159
):
156
160
super ().__init__ ()
157
161
158
162
self .tokenizer = tokenizer
159
163
self .dataset_path = dataset_path
160
164
self .dataset_name = dataset_name
161
165
self .model_ctx_len = model_ctx_len
166
+ self .prompts = prompts
167
+ self .current_prompt = None
168
+ self .log_instances = 5 # Log the first few instances as a sanity check
162
169
163
170
self .samples : List [Dict [str , Any ]] = []
164
171
dataset_names : Sequence [Optional [str ]]
@@ -174,6 +181,7 @@ def __init__(
174
181
path = self .dataset_path ,
175
182
name = ds_name ,
176
183
split = split ,
184
+ trust_remote_code = True ,
177
185
)
178
186
)
179
187
self .dataset = datasets .concatenate_datasets (dataset_list )
@@ -191,51 +199,65 @@ def prep_examples(self):
191
199
"""Append doc_ids to each example so that they are processed together in the metric"""
192
200
doc_id = 0
193
201
for doc in self .dataset :
194
- # from EAI harness
195
- # how this all works:
196
- # CTX CONT
197
- # inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
198
- # gpt2 \ \
199
- # logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the
200
- # cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice
201
-
202
- continuations = self .doc_to_continuations (doc )
203
- label_id = self .doc_to_label (doc )
204
- ctx = self .token_encode (self .doc_to_text (doc ))
205
- dc = self .token_encode (self .doc_to_domain_conditional (doc ))
206
-
207
- for cont_id , continuation_str in enumerate (continuations ):
208
- cont_str_len = len (continuation_str ) - 1 # continuation contain leading blank
209
- continuation = self .token_encode (continuation_str )
210
-
211
- # query, remove last token from continuation, truncate from left is longer than model ctx length
212
- query = ctx + continuation [:- 1 ]
213
- query = query [- self .model_ctx_len :]
214
-
215
- # get domain conditional query
216
- # we don't expect this to be longer than self.model_ctx_len and it won't make sense to truncate from left
217
- dc_query = dc + continuation [:- 1 ]
218
-
219
- # form a sample
220
- self .samples .append (
221
- {
222
- "doc_id" : doc_id ,
223
- "cont_id" : cont_id ,
224
- "ctx" : ctx ,
225
- "continuation" : continuation ,
226
- "ctx_len" : len (ctx ),
227
- "dc_len" : len (dc ),
228
- "cont_len" : len (
229
- continuation
230
- ), # even if query has last token removed, LM will output same cont len
231
- "cont_str_len" : cont_str_len ,
232
- "query" : query , # remove last token from continuation
233
- "dc_query" : dc_query ,
234
- "label_id" : label_id ,
235
- }
236
- )
237
-
238
- doc_id += 1
202
+ for prompt in self .prompts :
203
+ self .current_prompt = prompt
204
+ # from EAI harness
205
+ # how this all works:
206
+ # CTX CONT
207
+ # inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
208
+ # gpt2 \ \
209
+ # logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the
210
+ # cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice
211
+
212
+ continuations = self .doc_to_continuations (doc )
213
+ label_id = self .doc_to_label (doc )
214
+ doc_text = self .doc_to_text (doc )
215
+ ctx = self .token_encode (doc_text )
216
+ dc = self .token_encode (self .doc_to_domain_conditional (doc ))
217
+ if self .log_instances > 0 :
218
+ self .log_instances -= 1
219
+ ds_name = self .dataset_name
220
+ if isinstance (ds_name , list ):
221
+ ds_name = ds_name [0 ]
222
+ log .info (
223
+ f"Sample doc from ({ self .dataset_path } , { ds_name } , { self .current_prompt } ):"
224
+ + f"\n doc_text: { doc_text } \n continuations: { continuations } "
225
+ )
226
+
227
+ for cont_id , continuation_str in enumerate (continuations ):
228
+ cont_str_len = len (continuation_str ) - 1 # continuation contain leading blank
229
+ continuation = self .token_encode (continuation_str )
230
+
231
+ # query, remove last token from continuation, truncate from left is longer than model ctx length
232
+ query = ctx + continuation [:- 1 ]
233
+ query = query [- self .model_ctx_len :]
234
+ # this will be different from len(ctx) when truncated by model_ctx_len
235
+ actual_ctx_len = len (query ) - len (continuation ) + 1
236
+
237
+ # get domain conditional query
238
+ # we don't expect this to be longer than self.model_ctx_len and it won't make sense to truncate from left
239
+ dc_query = dc + continuation [:- 1 ]
240
+
241
+ # form a sample
242
+ self .samples .append (
243
+ {
244
+ "doc_id" : doc_id ,
245
+ "cont_id" : cont_id ,
246
+ "ctx" : ctx ,
247
+ "continuation" : continuation ,
248
+ "ctx_len" : actual_ctx_len ,
249
+ "dc_len" : len (dc ),
250
+ "cont_len" : len (
251
+ continuation
252
+ ), # even if query has last token removed, LM will output same cont len
253
+ "cont_str_len" : cont_str_len ,
254
+ "query" : query , # remove last token from continuation
255
+ "dc_query" : dc_query ,
256
+ "label_id" : label_id ,
257
+ }
258
+ )
259
+
260
+ doc_id += 1
239
261
240
262
def pad_tokens_until_max (self , tokens , max_len = 2048 ):
241
263
"""truncate from left if len(tokens) > model_ctx_len, max_len is not considered then
@@ -655,7 +677,7 @@ def __init__(self, tokenizer, dataset_path="sciq", dataset_name=None):
655
677
)
656
678
657
679
def doc_to_text (self , doc ):
658
- return doc ["support" ] + "\n Question: " + doc ["question" ] + "\n Answer:" . strip ()
680
+ return doc ["support" ]. strip () + "\n Question: " + doc ["question" ] + "\n Answer:"
659
681
660
682
def doc_to_continuations (self , doc ):
661
683
# add spaces in front of continuation
@@ -1055,7 +1077,14 @@ class MMLU(ICLMultiChoiceTaskDataset):
1055
1077
"other" : ["other" , "business" , "health" ],
1056
1078
}
1057
1079
1058
- def __init__ (self , tokenizer , dataset_path = "hails/mmlu_no_train" , dataset_name = None , split = "validation" ):
1080
+ def __init__ (
1081
+ self ,
1082
+ tokenizer ,
1083
+ dataset_path = "hails/mmlu_no_train" ,
1084
+ dataset_name = None ,
1085
+ split = "validation" ,
1086
+ prompt_variations = None ,
1087
+ ):
1059
1088
dataset_names = []
1060
1089
# Collect the relevant categories
1061
1090
if dataset_name in MMLU ._categories :
@@ -1069,10 +1098,40 @@ def __init__(self, tokenizer, dataset_path="hails/mmlu_no_train", dataset_name=N
1069
1098
for name , cats in MMLU ._subcategories .items ():
1070
1099
if dataset_name in cats :
1071
1100
dataset_names .append (name )
1072
- super ().__init__ (tokenizer = tokenizer , dataset_path = dataset_path , dataset_name = dataset_names , split = split )
1101
+ self .dev_set = {}
1102
+ if prompt_variations == 1 :
1103
+ prompts = [None , "inst" , "inst+1" , "inst+2" , "inst+3" , "inst+4" , "inst+5" ]
1104
+ # Need to grab the dev set for the few-shot prompts
1105
+ for name in dataset_names :
1106
+ self .dev_set [name ] = datasets .load_dataset (
1107
+ path = dataset_path , name = name , split = "dev" , trust_remote_code = True
1108
+ )
1109
+ super ().__init__ (
1110
+ tokenizer = tokenizer ,
1111
+ dataset_path = dataset_path ,
1112
+ dataset_name = dataset_names ,
1113
+ split = split ,
1114
+ prompts = prompts ,
1115
+ )
1073
1116
1074
1117
def doc_to_text (self , doc ):
1075
- return "Question: " + doc ["question" ] + "\n Answer:"
1118
+ output_text = "Question: " + doc ["question" ] + "\n Answer:"
1119
+ if self .current_prompt is not None :
1120
+ prefix = ""
1121
+ if "inst" in self .current_prompt :
1122
+ subject = doc .get ("subject" ).replace ("_" , " " )
1123
+ prefix = f"The following are multiple choice questions (with answers) about { subject } :\n \n "
1124
+ num_shots = re .findall ("\\ +(\\ d+)" , self .current_prompt )
1125
+ if num_shots :
1126
+ dev_set = self .dev_set .get (doc .get ("subject" ), [])
1127
+ num_shots_int = int (num_shots [0 ])
1128
+ for idx , dev_doc in enumerate (dev_set ):
1129
+ if idx >= num_shots_int :
1130
+ break
1131
+ answer = dev_doc ["choices" ][dev_doc ["answer" ]]
1132
+ prefix += "Question: " + dev_doc ["question" ] + "\n Answer: " + answer + "\n \n "
1133
+ output_text = prefix + output_text
1134
+ return output_text
1076
1135
1077
1136
def doc_to_continuations (self , doc ):
1078
1137
# add spaces in front of continuation
@@ -1108,4 +1167,8 @@ def doc_to_domain_conditional(self, doc):
1108
1167
"mmlu_humanities" : (MMLU , {"dataset_name" : "humanities" }),
1109
1168
"mmlu_social_sciences" : (MMLU , {"dataset_name" : "social_sciences" }),
1110
1169
"mmlu_other" : (MMLU , {"dataset_name" : "other" }),
1170
+ "mmlu_stem_var" : (MMLU , {"dataset_name" : "stem" , "prompt_variations" : 1 }),
1171
+ "mmlu_humanities_var" : (MMLU , {"dataset_name" : "humanities" , "prompt_variations" : 1 }),
1172
+ "mmlu_social_sciences_var" : (MMLU , {"dataset_name" : "social_sciences" , "prompt_variations" : 1 }),
1173
+ "mmlu_other_var" : (MMLU , {"dataset_name" : "other" , "prompt_variations" : 1 }),
1111
1174
}
0 commit comments