File tree Expand file tree Collapse file tree 3 files changed +23
-9
lines changed Expand file tree Collapse file tree 3 files changed +23
-9
lines changed Original file line number Diff line number Diff line change
1
+ # These are supported funding model platforms
2
+
3
+ github : kohya-ss
Original file line number Diff line number Diff line change @@ -957,8 +957,11 @@ def make_buckets(self):
957
957
self .bucket_info ["buckets" ][i ] = {"resolution" : reso , "count" : len (bucket )}
958
958
logger .info (f"bucket { i } : resolution { reso } , count: { len (bucket )} " )
959
959
960
- img_ar_errors = np .array (img_ar_errors )
961
- mean_img_ar_error = np .mean (np .abs (img_ar_errors ))
960
+ if len (img_ar_errors ) == 0 :
961
+ mean_img_ar_error = 0 # avoid NaN
962
+ else :
963
+ img_ar_errors = np .array (img_ar_errors )
964
+ mean_img_ar_error = np .mean (np .abs (img_ar_errors ))
962
965
self .bucket_info ["mean_img_ar_error" ] = mean_img_ar_error
963
966
logger .info (f"mean ar error (without repeats): { mean_img_ar_error } " )
964
967
Original file line number Diff line number Diff line change @@ -913,14 +913,22 @@ def remove_model(old_ckpt_name):
913
913
if "latents" in batch and batch ["latents" ] is not None :
914
914
latents = batch ["latents" ].to (accelerator .device ).to (dtype = weight_dtype )
915
915
else :
916
- with torch .no_grad ():
917
- # latentに変換
918
- latents = vae .encode (batch ["images" ].to (dtype = vae_dtype )).latent_dist .sample ().to (dtype = weight_dtype )
919
-
916
+ if args .vae_batch_size is None or len (batch ["images" ]) <= args .vae_batch_size :
917
+ with torch .no_grad ():
918
+ # latentに変換
919
+ latents = vae .encode (batch ["images" ].to (dtype = vae_dtype )).latent_dist .sample ().to (dtype = weight_dtype )
920
+ else :
921
+ chunks = [batch ["images" ][i :i + args .vae_batch_size ] for i in range (0 , len (batch ["images" ]), args .vae_batch_size )]
922
+ list_latents = []
923
+ for chunk in chunks :
924
+ with torch .no_grad ():
925
+ # latentに変換
926
+ list_latents .append (vae .encode (chunk .to (dtype = vae_dtype )).latent_dist .sample ().to (dtype = weight_dtype ))
927
+ latents = torch .cat (list_latents , dim = 0 )
920
928
# NaNが含まれていれば警告を表示し0に置き換える
921
- if torch .any (torch .isnan (latents )):
922
- accelerator .print ("NaN found in latents, replacing with zeros" )
923
- latents = torch .nan_to_num (latents , 0 , out = latents )
929
+ if torch .any (torch .isnan (latents )):
930
+ accelerator .print ("NaN found in latents, replacing with zeros" )
931
+ latents = torch .nan_to_num (latents , 0 , out = latents )
924
932
latents = latents * self .vae_scale_factor
925
933
926
934
# get multiplier for each sample
You can’t perform that action at this time.
0 commit comments