-
Notifications
You must be signed in to change notification settings - Fork 8
bring cache-dit to flux-fast #13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, thanks for your contributions!
Would it be possible to also show some numbers on H100? Additionally, does cache-dit
work with torch.compile()
?
Sorry~ I don't have a H100 device.
@sayakpaul please take a try~ bf16 + cache-dit + torch.compile # bf16 + cache-dit + torch.compile
python run_benchmark.py \
--ckpt "black-forest-labs/FLUX.1-dev" \
--trace-file bfloat16.json.gz \
--compile_export_mode compile \
--disable_fused_projections \
--disable_channels_last \
--disable_fa3 \
--disable_quant \
--disable_inductor_tuning_flags \
--num_inference_steps 28 \
--enable_cache_dit \
--output-file output_cache_compile.png log: (my device is NVIDIA L20) SingleProcess AUTOTUNE benchmarking takes 1.0541 seconds and 0.0003 seconds precompiling for 21 choices
AUTOTUNE convolution(1x128x1024x1024, 128x128x3x3)
convolution 2.8324 ms 100.0%
triton_convolution2d_8250 2.8774 ms 98.4% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
triton_convolution2d_8249 2.8918 ms 97.9% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
triton_convolution2d_8245 2.9604 ms 95.7% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
triton_convolution2d_8246 3.2451 ms 87.3% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
triton_convolution2d_8251 3.3147 ms 85.4% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
triton_convolution2d_8248 3.7745 ms 75.0% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
triton_convolution2d_8247 9.7270 ms 29.1% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=1, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.5687 seconds and 0.0002 seconds precompiling for 8 choices
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:02<00:00, 1.85it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:02<00:00, 1.85it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:21<00:00, 1.32it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:16<00:00, 1.71it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:16<00:00, 1.70it/s]
time mean/var: tensor([17.3543, 17.4310]) 17.392642974853516 0.0029477551579475403
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:16<00:00, 1.71it/s] PSNR: cache-dit-metrics-cli all -i1 output.png -i2 output_cache_compile.png
INFO 07-11 14:09:42 [metrics.py:415] output.png vs output_cache_compile.png, Num: 1, PSNR: 34.16003517257299
INFO 07-11 14:09:42 [metrics.py:420] output.png vs output_cache_compile.png, Num: 1, SSIM: 0.9749983003494543
INFO 07-11 14:09:43 [metrics.py:425] output.png vs output_cache_compile.png, Num: 1, MSE: 24.950361569722492
|
Cool. Let's add some of these numbers to the README then? |
@sayakpaul done! PTAL~ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a couple more comments.
mode="max-autotune" if not is_cached else "max-autotune-no-cudagraphs", | ||
fullgraph=(True if not is_cached else False), | ||
dynamic=True if is_hip() else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why would we want to have graph-breaks when a transformer is cached?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/dual_block_cache/cache_context.py#L1163. The cache-dit relies heavily on dynamic Python operations to maintain the cache_context, so it is necessary to introduce graph breaks at appropriate positions to be compatible with torch.compile.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cc: @anijain2305 are there any workarounds here?
@DefTruth can we add this explanation in the comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cc: @anijain2305 are there any workarounds here?
@DefTruth can we add this explanation in the comment?
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ohh that's unfortunate, I expect some non-negligible slowdowns from the introduction of graph breaks
cache_options = { | ||
"cache_type": CacheType.DBCache, | ||
"warmup_steps": 8, | ||
"max_cached_steps": 8, | ||
"Fn_compute_blocks": 12, | ||
"Bn_compute_blocks": 12, | ||
"residual_diff_threshold": 0.12, | ||
} | ||
apply_cache_on_pipe(pipeline, **cache_options) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cool, not too bad :) I think it's fine for the purposes of this PR to hard-code this stuff, but maybe we eventually want the caching to be configurable somehow
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jbschlosser
design: https://github.com/vipshop/cache-dit?tab=readme-ov-file#dbcache, different configurations of compute blocks (F8B12, etc.) can be customized in cache-dit: DBCache
.
⚡️DBCache: Dual Block Cache
DBCache provides configurable parameters for custom optimization, enabling a balanced trade-off between performance and precision:
- Fn: Specifies that DBCache uses the first n Transformer blocks to fit the information at time step t, enabling the calculation of a more stable L1 diff and delivering more accurate information to subsequent blocks.
- Bn: Further fuses approximate information in the last n Transformer blocks to enhance prediction accuracy. These blocks act as an auto-scaler for approximate hidden states that use residual cache.
- warmup_steps: (default: 0) DBCache does not apply the caching strategy when the number of running steps is less than or equal to this value, ensuring the model sufficiently learns basic features during warmup.
- max_cached_steps: (default: -1) DBCache disables the caching strategy when the previous cached steps exceed this value to prevent precision degradation.
- residual_diff_threshold: The value of residual diff threshold, a higher value leads to faster performance at the cost of lower precision.
For a good balance between performance and precision, DBCache is configured by default with F8B8, 8 warmup steps, and unlimited cached steps.
from diffusers import FluxPipeline
from cache_dit.cache_factory import apply_cache_on_pipe, CacheType
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16,
).to("cuda")
# Default options, F8B8, good balance between performance and precision
cache_options = CacheType.default_options(CacheType.DBCache)
# Custom options, F8B16, higher precision
cache_options = {
"cache_type": CacheType.DBCache,
"warmup_steps": 8,
"max_cached_steps": 8, # -1 means no limit
"Fn_compute_blocks": 8, # Fn, F8, etc.
"Bn_compute_blocks": 16, # Bn, B16, etc.
"residual_diff_threshold": 0.12,
}
apply_cache_on_pipe(pipe, **cache_options)
Moreover, users configuring higher Bn values (e.g., F8B16) while aiming to maintain good performance can specify Bn_compute_blocks_ids to work with Bn. DBCache will only compute the specified blocks, with the remaining estimated using the previous step's residual cache.
# Custom options, F8B16, higher precision with good performance.
cache_options = {
# 0, 2, 4, ..., 14, 15, etc. [0,16)
"Bn_compute_blocks_ids": CacheType.range(0, 16, 2),
# If the L1 difference is below this threshold, skip Bn blocks
# not in `Bn_compute_blocks_ids`(1, 3,..., etc), Otherwise,
# compute these blocks.
"non_compute_blocks_diff_threshold": 0.08,
}
DBCache, L20x1 , Steps: 28, "A cat holding a sign that says hello world with complex background"
Baseline(L20x1) | F1B0 (0.08) | F1B0 (0.20) | F8B8 (0.15) | F12B12 (0.20) | F16B16 (0.20) |
---|---|---|---|---|---|
24.85s | 15.59s | 8.58s | 15.41s | 15.11s | 17.74s |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
🔥Hybrid TaylorSeer
We have supported the TaylorSeers: From Reusing to Forecasting: Accelerating Diffusion Models with TaylorSeers algorithm to further improve the precision of DBCache in cases where the cached steps are large, namely, Hybrid TaylorSeer + DBCache. At timesteps with significant intervals, the feature similarity in diffusion models decreases substantially, significantly harming the generation quality.
TaylorSeer employs a differential method to approximate the higher-order derivatives of features and predict features in future timesteps with Taylor series expansion. The TaylorSeer implemented in CacheDiT supports both hidden states and residual cache types. That is
cache_options = {
# TaylorSeer options
"enable_taylorseer": True,
"enable_encoder_taylorseer": True,
# Taylorseer cache type cache be hidden_states or residual.
"taylorseer_cache_type": "residual",
# Higher values of n_derivatives will lead to longer
# computation time but may improve precision significantly.
"taylorseer_kwargs": {
"n_derivatives": 2, # default is 2.
},
"warmup_steps": 3, # prefer: >= n_derivatives + 1
"residual_diff_threshold": 0.12,
}
Important
Please note that if you have used TaylorSeer as the calibrator for approximate hidden states, the Bn param of DBCache can be set to 0. In essence, DBCache's Bn is also act as a calibrator, so you can choose either Bn > 0 or TaylorSeer. We recommend using the configuration scheme of TaylorSeer + DBCache FnB0.
DBCache F1B0 + TaylorSeer, L20x1, Steps: 28,
"A cat holding a sign that says hello world with complex background"
Baseline(L20x1) | F1B0 (0.12) | +TaylorSeer | F1B0 (0.15) | +TaylorSeer | +compile |
---|---|---|---|---|---|
24.85s | 12.85s | 12.86s | 10.27s | 10.28s | 8.48s |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, thanks for the contribution! I'll see if I can get some time to run some benchmarks with caching on my H100 machine
Feel free to update this PR and replace the result in L20 with the result from H100. |
Will merge the PR after that. It would be good to investigate the graph break problems for potential improvements though. Do you want to open an issue in the cache-dit repo and tag us? |
I have you open an issue in the cache-dit: |
@sayakpaul @jbschlosser A relatively safe approach is to modify the |
Sure, let's do that. Thanks for willing to help. |
@DefTruth I was also wondering if the cache-dit configuration options are hardware-dependent. If so, I think we should mention it in the README and provide guidance to the users about how they should tune it. WDYT? |
The configuration options for cache-dit are not hardware-dependent. I have already added a link to the documentation that provides guidance to users on how they should tune it. |
Done~ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for your hard work 🖖
Will wait a bit for @jbschlosser to run H100 exps if possible before merging.
just launched a set of experiments with caching enabled :) |
ah one sec, looks like the flash_attn imports were messed up and ruined my runs |
on H100 for Flux Schnell, I'm not seeing a huge difference with
it doesn't work with |
@jbschlosser # bf16
python run_benchmark.py \
--ckpt "black-forest-labs/FLUX.1-dev" \
--trace-file bfloat16.json.gz \
--compile_export_mode disabled \
--disable_fused_projections \
--disable_channels_last \
--disable_fa3 \
--disable_quant \
--disable_inductor_tuning_flags \
--num_inference_steps 28 \
--output-file output.png
# bf16 + cache-dit + torch.compile
python run_benchmark.py \
--ckpt "black-forest-labs/FLUX.1-dev" \
--trace-file bfloat16.json.gz \
--compile_export_mode compile \
--disable_fused_projections \
--disable_channels_last \
--disable_fa3 \
--disable_quant \
--disable_inductor_tuning_flags \
--num_inference_steps 28 \
--enable_cache_dit \
--output-file output_cache_compile.png cache-dit doesn't work with |
Weird, cache-dit doesn't depend on flash_attn. |
If you need, I can add a check to skip the export process when both cache-dit and export are enabled, and add a warning log to inform users that they need to disable cache-dit before using the export function. |
@jbschlosser done~ PTAL |
Install
docs: https://github.com/vipshop/cache-dit
BFloat16
python run_benchmark.py \ --ckpt "black-forest-labs/FLUX.1-dev" \ --trace-file bfloat16.json.gz \ --compile_export_mode disabled \ --disable_fused_projections \ --disable_channels_last \ --disable_fa3 \ --disable_quant \ --disable_inductor_tuning_flags \ --num_inference_steps 28 \ --output-file output.png
BFloat16 + cache-dit
python run_benchmark.py \ --ckpt "black-forest-labs/FLUX.1-dev" \ --trace-file bfloat16.json.gz \ --compile_export_mode disabled \ --disable_fused_projections \ --disable_channels_last \ --disable_fa3 \ --disable_quant \ --disable_inductor_tuning_flags \ --num_inference_steps 28 \ --enable_cache_dit \ --output-file output_cache.png
BFloat16 + cache-dit + torch.compile
Metrics