Skip to content

Commit f036a1d

Browse files
committed
Merge remote-tracking branch 'upstream/main'
2 parents feea39d + a21b6a9 commit f036a1d

File tree

3 files changed

+23
-9
lines changed

3 files changed

+23
-9
lines changed

.github/FUNDING.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# These are supported funding model platforms
2+
3+
github: kohya-ss

library/train_util.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -957,8 +957,11 @@ def make_buckets(self):
957957
self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)}
958958
logger.info(f"bucket {i}: resolution {reso}, count: {len(bucket)}")
959959

960-
img_ar_errors = np.array(img_ar_errors)
961-
mean_img_ar_error = np.mean(np.abs(img_ar_errors))
960+
if len(img_ar_errors) == 0:
961+
mean_img_ar_error = 0 # avoid NaN
962+
else:
963+
img_ar_errors = np.array(img_ar_errors)
964+
mean_img_ar_error = np.mean(np.abs(img_ar_errors))
962965
self.bucket_info["mean_img_ar_error"] = mean_img_ar_error
963966
logger.info(f"mean ar error (without repeats): {mean_img_ar_error}")
964967

train_network.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -913,14 +913,22 @@ def remove_model(old_ckpt_name):
913913
if "latents" in batch and batch["latents"] is not None:
914914
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
915915
else:
916-
with torch.no_grad():
917-
# latentに変換
918-
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
919-
916+
if args.vae_batch_size is None or len(batch["images"]) <= args.vae_batch_size:
917+
with torch.no_grad():
918+
# latentに変換
919+
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
920+
else:
921+
chunks = [batch["images"][i:i + args.vae_batch_size] for i in range(0, len(batch["images"]), args.vae_batch_size)]
922+
list_latents = []
923+
for chunk in chunks:
924+
with torch.no_grad():
925+
# latentに変換
926+
list_latents.append(vae.encode(chunk.to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype))
927+
latents = torch.cat(list_latents, dim=0)
920928
# NaNが含まれていれば警告を表示し0に置き換える
921-
if torch.any(torch.isnan(latents)):
922-
accelerator.print("NaN found in latents, replacing with zeros")
923-
latents = torch.nan_to_num(latents, 0, out=latents)
929+
if torch.any(torch.isnan(latents)):
930+
accelerator.print("NaN found in latents, replacing with zeros")
931+
latents = torch.nan_to_num(latents, 0, out=latents)
924932
latents = latents * self.vae_scale_factor
925933

926934
# get multiplier for each sample

0 commit comments

Comments
 (0)