You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Min-SNR gamma support for Dreambooth training (huggingface#5107)
* min-SNR gamma for Dreambooth training
* Align the mse_loss_weights style with SDXL training example
---------
Co-authored-by: bghira <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
Copy file name to clipboardExpand all lines: examples/dreambooth/train_dreambooth.py
+53-5Lines changed: 53 additions & 5 deletions
Original file line number
Diff line number
Diff line change
@@ -224,6 +224,30 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
224
224
raiseValueError(f"{model_class} is not supported.")
225
225
226
226
227
+
defcompute_snr(timesteps, noise_scheduler):
228
+
"""
229
+
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
0 commit comments