Skip to content

HiDream LoRA support #11383

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
vladmandic opened this issue Apr 22, 2025 · 18 comments · Fixed by #11532
Closed

HiDream LoRA support #11383

vladmandic opened this issue Apr 22, 2025 · 18 comments · Fixed by #11532
Assignees
Labels
bug Something isn't working enhancement New feature or request

Comments

@vladmandic
Copy link
Contributor

Describe the bug

i've tried loras currently available on civitai and most appear to load without any errors using load_lora_weights.
but later, they do not appear in get_list_adapters or get_active_adapters and do not seem to be applied to the model.
if loras were not loaded, i'd expect to see some error?
some return error Invalid LoRA checkpoint, but most do not return any error.

list of currently available loras on civitai: https://civitai.com/search/models?baseModel=HiDream&modelType=LORA&sortBy=models_v9

Reproduction

N/A

Logs

System Info

diffusers==e30d3bf5442fbdbee899e8a5da0b11b621d54f1b

Who can help?

@linoytsaban @sayakpaul

@vladmandic vladmandic added the bug Something isn't working label Apr 22, 2025
@linoytsaban linoytsaban added the enhancement New feature or request label Apr 22, 2025
@linoytsaban
Copy link
Collaborator

Hey @vladmandic, we haven't yet added support for HiDream LoRAs that are not in diffusers format, we're working on it at the moment and should be supported soon :)

@vladmandic
Copy link
Contributor Author

which is fine, but why does load_lora_weights returns without any errors for 80% of them?
if lora is not loaded and added to list, that method should not be a silent black hole.

@linoytsaban
Copy link
Collaborator

yeah it should throw an error if keys are incompatible, could you please provide code snippet for the LoRA you're trying to load that doesn't load and doesn't error either?

@vladmandic
Copy link
Contributor Author

vladmandic commented Apr 23, 2025

@linoytsaban almost every lora from the link i provided.
but heres a specific example: https://civitai.com/api/download/models/1686777?type=Model&format=SafeTensor
regarding code:

pipe.load_lora_weight(filename, adapter_name="test") # no error
print(pipe.get_list_adapters()) # empty list
pipe.set_adapters(adapter_names=["test"], adapter_weights=[1.0]) # fails as lora is not loaded

@sayakpaul
Copy link
Member

What model is it? Full or Dev?

Could you please provide two things?

  • LoRA file on the Hub (just makes it easier for the folks who use server machines to test things). It helps to view the structure of the LoRA file directly on the Hub.
  • A minimal, fully reproducible snippet.

I did the following and it logged:

from diffusers import DiffusionPipeline
from transformers import PreTrainedTokenizerFast, LlamaForCausalLM
import torch 

repo_id = "HiDream-ai/HiDream-I1-Full"

tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
text_encoder_4 = LlamaForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3.1-8B-Instruct",
    output_hidden_states=True,
    output_attentions=True,
    torch_dtype=torch.bfloat16,
)

pipeline = DiffusionPipeline.from_pretrained(
    repo_id, 
    tokenizer_4=tokenizer_4,
    text_encoder_4=text_encoder_4,
    torch_dtype=torch.bfloat16
).to("cuda")
pipeline.load_lora_weights("sayakpaul/different-lora-from-civitai", weight_name="486_hidream.safetensors")

Log:

No LoRA keys associated to HiDreamImageTransformer2DModel found with the prefix='transformer'. This is safe to ignore if LoRA state dict didn't originally have any HiDreamImageTransformer2DModel related params. You can also try specifying `prefix=None` to resolve the warning. Otherwise, open an issue if you think it's unexpected: https://github.com/huggingface/diffusers/issues/new

@vladmandic
Copy link
Contributor Author

i used hidream-i1-full.
regarding upload on hub, i understand your position but do note that civitai is defacto standard for anything regarding image models and. if you cannot access the lora on civitai we have a bigger problem.

your snippet is fine.
btw, what level is that log message? because diffusers can be chatty and i disable everything below warning.

@sayakpaul
Copy link
Member

sayakpaul commented Apr 23, 2025

regarding upload on hub, i understand your position but do note that civitai is defacto standard for anything regarding image models and. if you cannot access the lora on civitai we have a bigger problem.

Not to digress, https://huggingface.co/spaces/sayakpaul/civitai-to-hub could help.

btw, what level is that log message? because diffusers can be chatty and i disable everything below warning.

It's logger.warning():

logger.warning(

@mukundkhanna123
Copy link

Hi, I am facing the same issue. Here's a minimal reproducible example:

import torch
from diffusers import (
    AutoencoderKL,
    FlowMatchEulerDiscreteScheduler,
    HiDreamImagePipeline,
    HiDreamImageTransformer2DModel,
)
from transformers import PreTrainedTokenizerFast, LlamaForCausalLM

llama_repo = "meta-llama/Meta-Llama-3.1-8B-Instruct"
model_id = "HiDream-ai/HiDream-I1-Full"

tokenizer_4 = PreTrainedTokenizerFast.from_pretrained(
    llama_repo,
)
text_encoder_4 = LlamaForCausalLM.from_pretrained(
    llama_repo,
    output_hidden_states=True,
    output_attentions=True,
    torch_dtype=torch.bfloat16,
)

transformer = HiDreamImageTransformer2DModel.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    subfolder="transformer",
)
pipeline = HiDreamImagePipeline.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    tokenizer_4=tokenizer_4,
    text_encoder_4=text_encoder_4,
    transformer=transformer,
)
pipeline.load_lora_weights(
    "sayakpaul/different-lora-from-civitai",
    weight_name="486_hidream.safetensors",
)

print(pipeline.get_list_adapters())

This fails for me with the following error:

