Skip to content

Commit 8ddfe79

Browse files
authored
Merge branch 'main' into shanea/tokenizer-package-data
2 parents d005b16 + 1b2658b commit 8ddfe79

File tree

4 files changed

+185
-21
lines changed

4 files changed

+185
-21
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1616
- Added caching to disk of HF datasets used in downstream evals
1717
- Added FLOPs logging
1818
- Added configs for OLMo tiny set of models
19+
- Added configuration field `optimizer.record_update_metrics`, which defaults to `False`, but when set to True will trigger AdamW to collect the step size norm and absolute max for each parameter.
1920
- Added `olmo_data`, a package holding data files like tokenizers.
2021
- Added ability to load tokenizers from `olmo_data` package data.
2122

@@ -24,6 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2425
- Added original legacy unsharding implementation back, as the default. The new
2526
shared memory implementation can be used by passing `use_legacy_shared_mem_impl` to `unshard.py`.
2627
- Refactor weight initialization. IMPORTANT: this does not maintain backwards-compatibility with older configs; the jobs will still run, but may produce different outputs.
28+
- Changed the behavior of the Lion optimizer to only record the update cosine similarity when `optimizer.record_update_metrics` is `True` in order to be consistent with the API.
2729

2830
### Fixed
2931

olmo/config.py

+6
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,12 @@ class OptimizerConfig(BaseConfig):
489489
If not set, defaults to the wandb `log_interval`.
490490
"""
491491

492+
record_update_metrics: bool = False
493+
"""
494+
Whether to record detailed metrics about the optimizer's parameter updates, like the norm and max
495+
of the update with AdamW.
496+
"""
497+
492498
def __post_init__(self):
493499
self.betas = tuple(self.betas) # type: ignore[assignment]
494500

olmo/optim.py

+176-21
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@
3535

3636

3737
class Optimizer(OptimizerBase):
38+
def __init__(self, *args, record_update_metrics: bool = False, **kwargs):
39+
super().__init__(*args, **kwargs)
40+
self._record_update_metrics = record_update_metrics
41+
self._collecting_metrics = False
42+
3843
def _clean_param_name(self, name: str) -> str:
3944
return name.replace("_fsdp_wrapped_module.", "")
4045

@@ -50,6 +55,7 @@ def clip_grads_and_collect_metrics(
5055
Clips gradients for every group that has the field `max_grad_norm`.
5156
At the same time collect metrics for each parameter and its gradient.
5257
"""
58+
self._collecting_metrics = collect_param_metrics
5359
device = get_default_device() if device is None else device
5460

5561
# NOTE (epwalsh): during distributed training we're making an assumption that the order of
@@ -365,12 +371,13 @@ def __init__(
365371
lr: float = 1e-4,
366372
betas: Tuple[float, float] = (0.9, 0.99),
367373
weight_decay: float = 0.0,
374+
record_update_metrics: bool = False,
368375
device: Optional[torch.device] = None,
369376
):
370377
assert lr > 0.0
371378
assert all([0.0 <= beta <= 1.0 for beta in betas])
372379
defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
373-
super().__init__(params, defaults)
380+
super().__init__(params, defaults, record_update_metrics=record_update_metrics)
374381
for group in self.param_groups:
375382
group["initial_lr"] = group["lr"]
376383
self._update_total_dot_prod: Optional[torch.Tensor] = None
@@ -391,6 +398,10 @@ def get_post_step_metrics(
391398
if update_total_dot_prod is None or update_total_norm is None or signed_update_total_norm is None:
392399
return {}
393400

401+
self._update_total_dot_prod = None
402+
self._update_total_norm = None
403+
self._signed_update_total_norm = None
404+
394405
if is_distributed() and isinstance(module, FullyShardedDataParallel):
395406
# Reduce total dot prod and norms across all ranks.
396407
update_total_norm = update_total_norm**2.0
@@ -419,9 +430,13 @@ def step(self, closure=None) -> None:
419430
with torch.enable_grad():
420431
closure()
421432

422-
update_total_dot_prod = torch.tensor(0.0, dtype=torch.float32)
423-
update_norms = []
424-
signed_update_norms = []
433+
update_total_dot_prod: Optional[torch.Tensor] = None
434+
update_norms: Optional[List[torch.Tensor]] = None
435+
signed_update_norms: Optional[List[torch.Tensor]] = None
436+
if self._collecting_metrics and self._record_update_metrics:
437+
update_total_dot_prod = torch.tensor(0.0, dtype=torch.float32)
438+
update_norms = []
439+
signed_update_norms = []
425440

426441
for group in self.param_groups:
427442
for p in group["params"]:
@@ -452,31 +467,169 @@ def step(self, closure=None) -> None:
452467

453468
# Track dot product and norms of update vs signed update in order to calculate
454469
# their cosine similarity.
455-
update_total_dot_prod = update_total_dot_prod.to(update.device)
456-
update_total_dot_prod += torch.tensordot(update, signed_update, dims=len(update.shape))
457-
update_norms.append(torch.linalg.vector_norm(update, 2.0, dtype=torch.float32))
458-
signed_update_norms.append(torch.linalg.vector_norm(signed_update, 2.0, dtype=torch.float32))
470+
if (
471+
update_total_dot_prod is not None
472+
and update_norms is not None
473+
and signed_update_norms is not None
474+
):
475+
update_total_dot_prod = update_total_dot_prod.to(update.device)
476+
update_total_dot_prod += torch.tensordot(update, signed_update, dims=len(update.shape))
477+
update_norms.append(torch.linalg.vector_norm(update, 2.0, dtype=torch.float32))
478+
signed_update_norms.append(torch.linalg.vector_norm(signed_update, 2.0, dtype=torch.float32))
459479

460480
# Compute cosine similarity between update and signed update.
461-
self._update_total_dot_prod = update_total_dot_prod.to(
462-
get_default_device() if self._device is None else self._device
463-
)
464-
self._update_total_norm = torch.linalg.vector_norm(
465-
torch.stack(update_norms),
466-
2.0,
467-
dtype=torch.float32,
468-
).to(get_default_device() if self._device is None else self._device)
469-
self._signed_update_total_norm = torch.linalg.vector_norm(
470-
torch.stack(signed_update_norms),
471-
2.0,
472-
dtype=torch.float32,
473-
).to(get_default_device() if self._device is None else self._device)
481+
if update_total_dot_prod is not None and update_norms is not None and signed_update_norms is not None:
482+
device = get_default_device() if self._device is None else self._device
483+
self._update_total_dot_prod = update_total_dot_prod.to(device)
484+
self._update_total_norm = torch.linalg.vector_norm(
485+
torch.stack(update_norms),
486+
2.0,
487+
dtype=torch.float32,
488+
).to(device)
489+
self._signed_update_total_norm = torch.linalg.vector_norm(
490+
torch.stack(signed_update_norms),
491+
2.0,
492+
dtype=torch.float32,
493+
).to(device)
474494

475495

476496
class AdamW(torch.optim.AdamW, Optimizer):
497+
def __init__(self, *args, record_update_metrics: bool = False, **kwargs):
498+
super().__init__(*args, **kwargs)
499+
500+
# Need to set these here just like in our base `Optimizer` class since our `Optimizer.__init__`
501+
# won't be called.
502+
self._record_update_metrics = record_update_metrics
503+
self._collecting_metrics = False
504+
505+
self._step_size_param_names: Optional[List[str]] = None
506+
self._step_size_norms: Optional[List[torch.Tensor]] = None
507+
self._step_size_maxs: Optional[List[torch.Tensor]] = None
508+
509+
@torch.no_grad()
510+
def step(self, closure=None) -> None:
511+
if not (self._record_update_metrics and self._collecting_metrics):
512+
return super().step(closure=closure)
513+
514+
device = get_default_device()
515+
param_names = []
516+
step_size_norms = []
517+
step_size_maxs = []
518+
for group in self.param_groups:
519+
beta1, beta2 = group["betas"]
520+
lr = group["lr"]
521+
weight_decay = group["weight_decay"]
522+
eps = group["eps"]
523+
amsgrad = group["amsgrad"]
524+
for name, param in zip(group["param_names"], group["params"]):
525+
name = self._clean_param_name(name)
526+
param_names.append(name)
527+
grad = param.grad
528+
if grad is None:
529+
step_size_norms.append(torch.tensor([0.0], device=device))
530+
step_size_maxs.append(torch.tensor([0.0], device=device))
531+
continue
532+
533+
state = self.state[param]
534+
# init state if needed
535+
if len(state) == 0:
536+
state["step"] = (
537+
torch.zeros((), dtype=torch.float32, device=param.device)
538+
if group["capturable"] or group["fused"]
539+
else torch.tensor(0.0, dtype=torch.float32)
540+
)
541+
# Exponential moving average of gradient values
542+
state["exp_avg"] = torch.zeros_like(param, memory_format=torch.preserve_format)
543+
# Exponential moving average of squared gradient values
544+
state["exp_avg_sq"] = torch.zeros_like(param, memory_format=torch.preserve_format)
545+
if amsgrad:
546+
# Maintains max of all exp. moving avg. of sq. grad. values
547+
state["max_exp_avg_sq"] = torch.zeros_like(param, memory_format=torch.preserve_format)
548+
549+
exp_avg = state["exp_avg"]
550+
exp_avg_sq = state["exp_avg_sq"]
551+
step_t = state["step"]
552+
553+
# Update step.
554+
step_t += 1
555+
556+
# Perform step weight decay.
557+
param.mul_(1 - lr * weight_decay)
558+
559+
# Decay the first and second moment running average coefficient.
560+
exp_avg.lerp_(grad, 1 - beta1)
561+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
562+
563+
step = step_t.item()
564+
565+
bias_correction1 = 1 - beta1**step
566+
bias_correction2 = 1 - beta2**step
567+
568+
step_size = lr / bias_correction1
569+
570+
bias_correction2_sqrt = sqrt(bias_correction2)
571+
572+
if amsgrad:
573+
max_exp_avg_sq = state["max_exp_avg_sq"]
574+
# Maintains the maximum of all 2nd moment running avg. till now
575+
torch.maximum(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
576+
577+
# Use the max. for normalizing running avg. of gradient
578+
denom = (max_exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
579+
else:
580+
denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
581+
582+
update = -step_size * torch.div(exp_avg, denom)
583+
param.add_(update)
584+
step_size_norms.append(torch.linalg.vector_norm(update, 2.0, dtype=torch.float32).unsqueeze(0))
585+
step_size_maxs.append(update.abs().max().unsqueeze(0))
586+
587+
self._step_size_param_names = param_names
588+
self._step_size_norms = step_size_norms
589+
self._step_size_maxs = step_size_maxs
590+
477591
def get_state_for_param(self, param: nn.Parameter) -> Dict[str, Optional[torch.Tensor]]:
478592
return {key: self.state[param].get(key) for key in ("exp_avg", "exp_avg_sq")} # type: ignore
479593

594+
def get_post_step_metrics(
595+
self, module: nn.Module, process_group: Optional[dist.ProcessGroup] = None
596+
) -> Dict[str, torch.Tensor]:
597+
if not (self._record_update_metrics and self._collecting_metrics):
598+
return {}
599+
else:
600+
device = get_default_device()
601+
dst_rank = 0
602+
if process_group is not None:
603+
dst_rank = dist.get_global_rank(process_group, 0)
604+
param_names = self._step_size_param_names
605+
step_size_norms = self._step_size_norms
606+
step_size_maxs = self._step_size_maxs
607+
assert param_names is not None
608+
assert step_size_norms is not None
609+
assert step_size_maxs is not None
610+
611+
# Reduce metrics if needed.
612+
if is_distributed() and isinstance(module, FullyShardedDataParallel):
613+
# Reduce norms.
614+
all_norms = torch.cat(step_size_norms).to(device) ** 2.0
615+
dist.reduce(all_norms, dst_rank, op=dist.ReduceOp.SUM, group=process_group)
616+
step_size_norms = (all_norms ** (0.5)).squeeze(0).split(1)
617+
618+
# Reduce maxs.
619+
all_maxs = torch.cat(step_size_maxs).to(device)
620+
dist.reduce(all_maxs, dst_rank, op=dist.ReduceOp.MAX, group=process_group)
621+
step_size_maxs = all_maxs.split(1)
622+
623+
metrics = {}
624+
for param_name, step_size_norm, step_size_max in zip(param_names, step_size_norms, step_size_maxs): # type: ignore[arg-type]
625+
metrics[f"step/{param_name}.norm"] = step_size_norm.squeeze(0)
626+
metrics[f"step/{param_name}.max"] = step_size_max.squeeze(0)
627+
628+
self._step_size_param_names = None
629+
self._step_size_norms = None
630+
self._step_size_maxs = None
631+
return metrics
632+
480633

481634
@dataclass
482635
class Scheduler(metaclass=ABCMeta):
@@ -745,13 +898,15 @@ def build_optimizer(cfg: TrainConfig, model: nn.Module) -> Optimizer:
745898
lr=cfg.optimizer.learning_rate,
746899
betas=cfg.optimizer.betas,
747900
weight_decay=cfg.optimizer.weight_decay,
901+
record_update_metrics=cfg.optimizer.record_update_metrics,
748902
)
749903
elif cfg.optimizer.name == OptimizerType.adamw:
750904
return AdamW(
751905
param_groups,
752906
lr=cfg.optimizer.learning_rate,
753907
betas=cfg.optimizer.betas,
754908
weight_decay=cfg.optimizer.weight_decay,
909+
record_update_metrics=cfg.optimizer.record_update_metrics,
755910
eps=cfg.optimizer.eps,
756911
)
757912
else:

olmo/train.py

+1
Original file line numberDiff line numberDiff line change
@@ -1245,6 +1245,7 @@ def on_trace_ready(p):
12451245
log.info("Training epoch complete")
12461246
self.epoch = epoch + 1
12471247
self.global_train_examples_seen_this_epoch = 0
1248+
self.dataset.start_index = 0
12481249
if self.epoch < self.max_epochs:
12491250
self.dataset.reshuffle()
12501251
continue

0 commit comments

Comments
 (0)