Skip to content

Val loss improvement #1903

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Feb 26, 2025
Merged
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 6 additions & 40 deletions flux_train_network.py
Original file line number Diff line number Diff line change
@@ -381,8 +381,7 @@ def get_noise_pred_and_target(
t5_attn_mask = None

def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
# if not args.split_mode:
# normal forward
# grad is enabled even if unet is not in train mode, because Text Encoder is in train mode
with torch.set_grad_enabled(is_train), accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = unet(
@@ -395,44 +394,6 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
"""
else:
# split forward to reduce memory usage
assert network.train_blocks == "single", "train_blocks must be single for split mode"
with accelerator.autocast():
# move flux lower to cpu, and then move flux upper to gpu
unet.to("cpu")
clean_memory_on_device(accelerator.device)
self.flux_upper.to(accelerator.device)

# upper model does not require grad
with torch.no_grad():
intermediate_img, intermediate_txt, vec, pe = self.flux_upper(
img=packed_noisy_model_input,
img_ids=img_ids,
txt=t5_out,
txt_ids=txt_ids,
y=l_pooled,
timesteps=timesteps / 1000,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)

# move flux upper back to cpu, and then move flux lower to gpu
self.flux_upper.to("cpu")
clean_memory_on_device(accelerator.device)
unet.to(accelerator.device)

# lower model requires grad
intermediate_img.requires_grad_(True)
intermediate_txt.requires_grad_(True)
vec.requires_grad_(True)
pe.requires_grad_(True)

with torch.set_grad_enabled(is_train and train_unet):
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
"""

return model_pred

model_pred = call_dit(
@@ -551,6 +512,11 @@ def forward(hidden_states):
text_encoder.to(te_weight_dtype) # fp8
prepare_fp8(text_encoder, weight_dtype)

def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
if self.is_swapping_blocks:
# prepare for next forward: because backward pass is not called, we need to prepare it here
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()

def prepare_unet_with_accelerator(
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
) -> torch.nn.Module:
62 changes: 29 additions & 33 deletions library/train_util.py
Original file line number Diff line number Diff line change
@@ -13,17 +13,7 @@
import shutil
import time
import typing
from typing import (
Any,
Callable,
Dict,
List,
NamedTuple,
Optional,
Sequence,
Tuple,
Union
)
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState
import glob
import math
@@ -146,12 +136,13 @@
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz"


def split_train_val(
paths: List[str],
paths: List[str],
sizes: List[Optional[Tuple[int, int]]],
is_training_dataset: bool,
validation_split: float,
validation_seed: int | None
is_training_dataset: bool,
validation_split: float,
validation_seed: int | None,
) -> Tuple[List[str], List[Optional[Tuple[int, int]]]]:
"""
Split the dataset into train and validation
@@ -1842,7 +1833,7 @@ def get_item_for_caching(self, bucket, bucket_batch_size, image_index):
class DreamBoothDataset(BaseDataset):
IMAGE_INFO_CACHE_FILE = "metadata_cache.json"

# The is_training_dataset defines the type of dataset, training or validation
# The is_training_dataset defines the type of dataset, training or validation
# if is_training_dataset is True -> training dataset
# if is_training_dataset is False -> validation dataset
def __init__(
@@ -1981,29 +1972,25 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}")

# We want to create a training and validation split. This should be improved in the future
# to allow a clearer distinction between training and validation. This can be seen as a
# to allow a clearer distinction between training and validation. This can be seen as a
# short-term solution to limit what is necessary to implement validation datasets
#
#
# We split the dataset for the subset based on if we are doing a validation split
# The self.is_training_dataset defines the type of dataset, training or validation
# The self.is_training_dataset defines the type of dataset, training or validation
# if self.is_training_dataset is True -> training dataset
# if self.is_training_dataset is False -> validation dataset
if self.validation_split > 0.0:
# For regularization images we do not want to split this dataset.
# For regularization images we do not want to split this dataset.
if subset.is_reg is True:
# Skip any validation dataset for regularization images
if self.is_training_dataset is False:
img_paths = []
sizes = []
# Otherwise the img_paths remain as original img_paths and no split
# Otherwise the img_paths remain as original img_paths and no split
# required for training images dataset of regularization images
else:
img_paths, sizes = split_train_val(
img_paths,
sizes,
self.is_training_dataset,
self.validation_split,
self.validation_seed
img_paths, sizes, self.is_training_dataset, self.validation_split, self.validation_seed
)

logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
@@ -2373,7 +2360,7 @@ def __init__(
bucket_no_upscale: bool,
debug_dataset: bool,
validation_split: float,
validation_seed: Optional[int],
validation_seed: Optional[int],
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset)

@@ -2431,9 +2418,9 @@ def __init__(
self.image_data = self.dreambooth_dataset_delegate.image_data
self.batch_size = batch_size
self.num_train_images = self.dreambooth_dataset_delegate.num_train_images
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
self.validation_split = validation_split
self.validation_seed = validation_seed
self.validation_seed = validation_seed

# assert all conditioning data exists
missing_imgs = []
@@ -5944,12 +5931,17 @@ def save_sd_model_on_train_end_common(


def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device) -> torch.Tensor:
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")
if min_timestep < max_timestep:
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")
else:
timesteps = torch.full((b_size,), max_timestep, device="cpu")
timesteps = timesteps.long().to(device)
return timesteps


def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]:
def get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents: torch.FloatTensor
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]:
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
@@ -6441,7 +6433,7 @@ def sample_image_inference(
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption


def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tracker_name: str):
def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tracker_name: str):
"""
Initialize experiment trackers with tracker specific behaviors
"""
@@ -6458,13 +6450,17 @@ def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tr
)

if "wandb" in [tracker.name for tracker in accelerator.trackers]:
import wandb
import wandb

wandb_tracker = accelerator.get_tracker("wandb", unwrap=True)

# Define specific metrics to handle validation and epochs "steps"
wandb_tracker.define_metric("epoch", hidden=True)
wandb_tracker.define_metric("val_step", hidden=True)

wandb_tracker.define_metric("global_step", hidden=True)


# endregion


9 changes: 7 additions & 2 deletions sd3_train_network.py
Original file line number Diff line number Diff line change
@@ -450,14 +450,19 @@ def forward(hidden_states):
text_encoder.to(te_weight_dtype) # fp8
prepare_fp8(text_encoder, weight_dtype)

def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
# drop cached text encoder outputs
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True):
# drop cached text encoder outputs: in validation, we drop cached outputs deterministically by fixed seed
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list)
batch["text_encoder_outputs_list"] = text_encoder_outputs_list

def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
if self.is_swapping_blocks:
# prepare for next forward: because backward pass is not called, we need to prepare it here
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()

def prepare_unet_with_accelerator(
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
) -> torch.nn.Module:
429 changes: 263 additions & 166 deletions train_network.py

Large diffs are not rendered by default.