Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(ppo_trainer): compute mean KL sequence-wise #441

Merged
merged 4 commits into from
Apr 20, 2023
Merged

Conversation

maxreciprocate
Copy link
Collaborator

@maxreciprocate maxreciprocate commented Apr 19, 2023

This PR fixes #438, specifically:

  • Mean KL, which used for statistics and AdaptiveKLController, is now calculated sequence-wise and not token-wise
  • Statistics are now averaged across rollouts, instead of just taking them only from the last one
  • exp_* logging variables are renamed to rollout_*, to not confuse them with sqrt_*

https://wandb.ai/sorry/trlx/reports/Sequence-wise-v-token-wise-mean-KL--Vmlldzo0MTE0NzUy

https://wandb.ai/sorry/trlx-references/reports/fix-kl-computation-v-main--Vmlldzo0MTE3Nzc4

@maxreciprocate maxreciprocate requested a review from Dahoas April 19, 2023 16:40
@@ -435,7 +435,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
start = prompt_tensors.shape[1] - 1

log_ratio = (logprobs - ref_logprobs) * attention_mask[:, :-1]
self.mean_kl = (log_ratio.exp() - 1 - log_ratio).mean().to(device)
mean_kl = (log_ratio.exp() - 1 - log_ratio).sum(1).mean().to(device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this. Do you think it makes sense to add an option to the config allowing us to choose between token wise vs. sequence wise kl? I agree having a kl computation invariant to seq length is good to keep around

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's easy enough to log both, but for `AdaptiveKLController's purposes I would just stick with one variant which is seq-wise as in lm-human-preferences, since targets are already copied from there and also the respective paper (Ziegler2019 et al.) has target KL value of 8
https://github.com/openai/lm-human-preferences/blob/ec727fde10f1eafb3177e9b0f41a42142e95a2fd/launch.py#L-129-L131

@@ -470,18 +470,21 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
)

rollout_count += 1
exp_time = clock.tick()

if torch.distributed.is_initialized():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice

torch.distributed.all_reduce(self.mean_kl, torch.distributed.ReduceOp.AVG)

stats["policy/sqrt_kl"] = torch.sqrt(self.mean_kl).item()
stats = {k: sum([xs[k] for xs in accumulated_stats]) / len(accumulated_stats) for k in stats}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strictly speaking this isn't necessarily correct e.g. if we are recording the max over all local rollouts. However I don't know how to perform the correct reduction without annotating each stat, so this seems fine for now

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are no maxs/mins in this particular dict, but I see what you are saying. Also I guess you're referring to

stats = {key: sum([stats[key] for stats in stats_accum]) / self.num_mb for key in stats_accum[0]}
but it's also pretty easy to fix, perhaps in a separate PR

Copy link
Collaborator

@Dahoas Dahoas left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

@Dahoas Dahoas merged commit 6e655a4 into main Apr 20, 2023
@lzy37ld
Copy link

lzy37ld commented Apr 21, 2023

Thanks for this great work! May you explain a bit why this is sequence-wise KL? I could not connect 'kl = log_ratio.exp() - 1 - log_ratio' to the real math formulation p(x) logp(x)/q(x).
Gently ping @reciprocated

@maxreciprocate
Copy link
Collaborator Author

@lzy37ld It's sequence-wise KL because of the .sum(1) before taking the average, while KL expression itself comes from this blogpost http://joschu.net/blog/kl-approx.html and is commonly used https://github.com/DLR-RM/stable-baselines3/blob/dc09d81f9c07943ddbeac57405d9ae2a31f4d434/stable_baselines3/ppo/ppo.py#L255

@maxreciprocate maxreciprocate deleted the fix-kl-computation branch April 21, 2023 11:28
@lzy37ld
Copy link

lzy37ld commented Apr 22, 2023

@reciprocated Thanks! That makes sense! It's really a fantastic implementation!
One more question here: Why here we would subtract a reward(x,y_original) at the reward_fn? When I looked at the paper, I notice that they only focus on R(x, generated_y)
image

@maxreciprocate
Copy link
Collaborator Author

@lzy37ld It's optional normalization (disabled with setting delta_reward to False)

delta_reward = True

which @PhungVanDuy has found to work better than passing raw reward

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

The mean_kl implementation is different from that of openai/lm-human-preference
3 participants