Skip to content

Added the ability to set SDXL Micro-Conditioning embeddings as 0 #4208

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
budui opened this issue Jul 22, 2023 · 6 comments · Fixed by #4774
Closed

Added the ability to set SDXL Micro-Conditioning embeddings as 0 #4208

budui opened this issue Jul 22, 2023 · 6 comments · Fixed by #4774

Comments

@budui
Copy link
Contributor

budui commented Jul 22, 2023

Is your feature request related to a problem? Please describe.

During the SDXL training process, it may be necessary to pass in a zero embedding as Micro-Conditioning embeddings:

https://github.com/Stability-AI/generative-models/blob/e25e4c0df1d01fb9720f62c73b4feab2e4003e3f/sgm/modules/encoders/modules.py#L151-L161

# those line will randomly set embedding as zero if `ucg_rate` > 0
                if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None:
                    emb = (
                        expand_dims_like(
                            torch.bernoulli(
                                (1.0 - embedder.ucg_rate)
                                * torch.ones(emb.shape[0], device=emb.device)
                            ),
                            emb,
                        )
                        * emb
                    )

https://github.com/Stability-AI/generative-models/blob/e25e4c0df1d01fb9720f62c73b4feab2e4003e3f/configs/example_training/txt2img-clipl-legacy-ucg-training.yaml#L65

# SDXL set  the `ucg_rate` of `original_size_as_tuple` embedder as 0.1. 
# so during traning, we need to pass zero embedding as added embedding for time embedding of Unet
            ucg_rate: 0.1
            input_key: original_size_as_tuple
            target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
            params:
              outdim: 256  # multiplied by two

Current SDXL-UNet2DConditionModel accepts encoder_hidden_states, time_ids and add_text_embeds as condition.

text_embeds = added_cond_kwargs.get("text_embeds")
if "time_ids" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
)
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(emb.dtype)
aug_emb = self.add_embedding(add_embeds)

To correctly finetune the SDXL model, we need to randomly set the condition embeddings to 0 with a suitable probability.
While it is easy to set encoder_hidden_states and add_text_embeds as zero embedding, It is impossible to zero time_embeds at line 849.

original SDXL uses different embedders to convert different micro-conditions into Fourier features. during training, different Fourier features are independently randomly set to 0. Therefore, UNet2DConditionModel need to be able to independently zero time_embeds part.

Describe the solution you'd like

Added the ability to set SDXL Micro-Conditioning embeddings as 0.

Describe alternatives you've considered

Perhaps it is possible to allow diffusers users to pass in a time_embeds, and if time_embeds exists, time_ids are no longer used?

if "time_embeds" in added_cond_kwargs:
    time_embeds = added_cond_kwargs.get("time_embeds") 
else:
    time_ids = added_cond_kwargs.get("time_ids") 
     time_embeds = self.add_time_proj(time_ids.flatten()) 
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) 
@sayakpaul
Copy link
Member

Thanks for the detailed issue. Yes, we're aware of this issue.

@patrickvonplaten I suppose you were working on it?

@patrickvonplaten
Copy link
Contributor

Actually only now noticed this - thanks for bringing it up @budui !

Do you think it's also important to provide this feature for inference or just for training?

@budui
Copy link
Contributor Author

budui commented Jul 25, 2023

Both training and inference should require this feature. For training, diffusers may need to have the ability to reproduce Stability AI's training scripts. For inference, the current SDXL Pipeline lacks the ability to specify a negative micro condition (specified as a specific value or zero embedding).

I did a quick experiment, specifying a negative condition:

A: condition and negative conditon use the same micro condition as diffusers SDXL pipeline doing now.

# prompt: "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
# seed: 1000
# original size (1024, 1024) vs (1024, 1024)
condition=dict(
        caption=prompt,
        crop_left=0,
        crop_top=0,
        original_height=1024,
        original_width=1024,
        target_height=1024,
        target_width=1024,
),
negative_condition=dict(
        caption="",
        crop_left=0,
        crop_top=0,
        original_height=1024,
        original_width=1024,
        target_height=1024,
        target_width=1024,
 ),

size1

B: Negative conditions use a lower original size, resulting in a clearer image

# prompt: "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
# seed: 1000
# original size (1024, 1024) vs (512, 512)
condition=dict(
        caption=prompt,
        crop_left=0,
        crop_top=0,
        original_height=1024,
        original_width=1024,
        target_height=1024,
        target_width=1024,
),
negative_condition=dict(
        caption="",
        crop_left=0,
        crop_top=0,
        original_height=512,
        original_width=512,
        target_height=1024,
        target_width=1024,
 ),

size2-512

I haven't come to the effect of using zero embedding as a negative condition, because I haven't found a quick workaround to do it. But I'd be happy to do more testing after diffusers add a way to specify zero embedding in UNet

@sayakpaul
Copy link
Member

@budui sorry for the delay on our end. Would you maybe be willing to contribute this feature in a PR? We're more than happy to help out.

@patrickvonplaten
Copy link
Contributor

@sayakpaul do you want to give this PR/issue a try?

@sayakpaul
Copy link
Member

Yeah

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants