Skip to content

bring cache-dit to flux-fast #13

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: main
Choose a base branch
from

Conversation

DefTruth
Copy link

@DefTruth DefTruth commented Jul 10, 2025

Install

docs: https://github.com/vipshop/cache-dit

pip install -U cache-dit

BFloat16

python run_benchmark.py \
    --ckpt "black-forest-labs/FLUX.1-dev" \
    --trace-file bfloat16.json.gz \
    --compile_export_mode disabled \
    --disable_fused_projections \
    --disable_channels_last \
    --disable_fa3 \
    --disable_quant \
    --disable_inductor_tuning_flags \
    --num_inference_steps 28 \
    --output-file output.png 

BFloat16 + cache-dit

python run_benchmark.py \
    --ckpt "black-forest-labs/FLUX.1-dev" \
    --trace-file bfloat16.json.gz \
    --compile_export_mode disabled \
    --disable_fused_projections \
    --disable_channels_last \
    --disable_fa3 \
    --disable_quant \
    --disable_inductor_tuning_flags \
    --num_inference_steps 28 \
    --enable_cache_dit \
    --output-file output_cache.png 

BFloat16 + cache-dit + torch.compile

# bf16 + cache-dit + torch.compile
python run_benchmark.py \
    --ckpt "black-forest-labs/FLUX.1-dev" \
    --trace-file bfloat16.json.gz \
    --compile_export_mode compile \
    --disable_fused_projections \
    --disable_channels_last \
    --disable_fa3 \
    --disable_quant \
    --disable_inductor_tuning_flags \
    --num_inference_steps 28 \
    --enable_cache_dit \
    --output-file output_cache_compile.png

Metrics

BF16 BF16 + cache-dit BF16 + cache-dit + compile
Baseline PSNR: 34.23 PSNR: 34.16
L20: 24.94s L20: 20.85s L20: 17.39s
output output_cache output_cache_compile
cache-dit-metrics-cli all -i1 output.png -i2 output_cache.png
INFO 07-11 14:19:04 [metrics.py:415] output.png vs output_cache.png, Num: 1, PSNR: 34.232669211224206
INFO 07-11 14:19:05 [metrics.py:420] output.png vs output_cache.png, Num: 1, SSIM: 0.9748773813481453
INFO 07-11 14:19:05 [metrics.py:425] output.png vs output_cache.png, Num: 1,  MSE: 24.53654670715332
cache-dit-metrics-cli all -i1 output.png -i2 output_cache_compile.png
INFO 07-11 14:09:42 [metrics.py:415] output.png vs output_cache_compile.png, Num: 1, PSNR: 34.16003517257299
INFO 07-11 14:09:42 [metrics.py:420] output.png vs output_cache_compile.png, Num: 1, SSIM: 0.9749983003494543
INFO 07-11 14:09:43 [metrics.py:425] output.png vs output_cache_compile.png, Num: 1,  MSE: 24.950361569722492

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks for your contributions!

Would it be possible to also show some numbers on H100? Additionally, does cache-dit work with torch.compile()?

@sayakpaul sayakpaul requested a review from jbschlosser July 10, 2025 12:11
@DefTruth
Copy link
Author

DefTruth commented Jul 10, 2025

Nice, thanks for your contributions!

Would it be possible to also show some numbers on H100? Additionally, does cache-dit work with torch.compile()?

Sorry~ I don't have a H100 device.

cache-dit can work with torch.compile() with fullgraph=False, please check https://github.com/vipshop/cache-dit?tab=readme-ov-file#compile and https://github.com/vipshop/cache-dit/blob/main/bench/bench.py#L189 for more details.

@sayakpaul please take a try~

bf16 + cache-dit + torch.compile

# bf16 + cache-dit + torch.compile
python run_benchmark.py \
    --ckpt "black-forest-labs/FLUX.1-dev" \
    --trace-file bfloat16.json.gz \
    --compile_export_mode compile \
    --disable_fused_projections \
    --disable_channels_last \
    --disable_fa3 \
    --disable_quant \
    --disable_inductor_tuning_flags \
    --num_inference_steps 28 \
    --enable_cache_dit \
    --output-file output_cache_compile.png

log: (my device is NVIDIA L20)

SingleProcess AUTOTUNE benchmarking takes 1.0541 seconds and 0.0003 seconds precompiling for 21 choices
AUTOTUNE convolution(1x128x1024x1024, 128x128x3x3)
  convolution 2.8324 ms 100.0%
  triton_convolution2d_8250 2.8774 ms 98.4% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_8249 2.8918 ms 97.9% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_8245 2.9604 ms 95.7% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_8246 3.2451 ms 87.3% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_8251 3.3147 ms 85.4% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_8248 3.7745 ms 75.0% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_8247 9.7270 ms 29.1% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=1, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.5687 seconds and 0.0002 seconds precompiling for 8 choices
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:02<00:00,  1.85it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:02<00:00,  1.85it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:21<00:00,  1.32it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:16<00:00,  1.71it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:16<00:00,  1.70it/s]
time mean/var: tensor([17.3543, 17.4310]) 17.392642974853516 0.0029477551579475403
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:16<00:00,  1.71it/s]

PSNR:

cache-dit-metrics-cli all -i1 output.png -i2 output_cache_compile.png
INFO 07-11 14:09:42 [metrics.py:415] output.png vs output_cache_compile.png, Num: 1, PSNR: 34.16003517257299
INFO 07-11 14:09:42 [metrics.py:420] output.png vs output_cache_compile.png, Num: 1, SSIM: 0.9749983003494543
INFO 07-11 14:09:43 [metrics.py:425] output.png vs output_cache_compile.png, Num: 1,  MSE: 24.950361569722492
BF16 BF16 + cache-dit BF16 + cache-dit + compile
Baseline PSNR: 34.23 PSNR: 34.16
L20: 24.94s L20: 20.85s L20: 17.39s
output output_cache output_cache_compile

@sayakpaul
Copy link
Member

Cool. Let's add some of these numbers to the README then?

@DefTruth
Copy link
Author

Cool. Let's add some of these numbers to the README then?

@sayakpaul done! PTAL~

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a couple more comments.

Comment on lines +243 to +245
mode="max-autotune" if not is_cached else "max-autotune-no-cudagraphs",
fullgraph=(True if not is_cached else False),
dynamic=True if is_hip() else None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would we want to have graph-breaks when a transformer is cached?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/cache_factory/dual_block_cache/cache_context.py#L1163. The cache-dit relies heavily on dynamic Python operations to maintain the cache_context, so it is necessary to introduce graph breaks at appropriate positions to be compatible with torch.compile.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cc: @anijain2305 are there any workarounds here?

@DefTruth can we add this explanation in the comment?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cc: @anijain2305 are there any workarounds here?

@DefTruth can we add this explanation in the comment?

done

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohh that's unfortunate, I expect some non-negligible slowdowns from the introduction of graph breaks

Comment on lines +405 to +413
cache_options = {
"cache_type": CacheType.DBCache,
"warmup_steps": 8,
"max_cached_steps": 8,
"Fn_compute_blocks": 12,
"Bn_compute_blocks": 12,
"residual_diff_threshold": 0.12,
}
apply_cache_on_pipe(pipeline, **cache_options)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool, not too bad :) I think it's fine for the purposes of this PR to hard-code this stuff, but maybe we eventually want the caching to be configurable somehow

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jbschlosser
design: https://github.com/vipshop/cache-dit?tab=readme-ov-file#dbcache, different configurations of compute blocks (F8B12, etc.) can be customized in cache-dit: DBCache.

⚡️DBCache: Dual Block Cache

DBCache provides configurable parameters for custom optimization, enabling a balanced trade-off between performance and precision:

  • Fn: Specifies that DBCache uses the first n Transformer blocks to fit the information at time step t, enabling the calculation of a more stable L1 diff and delivering more accurate information to subsequent blocks.
  • Bn: Further fuses approximate information in the last n Transformer blocks to enhance prediction accuracy. These blocks act as an auto-scaler for approximate hidden states that use residual cache.

  • warmup_steps: (default: 0) DBCache does not apply the caching strategy when the number of running steps is less than or equal to this value, ensuring the model sufficiently learns basic features during warmup.
  • max_cached_steps: (default: -1) DBCache disables the caching strategy when the previous cached steps exceed this value to prevent precision degradation.
  • residual_diff_threshold: The value of residual diff threshold, a higher value leads to faster performance at the cost of lower precision.

For a good balance between performance and precision, DBCache is configured by default with F8B8, 8 warmup steps, and unlimited cached steps.

from diffusers import FluxPipeline
from cache_dit.cache_factory import apply_cache_on_pipe, CacheType

pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    torch_dtype=torch.bfloat16,
).to("cuda")

# Default options, F8B8, good balance between performance and precision
cache_options = CacheType.default_options(CacheType.DBCache)

# Custom options, F8B16, higher precision
cache_options = {
    "cache_type": CacheType.DBCache,
    "warmup_steps": 8,
    "max_cached_steps": 8,    # -1 means no limit
    "Fn_compute_blocks": 8,   # Fn, F8, etc.
    "Bn_compute_blocks": 16,  # Bn, B16, etc.
    "residual_diff_threshold": 0.12,
}

apply_cache_on_pipe(pipe, **cache_options)

Moreover, users configuring higher Bn values (e.g., F8B16) while aiming to maintain good performance can specify Bn_compute_blocks_ids to work with Bn. DBCache will only compute the specified blocks, with the remaining estimated using the previous step's residual cache.

# Custom options, F8B16, higher precision with good performance.
cache_options = {
    # 0, 2, 4, ..., 14, 15, etc. [0,16)
    "Bn_compute_blocks_ids": CacheType.range(0, 16, 2),
    # If the L1 difference is below this threshold, skip Bn blocks 
    # not in `Bn_compute_blocks_ids`(1, 3,..., etc), Otherwise, 
    # compute these blocks.
    "non_compute_blocks_diff_threshold": 0.08,
}

DBCache, L20x1 , Steps: 28, "A cat holding a sign that says hello world with complex background"

Baseline(L20x1) F1B0 (0.08) F1B0 (0.20) F8B8 (0.15) F12B12 (0.20) F16B16 (0.20)
24.85s 15.59s 8.58s 15.41s 15.11s 17.74s

🔥Hybrid TaylorSeer

We have supported the TaylorSeers: From Reusing to Forecasting: Accelerating Diffusion Models with TaylorSeers algorithm to further improve the precision of DBCache in cases where the cached steps are large, namely, Hybrid TaylorSeer + DBCache. At timesteps with significant intervals, the feature similarity in diffusion models decreases substantially, significantly harming the generation quality.

$$ \mathcal{F}_{\text {pred }, m}\left(x_{t-k}^l\right)=\mathcal{F}\left(x_t^l\right)+\sum_{i=1}^m \frac{\Delta^i \mathcal{F}\left(x_t^l\right)}{i!\cdot N^i}(-k)^i $$

TaylorSeer employs a differential method to approximate the higher-order derivatives of features and predict features in future timesteps with Taylor series expansion. The TaylorSeer implemented in CacheDiT supports both hidden states and residual cache types. That is $\mathcal{F}_{\text {pred }, m}\left(x_{t-k}^l\right)$ can be a residual cache or a hidden-state cache.

cache_options = {
    # TaylorSeer options
    "enable_taylorseer": True,
    "enable_encoder_taylorseer": True,
    # Taylorseer cache type cache be hidden_states or residual.
    "taylorseer_cache_type": "residual",
    # Higher values of n_derivatives will lead to longer 
    # computation time but may improve precision significantly.
    "taylorseer_kwargs": {
        "n_derivatives": 2, # default is 2.
    },
    "warmup_steps": 3, # prefer: >= n_derivatives + 1
    "residual_diff_threshold": 0.12,
}

Important

Please note that if you have used TaylorSeer as the calibrator for approximate hidden states, the Bn param of DBCache can be set to 0. In essence, DBCache's Bn is also act as a calibrator, so you can choose either Bn > 0 or TaylorSeer. We recommend using the configuration scheme of TaylorSeer + DBCache FnB0.

DBCache F1B0 + TaylorSeer, L20x1, Steps: 28,
"A cat holding a sign that says hello world with complex background"

Baseline(L20x1) F1B0 (0.12) +TaylorSeer F1B0 (0.15) +TaylorSeer +compile
24.85s 12.85s 12.86s 10.27s 10.28s 8.48s

Copy link
Collaborator

@jbschlosser jbschlosser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thanks for the contribution! I'll see if I can get some time to run some benchmarks with caching on my H100 machine

@DefTruth
Copy link
Author

Awesome, thanks for the contribution! I'll see if I can get some time to run some benchmarks with caching on my H100 machine

Feel free to update this PR and replace the result in L20 with the result from H100.

@DefTruth DefTruth changed the title bring almost lossless cache-dit to flux-fast bring cache-dit to flux-fast Jul 10, 2025
@sayakpaul
Copy link
Member

Will merge the PR after that. It would be good to investigate the graph break problems for potential improvements though. Do you want to open an issue in the cache-dit repo and tag us?

@DefTruth
Copy link
Author

Will merge the PR after that. It would be good to investigate the graph break problems for potential improvements though. Do you want to open an issue in the cache-dit repo and tag us?

I have you open an issue in the cache-dit:

@DefTruth
Copy link
Author

@sayakpaul @jbschlosser A relatively safe approach is to modify the --disable_cache_dit (default: False) option to --enable_cache_dit (default: False). This way, under the default settings, all the original optimization options will not be affected. If you wish to do this, I can help with the modification.

@sayakpaul
Copy link
Member

Sure, let's do that. Thanks for willing to help.

@sayakpaul
Copy link
Member

@DefTruth I was also wondering if the cache-dit configuration options are hardware-dependent. If so, I think we should mention it in the README and provide guidance to the users about how they should tune it.

WDYT?

@DefTruth
Copy link
Author

@DefTruth I was also wondering if the cache-dit configuration options are hardware-dependent. If so, I think we should mention it in the README and provide guidance to the users about how they should tune it.

WDYT?

The configuration options for cache-dit are not hardware-dependent. I have already added a link to the documentation that provides guidance to users on how they should tune it.

@DefTruth
Copy link
Author

Sure, let's do that. Thanks for willing to help.

Done~

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for your hard work 🖖

Will wait a bit for @jbschlosser to run H100 exps if possible before merging.

@jbschlosser
Copy link
Collaborator

Will wait a bit for @jbschlosser to run H100 exps if possible before merging.

just launched a set of experiments with caching enabled :)

@jbschlosser
Copy link
Collaborator

ah one sec, looks like the flash_attn imports were messed up and ruined my runs

@jbschlosser
Copy link
Collaborator

on H100 for Flux Schnell, I'm not seeing a huge difference with --enable_cache_dit and cache_dit==0.2.8 installed:

results_mean_sec = [
    ("bfloat16", 1.1336731910705566),
    ("compile", 0.7163330912590027),
    ("qkv proj", 0.7130974531173706),
    ("channels_last", 0.6999137997627258),
    ("flash v3", 0.6080203652381897),
    ("float8", 0.4970979690551758),
    ("flags", 0.49017876386642456),
    ("export", "N/A"),
]

it doesn't work with torch.export, but I think that's expected due to the way cache_dit implements the Flux diffusers adapter. thoughts?

@DefTruth
Copy link
Author

DefTruth commented Jul 12, 2025

on H100 for Flux Schnell, I'm not seeing a huge difference with --enable_cache_dit and cache_dit==0.2.8 installed:

results_mean_sec = [
    ("bfloat16", 1.1336731910705566),
    ("compile", 0.7163330912590027),
    ("qkv proj", 0.7130974531173706),
    ("channels_last", 0.6999137997627258),
    ("flash v3", 0.6080203652381897),
    ("float8", 0.4970979690551758),
    ("flags", 0.49017876386642456),
    ("export", "N/A"),
]

it doesn't work with torch.export, but I think that's expected due to the way cache_dit implements the Flux diffusers adapter. thoughts?

@jbschlosser
cache-dit is not suitable for Flux Schnell with only 4 steps (there's no need to use cache either),please try Flux.1 Dev with 28 steps:

# bf16
python run_benchmark.py \
    --ckpt "black-forest-labs/FLUX.1-dev" \
    --trace-file bfloat16.json.gz \
    --compile_export_mode disabled \
    --disable_fused_projections \
    --disable_channels_last \
    --disable_fa3 \
    --disable_quant \
    --disable_inductor_tuning_flags \
    --num_inference_steps 28 \
    --output-file output.png 

# bf16 + cache-dit + torch.compile
python run_benchmark.py \
    --ckpt "black-forest-labs/FLUX.1-dev" \
    --trace-file bfloat16.json.gz \
    --compile_export_mode compile \
    --disable_fused_projections \
    --disable_channels_last \
    --disable_fa3 \
    --disable_quant \
    --disable_inductor_tuning_flags \
    --num_inference_steps 28 \
    --enable_cache_dit \
    --output-file output_cache_compile.png

cache-dit doesn't work with torch.export now. cache-dit extends Flux and introduces some Python dynamic operations, so it may not be possible to export the model using torch.export.

@DefTruth
Copy link
Author

DefTruth commented Jul 12, 2025

flash_attn

Weird, cache-dit doesn't depend on flash_attn.

@DefTruth
Copy link
Author

@jbschlosser

If you need, I can add a check to skip the export process when both cache-dit and export are enabled, and add a warning log to inform users that they need to disable cache-dit before using the export function.

@DefTruth
Copy link
Author

@jbschlosser

If you need, I can add a check to skip the export process when both cache-dit and export are enabled, and add a warning log to inform users that they need to disable cache-dit before using the export function.

@jbschlosser done~ PTAL

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants