Skip to content

Commit ba96c6a

Browse files
authored
Merge branch 'main' into shanea/storage-cleaner-unshard-improvements
2 parents 095cd7e + 1d264e4 commit ba96c6a

File tree

5 files changed

+123
-52
lines changed

5 files changed

+123
-52
lines changed

CHANGELOG.md

+5-1
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1313
- Added the option to directly pass input embeddings to `OLMo` and `OLMoForCausalLM`.
1414
- Added support for Python 3.8.
1515
- Added code to throw an error if `output_attentions` is set to `True` in forward call to `OLMoForCausalLM`. This functionality hasn't been implemented yet.
16+
- Fixed running with data loading workers on LUMI
1617

1718
### Added
1819
- Added `output_hidden_states` argument and associated functionality to `OLMo` and `OLMoForCausalLM` to return model intermediate hidden states.
19-
- Added MMLU downstream evaluation tasks.
20+
- Added MMLU downstream evaluation tasks, with prompt variations.
2021
- Added support for PyTorch v2.2.
22+
- Added ability to show logs from all ranks
23+
24+
2125

2226
## [v0.2.4](https://github.com/allenai/OLMo/releases/tag/v0.2.4) - 2024-02-02
2327

olmo/eval/downstream.py

+112-49
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import abc
2+
import logging
23
import re
34
from typing import Any, ClassVar, Dict, List, Optional, Sequence, Union
45

@@ -10,6 +11,8 @@
1011

1112
from ..tokenizer import Tokenizer
1213

14+
log = logging.getLogger(__name__)
15+
1316

1417
class ICLMetric(Metric):
1518
# update method does not require access to global metric state
@@ -152,13 +155,17 @@ def __init__(
152155
dataset_name: Union[str, Sequence[str], None] = None,
153156
model_ctx_len: int = 2048,
154157
split="validation",
158+
prompts=[None], # List of prompt variants to use
155159
):
156160
super().__init__()
157161

158162
self.tokenizer = tokenizer
159163
self.dataset_path = dataset_path
160164
self.dataset_name = dataset_name
161165
self.model_ctx_len = model_ctx_len
166+
self.prompts = prompts
167+
self.current_prompt = None
168+
self.log_instances = 5 # Log the first few instances as a sanity check
162169

163170
self.samples: List[Dict[str, Any]] = []
164171
dataset_names: Sequence[Optional[str]]
@@ -174,6 +181,7 @@ def __init__(
174181
path=self.dataset_path,
175182
name=ds_name,
176183
split=split,
184+
trust_remote_code=True,
177185
)
178186
)
179187
self.dataset = datasets.concatenate_datasets(dataset_list)
@@ -191,51 +199,65 @@ def prep_examples(self):
191199
"""Append doc_ids to each example so that they are processed together in the metric"""
192200
doc_id = 0
193201
for doc in self.dataset:
194-
# from EAI harness
195-
# how this all works:
196-
# CTX CONT
197-
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
198-
# gpt2 \ \
199-
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the
200-
# cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice
201-
202-
continuations = self.doc_to_continuations(doc)
203-
label_id = self.doc_to_label(doc)
204-
ctx = self.token_encode(self.doc_to_text(doc))
205-
dc = self.token_encode(self.doc_to_domain_conditional(doc))
206-
207-
for cont_id, continuation_str in enumerate(continuations):
208-
cont_str_len = len(continuation_str) - 1 # continuation contain leading blank
209-
continuation = self.token_encode(continuation_str)
210-
211-
# query, remove last token from continuation, truncate from left is longer than model ctx length
212-
query = ctx + continuation[:-1]
213-
query = query[-self.model_ctx_len :]
214-
215-
# get domain conditional query
216-
# we don't expect this to be longer than self.model_ctx_len and it won't make sense to truncate from left
217-
dc_query = dc + continuation[:-1]
218-
219-
# form a sample
220-
self.samples.append(
221-
{
222-
"doc_id": doc_id,
223-
"cont_id": cont_id,
224-
"ctx": ctx,
225-
"continuation": continuation,
226-
"ctx_len": len(ctx),
227-
"dc_len": len(dc),
228-
"cont_len": len(
229-
continuation
230-
), # even if query has last token removed, LM will output same cont len
231-
"cont_str_len": cont_str_len,
232-
"query": query, # remove last token from continuation
233-
"dc_query": dc_query,
234-
"label_id": label_id,
235-
}
236-
)
237-
238-
doc_id += 1
202+
for prompt in self.prompts:
203+
self.current_prompt = prompt
204+
# from EAI harness
205+
# how this all works:
206+
# CTX CONT
207+
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
208+
# gpt2 \ \
209+
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the
210+
# cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice
211+
212+
continuations = self.doc_to_continuations(doc)
213+
label_id = self.doc_to_label(doc)
214+
doc_text = self.doc_to_text(doc)
215+
ctx = self.token_encode(doc_text)
216+
dc = self.token_encode(self.doc_to_domain_conditional(doc))
217+
if self.log_instances > 0:
218+
self.log_instances -= 1
219+
ds_name = self.dataset_name
220+
if isinstance(ds_name, list):
221+
ds_name = ds_name[0]
222+
log.info(
223+
f"Sample doc from ({self.dataset_path}, {ds_name}, {self.current_prompt}):"
224+
+ f"\ndoc_text: {doc_text}\ncontinuations: {continuations}"
225+
)
226+
227+
for cont_id, continuation_str in enumerate(continuations):
228+
cont_str_len = len(continuation_str) - 1 # continuation contain leading blank
229+
continuation = self.token_encode(continuation_str)
230+
231+
# query, remove last token from continuation, truncate from left is longer than model ctx length
232+
query = ctx + continuation[:-1]
233+
query = query[-self.model_ctx_len :]
234+
# this will be different from len(ctx) when truncated by model_ctx_len
235+
actual_ctx_len = len(query) - len(continuation) + 1
236+
237+
# get domain conditional query
238+
# we don't expect this to be longer than self.model_ctx_len and it won't make sense to truncate from left
239+
dc_query = dc + continuation[:-1]
240+
241+
# form a sample
242+
self.samples.append(
243+
{
244+
"doc_id": doc_id,
245+
"cont_id": cont_id,
246+
"ctx": ctx,
247+
"continuation": continuation,
248+
"ctx_len": actual_ctx_len,
249+
"dc_len": len(dc),
250+
"cont_len": len(
251+
continuation
252+
), # even if query has last token removed, LM will output same cont len
253+
"cont_str_len": cont_str_len,
254+
"query": query, # remove last token from continuation
255+
"dc_query": dc_query,
256+
"label_id": label_id,
257+
}
258+
)
259+
260+
doc_id += 1
239261

