Skip to content

Commit 856331c

Browse files
authored
Support training SD V2 with Flax (huggingface#1783)
* Support training SD V2 with Flax Mostly involves supporting a v_prediction scheduler. The implementation in huggingface#1777 doesn't take into account a recent refactor of `scheduling_utils_flax`, so this should be used instead. * Add to other top-level files.
1 parent f7154f8 commit 856331c

File tree

6 files changed

+75
-16
lines changed

6 files changed

+75
-16
lines changed

examples/dreambooth/train_dreambooth_flax.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -525,28 +525,35 @@ def compute_loss(params):
525525
)[0]
526526

527527
# Predict the noise residual
528-
unet_outputs = unet.apply(
528+
model_pred = unet.apply(
529529
{"params": params["unet"]}, noisy_latents, timesteps, encoder_hidden_states, train=True
530-
)
531-
noise_pred = unet_outputs.sample
530+
).sample
531+
532+
# Get the target for loss depending on the prediction type
533+
if noise_scheduler.config.prediction_type == "epsilon":
534+
target = noise
535+
elif noise_scheduler.config.prediction_type == "v_prediction":
536+
target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps)
537+
else:
538+
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
532539

533540
if args.with_prior_preservation:
534541
# Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
535-
noise_pred, noise_pred_prior = jnp.split(noise_pred, 2, axis=0)
536-
noise, noise_prior = jnp.split(noise, 2, axis=0)
542+
model_pred, model_pred_prior = jnp.split(model_pred, 2, axis=0)
543+
target, target_prior = jnp.split(target, 2, axis=0)
537544

538545
# Compute instance loss
539-
loss = (noise - noise_pred) ** 2
546+
loss = (target - model_pred) ** 2
540547
loss = loss.mean()
541548

542549
# Compute prior loss
543-
prior_loss = (noise_prior - noise_pred_prior) ** 2
550+
prior_loss = (target_prior - model_pred_prior) ** 2
544551
prior_loss = prior_loss.mean()
545552

546553
# Add the prior loss to the instance loss.
547554
loss = loss + args.prior_loss_weight * prior_loss
548555
else:
549-
loss = (noise - noise_pred) ** 2
556+
loss = (target - model_pred) ** 2
550557
loss = loss.mean()
551558

552559
return loss

examples/text_to_image/train_text_to_image_flax.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -459,9 +459,19 @@ def compute_loss(params):
459459
)[0]
460460

461461
# Predict the noise residual and compute loss
462-
unet_outputs = unet.apply({"params": params}, noisy_latents, timesteps, encoder_hidden_states, train=True)
463-
noise_pred = unet_outputs.sample
464-
loss = (noise - noise_pred) ** 2
462+
model_pred = unet.apply(
463+
{"params": params}, noisy_latents, timesteps, encoder_hidden_states, train=True
464+
).sample
465+
466+
# Get the target for loss depending on the prediction type
467+
if noise_scheduler.config.prediction_type == "epsilon":
468+
target = noise
469+
elif noise_scheduler.config.prediction_type == "v_prediction":
470+
target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps)
471+
else:
472+
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
473+
474+
loss = (target - model_pred) ** 2
465475
loss = loss.mean()
466476

467477
return loss

examples/textual_inversion/textual_inversion_flax.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -536,11 +536,20 @@ def compute_loss(params):
536536
encoder_hidden_states = state.apply_fn(
537537
batch["input_ids"], params=params, dropout_rng=dropout_rng, train=True
538538
)[0]
539-
unet_outputs = unet.apply(
539+
# Predict the noise residual and compute loss
540+
model_pred = unet.apply(
540541
{"params": unet_params}, noisy_latents, timesteps, encoder_hidden_states, train=False
541-
)
542-
noise_pred = unet_outputs.sample
543-
loss = (noise - noise_pred) ** 2
542+
).sample
543+
544+
# Get the target for loss depending on the prediction type
545+
if noise_scheduler.config.prediction_type == "epsilon":
546+
target = noise
547+
elif noise_scheduler.config.prediction_type == "v_prediction":
548+
target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps)
549+
else:
550+
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
551+
552+
loss = (target - model_pred) ** 2
544553
loss = loss.mean()
545554

546555
return loss

src/diffusers/schedulers/scheduling_ddim_flax.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
FlaxSchedulerMixin,
3030
FlaxSchedulerOutput,
3131
add_noise_common,
32+
get_velocity_common,
3233
)
3334

3435

@@ -301,5 +302,14 @@ def add_noise(
301302
) -> jnp.ndarray:
302303
return add_noise_common(state.common, original_samples, noise, timesteps)
303304

305+
def get_velocity(
306+
self,
307+
state: DDIMSchedulerState,
308+
sample: jnp.ndarray,
309+
noise: jnp.ndarray,
310+
timesteps: jnp.ndarray,
311+
) -> jnp.ndarray:
312+
return get_velocity_common(state.common, sample, noise, timesteps)
313+
304314
def __len__(self):
305315
return self.config.num_train_timesteps

src/diffusers/schedulers/scheduling_ddpm_flax.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
FlaxSchedulerMixin,
3030
FlaxSchedulerOutput,
3131
add_noise_common,
32+
get_velocity_common,
3233
)
3334

3435

@@ -293,5 +294,14 @@ def add_noise(
293294
) -> jnp.ndarray:
294295
return add_noise_common(state.common, original_samples, noise, timesteps)
295296

297+
def get_velocity(
298+
self,
299+
state: DDPMSchedulerState,
300+
sample: jnp.ndarray,
301+
noise: jnp.ndarray,
302+
timesteps: jnp.ndarray,
303+
) -> jnp.ndarray:
304+
return get_velocity_common(state.common, sample, noise, timesteps)
305+
296306
def __len__(self):
297307
return self.config.num_train_timesteps

src/diffusers/schedulers/scheduling_utils_flax.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def create(cls, scheduler):
242242
)
243243

244244

245-
def add_noise_common(
245+
def get_sqrt_alpha_prod(
246246
state: CommonSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray
247247
):
248248
alphas_cumprod = state.alphas_cumprod
@@ -255,5 +255,18 @@ def add_noise_common(
255255
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
256256
sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)
257257

258+
return sqrt_alpha_prod, sqrt_one_minus_alpha_prod
259+
260+
261+
def add_noise_common(
262+
state: CommonSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray
263+
):
264+
sqrt_alpha_prod, sqrt_one_minus_alpha_prod = get_sqrt_alpha_prod(state, original_samples, noise, timesteps)
258265
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
259266
return noisy_samples
267+
268+
269+
def get_velocity_common(state: CommonSchedulerState, sample: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray):
270+
sqrt_alpha_prod, sqrt_one_minus_alpha_prod = get_sqrt_alpha_prod(state, sample, noise, timesteps)
271+
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
272+
return velocity

0 commit comments

Comments
 (0)