ValueError: `state_dict` should be empty at this point but has state_dict.keys()=dict_keys(['double_stream_blocks.0.block.adaLN_modulation.1.lora_A.weight', 'double_stream_blocks.0.block.adaLN_modulation.1.lora_B.weight', ... single_stream_blocks.9.block.ff_i.shared_experts.w3.lora_A.weight', 'single_stream_blocks.9.block.ff_i.shared_experts.w3.lora_B.weight'])

Additionally, when I load a LoRA that I trained myself (using diffusers), it doesn't throw an error — but:

  • pipeline.get_list_adapters() shows an empty list,
  • and the final generated image is identical to the base model output, suggesting that the LoRA was not actually applied during generation.

Would appreciate any help on this!

@sayakpaul
Copy link
Member

@mukundkhanna123 that is because the LoRA isn't supported yet but will be supported soon. Additionally, if you use the LoRA mentioned in the snippet, it shouldn't error out. Instead, it should lead to the warning mentioned in #11383 (comment). Did you make any changes to the diffusers codebase?

Additionally, when I load a LoRA that I trained myself (using diffusers), it doesn't throw an error — but:

Can you please open a separate issue for that? Cc: @linoytsaban

@sayakpaul sayakpaul self-assigned this Apr 27, 2025
@mukundkhanna123
Copy link

Okay I used the latest commit bd96a08 and it didn't error out with the lora mentioned in the snippet and just gave the warning. I am able to load the LoRA I trained as well ( trained with diffusers). However there is no change in the output image.

@sayakpaul
Copy link
Member

I am able to load the LoRA I trained as well ( trained with diffusers). However, there is no change in the output image.

Could you please create a new issue for that and tag me and @linoytsaban? Linoy has been training a bunch of LoRAs with our script and the resultant LoRAs work. Here is an example: https://huggingface.co/linoyts/hidream-3dicon-lora

@mukundkhanna123
Copy link

Hi, thank you for taking the time to reply. I updated the LoRA script to do diffusion training on Hidream using a LoRA. I am seeing that my grad norms for all lora_A weights are 0. That is why the weights are not getting updated. I have tried debugging but I am not able to solve this. The LoRA parameters all have required_grad=True and are passed to the optimizer.

If you guys could help with this then that would be really helpful!. Thank you

@linoytsaban
Copy link
Collaborator

Hey @mukundkhanna123, can you share the configuration you used for training?

@mukundkhanna123
Copy link

I've tried multiple configurations all with similar results, used a learning rate in the range of 1e-4 to 1e-5, tried setting weight decay to 0 and using the default value, using a batch size of 96 with gradient accumulation and no warmup with AdamW optimizer, the beta for dpo is set to 2000 and I'm using a rank of 64. The base model is HiDream-I1-Full

@linoytsaban
Copy link
Collaborator

I've trained on multiple concepts and didn't experience what you're describing so it's hard to pinpoint the issue.
could you try the following config and see if you can replicate the results or not?

import os
os.environ['MODEL_NAME'] = "HiDream-ai/HiDream-I1-Full"
os.environ['DATASET_NAME'] ="Norod78/Yarn-art-style"
os.environ['OUTPUT_DIR'] = "hidream-yarn-art-lora-v2-trainer"

!accelerate launch train_dreambooth_lora_hidream.py \
  --pretrained_model_name_or_path=$MODEL_NAME  \
  --dataset_name=$DATASET_NAME \
  --output_dir=$OUTPUT_DIR \
  --mixed_precision="bf16" \
  --lora_layers="to_k,to_q,to_v,to_out"\
  --instance_prompt="a dog, yarn art style" \
  --validation_prompt="yoda, yarn art style" \
  --caption_column="text" \
  --resolution=1024 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=1 \
  --use_8bit_adam\
  --rank=16 \
  --learning_rate=1e-4 \
  --report_to="wandb" \
  --lr_scheduler="constant_with_warmup" \
  --lr_warmup_steps=200 \
  --max_train_steps=1000 \
  --validation_epochs=25 \
  --seed="0" \
  --push_to_hub

trained weights are here
Image

@mukundkhanna123
Copy link

@linoytsaban I am doing LoRA training along with the loss function of DPO. Attaching a snippet of my training loop

for epoch in range(first_epoch, args.num_train_epochs):
        transformer.train()

        for step, batch in enumerate(train_dataloader):
            models_to_accumulate = [transformer]
            prompts = batch["prompts"]

            with accelerator.accumulate(models_to_accumulate):
                
                # Convert images to latent space
                if args.offload:
                    vae = vae.to(accelerator.device)
                pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
                feed_pixel_values = torch.cat(pixel_values.chunk(2, dim=1))
                latents = []
                for i in range(0, feed_pixel_values.shape[0], args.vae_encode_batch_size):
                    latents.append(
                        vae.encode(feed_pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample()
                    )
                model_input = torch.cat(latents, dim=0)
                model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
                model_input = model_input.to(dtype=weight_dtype)
                if args.offload:
                    vae = vae.to("cpu")
                
                # Sample noise that we'll add to the latents
                noise = torch.randn_like(model_input).chunk(2)[0].repeat(2,1,1,1)
                bsz = model_input.shape[0] // 2
                        
                # Sample a random timestep for each image
                # for weighting schemes where we sample timesteps non-uniformly
                u = compute_density_for_timestep_sampling(
                    weighting_scheme=args.weighting_scheme,
                    batch_size=bsz,
                    logit_mean=args.logit_mean,
                    logit_std=args.logit_std,
                    mode_scale=args.mode_scale,
                )
                indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
                timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device).repeat(2)
                # Add noise according to flow matching.ƒ
                # zt = (1 - texp) * x + texp * z1
                sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
                noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
                
                
                t5_prompt_embeds, llama3_prompt_embeds, pooled_prompt_embeds = compute_text_embeddings(
                     prompts, text_encoding_pipeline
                 )
                t5_prompt_embeds = t5_prompt_embeds.repeat(2, 1, 1).to(accelerator.device, dtype=weight_dtype)
                llama3_prompt_embeds = llama3_prompt_embeds.repeat(1, 2, 1, 1).to(accelerator.device, dtype=weight_dtype)
                pooled_prompt_embeds = pooled_prompt_embeds.repeat(2, 1).to(accelerator.device, dtype=weight_dtype) 
                # Predict the noise residual
                
                model_pred = transformer(
                    hidden_states=noisy_model_input,
                    encoder_hidden_states_t5=t5_prompt_embeds,
                    encoder_hidden_states_llama3=llama3_prompt_embeds,
                    pooled_embeds=pooled_prompt_embeds,
                    timesteps=timesteps,
                    return_dict=False,
                )[0]
                model_pred = model_pred * -1
                
                # these weighting schemes use a uniform timestep sampling
                # and instead post-weight the loss
                weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
                target = noise - model_input
                # Compute loss.
                model_losses = F.mse_loss(model_pred.float(), target.float(), reduction="none")
                model_losses = model_losses * weighting.float()
                model_losses = model_losses.mean(dim=list(range(1, len(model_losses.shape))))
                
                model_losses_w, model_losses_l = model_losses.chunk(2)
                # For logging
                raw_model_loss = model_losses.mean()
                model_diff = model_losses_w - model_losses_l  # These are both LBS (as is t)
                
                accelerator.unwrap_model(transformer).disable_adapters()

                with torch.no_grad():
                    ref_pred = transformer(
                        hidden_states=noisy_model_input,
                        encoder_hidden_states_t5=t5_prompt_embeds,
                        encoder_hidden_states_llama3=llama3_prompt_embeds,
                        pooled_embeds=pooled_prompt_embeds,
                        timesteps=timesteps,
                        return_dict=False,
                    )[0]
                    ref_pred = ref_pred * -1 
                    
                    ref_losses = F.mse_loss(ref_pred.float(), target.float(), reduction="none")
                    ref_losses = ref_losses * weighting.float()
                    ref_losses = ref_losses.mean(dim=list(range(1, len(ref_losses.shape))))
                    
                    ref_losses_w, ref_losses_l = ref_losses.chunk(2)
                    raw_ref_loss = ref_losses.mean()
                    ref_diff = ref_losses_w - ref_losses_l
                
                # Re-enable adapters.
                accelerator.unwrap_model(transformer).enable_adapters()

                # Final loss.
                logits = ref_diff - model_diff
                if args.loss_type == "logsigmoid":
                    loss = -1 * F.logsigmoid(0.5 * args.beta_dpo * logits).mean()         
                elif args.loss_type == 'sigmoid':
                    loss = F.sigmoid(args.beta_dpo * logits).mean()
                elif args.loss_type == "hinge":
                    loss = torch.relu(1 - args.beta_dpo * logits).mean()
                elif args.loss_type == "ipo":
                    losses = (logits - 1 / (2 * args.beta)) ** 2
                    loss = losses.mean()
                else:
                    raise ValueError(f"Unknown loss type {args.loss_type}")
                
                implicit_acc = (logits > 0).sum().float() / logits.size(0)
                implicit_acc += 0.5 * (logits == 0).sum().float() / logits.size(0)
                
                accelerator.backward(loss)
                if accelerator.sync_gradients:
                    params_to_clip = transformer.parameters()
                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
                
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad() 

@sayakpaul
Copy link
Member

sayakpaul commented Apr 28, 2025

Then we're digressing. We cannot really control the code we haven't written and will request you to open a "Discussion" instead. It would be helpful for everyone if you respected that.

@sayakpaul
Copy link
Member

@vladmandic sorry for the spam here on this issue. I am at ICLR, hence there's no update on this issue yet. I will get to it soon and update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants