You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Before submitting a bug, please make sure the issue hasn't been already addressed by searching through the FAQs and existing/past issues
Describe the bug
I have one set of weights, one tokenizer, the same prompt, and identical generation parameters. Yet somehow, when I load the model using AutoModelForCausalLM, I get one output, and when I construct it manually with LlamaForCausalLM plus the same config and state_dict, I get another output entirely.
This code can show the difference on both a6000 and a100.
Minimal reproducible example
<Remember to wrap the code in ```triple-quotes blocks```>
importtorchfromtransformersimport (
AutoTokenizer,
AutoModelForCausalLM,
LlamaForCausalLM,
LlamaConfig
)
# 1) Adjust these as neededmodel_name="meta-llama/Llama-3.1-8B"prompt="Hello from Llama 3.1! Tell me something interesting."dtype=torch.float16# or torch.float32 if needed# 2) Get the tokenizertokenizer=AutoTokenizer.from_pretrained(model_name, use_fast=False)
# Prepare inputinputs=tokenizer(prompt, return_tensors="pt").to("cuda")
############################################# A) Load with AutoModelForCausalLM############################################print("=== Loading with AutoModelForCausalLM ===")
model_auto=AutoModelForCausalLM.from_pretrained(
model_name,
attn_implementation="eager", # matches your usagetorch_dtype=dtype
).cuda()
model_auto.eval() # turn off dropoutconfig=model_auto.configwithtorch.no_grad():
out_auto=model_auto(**inputs)
logits_auto=out_auto.logits# shape: [batch_size, seq_len, vocab_size]delmodel_autotorch.cuda.empty_cache()
############################################# B) Load with LlamaForCausalLM + config############################################print("=== Loading with LlamaForCausalLM + config ===")
# Get config from the same checkpoint# Build Llama model directlymodel_llama=LlamaForCausalLM(config).cuda()
model_llama.eval()
# Load the same weights that AutoModelForCausalLM usedmodel_auto_temp=AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=dtype)
model_llama.load_state_dict(model_auto_temp.state_dict())
delmodel_auto_temptorch.cuda.empty_cache()
withtorch.no_grad():
out_llama=model_llama(**inputs)
logits_llama=out_llama.logits############################################# C) Compare the Logits############################################# Compute maximum absolute differencemax_diff= (logits_auto-logits_llama).abs().max()
print(f"\nMax absolute difference between logits: {max_diff.item()}")
ifmax_diff<1e-7:
print("→ The logits are effectively identical (within floating-point precision).")
else:
print("→ There is a non-trivial difference in logits!")
Output
<Remember to wrap the output in ```triple-quotes blocks```>
Max absolute difference between logits: 0.11245954036712646
→ There is a non-trivial difference in logits!
Runtime Environment
Model: meta-llama/Llama-3.1-8B
Using via huggingface?: yes
OS: Linux
GPU VRAM: 40GB
Number of GPUs: 1
GPU Make: Nvidia
Additional context
Add any other context about the problem or environment here.
The text was updated successfully, but these errors were encountered:
Before submitting a bug, please make sure the issue hasn't been already addressed by searching through the FAQs and existing/past issues
Describe the bug
I have one set of weights, one tokenizer, the same prompt, and identical generation parameters. Yet somehow, when I load the model using AutoModelForCausalLM, I get one output, and when I construct it manually with LlamaForCausalLM plus the same config and state_dict, I get another output entirely.
This code can show the difference on both a6000 and a100.
Minimal reproducible example
<Remember to wrap the code in
```triple-quotes blocks```
>Output
<Remember to wrap the output in
```triple-quotes blocks```
>Runtime Environment
Additional context
Add any other context about the problem or environment here.
The text was updated successfully, but these errors were encountered: