Skip to content

Commit f6a5c35

Browse files
[Community] Fix merger (huggingface#2006)
* [Community] Fix merger * finish
1 parent 651c5ad commit f6a5c35

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

examples/community/checkpoint_merger.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class CheckpointMergerPipeline(DiffusionPipeline):
3232
"""
3333

3434
def __init__(self):
35+
self.register_to_config()
3536
super().__init__()
3637

3738
def _compare_model_configs(self, dict0, dict1):
@@ -167,6 +168,7 @@ def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike]
167168
final_pipe = DiffusionPipeline.from_pretrained(
168169
cached_folders[0], torch_dtype=torch_dtype, device_map=device_map
169170
)
171+
final_pipe.to(self.device)
170172

171173
checkpoint_path_2 = None
172174
if len(cached_folders) > 2:
@@ -202,9 +204,9 @@ def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike]
202204
theta_0 = theta_0()
203205

204206
update_theta_0 = getattr(module, "load_state_dict")
205-
theta_1 = torch.load(checkpoint_path_1)
207+
theta_1 = torch.load(checkpoint_path_1, map_location="cpu")
206208

207-
theta_2 = torch.load(checkpoint_path_2) if checkpoint_path_2 else None
209+
theta_2 = torch.load(checkpoint_path_2, map_location="cpu") if checkpoint_path_2 else None
208210

209211
if not theta_0.keys() == theta_1.keys():
210212
print("SKIPPING ATTR ", attr, " DUE TO MISMATCH")

0 commit comments

Comments
 (0)