File tree Expand file tree Collapse file tree 1 file changed +15
-7
lines changed Expand file tree Collapse file tree 1 file changed +15
-7
lines changed Original file line number Diff line number Diff line change @@ -912,14 +912,22 @@ def remove_model(old_ckpt_name):
912
912
if "latents" in batch and batch ["latents" ] is not None :
913
913
latents = batch ["latents" ].to (accelerator .device ).to (dtype = weight_dtype )
914
914
else :
915
- with torch .no_grad ():
916
- # latentに変換
917
- latents = vae .encode (batch ["images" ].to (dtype = vae_dtype )).latent_dist .sample ().to (dtype = weight_dtype )
918
-
915
+ if args .vae_batch_size is None or len (batch ["images" ]) <= args .vae_batch_size :
916
+ with torch .no_grad ():
917
+ # latentに変換
918
+ latents = vae .encode (batch ["images" ].to (dtype = vae_dtype )).latent_dist .sample ().to (dtype = weight_dtype )
919
+ else :
920
+ chunks = [batch ["images" ][i :i + args .vae_batch_size ] for i in range (0 , len (batch ["images" ]), args .vae_batch_size )]
921
+ list_latents = []
922
+ for chunk in chunks :
923
+ with torch .no_grad ():
924
+ # latentに変換
925
+ list_latents .append (vae .encode (chunk .to (dtype = vae_dtype )).latent_dist .sample ().to (dtype = weight_dtype ))
926
+ latents = torch .cat (list_latents , dim = 0 )
919
927
# NaNが含まれていれば警告を表示し0に置き換える
920
- if torch .any (torch .isnan (latents )):
921
- accelerator .print ("NaN found in latents, replacing with zeros" )
922
- latents = torch .nan_to_num (latents , 0 , out = latents )
928
+ if torch .any (torch .isnan (latents )):
929
+ accelerator .print ("NaN found in latents, replacing with zeros" )
930
+ latents = torch .nan_to_num (latents , 0 , out = latents )
923
931
latents = latents * self .vae_scale_factor
924
932
925
933
# get multiplier for each sample
You can’t perform that action at this time.
0 commit comments