Skip to content

[Single File] Add single file loading for SANA Transformer #10947

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Mar 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
convert_ltx_vae_checkpoint_to_diffusers,
convert_lumina2_to_diffusers,
convert_mochi_transformer_checkpoint_to_diffusers,
convert_sana_transformer_to_diffusers,
convert_sd3_transformer_checkpoint_to_diffusers,
convert_stable_cascade_unet_single_file_to_diffusers,
convert_wan_transformer_to_diffusers,
Expand Down Expand Up @@ -119,6 +120,10 @@
"checkpoint_mapping_fn": convert_lumina2_to_diffusers,
"default_subfolder": "transformer",
},
"SanaTransformer2DModel": {
"checkpoint_mapping_fn": convert_sana_transformer_to_diffusers,
"default_subfolder": "transformer",
},
"WanTransformer3DModel": {
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
"default_subfolder": "transformer",
Expand Down
115 changes: 115 additions & 0 deletions src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,12 @@
"hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
"instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight",
"lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"],
"sana": [
"blocks.0.cross_attn.q_linear.weight",
"blocks.0.cross_attn.q_linear.bias",
"blocks.0.cross_attn.kv_linear.weight",
"blocks.0.cross_attn.kv_linear.bias",
],
"wan": ["model.diffusion_model.head.modulation", "head.modulation"],
"wan_vae": "decoder.middle.0.residual.0.gamma",
}
Expand Down Expand Up @@ -178,6 +184,7 @@
"hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"},
"instruct-pix2pix": {"pretrained_model_name_or_path": "timbrooks/instruct-pix2pix"},
"lumina2": {"pretrained_model_name_or_path": "Alpha-VLLM/Lumina-Image-2.0"},
"sana": {"pretrained_model_name_or_path": "Efficient-Large-Model/Sana_1600M_1024px_diffusers"},
"wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},
"wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
"wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
Expand Down Expand Up @@ -669,6 +676,9 @@ def infer_diffusers_model_type(checkpoint):
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]):
model_type = "lumina2"

elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sana"]):
model_type = "sana"

elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["wan"]):
if "model.diffusion_model.patch_embedding.weight" in checkpoint:
target_key = "model.diffusion_model.patch_embedding.weight"
Expand Down Expand Up @@ -2897,6 +2907,111 @@ def convert_lumina_attn_to_diffusers(tensor, diffusers_key):
return converted_state_dict


def convert_sana_transformer_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {}
keys = list(checkpoint.keys())
for k in keys:
if "model.diffusion_model." in k:
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)

num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "blocks" in k))[-1] + 1 # noqa: C401

# Positional and patch embeddings.
checkpoint.pop("pos_embed")
converted_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
converted_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")

# Timestep embeddings.
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = checkpoint.pop(
"t_embedder.mlp.0.weight"
)
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = checkpoint.pop(
"t_embedder.mlp.2.weight"
)
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
converted_state_dict["time_embed.linear.weight"] = checkpoint.pop("t_block.1.weight")
converted_state_dict["time_embed.linear.bias"] = checkpoint.pop("t_block.1.bias")

# Caption Projection.
checkpoint.pop("y_embedder.y_embedding")
converted_state_dict["caption_projection.linear_1.weight"] = checkpoint.pop("y_embedder.y_proj.fc1.weight")
converted_state_dict["caption_projection.linear_1.bias"] = checkpoint.pop("y_embedder.y_proj.fc1.bias")
converted_state_dict["caption_projection.linear_2.weight"] = checkpoint.pop("y_embedder.y_proj.fc2.weight")
converted_state_dict["caption_projection.linear_2.bias"] = checkpoint.pop("y_embedder.y_proj.fc2.bias")
converted_state_dict["caption_norm.weight"] = checkpoint.pop("attention_y_norm.weight")

for i in range(num_layers):
converted_state_dict[f"transformer_blocks.{i}.scale_shift_table"] = checkpoint.pop(
f"blocks.{i}.scale_shift_table"
)

