Skip to content

[From Single File] support from_single_file method for WanVACE3DTransformer #11807

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 2, 2025

Conversation

J4BEZ
Copy link
Contributor

@J4BEZ J4BEZ commented Jun 25, 2025

What does this PR do?

This PR would solve the problem #11630

First of all, I would like to sincerely thank the team for your continued hard work in making state-of-the-art generative models accessible to everyone.

While encountering the same issue as reported in #11630, I was able to find a solution. I’m submitting this pull request to share that fix with the community in the hope that it may help others facing the same problem.

best regards,
J4BEZ

Fixes # (issue)
#11630

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@nitinmukesh
Copy link

nitinmukesh commented Jun 25, 2025

@J4BEZ

Thank you for PR. Tested by uninstalling diffusers and installing
pip install git+https://github.com/huggingface/diffusers.git@refs/pull/11807/head

AttributeError: module diffusers has no attribute WandVACETransformer3DModel. Did you mean: 'WanVACETransformer3DModel'?

There seems to be a typo issue

src/diffusers/loaders/single_file_model.py

    "WandVACETransformer3DModel": {
        "checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
        "default_subfolder": "transformer",
    },

I think it should be WanVACETransformer3DModel instead of WandVACETransformer3DModel

Sincere thanks to @nitinmukesh 🙇‍♂️
@J4BEZ
Copy link
Contributor Author

J4BEZ commented Jun 25, 2025

@nitinmukesh
That was a close call😳 thank you sincerely for taking the time to test this with me 🙇‍♂️
I’ve promptly updated the code based on the guidance you provided.
Wishing you a smooth and peaceful day!

@nitinmukesh
Copy link

nitinmukesh commented Jun 25, 2025

Thank you @J4BEZ .

Please could you share the code you are using. I am getting error with mine.

Some weights of the model checkpoint were not used when initializing WanVACETransformer3DModel:
 ['vace_blocks.10.after_proj.bias, vace_blocks.10.after_proj.weight, vace_blocks.11.after_proj.bias, vace_blocks.11.after_proj.weight, vace_blocks.12.after_proj.bias, vace_blocks.12.after_proj.weight, vace_blocks.13.after_proj.bias, vace_blocks.13.after_proj.weight, vace_blocks.14.after_proj.bias, vace_blocks.14.after_proj.weight, vace_blocks.8.after_proj.bias, vace_blocks.8.after_proj.weight, vace_blocks.9.after_proj.bias, vace_blocks.9.after_proj.weight']
Traceback (most recent call last):
  File "C:\aiOWN\diffuser_webui\WanVace_GGUF.py", line 12, in <module>
    transformer_gguf = WanVACETransformer3DModel.from_single_file(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\nitin\miniconda3\envs\sddw-dev\Lib\site-packages\huggingface_hub\utils\_validators.py", line 114, in _inner_fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "C:\Users\nitin\miniconda3\envs\sddw-dev\Lib\site-packages\diffusers\loaders\single_file_model.py", line 451, in from_single_file
    dispatch_model(model, **device_map_kwargs)
  File "C:\Users\nitin\miniconda3\envs\sddw-dev\Lib\site-packages\accelerate\big_modeling.py", line 502, in dispatch_model
    model.to(device)
  File "C:\Users\nitin\miniconda3\envs\sddw-dev\Lib\site-packages\diffusers\models\modeling_utils.py", line 1383, in to
    return super().to(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\nitin\miniconda3\envs\sddw-dev\Lib\site-packages\torch\nn\modules\module.py", line 1355, in to
    return self._apply(convert)
           ^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\nitin\miniconda3\envs\sddw-dev\Lib\site-packages\torch\nn\modules\module.py", line 915, in _apply
    module._apply(fn)
  File "C:\Users\nitin\miniconda3\envs\sddw-dev\Lib\site-packages\torch\nn\modules\module.py", line 942, in _apply
    param_applied = fn(param)
                    ^^^^^^^^^
  File "C:\Users\nitin\miniconda3\envs\sddw-dev\Lib\site-packages\torch\nn\modules\module.py", line 1348, in convert
    raise NotImplementedError(
NotImplementedError: Cannot copy out of meta tensor; no data! Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() when moving module from meta to a different device.

I'm using GGUF from https://huggingface.co/samuelchristlie/Wan2.1-VACE-1.3B-GGUF

model_id = "a-r-r-o-w/Wan-VACE-1.3B-diffusers"
transformer_path = f"https://huggingface.co/samuelchristlie/Wan2.1-VACE-1.3B-GGUF/blob/main/Wan2.1-VACE-1.3B-Q8_0.gguf"
transformer_gguf = WanVACETransformer3DModel.from_single_file(
    transformer_path,
    quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
    torch_dtype=torch.bfloat16,
    config=model_id,
    subfolder="transformer",
)

@J4BEZ
Copy link
Contributor Author

J4BEZ commented Jun 25, 2025

@nitinmukesh
Oh! Thank you very much for helping me notice something I had previously overlooked. 🙇‍♂️
I tested the loading of the 14B version GGUF model as follows:

import torch

from diffusers import AutoencoderKLWan, WanVACEPipeline, WanVACETransformer3DModel, GGUFQuantizationConfig, UniPCMultistepScheduler
from huggingface_hub import hf_hub_download

model_id = "Wan-AI/Wan2.1-VACE-14B-diffusers"
gguf_model_id = "QuantStack/Wan2.1_14B_VACE-GGUF"
gguf_model_name = "Wan2.1_14B_VACE-Q3_K_S.gguf"
FLOW_SHIFT = 5.0

gguf_path = hf_hub_download(gguf_model_id, gguf_model_name)

transformer = WanVACETransformer3DModel.from_single_file(
    gguf_path,
    quantization_config=GGUFQuantizationConfig(
        compute_dtype=torch.bfloat16,
    )
)

Thanks to your feedback, I found that—unlike the 14B model—the 1.3B model includes additional layers ranging from vace_blocks.8 to vace_blocks.14.

I will make the necessary adjustments shortly and follow up with you as soon as possible.
Thank you again🙇‍♂️

Sincere thanks to @nitinmukesh again🙇‍♂️
@J4BEZ J4BEZ marked this pull request as draft June 25, 2025 16:04
@J4BEZ
Copy link
Contributor Author

J4BEZ commented Jun 25, 2025

@nitinmukesh

Upon execution, I did not encounter any warnings such as
Some weights of the model checkpoint were not used when initializing WanVACETransformer3DModel: ....
Instead, the following error was raised directly:

NotImplementedError: Cannot copy out of meta tensor; no data! Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() when moving module from meta to a different device.

To investigate further, I compared the state dict keys between the original model and the GGUF-converted model.
As a result, I found that the key vace_patch_embedding.weight is missing from the GGUF model, while it does exist in the original checkpoint.

This may possibly be contributing to the issue.

Below is the code I used for the key comparison:

from gguf.gguf_reader import GGUFReader
from safetensors import safe_open
from huggingface_hub import hf_hub_download

# Load GGUF keys
gguf_file_path = hf_hub_download("samuelchristlie/Wan2.1-VACE-1.3B-GGUF", "Wan2.1-VACE-1.3B-Q8_0.gguf")
original_file_path = hf_hub_download("Wan-AI/Wan2.1-VACE-1.3B", "diffusion_pytorch_model.safetensors")

def read_gguf_file(gguf_file_path):
    keys = set()
    reader = GGUFReader(gguf_file_path)
    for tensor in reader.tensors:
        keys.add(tensor.name)
    return keys

gguf_keys = read_gguf_file(gguf_file_path)

# Load original model keys
original_tensors = {}
with safe_open(original_file_path, framework="pt", device="cpu") as f:
    for key in f.keys():
        original_tensors[key] = f.get_tensor(key)

original_keys = set(original_tensors.keys())

# Check key differences
print(original_keys - gguf_keys)  # {'vace_patch_embedding.weight'}
print(gguf_keys - original_keys)  # set()

I hope this information helps clarify the root of the issue.
Hope you have a peaceful day

Sincerly yours,
J4BEZ

@yiyixuxu
Copy link
Collaborator

thanks for the PR!
let us know when it's ready for review
cc @DN6

@nitinmukesh
Copy link

@yiyixuxu
cc @DN6 @a-r-r-o-w

I guess support is needed from diffusers team. 14B GGUF is working but 1.3B is not.

from typing import List
import torch
import PIL.Image
from diffusers import AutoencoderKLWan, WanVACEPipeline, WanVACETransformer3DModel
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
from diffusers.utils import export_to_video, load_image, load_video
from diffusers import GGUFQuantizationConfig

model_id = "a-r-r-o-w/Wan-VACE-1.3B-diffusers"
# transformer_path = f"https://huggingface.co/newgenai79/Wan-VACE-1.3B-diffusers-gguf/blob/main/Wan-VACE-1.3B-diffusers-Q8_0.gguf"
transformer_path = f"https://huggingface.co/samuelchristlie/Wan2.1-VACE-1.3B-GGUF/blob/main/Wan2.1-VACE-1.3B-Q8_0.gguf"
transformer_gguf = WanVACETransformer3DModel.from_single_file(
    transformer_path,
    quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
    torch_dtype=torch.bfloat16,
    config=model_id,
    subfolder="transformer",
)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanVACEPipeline.from_pretrained(
    model_id,
    transformer=transformer_gguf,
    vae=vae, 
    torch_dtype=torch.bfloat16
)
flow_shift = 3.0  # 5.0 for 720P, 3.0 for 480P
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
pipe.enable_model_cpu_offload()
pipe.vae.enable_tiling()


prompt = "A sleek, humanoid robot stands in a vast warehouse filled with neatly stacked cardboard boxes on industrial shelves. The robot's metallic body gleams under the bright, even lighting, highlighting its futuristic design and intricate joints. A glowing blue light emanates from its chest, adding a touch of advanced technology. The background is dominated by rows of boxes, suggesting a highly organized storage system. The floor is lined with wooden pallets, enhancing the industrial setting. The camera remains static, capturing the robot's poised stance amidst the orderly environment, with a shallow depth of field that keeps the focus on the robot while subtly blurring the background for a cinematic effect."
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"

output = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=832,
    height=480,
    num_frames=81,
    num_inference_steps=30,
    guidance_scale=5.0,
    conditioning_scale=0.0,
    generator=torch.Generator().manual_seed(0),
).frames[0]
export_to_video(output, "output_GGUF1.mp4", fps=16)
Traceback (most recent call last):
  File "C:\aiOWN\diffuser_webui\WanVace_GGUF.py", line 12, in <module>
    transformer_gguf = WanVACETransformer3DModel.from_single_file(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\nitin\miniconda3\envs\sddw-dev\Lib\site-packages\huggingface_hub\utils\_validators.py", line 114, in _inner_fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "C:\Users\nitin\miniconda3\envs\sddw-dev\Lib\site-packages\diffusers\loaders\single_file_model.py", line 451, in from_single_file
    dispatch_model(model, **device_map_kwargs)
  File "C:\Users\nitin\miniconda3\envs\sddw-dev\Lib\site-packages\accelerate\big_modeling.py", line 502, in dispatch_model
    model.to(device)
  File "C:\Users\nitin\miniconda3\envs\sddw-dev\Lib\site-packages\diffusers\models\modeling_utils.py", line 1383, in to
    return super().to(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\nitin\miniconda3\envs\sddw-dev\Lib\site-packages\torch\nn\modules\module.py", line 1355, in to
    return self._apply(convert)
           ^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\nitin\miniconda3\envs\sddw-dev\Lib\site-packages\torch\nn\modules\module.py", line 915, in _apply
    module._apply(fn)
  File "C:\Users\nitin\miniconda3\envs\sddw-dev\Lib\site-packages\torch\nn\modules\module.py", line 942, in _apply
    param_applied = fn(param)
                    ^^^^^^^^^
  File "C:\Users\nitin\miniconda3\envs\sddw-dev\Lib\site-packages\torch\nn\modules\module.py", line 1348, in convert
    raise NotImplementedError(
NotImplementedError: Cannot copy out of meta tensor; no data! Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() when moving module from meta to a different device.

@nitinmukesh
Copy link

and if I comment config and subfolder

transformer_path = f"https://huggingface.co/samuelchristlie/Wan2.1-VACE-1.3B-GGUF/blob/main/Wan2.1-VACE-1.3B-Q8_0.gguf"
transformer_gguf = WanVACETransformer3DModel.from_single_file(
    transformer_path,
    quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
    torch_dtype=torch.bfloat16,
    # config=model_id,
    # subfolder="transformer",
)
  File "C:\aiOWN\diffuser_webui\WanVace_GGUF.py", line 12, in <module>
    transformer_gguf = WanVACETransformer3DModel.from_single_file(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\nitin\miniconda3\envs\sddw-dev\Lib\site-packages\huggingface_hub\utils\_validators.py", line 114, in _inner_fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "C:\Users\nitin\miniconda3\envs\sddw-dev\Lib\site-packages\diffusers\loaders\single_file_model.py", line 389, in from_single_file
    model = cls.from_config(diffusers_model_config)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\nitin\miniconda3\envs\sddw-dev\Lib\site-packages\diffusers\configuration_utils.py", line 263, in from_config
    model = cls(**init_dict)
            ^^^^^^^^^^^^^^^^
  File "C:\Users\nitin\miniconda3\envs\sddw-dev\Lib\site-packages\diffusers\configuration_utils.py", line 693, in inner_init
    init(self, *args, **init_kwargs)
  File "C:\Users\nitin\miniconda3\envs\sddw-dev\Lib\site-packages\diffusers\models\transformers\transformer_wan_vace.py", line 214, in __init__
    raise ValueError(f"VACE layers {vace_layers} exceed the number of transformer layers {num_layers}.")
ValueError: VACE layers [0, 5, 10, 15, 20, 25, 30, 35] exceed the number of transformer layers 30.

@DN6
Copy link
Collaborator

DN6 commented Jun 26, 2025

@J4BEZ will take a look into the conversion issue

@DN6
Copy link
Collaborator

DN6 commented Jun 27, 2025

Hmm @J4BEZ it does look like the issue loading the 1.3B checkpoint you linked is indeed due the missing key in the file. This version has the missing key, and loading works fine.
https://huggingface.co/calcuis/wan-gguf/blob/main/wan2.1-v4-vace-1.3b-q4_0.gguf

@DN6 DN6 marked this pull request as ready for review June 27, 2025 09:45
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@DN6
Copy link
Collaborator

DN6 commented Jun 27, 2025

@bot /style

Copy link
Contributor

github-actions bot commented Jun 27, 2025

Style bot fixed some files and pushed the changes.

@J4BEZ
Copy link
Contributor Author

J4BEZ commented Jun 28, 2025

Thank you so much 🙇‍♂️
I'm truly delighted to hear that everything is loading perfectly!

I'm especially thrilled to contribute to diffusers, which I consider a true work of art.

I deeply appreciate your continued dedication to supporting the open-source ecosystem.

Wishing everyone a peaceful and smooth end to the week.

@juntaosun
Copy link

Great, when will it be merged? load GGUF

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Jul 2, 2025

cc @DN6 Iooks like it's ready to merge now?

@DN6 DN6 merged commit 0e95aa8 into huggingface:main Jul 2, 2025
12 checks passed
tolgacangoz pushed a commit to tolgacangoz/diffusers that referenced this pull request Jul 5, 2025
…ansformer` (huggingface#11807)

* add `WandVACETransformer3DModel` in`SINGLE_FILE_LOADABLE_CLASSES`

* add rename keys for `VACE`

add rename keys for `VACE`

* fix typo

Sincere thanks to @nitinmukesh 🙇‍♂️

* support for `1.3B VACE` model

Sincere thanks to @nitinmukesh again🙇‍♂️

* update

* update

* Apply style fixes

---------

Co-authored-by: Dhruv Nair <[email protected]>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
@tin2tin
Copy link

tin2tin commented Jul 7, 2025

Trying the code above with the weight DN6 linked throws an error.

The code:

from typing import List
import torch
import PIL.Image
from diffusers import AutoencoderKLWan, WanVACEPipeline, WanVACETransformer3DModel
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
from diffusers.utils import export_to_video, load_image, load_video
from diffusers import GGUFQuantizationConfig

model_id = "a-r-r-o-w/Wan-VACE-1.3B-diffusers"
# transformer_path = f"https://huggingface.co/newgenai79/Wan-VACE-1.3B-diffusers-gguf/blob/main/Wan-VACE-1.3B-diffusers-Q8_0.gguf"
transformer_path = f"https://huggingface.co/calcuis/wan-gguf/blob/main/wan2.1-v4-vace-1.3b-q4_0.gguf"
transformer_gguf = WanVACETransformer3DModel.from_single_file(
    transformer_path,
    quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
    torch_dtype=torch.bfloat16,
    config=model_id,
    subfolder="transformer",
)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanVACEPipeline.from_pretrained(
    model_id,
    transformer=transformer_gguf,
    vae=vae, 
    torch_dtype=torch.bfloat16
)
flow_shift = 3.0  # 5.0 for 720P, 3.0 for 480P
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
pipe.enable_model_cpu_offload()
pipe.vae.enable_tiling()


prompt = "A sleek, humanoid robot stands in a vast warehouse filled with neatly stacked cardboard boxes on industrial shelves. The robot's metallic body gleams under the bright, even lighting, highlighting its futuristic design and intricate joints. A glowing blue light emanates from its chest, adding a touch of advanced technology. The background is dominated by rows of boxes, suggesting a highly organized storage system. The floor is lined with wooden pallets, enhancing the industrial setting. The camera remains static, capturing the robot's poised stance amidst the orderly environment, with a shallow depth of field that keeps the focus on the robot while subtly blurring the background for a cinematic effect."
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"

output = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=832,
    height=480,
    num_frames=81,
    num_inference_steps=30,
    guidance_scale=5.0,
    conditioning_scale=0.0,
    generator=torch.Generator().manual_seed(0),
).frames[0]
export_to_video(output, "output_GGUF1.mp4", fps=16)

The Error:

config.json: 100%|████████████████████████████████████████████████████████████████████████████| 662/662 [00:00<?, ?B/s]
config.json: 100%|████████████████████████████████████████████████████████████████████████████| 724/724 [00:00<?, ?B/s]
diffusion_pytorch_model.safetensors: 100%|██████████████████████████████████████████| 508M/508M [00:15<00:00, 33.7MB/s]
model_index.json: 100%|███████████████████████████████████████████████████████████████████████| 408/408 [00:00<?, ?B/s]
scheduler_config.json: 100%|██████████████████████████████████████████████████████████████████| 751/751 [00:00<?, ?B/s]
special_tokens_map.json: 7.08kB [00:00, 1.23MB/s]                                       | 2/13 [00:00<00:01,  5.87it/s]
config.json: 100%|████████████████████████████████████████████████████████████████████████████| 850/850 [00:00<?, ?B/s]
model.safetensors.index.json: 22.5kB [00:00, ?B/s]                                      | 3/13 [00:00<00:01,  5.92it/s]
tokenizer_config.json: 61.8kB [00:00, ?B/s]                                                  | 0.00/850 [00:00<?, ?B/s]
spiece.model: 100%|███████████████████████████████████████████████████████████████| 4.55M/4.55M [00:00<00:00, 10.3MB/s]
tokenizer.json: 100%|█████████████████████████████████████████████████████████████| 16.8M/16.8M [00:03<00:00, 4.45MB/s]
model-00003-of-00003.safetensors: 100%|███████████████████████████████████████████| 1.44G/1.44G [02:59<00:00, 8.03MB/s]
model-00002-of-00003.safetensors: 100%|███████████████████████████████████████████| 4.98G/4.98G [03:52<00:00, 21.4MB/s]
model-00001-of-00003.safetensors: 100%|███████████████████████████████████████████| 4.94G/4.94G [04:25<00:00, 18.6MB/s]
Fetching 13 files: 100%|███████████████████████████████████████████████████████████████| 13/13 [04:26<00:00, 20.49s/it]
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 146.18it/s]
Loading pipeline components...: 100%|████████████████████████████████████████████████████| 5/5 [00:00<00:00, 11.13it/s]
  0%|                                                                                           | 0/30 [00:02<?, ?it/s]
Error: Python: Traceback (most recent call last):
  File ".\python\Lib\site-packages\torch\utils\_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File ".\python\Lib\site-packages\diffusers\pipelines\wan\pipeline_wan_vace.py", line 909, in __call__
    noise_pred = self.transformer(
                 ^^^^^^^^^^^^^^^^^
  File ".\python\Lib\site-packages\torch\nn\modules\module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".\python\Lib\site-packages\torch\nn\modules\module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".\python\Lib\site-packages\accelerate\hooks.py", line 175, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".\python\Lib\site-packages\diffusers\models\transformers\transformer_wan_vace.py", line 324, in forward
    temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
                                                                              ^^^^^^^^^^^^^^^^^^^^^^^^
  File ".\python\Lib\site-packages\torch\nn\modules\module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".\python\Lib\site-packages\torch\nn\modules\module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".\python\Lib\site-packages\diffusers\models\transformers\transformer_wan.py", line 178, in forward
    temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".\python\Lib\site-packages\torch\nn\modules\module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".\python\Lib\site-packages\torch\nn\modules\module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".\python\Lib\site-packages\diffusers\models\embeddings.py", line 1308, in forward
    sample = self.linear_1(sample)
             ^^^^^^^^^^^^^^^^^^^^^
  File ".\python\Lib\site-packages\torch\nn\modules\module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".\python\Lib\site-packages\torch\nn\modules\module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".\python\Lib\site-packages\diffusers\quantizers\gguf\utils.py", line 460, in forward
    output = torch.nn.functional.linear(inputs, weight, bias)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: mat1 and mat2 must have the same dtype, but got Byte and BFloat16

@nitinmukesh
Copy link

nitinmukesh commented Jul 7, 2025

1.3b didn't worked for me either and I gave up. Not sure what is wrong.

wan2.1-v4-vace-1.3b-q8_0.gguf is only ~2GB and would have been very helpful for low VRAM.

@tin2tin
Copy link

tin2tin commented Jul 7, 2025

@nitinmukesh Can you get 14b working on consumer hardware?

@nitinmukesh
Copy link

nitinmukesh commented Jul 7, 2025

@tin2tin

I would not even try it provided I have only 8 GB + 16 GB.

It should easily work for you as you have 24 GB VRAM

https://huggingface.co/calcuis/wan-gguf/resolve/main/wan2.1-v2-vace-14b-q4_0.gguf?download=true
~10 GB

@nitinmukesh
Copy link

Also logged the issue here
#11878

@tin2tin
Copy link

tin2tin commented Jul 7, 2025

14b on 4090, just sits there like this for a very long time, until I kill it:
image

@nitinmukesh
Copy link

Exactly same problem I face with 1.3B too.
Something is wrong, not sure why.

@nitinmukesh
Copy link

@tin2tin

Can you try removing and see if it works
config=model_id,
subfolder="transformer",

transformer_gguf = WanVACETransformer3DModel.from_single_file(
transformer_path,
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
torch_dtype=torch.bfloat16,
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants