Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inquiry about the RoPE difference between FAESM and ESM2 #13

Open
Setsuna111 opened this issue Jan 22, 2025 · 3 comments
Open

Inquiry about the RoPE difference between FAESM and ESM2 #13

Setsuna111 opened this issue Jan 22, 2025 · 3 comments

Comments

@Setsuna111
Copy link

Thanks for your excellent work!

When using the RotaryEmbedding before the attention score calculation, I noticed that differed from those generated by the vanilla ESM2 implementation with its original RoPE embedding method.

Although both embeddings seem to be implemented correctly, the differences may not be acceptable when using pretrained parameters from ESM2.

Here are the implementation details:

At line:161 in esm.py, during a protein embedding inference task (batch_size=2), I generated the q and k with faesm's RoPE following faesm's scripts:

qkv = self.rotary_embeddings(qkv=qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
q_fa, k_fa = qkv[:, 0], qkv[:, 1]
q_fa_0, q_fa_1, k_fa_0, k_fa_1 = q_fa[cu_seqlens[0]: cu_seqlens[1]], q_fa[cu_seqlens[1]: cu_seqlens[2]], k_fa[cu_seqlens[0]: cu_seqlens[1]], k_fa[cu_seqlens[1]: cu_seqlens[2]]

The resulting outputs (partial example):

q_fa_0=
tensor([[[-2.3041e-02,  9.9304e-02,  7.9895e-02,  ..., -1.0767e-03,
           8.3590e-04, -3.1924e-04],
         [-7.0923e-02,  5.4962e-02, -6.2042e-02,  ..., -4.0771e-02,
          -1.2293e-03, -9.8884e-05],
         [-1.1688e-01, -1.8689e-01,  1.7371e-01,  ..., -3.5453e-04,
          -1.6037e-02, -4.9305e-04],
         ...,
q_fa_1=
tensor([[[-2.3041e-02,  9.9304e-02,  7.9895e-02,  ..., -1.0767e-03,
           8.3447e-04, -3.1734e-04],
         [-7.0923e-02,  5.4962e-02, -6.2042e-02,  ..., -4.0771e-02,
          -1.2283e-03, -1.0097e-04],
         [-1.1688e-01, -1.8689e-01,  1.7371e-01,  ..., -3.5381e-04,
          -1.6068e-02, -4.9305e-04],
         ...,
k_fa_0=
tensor([[[ 5.2124e-02,  2.3026e-02,  2.3071e-02,  ...,  1.9455e-02,
           1.7685e-02,  1.7023e-03],
         [ 1.6296e-01, -5.6885e-01,  1.6479e-02,  ...,  1.0586e+00,
           3.5370e-02,  6.6711e-02],
         [-1.2324e+00,  1.3086e+00, -7.3730e-01,  ..., -3.9948e-02,
          -2.5391e-01,  2.9877e-02],
         ...,
k_fa_1=
tensor([[[ 5.2185e-02,  2.3056e-02,  2.3056e-02,  ...,  1.9455e-02,
           1.7670e-02,  1.6966e-03],
         [ 1.6296e-01, -5.6885e-01,  1.6541e-02,  ...,  1.0586e+00,
           3.5370e-02,  6.6711e-02],
         [-1.2314e+00,  1.3096e+00, -7.3779e-01,  ..., -3.9978e-02,
          -2.5366e-01,  2.9861e-02],
         ...,

Subsequently, with the same input, I generated the embeddings using ESM2's RoPE method:

from transformers.models.esm.modeling_esm import RotaryEmbedding
esm_rot = RotaryEmbedding(dim=self.attention_head_size)
q_esm_0, k_esm_0 = esm_rot(qkv[cu_seqlens[0]: cu_seqlens[1], 0], qkv[cu_seqlens[0]: cu_seqlens[1], 1])
q_esm_1, k_esm_1 = esm_rot(qkv[cu_seqlens[1]: cu_seqlens[2], 0], qkv[cu_seqlens[1]: cu_seqlens[2], 1])

Unfortunately, I observed differences in the outputs:

q_esm_0=
tensor([[[[-2.3041e-02,  9.9304e-02,  7.9895e-02,  ..., -1.0767e-03,
            8.3590e-04, -3.1924e-04],
          [-4.2380e-02, -3.2871e-02, -3.6795e-02,  ..., -4.0770e-02,
           -1.2293e-03, -9.8925e-05],
          [ 2.1181e-01,  1.4430e-01, -1.0543e-01,  ..., -3.5444e-04,
           -1.6036e-02, -4.9315e-04],
          ...,
q_esm_1=
tensor([[[[-2.3041e-02,  9.9304e-02,  7.9895e-02,  ..., -1.0767e-03,
            8.3447e-04, -3.1734e-04],
          [-4.2384e-02, -3.2871e-02, -3.6795e-02,  ..., -4.0770e-02,
           -1.2283e-03, -1.0101e-04],
          [ 2.1181e-01,  1.4430e-01, -1.0543e-01,  ..., -3.5373e-04,
           -1.6067e-02, -4.9315e-04],
          ...,
k_esm_0 =
tensor([[[[ 5.2124e-02,  2.3026e-02,  2.3071e-02,  ...,  1.9455e-02,
            1.7685e-02,  1.7023e-03],
          [ 2.5332e-01, -3.5166e-01, -1.3067e-01,  ...,  1.0586e+00,
            3.5377e-02,  6.6712e-02],
          [ 8.5230e-01,  6.8948e-01,  2.0871e-01,  ..., -3.9941e-02,
           -2.5386e-01,  2.9870e-02],
          ...,

k_esm_1 =
tensor([[[[ 5.2185e-02,  2.3056e-02,  2.3056e-02,  ...,  1.9455e-02,
            1.7670e-02,  1.6966e-03],
          [ 2.5343e-01, -3.5170e-01, -1.3049e-01,  ...,  1.0586e+00,
            3.5377e-02,  6.6712e-02],
          [ 8.5145e-01,  6.8955e-01,  2.0850e-01,  ..., -3.9972e-02,
           -2.5361e-01,  2.9855e-02],
          ...,

The embedded features differ significantly except for the first row (unaffected by the cosine and sine table differences between the two methods).

I'm wondering whether using the speed-optimized FAESM poses a risk due to its deviation from the original ESM implementation? Thanks again for your work.

@pengzhangzhi
Copy link
Owner

thanks for posting it. flash attention actually unpad the sequences. Say x is shape [B, L] normally for ESM2. But it contains some padding tokens in x, so in flash attention, we get rid of it and have x_unpad [N,], where N is the all the valid tokens across the B samples. If u print the shape using flash attention and esm, it's diff. So make sure u compare the right elements.

also, u wanna compute the error, the diff. instead of staring at the numbers.
In my benchmarking, the errors are acceptable.

@Setsuna111
Copy link
Author

Thanks for your response! Acturally I compared the exact elements because I indexed the input embeddings from the unpaded q and k generated by the FAESM scripts. The numerical errors I focused on indicate the theoretical compatibility of using ESM2's parameters, which were trained under a different RoPE method.

In my task of fine-tuning ESM2's parameters, I prefer to use the original RoPE method from ESM2 to keep faithful to the original parameters, even though this approach might slow down the embedding process. Nevertheless, I will compare the two embeddings methods and looking forward to share the results with you.

Anyway, FAESM is an excellent work that contributes to protein understanding and the AI4Bio community. Thanks!

@pengzhangzhi
Copy link
Owner

ya u wanna show the max abs diff. to see the error.

BTW i trained all my models with FAESM, it works pretty well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants