35
35
36
36
37
37
class Optimizer (OptimizerBase ):
38
+ def __init__ (self , * args , record_update_metrics : bool = False , ** kwargs ):
39
+ super ().__init__ (* args , ** kwargs )
40
+ self ._record_update_metrics = record_update_metrics
41
+ self ._collecting_metrics = False
42
+
38
43
def _clean_param_name (self , name : str ) -> str :
39
44
return name .replace ("_fsdp_wrapped_module." , "" )
40
45
@@ -50,6 +55,7 @@ def clip_grads_and_collect_metrics(
50
55
Clips gradients for every group that has the field `max_grad_norm`.
51
56
At the same time collect metrics for each parameter and its gradient.
52
57
"""
58
+ self ._collecting_metrics = collect_param_metrics
53
59
device = get_default_device () if device is None else device
54
60
55
61
# NOTE (epwalsh): during distributed training we're making an assumption that the order of
@@ -365,12 +371,13 @@ def __init__(
365
371
lr : float = 1e-4 ,
366
372
betas : Tuple [float , float ] = (0.9 , 0.99 ),
367
373
weight_decay : float = 0.0 ,
374
+ record_update_metrics : bool = False ,
368
375
device : Optional [torch .device ] = None ,
369
376
):
370
377
assert lr > 0.0
371
378
assert all ([0.0 <= beta <= 1.0 for beta in betas ])
372
379
defaults = dict (lr = lr , betas = betas , weight_decay = weight_decay )
373
- super ().__init__ (params , defaults )
380
+ super ().__init__ (params , defaults , record_update_metrics = record_update_metrics )
374
381
for group in self .param_groups :
375
382
group ["initial_lr" ] = group ["lr" ]
376
383
self ._update_total_dot_prod : Optional [torch .Tensor ] = None
@@ -391,6 +398,10 @@ def get_post_step_metrics(
391
398
if update_total_dot_prod is None or update_total_norm is None or signed_update_total_norm is None :
392
399
return {}
393
400
401
+ self ._update_total_dot_prod = None
402
+ self ._update_total_norm = None
403
+ self ._signed_update_total_norm = None
404
+
394
405
if is_distributed () and isinstance (module , FullyShardedDataParallel ):
395
406
# Reduce total dot prod and norms across all ranks.
396
407
update_total_norm = update_total_norm ** 2.0
@@ -419,9 +430,13 @@ def step(self, closure=None) -> None:
419
430
with torch .enable_grad ():
420
431
closure ()
421
432
422
- update_total_dot_prod = torch .tensor (0.0 , dtype = torch .float32 )
423
- update_norms = []
424
- signed_update_norms = []
433
+ update_total_dot_prod : Optional [torch .Tensor ] = None
434
+ update_norms : Optional [List [torch .Tensor ]] = None
435
+ signed_update_norms : Optional [List [torch .Tensor ]] = None
436
+ if self ._collecting_metrics and self ._record_update_metrics :
437
+ update_total_dot_prod = torch .tensor (0.0 , dtype = torch .float32 )
438
+ update_norms = []
439
+ signed_update_norms = []
425
440
426
441
for group in self .param_groups :
427
442
for p in group ["params" ]:
@@ -452,31 +467,169 @@ def step(self, closure=None) -> None:
452
467
453
468
# Track dot product and norms of update vs signed update in order to calculate
454
469
# their cosine similarity.
455
- update_total_dot_prod = update_total_dot_prod .to (update .device )
456
- update_total_dot_prod += torch .tensordot (update , signed_update , dims = len (update .shape ))
457
- update_norms .append (torch .linalg .vector_norm (update , 2.0 , dtype = torch .float32 ))
458
- signed_update_norms .append (torch .linalg .vector_norm (signed_update , 2.0 , dtype = torch .float32 ))
470
+ if (
471
+ update_total_dot_prod is not None
472
+ and update_norms is not None
473
+ and signed_update_norms is not None
474
+ ):
475
+ update_total_dot_prod = update_total_dot_prod .to (update .device )
476
+ update_total_dot_prod += torch .tensordot (update , signed_update , dims = len (update .shape ))
477
+ update_norms .append (torch .linalg .vector_norm (update , 2.0 , dtype = torch .float32 ))
478
+ signed_update_norms .append (torch .linalg .vector_norm (signed_update , 2.0 , dtype = torch .float32 ))
459
479
460
480
# Compute cosine similarity between update and signed update.
461
- self . _update_total_dot_prod = update_total_dot_prod . to (
462
- get_default_device () if self ._device is None else self ._device
463
- )
464
- self ._update_total_norm = torch .linalg .vector_norm (
465
- torch .stack (update_norms ),
466
- 2.0 ,
467
- dtype = torch .float32 ,
468
- ). to ( get_default_device () if self . _device is None else self . _device )
469
- self ._signed_update_total_norm = torch .linalg .vector_norm (
470
- torch .stack (signed_update_norms ),
471
- 2.0 ,
472
- dtype = torch .float32 ,
473
- ). to ( get_default_device () if self . _device is None else self . _device )
481
+ if update_total_dot_prod is not None and update_norms is not None and signed_update_norms is not None :
482
+ device = get_default_device () if self ._device is None else self ._device
483
+ self . _update_total_dot_prod = update_total_dot_prod . to ( device )
484
+ self ._update_total_norm = torch .linalg .vector_norm (
485
+ torch .stack (update_norms ),
486
+ 2.0 ,
487
+ dtype = torch .float32 ,
488
+ ). to ( device )
489
+ self ._signed_update_total_norm = torch .linalg .vector_norm (
490
+ torch .stack (signed_update_norms ),
491
+ 2.0 ,
492
+ dtype = torch .float32 ,
493
+ ). to ( device )
474
494
475
495
476
496
class AdamW (torch .optim .AdamW , Optimizer ):
497
+ def __init__ (self , * args , record_update_metrics : bool = False , ** kwargs ):
498
+ super ().__init__ (* args , ** kwargs )
499
+
500
+ # Need to set these here just like in our base `Optimizer` class since our `Optimizer.__init__`
501
+ # won't be called.
502
+ self ._record_update_metrics = record_update_metrics
503
+ self ._collecting_metrics = False
504
+
505
+ self ._step_size_param_names : Optional [List [str ]] = None
506
+ self ._step_size_norms : Optional [List [torch .Tensor ]] = None
507
+ self ._step_size_maxs : Optional [List [torch .Tensor ]] = None
508
+
509
+ @torch .no_grad ()
510
+ def step (self , closure = None ) -> None :
511
+ if not (self ._record_update_metrics and self ._collecting_metrics ):
512
+ return super ().step (closure = closure )
513
+
514
+ device = get_default_device ()
515
+ param_names = []
516
+ step_size_norms = []
517
+ step_size_maxs = []
518
+ for group in self .param_groups :
519
+ beta1 , beta2 = group ["betas" ]
520
+ lr = group ["lr" ]
521
+ weight_decay = group ["weight_decay" ]
522
+ eps = group ["eps" ]
523
+ amsgrad = group ["amsgrad" ]
524
+ for name , param in zip (group ["param_names" ], group ["params" ]):
525
+ name = self ._clean_param_name (name )
526
+ param_names .append (name )
527
+ grad = param .grad
528
+ if grad is None :
529
+ step_size_norms .append (torch .tensor ([0.0 ], device = device ))
530
+ step_size_maxs .append (torch .tensor ([0.0 ], device = device ))
531
+ continue
532
+
533
+ state = self .state [param ]
534
+ # init state if needed
535
+ if len (state ) == 0 :
536
+ state ["step" ] = (
537
+ torch .zeros ((), dtype = torch .float32 , device = param .device )
538
+ if group ["capturable" ] or group ["fused" ]
539
+ else torch .tensor (0.0 , dtype = torch .float32 )
540
+ )
541
+ # Exponential moving average of gradient values
542
+ state ["exp_avg" ] = torch .zeros_like (param , memory_format = torch .preserve_format )
543
+ # Exponential moving average of squared gradient values
544
+ state ["exp_avg_sq" ] = torch .zeros_like (param , memory_format = torch .preserve_format )
545
+ if amsgrad :
546
+ # Maintains max of all exp. moving avg. of sq. grad. values
547
+ state ["max_exp_avg_sq" ] = torch .zeros_like (param , memory_format = torch .preserve_format )
548
+
549
+ exp_avg = state ["exp_avg" ]
550
+ exp_avg_sq = state ["exp_avg_sq" ]
551
+ step_t = state ["step" ]
552
+
553
+ # Update step.
554
+ step_t += 1
555
+
556
+ # Perform step weight decay.
557
+ param .mul_ (1 - lr * weight_decay )
558
+
559
+ # Decay the first and second moment running average coefficient.
560
+ exp_avg .lerp_ (grad , 1 - beta1 )
561
+ exp_avg_sq .mul_ (beta2 ).addcmul_ (grad , grad , value = 1 - beta2 )
562
+
563
+ step = step_t .item ()
564
+
565
+ bias_correction1 = 1 - beta1 ** step
566
+ bias_correction2 = 1 - beta2 ** step
567
+
568
+ step_size = lr / bias_correction1
569
+
570
+ bias_correction2_sqrt = sqrt (bias_correction2 )
571
+
572
+ if amsgrad :
573
+ max_exp_avg_sq = state ["max_exp_avg_sq" ]
574
+ # Maintains the maximum of all 2nd moment running avg. till now
575
+ torch .maximum (max_exp_avg_sq , exp_avg_sq , out = max_exp_avg_sq )
576
+
577
+ # Use the max. for normalizing running avg. of gradient
578
+ denom = (max_exp_avg_sq .sqrt () / bias_correction2_sqrt ).add_ (eps )
579
+ else :
580
+ denom = (exp_avg_sq .sqrt () / bias_correction2_sqrt ).add_ (eps )
581
+
582
+ update = - step_size * torch .div (exp_avg , denom )
583
+ param .add_ (update )
584
+ step_size_norms .append (torch .linalg .vector_norm (update , 2.0 , dtype = torch .float32 ).unsqueeze (0 ))
585
+ step_size_maxs .append (update .abs ().max ().unsqueeze (0 ))
586
+
587
+ self ._step_size_param_names = param_names
588
+ self ._step_size_norms = step_size_norms
589
+ self ._step_size_maxs = step_size_maxs
590
+
477
591
def get_state_for_param (self , param : nn .Parameter ) -> Dict [str , Optional [torch .Tensor ]]:
478
592
return {key : self .state [param ].get (key ) for key in ("exp_avg" , "exp_avg_sq" )} # type: ignore
479
593
594
+ def get_post_step_metrics (
595
+ self , module : nn .Module , process_group : Optional [dist .ProcessGroup ] = None
596
+ ) -> Dict [str , torch .Tensor ]:
597
+ if not (self ._record_update_metrics and self ._collecting_metrics ):
598
+ return {}
599
+ else :
600
+ device = get_default_device ()
601
+ dst_rank = 0
602
+ if process_group is not None :
603
+ dst_rank = dist .get_global_rank (process_group , 0 )
604
+ param_names = self ._step_size_param_names
605
+ step_size_norms = self ._step_size_norms
606
+ step_size_maxs = self ._step_size_maxs
607
+ assert param_names is not None
608
+ assert step_size_norms is not None
609
+ assert step_size_maxs is not None
610
+
611
+ # Reduce metrics if needed.
612
+ if is_distributed () and isinstance (module , FullyShardedDataParallel ):
613
+ # Reduce norms.
614
+ all_norms = torch .cat (step_size_norms ).to (device ) ** 2.0
615
+ dist .reduce (all_norms , dst_rank , op = dist .ReduceOp .SUM , group = process_group )
616
+ step_size_norms = (all_norms ** (0.5 )).squeeze (0 ).split (1 )
617
+
618
+ # Reduce maxs.
619
+ all_maxs = torch .cat (step_size_maxs ).to (device )
620
+ dist .reduce (all_maxs , dst_rank , op = dist .ReduceOp .MAX , group = process_group )
621
+ step_size_maxs = all_maxs .split (1 )
622
+
623
+ metrics = {}
624
+ for param_name , step_size_norm , step_size_max in zip (param_names , step_size_norms , step_size_maxs ): # type: ignore[arg-type]
625
+ metrics [f"step/{ param_name } .norm" ] = step_size_norm .squeeze (0 )
626
+ metrics [f"step/{ param_name } .max" ] = step_size_max .squeeze (0 )
627
+
628
+ self ._step_size_param_names = None
629
+ self ._step_size_norms = None
630
+ self ._step_size_maxs = None
631
+ return metrics
632
+
480
633
481
634
@dataclass
482
635
class Scheduler (metaclass = ABCMeta ):
@@ -745,13 +898,15 @@ def build_optimizer(cfg: TrainConfig, model: nn.Module) -> Optimizer:
745
898
lr = cfg .optimizer .learning_rate ,
746
899
betas = cfg .optimizer .betas ,
747
900
weight_decay = cfg .optimizer .weight_decay ,
901
+ record_update_metrics = cfg .optimizer .record_update_metrics ,
748
902
)
749
903
elif cfg .optimizer .name == OptimizerType .adamw :
750
904
return AdamW (
751
905
param_groups ,
752
906
lr = cfg .optimizer .learning_rate ,
753
907
betas = cfg .optimizer .betas ,
754
908
weight_decay = cfg .optimizer .weight_decay ,
909
+ record_update_metrics = cfg .optimizer .record_update_metrics ,
755
910
eps = cfg .optimizer .eps ,
756
911
)
757
912
else :
0 commit comments