240262
def pad_tokens_until_max(self, tokens, max_len=2048):
241263
"""truncate from left if len(tokens) > model_ctx_len, max_len is not considered then
@@ -655,7 +677,7 @@ def __init__(self, tokenizer, dataset_path="sciq", dataset_name=None):
655677
)
656678

657679
def doc_to_text(self, doc):
658-
return doc["support"] + "\nQuestion: " + doc["question"] + "\nAnswer:".strip()
680+
return doc["support"].strip() + "\nQuestion: " + doc["question"] + "\nAnswer:"
659681

660682
def doc_to_continuations(self, doc):
661683
# add spaces in front of continuation
@@ -1055,7 +1077,14 @@ class MMLU(ICLMultiChoiceTaskDataset):
10551077
"other": ["other", "business", "health"],
10561078
}
10571079

1058-
def __init__(self, tokenizer, dataset_path="hails/mmlu_no_train", dataset_name=None, split="validation"):
1080+
def __init__(
1081+
self,
1082+
tokenizer,
1083+
dataset_path="hails/mmlu_no_train",
1084+
dataset_name=None,
1085+
split="validation",
1086+
prompt_variations=None,
1087+
):
10591088
dataset_names = []
10601089
# Collect the relevant categories
10611090
if dataset_name in MMLU._categories:
@@ -1069,10 +1098,40 @@ def __init__(self, tokenizer, dataset_path="hails/mmlu_no_train", dataset_name=N
10691098
for name, cats in MMLU._subcategories.items():
10701099
if dataset_name in cats:
10711100
dataset_names.append(name)
1072-
super().__init__(tokenizer=tokenizer, dataset_path=dataset_path, dataset_name=dataset_names, split=split)
1101+
self.dev_set = {}
1102+
if prompt_variations == 1:
1103+
prompts = [None, "inst", "inst+1", "inst+2", "inst+3", "inst+4", "inst+5"]
1104+
# Need to grab the dev set for the few-shot prompts
1105+
for name in dataset_names:
1106+
self.dev_set[name] = datasets.load_dataset(
1107+
path=dataset_path, name=name, split="dev", trust_remote_code=True
1108+
)
1109+
super().__init__(
1110+
tokenizer=tokenizer,
1111+
dataset_path=dataset_path,
1112+
dataset_name=dataset_names,
1113+
split=split,
1114+
prompts=prompts,
1115+
)
10731116

10741117
def doc_to_text(self, doc):
1075-
return "Question: " + doc["question"] + "\nAnswer:"
1118+
output_text = "Question: " + doc["question"] + "\nAnswer:"
1119+
if self.current_prompt is not None:
1120+
prefix = ""
1121+
if "inst" in self.current_prompt:
1122+
subject = doc.get("subject").replace("_", " ")
1123+
prefix = f"The following are multiple choice questions (with answers) about {subject}:\n\n"
1124+
num_shots = re.findall("\\+(\\d+)", self.current_prompt)
1125+
if num_shots:
1126+
dev_set = self.dev_set.get(doc.get("subject"), [])
1127+
num_shots_int = int(num_shots[0])
1128+
for idx, dev_doc in enumerate(dev_set):
1129+
if idx >= num_shots_int:
1130+
break
1131+
answer = dev_doc["choices"][dev_doc["answer"]]
1132+
prefix += "Question: " + dev_doc["question"] + "\nAnswer: " + answer + "\n\n"
1133+
output_text = prefix + output_text
1134+
return output_text
10761135

10771136
def doc_to_continuations(self, doc):
10781137
# add spaces in front of continuation
@@ -1108,4 +1167,8 @@ def doc_to_domain_conditional(self, doc):
11081167
"mmlu_humanities": (MMLU, {"dataset_name": "humanities"}),
11091168
"mmlu_social_sciences": (MMLU, {"dataset_name": "social_sciences"}),
11101169
"mmlu_other": (MMLU, {"dataset_name": "other"}),
1170+
"mmlu_stem_var": (MMLU, {"dataset_name": "stem", "prompt_variations": 1}),
1171+
"mmlu_humanities_var": (MMLU, {"dataset_name": "humanities", "prompt_variations": 1}),
1172+
"mmlu_social_sciences_var": (MMLU, {"dataset_name": "social_sciences", "prompt_variations": 1}),
1173+
"mmlu_other_var": (MMLU, {"dataset_name": "other", "prompt_variations": 1}),
11111174
}

olmo/util.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def __repr__(self) -> str:
5959
class LogFilterType(StrEnum):
6060
rank0_only = "rank0_only"
6161
local_rank0_only = "local_rank0_only"
62+
all_ranks = "all_ranks"
6263

6364

6465
def log_extra_field(field_name: str, field_value: Any) -> None:
@@ -126,11 +127,12 @@ def local_rank0_filter(record: logging.LogRecord) -> int:
126127
else:
127128
return 0
128129

129-
filter = None
130130
if log_filter_type == LogFilterType.rank0_only:
131131
filter = rank0_filter
132132
elif log_filter_type == LogFilterType.local_rank0_only:
133133
filter = local_rank0_filter # type: ignore
134+
elif log_filter_type == LogFilterType.all_ranks:
135+
filter = None
134136
else:
135137
raise ValueError(log_filter_type)
136138

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ dependencies = [
2121
"google-cloud-storage",
2222
"tokenizers",
2323
"packaging",
24-
"cached_path",
24+
"cached_path>=1.6.2",
2525
"transformers",
2626
]
2727

scripts/train.py

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import gzip
44
import logging
5+
import multiprocessing as mp
56
import sys
67
from pathlib import Path
78
from typing import Optional, TextIO
@@ -240,6 +241,7 @@ def dummy_init_fn(module: torch.nn.Module) -> None:
240241

241242

242243
if __name__ == "__main__":
244+
mp.set_start_method("spawn")
243245
# Initialize process group.
244246
dist.init_process_group(backend="nccl")
245247

0 commit comments

Comments
 (0)