Skip to content

LLaMAMoE fixes #2014

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft

Conversation

ysjprojects
Copy link
Contributor

Addresses #2013

Currently, there is a mismatch in MoE block implementation between litgpt and hf for Mixtral models.

hf:

routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
        routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
        routing_weights /= routing_weights.sum(dim=-1, keepdim=True)

In litgpt, softmax operation is performed after selecting the top k experts. While the experts are preserved, the probability values are different and will affect downstream calculations. The L1 normalization in hf implementation is missing in litgpt implementation as well.

Copy link
Member

@Borda Borda left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, can we add a test for this fix?

@ysjprojects
Copy link
Contributor Author

btw, can we add a test for this fix?

Hi, can I clarify what will the test look like?

@Borda
Copy link
Member

Borda commented Apr 15, 2025

Hi, can I clarify what will the test look like?

something which would "justify" this change, so having a case that would be failing before and passing no so we won't accidentally revert this change

@ysjprojects
Copy link
Contributor Author

Hi, can I clarify what will the test look like?

something which would "justify" this change, so having a case that would be failing before and passing no so we won't accidentally revert this change

On second thought, the current implementation passes all tests and the logits are close enough to the HF model's.

There is a slight deviation but ultimately it does not actually affect the model, so maybe we can close this PR.

@Borda
Copy link
Member

Borda commented May 16, 2025

On second thought, the current implementation passes all tests and the logits are close enough to the HF model's.
There is a slight deviation but ultimately it does not actually affect the model, so maybe we can close this PR.

sure my thinking is this cool fix what prevents us from accidentally reverting it, since all tests would be passing

Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you do the math, it seems that the new method is just a particularly clumsy way to compute the same value and this is why you would not see much of a difference.

Unless it is causing real problems in real implementations, we should keep the way we currently have with topk first and then softmax.

Edit: If it is causing real problems, we can change it (e.g. we did adapt to some Llama version needing specific casting for the RoPE cache at some point), but then we should add commentary.

@Borda Borda self-requested a review May 16, 2025 08:13
@ysjprojects
Copy link
Contributor Author

If you do the math, it seems that the new method is just a particularly clumsy way to compute the same value and this is why you would not see much of a difference.

Unless it is causing real problems in real implementations, we should keep the way we currently have with topk first and then softmax.

Edit: If it is causing real problems, we can change it (e.g. we did adapt to some Llama version needing specific casting for the RoPE cache at some point), but then we should add commentary.

I see now. Thanks for educating me!

@Borda Borda marked this pull request as draft May 23, 2025 14:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants