Skip to content

Merge changes #211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 65 commits into from
May 12, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
bd96a08
[train_dreambooth_lora.py] Set LANCZOS as default interpolation mode …
merterbak Apr 26, 2025
aa5f5d4
[tests] add tests to check for graph breaks, recompilation, cuda sync…
sayakpaul Apr 28, 2025
9ce89e2
enable group_offload cases and quanto cases on XPU (#11405)
yao-matrix Apr 28, 2025
a7e9f85
enable test_layerwise_casting_memory cases on XPU (#11406)
yao-matrix Apr 28, 2025
0e3f271
[tests] fix import. (#11434)
sayakpaul Apr 28, 2025
b3b04fe
[train_text_to_image] Better image interpolation in training scripts …
tongyu0924 Apr 28, 2025
3da98e7
[train_text_to_image_lora] Better image interpolation in training scr…
tongyu0924 Apr 28, 2025
7567adf
enable 28 GGUF test cases on XPU (#11404)
yao-matrix Apr 28, 2025
0ac1d5b
[Hi-Dream LoRA] fix bug in validation (#11439)
linoytsaban Apr 28, 2025
4a9ab65
Fixing missing provider options argument (#11397)
urpetkov-amd Apr 28, 2025
58431f1
Set LANCZOS as the default interpolation for image resizing in Contro…
YoulunPeng Apr 29, 2025
8fe5a14
Raise warning instead of error for block offloading with streams (#11…
a-r-r-o-w Apr 30, 2025
60892c5
enable marigold_intrinsics cases on XPU (#11445)
yao-matrix Apr 30, 2025
c865115
`torch.compile` fullgraph compatibility for Hunyuan Video (#11457)
a-r-r-o-w Apr 30, 2025
fbe2fe5
enable consistency test cases on XPU, all passed (#11446)
yao-matrix Apr 30, 2025
35fada4
enable unidiffuser test cases on xpu (#11444)
yao-matrix Apr 30, 2025
fbce7ae
Add generic support for Intel Gaudi accelerator (hpu device) (#11328)
dsocek Apr 30, 2025
8cd7426
Add StableDiffusion3InstructPix2PixPipeline (#11378)
xduzhangjiayu Apr 30, 2025
23c9802
make safe diffusion test cases pass on XPU and A100 (#11458)
yao-matrix Apr 30, 2025
38ced7e
[test_models_transformer_hunyuan_video] help us test torch.compile() …
tongyu0924 Apr 30, 2025
daf0a23
Add LANCZOS as default interplotation mode. (#11463)
Va16hav07 Apr 30, 2025
06beeca
make autoencoders. controlnet_flux and wan_transformer3d_single_file …
yao-matrix Apr 30, 2025
d70f8ee
[WAN] fix recompilation issues (#11475)
sayakpaul May 1, 2025
86294d3
Fix typos in docs and comments (#11416)
co63oc May 1, 2025
5dcdf4a
[tests] xfail recent pipeline tests for specific methods. (#11469)
sayakpaul May 1, 2025
d0c0239
cache packages_distributions (#11453)
vladmandic May 1, 2025
b848d47
[docs] Memory optims (#11385)
stevhliu May 1, 2025
e23705e
[docs] Adapters (#11331)
stevhliu May 2, 2025
ed6cf52
[train_dreambooth_lora_sdxl_advanced] Add LANCZOS as the default inte…
yuanjua May 2, 2025
ec3d582
[train_dreambooth_lora_flux_advanced] Add LANCZOS as the default inte…
ysurs May 2, 2025
a674914
enable semantic diffusion and stable diffusion panorama cases on XPU …
yao-matrix May 5, 2025
8520d49
[Feature] Implement tiled VAE encoding/decoding for Wan model. (#11414)
c8ef May 5, 2025
fc5e906
[train_text_to_image_sdxl]Add LANCZOS as default interpolation mode f…
ParagEkbote May 5, 2025
ec93239
[train_dreambooth_lora_sdxl] Add --image_interpolation_mode option fo…
MinJu-Ha May 5, 2025
ee1516e
[train_dreambooth_lora_lumina2] Add LANCZOS as the default interpolat…
cjfghk5697 May 5, 2025
071807c
[training] feat: enable quantization for hidream lora training. (#11494)
sayakpaul May 5, 2025
9c29e93
Set LANCZOS as the default interpolation method for image resizing. (…
yijun-lee May 5, 2025
ed4efbd
Update training script for txt to img sdxl with lora supp with new in…
RogerSinghChugh May 5, 2025
1fa5639
Fix torchao docs typo for fp8 granular quantization (#11473)
a-r-r-o-w May 6, 2025
53f1043
Update setup.py to pin min version of `peft` (#11502)
sayakpaul May 6, 2025
d88ae1f
update dep table. (#11504)
sayakpaul May 6, 2025
10bee52
[LoRA] use `removeprefix` to preserve sanity. (#11493)
sayakpaul May 6, 2025
d7ffe60
Hunyuan Video Framepack (#11428)
a-r-r-o-w May 6, 2025
8c661ea
enable lora cases on XPU (#11506)
yao-matrix May 6, 2025
7937166
[lora_conversion] Enhance key handling for OneTrainer components in L…
iamwavecut May 6, 2025
fb29132
[docs] minor updates to bitsandbytes docs. (#11509)
sayakpaul May 6, 2025
7b90494
Cosmos (#10660)
a-r-r-o-w May 7, 2025
53bd367
clean up the __Init__ for stable_diffusion (#11500)
yiyixuxu May 7, 2025
87e508f
fix audioldm
sayakpaul May 8, 2025
c5c34a4
Revert "fix audioldm"
sayakpaul May 8, 2025
66e50d4
[LoRA] make lora alpha and dropout configurable (#11467)
linoytsaban May 8, 2025
784db0e
Add cross attention type for Sana-Sprint training in diffusers. (#11514)
scxue May 8, 2025
6674a51
Conditionally import torchvision in Cosmos transformer (#11524)
a-r-r-o-w May 8, 2025
393aefc
[tests] fix audioldm2 for transformers main. (#11522)
sayakpaul May 8, 2025
599c887
feat: pipeline-level quantization config (#11130)
sayakpaul May 9, 2025
7acf834
[Tests] Enable more general testing for `torch.compile()` with LoRA h…
sayakpaul May 9, 2025
0c47c95
[LoRA] support non-diffusers hidream loras (#11532)
sayakpaul May 9, 2025
2d38089
enable 7 cases on XPU (#11503)
yao-matrix May 9, 2025
3c0a012
[LTXPipeline] Update latents dtype to match VAE dtype (#11533)
james-p-xu May 9, 2025
d6bf268
enable dit integration cases on xpu (#11523)
yao-matrix May 9, 2025
0ba1f76
enable print_env on xpu (#11507)
yao-matrix May 9, 2025
92fe689
Change Framepack transformer layer initialization order (#11535)
a-r-r-o-w May 9, 2025
01abfc8
[tests] add tests for framepack transformer model. (#11520)
sayakpaul May 11, 2025
e48f6ae
Hunyuan Video Framepack F1 (#11534)
a-r-r-o-w May 12, 2025
c372615
enable several pipeline integration tests on XPU (#11526)
yao-matrix May 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: pipeline-level quantization config (huggingface#11130)
* feat: pipeline-level quant config.

Co-authored-by: SunMarc <[email protected]>

condition better.

support mapping.

improvements.

[Quantization] Add Quanto backend (huggingface#10756)

* update

* updaet

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* Update docs/source/en/quantization/quanto.md

Co-authored-by: Sayak Paul <[email protected]>

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* Update src/diffusers/quantizers/quanto/utils.py

Co-authored-by: Sayak Paul <[email protected]>

* update

* update

---------

Co-authored-by: Sayak Paul <[email protected]>

[Single File] Add single file loading for SANA Transformer (huggingface#10947)

* added support for from_single_file

* added diffusers mapping script

* added testcase

* bug fix

* updated tests

* corrected code quality

* corrected code quality

---------

Co-authored-by: Dhruv Nair <[email protected]>

[LoRA] Improve warning messages when LoRA loading becomes a no-op (huggingface#10187)

* updates

* updates

* updates

* updates

* notebooks revert

* fix-copies.

* seeing

* fix

* revert

* fixes

* fixes

* fixes

* remove print

* fix

* conflicts ii.

* updates

* fixes

* better filtering of prefix.

---------

Co-authored-by: hlky <[email protected]>

[LoRA] CogView4 (huggingface#10981)

* update

* make fix-copies

* update

[Tests] improve quantization tests by additionally measuring the inference memory savings (huggingface#11021)

* memory usage tests

* fixes

* gguf

[`Research Project`] Add AnyText: Multilingual Visual Text Generation And Editing (huggingface#8998)

* Add initial template

* Second template

* feat: Add TextEmbeddingModule to AnyTextPipeline

* feat: Add AuxiliaryLatentModule template to AnyTextPipeline

* Add bert tokenizer from the anytext repo for now

* feat: Update AnyTextPipeline's modify_prompt method

This commit adds improvements to the modify_prompt method in the AnyTextPipeline class. The method now handles special characters and replaces selected string prompts with a placeholder. Additionally, it includes a check for Chinese text and translation using the trans_pipe.

* Fill in the `forward` pass of `AuxiliaryLatentModule`

* `make style && make quality`

* `chore: Update bert_tokenizer.py with a TODO comment suggesting the use of the transformers library`

* Update error handling to raise and logging

* Add `create_glyph_lines` function into `TextEmbeddingModule`

* make style

* Up

* Up

* Up

* Up

* Remove several comments

* refactor: Remove ControlNetConditioningEmbedding and update code accordingly

* Up

* Up

* up

* refactor: Update AnyTextPipeline to include new optional parameters

* up

* feat: Add OCR model and its components

* chore: Update `TextEmbeddingModule` to include OCR model components and dependencies

* chore: Update `AuxiliaryLatentModule` to include VAE model and its dependencies for masked image in the editing task

* `make style`

* refactor: Update `AnyTextPipeline`'s docstring

* Update `AuxiliaryLatentModule` to include info dictionary so that text processing is done once

* simplify

* `make style`

* Converting `TextEmbeddingModule` to ordinary `encode_prompt()` function

* Simplify for now

* `make style`

* Up

* feat: Add scripts to convert AnyText controlnet to diffusers

* `make style`

* Fix: Move glyph rendering to `TextEmbeddingModule` from `AuxiliaryLatentModule`

* make style

* Up

* Simplify

* Up

* feat: Add safetensors module for loading model file

* Fix device issues

* Up

* Up

* refactor: Simplify

* refactor: Simplify code for loading models and handling data types

* `make style`

* refactor: Update to() method in FrozenCLIPEmbedderT3 and TextEmbeddingModule

* refactor: Update dtype in embedding_manager.py to match proj.weight

* Up

* Add attribution and adaptation information to pipeline_anytext.py

* Update usage example

* Will refactor `controlnet_cond_embedding` initialization

* Add `AnyTextControlNetConditioningEmbedding` template

* Refactor organization

* style

* style

* Move custom blocks from `AuxiliaryLatentModule` to `AnyTextControlNetConditioningEmbedding`

* Follow one-file policy

* style

* [Docs] Update README and pipeline_anytext.py to use AnyTextControlNetModel

* [Docs] Update import statement for AnyTextControlNetModel in pipeline_anytext.py

* [Fix] Update import path for ControlNetModel, ControlNetOutput in anytext_controlnet.py

* Refactor AnyTextControlNet to use configurable conditioning embedding channels

* Complete control net conditioning embedding in AnyTextControlNetModel

* up

* [FIX] Ensure embeddings use correct device in AnyTextControlNetModel

* up

* up

* style

* [UPDATE] Revise README and example code for AnyTextPipeline integration with DiffusionPipeline

* [UPDATE] Update example code in anytext.py to use correct font file and improve clarity

* down

* [UPDATE] Refactor BasicTokenizer usage to a new Checker class for text processing

* update pillow

* [UPDATE] Remove commented-out code and unnecessary docstring in anytext.py and anytext_controlnet.py for improved clarity

* [REMOVE] Delete frozen_clip_embedder_t3.py as it is in the anytext.py file

* [UPDATE] Replace edict with dict for configuration in anytext.py and RecModel.py for consistency

* 🆙

* style

* [UPDATE] Revise README.md for clarity, remove unused imports in anytext.py, and add author credits in anytext_controlnet.py

* style

* Update examples/research_projects/anytext/README.md

Co-authored-by: Aryan <[email protected]>

* Remove commented-out image preparation code in AnyTextPipeline

* Remove unnecessary blank line in README.md

[Quantization] Allow loading TorchAO serialized Tensor objects with torch>=2.6  (huggingface#11018)

* update

* update

* update

* update

* update

* update

* update

* update

* update

fix: mixture tiling sdxl pipeline - adjust gerating time_ids & embeddings  (huggingface#11012)

small fix on generating time_ids & embeddings

[LoRA] support wan i2v loras from the world. (huggingface#11025)

* support wan i2v loras from the world.

* remove copied from.

* upates

* add lora.

Fix SD3 IPAdapter feature extractor (huggingface#11027)

chore: fix help messages in advanced diffusion examples (huggingface#10923)

Fix missing **kwargs in lora_pipeline.py (huggingface#11011)

* Update lora_pipeline.py

* Apply style fixes

* fix-copies

---------

Co-authored-by: hlky <[email protected]>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

Fix for multi-GPU WAN inference (huggingface#10997)

Ensure that hidden_state and shift/scale are on the same device when running with multiple GPUs

Co-authored-by: Jimmy <39@🇺🇸.com>

[Refactor] Clean up import utils boilerplate (huggingface#11026)

* update

* update

* update

Use `output_size` in `repeat_interleave` (huggingface#11030)

[hybrid inference 🍯🐝] Add VAE encode (huggingface#11017)

* [hybrid inference 🍯🐝] Add VAE encode

* _toctree: add vae encode

* Add endpoints, tests

* vae_encode docs

* vae encode benchmarks

* api reference

* changelog

* Update docs/source/en/hybrid_inference/overview.md

Co-authored-by: Sayak Paul <[email protected]>

* update

---------

Co-authored-by: Sayak Paul <[email protected]>

Wan Pipeline scaling fix, type hint warning, multi generator fix (huggingface#11007)

* Wan Pipeline scaling fix, type hint warning, multi generator fix

* Apply suggestions from code review

[LoRA] change to warning from info when notifying the users about a LoRA no-op (huggingface#11044)

* move to warning.

* test related changes.

Rename Lumina(2)Text2ImgPipeline -> Lumina(2)Pipeline (huggingface#10827)

* Rename Lumina(2)Text2ImgPipeline -> Lumina(2)Pipeline

---------

Co-authored-by: YiYi Xu <[email protected]>

making ```formatted_images``` initialization compact (huggingface#10801)

compact writing

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: YiYi Xu <[email protected]>

Fix aclnnRepeatInterleaveIntWithDim error on NPU for get_1d_rotary_pos_embed (huggingface#10820)

* get_1d_rotary_pos_embed support npu

* Update src/diffusers/models/embeddings.py

---------

Co-authored-by: Kai zheng <[email protected]>
Co-authored-by: hlky <[email protected]>
Co-authored-by: YiYi Xu <[email protected]>

[Tests] restrict memory tests for quanto for certain schemes. (huggingface#11052)

* restrict memory tests for quanto for certain schemes.

* Apply suggestions from code review

Co-authored-by: Dhruv Nair <[email protected]>

* fixes

* style

---------

Co-authored-by: Dhruv Nair <[email protected]>

[LoRA] feat: support non-diffusers wan t2v loras. (huggingface#11059)

feat: support non-diffusers wan t2v loras.

[examples/controlnet/train_controlnet_sd3.py] Fixes huggingface#11050 - Cast prompt_embeds and pooled_prompt_embeds to weight_dtype to prevent dtype mismatch (huggingface#11051)

Fix: dtype mismatch of prompt embeddings in sd3 controlnet training

Co-authored-by: Andreas Jörg <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>

reverts accidental change that removes attn_mask in attn. Improves fl… (huggingface#11065)

reverts accidental change that removes attn_mask in attn. Improves flux ptxla by using flash block sizes. Moves encoding outside the for loop.

Co-authored-by: Juan Acevedo <[email protected]>

Fix deterministic issue when getting pipeline dtype and device (huggingface#10696)

Co-authored-by: Dhruv Nair <[email protected]>

[Tests] add requires peft decorator. (huggingface#11037)

* add requires peft decorator.

* install peft conditionally.

* conditional deps.

Co-authored-by: DN6 <[email protected]>

---------

Co-authored-by: DN6 <[email protected]>

CogView4 Control Block (huggingface#10809)

* cogview4 control training

---------

Co-authored-by: OleehyO <[email protected]>
Co-authored-by: yiyixuxu <[email protected]>

[CI] pin transformers version for benchmarking. (huggingface#11067)

pin transformers version for benchmarking.

updates

Fix Wan I2V Quality (huggingface#11087)

* fix_wan_i2v_quality

* Update src/diffusers/pipelines/wan/pipeline_wan_i2v.py

Co-authored-by: YiYi Xu <[email protected]>

* Update src/diffusers/pipelines/wan/pipeline_wan_i2v.py

Co-authored-by: YiYi Xu <[email protected]>

* Update src/diffusers/pipelines/wan/pipeline_wan_i2v.py

Co-authored-by: YiYi Xu <[email protected]>

* Update pipeline_wan_i2v.py

---------

Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: hlky <[email protected]>

LTX 0.9.5 (huggingface#10968)

* update

---------

Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: hlky <[email protected]>

make PR GPU tests conditioned on styling. (huggingface#11099)

Group offloading improvements (huggingface#11094)

update

Fix pipeline_flux_controlnet.py (huggingface#11095)

* Fix pipeline_flux_controlnet.py

* Fix style

update readme instructions. (huggingface#11096)

Co-authored-by: Juan Acevedo <[email protected]>

Resolve stride mismatch in UNet's ResNet to support Torch DDP (huggingface#11098)

Modify UNet's ResNet implementation to resolve stride mismatch in Torch's DDP

Fix Group offloading behaviour when using streams (huggingface#11097)

* update

* update

Quality options in `export_to_video` (huggingface#11090)

* Quality options in `export_to_video`

* make style

improve more.

add placeholders for docstrings.

formatting.

smol fix.

solidify validation and annotation

* Revert "feat: pipeline-level quant config."

This reverts commit 316ff46.

* feat: implement pipeline-level quantization config

Co-authored-by: SunMarc <[email protected]>

* update

* fixes

* fix validation.

* add tests and other improvements.

* add tests

* import quality

* remove prints.

* add docs.

* fixes to docs.

* doc fixes.

* doc fixes.

* add validation to the input quantization_config.

* clarify recommendations.

* docs

* add to ci.

* todo.

---------

Co-authored-by: SunMarc <[email protected]>
  • Loading branch information
sayakpaul and SunMarc authored May 9, 2025
commit 599c887164a83b3450a7c2e640ddb86d63c0d518
54 changes: 54 additions & 0 deletions .github/workflows/nightly_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,60 @@ jobs:
pip install slack_sdk tabulate
python utils/log_reports.py >> $GITHUB_STEP_SUMMARY

run_nightly_pipeline_level_quantization_tests:
name: Torch quantization nightly tests
strategy:
fail-fast: false
max-parallel: 2
runs-on:
group: aws-g6e-xlarge-plus
container:
image: diffusers/diffusers-pytorch-cuda
options: --shm-size "20gb" --ipc host --gpus 0
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: NVIDIA-SMI
run: nvidia-smi
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
python -m uv pip install -U bitsandbytes optimum_quanto
python -m uv pip install pytest-reportlog
- name: Environment
run: |
python utils/print_env.py
- name: Pipeline-level quantization tests on GPU
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
BIG_GPU_MEMORY: 40
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
--make-reports=tests_pipeline_level_quant_torch_cuda \
--report-log=tests_pipeline_level_quant_torch_cuda.log \
tests/quantization/test_pipeline_level_quantization.py
- name: Failure short reports
if: ${{ failure() }}
run: |
cat reports/tests_pipeline_level_quant_torch_cuda_stats.txt
cat reports/tests_pipeline_level_quant_torch_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
with:
name: torch_cuda_pipeline_level_quant_reports
path: reports
- name: Generate Report and Notify Channel
if: always()
run: |
pip install slack_sdk tabulate
python utils/log_reports.py >> $GITHUB_STEP_SUMMARY

# M1 runner currently not well supported
# TODO: (Dhruv) add these back when we setup better testing for Apple Silicon
# run_nightly_tests_apple_m1:
Expand Down
7 changes: 4 additions & 3 deletions docs/source/en/api/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,17 @@ specific language governing permissions and limitations under the License.

# Quantization

Quantization techniques reduce memory and computational costs by representing weights and activations with lower-precision data types like 8-bit integers (int8). This enables loading larger models you normally wouldn't be able to fit into memory, and speeding up inference. Diffusers supports 8-bit and 4-bit quantization with [bitsandbytes](https://huggingface.co/docs/bitsandbytes/en/index).

Quantization techniques that aren't supported in Transformers can be added with the [`DiffusersQuantizer`] class.
Quantization techniques reduce memory and computational costs by representing weights and activations with lower-precision data types like 8-bit integers (int8). This enables loading larger models you normally wouldn't be able to fit into memory, and speeding up inference.

<Tip>

Learn how to quantize models in the [Quantization](../quantization/overview) guide.

</Tip>

## PipelineQuantizationConfig

[[autodoc]] quantizers.PipelineQuantizationConfig

## BitsAndBytesConfig

Expand Down
87 changes: 87 additions & 0 deletions docs/source/en/quantization/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,90 @@ Diffusers currently supports the following quantization methods.
- [Quanto](./quanto.md)

[This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.

## Pipeline-level quantization

Diffusers allows users to directly initialize pipelines from checkpoints that may contain quantized models ([example](https://huggingface.co/hf-internal-testing/flux.1-dev-nf4-pkg)). However, users may want to apply
quantization on-the-fly when initializing a pipeline from a pre-trained and non-quantized checkpoint. You can
do this with [`~quantizers.PipelineQuantizationConfig`].

Start by defining a `PipelineQuantizationConfig`:

```py
import torch
from diffusers import DiffusionPipeline
from diffusers.quantizers.quantization_config import QuantoConfig
from diffusers.quantizers import PipelineQuantizationConfig
from transformers import BitsAndBytesConfig

pipeline_quant_config = PipelineQuantizationConfig(
quant_mapping={
"transformer": QuantoConfig(weights_dtype="int8"),
"text_encoder_2": BitsAndBytesConfig(
load_in_4bit=True, compute_dtype=torch.bfloat16
),
}
)
```

Then pass it to [`~DiffusionPipeline.from_pretrained`] and run inference:

```py
pipe = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
quantization_config=pipeline_quant_config,
torch_dtype=torch.bfloat16,
).to("cuda")

image = pipe("photo of a cute dog").images[0]
```

This method allows for more granular control over the quantization specifications of individual
model-level components of a pipeline. It also allows for different quantization backends for
different components. In the above example, you used a combination of Quanto and BitsandBytes. However,
one caveat of this method is that users need to know which components come from `transformers` to be able
to import the right quantization config class.

The other method is simpler in terms of experience but is
less-flexible. Start by defining a `PipelineQuantizationConfig` but in a different way:

```py
pipeline_quant_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_4bit",
quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
components_to_quantize=["transformer", "text_encoder_2"],
)
```

This `pipeline_quant_config` can now be passed to [`~DiffusionPipeline.from_pretrained`] similar to the above example.

In this case, `quant_kwargs` will be used to initialize the quantization specifications
of the respective quantization configuration class of `quant_backend`. `components_to_quantize`
is used to denote the components that will be quantized. For most pipelines, you would want to
keep `transformer` in the list as that is often the most compute and memory intensive.

The config below will work for most diffusion pipelines that have a `transformer` component present.
In most case, you will want to quantize the `transformer` component as that is often the most compute-
intensive part of a diffusion pipeline.

```py
pipeline_quant_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_4bit",
quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
components_to_quantize=["transformer"],
)
```

Below is a list of the supported quantization backends available in both `diffusers` and `transformers`:

* `bitsandbytes_4bit`
* `bitsandbytes_8bit`
* `gguf`
* `quanto`
* `torchao`


Diffusion pipelines can have multiple text encoders. [`FluxPipeline`] has two, for example. It's
recommended to quantize the text encoders that are memory-intensive. Some examples include T5,
Llama, Gemma, etc. In the above example, you quantized the T5 model of [`FluxPipeline`] through
`text_encoder_2` while keeping the CLIP model intact (accessible through `text_encoder`).
13 changes: 13 additions & 0 deletions src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,8 +675,10 @@ def load_sub_model(
use_safetensors: bool,
dduf_entries: Optional[Dict[str, DDUFEntry]],
provider_options: Any,
quantization_config: Optional[Any] = None,
):
"""Helper method to load the module `name` from `library_name` and `class_name`"""
from ..quantizers import PipelineQuantizationConfig

# retrieve class candidates

Expand Down Expand Up @@ -769,6 +771,17 @@ def load_sub_model(
else:
loading_kwargs["low_cpu_mem_usage"] = False

if (
quantization_config is not None
and isinstance(quantization_config, PipelineQuantizationConfig)
and issubclass(class_obj, torch.nn.Module)
):
model_quant_config = quantization_config._resolve_quant_config(
is_diffusers=is_diffusers_model, module_name=name
)
if model_quant_config is not None:
loading_kwargs["quantization_config"] = model_quant_config

# check if the module is in a subdirectory
if dduf_entries:
loading_kwargs["dduf_entries"] = dduf_entries
Expand Down
6 changes: 6 additions & 0 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from ..models import AutoencoderKL
from ..models.attention_processor import FusedAttnProcessor2_0
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin
from ..quantizers import PipelineQuantizationConfig
from ..quantizers.bitsandbytes.utils import _check_bnb_status
from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from ..utils import (
Expand Down Expand Up @@ -725,6 +726,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
use_safetensors = kwargs.pop("use_safetensors", None)
use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
quantization_config = kwargs.pop("quantization_config", None)

if torch_dtype is not None and not isinstance(torch_dtype, dict) and not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
Expand All @@ -741,6 +743,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
" install accelerate\n```\n."
)

if quantization_config is not None and not isinstance(quantization_config, PipelineQuantizationConfig):
raise ValueError("`quantization_config` must be an instance of `PipelineQuantizationConfig`.")

if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
Expand Down Expand Up @@ -1001,6 +1006,7 @@ def load_module(name, value):
use_safetensors=use_safetensors,
dduf_entries=dduf_entries,
provider_options=provider_options,
quantization_config=quantization_config,
)
logger.info(
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
Expand Down
Loading