Skip to content

Commit cdcc01b

Browse files
authored
[Examples] add compute_snr() to training utils. (huggingface#5188)
add compute_snr() to training utils.
1 parent ba59e92 commit cdcc01b

File tree

11 files changed

+46
-256
lines changed

11 files changed

+46
-256
lines changed

examples/dreambooth/train_dreambooth.py

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
UNet2DConditionModel,
5353
)
5454
from diffusers.optimization import get_scheduler
55+
from diffusers.training_utils import compute_snr
5556
from diffusers.utils import check_min_version, is_wandb_available
5657
from diffusers.utils.import_utils import is_xformers_available
5758

@@ -224,30 +225,6 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
224225
raise ValueError(f"{model_class} is not supported.")
225226

226227

227-
def compute_snr(timesteps, noise_scheduler):
228-
"""
229-
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
230-
"""
231-
alphas_cumprod = noise_scheduler.alphas_cumprod
232-
sqrt_alphas_cumprod = alphas_cumprod**0.5
233-
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
234-
# Expand the tensors.
235-
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
236-
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
237-
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
238-
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
239-
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
240-
241-
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
242-
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
243-
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
244-
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
245-
246-
# Compute SNR
247-
snr = (alpha / sigma) ** 2
248-
return snr
249-
250-
251228
def parse_args(input_args=None):
252229
parser = argparse.ArgumentParser(description="Simple example of a training script.")
253230
parser.add_argument(
@@ -1302,7 +1279,7 @@ def compute_text_embeddings(prompt):
13021279
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
13031280
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
13041281
# This is discussed in Section 4.2 of the same paper.
1305-
snr = compute_snr(timesteps, noise_scheduler)
1282+
snr = compute_snr(noise_scheduler, timesteps)
13061283
base_weight = (
13071284
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
13081285
)

examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
import diffusers
4343
from diffusers import AutoPipelineForText2Image, DDPMScheduler, UNet2DConditionModel, VQModel
4444
from diffusers.optimization import get_scheduler
45-
from diffusers.training_utils import EMAModel
45+
from diffusers.training_utils import EMAModel, compute_snr
4646
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
4747
from diffusers.utils.import_utils import is_xformers_available
4848

@@ -530,30 +530,6 @@ def deepspeed_zero_init_disabled_context_manager():
530530
else:
531531
raise ValueError("xformers is not available. Make sure it is installed correctly")
532532

533-
def compute_snr(timesteps):
534-
"""
535-
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
536-
"""
537-
alphas_cumprod = noise_scheduler.alphas_cumprod
538-
sqrt_alphas_cumprod = alphas_cumprod**0.5
539-
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
540-
541-
# Expand the tensors.
542-
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
543-
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
544-
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
545-
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
546-
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
547-
548-
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
549-
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
550-
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
551-
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
552-
553-
# Compute SNR.
554-
snr = (alpha / sigma) ** 2
555-
return snr
556-
557533
# `accelerate` 0.16.0 will have better support for customized saving
558534
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
559535
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
@@ -800,7 +776,7 @@ def collate_fn(examples):
800776
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
801777
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
802778
# This is discussed in Section 4.2 of the same paper.
803-
snr = compute_snr(timesteps)
779+
snr = compute_snr(noise_scheduler, timesteps)
804780
base_weight = (
805781
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
806782
)

examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from diffusers.loaders import AttnProcsLayers
4242
from diffusers.models.attention_processor import LoRAAttnAddedKVProcessor
4343
from diffusers.optimization import get_scheduler
44+
from diffusers.training_utils import compute_snr
4445
from diffusers.utils import check_min_version, is_wandb_available
4546

4647

@@ -419,30 +420,6 @@ def main():
419420

420421
unet.set_attn_processor(lora_attn_procs)
421422

422-
def compute_snr(timesteps):
423-
"""
424-
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
425-
"""
426-
alphas_cumprod = noise_scheduler.alphas_cumprod
427-
sqrt_alphas_cumprod = alphas_cumprod**0.5
428-
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
429-
430-
# Expand the tensors.
431-
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
432-
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
433-
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
434-
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
435-
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
436-
437-
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
438-
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
439-
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
440-
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
441-
442-
# Compute SNR.
443-
snr = (alpha / sigma) ** 2
444-
return snr
445-
446423
lora_layers = AttnProcsLayers(unet.attn_processors)
447424

448425
if args.allow_tf32:
@@ -653,7 +630,7 @@ def collate_fn(examples):
653630
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
654631
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
655632
# This is discussed in Section 4.2 of the same paper.
656-
snr = compute_snr(timesteps)
633+
snr = compute_snr(noise_scheduler, timesteps)
657634
base_weight = (
658635
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
659636
)

examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from diffusers.loaders import AttnProcsLayers
4242
from diffusers.models.attention_processor import LoRAAttnProcessor
4343
from diffusers.optimization import get_scheduler
44+
from diffusers.training_utils import compute_snr
4445
from diffusers.utils import check_min_version, is_wandb_available
4546

4647

@@ -413,31 +414,6 @@ def main():
413414
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=2048, rank=args.rank)
414415

