Skip to content

Validation Loss Enhancements #1900

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: sd3
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flux_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) ->
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
return noise_scheduler

def encode_images_to_latents(self, args, accelerator, vae, images):
def encode_images_to_latents(self, args, vae, images):
return vae.encode(images)

def shift_scale_latents(self, args, latents):
Expand Down
98 changes: 30 additions & 68 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -1163,6 +1163,18 @@ def load_model_hook(models, input_dir):
args.max_train_steps > initial_step
), f"max_train_steps should be greater than initial step / max_train_stepsは初期ステップより大きい必要があります: {args.max_train_steps} vs {initial_step}"

if args.validate_every_n_steps is not None:
validation_steps = (
min(args.max_validation_steps, len(val_dataloader))
if args.max_validation_steps is not None
else len(val_dataloader)
)
val_progress_bar = tqdm(
range(validation_steps), smoothing=0,
disable=not accelerator.is_local_main_process,
desc="validation steps"
)

progress_bar = tqdm(
range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps"
)
Expand Down Expand Up @@ -1247,12 +1259,6 @@ def remove_model(old_ckpt_name):
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)

validation_steps = (
min(args.max_validation_steps, len(val_dataloader))
if args.max_validation_steps is not None
else len(val_dataloader)
)

# training loop
if initial_step > 0: # only if skip_until_initial_step is specified
for skip_epoch in range(epoch_to_start): # skip epochs
Expand Down Expand Up @@ -1383,19 +1389,18 @@ def remove_model(old_ckpt_name):
maximum_norm
)
accelerator.log(logs, step=global_step)


# VALIDATION PER STEP
should_validate_step = (
args.validate_every_n_steps is not None
and global_step != 0 # Skip first step
and global_step % args.validate_every_n_steps == 0
args.validate_every_n_steps is not None
and args.validation_at_start
and (global_step - 1) % args.validate_every_n_steps == 0 # Note: Should use global step - 1 since the global step is incremented prior to this being run
)
if accelerator.sync_gradients and validation_steps > 0 and should_validate_step:
val_progress_bar = tqdm(
range(validation_steps), smoothing=0,
disable=not accelerator.is_local_main_process,
desc="validation steps"
)

# Break out validation processing so that it does not need to be repeated
def process_validation():
val_progress_bar.reset()
for val_step, batch in enumerate(val_dataloader):
if val_step >= validation_steps:
break
Expand Down Expand Up @@ -1440,6 +1445,10 @@ def remove_model(old_ckpt_name):
"loss/validation/step_divergence": loss_validation_divergence,
}
accelerator.log(logs, step=global_step)
# END VALIDATION PROCESSING

if accelerator.sync_gradients and should_validate_step:
process_validation()

if global_step >= args.max_train_steps:
break
Expand All @@ -1452,59 +1461,7 @@ def remove_model(old_ckpt_name):
)

if should_validate_epoch and len(val_dataloader) > 0:
val_progress_bar = tqdm(
range(validation_steps), smoothing=0,
disable=not accelerator.is_local_main_process,
desc="epoch validation steps"
)

for val_step, batch in enumerate(val_dataloader):
if val_step >= validation_steps:
break

# temporary, for batch processing
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)

loss = self.process_batch(
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train=False,
train_text_encoder=False,
train_unet=False
)

current_loss = loss.detach().item()
val_epoch_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
val_progress_bar.update(1)
val_progress_bar.set_postfix({ "val_epoch_avg_loss": val_epoch_loss_recorder.moving_average })

if is_tracking:
logs = {
"loss/validation/epoch_current": current_loss,
"epoch": epoch + 1,
"val_step": (epoch * validation_steps) + val_step
}
accelerator.log(logs, step=global_step)

if is_tracking:
avr_loss: float = val_epoch_loss_recorder.moving_average
loss_validation_divergence = val_step_loss_recorder.moving_average - avr_loss
logs = {
"loss/validation/epoch_average": avr_loss,
"loss/validation/epoch_divergence": loss_validation_divergence,
"epoch": epoch + 1
}
accelerator.log(logs, step=global_step)
process_validation()

# END OF EPOCH
if is_tracking:
Expand Down Expand Up @@ -1722,6 +1679,11 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します"
)
parser.add_argument(
"--validation_at_start",
action="store_true",
help="Calculate validation loss at run start"
)
return parser


Expand Down