Skip to content

Commit 3651b14

Browse files
pcuencamar-muelpatil-surajpatrickvonplaten
authored
SDXL flax (huggingface#4254)
* support transformer_layers_per block in flax UNet * add support for text_time additional embeddings to Flax UNet * rename attention layers for VAE * add shape asserts when renaming attention layers * transpose VAE attention layers * add pipeline flax SDXL code [WIP] * continue add pipeline flax SDXL code [WIP] * cleanup * Working on JIT support Fixed prompt embedding shapes so they work in parallel mode. Assuming we always have both text encoders for now, for simplicity. * Fixing embeddings (untested) * Remove spurious line * Shard guidance_scale when jitting. * Decode images * Fix sharding * style * Refiner UNet can be loaded. * Refiner / img2img pipeline * Allow latent outputs from base and latent inputs in refiner This makes it possible to chain base + refiner without having to use the vae decoder in the base model, the vae encoder in the refiner, skipping conversions to/from PIL, and avoiding TPU <-> CPU memory copies. * Adapt to FlaxCLIPTextModelOutput * Update Flax XL pipeline to FlaxCLIPTextModelOutput * make fix-copies * make style * add euler scheduler * Fix import * Fix copies, comment unused code. * Fix SDXL Flax imports * Fix euler discrete begin * improve init import * finish * put discrete euler in init * fix flax euler * Fix more * make style * correct init * correct init * Temporarily remove FlaxStableDiffusionXLImg2ImgPipeline * correct pipelines * finish --------- Co-authored-by: Martin Müller <[email protected]> Co-authored-by: patil-suraj <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
1 parent 2e860e8 commit 3651b14

17 files changed

+1248
-488
lines changed

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,7 @@
368368
"FlaxDDIMScheduler",
369369
"FlaxDDPMScheduler",
370370
"FlaxDPMSolverMultistepScheduler",
371+
"FlaxEulerDiscreteScheduler",
371372
"FlaxKarrasVeScheduler",
372373
"FlaxLMSDiscreteScheduler",
373374
"FlaxPNDMScheduler",
@@ -395,6 +396,7 @@
395396
"FlaxStableDiffusionImg2ImgPipeline",
396397
"FlaxStableDiffusionInpaintPipeline",
397398
"FlaxStableDiffusionPipeline",
399+
"FlaxStableDiffusionXLPipeline",
398400
]
399401
)
400402

@@ -673,6 +675,7 @@
673675
FlaxDDIMScheduler,
674676
FlaxDDPMScheduler,
675677
FlaxDPMSolverMultistepScheduler,
678+
FlaxEulerDiscreteScheduler,
676679
FlaxKarrasVeScheduler,
677680
FlaxLMSDiscreteScheduler,
678681
FlaxPNDMScheduler,
@@ -691,6 +694,7 @@
691694
FlaxStableDiffusionImg2ImgPipeline,
692695
FlaxStableDiffusionInpaintPipeline,
693696
FlaxStableDiffusionPipeline,
697+
FlaxStableDiffusionXLPipeline,
694698
)
695699

696700
try:

src/diffusers/models/modeling_flax_pytorch_utils.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,25 @@ def rename_key(key):
4242
# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
4343
def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):
4444
"""Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""
45-
4645
# conv norm or layer norm
4746
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
47+
48+
# rename attention layers
49+
if len(pt_tuple_key) > 1:
50+
for rename_from, rename_to in (
51+
("to_out_0", "proj_attn"),
52+
("to_k", "key"),
53+
("to_v", "value"),
54+
("to_q", "query"),
55+
):
56+
if pt_tuple_key[-2] == rename_from:
57+
weight_name = pt_tuple_key[-1]
58+
weight_name = "kernel" if weight_name == "weight" else weight_name
59+
renamed_pt_tuple_key = pt_tuple_key[:-2] + (rename_to, weight_name)
60+
if renamed_pt_tuple_key in random_flax_state_dict:
61+
assert random_flax_state_dict[renamed_pt_tuple_key].shape == pt_tensor.T.shape
62+
return renamed_pt_tuple_key, pt_tensor.T
63+
4864
if (
4965
any("norm" in str_ for str_ in pt_tuple_key)
5066
and (pt_tuple_key[-1] == "bias")

src/diffusers/models/modeling_flax_utils.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -303,23 +303,23 @@ def from_pretrained(
303303
"framework": "flax",
304304
}
305305

306-
# Load config if we don't provide a configuration
307-
config_path = config if config is not None else pretrained_model_name_or_path
308-
model, model_kwargs = cls.from_config(
309-
config_path,
310-
cache_dir=cache_dir,
311-
return_unused_kwargs=True,
312-
force_download=force_download,
313-
resume_download=resume_download,
314-
proxies=proxies,
315-
local_files_only=local_files_only,
316-
use_auth_token=use_auth_token,
317-
revision=revision,
318-
subfolder=subfolder,
319-
# model args
320-
dtype=dtype,
321-
**kwargs,
322-
)
306+
# Load config if we don't provide one
307+
if config is None:
308+
config, unused_kwargs = cls.load_config(
309+
pretrained_model_name_or_path,
310+
cache_dir=cache_dir,
311+
return_unused_kwargs=True,
312+
force_download=force_download,
313+
resume_download=resume_download,
314+
proxies=proxies,
315+
local_files_only=local_files_only,
316+
use_auth_token=use_auth_token,
317+
revision=revision,
318+
subfolder=subfolder,
319+
**kwargs,
320+
)
321+
322+
model, model_kwargs = cls.from_config(config, dtype=dtype, return_unused_kwargs=True, **unused_kwargs)
323323

324324
# Load model
325325
pretrained_path_with_subfolder = (

src/diffusers/models/unet_2d_blocks_flax.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
5252
only_cross_attention: bool = False
5353
use_memory_efficient_attention: bool = False
5454
dtype: jnp.dtype = jnp.float32
55+
transformer_layers_per_block: int = 1
5556

5657
def setup(self):
5758
resnets = []
@@ -72,7 +73,7 @@ def setup(self):
7273
in_channels=self.out_channels,
7374
n_heads=self.num_attention_heads,
7475
d_head=self.out_channels // self.num_attention_heads,
75-
depth=1,
76+
depth=self.transformer_layers_per_block,
7677
use_linear_projection=self.use_linear_projection,
7778
only_cross_attention=self.only_cross_attention,
7879
use_memory_efficient_attention=self.use_memory_efficient_attention,
@@ -192,6 +193,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
192193
only_cross_attention: bool = False
193194
use_memory_efficient_attention: bool = False
194195
dtype: jnp.dtype = jnp.float32
196+
transformer_layers_per_block: int = 1
195197

196198
def setup(self):
197199
resnets = []
@@ -213,7 +215,7 @@ def setup(self):
213215
in_channels=self.out_channels,
214216
n_heads=self.num_attention_heads,
215217
d_head=self.out_channels // self.num_attention_heads,
216-
depth=1,
218+
depth=self.transformer_layers_per_block,
217219
use_linear_projection=self.use_linear_projection,
218220
only_cross_attention=self.only_cross_attention,
219221
use_memory_efficient_attention=self.use_memory_efficient_attention,
@@ -331,6 +333,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
331333
use_linear_projection: bool = False
332334
use_memory_efficient_attention: bool = False
333335
dtype: jnp.dtype = jnp.float32
336+
transformer_layers_per_block: int = 1
334337

335338
def setup(self):
336339
# there is always at least one resnet
@@ -350,7 +353,7 @@ def setup(self):
350353
in_channels=self.in_channels,
351354
n_heads=self.num_attention_heads,
352355
d_head=self.in_channels // self.num_attention_heads,
353-
depth=1,
356+
depth=self.transformer_layers_per_block,
354357
use_linear_projection=self.use_linear_projection,
355358
use_memory_efficient_attention=self.use_memory_efficient_attention,
356359
dtype=self.dtype,

src/diffusers/models/unet_2d_condition.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -883,7 +883,6 @@ def forward(
883883
time_ids = added_cond_kwargs.get("time_ids")
884884
time_embeds = self.add_time_proj(time_ids.flatten())
885885
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
886-
887886
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
888887
add_embeds = add_embeds.to(emb.dtype)
889888
aug_emb = self.add_embedding(add_embeds)

src/diffusers/models/unet_2d_condition_flax.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Optional, Tuple, Union
14+
from typing import Dict, Optional, Tuple, Union
1515

1616
import flax
1717
import flax.linen as nn
@@ -116,6 +116,11 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
116116
flip_sin_to_cos: bool = True
117117
freq_shift: int = 0
118118
use_memory_efficient_attention: bool = False
119+
transformer_layers_per_block: Union[int, Tuple[int]] = 1
120+
addition_embed_type: Optional[str] = None
121+
addition_time_embed_dim: Optional[int] = None
122+
addition_embed_type_num_heads: int = 64
123+
projection_class_embeddings_input_dim: Optional[int] = None
119124

120125
def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
121126
# init input tensors
@@ -127,7 +132,17 @@ def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
127132
params_rng, dropout_rng = jax.random.split(rng)
128133
rngs = {"params": params_rng, "dropout": dropout_rng}
129134

130-
return self.init(rngs, sample, timesteps, encoder_hidden_states)["params"]
135+
added_cond_kwargs = None
136+
if self.addition_embed_type == "text_time":
137+
# TODO: how to get this from the config? It's no longer cross_attention_dim
138+
text_embeds_dim = 1280
139+
time_ids_channels = self.projection_class_embeddings_input_dim - text_embeds_dim
140+
time_ids_dims = time_ids_channels // self.addition_time_embed_dim
141+
added_cond_kwargs = {
142+
"text_embeds": jnp.zeros((1, text_embeds_dim), dtype=jnp.float32),
143+
"time_ids": jnp.zeros((1, time_ids_dims), dtype=jnp.float32),
144+
}
145+
return self.init(rngs, sample, timesteps, encoder_hidden_states, added_cond_kwargs)["params"]
131146

132147
def setup(self):
133148
block_out_channels = self.block_out_channels
@@ -168,6 +183,24 @@ def setup(self):
168183
if isinstance(num_attention_heads, int):
169184
num_attention_heads = (num_attention_heads,) * len(self.down_block_types)
170185

186+
# transformer layers per block
187+
transformer_layers_per_block = self.transformer_layers_per_block
188+
if isinstance(transformer_layers_per_block, int):
189+
transformer_layers_per_block = [transformer_layers_per_block] * len(self.down_block_types)
190+
191+
# addition embed types
192+
if self.addition_embed_type is None:
193+
self.add_embedding = None
194+
elif self.addition_embed_type == "text_time":
195+
if self.addition_time_embed_dim is None:
196+
raise ValueError(
197+
f"addition_embed_type {self.addition_embed_type} requires `addition_time_embed_dim` to not be None"
198+
)
199+
self.add_time_proj = FlaxTimesteps(self.addition_time_embed_dim, self.flip_sin_to_cos, self.freq_shift)
200+
self.add_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
201+
else:
202+
raise ValueError(f"addition_embed_type: {self.addition_embed_type} must be None or `text_time`.")
203+
171204
# down
172205
down_blocks = []
173206
output_channel = block_out_channels[0]
@@ -182,6 +215,7 @@ def setup(self):
182215
out_channels=output_channel,
183216
dropout=self.dropout,
184217
num_layers=self.layers_per_block,
218+
transformer_layers_per_block=transformer_layers_per_block[i],
185219
num_attention_heads=num_attention_heads[i],
186220
add_downsample=not is_final_block,
187221
use_linear_projection=self.use_linear_projection,
@@ -207,6 +241,7 @@ def setup(self):
207241
in_channels=block_out_channels[-1],
208242
dropout=self.dropout,
209243
num_attention_heads=num_attention_heads[-1],
244+
transformer_layers_per_block=transformer_layers_per_block[-1],
210245
use_linear_projection=self.use_linear_projection,
211246
use_memory_efficient_attention=self.use_memory_efficient_attention,
212247
dtype=self.dtype,
@@ -218,6 +253,7 @@ def setup(self):
218253
reversed_num_attention_heads = list(reversed(num_attention_heads))
219254
only_cross_attention = list(reversed(only_cross_attention))
220255
output_channel = reversed_block_out_channels[0]
256+
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
221257
for i, up_block_type in enumerate(self.up_block_types):
222258
prev_output_channel = output_channel
223259
output_channel = reversed_block_out_channels[i]
@@ -231,6 +267,7 @@ def setup(self):
231267
out_channels=output_channel,
232268
prev_output_channel=prev_output_channel,
233269
num_layers=self.layers_per_block + 1,
270+
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
234271
num_attention_heads=reversed_num_attention_heads[i],
235272
add_upsample=not is_final_block,
236273
dropout=self.dropout,
@@ -269,6 +306,7 @@ def __call__(
269306
sample,
270307
timesteps,
271308
encoder_hidden_states,
309+
added_cond_kwargs: Optional[Union[Dict, FrozenDict]] = None,
272310
down_block_additional_residuals=None,
273311
mid_block_additional_residual=None,
274312
return_dict: bool = True,
@@ -300,6 +338,31 @@ def __call__(
300338
t_emb = self.time_proj(timesteps)
301339
t_emb = self.time_embedding(t_emb)
302340

341+
# additional embeddings
342+
aug_emb = None
343+
if self.addition_embed_type == "text_time":
344+
if added_cond_kwargs is None:
345+
raise ValueError(
346+
f"Need to provide argument `added_cond_kwargs` for {self.__class__} when using `addition_embed_type={self.addition_embed_type}`"
347+
)
348+
text_embeds = added_cond_kwargs.get("text_embeds")
349+
if text_embeds is None:
350+
raise ValueError(
351+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
352+
)
353+
time_ids = added_cond_kwargs.get("time_ids")
354+
if time_ids is None:
355+
raise ValueError(
356+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
357+
)
358+
# compute time embeds
359+
time_embeds = self.add_time_proj(jnp.ravel(time_ids)) # (1, 6) => (6,) => (6, 256)
360+
time_embeds = jnp.reshape(time_embeds, (text_embeds.shape[0], -1))
361+
add_embeds = jnp.concatenate([text_embeds, time_embeds], axis=-1)
362+
aug_emb = self.add_embedding(add_embeds)
363+
364+
t_emb = t_emb + aug_emb if aug_emb is not None else t_emb
365+
303366
# 2. pre-process
304367
sample = jnp.transpose(sample, (0, 2, 3, 1))
305368
sample = self.conv_in(sample)

0 commit comments

Comments
 (0)