# Self-Attention
sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"blocks.{i}.attn.qkv.weight"), 3, dim=0)
converted_state_dict[f"transformer_blocks.{i}.attn1.to_q.weight"] = torch.cat([sample_q])
converted_state_dict[f"transformer_blocks.{i}.attn1.to_k.weight"] = torch.cat([sample_k])
converted_state_dict[f"transformer_blocks.{i}.attn1.to_v.weight"] = torch.cat([sample_v])

# Output Projections
converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.weight"] = checkpoint.pop(
f"blocks.{i}.attn.proj.weight"
)
converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.bias"] = checkpoint.pop(
f"blocks.{i}.attn.proj.bias"
)

# Cross-Attention
converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = checkpoint.pop(
f"blocks.{i}.cross_attn.q_linear.weight"
)
converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = checkpoint.pop(
f"blocks.{i}.cross_attn.q_linear.bias"
)

linear_sample_k, linear_sample_v = torch.chunk(
checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.weight"), 2, dim=0
)
linear_sample_k_bias, linear_sample_v_bias = torch.chunk(
checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.bias"), 2, dim=0
)
converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = linear_sample_k
converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = linear_sample_v
converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = linear_sample_k_bias
converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = linear_sample_v_bias

# Output Projections
converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = checkpoint.pop(
f"blocks.{i}.cross_attn.proj.weight"
)
converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = checkpoint.pop(
f"blocks.{i}.cross_attn.proj.bias"
)

# MLP
converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.weight"] = checkpoint.pop(
f"blocks.{i}.mlp.inverted_conv.conv.weight"
)
converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.bias"] = checkpoint.pop(
f"blocks.{i}.mlp.inverted_conv.conv.bias"
)
converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.weight"] = checkpoint.pop(
f"blocks.{i}.mlp.depth_conv.conv.weight"
)
converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.bias"] = checkpoint.pop(
f"blocks.{i}.mlp.depth_conv.conv.bias"
)
converted_state_dict[f"transformer_blocks.{i}.ff.conv_point.weight"] = checkpoint.pop(
f"blocks.{i}.mlp.point_conv.conv.weight"
)

# Final layer
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
converted_state_dict["scale_shift_table"] = checkpoint.pop("final_layer.scale_shift_table")

return converted_state_dict


def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {}

Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/models/transformers/sana_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torch import nn

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention_processor import (
Attention,
Expand Down Expand Up @@ -195,7 +195,7 @@ def forward(
return hidden_states


class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
r"""
A 2D Transformer model introduced in [Sana](https://huggingface.co/papers/2410.10629) family of models.

Expand Down
61 changes: 61 additions & 0 deletions tests/single_file/test_sana_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import gc
import unittest

import torch

from diffusers import (
SanaTransformer2DModel,
)
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_torch_accelerator,
torch_device,
)


enable_full_determinism()


@require_torch_accelerator
class SanaTransformer2DModelSingleFileTests(unittest.TestCase):
model_class = SanaTransformer2DModel
ckpt_path = (
"https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px/blob/main/checkpoints/Sana_1600M_1024px.pth"
)
alternate_keys_ckpt_paths = [
"https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px/blob/main/checkpoints/Sana_1600M_1024px.pth"
]

repo_id = "Efficient-Large-Model/Sana_1600M_1024px_diffusers"

def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)

def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)

def test_single_file_components(self):
model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
model_single_file = self.model_class.from_single_file(self.ckpt_path)

PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert (
model.config[param_name] == param_value
), f"{param_name} differs between single file loading and pretrained loading"

def test_checkpoint_loading(self):
for ckpt_path in self.alternate_keys_ckpt_paths:
torch.cuda.empty_cache()
model = self.model_class.from_single_file(ckpt_path)

del model
gc.collect()
torch.cuda.empty_cache()
Loading