Skip to content

Commit 52c8dec

Browse files
authored
Merge pull request kohya-ss#2015 from DKnight54/uncache_vae_batch
Using --vae_batch_size to set batch size for dynamic latent generation
2 parents a0f1173 + 381303d commit 52c8dec

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

train_network.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -912,14 +912,22 @@ def remove_model(old_ckpt_name):
912912
if "latents" in batch and batch["latents"] is not None:
913913
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
914914
else:
915-
with torch.no_grad():
916-
# latentに変換
917-
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
918-
915+
if args.vae_batch_size is None or len(batch["images"]) <= args.vae_batch_size:
916+
with torch.no_grad():
917+
# latentに変換
918+
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
919+
else:
920+
chunks = [batch["images"][i:i + args.vae_batch_size] for i in range(0, len(batch["images"]), args.vae_batch_size)]
921+
list_latents = []
922+
for chunk in chunks:
923+
with torch.no_grad():
924+
# latentに変換
925+
list_latents.append(vae.encode(chunk.to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype))
926+
latents = torch.cat(list_latents, dim=0)
919927
# NaNが含まれていれば警告を表示し0に置き換える
920-
if torch.any(torch.isnan(latents)):
921-
accelerator.print("NaN found in latents, replacing with zeros")
922-
latents = torch.nan_to_num(latents, 0, out=latents)
928+
if torch.any(torch.isnan(latents)):
929+
accelerator.print("NaN found in latents, replacing with zeros")
930+
latents = torch.nan_to_num(latents, 0, out=latents)
923931
latents = latents * self.vae_scale_factor
924932

925933
# get multiplier for each sample

0 commit comments

Comments
 (0)