Skip to content

[fix] fix for prior preservation and mixed precision sampling #11873

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion examples/dreambooth/train_dreambooth_lora_flux_kontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,6 +1193,7 @@ def main(args):
subfolder="transformer",
revision=args.revision,
variant=args.variant,
torch_dtype=torch_dtype,
)
pipeline = FluxKontextPipeline.from_pretrained(
args.pretrained_model_name_or_path,
Expand All @@ -1215,7 +1216,8 @@ def main(args):
for example in tqdm(
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
):
images = pipeline(example["prompt"]).images
with torch.autocast(device_type=accelerator.device.type, dtype=torch_dtype):
images = pipeline(prompt=example["prompt"]).images

for i, image in enumerate(images):
hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
Expand Down Expand Up @@ -1789,6 +1791,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
device=accelerator.device,
prompt=args.instance_prompt,
)
else:
prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(
prompts, text_encoders, tokenizers
)

# Convert images to latent space
if args.cache_latents:
Expand Down