Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

State Dict Key Mismatch When Loading Pretrained ResNet Weights into FLAIRModel #8

Open
kannyjyk opened this issue Jan 18, 2025 · 1 comment

Comments

@kannyjyk
Copy link

I encountered a RuntimeError when attempting to load pretrained ResNet weights into the FLAIRModel. The issue arises due to a mismatch between the keys in the state_dict from the weights file and the model’s state_dict. Specifically, the following error is raised:

RuntimeError: Error(s) in loading state_dict for FLAIRModel:
    Unexpected key(s) in state_dict: "text_model.model.embeddings.position_ids".

Steps to Reproduce

  • Download pretrained ResNet weights (e.g., IMAGENET1K_V1).
  • Use the weights file to initialize the FLAIRModel.
  • Call load_from_pretrained() with the downloaded weights.
  • Observe the RuntimeError.

Expected Behavior

The pretrained weights should load into the model without key mismatches.

Actual Behavior

The loading process fails due to the key "text_model.model.embeddings.position_ids" being present in the weights state_dict, but not in the model’s state_dict.

Questions

  • Is this mismatch expected, and if so, what is the recommended way to handle such keys?
  • Should I manually filter out mismatched keys from the state_dict, or is there an existing utility in PyTorch or FLAIR to handle this?
  • Is there a specific version of the ResNet weights that is compatible with the FLAIRModel?

Any guidance or suggestions would be greatly appreciated. Thank you!!!

@jusiro
Copy link
Owner

jusiro commented Feb 3, 2025

Hi @kannyjyk ,

ImageNet weights do not contain information for the text encoder, raising the error in state dict key mismatch. Thus, this mismatch is expected. If you want to load some model weights without requiring the same weights at the state dict for the target model, you can use: model.load_state_dict(state_dict, strict=False). However, if you want to use FLAIR with ImageNet's pre-trained weights, you should do so carefully. First, the projection layers will have random weights, and the text encoder won't be pre-trained using fundus data. Still, if you want to evaluate the transferability of ImageNet weights within the same framework, you can check the pipeline in main_transferability.py, by setting the keys: --init_imagenet True --load_weights False --project_features False --norm_features False.

Kind regards.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants