Skip to content

feat: pipeline-level quantization config #11130

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 39 commits into from
May 9, 2025
Merged

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Mar 21, 2025

What does this PR do?

See: #10327

TL;DR: This PR adds support to apply a quantization config when doing DiffusionPipeline.from_pretrained(...), thereby making it easier for the users to benefit from quantization.

Why

To apply quantization to a DiffusionPipeline, a user has to first initialize the models they want to quantize with desired quantization_configs:

quant_config = TransformersBitsAndBytesConfig(load_in_8bit=True,)

text_encoder_2_8bit = T5EncoderModel.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    subfolder="text_encoder_2",
    quantization_config=quant_config,
    torch_dtype=torch.float16,
)

quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True,)
transformer_8bit = FluxTransformer2DModel.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    subfolder="transformer",
    quantization_config=quant_config,
    torch_dtype=torch.float16,
)

pipe = DiffusionPipeline.from_pretrained(
    ..., transformer=transformer_8bit, text_encoder_2=text_encoder_2_8bit
)

This is cumbersome.

What

@SunMarc and I worked on this PR to show the kind of simple changes we need to enable a user to pass a quantization config directly whilst doing DiffusionPipeline.from_pretrained(..., quantization_config=...). The user experience now becomes:

pipeline_quant_config = PipelineQuantizationConfig(
    quant_backend="bitsandbytes_8bit",
    quant_kwargs={"load_in_8bit": True},
    components_to_quantize=["text_encoder_2", "transformer"]
)
pipe = DiffusionPipeline.from_pretrained(
    ...,
    quantization_config=pipeline_quant_config
)

Users can specify granular level quantization mapping too:

quant_mapping = {
    "transformer": DiffBitsAndBytesConfig(
        load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
     ),
    "text_encoder_2": TranBitsAndBytesConfig(
        load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
     ),
}
pipeline_quant_config = PipelineQuantizationConfig(quant_mapping=quant_mapping)

This is particularly helpful when using different quantization backends for different modules (below we show a combination of Quanto and BitsAndBytes):

quant_mapping = {
    "transformer": QuantoConfig(weights_dtype="float8"),
    "text_encoder_2": TranBitsAndBytesConfig(
        load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
     ),
}
pipeline_quant_config = PipelineQuantizationConfig(quant_mapping=quant_mapping)

Here's a script that might be helpful for others to test:

code
from diffusers.quantizers import PipelineQuantizationConfig
from diffusers import DiffusionPipeline
import argparse
import torch


def get_global_config():
    quant_config = PipelineQuantizationConfig(
        quant_backend="bitsandbytes_4bit",
        quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
        components_to_quantize=["transformer", "text_encoder_2"],
    )
    return quant_config


def get_granular_config(use_quanto=False):
    from diffusers import BitsAndBytesConfig as DiffBitsAndBytesConfig, QuantoConfig
    from transformers import BitsAndBytesConfig as TranBitsAndBytesConfig

    transformer_config = (
        QuantoConfig(weights_dtype="float8")
        if use_quanto
        else DiffBitsAndBytesConfig(
            load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
        )
    )

    quant_mapping = {
        "transformer": transformer_config,
        "text_encoder_2": TranBitsAndBytesConfig(
            load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
        ),
    }
    quant_config = PipelineQuantizationConfig(quant_mapping=quant_mapping)
    return quant_config


def load_pipeline(quant_config):
    pipe = DiffusionPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-dev", quantization_config=quant_config, torch_dtype=torch.bfloat16
    ).to("cuda")
    return pipe


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--use_global_config", action="store_true")
    parser.add_argument("--use_quanto", action="store_true")
    args = parser.parse_args()

    quant_config = get_global_config() if args.use_global_config else get_granular_config(args.use_quanto)
    pipe = load_pipeline(quant_config)

    pipe_kwargs = {
        "prompt": "A cat holding a sign that says hello world",
        "height": 1024,
        "width": 1024,
        "guidance_scale": 3.5,
        "num_inference_steps": 50,
        "max_sequence_length": 512,
    }

    image = pipe(**pipe_kwargs, generator=torch.manual_seed(0)).images[0]
    image.save(f"quant_global@{args.use_global_config}_quanto@{args.use_quanto}.png")

Cc: @asomoza if you want to test this out :)

TODOs

  • Docs
  • Tests
  • Enable testing in CI

sayakpaul and others added 5 commits March 20, 2025 20:22
Co-authored-by: SunMarc <marc.sun@hotmail.fr>

condition better.

support mapping.

improvements.

