Skip to content

Commit 1051ca8

Browse files
yiyixuxuyiyixuxupatrickvonplatenpcuenca
authored
Stable Diffusion Latent Upscaler (huggingface#2059)
* Modify UNet2DConditionModel - allow skipping mid_block - adding a norm_group_size argument so that we can set the `num_groups` for group norm using `num_channels//norm_group_size` - allow user to set dimension for the timestep embedding (`time_embed_dim`) - the kernel_size for `conv_in` and `conv_out` is now configurable - add random fourier feature layer (`GaussianFourierProjection`) for `time_proj` - allow user to add the time and class embeddings before passing through the projection layer together - `time_embedding(t_emb + class_label))` - added 2 arguments `attn1_types` and `attn2_types` * currently we have argument `only_cross_attention`: when it's set to `True`, we will have a to the `BasicTransformerBlock` block with 2 cross-attention , otherwise we get a self-attention followed by a cross-attention; in k-upscaler, we need to have blocks that include just one cross-attention, or self-attention -> cross-attention; so I added `attn1_types` and `attn2_types` to the unet's argument list to allow user specify the attention types for the 2 positions in each block; note that I stil kept the `only_cross_attention` argument for unet for easy configuration, but it will be converted to `attn1_type` and `attn2_type` when passing down to the down blocks - the position of downsample layer and upsample layer is now configurable - in k-upscaler unet, there is only one skip connection per each up/down block (instead of each layer in stable diffusion unet), added `skip_freq = "block"` to support this use case - if user passes attention_mask to unet, it will prepare the mask and pass a flag to cross attention processer to skip the `prepare_attention_mask` step inside cross attention block add up/down blocks for k-upscaler modify CrossAttention class - make the `dropout` layer in `to_out` optional - `use_conv_proj` - use conv instead of linear for all projection layers (i.e. `to_q`, `to_k`, `to_v`, `to_out`) whenever possible. note that when it's used to do cross attention, to_k, to_v has to be linear because the `encoder_hidden_states` is not 2d - `cross_attention_norm` - add an optional layernorm on encoder_hidden_states - `attention_dropout`: add an optional dropout on attention score adapt BasicTransformerBlock - add an ada groupnorm layer to conditioning attention input with timestep embedding - allow skipping the FeedForward layer in between the attentions - replaced the only_cross_attention argument with attn1_type and attn2_type for more flexible configuration update timestep embedding: add new act_fn gelu and an optional act_2 modified ResnetBlock2D - refactored with AdaGroupNorm class (the timestep scale shift normalization) - add `mid_channel` argument - allow the first conv to have a different output dimension from the second conv - add option to use input AdaGroupNorm on the input instead of groupnorm - add options to add a dropout layer after each conv - allow user to set the bias in conv_shortcut (needed for k-upscaler) - add gelu adding conversion script for k-upscaler unet add pipeline * fix attention mask * fix a typo * fix a bug * make sure model can be used with GPU * make pipeline work with fp16 * fix an error in BasicTransfomerBlock * make style * fix typo * some more fixes * uP * up * correct more * some clean-up * clean time proj * up * uP * more changes * remove the upcast_attention=True from unet config * remove attn1_types, attn2_types etc * fix * revert incorrect changes up/down samplers * make style * remove outdated files * Apply suggestions from code review * attention refactor * refactor cross attention * Apply suggestions from code review * update * up * update * Apply suggestions from code review * finish * Update src/diffusers/models/cross_attention.py * more fixes * up * up * up * finish * more corrections of conversion state * act_2 -> act_2_fn * remove dropout_after_conv from ResnetBlock2D * make style * simplify KAttentionBlock * add fast test for latent upscaler pipeline * add slow test * slow test fp16 * make style * add doc string for pipeline_stable_diffusion_latent_upscale * add api doc page for latent upscaler pipeline * deprecate attention mask * clean up embeddings * simplify resnet * up * clean up resnet * up * correct more * up * up * improve a bit more * correct more * more clean-ups * Update docs/source/en/api/pipelines/stable_diffusion/latent_upscale.mdx Co-authored-by: Patrick von Platen <[email protected]> * Update docs/source/en/api/pipelines/stable_diffusion/latent_upscale.mdx Co-authored-by: Patrick von Platen <[email protected]> * add docstrings for new unet config * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py Co-authored-by: Patrick von Platen <[email protected]> * # Copied from * encode the image if not latent * remove force casting vae to fp32 * fix * add comments about preconditioning parameters from k-diffusion paper * attn1_type, attn2_type -> add_self_attention * clean up get_down_block and get_up_block * fix * fixed a typo(?) in ada group norm * update slice attention processer for cross attention * update slice * fix fast test * update the checkpoint * finish tests * fix-copies * fix-copy for modeling_text_unet.py * make style * make style * fix f-string * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py Co-authored-by: Patrick von Platen <[email protected]> * fix import * correct changes * fix resnet * make fix-copies * correct euler scheduler * add missing #copied from for preprocess * revert * fix * fix copies * Update docs/source/en/api/pipelines/stable_diffusion/latent_upscale.mdx Co-authored-by: Pedro Cuenca <[email protected]> * Update docs/source/en/api/pipelines/stable_diffusion/latent_upscale.mdx Co-authored-by: Pedro Cuenca <[email protected]> * Update docs/source/en/api/pipelines/stable_diffusion/latent_upscale.mdx Co-authored-by: Pedro Cuenca <[email protected]> * Update docs/source/en/api/pipelines/stable_diffusion/latent_upscale.mdx Co-authored-by: Pedro Cuenca <[email protected]> * Update src/diffusers/models/cross_attention.py Co-authored-by: Pedro Cuenca <[email protected]> * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py Co-authored-by: Pedro Cuenca <[email protected]> * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py Co-authored-by: Pedro Cuenca <[email protected]> * clean up conversion script * KDownsample2d,KUpsample2d -> KDownsample2D,KUpsample2D * more * Update src/diffusers/models/unet_2d_condition.py Co-authored-by: Pedro Cuenca <[email protected]> * remove prepare_extra_step_kwargs * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py Co-authored-by: Pedro Cuenca <[email protected]> * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py Co-authored-by: Patrick von Platen <[email protected]> * fix a typo in timestep embedding * remove num_image_per_prompt * fix fasttest * make style + fix-copies * fix * fix xformer test * fix style * doc string * make style * fix-copies * docstring for time_embedding_norm * make style * final finishes * make fix-copies * fix tests --------- Co-authored-by: yiyixuxu <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
1 parent 3b66cc0 commit 1051ca8

21 files changed

+2077
-97
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,8 @@
145145
title: Image-Variation
146146
- local: api/pipelines/stable_diffusion/upscale
147147
title: Super-Resolution
148+
- local: api/pipelines/stable_diffusion/latent_upscale
149+
title: Stable-Diffusion-Latent-Upscaler
148150
- local: api/pipelines/stable_diffusion/pix2pix
149151
title: InstructPix2Pix
150152
title: Stable Diffusion
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# Stable Diffusion Latent Upscaler
14+
15+
## StableDiffusionLatentUpscalePipeline
16+
17+
The Stable Diffusion Latent Upscaler model was created by [Katherine Crowson](https://github.com/crowsonkb/k-diffusion) in collaboration with [Stability AI](https://stability.ai/). It can be used on top of any [`StableDiffusionUpscalePipeline`] checkpoint to enhance its output image resolution by a factor of 2.
18+
19+
A notebook that demonstrates the original implementation can be found here:
20+
- [Stable Diffusion Upscaler Demo](https://colab.research.google.com/drive/1o1qYJcFeywzCIdkfKJy7cTpgZTCM2EI4)
21+
22+
Available Checkpoints are:
23+
- *stabilityai/latent-upscaler*: [stabilityai/sd-x2-latent-upscaler](https://huggingface.co/stabilityai/sd-x2-latent-upscaler)
24+
25+
26+
[[autodoc]] StableDiffusionLatentUpscalePipeline
27+
- all
28+
- __call__
29+
- enable_sequential_cpu_offload
30+
- enable_attention_slicing
31+
- disable_attention_slicing
32+
- enable_xformers_memory_efficient_attention
33+
- disable_xformers_memory_efficient_attention

docs/source/en/api/pipelines/stable_diffusion/overview.mdx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ For more details about how Stable Diffusion works and how it differs from the ba
3131
| [StableDiffusionDepth2ImgPipeline](./depth2img) | **Experimental** *Depth-to-Image Text-Guided Generation * | | Coming soon
3232
| [StableDiffusionImageVariationPipeline](./image_variation) | **Experimental** *Image Variation Generation * | | [🤗 Stable Diffusion Image Variations](https://huggingface.co/spaces/lambdalabs/stable-diffusion-image-variations)
3333
| [StableDiffusionUpscalePipeline](./upscale) | **Experimental** *Text-Guided Image Super-Resolution * | | Coming soon
34+
| [StableDiffusionLatentUpscalePipeline](./latent_upscale) | **Experimental** *Text-Guided Image Super-Resolution * | | Coming soon
3435
| [StableDiffusionInstructPix2PixPipeline](./pix2pix) | **Experimental** *Text-Based Image Editing * | | [InstructPix2Pix: Learning to Follow Image Editing Instructions](https://huggingface.co/spaces/timbrooks/instruct-pix2pix)
3536

3637

Lines changed: 297 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,297 @@
1+
import argparse
2+
3+
import torch
4+
5+
import huggingface_hub
6+
import k_diffusion as K
7+
from diffusers import UNet2DConditionModel
8+
9+
10+
UPSCALER_REPO = "pcuenq/k-upscaler"
11+
12+
13+
def resnet_to_diffusers_checkpoint(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix):
14+
rv = {
15+
# norm1
16+
f"{diffusers_resnet_prefix}.norm1.linear.weight": checkpoint[f"{resnet_prefix}.main.0.mapper.weight"],
17+
f"{diffusers_resnet_prefix}.norm1.linear.bias": checkpoint[f"{resnet_prefix}.main.0.mapper.bias"],
18+
# conv1
19+
f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.main.2.weight"],
20+
f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.main.2.bias"],
21+
# norm2
22+
f"{diffusers_resnet_prefix}.norm2.linear.weight": checkpoint[f"{resnet_prefix}.main.4.mapper.weight"],
23+
f"{diffusers_resnet_prefix}.norm2.linear.bias": checkpoint[f"{resnet_prefix}.main.4.mapper.bias"],
24+
# conv2
25+
f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.main.6.weight"],
26+
f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.main.6.bias"],
27+
}
28+
29+
if resnet.conv_shortcut is not None:
30+
rv.update(
31+
{
32+
f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.skip.weight"],
33+
}
34+
)
35+
36+
return rv
37+
38+
39+
def self_attn_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix):
40+
weight_q, weight_k, weight_v = checkpoint[f"{attention_prefix}.qkv_proj.weight"].chunk(3, dim=0)
41+
bias_q, bias_k, bias_v = checkpoint[f"{attention_prefix}.qkv_proj.bias"].chunk(3, dim=0)
42+
rv = {
43+
# norm
44+
f"{diffusers_attention_prefix}.norm1.linear.weight": checkpoint[f"{attention_prefix}.norm_in.mapper.weight"],
45+
f"{diffusers_attention_prefix}.norm1.linear.bias": checkpoint[f"{attention_prefix}.norm_in.mapper.bias"],
46+
# to_q
47+
f"{diffusers_attention_prefix}.attn1.to_q.weight": weight_q.squeeze(-1).squeeze(-1),
48+
f"{diffusers_attention_prefix}.attn1.to_q.bias": bias_q,
49+
# to_k
50+
f"{diffusers_attention_prefix}.attn1.to_k.weight": weight_k.squeeze(-1).squeeze(-1),
51+
f"{diffusers_attention_prefix}.attn1.to_k.bias": bias_k,
52+
# to_v
53+
f"{diffusers_attention_prefix}.attn1.to_v.weight": weight_v.squeeze(-1).squeeze(-1),
54+
f"{diffusers_attention_prefix}.attn1.to_v.bias": bias_v,
55+
# to_out
56+
f"{diffusers_attention_prefix}.attn1.to_out.0.weight": checkpoint[f"{attention_prefix}.out_proj.weight"]
57+
.squeeze(-1)
58+
.squeeze(-1),
59+
f"{diffusers_attention_prefix}.attn1.to_out.0.bias": checkpoint[f"{attention_prefix}.out_proj.bias"],
60+
}
61+
62+
return rv
63+
64+
65+
def cross_attn_to_diffusers_checkpoint(
66+
checkpoint, *, diffusers_attention_prefix, diffusers_attention_index, attention_prefix
67+
):
68+
weight_k, weight_v = checkpoint[f"{attention_prefix}.kv_proj.weight"].chunk(2, dim=0)
69+
bias_k, bias_v = checkpoint[f"{attention_prefix}.kv_proj.bias"].chunk(2, dim=0)
70+
71+
rv = {
72+
# norm2 (ada groupnorm)
73+
f"{diffusers_attention_prefix}.norm{diffusers_attention_index}.linear.weight": checkpoint[
74+
f"{attention_prefix}.norm_dec.mapper.weight"
75+
],
76+
f"{diffusers_attention_prefix}.norm{diffusers_attention_index}.linear.bias": checkpoint[
77+
f"{attention_prefix}.norm_dec.mapper.bias"
78+
],
79+
# layernorm on encoder_hidden_state
80+
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.norm_cross.weight": checkpoint[
81+
f"{attention_prefix}.norm_enc.weight"
82+
],
83+
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.norm_cross.bias": checkpoint[
84+
f"{attention_prefix}.norm_enc.bias"
85+
],
86+
# to_q
87+
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_q.weight": checkpoint[
88+
f"{attention_prefix}.q_proj.weight"
89+
]
90+
.squeeze(-1)
91+
.squeeze(-1),
92+
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_q.bias": checkpoint[
93+
f"{attention_prefix}.q_proj.bias"
94+
],
95+
# to_k
96+
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_k.weight": weight_k.squeeze(-1).squeeze(-1),
97+
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_k.bias": bias_k,
98+
# to_v
99+
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_v.weight": weight_v.squeeze(-1).squeeze(-1),
100+
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_v.bias": bias_v,
101+
# to_out
102+
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_out.0.weight": checkpoint[
103+
f"{attention_prefix}.out_proj.weight"
104+
]
105+
.squeeze(-1)
106+
.squeeze(-1),
107+
f"{diffusers_attention_prefix}.attn{diffusers_attention_index}.to_out.0.bias": checkpoint[
108+
f"{attention_prefix}.out_proj.bias"
109+
],
110+
}
111+
112+
return rv
113+
114+
115+
def block_to_diffusers_checkpoint(block, checkpoint, block_idx, block_type):
116+
block_prefix = "inner_model.u_net.u_blocks" if block_type == "up" else "inner_model.u_net.d_blocks"
117+
block_prefix = f"{block_prefix}.{block_idx}"
118+
119+
diffusers_checkpoint = {}
120+
121+
if not hasattr(block, "attentions"):
122+
n = 1 # resnet only
123+
elif not block.attentions[0].add_self_attention:
124+
n = 2 # resnet -> cross-attention
125+
else:
126+
n = 3 # resnet -> self-attention -> cross-attention)
127+
128+
for resnet_idx, resnet in enumerate(block.resnets):
129+
# diffusers_resnet_prefix = f"{diffusers_up_block_prefix}.resnets.{resnet_idx}"
130+
diffusers_resnet_prefix = f"{block_type}_blocks.{block_idx}.resnets.{resnet_idx}"
131+
idx = n * resnet_idx if block_type == "up" else n * resnet_idx + 1
132+
resnet_prefix = f"{block_prefix}.{idx}" if block_type == "up" else f"{block_prefix}.{idx}"
133+
134+
diffusers_checkpoint.update(
135+
resnet_to_diffusers_checkpoint(
136+
resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
137+
)
138+
)
139+
140+
if hasattr(block, "attentions"):
141+
for attention_idx, attention in enumerate(block.attentions):
142+
diffusers_attention_prefix = f"{block_type}_blocks.{block_idx}.attentions.{attention_idx}"
143+
idx = n * attention_idx + 1 if block_type == "up" else n * attention_idx + 2
144+
self_attention_prefix = f"{block_prefix}.{idx}"
145+
cross_attention_prefix = f"{block_prefix}.{idx }"
146+
cross_attention_index = 1 if not attention.add_self_attention else 2
147+
idx = (
148+
n * attention_idx + cross_attention_index
149+
if block_type == "up"
150+
else n * attention_idx + cross_attention_index + 1
151+
)
152+
cross_attention_prefix = f"{block_prefix}.{idx }"
153+
154+
diffusers_checkpoint.update(
155+
cross_attn_to_diffusers_checkpoint(
156+
checkpoint,
157+
diffusers_attention_prefix=diffusers_attention_prefix,
158+
diffusers_attention_index=2,
159+
attention_prefix=cross_attention_prefix,
160+
)
161+
)
162+
163+
if attention.add_self_attention is True:
164+
diffusers_checkpoint.update(
165+
self_attn_to_diffusers_checkpoint(
166+
checkpoint,
167+
diffusers_attention_prefix=diffusers_attention_prefix,
168+
attention_prefix=self_attention_prefix,
169+
)
170+
)
171+
172+
return diffusers_checkpoint
173+
174+
175+
def unet_to_diffusers_checkpoint(model, checkpoint):
176+
diffusers_checkpoint = {}
177+
178+
# pre-processing
179+
diffusers_checkpoint.update(
180+
{
181+
"conv_in.weight": checkpoint["inner_model.proj_in.weight"],
182+
"conv_in.bias": checkpoint["inner_model.proj_in.bias"],
183+
}
184+
)
185+
186+
# timestep and class embedding
187+
diffusers_checkpoint.update(
188+
{
189+
"time_proj.weight": checkpoint["inner_model.timestep_embed.weight"].squeeze(-1),
190+
"time_embedding.linear_1.weight": checkpoint["inner_model.mapping.0.weight"],
191+
"time_embedding.linear_1.bias": checkpoint["inner_model.mapping.0.bias"],
192+
"time_embedding.linear_2.weight": checkpoint["inner_model.mapping.2.weight"],
193+
"time_embedding.linear_2.bias": checkpoint["inner_model.mapping.2.bias"],
194+
"time_embedding.cond_proj.weight": checkpoint["inner_model.mapping_cond.weight"],
195+
}
196+
)
197+
198+
# down_blocks
199+
for down_block_idx, down_block in enumerate(model.down_blocks):
200+
diffusers_checkpoint.update(block_to_diffusers_checkpoint(down_block, checkpoint, down_block_idx, "down"))
201+
202+
# up_blocks
203+
for up_block_idx, up_block in enumerate(model.up_blocks):
204+
diffusers_checkpoint.update(block_to_diffusers_checkpoint(up_block, checkpoint, up_block_idx, "up"))
205+
206+
# post-processing
207+
diffusers_checkpoint.update(
208+
{
209+
"conv_out.weight": checkpoint["inner_model.proj_out.weight"],
210+
"conv_out.bias": checkpoint["inner_model.proj_out.bias"],
211+
}
212+
)
213+
214+
return diffusers_checkpoint
215+
216+
217+
def unet_model_from_original_config(original_config):
218+
in_channels = original_config["input_channels"] + original_config["unet_cond_dim"]
219+
out_channels = original_config["input_channels"] + (1 if original_config["has_variance"] else 0)
220+
221+
block_out_channels = original_config["channels"]
222+
223+
assert (
224+
len(set(original_config["depths"])) == 1
225+
), "UNet2DConditionModel currently do not support blocks with different number of layers"
226+
layers_per_block = original_config["depths"][0]
227+
228+
class_labels_dim = original_config["mapping_cond_dim"]
229+
cross_attention_dim = original_config["cross_cond_dim"]
230+
231+
attn1_types = []
232+
attn2_types = []
233+
for s, c in zip(original_config["self_attn_depths"], original_config["cross_attn_depths"]):
234+
if s:
235+
a1 = "self"
236+
a2 = "cross" if c else None
237+
elif c:
238+
a1 = "cross"
239+
a2 = None
240+
else:
241+
a1 = None
242+
a2 = None
243+
attn1_types.append(a1)
244+
attn2_types.append(a2)
245+
246+
unet = UNet2DConditionModel(
247+
in_channels=in_channels,
248+
out_channels=out_channels,
249+
down_block_types=("KDownBlock2D", "KCrossAttnDownBlock2D", "KCrossAttnDownBlock2D", "KCrossAttnDownBlock2D"),
250+
mid_block_type=None,
251+
up_block_types=("KCrossAttnUpBlock2D", "KCrossAttnUpBlock2D", "KCrossAttnUpBlock2D", "KUpBlock2D"),
252+
block_out_channels=block_out_channels,
253+
layers_per_block=layers_per_block,
254+
act_fn="gelu",
255+
norm_num_groups=None,
256+
cross_attention_dim=cross_attention_dim,
257+
attention_head_dim=64,
258+
time_cond_proj_dim=class_labels_dim,
259+
resnet_time_scale_shift="scale_shift",
260+
time_embedding_type="fourier",
261+
timestep_post_act="gelu",
262+
conv_in_kernel=1,
263+
conv_out_kernel=1,
264+
)
265+
266+
return unet
267+
268+
269+
def main(args):
270+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
271+
272+
orig_config_path = huggingface_hub.hf_hub_download(UPSCALER_REPO, "config_laion_text_cond_latent_upscaler_2.json")
273+
orig_weights_path = huggingface_hub.hf_hub_download(
274+
UPSCALER_REPO, "laion_text_cond_latent_upscaler_2_1_00470000_slim.pth"
275+
)
276+
print(f"loading original model configuration from {orig_config_path}")
277+
print(f"loading original model checkpoint from {orig_weights_path}")
278+
279+
print("converting to diffusers unet")
280+
orig_config = K.config.load_config(open(orig_config_path))["model"]
281+
model = unet_model_from_original_config(orig_config)
282+
283+
orig_checkpoint = torch.load(orig_weights_path, map_location=device)["model_ema"]
284+
converted_checkpoint = unet_to_diffusers_checkpoint(model, orig_checkpoint)
285+
286+
model.load_state_dict(converted_checkpoint, strict=True)
287+
model.save_pretrained(args.dump_path)
288+
print(f"saving converted unet model in {args.dump_path}")
289+
290+
291+
if __name__ == "__main__":
292+
parser = argparse.ArgumentParser()
293+
294+
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
295+
args = parser.parse_args()
296+
297+
main(args)

