Skip to content

[LoRA] Implement hot-swapping of LoRA #9453

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 55 commits into from
Apr 8, 2025

Conversation

BenjaminBossan
Copy link
Member

This PR adds the possibility to hot-swap LoRA adapters. It is WIP.

Description

As of now, users can already load multiple LoRA adapters. They can offload existing adapters or they can unload them (i.e. delete them). However, they cannot "hotswap" adapters yet, i.e. substitute the weights from one LoRA adapter with the weights of another, without the need to create a separate LoRA adapter.

Generally, hot-swapping may not appear not super useful but when the model is compiled, it is necessary to prevent recompilation. See #9279 for more context.

Caveats

To hot-swap a LoRA adapter for another, these two adapters should target exactly the same layers and the "hyper-parameters" of the two adapters should be identical. For instance, the LoRA alpha has to be the same: Given that we keep the alpha from the first adapter, the LoRA scaling would be incorrect for the second adapter otherwise.

Theoretically, we could override the scaling dict with the alpha values derived from the second adapter's config, but changing the dict will trigger a guard for recompilation, defeating the main purpose of the feature.

I also found that compilation flags can have an impact on whether this works or not. E.g. when passing "reduce-overhead", there will be errors of the type:

input name: arg861_1. data pointer changed from 139647332027392 to
139647331054592

I don't know enough about compilation to determine whether this is problematic or not.

Current state

This is obviously WIP right now to collect feedback and discuss which direction to take this. If this PR turns out to be useful, the hot-swapping functions will be added to PEFT itself and can be imported here (or there is a separate copy in diffusers to avoid the need for a min PEFT version to use this feature).

Moreover, more tests need to be added to better cover this feature, although we don't necessarily need tests for the hot-swapping functionality itself, since those tests will be added to PEFT.

Furthermore, as of now, this is only implemented for the unet. Other pipeline components have yet to implement this feature.

Finally, it should be properly documented.

I would like to collect feedback on the current state of the PR before putting more time into finalizing it.

What does this PR do?

Fixes # (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

This PR adds the possibility to hot-swap LoRA adapters. It is WIP.

Description

As of now, users can already load multiple LoRA adapters. They can
offload existing adapters or they can unload them (i.e. delete them).
However, they cannot "hotswap" adapters yet, i.e. substitute the weights
from one LoRA adapter with the weights of another, without the need to
create a separate LoRA adapter.

Generally, hot-swapping may not appear not super useful but when the
model is compiled, it is necessary to prevent recompilation. See huggingface#9279
for more context.

Caveats

To hot-swap a LoRA adapter for another, these two adapters should target
exactly the same layers and the "hyper-parameters" of the two adapters
should be identical. For instance, the LoRA alpha has to be the same:
Given that we keep the alpha from the first adapter, the LoRA scaling
would be incorrect for the second adapter otherwise.

Theoretically, we could override the scaling dict with the alpha values
derived from the second adapter's config, but changing the dict will
trigger a guard for recompilation, defeating the main purpose of the
feature.

I also found that compilation flags can have an impact on whether this
works or not. E.g. when passing "reduce-overhead", there will be errors
of the type:

> input name: arg861_1. data pointer changed from 139647332027392 to
139647331054592

I don't know enough about compilation to determine whether this is
problematic or not.

Current state

This is obviously WIP right now to collect feedback and discuss which
direction to take this. If this PR turns out to be useful, the
hot-swapping functions will be added to PEFT itself and can be imported
here (or there is a separate copy in diffusers to avoid the need for a
min PEFT version to use this feature).

Moreover, more tests need to be added to better cover this feature,
although we don't necessarily need tests for the hot-swapping
functionality itself, since those tests will be added to PEFT.

Furthermore, as of now, this is only implemented for the unet. Other
pipeline components have yet to implement this feature.

Finally, it should be properly documented.

I would like to collect feedback on the current state of the PR before
putting more time into finalizing it.
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for working on this. I left some comments.

@yiyixuxu
Copy link
Collaborator

cc @apolinario
can you take a look at this initial draft?

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Sep 20, 2024

does most lora have same scaling?
just wonder how important (or not important) it is to be able to support hot swap with different scale (without recompile) - maybe more of a question for @apolinario

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Sep 22, 2024

So I played around a little bit, I have two main question:

Do we support hotswap with different lora ranks? the rank config is not checked in the _check_hotswap_configs_compatible step, so it is a bit unclear. However, I would imagine a different rank Lora would most likely trigger recompilation because the weights shapes are different now. If we want to support Lora with different rank, maybe we need to pad the weights to a fixed size.

I think we should also look into supporting hot-swap with different scaling, I checked some popular loras on our hub, I think most of them have different ranks/alphas so this feature will be a lot more impactful if we are able to support different rank & scaling - based on this thread #9279, I understand that the change in the "scaling" dict would trigger a recompilation. But maybe there are ways to avoid it? for example, if scaling value is a tensor, torch.compile will put different guards in it. I played around with this dummy example a little bit

this trigger recompile

import torch
scaling = {}

def fn(x, key):
    return x * scaling[key]


opt_fn = torch.compile(fn, backend="eager")

x = torch.rand(4)

scaling["first"] = 1.0
opt_fn(x, "first")

print(f" finish first run, updating scaling")

scaling["first"] = 2.0
opt_fn(x, "first")

this won't

import torch
scaling = {}

def fn(x, key):
    return x * scaling[key]


opt_fn = torch.compile(fn, backend="eager")

x = torch.rand(4)

scaling["first"] = torch.tensor(1.0)
opt_fn(x, "first")

print(f" finish first run, updating scaling")

scaling["first"] = torch.tensor(2.0)
opt_fn(x, "first")

I'm very excited about having this in diffusers ! think would be a super nice feature, especially for production use case :)

@sayakpaul
Copy link
Member

I agree with your point on supporting LoRAs with different scaling in this context.

With backend="eager", we may not get the full benefits of torch.compile() I think because parts of the graph would run in eager mode and the benefits of a compiled graph could diminish.

A good way to verify it would be to measure the performance of a pipeline with eager torch.compile() and non-eager torch.compile() 👀

Cc: @anijain2305.

If we want to support Lora with different rank, maybe we need to pad the weights to a fixed size.

I will let @BenjaminBossan comment further but this might require a lot of changes within the tuner modules inside peft.

@BenjaminBossan
Copy link
Member Author

Thanks for all the feedback. I haven't forgotten about this PR, I was just occupied with other things. I'll come back to this as soon as I have a bit of time on my hands.

The idea of using a tensor instead of float for scaling is intriguing, thanks for testing it. It might just work OOTB, as torch broadcasts 0-dim tensors automatically. Another possibility would be to multiply the scaling directly into one of the weights, so that the original alpha can be retained, but that is probably very error prone.

Regarding different ranks, I have yet to test that.

@anijain2305
Copy link

Yes, torch.compile(backend="eager") is just for debugging purposes. In order to see benefits of torch.compile, you will have to use the inductor backend. Simply using torch.compile uses Inductor backend. If your model is overhead-bound, you should use torch.compile(mode="reduce-overhead") to use Cudagraphs.

@sayakpaul
Copy link
Member

If different ranks become a problem, then https://huggingface.co/sayakpaul/lower-rank-flux-lora could provide a meaningful direction.

@apolinario
Copy link
Collaborator

Indeed, although avoiding recompilation altogether with different ranks would be even greater for real time swap applications

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Sep 25, 2024

yep can be a nice feature indeed!
but for this PR we should aim to support different ranks without reduced rank as we are targeting for production use cases

@sayakpaul
Copy link
Member

Indeed. For different ranks, things that come to mind:

  1. Scaling value as a torch tensor --> but this was tested with "eager", so practically no torch.compile().
  2. Some form of padding in the parameter space --> I think this could be interesting but I am not sure how much code changes we're talking about here.
  3. Even if we were to pad, what is going to be the maximum length? Should this be requested from the user? I think we cannot know this value beforehand unless a user specifies it. A sensible choice for this value would be the highest rank that a user is expecting in their pool of LoRAs.

@sayakpaul
Copy link
Member

A reverse direction of what I showed in #9453 is also possible (increase the rank of a LoRA):
https://huggingface.co/sayakpaul/flux-lora-resizing#lora-rank-upsampling

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Sep 27, 2024

hi @BenjaminBossan
I tested out padding the weights + using a tensor to store scaling here in this commit c738f14

and they work for the 4 loras I tested (all with different ranks and scaling) - I'm not as familiar with peft and just made enough changes for the purpose of the experiment & provide a reference point, so the code is very hacky there. sorry for that!

to test ,

# testing hotswap PR

# TORCH_LOGS="guards,recompiles" TORCH_COMPILE_DEBUG=1 TORCH_LOGS_OUT=traces.txt python yiyi_test_3.py
from diffusers import DiffusionPipeline
import torch
import time

torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

branch = "test-hotswap"

loras = [
    "Norod78/sd15-megaphone-lora", # rank 16, scaling 0.5
    "artificialguybr/coloringbook-redmond-1-5v-coloring-book-lora-for-liberteredmond-sd-1-5", # rank 64, scaling 1.0
    "Norod78/SD15-Rubber-Duck-LoRA", # rank 16, scaling 0.5
    "wooyvern/sd-1.5-dark-fantasy-1.1", # rank 128, scaling 1.0
]

prompts =[
    "Marge Simpson holding a megaphone in her hand with her town in the background",
    "A lion, minimalist, Coloring Book, ColoringBookAF",
    "The girl with a pearl earring Rubber duck",
    "<lora:fantasyV1.1:1>, a painting of a skeleton with a long cloak and a group of skeletons in a forest with a crescent moon in the background, David Wojnarowicz, dark art, a screenprint, psychedelic art",    
]

def print_rank_scaling(pipe):
    print(f" rank: {pipe.unet.peft_config['default_0'].r}")
    print(f" scaling: {pipe.unet.down_blocks[0].attentions[0].proj_in.scaling}")


# pipe_id = "stabilityai/stable-diffusion-xl-base-1.0"
pipe_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
pipe = DiffusionPipeline.from_pretrained(pipe_id, torch_dtype=torch.float16).to("cuda")

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

pipe.unet = pipe.unet.to(memory_format=torch.channels_last)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)

for i, (lora_repo, prompt) in enumerate(zip(loras, prompts)):
    hotswap = False if i == 0 else True
    print(f"\nProcessing LoRA {i}: {lora_repo}")
    print(f" prompt: {prompt}")
    print(f" hotswap: {hotswap}")
    
    # Start timing for the entire iteration
    start_time = time.time()

    # Load LoRA weights
    pipe.load_lora_weights(lora_repo, hotswap=hotswap, adapter_name = "default_0")
    print_rank_scaling(pipe)

    # Time image generation
    generator = torch.Generator(device="cuda").manual_seed(42)
    generate_start_time = time.time()
    image = pipe(prompt, num_inference_steps=50, generator=generator).images[0]
    generate_time = time.time() - generate_start_time

    # Save the image
    image.save(f"yiyi_test_3_out_{branch}_lora{i}.png")

    # Unload LoRA weights
    pipe.unload_lora_weights()

    # Calculate and print total time for this iteration
    total_time = time.time() - start_time
    
    print(f"Image generation time: {generate_time:.2f} seconds")
    print(f"Total time for LoRA {i}: {total_time:.2f} seconds")

mem_bytes = torch.cuda.max_memory_allocated()
print(f"total Memory: {mem_bytes/(1024*1024):.3f} MB")

output

Loading pipeline components...:  57%|████████████████████████████████████████████████████████████████                                                | 4/7 [00:00<00:00,  4.45it/s]/home/yiyi/diffusers/.venv/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884
  warnings.warn(
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:01<00:00,  5.47it/s]
PyTorch version: 2.4.1+cu121
CUDA available: True
Device: cuda

Processing LoRA 0: Norod78/sd15-megaphone-lora
 prompt: Marge Simpson holding a megaphone in her hand with her town in the background
 hotswap: False
 rank: 16
 scaling: {'default_0': tensor(0.5000, device='cuda:0')}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [01:36<00:00,  1.94s/it]
Image generation time: 99.18 seconds
Total time for LoRA 0: 105.29 seconds

Processing LoRA 1: artificialguybr/coloringbook-redmond-1-5v-coloring-book-lora-for-liberteredmond-sd-1-5
 prompt: A lion, minimalist, Coloring Book, ColoringBookAF
 hotswap: True
 rank: 64
 scaling: {'default_0': tensor(1., device='cuda:0')}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:01<00:00, 44.67it/s]
Image generation time: 2.13 seconds
Total time for LoRA 1: 3.09 seconds

Processing LoRA 2: Norod78/SD15-Rubber-Duck-LoRA
 prompt: The girl with a pearl earring Rubber duck
 hotswap: True
 rank: 16
 scaling: {'default_0': tensor(0.5000, device='cuda:0')}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:01<00:00, 44.73it/s]
Image generation time: 1.23 seconds
Total time for LoRA 2: 1.85 seconds

Processing LoRA 3: wooyvern/sd-1.5-dark-fantasy-1.1
 prompt: <lora:fantasyV1.1:1>, a painting of a skeleton with a long cloak and a group of skeletons in a forest with a crescent moon in the background, David Wojnarowicz, dark art, a screenprint, psychedelic art
 hotswap: True
 rank: 128
 scaling: {'default_0': tensor(1., device='cuda:0')}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:01<00:00, 44.53it/s]
Image generation time: 2.01 seconds
Total time for LoRA 3: 3.46 seconds
total Memory: 3417.621 MB

confirm outputs are same as in main

yiyi_test_3_out_test-hotswap_lora0

yiyi_test_3_out_test-hotswap_lora1

yiyi_test_3_out_test-hotswap_lora2

yiyi_test_3_out_test-hotswap_lora3

@sayakpaul
Copy link
Member

Very cool!

Could you also try logging the traces just to confirm it does not trigger any recompilation?

TORCH_LOGS="guards,recompiles" TORCH_LOGS_OUT=traces.txt python my_code.py

@yiyixuxu
Copy link
Collaborator

I did and it doesn't

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Sep 27, 2024

also, I think, from the user experience perspective, it might be more convenient to have a "hotswap" mode that, once it's on, everything will be hot-swapped by default. I think, it is not something you use on and off, no?

maybe be a question for @apolinario

@sayakpaul
Copy link
Member

also, I think, from the user experience perspective, it might be more convenient to have a "hotswap" mode that, once it's on, everything will be hot-swapped by default. I think, it is not something you use on and off, no?

I think that is the case, yes! I also agree that the ability to hot-swap LoRAs (with torch.compile()) is a far better and more appealing UX.

But just in case it becomes a memory problem, users can explore the LoRA resizing path to have everything to a small unified rank (if it doesn't lead too much quality degradation).

@sayakpaul
Copy link
Member

Replied.

@BenjaminBossan
Copy link
Member Author

Thanks, now the tests are passing again!

I think as a consequence of that PR, there is now an error about requiring make fix-copies for Amused, CogVideoX, etc. mixins. I suppose this should be safe enough to apply but would like to confirm this first.

Should I run make fix-copies?

@sayakpaul
Copy link
Member

Should I run make fix-copies?

I think this should be fine. Ccing @stevhliu to decide where this hotswap feature should best reside.

@stevhliu
Copy link
Member

Nice! Let's add the hotswap feature in this LoRA section

@sayakpaul
Copy link
Member

@yiyixuxu could you review this once more? The failing tests are unrelated.

@BenjaminBossan
Copy link
Member Author

@stevhliu I added a section as you suggested.

Copy link
Member

@stevhliu stevhliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super nice, thank you @BenjaminBossan!

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
jonluca added a commit to weights-ai/diffusers that referenced this pull request Mar 20, 2025
* [WIP][LoRA] Implement hot-swapping of LoRA

This PR adds the possibility to hot-swap LoRA adapters. It is WIP.

Description

As of now, users can already load multiple LoRA adapters. They can
offload existing adapters or they can unload them (i.e. delete them).
However, they cannot "hotswap" adapters yet, i.e. substitute the weights
from one LoRA adapter with the weights of another, without the need to
create a separate LoRA adapter.

Generally, hot-swapping may not appear not super useful but when the
model is compiled, it is necessary to prevent recompilation. See huggingface#9279
for more context.

Caveats

To hot-swap a LoRA adapter for another, these two adapters should target
exactly the same layers and the "hyper-parameters" of the two adapters
should be identical. For instance, the LoRA alpha has to be the same:
Given that we keep the alpha from the first adapter, the LoRA scaling
would be incorrect for the second adapter otherwise.

Theoretically, we could override the scaling dict with the alpha values
derived from the second adapter's config, but changing the dict will
trigger a guard for recompilation, defeating the main purpose of the
feature.

I also found that compilation flags can have an impact on whether this
works or not. E.g. when passing "reduce-overhead", there will be errors
of the type:

> input name: arg861_1. data pointer changed from 139647332027392 to
139647331054592

I don't know enough about compilation to determine whether this is
problematic or not.

Current state

This is obviously WIP right now to collect feedback and discuss which
direction to take this. If this PR turns out to be useful, the
hot-swapping functions will be added to PEFT itself and can be imported
here (or there is a separate copy in diffusers to avoid the need for a
min PEFT version to use this feature).

Moreover, more tests need to be added to better cover this feature,
although we don't necessarily need tests for the hot-swapping
functionality itself, since those tests will be added to PEFT.

Furthermore, as of now, this is only implemented for the unet. Other
pipeline components have yet to implement this feature.

Finally, it should be properly documented.

I would like to collect feedback on the current state of the PR before
putting more time into finalizing it.

* Reviewer feedback

* Reviewer feedback, adjust test

* Fix, doc

* Make fix

* Fix for possible g++ error

* Add test for recompilation w/o hotswapping

* Make hotswap work

Requires huggingface/peft#2366

More changes to make hotswapping work. Together with the mentioned PEFT
PR, the tests pass for me locally.

List of changes:

- docstring for hotswap
- remove code copied from PEFT, import from PEFT now
- adjustments to PeftAdapterMixin.load_lora_adapter (unfortunately, some
  state dict renaming was necessary, LMK if there is a better solution)
- adjustments to UNet2DConditionLoadersMixin._process_lora: LMK if this
  is even necessary or not, I'm unsure what the overall relationship is
  between this and PeftAdapterMixin.load_lora_adapter
- also in UNet2DConditionLoadersMixin._process_lora, I saw that there is
  no LoRA unloading when loading the adapter fails, so I added it
  there (in line with what happens in PeftAdapterMixin.load_lora_adapter)
- rewritten tests to avoid shelling out, make the test more precise by
  making sure that the outputs align, parametrize it
- also checked the pipeline code mentioned in this comment:
  huggingface#9453 (comment);
  when running this inside the with
  torch._dynamo.config.patch(error_on_recompile=True) context, there is
  no error, so I think hotswapping is now working with pipelines.

* Address reviewer feedback:

- Revert deprecated method
- Fix PEFT doc link to main
- Don't use private function
- Clarify magic numbers
- Add pipeline test

Moreover:
- Extend docstrings
- Extend existing test for outputs != 0
- Extend existing test for wrong adapter name

* Change order of test decorators

parameterized.expand seems to ignore skip decorators if added in last
place (i.e. innermost decorator).

* Split model and pipeline tests

Also increase test coverage by also targeting conv2d layers (support of
which was added recently on the PEFT PR).

* Reviewer feedback: Move decorator to test classes

... instead of having them on each test method.

* Apply suggestions from code review

Co-authored-by: hlky <hlky@hlky.ac>

* Reviewer feedback: version check, TODO comment

* Add enable_lora_hotswap method

* Reviewer feedback: check _lora_loadable_modules

* Revert changes in unet.py

* Add possibility to ignore enabled at wrong time

* Fix docstrings

* Log possible PEFT error, test

* Raise helpful error if hotswap not supported

I.e. for the text encoder

* Formatting

* More linter

* More ruff

* Doc-builder complaint

* Update docstring:

- mention no text encoder support yet
- make it clear that LoRA is meant
- mention that same adapter name should be passed

* Fix error in docstring

* Update more methods with hotswap argument

- SDXL
- SD3
- Flux

No changes were made to load_lora_into_transformer.

* Add hotswap argument to load_lora_into_transformer

For SD3 and Flux. Use shorter docstring for brevity.

* Extend docstrings

* Add version guards to tests

* Formatting

* Fix LoRA loading call to add prefix=None

See:
huggingface#10187 (comment)

* Run make fix-copies

* Add hot swap documentation to the docs

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

---------

Co-authored-by: Benjamin Bossan <benjamin.bossan@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
Co-authored-by: hlky <hlky@hlky.ac>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
@BenjaminBossan
Copy link
Member Author

Anything else from my side that needs doing? :)

Before merging, I'd suggest running the full test suite to ensure nothing was broken (I'm not sure how much is covered by the standard CI being run on PRs).

@yiyixuxu yiyixuxu removed the wip label Mar 31, 2025
@jonluca
Copy link

jonluca commented Apr 7, 2025

Is this ready to merge?

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Apr 7, 2025

@sayakpaul are we waiting for a transformer PR to be merged first?

@sayakpaul
Copy link
Member

@BenjaminBossan could you resolve the conflicts and rebase with main? I will then run our LoRA tests.

@yiyixuxu I think this can be merged without that PR. I had requested for your one final review here: #9453 (comment) before we merge. So, that would be great.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Apr 8, 2025

@sayakpaul nice, looks good to me!!

@sayakpaul
Copy link
Member

Thanks @yiyixuxu! Will work with Benjamin to do the final merge :)

@BenjaminBossan
Copy link
Member Author

@sayakpaul The PR is updated, the failing tests appear to be unrelated. As mentioned earlier, if there is a way to run the full test suite with GPUs, it would be a good idea to do that too.

Regarding transformers support for hot-swapping, there is currently no work on that but I added it to the PEFT backlog.

@sayakpaul
Copy link
Member

The failing test is unrelated. Running the LoRA GPU tests now.

@sayakpaul
Copy link
Member

Tests are passing including the integration tests. Merging. Thanks a lot for this contribution!

@sayakpaul sayakpaul merged commit fb54499 into huggingface:main Apr 8, 2025
28 of 29 checks passed
@github-project-automation github-project-automation bot moved this from In Progress to Done in Diffusers Roadmap 0.34 Apr 8, 2025
@BenjaminBossan BenjaminBossan deleted the lora-hot-swapping branch April 8, 2025 11:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Archived in project
Development

Successfully merging this pull request may close these issues.