-
Notifications
You must be signed in to change notification settings - Fork 476
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix(ppo_trainer): compute mean KL sequence-wise #441
Conversation
@@ -435,7 +435,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq | |||
start = prompt_tensors.shape[1] - 1 | |||
|
|||
log_ratio = (logprobs - ref_logprobs) * attention_mask[:, :-1] | |||
self.mean_kl = (log_ratio.exp() - 1 - log_ratio).mean().to(device) | |||
mean_kl = (log_ratio.exp() - 1 - log_ratio).sum(1).mean().to(device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing this. Do you think it makes sense to add an option to the config allowing us to choose between token wise vs. sequence wise kl? I agree having a kl computation invariant to seq length is good to keep around
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's easy enough to log both, but for `AdaptiveKLController's purposes I would just stick with one variant which is seq-wise as in lm-human-preferences, since targets are already copied from there and also the respective paper (Ziegler2019 et al.) has target KL value of 8
https://github.com/openai/lm-human-preferences/blob/ec727fde10f1eafb3177e9b0f41a42142e95a2fd/launch.py#L-129-L131
@@ -470,18 +470,21 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq | |||
) | |||
|
|||
rollout_count += 1 | |||
exp_time = clock.tick() | |||
|
|||
if torch.distributed.is_initialized(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice
torch.distributed.all_reduce(self.mean_kl, torch.distributed.ReduceOp.AVG) | ||
|
||
stats["policy/sqrt_kl"] = torch.sqrt(self.mean_kl).item() | ||
stats = {k: sum([xs[k] for xs in accumulated_stats]) / len(accumulated_stats) for k in stats} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Strictly speaking this isn't necessarily correct e.g. if we are recording the max over all local rollouts. However I don't know how to perform the correct reduction without annotating each stat, so this seems fine for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are no maxs/mins in this particular dict, but I see what you are saying. Also I guess you're referring to
trlx/trlx/trainer/accelerate_base_trainer.py
Line 544 in 9bc0836
stats = {key: sum([stats[key] for stats in stats_accum]) / self.num_mb for key in stats_accum[0]} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
Thanks for this great work! May you explain a bit why this is sequence-wise KL? I could not connect |
@lzy37ld It's sequence-wise KL because of the |
@reciprocated Thanks! That makes sense! It's really a fantastic implementation! |
@lzy37ld It's optional normalization (disabled with setting Line 171 in 6e655a4
which @PhungVanDuy has found to work better than passing raw reward |
This PR fixes #438, specifically:
exp_*
logging variables are renamed torollout_*
, to not confuse them withsqrt_*
https://wandb.ai/sorry/trlx/reports/Sequence-wise-v-token-wise-mean-KL--Vmlldzo0MTE0NzUy
https://wandb.ai/sorry/trlx-references/reports/fix-kl-computation-v-main--Vmlldzo0MTE3Nzc4