415416
prior.set_attn_processor(lora_attn_procs)
416-
417-
def compute_snr(timesteps):
418-
"""
419-
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
420-
"""
421-
alphas_cumprod = noise_scheduler.alphas_cumprod
422-
sqrt_alphas_cumprod = alphas_cumprod**0.5
423-
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
424-
425-
# Expand the tensors.
426-
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
427-
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
428-
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
429-
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
430-
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
431-
432-
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
433-
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
434-
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
435-
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
436-
437-
# Compute SNR.
438-
snr = (alpha / sigma) ** 2
439-
return snr
440-
441417
lora_layers = AttnProcsLayers(prior.attn_processors)
442418

443419
if args.allow_tf32:
@@ -684,7 +660,7 @@ def collate_fn(examples):
684660
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
685661
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
686662
# This is discussed in Section 4.2 of the same paper.
687-
snr = compute_snr(timesteps)
663+
snr = compute_snr(noise_scheduler, timesteps)
688664
base_weight = (
689665
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
690666
)

examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
import diffusers
4343
from diffusers import AutoPipelineForText2Image, DDPMScheduler, PriorTransformer
4444
from diffusers.optimization import get_scheduler
45-
from diffusers.training_utils import EMAModel
45+
from diffusers.training_utils import EMAModel, compute_snr
4646
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
4747

4848

@@ -523,30 +523,6 @@ def deepspeed_zero_init_disabled_context_manager():
523523
ema_prior = EMAModel(ema_prior.parameters(), model_cls=PriorTransformer, model_config=ema_prior.config)
524524
ema_prior.to(accelerator.device)
525525

526-
def compute_snr(timesteps):
527-
"""
528-
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
529-
"""
530-
alphas_cumprod = noise_scheduler.alphas_cumprod
531-
sqrt_alphas_cumprod = alphas_cumprod**0.5
532-
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
533-
534-
# Expand the tensors.
535-
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
536-
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
537-
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
538-
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
539-
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
540-
541-
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
542-
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
543-
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
544-
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
545-
546-
# Compute SNR.
547-
snr = (alpha / sigma) ** 2
548-
return snr
549-
550526
# `accelerate` 0.16.0 will have better support for customized saving
551527
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
552528
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
@@ -832,7 +808,7 @@ def collate_fn(examples):
832808
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
833809
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
834810
# This is discussed in Section 4.2 of the same paper.
835-
snr = compute_snr(timesteps)
811+
snr = compute_snr(noise_scheduler, timesteps)
836812
base_weight = (
837813
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
838814
)

examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
import diffusers
4545
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
4646
from diffusers.optimization import get_scheduler
47-
from diffusers.training_utils import EMAModel
47+
from diffusers.training_utils import EMAModel, compute_snr
4848
from diffusers.utils import check_min_version, deprecate, is_wandb_available
4949
from diffusers.utils.import_utils import is_xformers_available
5050

@@ -524,30 +524,6 @@ def deepspeed_zero_init_disabled_context_manager():
524524
else:
525525
raise ValueError("xformers is not available. Make sure it is installed correctly")
526526

527-
def compute_snr(timesteps):
528-
"""
529-
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
530-
"""
531-
alphas_cumprod = noise_scheduler.alphas_cumprod
532-
sqrt_alphas_cumprod = alphas_cumprod**0.5
533-
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
534-
535-
# Expand the tensors.
536-
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
537-
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
538-
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
539-
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
540-
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
541-
542-
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
543-
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
544-
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
545-
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
546-
547-
# Compute SNR.
548-
snr = (alpha / sigma) ** 2
549-
return snr
550-
551527
# `accelerate` 0.16.0 will have better support for customized saving
552528
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
553529
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
@@ -871,7 +847,7 @@ def collate_fn(examples):
871847
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
872848
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
873849
# This is discussed in Section 4.2 of the same paper.
874-
snr = compute_snr(timesteps)
850+
snr = compute_snr(noise_scheduler, timesteps)
875851
base_weight = (
876852
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
877853
)

examples/text_to_image/train_text_to_image.py

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
import diffusers
4444
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
4545
from diffusers.optimization import get_scheduler
46-
from diffusers.training_utils import EMAModel
46+
from diffusers.training_utils import EMAModel, compute_snr
4747
from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid
4848
from diffusers.utils.import_utils import is_xformers_available
4949

@@ -601,30 +601,6 @@ def deepspeed_zero_init_disabled_context_manager():
601601
else:
602602
raise ValueError("xformers is not available. Make sure it is installed correctly")
603603

604-
def compute_snr(timesteps):
605-
"""
606-
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
607-
"""
608-
alphas_cumprod = noise_scheduler.alphas_cumprod
609-
sqrt_alphas_cumprod = alphas_cumprod**0.5
610-
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
611-
612-
# Expand the tensors.
613-
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
614-
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
615-
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
616-
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
617-
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
618-
619-
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
620-
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
621-
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
622-
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
623-
624-
# Compute SNR.
625-
snr = (alpha / sigma) ** 2
626-
return snr
627-
628604
# `accelerate` 0.16.0 will have better support for customized saving
629605
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
630606
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
@@ -951,7 +927,7 @@ def collate_fn(examples):
951927
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
952928
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
953929
# This is discussed in Section 4.2 of the same paper.
954-
snr = compute_snr(timesteps)
930+
snr = compute_snr(noise_scheduler, timesteps)
955931
base_weight = (
956932
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
957933
)

0 commit comments

Comments
 (0)