Skip to content

Commit d30ebb2

Browse files
committed
update readme, add metadata for network module
1 parent 90b1879 commit d30ebb2

File tree

3 files changed

+77
-26
lines changed

3 files changed

+77
-26
lines changed

README.md

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,14 +150,15 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
150150
- Also, if there is an absolute path, the path may be exposed, so it is recommended to specify a relative path or write it in the configuration file. In such cases, an INFO log is displayed.
151151
- See [#1123](https://github.com/kohya-ss/sd-scripts/pull/1123) and PR [#1240](https://github.com/kohya-ss/sd-scripts/pull/1240) for details.
152152
- Colab seems to stop with log output. Try specifying `--console_log_simple` option in the training script to disable rich logging.
153-
- Other improvements include the addition of masked loss, DeepSpeed support, dataset settings improvements, and image tagging improvements. See below for details.
153+
- Other improvements include the addition of masked loss, scheduled Huber Loss, DeepSpeed support, dataset settings improvements, and image tagging improvements. See below for details.
154154

155155
#### Training scripts
156156

157157
- `train_network.py` and `sdxl_train_network.py` are modified to record some dataset settings in the metadata of the trained model (`caption_prefix`, `caption_suffix`, `keep_tokens_separator`, `secondary_separator`, `enable_wildcard`).
158158
- Fixed a bug that U-Net and Text Encoders are included in the state in `train_network.py` and `sdxl_train_network.py`. The saving and loading of the state are faster, the file size is smaller, and the memory usage when loading is reduced.
159159
- DeepSpeed is supported. PR [#1101](https://github.com/kohya-ss/sd-scripts/pull/1101) and [#1139](https://github.com/kohya-ss/sd-scripts/pull/1139) Thanks to BootsofLagrangian! See PR [#1101](https://github.com/kohya-ss/sd-scripts/pull/1101) for details.
160-
- The masked loss is supported in each training script. PR [#1207](https://github.com/kohya-ss/sd-scripts/pull/1207) See [Masked loss](#masked-loss) for details.
160+
- The masked loss is supported in each training script. PR [#1207](https://github.com/kohya-ss/sd-scripts/pull/1207) See [Masked loss](#about-masked-loss) for details.
161+
- Scheduled Huber Loss has been introduced to each training scripts. PR [#1228](https://github.com/kohya-ss/sd-scripts/pull/1228/) Thanks to kabachuha for the PR and cheald, drhead, and others for the discussion! See [Scheduled Huber Loss](#about-scheduled-huber-loss) for details.
161162
- The options `--noise_offset_random_strength` and `--ip_noise_gamma_random_strength` are added to each training script. These options can be used to vary the noise offset and ip noise gamma in the range of 0 to the specified value. PR [#1177](https://github.com/kohya-ss/sd-scripts/pull/1177) Thanks to KohakuBlueleaf!
162163
- The options `--save_state_on_train_end` are added to each training script. PR [#1168](https://github.com/kohya-ss/sd-scripts/pull/1168) Thanks to gesen2egee!
163164
- The options `--sample_every_n_epochs` and `--sample_every_n_steps` in each training script now display a warning and ignore them when a number less than or equal to `0` is specified. Thanks to S-Del for raising the issue.
@@ -199,6 +200,23 @@ The feature is not fully tested, so there may be bugs. If you find any issues, p
199200

200201
ControlNet dataset is used to specify the mask. The mask images should be the RGB images. The pixel value 255 in R channel is treated as the mask (the loss is calculated only for the pixels with the mask), and 0 is treated as the non-mask. The pixel values 0-255 are converted to 0-1 (i.e., the pixel value 128 is treated as the half weight of the loss). See details for the dataset specification in the [LLLite documentation](./docs/train_lllite_README.md#preparing-the-dataset).
201202

203+
#### About Scheduled Huber Loss
204+
205+
Scheduled Huber Loss has been introduced to each training scripts. This is a method to improve robustness against outliers or anomalies (data corruption) in the training data.
206+
207+
With the traditional MSE (L2) loss function, the impact of outliers could be significant, potentially leading to a degradation in the quality of generated images. On the other hand, while the Huber loss function can suppress the influence of outliers, it tends to compromise the reproduction of fine details in images.
208+
209+
To address this, the proposed method employs a clever application of the Huber loss function. By scheduling the use of Huber loss in the early stages of training (when noise is high) and MSE in the later stages, it strikes a balance between outlier robustness and fine detail reproduction.
210+
211+
Experimental results have confirmed that this method achieves higher accuracy on data containing outliers compared to pure Huber loss or MSE. The increase in computational cost is minimal.
212+
213+
The newly added arguments loss_type, huber_schedule, and huber_c allow for the selection of the loss function type (Huber, smooth L1, MSE), scheduling method (exponential, constant, SNR), and Huber's parameter. This enables optimization based on the characteristics of the dataset.
214+
215+
See PR [#1228](https://github.com/kohya-ss/sd-scripts/pull/1228/) for details.
216+
217+
- `loss_type`: Specify the loss function type. Choose `huber` for Huber loss, `smooth_l1` for smooth L1 loss, and `l2` for MSE loss. The default is `l2`, which is the same as before.
218+
- `huber_schedule`: Specify the scheduling method. Choose `exponential`, `constant`, or `SNR`. The default is `exponential`.
219+
- `huber_c`: Specify the Huber's parameter. The default is `0.1`.
202220

203221
#### 主要な変更点
204222

@@ -211,14 +229,15 @@ ControlNet dataset is used to specify the mask. The mask images should be the RG
211229
- また絶対パスの指定がある場合、そのパスが公開される可能性がありますので、相対パスを指定するか設定ファイルに記載することをお勧めします。このような場合は INFO ログが表示されます。
212230
- 詳細は [#1123](https://github.com/kohya-ss/sd-scripts/pull/1123) および PR [#1240](https://github.com/kohya-ss/sd-scripts/pull/1240) をご覧ください。
213231
- Colab での動作時、ログ出力で停止してしまうようです。学習スクリプトに `--console_log_simple` オプションを指定し、rich のロギングを無効してお試しください。
214-
- その他、マスクロス追加、DeepSpeed 対応、データセット設定の改善、画像タグ付けの改善などがあります。詳細は以下をご覧ください。
232+
- その他、マスクロス追加、Scheduled Huber Loss 追加、DeepSpeed 対応、データセット設定の改善、画像タグ付けの改善などがあります。詳細は以下をご覧ください。
215233

216234
#### 学習スクリプト
217235

218236
- `train_network.py` および `sdxl_train_network.py` で、学習したモデルのメタデータに一部のデータセット設定が記録されるよう修正しました(`caption_prefix``caption_suffix``keep_tokens_separator``secondary_separator``enable_wildcard`)。
219237
- `train_network.py` および `sdxl_train_network.py` で、state に U-Net および Text Encoder が含まれる不具合を修正しました。state の保存、読み込みが高速化され、ファイルサイズも小さくなり、また読み込み時のメモリ使用量も削減されます。
220238
- DeepSpeed がサポートされました。PR [#1101](https://github.com/kohya-ss/sd-scripts/pull/1101)[#1139](https://github.com/kohya-ss/sd-scripts/pull/1139) BootsofLagrangian 氏に感謝します。詳細は PR [#1101](https://github.com/kohya-ss/sd-scripts/pull/1101) をご覧ください。
221-
- 各学習スクリプトでマスクロスをサポートしました。PR [#1207](https://github.com/kohya-ss/sd-scripts/pull/1207) 詳細は [Masked loss](#masked-loss) をご覧ください。
239+
- 各学習スクリプトでマスクロスをサポートしました。PR [#1207](https://github.com/kohya-ss/sd-scripts/pull/1207) 詳細は [マスクロスについて](#マスクロスについて) をご覧ください。
240+
- 各学習スクリプトに Scheduled Huber Loss を追加しました。PR [#1228](https://github.com/kohya-ss/sd-scripts/pull/1228/) ご提案いただいた kabachuha 氏、および議論を深めてくださった cheald 氏、drhead 氏を始めとする諸氏に感謝します。詳細は [Scheduled Huber Loss について](#scheduled-huber-loss-について) をご覧ください。
222241
- 各学習スクリプトに、noise offset、ip noise gammaを、それぞれ 0~指定した値の範囲で変動させるオプション `--noise_offset_random_strength` および `--ip_noise_gamma_random_strength` が追加されました。 PR [#1177](https://github.com/kohya-ss/sd-scripts/pull/1177) KohakuBlueleaf 氏に感謝します。
223242
- 各学習スクリプトに、学習終了時に state を保存する `--save_state_on_train_end` オプションが追加されました。 PR [#1168](https://github.com/kohya-ss/sd-scripts/pull/1168) gesen2egee 氏に感謝します。
224243
- 各学習スクリプトで `--sample_every_n_epochs` および `--sample_every_n_steps` オプションに `0` 以下の数値を指定した時、警告を表示するとともにそれらを無視するよう変更しました。問題提起していただいた S-Del 氏に感謝します。
@@ -262,6 +281,26 @@ ControlNet dataset is used to specify the mask. The mask images should be the RG
262281
Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates.
263282
最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。
264283

284+
#### Scheduled Huber Loss について
285+
286+
各学習スクリプトに、学習データ中の異常値や外れ値(data corruption)への耐性を高めるための手法、Scheduled Huber Lossが導入されました。
287+
288+
従来のMSE(L2)損失関数では、異常値の影響を大きく受けてしまい、生成画像の品質低下を招く恐れがありました。一方、Huber損失関数は異常値の影響を抑えられますが、画像の細部再現性が損なわれがちでした。
289+
290+
この手法ではHuber損失関数の適用を工夫し、学習の初期段階(ノイズが大きい場合)ではHuber損失を、後期段階ではMSEを用いるようスケジューリングすることで、異常値耐性と細部再現性のバランスを取ります。
291+
292+
実験の結果では、この手法が純粋なHuber損失やMSEと比べ、異常値を含むデータでより高い精度を達成することが確認されています。また計算コストの増加はわずかです。
293+
294+
具体的には、新たに追加された引数loss_type、huber_schedule、huber_cで、損失関数の種類(Huber, smooth L1, MSE)とスケジューリング方法(exponential, constant, SNR)を選択できます。これによりデータセットに応じた最適化が可能になります。
295+
296+
詳細は PR [#1228](https://github.com/kohya-ss/sd-scripts/pull/1228/) をご覧ください。
297+
298+
- `loss_type` : 損失関数の種類を指定します。`huber` で Huber損失、`smooth_l1` で smooth L1 損失、`l2` で MSE 損失を選択します。デフォルトは `l2` で、従来と同様です。
299+
- `huber_schedule` : スケジューリング方法を指定します。`exponential` で指数関数的、`constant` で一定、`snr` で信号対雑音比に基づくスケジューリングを選択します。デフォルトは `exponential` です。
300+
- `huber_c` : Huber損失のパラメータを指定します。デフォルトは `0.1` です。
301+
302+
PR 内でいくつかの比較が共有されています。この機能を試す場合、最初は `--loss_type smooth_l1 --huber_schedule snr --huber_c 0.1` などで試してみるとよいかもしれません。
303+
265304
## Additional Information
266305

267306
### Naming of LoRA

library/train_util.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3241,20 +3241,21 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
32413241
type=str,
32423242
default="l2",
32433243
choices=["l2", "huber", "smooth_l1"],
3244-
help="The type of loss to use and whether it's scheduled based on the timestep"
3244+
help="The type of loss function to use (L2, Huber, or smooth L1), default is L2 / 使用する損失関数の種類(L2、Huber、またはsmooth L1)、デフォルトはL2",
32453245
)
32463246
parser.add_argument(
32473247
"--huber_schedule",
32483248
type=str,
32493249
default="exponential",
32503250
choices=["constant", "exponential", "snr"],
3251-
help="The type of loss to use and whether it's scheduled based on the timestep"
3251+
help="The scheduling method for Huber loss (constant, exponential, or SNR-based). Only used when loss_type is 'huber' or 'smooth_l1'. default is exponential"
3252+
+ " / Huber損失のスケジューリング方法(constant、exponential、またはSNRベース)。loss_typeが'huber'または'smooth_l1'の場合に有効、デフォルトはexponential",
32523253
)
32533254
parser.add_argument(
32543255
"--huber_c",
32553256
type=float,
32563257
default=0.1,
3257-
help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type.",
3258+
help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1",
32583259
)
32593260

32603261
parser.add_argument(
@@ -4862,39 +4863,39 @@ def save_sd_model_on_train_end_common(
48624863
if args.huggingface_repo_id is not None:
48634864
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)
48644865

4866+
48654867
def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device):
48664868

4867-
#TODO: if a huber loss is selected, it will use constant timesteps for each batch
4869+
# TODO: if a huber loss is selected, it will use constant timesteps for each batch
48684870
# as. In the future there may be a smarter way
48694871

4870-
if args.loss_type == 'huber' or args.loss_type == 'smooth_l1':
4871-
timesteps = torch.randint(
4872-
min_timestep, max_timestep, (1,), device='cpu'
4873-
)
4872+
if args.loss_type == "huber" or args.loss_type == "smooth_l1":
4873+
timesteps = torch.randint(min_timestep, max_timestep, (1,), device="cpu")
48744874
timestep = timesteps.item()
48754875

48764876
if args.huber_schedule == "exponential":
4877-
alpha = - math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
4877+
alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
48784878
huber_c = math.exp(-alpha * timestep)
48794879
elif args.huber_schedule == "snr":
48804880
alphas_cumprod = noise_scheduler.alphas_cumprod[timestep]
48814881
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
4882-
huber_c = (1 - args.huber_c) / (1 + sigmas)**2 + args.huber_c
4882+
huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
48834883
elif args.huber_schedule == "constant":
48844884
huber_c = args.huber_c
48854885
else:
4886-
raise NotImplementedError(f'Unknown Huber loss schedule {args.huber_schedule}!')
4886+
raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")
48874887

48884888
timesteps = timesteps.repeat(b_size).to(device)
4889-
elif args.loss_type == 'l2':
4889+
elif args.loss_type == "l2":
48904890
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device)
4891-
huber_c = 1 # may be anything, as it's not used
4891+
huber_c = 1 # may be anything, as it's not used
48924892
else:
4893-
raise NotImplementedError(f'Unknown loss type {args.loss_type}')
4893+
raise NotImplementedError(f"Unknown loss type {args.loss_type}")
48944894
timesteps = timesteps.long()
48954895

48964896
return timesteps, huber_c
48974897

4898+
48984899
def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
48994900
# Sample noise that we'll add to the latents
49004901
noise = torch.randn_like(latents, device=latents.device)
@@ -4929,27 +4930,31 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
49294930

49304931
return noise, noisy_latents, timesteps, huber_c
49314932

4933+
49324934
# NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already
4933-
def conditional_loss(model_pred:torch.Tensor, target:torch.Tensor, reduction:str="mean", loss_type:str="l2", huber_c:float=0.1):
4934-
4935-
if loss_type == 'l2':
4935+
def conditional_loss(
4936+
model_pred: torch.Tensor, target: torch.Tensor, reduction: str = "mean", loss_type: str = "l2", huber_c: float = 0.1
4937+
):
4938+
4939+
if loss_type == "l2":
49364940
loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction)
4937-
elif loss_type == 'huber':
4941+
elif loss_type == "huber":
49384942
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
49394943
if reduction == "mean":
49404944
loss = torch.mean(loss)
49414945
elif reduction == "sum":
49424946
loss = torch.sum(loss)
4943-
elif loss_type == 'smooth_l1':
4947+
elif loss_type == "smooth_l1":
49444948
loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
49454949
if reduction == "mean":
49464950
loss = torch.mean(loss)
49474951
elif reduction == "sum":
49484952
loss = torch.sum(loss)
49494953
else:
4950-
raise NotImplementedError(f'Unsupported Loss Type {loss_type}')
4954+
raise NotImplementedError(f"Unsupported Loss Type {loss_type}")
49514955
return loss
49524956

4957+
49534958
def append_lr_to_logs(logs, lr_scheduler, optimizer_type, including_unet=True):
49544959
names = []
49554960
if including_unet:

train_network.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,7 @@ def save_model_hook(models, weights, output_dir):
476476
# pop weights of other models than network to save only network weights
477477
if accelerator.is_main_process:
478478
remove_indices = []
479-
for i,model in enumerate(models):
479+
for i, model in enumerate(models):
480480
if not isinstance(model, type(accelerator.unwrap_model(network))):
481481
remove_indices.append(i)
482482
for i in reversed(remove_indices):
@@ -569,6 +569,11 @@ def load_model_hook(models, input_dir):
569569
"ss_scale_weight_norms": args.scale_weight_norms,
570570
"ss_ip_noise_gamma": args.ip_noise_gamma,
571571
"ss_debiased_estimation": bool(args.debiased_estimation_loss),
572+
"ss_noise_offset_random_strength": args.noise_offset_random_strength,
573+
"ss_ip_noise_gamma_random_strength": args.ip_noise_gamma_random_strength,
574+
"ss_loss_type": args.loss_type,
575+
"ss_huber_schedule": args.huber_schedule,
576+
"ss_huber_c": args.huber_c,
572577
}
573578

574579
if use_user_config:
@@ -873,7 +878,9 @@ def remove_model(old_ckpt_name):
873878
else:
874879
target = noise
875880

876-
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
881+
loss = train_util.conditional_loss(
882+
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
883+
)
877884
if args.masked_loss:
878885
loss = apply_masked_loss(loss, batch)
879886
loss = loss.mean([1, 2, 3])

0 commit comments

Comments
 (0)