Skip to content

Commit c28d694

Browse files
authored
[Community Pipeline] Checkpoint Merger based on Automatic1111 (huggingface#1472)
* Add checkpoint_merger pipeline * Added missing docs for a parameter. * Fomratting fixes. * Fixed code quality issues. * Bug fix: Off by 1 index * Added docs for pipeline
1 parent 5177e65 commit c28d694

File tree

2 files changed

+309
-0
lines changed

2 files changed

+309
-0
lines changed

examples/community/README.md

+47
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ If a community doesn't work as expected, please open an issue and ping the autho
2323
| Text Based Inpainting Stable Diffusion | Stable Diffusion Inpainting Pipeline that enables passing a text prompt to generate the mask for inpainting| [Text Based Inpainting Stable Diffusion](#image-to-image-inpainting-stable-diffusion) | - | [Dhruv Karan](https://github.com/unography) |
2424
| Bit Diffusion | Diffusion on discrete data | [Bit Diffusion](#bit-diffusion) | - |[Stuti R.](https://github.com/kingstut) |
2525
| K-Diffusion Stable Diffusion | Run Stable Diffusion with any of [K-Diffusion's samplers](https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py) | [Stable Diffusion with K Diffusion](#stable-diffusion-with-k-diffusion) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) |
26+
| Checkpoint Merger Pipeline | Diffusion Pipeline that enables merging of saved model checkpoints | [Checkpoint Merger Pipeline](#checkpoint-merger-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
2627

2728

2829

@@ -727,3 +728,49 @@ image = pipe(prompt, generator=generator, num_inference_steps=50).images[0]
727728

728729
![diffusers_euler](https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/k_diffusion/astronaut_euler_k_diffusion.png)
729730

731+
### Checkpoint Merger Pipeline
732+
Based on the AUTOMATIC1111/webui for checkpoint merging. This is a custom pipeline that merges upto 3 pretrained model checkpoints as long as they are in the HuggingFace model_index.json format.
733+
734+
The checkpoint merging is currently memory intensive as it modifies the weights of a DiffusionPipeline object in place. Expect atleast 13GB RAM Usage on Kaggle GPU kernels and
735+
on colab you might run out of the 12GB memory even while merging two checkpoints.
736+
737+
Usage:-
738+
```python
739+
from diffusers import DiffusionPipeline
740+
741+
#Return a CheckpointMergerPipeline class that allows you to merge checkpoints.
742+
#The checkpoint passed here is ignored. But still pass one of the checkpoints you plan to
743+
#merge for convenience
744+
pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="checkpoint_merger")
745+
746+
#There are multiple possible scenarios:
747+
#The pipeline with the merged checkpoints is returned in all the scenarios
748+
749+
#Compatible checkpoints a.k.a matched model_index.json files. Ignores the meta attributes in model_index.json during comparision.( attrs with _ as prefix )
750+
merged_pipe = pipe.merge(["CompVis/stable-diffusion-v1-4","CompVis/stable-diffusion-v1-2"], interp = "sigmoid", alpha = 0.4)
751+
752+
#Incompatible checkpoints in model_index.json but merge might be possible. Use force = True to ignore model_index.json compatibility
753+
merged_pipe_1 = pipe.merge(["CompVis/stable-diffusion-v1-4","hakurei/waifu-diffusion"], force = True, interp = "sigmoid", alpha = 0.4)
754+
755+
#Three checkpoint merging. Only "add_difference" method actually works on all three checkpoints. Using any other options will ignore the 3rd checkpoint.
756+
merged_pipe_2 = pipe.merge(["CompVis/stable-diffusion-v1-4","hakurei/waifu-diffusion","prompthero/openjourney"], force = True, interp = "add_difference", alpha = 0.4)
757+
758+
prompt = "An astronaut riding a horse on Mars"
759+
760+
image = merged_pipe(prompt).images[0]
761+
762+
```
763+
Some examples along with the merge details:
764+
765+
1. "CompVis/stable-diffusion-v1-4" + "hakurei/waifu-diffusion" ; Sigmoid interpolation; alpha = 0.8
766+
767+
![Stable plus Waifu Sigmoid 0.8](https://huggingface.co/datasets/NagaSaiAbhinay/CheckpointMergerSamples/resolve/main/stability_v1_4_waifu_sig_0.8.png)
768+
769+
2. "hakurei/waifu-diffusion" + "prompthero/openjourney" ; Inverse Sigmoid interpolation; alpha = 0.8
770+
771+
![Stable plus Waifu Sigmoid 0.8](https://huggingface.co/datasets/NagaSaiAbhinay/CheckpointMergerSamples/resolve/main/waifu_openjourney_inv_sig_0.8.png)
772+
773+
774+
3. "CompVis/stable-diffusion-v1-4" + "hakurei/waifu-diffusion" + "prompthero/openjourney"; Add Difference interpolation; alpha = 0.5
775+
776+
![Stable plus Waifu plus openjourney add_diff 0.5](https://huggingface.co/datasets/NagaSaiAbhinay/CheckpointMergerSamples/resolve/main/stable_waifu_openjourney_add_diff_0.5.png)
+262
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
import glob
2+
import os
3+
from typing import Dict, List, Union
4+
5+
import torch
6+
7+
from diffusers import DiffusionPipeline, __version__
8+
from diffusers.pipeline_utils import (
9+
CONFIG_NAME,
10+
DIFFUSERS_CACHE,
11+
ONNX_WEIGHTS_NAME,
12+
SCHEDULER_CONFIG_NAME,
13+
WEIGHTS_NAME,
14+
)
15+
from huggingface_hub import snapshot_download
16+
17+
18+
class CheckpointMergerPipeline(DiffusionPipeline):
19+
"""
20+
A class that that supports merging diffusion models based on the discussion here:
21+
https://github.com/huggingface/diffusers/issues/877
22+
23+
Example usage:-
24+
25+
pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="checkpoint_merger.py")
26+
27+
merged_pipe = pipe.merge(["CompVis/stable-diffusion-v1-4","prompthero/openjourney"], interp = 'inv_sigmoid', alpha = 0.8, force = True)
28+
29+
merged_pipe.to('cuda')
30+
31+
prompt = "An astronaut riding a unicycle on Mars"
32+
33+
results = merged_pipe(prompt)
34+
35+
## For more details, see the docstring for the merge method.
36+
37+
"""
38+
39+
def __init__(self):
40+
super().__init__()
41+
42+
def _compare_model_configs(self, dict0, dict1):
43+
if dict0 == dict1:
44+
return True
45+
else:
46+
config0, meta_keys0 = self._remove_meta_keys(dict0)
47+
config1, meta_keys1 = self._remove_meta_keys(dict1)
48+
if config0 == config1:
49+
print(f"Warning !: Mismatch in keys {meta_keys0} and {meta_keys1}.")
50+
return True
51+
return False
52+
53+
def _remove_meta_keys(self, config_dict: Dict):
54+
meta_keys = []
55+
temp_dict = config_dict.copy()
56+
for key in config_dict.keys():
57+
if key.startswith("_"):
58+
temp_dict.pop(key)
59+
meta_keys.append(key)
60+
return (temp_dict, meta_keys)
61+
62+
@torch.no_grad()
63+
def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike]], **kwargs):
64+
"""
65+
Returns a new pipeline object of the class 'DiffusionPipeline' with the merged checkpoints(weights) of the models passed
66+
in the argument 'pretrained_model_name_or_path_list' as a list.
67+
68+
Parameters:
69+
-----------
70+
pretrained_model_name_or_path_list : A list of valid pretrained model names in the HuggingFace hub or paths to locally stored models in the HuggingFace format.
71+
72+
**kwargs:
73+
Supports all the default DiffusionPipeline.get_config_dict kwargs viz..
74+
75+
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map.
76+
77+
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
78+
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
79+
80+
interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
81+
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
82+
83+
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
84+
85+
"""
86+
# Default kwargs from DiffusionPipeline
87+
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
88+
resume_download = kwargs.pop("resume_download", False)
89+
force_download = kwargs.pop("force_download", False)
90+
proxies = kwargs.pop("proxies", None)
91+
local_files_only = kwargs.pop("local_files_only", False)
92+
use_auth_token = kwargs.pop("use_auth_token", None)
93+
revision = kwargs.pop("revision", None)
94+
torch_dtype = kwargs.pop("torch_dtype", None)
95+
device_map = kwargs.pop("device_map", None)
96+
97+
alpha = kwargs.pop("alpha", 0.5)
98+
interp = kwargs.pop("interp", None)
99+
100+
print("Recieved list", pretrained_model_name_or_path_list)
101+
102+
checkpoint_count = len(pretrained_model_name_or_path_list)
103+
# Ignore result from model_index_json comparision of the two checkpoints
104+
force = kwargs.pop("force", False)
105+
106+
# If less than 2 checkpoints, nothing to merge. If more than 3, not supported for now.
107+
if checkpoint_count > 3 or checkpoint_count < 2:
108+
raise ValueError(
109+
"Received incorrect number of checkpoints to merge. Ensure that either 2 or 3 checkpoints are being"
110+
" passed."
111+
)
112+
113+
print("Received the right number of checkpoints")
114+
# chkpt0, chkpt1 = pretrained_model_name_or_path_list[0:2]
115+
# chkpt2 = pretrained_model_name_or_path_list[2] if checkpoint_count == 3 else None
116+
117+
# Validate that the checkpoints can be merged
118+
# Step 1: Load the model config and compare the checkpoints. We'll compare the model_index.json first while ignoring the keys starting with '_'
119+
config_dicts = []
120+
for pretrained_model_name_or_path in pretrained_model_name_or_path_list:
121+
if not os.path.isdir(pretrained_model_name_or_path):
122+
config_dict = DiffusionPipeline.get_config_dict(
123+
pretrained_model_name_or_path,
124+
cache_dir=cache_dir,
125+
resume_download=resume_download,
126+
force_download=force_download,
127+
proxies=proxies,
128+
local_files_only=local_files_only,
129+
use_auth_token=use_auth_token,
130+
revision=revision,
131+
)
132+
config_dicts.append(config_dict)
133+
134+
comparison_result = True
135+
for idx in range(1, len(config_dicts)):
136+
comparison_result &= self._compare_model_configs(config_dicts[idx - 1], config_dicts[idx])
137+
if not force and comparison_result is False:
138+
raise ValueError("Incompatible checkpoints. Please check model_index.json for the models.")
139+
print(config_dicts[0], config_dicts[1])
140+
print("Compatible model_index.json files found")
141+
# Step 2: Basic Validation has succeeded. Let's download the models and save them into our local files.
142+
cached_folders = []
143+
for pretrained_model_name_or_path, config_dict in zip(pretrained_model_name_or_path_list, config_dicts):
144+
folder_names = [k for k in config_dict.keys() if not k.startswith("_")]
145+
allow_patterns = [os.path.join(k, "*") for k in folder_names]
146+
allow_patterns += [
147+
WEIGHTS_NAME,
148+
SCHEDULER_CONFIG_NAME,
149+
CONFIG_NAME,
150+
ONNX_WEIGHTS_NAME,
151+
DiffusionPipeline.config_name,
152+
]
153+
requested_pipeline_class = config_dict.get("_class_name")
154+
user_agent = {"diffusers": __version__, "pipeline_class": requested_pipeline_class}
155+
156+
cached_folder = snapshot_download(
157+
pretrained_model_name_or_path,
158+
cache_dir=cache_dir,
159+
resume_download=resume_download,
160+
proxies=proxies,
161+
local_files_only=local_files_only,
162+
use_auth_token=use_auth_token,
163+
revision=revision,
164+
allow_patterns=allow_patterns,
165+
user_agent=user_agent,
166+
)
167+
print("Cached Folder", cached_folder)
168+
cached_folders.append(cached_folder)
169+
170+
# Step 3:-
171+
# Load the first checkpoint as a diffusion pipeline and modify it's module state_dict in place
172+
final_pipe = DiffusionPipeline.from_pretrained(
173+
cached_folders[0], torch_dtype=torch_dtype, device_map=device_map
174+
)
175+
176+
checkpoint_path_2 = None
177+
if len(cached_folders) > 2:
178+
checkpoint_path_2 = os.path.join(cached_folders[2])
179+
180+
if interp == "sigmoid":
181+
theta_func = CheckpointMergerPipeline.sigmoid
182+
elif interp == "inv_sigmoid":
183+
theta_func = CheckpointMergerPipeline.inv_sigmoid
184+
elif interp == "add_diff":
185+
theta_func = CheckpointMergerPipeline.add_difference
186+
else:
187+
theta_func = CheckpointMergerPipeline.weighted_sum
188+
189+
# Find each module's state dict.
190+
for attr in final_pipe.config.keys():
191+
if not attr.startswith("_"):
192+
checkpoint_path_1 = os.path.join(cached_folders[1], attr)
193+
if os.path.exists(checkpoint_path_1):
194+
files = glob.glob(os.path.join(checkpoint_path_1, "*.bin"))
195+
checkpoint_path_1 = files[0] if len(files) > 0 else None
196+
if checkpoint_path_2 is not None and os.path.exists(checkpoint_path_2):
197+
files = glob.glob(os.path.join(checkpoint_path_2, "*.bin"))
198+
checkpoint_path_2 = files[0] if len(files) > 0 else None
199+
# For an attr if both checkpoint_path_1 and 2 are None, ignore.
200+
# If atleast one is present, deal with it according to interp method, of course only if the state_dict keys match.
201+
if checkpoint_path_1 is None and checkpoint_path_2 is None:
202+
print("SKIPPING ATTR ", attr)
203+
continue
204+
try:
205+
module = getattr(final_pipe, attr)
206+
theta_0 = getattr(module, "state_dict")
207+
theta_0 = theta_0()
208+
209+
update_theta_0 = getattr(module, "load_state_dict")
210+
theta_1 = torch.load(checkpoint_path_1)
211+
212+
theta_2 = torch.load(checkpoint_path_2) if checkpoint_path_2 else None
213+
214+
if not theta_0.keys() == theta_1.keys():
215+
print("SKIPPING ATTR ", attr, " DUE TO MISMATCH")
216+
continue
217+
if theta_2 and not theta_1.keys() == theta_2.keys():
218+
print("SKIPPING ATTR ", attr, " DUE TO MISMATCH")
219+
except:
220+
print("SKIPPING ATTR ", attr)
221+
continue
222+
print("Found dicts for")
223+
print(attr)
224+
print(checkpoint_path_1)
225+
print(checkpoint_path_2)
226+
227+
for key in theta_0.keys():
228+
if theta_2:
229+
theta_0[key] = theta_func(theta_0[key], theta_1[key], theta_2[key], alpha)
230+
else:
231+
theta_0[key] = theta_func(theta_0[key], theta_1[key], None, alpha)
232+
233+
del theta_1
234+
del theta_2
235+
update_theta_0(theta_0)
236+
237+
del theta_0
238+
print("Diffusion pipeline successfully updated with merged weights")
239+
240+
return final_pipe
241+
242+
@staticmethod
243+
def weighted_sum(theta0, theta1, theta2, alpha):
244+
return ((1 - alpha) * theta0) + (alpha * theta1)
245+
246+
# Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
247+
@staticmethod
248+
def sigmoid(theta0, theta1, theta2, alpha):
249+
alpha = alpha * alpha * (3 - (2 * alpha))
250+
return theta0 + ((theta1 - theta0) * alpha)
251+
252+
# Inverse Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
253+
@staticmethod
254+
def inv_sigmoid(theta0, theta1, theta2, alpha):
255+
import math
256+
257+
alpha = 0.5 - math.sin(math.asin(1.0 - 2.0 * alpha) / 3.0)
258+
return theta0 + ((theta1 - theta0) * alpha)
259+
260+
@staticmethod
261+
def add_difference(theta0, theta1, theta2, alpha):
262+
return theta0 + (theta1 - theta2) * (1.0 - alpha)

0 commit comments

Comments
 (0)