Skip to content

Support ControlNet models with different number of channels in control images #3815

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

JCBrouwer
Copy link
Contributor

What does this PR do?

I'd like to be able to load TemporalNet2 into a ControlNetModel.

This ControlNet concatenates the previous image in a video and the optical flow together to form a 6-channel tensor as the control image.

At the moment the convert_original_controlnet_to_diffusers.py script does not support converting the checkpoint, raising the following error due to the control image having more channels than expected:

Traceback
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/jcbgb/anaconda3/envs/maua/lib/python3.10/runpy.py:196 in _run_module_as_main               │
│                                                                                                  │
│   193 │   main_globals = sys.modules["__main__"].__dict__                                        │
│   194 │   if alter_argv:                                                                         │
│   195 │   │   sys.argv[0] = mod_spec.origin                                                      │
│ ❱ 196 │   return _run_code(code, main_globals, None,                                             │
│   197 │   │   │   │   │    "__main__", mod_spec)                                                 │
│   198                                                                                            │
│   199 def run_module(mod_name, init_globals=None,                                                │
│                                                                                                  │
│ /home/jcbgb/anaconda3/envs/maua/lib/python3.10/runpy.py:86 in _run_code                          │
│                                                                                                  │
│    83 │   │   │   │   │      __loader__ = loader,                                                │
│    84 │   │   │   │   │      __package__ = pkg_name,                                             │
│    85 │   │   │   │   │      __spec__ = mod_spec)                                                │
│ ❱  86 │   exec(code, run_globals)                                                                │
│    87 │   return run_globals                                                                     │
│    88                                                                                            │
│    89 def _run_module_code(code, init_globals=None,                                              │
│                                                                                                  │
│ /HUGE/Code/verwarming/diffusers/scripts/convert_original_controlnet_to_diffusers.py:96 in        │
│ <module>                                                                                         │
│                                                                                                  │
│    93 │                                                                                          │
│    94 │   args = parser.parse_args()                                                             │
│    95 │                                                                                          │
│ ❱  96 │   controlnet = download_controlnet_from_original_ckpt(                                   │
│    97 │   │   checkpoint_path=args.checkpoint_path,                                              │
│    98 │   │   original_config_file=args.original_config_file,                                    │
│    99 │   │   image_size=args.image_size,                                                        │
│                                                                                                  │
│ /home/jcbgb/anaconda3/envs/maua/lib/python3.10/site-packages/diffusers/pipelines/stable_diffusio │
│ n/convert_from_ckpt.py:1415 in download_controlnet_from_original_ckpt                            │
│                                                                                                  │
│   1412 │   if "control_stage_config" not in original_config.model.params:                        │
│   1413 │   │   raise ValueError("`control_stage_config` not present in original config")         │
│   1414 │                                                                                         │
│ ❱ 1415 │   controlnet_model = convert_controlnet_checkpoint(                                     │
│   1416 │   │   checkpoint,                                                                       │
│   1417 │   │   original_config,                                                                  │
│   1418 │   │   checkpoint_path,                                                                  │
│                                                                                                  │
│ /home/jcbgb/anaconda3/envs/maua/lib/python3.10/site-packages/diffusers/pipelines/stable_diffusio │
│ n/convert_from_ckpt.py:1002 in convert_controlnet_checkpoint                                     │
│                                                                                                  │
│    999 │   │   skip_extract_state_dict=skip_extract_state_dict,                                  │
│   1000 │   )                                                                                     │
│   1001 │                                                                                         │
│ ❱ 1002 │   controlnet_model.load_state_dict(converted_ctrl_checkpoint)                           │
│   1003 │                                                                                         │
│   1004 │   return controlnet_model                                                               │
│   1005                                                                                           │
│                                                                                                  │
│ /home/jcbgb/anaconda3/envs/maua/lib/python3.10/site-packages/torch/nn/modules/module.py:2041 in  │
│ load_state_dict                                                                                  │
│                                                                                                  │
│   2038 │   │   │   │   │   │   ', '.join('"{}"'.format(k) for k in missing_keys)))               │
│   2039 │   │                                                                                     │
│   2040 │   │   if len(error_msgs) > 0:                                                           │
│ ❱ 2041 │   │   │   raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(     │
│   2042 │   │   │   │   │   │   │      self.__class__.__name__, "\n\t".join(error_msgs)))         │
│   2043 │   │   return _IncompatibleKeys(missing_keys, unexpected_keys)                           │
│   2044                                                                                           │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Error(s) in loading state_dict for ControlNetModel:
        size mismatch for controlnet_cond_embedding.conv_in.weight: copying a param with shape 
torch.Size([16, 6, 3, 3]) from checkpoint, the shape in current model is torch.Size([16, 3, 3, 3]).

Adding the --num_in_channels argument doesn't help as this changes the in channels for the whole UNet (?).

Diving into the ControlNetModel code, I found that the model parameter giving the above error is controlled by the conditioning_channels variable which wasn't exposed in the UNet config. I've added it into the UNet config creation function in convert_from_ckpt.py based on the hint_channels specified in the ControlNet .yaml/.json file.

I'm not 100% that hint_channels is the right config value to read from (not too familiar with ControlNet), but with this change I'm able to convert the TemporalNet2 checkpoint as expected and able to load it into a ControlNetModel. The results seem to work fine based on some test videos.

Are there any tests I can add to verify that the change works?

@patrickvonplaten @sayakpaul

Before submitting

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works for me! @williamberman wdyt?

@sayakpaul
Copy link
Member

Hello @JCBrouwer. Thanks for your contribution. While @williamberman reviews your PR, I wanted to check if you'd like to:

  • Host the converted ControlNet weights on the HF Hub along with some guidance on the usage.
  • Include some examples that the users can expect from the pipeline.
  • Include the steps used to get the checkpoints successfully converted to the diffusers format.

I think the community would benefit a lot from this info.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 20, 2023

The documentation is not available anymore as the PR was closed or merged.

@JCBrouwer
Copy link
Contributor Author

Hi there, thanks for the response!

  • I've hosted the converted model in this repository along with a script to do vid2vid styling based on a text prompt.

  • I can take a look at generating some example outputs and comparing naive batched img2img, TemporalNet, and TemporalNet2.

  • With the change in this repo applied I could just use the standard convert_original_controlnet_to_diffusers.py script:

git clone https://huggingface.co/huggingface/diffusers
cd diffusers
git clone https://huggingface.co/CiaraRowles/TemporalNet2
python scripts/convert_original_controlnet_to_diffusers.py --checkpoint_path TemporalNet2/temporalnetversion2.ckpt --original_config_file TemporalNet2/temporalnetversion2.yaml --dump_path ./TemporalNet2Diffusers --extract_ema --to_safetensors --device cpu

@anotherjesse
Copy link

@JCBrouwer - thank you for sharing your progress!

I tried running it but my results were poor.

If you have any known good inputs to your script, I can try them out and share any issues I see

@patrickvonplaten patrickvonplaten merged commit ef3844d into huggingface:main Jun 21, 2023
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
…l images (huggingface#3815)

support ControlNet models with a different hint_channels value (e.g. TemporalNet2)
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
…l images (huggingface#3815)

support ControlNet models with a different hint_channels value (e.g. TemporalNet2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants