Skip to content

Commit cd1b8d7

Browse files
[WIP] Refactor UniDiffuser Pipeline and Tests (huggingface#4948)
* Add VAE slicing and tiling methods. * Switch to using VaeImageProcessing for preprocessing and postprocessing of images. * Rename the VaeImageProcessor to vae_image_processor to avoid a name clash with the CLIPImageProcessor (image_processor). * Remove the postprocess() function because we're using a VaeImageProcessor instead. * Remove UniDiffuserPipeline.decode_image_latents because we're using VaeImageProcessor instead. * Refactor generating text from text latents into a decode_text_latents method. * Add enable_full_determinism() to UniDiffuser tests. * make style * Add PipelineLatentTesterMixin to UniDiffuserPipelineFastTests. * Remove enable_model_cpu_offload since it is now part of DiffusionPipeline. * Rename the VaeImageProcessor instance to self.image_processor for consistency with other pipelines and rename the CLIPImageProcessor instance to clip_image_processor to avoid a name clash. * Update UniDiffuser conversion script. * Make safe_serialization configurable in UniDiffuser conversion script. * Rename image_processor to clip_image_processor in UniDiffuser tests. * Add PipelineKarrasSchedulerTesterMixin to UniDiffuserPipelineFastTests. * Add initial test for compiling the UniDiffuser model (not tested yet). * Update encode_prompt and _encode_prompt to match that of StableDiffusionPipeline. * Turn off standard classifier-free guidance for now. * make style * make fix-copies * apply suggestions from review --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent db91e71 commit cd1b8d7

File tree

3 files changed

+270
-153
lines changed

3 files changed

+270
-153
lines changed

scripts/convert_unidiffuser_to_diffusers.py

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -73,17 +73,17 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
7373
new_item = new_item.replace("norm.weight", "group_norm.weight")
7474
new_item = new_item.replace("norm.bias", "group_norm.bias")
7575

76-
new_item = new_item.replace("q.weight", "query.weight")
77-
new_item = new_item.replace("q.bias", "query.bias")
76+
new_item = new_item.replace("q.weight", "to_q.weight")
77+
new_item = new_item.replace("q.bias", "to_q.bias")
7878

79-
new_item = new_item.replace("k.weight", "key.weight")
80-
new_item = new_item.replace("k.bias", "key.bias")
79+
new_item = new_item.replace("k.weight", "to_k.weight")
80+
new_item = new_item.replace("k.bias", "to_k.bias")
8181

82-
new_item = new_item.replace("v.weight", "value.weight")
83-
new_item = new_item.replace("v.bias", "value.bias")
82+
new_item = new_item.replace("v.weight", "to_v.weight")
83+
new_item = new_item.replace("v.bias", "to_v.bias")
8484

85-
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
86-
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
85+
new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
86+
new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
8787

8888
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
8989

@@ -92,6 +92,19 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
9292
return mapping
9393

9494

95+
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear
96+
def conv_attn_to_linear(checkpoint):
97+
keys = list(checkpoint.keys())
98+
attn_keys = ["query.weight", "key.weight", "value.weight"]
99+
for key in keys:
100+
if ".".join(key.split(".")[-2:]) in attn_keys:
101+
if checkpoint[key].ndim > 2:
102+
checkpoint[key] = checkpoint[key][:, :, 0, 0]
103+
elif "proj_attn.weight" in key:
104+
if checkpoint[key].ndim > 2:
105+
checkpoint[key] = checkpoint[key][:, :, 0]
106+
107+
95108
# Modified from diffusers.pipelines.stable_diffusion.convert_from_ckpt.assign_to_checkpoint
96109
# config.num_head_channels => num_head_channels
97110
def assign_to_checkpoint(
@@ -104,8 +117,9 @@ def assign_to_checkpoint(
104117
):
105118
"""
106119
This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
107-
attention layers, and takes into account additional replacements that may arise. Assigns the weights to the new
108-
checkpoint.
120+
attention layers, and takes into account additional replacements that may arise.
121+
122+
Assigns the weights to the new checkpoint.
109123
"""
110124
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
111125

@@ -143,25 +157,16 @@ def assign_to_checkpoint(
143157
new_path = new_path.replace(replacement["old"], replacement["new"])
144158

145159
# proj_attn.weight has to be converted from conv 1D to linear
146-
if "proj_attn.weight" in new_path:
160+
is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path)
161+
shape = old_checkpoint[path["old"]].shape
162+
if is_attn_weight and len(shape) == 3:
147163
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
164+
elif is_attn_weight and len(shape) == 4:
165+
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
148166
else:
149167
checkpoint[new_path] = old_checkpoint[path["old"]]
150168

151169

152-
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear
153-
def conv_attn_to_linear(checkpoint):
154-
keys = list(checkpoint.keys())
155-
attn_keys = ["query.weight", "key.weight", "value.weight"]
156-
for key in keys:
157-
if ".".join(key.split(".")[-2:]) in attn_keys:
158-
if checkpoint[key].ndim > 2:
159-
checkpoint[key] = checkpoint[key][:, :, 0, 0]
160-
elif "proj_attn.weight" in key:
161-
if checkpoint[key].ndim > 2:
162-
checkpoint[key] = checkpoint[key][:, :, 0]
163-
164-
165170
def create_vae_diffusers_config(config_type):
166171
# Hardcoded for now
167172
if args.config_type == "test":
@@ -339,7 +344,7 @@ def create_text_decoder_config_big():
339344
return text_decoder_config
340345

341346

342-
# Based on diffusers.pipelines.stable_diffusion.convert_from_ckpt.shave_segments.convert_ldm_vae_checkpoint
347+
# Based on diffusers.pipelines.stable_diffusion.convert_from_ckpt.convert_ldm_vae_checkpoint
343348
def convert_vae_to_diffusers(ckpt, diffusers_model, num_head_channels=1):
344349
"""
345350
Converts a UniDiffuser autoencoder_kl.pth checkpoint to a diffusers AutoencoderKL.
@@ -674,6 +679,11 @@ def convert_caption_decoder_to_diffusers(ckpt, diffusers_model):
674679
type=int,
675680
help="The UniDiffuser model type to convert to. Should be 0 for UniDiffuser-v0 and 1 for UniDiffuser-v1.",
676681
)
682+
parser.add_argument(
683+
"--safe_serialization",
684+
action="store_true",
685+
help="Whether to use safetensors/safe seialization when saving the pipeline.",
686+
)
677687

678688
args = parser.parse_args()
679689

@@ -766,11 +776,11 @@ def convert_caption_decoder_to_diffusers(ckpt, diffusers_model):
766776
vae=vae,
767777
text_encoder=text_encoder,
768778
image_encoder=image_encoder,
769-
image_processor=image_processor,
779+
clip_image_processor=image_processor,
770780
clip_tokenizer=clip_tokenizer,
771781
text_decoder=text_decoder,
772782
text_tokenizer=text_tokenizer,
773783
unet=unet,
774784
scheduler=scheduler,
775785
)
776-
pipeline.save_pretrained(args.pipeline_output_path)
786+
pipeline.save_pretrained(args.pipeline_output_path, safe_serialization=args.safe_serialization)

0 commit comments

Comments
 (0)