[Quantization] Add Quanto backend (#10756)

* update

* updaet

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* Update docs/source/en/quantization/quanto.md

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* Update src/diffusers/quantizers/quanto/utils.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* update

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

[Single File] Add single file loading for SANA Transformer (#10947)

* added support for from_single_file

* added diffusers mapping script

* added testcase

* bug fix

* updated tests

* corrected code quality

* corrected code quality

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

[LoRA] Improve warning messages when LoRA loading becomes a no-op (#10187)

* updates

* updates

* updates

* updates

* notebooks revert

* fix-copies.

* seeing

* fix

* revert

* fixes

* fixes

* fixes

* remove print

* fix

* conflicts ii.

* updates

* fixes

* better filtering of prefix.

---------

Co-authored-by: hlky <hlky@hlky.ac>

[LoRA] CogView4 (#10981)

* update

* make fix-copies

* update

[Tests] improve quantization tests by additionally measuring the inference memory savings (#11021)

* memory usage tests

* fixes

* gguf

[`Research Project`] Add AnyText: Multilingual Visual Text Generation And Editing (#8998)

* Add initial template

* Second template

* feat: Add TextEmbeddingModule to AnyTextPipeline

* feat: Add AuxiliaryLatentModule template to AnyTextPipeline

* Add bert tokenizer from the anytext repo for now

* feat: Update AnyTextPipeline's modify_prompt method

This commit adds improvements to the modify_prompt method in the AnyTextPipeline class. The method now handles special characters and replaces selected string prompts with a placeholder. Additionally, it includes a check for Chinese text and translation using the trans_pipe.

* Fill in the `forward` pass of `AuxiliaryLatentModule`

* `make style && make quality`

* `chore: Update bert_tokenizer.py with a TODO comment suggesting the use of the transformers library`

* Update error handling to raise and logging

* Add `create_glyph_lines` function into `TextEmbeddingModule`

* make style

* Up

* Up

* Up

* Up

* Remove several comments

* refactor: Remove ControlNetConditioningEmbedding and update code accordingly

* Up

* Up

* up

* refactor: Update AnyTextPipeline to include new optional parameters

* up

* feat: Add OCR model and its components

* chore: Update `TextEmbeddingModule` to include OCR model components and dependencies

* chore: Update `AuxiliaryLatentModule` to include VAE model and its dependencies for masked image in the editing task

* `make style`

* refactor: Update `AnyTextPipeline`'s docstring

* Update `AuxiliaryLatentModule` to include info dictionary so that text processing is done once

* simplify

* `make style`

* Converting `TextEmbeddingModule` to ordinary `encode_prompt()` function

* Simplify for now

* `make style`

* Up

* feat: Add scripts to convert AnyText controlnet to diffusers

* `make style`

* Fix: Move glyph rendering to `TextEmbeddingModule` from `AuxiliaryLatentModule`

* make style

* Up

* Simplify

* Up

* feat: Add safetensors module for loading model file

* Fix device issues

* Up

* Up

* refactor: Simplify

* refactor: Simplify code for loading models and handling data types

* `make style`

* refactor: Update to() method in FrozenCLIPEmbedderT3 and TextEmbeddingModule

* refactor: Update dtype in embedding_manager.py to match proj.weight

* Up

* Add attribution and adaptation information to pipeline_anytext.py

* Update usage example

* Will refactor `controlnet_cond_embedding` initialization

* Add `AnyTextControlNetConditioningEmbedding` template

* Refactor organization

* style

* style

* Move custom blocks from `AuxiliaryLatentModule` to `AnyTextControlNetConditioningEmbedding`

* Follow one-file policy

* style

* [Docs] Update README and pipeline_anytext.py to use AnyTextControlNetModel

* [Docs] Update import statement for AnyTextControlNetModel in pipeline_anytext.py

* [Fix] Update import path for ControlNetModel, ControlNetOutput in anytext_controlnet.py

* Refactor AnyTextControlNet to use configurable conditioning embedding channels

* Complete control net conditioning embedding in AnyTextControlNetModel

* up

* [FIX] Ensure embeddings use correct device in AnyTextControlNetModel

* up

* up

* style

* [UPDATE] Revise README and example code for AnyTextPipeline integration with DiffusionPipeline

* [UPDATE] Update example code in anytext.py to use correct font file and improve clarity

* down

* [UPDATE] Refactor BasicTokenizer usage to a new Checker class for text processing

* update pillow

* [UPDATE] Remove commented-out code and unnecessary docstring in anytext.py and anytext_controlnet.py for improved clarity

* [REMOVE] Delete frozen_clip_embedder_t3.py as it is in the anytext.py file

* [UPDATE] Replace edict with dict for configuration in anytext.py and RecModel.py for consistency

* 🆙

* style

* [UPDATE] Revise README.md for clarity, remove unused imports in anytext.py, and add author credits in anytext_controlnet.py

* style

* Update examples/research_projects/anytext/README.md

Co-authored-by: Aryan <contact.aryanvs@gmail.com>

* Remove commented-out image preparation code in AnyTextPipeline

* Remove unnecessary blank line in README.md

[Quantization] Allow loading TorchAO serialized Tensor objects with torch>=2.6  (#11018)

* update

* update

* update

* update

* update

* update

* update

* update

* update

fix: mixture tiling sdxl pipeline - adjust gerating time_ids & embeddings  (#11012)

small fix on generating time_ids & embeddings

[LoRA] support wan i2v loras from the world. (#11025)

* support wan i2v loras from the world.

* remove copied from.

* upates

* add lora.

Fix SD3 IPAdapter feature extractor (#11027)

chore: fix help messages in advanced diffusion examples (#10923)

Fix missing **kwargs in lora_pipeline.py (#11011)

* Update lora_pipeline.py

* Apply style fixes

* fix-copies

---------

Co-authored-by: hlky <hlky@hlky.ac>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

Fix for multi-GPU WAN inference (#10997)

Ensure that hidden_state and shift/scale are on the same device when running with multiple GPUs

Co-authored-by: Jimmy <39@🇺🇸.com>

[Refactor] Clean up import utils boilerplate (#11026)

* update

* update

* update

Use `output_size` in `repeat_interleave` (#11030)

[hybrid inference 🍯🐝] Add VAE encode (#11017)

* [hybrid inference 🍯🐝] Add VAE encode

* _toctree: add vae encode

* Add endpoints, tests

* vae_encode docs

* vae encode benchmarks

* api reference

* changelog

* Update docs/source/en/hybrid_inference/overview.md

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

Wan Pipeline scaling fix, type hint warning, multi generator fix (#11007)

* Wan Pipeline scaling fix, type hint warning, multi generator fix

* Apply suggestions from code review

[LoRA] change to warning from info when notifying the users about a LoRA no-op (#11044)

* move to warning.

* test related changes.

Rename Lumina(2)Text2ImgPipeline -> Lumina(2)Pipeline (#10827)

* Rename Lumina(2)Text2ImgPipeline -> Lumina(2)Pipeline

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>

making ```formatted_images``` initialization compact (#10801)

compact writing

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>

Fix aclnnRepeatInterleaveIntWithDim error on NPU for get_1d_rotary_pos_embed (#10820)

* get_1d_rotary_pos_embed support npu

* Update src/diffusers/models/embeddings.py

---------

Co-authored-by: Kai zheng <kaizheng@KaideMacBook-Pro.local>
Co-authored-by: hlky <hlky@hlky.ac>
Co-authored-by: YiYi Xu <yixu310@gmail.com>

[Tests] restrict memory tests for quanto for certain schemes. (#11052)

* restrict memory tests for quanto for certain schemes.

* Apply suggestions from code review

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* fixes

* style

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

[LoRA] feat: support non-diffusers wan t2v loras. (#11059)

feat: support non-diffusers wan t2v loras.

[examples/controlnet/train_controlnet_sd3.py] Fixes #11050 - Cast prompt_embeds and pooled_prompt_embeds to weight_dtype to prevent dtype mismatch (#11051)

Fix: dtype mismatch of prompt embeddings in sd3 controlnet training

Co-authored-by: Andreas Jörg <andreasjoerg@MacBook-Pro-von-Andreas-2.fritz.box>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

reverts accidental change that removes attn_mask in attn. Improves fl… (#11065)

reverts accidental change that removes attn_mask in attn. Improves flux ptxla by using flash block sizes. Moves encoding outside the for loop.

Co-authored-by: Juan Acevedo <jfacevedo@google.com>

Fix deterministic issue when getting pipeline dtype and device (#10696)

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

[Tests] add requires peft decorator. (#11037)

* add requires peft decorator.

* install peft conditionally.

* conditional deps.

Co-authored-by: DN6 <dhruv.nair@gmail.com>

---------

Co-authored-by: DN6 <dhruv.nair@gmail.com>

CogView4 Control Block (#10809)

* cogview4 control training

---------

Co-authored-by: OleehyO <leehy0357@gmail.com>
Co-authored-by: yiyixuxu <yixu310@gmail.com>

[CI] pin transformers version for benchmarking. (#11067)

pin transformers version for benchmarking.

updates

Fix Wan I2V Quality (#11087)

* fix_wan_i2v_quality

* Update src/diffusers/pipelines/wan/pipeline_wan_i2v.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/wan/pipeline_wan_i2v.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/wan/pipeline_wan_i2v.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update pipeline_wan_i2v.py

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: hlky <hlky@hlky.ac>

LTX 0.9.5 (#10968)

* update

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: hlky <hlky@hlky.ac>

make PR GPU tests conditioned on styling. (#11099)

Group offloading improvements (#11094)

update

Fix pipeline_flux_controlnet.py (#11095)

* Fix pipeline_flux_controlnet.py

* Fix style

update readme instructions. (#11096)

Co-authored-by: Juan Acevedo <jfacevedo@google.com>

Resolve stride mismatch in UNet's ResNet to support Torch DDP (#11098)

Modify UNet's ResNet implementation to resolve stride mismatch in Torch's DDP

Fix Group offloading behaviour when using streams (#11097)

* update

* update

Quality options in `export_to_video` (#11090)

* Quality options in `export_to_video`

* make style

improve more.

add placeholders for docstrings.

formatting.

smol fix.

solidify validation and annotation
Co-authored-by: SunMarc <marc@huggingface.co>
@sayakpaul sayakpaul requested review from DN6, yiyixuxu and hlky March 21, 2025 03:07
self,
quant_backend: str = None,
quant_kwargs: Dict[str, Union[str, float, int, dict]] = None,
modules_to_quantize: Optional[List[str]] = None,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should there be a reasonable default for this? @SunMarc had some ideas around this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking it could be nice to have a class attribute e.g modules_to_quantize in each pipeline. Or we can just create a mapping pipeline <-> modules_to_quantize if you prefer to keep this outside of the class. (e.g just like how peft deal with modules_to_target for loras)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we can go ahead and start adding simple test and a bit of documentation !

@hlky hlky mentioned this pull request Apr 3, 2025
@sayakpaul
Copy link
Member Author

@DN6 it would be great if you could do a first review of this

@SunMarc
Copy link
Member

SunMarc commented Apr 18, 2025

cc @DerekLiu35 might be interesting to you !

@sayakpaul sayakpaul removed the request for review from hlky April 21, 2025 09:07
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the change to pipelines code looks good to me!

the provided code example isn't working though (does not throw an error but no effect) - do we intend to support passing as a dict?

Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good. Could we add a test as well please.

@sayakpaul
Copy link
Member Author

@stevhliu where do you think we should document it?

@stevhliu
Copy link
Member

stevhliu commented May 1, 2025

We can document the two ways to pass quantization configs directly on the overview or we can create a new doc and add it after the overview.

I think adding directly to the overview would probably be the easiest.

It would also be nice to recommend when to use each method. With this new PipelineQuantizationConfig (which seems much simpler), are there still scenarios where a user would want to initialize the model first and then pass a quantization config?

@sayakpaul
Copy link
Member Author

We can document the two ways to pass quantization configs directly on the overview or we can create a new doc and add it after the overview.

I think adding directly to the overview would probably be the easiest.

I feel the same. Done in 872c91e.

With this new PipelineQuantizationConfig (which seems much simpler), are there still scenarios where a user would want to initialize the model first and then pass a quantization config?

Good question. I don't think there would be any need. We will see in due time.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented May 5, 2025

@sayakpaul
I'm not suggesting passing as a dict, I made that comment because in the code you provided for test in the PR description, you passed it as a dict and the pipeline does not throw an error

def get_granular_config(use_quanto=False):
    from diffusers import BitsAndBytesConfig as DiffBitsAndBytesConfig, QuantoConfig
    from transformers import BitsAndBytesConfig as TranBitsAndBytesConfig

    transformer_config = (
        QuantoConfig(weights_dtype="float8")
        if use_quanto
        else DiffBitsAndBytesConfig(
            load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
        )
    )

    quant_config = {
        "transformer": transformer_config,
        "text_encoder_2": TranBitsAndBytesConfig(
            load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
        ),
    }
    return quant_config

@sayakpaul
Copy link
Member Author

@yiyixuxu my bad. I have updated the code snippet. I have also added a check for validating the quantization_config input to DiffusionPipeline (and a test for it).

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot ! LGTM !

Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comments that are not merge blocks. Can be addressed here or in a follow up.

@sayakpaul sayakpaul merged commit 599c887 into main May 9, 2025
33 checks passed
@sayakpaul sayakpaul deleted the feat/pipeline-quant-config branch May 9, 2025 04:34
@sayakpaul
Copy link
Member Author

Thanks all for your comments! Cc: @asomoza @apolinario @linoytsaban for awareness of this feature as it makes the barrier to entry for quantization of a DiffusionPipeline lots easier.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants