19
19
Tuple ,
20
20
Union ,
21
21
)
22
- from accelerate import Accelerator , InitProcessGroupKwargs , DistributedDataParallelKwargs
22
+ from accelerate import Accelerator , InitProcessGroupKwargs , DistributedDataParallelKwargs , PartialState
23
23
import gc
24
24
import glob
25
25
import math
@@ -4636,7 +4636,6 @@ def line_to_prompt_dict(line: str) -> dict:
4636
4636
4637
4637
return prompt_dict
4638
4638
4639
-
4640
4639
def sample_images_common (
4641
4640
pipe_class ,
4642
4641
accelerator : Accelerator ,
@@ -4654,6 +4653,7 @@ def sample_images_common(
4654
4653
"""
4655
4654
StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した
4656
4655
"""
4656
+
4657
4657
if steps == 0 :
4658
4658
if not args .sample_at_first :
4659
4659
return
@@ -4668,13 +4668,15 @@ def sample_images_common(
4668
4668
if steps % args .sample_every_n_steps != 0 or epoch is not None : # steps is not divisible or end of epoch
4669
4669
return
4670
4670
4671
+ distributed_state = PartialState () #testing implementation of multi gpu distributed inference
4672
+
4671
4673
print (f"\n generating sample images at step / サンプル画像生成 ステップ: { steps } " )
4672
4674
if not os .path .isfile (args .sample_prompts ):
4673
4675
print (f"No prompt file / プロンプトファイルがありません: { args .sample_prompts } " )
4674
4676
return
4675
4677
4676
4678
org_vae_device = vae .device # CPUにいるはず
4677
- vae .to (device )
4679
+ vae .to (distributed_state . device )
4678
4680
4679
4681
# unwrap unet and text_encoder(s)
4680
4682
unet = accelerator .unwrap_model (unet )
@@ -4700,12 +4702,11 @@ def sample_images_common(
4700
4702
with open (args .sample_prompts , "r" , encoding = "utf-8" ) as f :
4701
4703
prompts = json .load (f )
4702
4704
4703
- schedulers : dict = {}
4705
+ # schedulers: dict = {} cannot find where this is used
4704
4706
default_scheduler = get_my_scheduler (
4705
4707
sample_sampler = args .sample_sampler ,
4706
4708
v_parameterization = args .v_parameterization ,
4707
4709
)
4708
- schedulers [args .sample_sampler ] = default_scheduler
4709
4710
4710
4711
pipeline = pipe_class (
4711
4712
text_encoder = text_encoder ,
@@ -4718,114 +4719,135 @@ def sample_images_common(
4718
4719
requires_safety_checker = False ,
4719
4720
clip_skip = args .clip_skip ,
4720
4721
)
4721
- pipeline .to (device )
4722
-
4722
+ pipeline .to (distributed_state .device )
4723
4723
save_dir = args .output_dir + "/sample"
4724
4724
os .makedirs (save_dir , exist_ok = True )
4725
+
4726
+ # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processess available (number of devices available)
4727
+ # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
4728
+ per_process_prompts = generate_per_device_prompt_list (prompts , num_of_processes = distributed_state .num_processes , prompt_replacement = prompt_replacement )
4725
4729
4726
4730
rng_state = torch .get_rng_state ()
4727
4731
cuda_rng_state = torch .cuda .get_rng_state () if torch .cuda .is_available () else None
4732
+ # True random sample image generation
4733
+ torch .seed ()
4734
+ torch .cuda .seed ()
4728
4735
4729
4736
with torch .no_grad ():
4730
- # with accelerator.autocast():
4731
- for i , prompt_dict in enumerate (prompts ):
4732
- if not accelerator .is_main_process :
4733
- continue
4734
-
4735
- if isinstance (prompt_dict , str ):
4736
- prompt_dict = line_to_prompt_dict (prompt_dict )
4737
-
4738
- assert isinstance (prompt_dict , dict )
4739
- negative_prompt = prompt_dict .get ("negative_prompt" )
4740
- sample_steps = prompt_dict .get ("sample_steps" , 30 )
4741
- width = prompt_dict .get ("width" , 512 )
4742
- height = prompt_dict .get ("height" , 512 )
4743
- scale = prompt_dict .get ("scale" , 7.5 )
4744
- seed = prompt_dict .get ("seed" )
4745
- controlnet_image = prompt_dict .get ("controlnet_image" )
4746
- prompt : str = prompt_dict .get ("prompt" , "" )
4747
- sampler_name : str = prompt_dict .get ("sample_sampler" , args .sample_sampler )
4748
-
4749
- if seed is not None :
4750
- torch .manual_seed (seed )
4751
- torch .cuda .manual_seed (seed )
4752
-
4753
- scheduler = schedulers .get (sampler_name )
4754
- if scheduler is None :
4755
- scheduler = get_my_scheduler (
4756
- sample_sampler = sampler_name ,
4757
- v_parameterization = args .v_parameterization ,
4758
- )
4759
- schedulers [sampler_name ] = scheduler
4760
- pipeline .scheduler = scheduler
4761
-
4762
- if prompt_replacement is not None :
4763
- prompt = prompt .replace (prompt_replacement [0 ], prompt_replacement [1 ])
4764
- if negative_prompt is not None :
4765
- negative_prompt = negative_prompt .replace (prompt_replacement [0 ], prompt_replacement [1 ])
4766
-
4767
- if controlnet_image is not None :
4768
- controlnet_image = Image .open (controlnet_image ).convert ("RGB" )
4769
- controlnet_image = controlnet_image .resize ((width , height ), Image .LANCZOS )
4770
-
4771
- height = max (64 , height - height % 8 ) # round to divisible by 8
4772
- width = max (64 , width - width % 8 ) # round to divisible by 8
4773
- print (f"prompt: { prompt } " )
4774
- print (f"negative_prompt: { negative_prompt } " )
4775
- print (f"height: { height } " )
4776
- print (f"width: { width } " )
4777
- print (f"sample_steps: { sample_steps } " )
4778
- print (f"scale: { scale } " )
4779
- print (f"sample_sampler: { sampler_name } " )
4780
- if seed is not None :
4781
- print (f"seed: { seed } " )
4782
- with accelerator .autocast ():
4783
- latents = pipeline (
4784
- prompt = prompt ,
4785
- height = height ,
4786
- width = width ,
4787
- num_inference_steps = sample_steps ,
4788
- guidance_scale = scale ,
4789
- negative_prompt = negative_prompt ,
4790
- controlnet = controlnet ,
4791
- controlnet_image = controlnet_image ,
4792
- )
4737
+ with distributed_state .split_between_processes (per_process_prompts ) as prompt_dict_lists :
4738
+ for prompt_dict in prompt_dict_lists [0 ]:
4739
+ sample_image_inference (accelerator , args , pipeline , save_dir , prompt_dict , epoch , steps , controlnet = controlnet )
4793
4740
4794
- image = pipeline .latents_to_image (latents )[0 ]
4795
-
4796
- ts_str = time .strftime ("%Y%m%d%H%M%S" , time .localtime ())
4797
- num_suffix = f"e{ epoch :06d} " if epoch is not None else f"{ steps :06d} "
4798
- seed_suffix = "" if seed is None else f"_{ seed } "
4799
- img_filename = (
4800
- f"{ '' if args .output_name is None else args .output_name + '_' } { ts_str } _{ num_suffix } _{ i :02d} { seed_suffix } .png"
4801
- )
4802
-
4803
- image .save (os .path .join (save_dir , img_filename ))
4804
-
4805
- # wandb有効時のみログを送信
4806
- try :
4807
- wandb_tracker = accelerator .get_tracker ("wandb" )
4808
- try :
4809
- import wandb
4810
- except ImportError : # 事前に一度確認するのでここはエラー出ないはず
4811
- raise ImportError ("No wandb / wandb がインストールされていないようです" )
4812
-
4813
- wandb_tracker .log ({f"sample_{ i } " : wandb .Image (image )})
4814
- except : # wandb 無効時
4815
- pass
4816
4741
4817
4742
# clear pipeline and cache to reduce vram usage
4818
4743
del pipeline
4819
- torch .cuda .empty_cache ()
4820
4744
4745
+ with torch .cuda .device (torch .cuda .current_device ()):
4746
+ torch .cuda .empty_cache ()
4747
+
4821
4748
torch .set_rng_state (rng_state )
4822
4749
if cuda_rng_state is not None :
4823
4750
torch .cuda .set_rng_state (cuda_rng_state )
4824
4751
vae .to (org_vae_device )
4825
4752
4753
+ def generate_per_device_prompt_list (prompts , num_of_processes , prompt_replacement = None ):
4754
+
4755
+ # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processess available (number of devices available)
4756
+ # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
4757
+ per_process_prompts = [[] for i in range (num_of_processes )]
4758
+ for i , prompt in enumerate (prompts ):
4759
+ if isinstance (prompt , str ):
4760
+ prompt = line_to_prompt_dict (prompt )
4761
+ assert isinstance (prompt , dict )
4762
+ prompt .pop ("subset" , None ) # Clean up subset key
4763
+ prompt ["enum" ] = i
4764
+ # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
4765
+ if prompt_replacement is not None :
4766
+ prompt ["prompt" ] = prompt ["prompt" ].replace (prompt_replacement [0 ], prompt_replacement [1 ])
4767
+ if prompt ["negative_prompt" ] is not None :
4768
+ prompt ["negative_prompt" ] = prompt ["negative_prompt" ].replace (prompt_replacement [0 ], prompt_replacement [1 ])
4769
+ # Refactor prompt replacement to here in order to simplify sample_image_inference function.
4770
+ per_process_prompts [i % num_of_processes ].append (prompt )
4771
+ return per_process_prompts
4772
+
4773
+ def sample_image_inference (accelerator : Accelerator , args : argparse .Namespace , pipeline , save_dir , prompt_dict , epoch , steps , controlnet = None ):
4774
+ assert isinstance (prompt_dict , dict )
4775
+ negative_prompt = prompt_dict .get ("negative_prompt" )
4776
+ sample_steps = prompt_dict .get ("sample_steps" , 30 )
4777
+ width = prompt_dict .get ("width" , 512 )
4778
+ height = prompt_dict .get ("height" , 512 )
4779
+ scale = prompt_dict .get ("scale" , 7.5 )
4780
+ seed = prompt_dict .get ("seed" )
4781
+ controlnet_image = prompt_dict .get ("controlnet_image" )
4782
+ prompt : str = prompt_dict .get ("prompt" , "" )
4783
+ sampler_name : str = prompt_dict .get ("sample_sampler" , args .sample_sampler )
4784
+
4785
+ if seed is not None :
4786
+ torch .manual_seed (seed )
4787
+ torch .cuda .manual_seed (seed )
4788
+
4789
+ scheduler = get_my_scheduler (
4790
+ sample_sampler = sampler_name ,
4791
+ v_parameterization = args .v_parameterization ,
4792
+ )
4793
+ pipeline .scheduler = scheduler
4794
+
4795
+ if controlnet_image is not None :
4796
+ controlnet_image = Image .open (controlnet_image ).convert ("RGB" )
4797
+ controlnet_image = controlnet_image .resize ((width , height ), Image .LANCZOS )
4798
+
4799
+ height = max (64 , height - height % 8 ) # round to divisible by 8
4800
+ width = max (64 , width - width % 8 ) # round to divisible by 8
4801
+ print (f"\n prompt: { prompt } " )
4802
+ print (f"negative_prompt: { negative_prompt } " )
4803
+ print (f"height: { height } " )
4804
+ print (f"width: { width } " )
4805
+ print (f"sample_steps: { sample_steps } " )
4806
+ print (f"scale: { scale } " )
4807
+ print (f"sample_sampler: { sampler_name } " )
4808
+ if seed is not None :
4809
+ print (f"seed: { seed } " )
4810
+ with accelerator .autocast ():
4811
+ latents = pipeline (
4812
+ prompt = prompt ,
4813
+ height = height ,
4814
+ width = width ,
4815
+ num_inference_steps = sample_steps ,
4816
+ guidance_scale = scale ,
4817
+ negative_prompt = negative_prompt ,
4818
+ controlnet = controlnet ,
4819
+ controlnet_image = controlnet_image ,
4820
+ )
4821
+ image = pipeline .latents_to_image (latents )[0 ]
4822
+ # adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
4823
+ ts_str = time .strftime ("%Y%m%d%H%M%S" , time .localtime ())
4824
+ num_suffix = f"e{ epoch :06d} " if epoch is not None else f"{ steps :06d} "
4825
+ seed_suffix = "" if seed is None else f"_{ seed } "
4826
+ i : int = prompt_dict ["enum" ]
4827
+ img_filename = (
4828
+ f"{ '' if args .output_name is None else args .output_name + '_' } { num_suffix } _{ i :02d} _{ ts_str } { seed_suffix } .png"
4829
+ )
4830
+
4831
+ image .save (os .path .join (save_dir , img_filename ))
4832
+ if seed is not None :
4833
+ torch .seed ()
4834
+ torch .cuda .seed ()
4835
+ # wandb有効時のみログを送信
4836
+ try :
4837
+ wandb_tracker = accelerator .get_tracker ("wandb" )
4838
+ try :
4839
+ import wandb
4840
+ except ImportError : # 事前に一度確認するのでここはエラー出ないはず
4841
+ raise ImportError ("No wandb / wandb がインストールされていないようです" )
4826
4842
4843
+ wandb_tracker .log ({f"sample_{ i } " : wandb .Image (image )})
4844
+ except : # wandb 無効時
4845
+ pass
4827
4846
# endregion
4828
4847
4848
+
4849
+
4850
+
4829
4851
# region 前処理用
4830
4852
4831
4853
0 commit comments