Skip to content

Validation Loss Enhancements #1900

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: sd3
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Clean up positioning of validation progress bar
  • Loading branch information
stepfunction83 committed Jan 26, 2025
commit f2d880660c83aa1955ec1e2944230c98ce0543b5
31 changes: 14 additions & 17 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -1163,6 +1163,18 @@ def load_model_hook(models, input_dir):
args.max_train_steps > initial_step
), f"max_train_steps should be greater than initial step / max_train_stepsは初期ステップより大きい必要があります: {args.max_train_steps} vs {initial_step}"

if args.validate_every_n_steps is not None:
validation_steps = (
min(args.max_validation_steps, len(val_dataloader))
if args.max_validation_steps is not None
else len(val_dataloader)
)
val_progress_bar = tqdm(
range(validation_steps), smoothing=0,
disable=not accelerator.is_local_main_process,
desc="validation steps"
)

progress_bar = tqdm(
range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps"
)
Expand Down Expand Up @@ -1247,12 +1259,6 @@ def remove_model(old_ckpt_name):
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)

validation_steps = (
min(args.max_validation_steps, len(val_dataloader))
if args.max_validation_steps is not None
else len(val_dataloader)
)

# training loop
if initial_step > 0: # only if skip_until_initial_step is specified
for skip_epoch in range(epoch_to_start): # skip epochs
Expand Down Expand Up @@ -1391,11 +1397,7 @@ def remove_model(old_ckpt_name):
and (global_step - 1) % args.validate_every_n_steps == 0 # Note: Should use global step - 1 since the global step is incremented prior to this being run
)
if accelerator.sync_gradients and should_validate_step:
val_progress_bar = tqdm(
range(validation_steps), smoothing=0,
disable=not accelerator.is_local_main_process,
desc="validation steps"
)
val_progress_bar.reset()
for val_step, batch in enumerate(val_dataloader):
if val_step >= validation_steps:
break
Expand Down Expand Up @@ -1452,12 +1454,7 @@ def remove_model(old_ckpt_name):
)

if should_validate_epoch and len(val_dataloader) > 0:
val_progress_bar = tqdm(
range(validation_steps), smoothing=0,
disable=not accelerator.is_local_main_process,
desc="epoch validation steps"
)

val_progress_bar.reset()
for val_step, batch in enumerate(val_dataloader):
if val_step >= validation_steps:
break
Expand Down