src/diffusers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@
115115
StableDiffusionInpaintPipeline,
116116
StableDiffusionInpaintPipelineLegacy,
117117
StableDiffusionInstructPix2PixPipeline,
118+
StableDiffusionLatentUpscalePipeline,
118119
StableDiffusionPipeline,
119120
StableDiffusionPipelineSafe,
120121
StableDiffusionUpscalePipeline,

src/diffusers/models/attention.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,3 +480,38 @@ def forward(self, x, timestep, class_labels, hidden_dtype=None):
480480
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
481481
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
482482
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
483+
484+
485+
class AdaGroupNorm(nn.Module):
486+
"""
487+
GroupNorm layer modified to incorporate timestep embeddings.
488+
"""
489+
490+
def __init__(
491+
self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5
492+
):
493+
super().__init__()
494+
self.num_groups = num_groups
495+
self.eps = eps
496+
self.act = None
497+
if act_fn == "swish":
498+
self.act = lambda x: F.silu(x)
499+
elif act_fn == "mish":
500+
self.act = nn.Mish()
501+
elif act_fn == "silu":
502+
self.act = nn.SiLU()
503+
elif act_fn == "gelu":
504+
self.act = nn.GELU()
505+
506+
self.linear = nn.Linear(embedding_dim, out_dim * 2)
507+
508+
def forward(self, x, emb):
509+
if self.act:
510+
emb = self.act(emb)
511+
emb = self.linear(emb)
512+
emb = emb[:, :, None, None]
513+
scale, shift = emb.chunk(2, dim=1)
514+
515+
x = F.group_norm(x, self.num_groups, eps=self.eps)
516+
x = x * (1 + scale) + shift
517+
return x

0 commit comments

Comments
 (0)