You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
* support transformer_layers_per block in flax UNet
* add support for text_time additional embeddings to Flax UNet
* rename attention layers for VAE
* add shape asserts when renaming attention layers
* transpose VAE attention layers
* add pipeline flax SDXL code [WIP]
* continue add pipeline flax SDXL code [WIP]
* cleanup
* Working on JIT support
Fixed prompt embedding shapes so they work in parallel mode. Assuming we
always have both text encoders for now, for simplicity.
* Fixing embeddings (untested)
* Remove spurious line
* Shard guidance_scale when jitting.
* Decode images
* Fix sharding
* style
* Refiner UNet can be loaded.
* Refiner / img2img pipeline
* Allow latent outputs from base and latent inputs in refiner
This makes it possible to chain base + refiner without having to use the
vae decoder in the base model, the vae encoder in the refiner, skipping
conversions to/from PIL, and avoiding TPU <-> CPU memory copies.
* Adapt to FlaxCLIPTextModelOutput
* Update Flax XL pipeline to FlaxCLIPTextModelOutput
* make fix-copies
* make style
* add euler scheduler
* Fix import
* Fix copies, comment unused code.
* Fix SDXL Flax imports
* Fix euler discrete begin
* improve init import
* finish
* put discrete euler in init
* fix flax euler
* Fix more
* make style
* correct init
* correct init
* Temporarily remove FlaxStableDiffusionXLImg2ImgPipeline
* correct pipelines
* finish
---------
Co-authored-by: Martin Müller <[email protected]>
Co-authored-by: patil-suraj <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
f"Need to provide argument `added_cond_kwargs` for {self.__class__} when using `addition_embed_type={self.addition_embed_type}`"
347
+
)
348
+
text_embeds=added_cond_kwargs.get("text_embeds")
349
+
iftext_embedsisNone:
350
+
raiseValueError(
351
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
352
+
)
353
+
time_ids=added_cond_kwargs.get("time_ids")
354
+
iftime_idsisNone:
355
+
raiseValueError(
356
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
0 commit comments