Skip to content

Commit 1567ce1

Browse files
DKnight54akx
andauthored
Enable distributed sample image generation on multi-GPU enviroment (kohya-ss#1061)
* Update train_util.py Modifying to attempt enable multi GPU inference * Update train_util.py additional VRAM checking, refactor check_vram_usage to return string for use with accelerator.print * Update train_network.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py remove sample image debug outputs * Update train_util.py * Update train_util.py * Update train_network.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_network.py * Update train_util.py * Update train_network.py * Update train_network.py * Update train_network.py * Cleanup of debugging outputs * adopt more elegant coding Co-authored-by: Aarni Koskela <[email protected]> * Update train_util.py Fix leftover debugging code attempt to refactor inference into separate function * refactor in function generate_per_device_prompt_list() generation of distributed prompt list * Clean up missing variables * fix syntax error * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * true random sample image generation update code to reinitialize random seed to true random if seed was set * true random sample image generation * simplify per process prompt * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_util.py * Update train_network.py * Update train_network.py * Update train_network.py --------- Co-authored-by: Aarni Koskela <[email protected]>
1 parent 7f948db commit 1567ce1

File tree

1 file changed

+115
-93
lines changed

1 file changed

+115
-93
lines changed

library/train_util.py

Lines changed: 115 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
Tuple,
2020
Union,
2121
)
22-
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
22+
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState
2323
import gc
2424
import glob
2525
import math
@@ -4636,7 +4636,6 @@ def line_to_prompt_dict(line: str) -> dict:
46364636

46374637
return prompt_dict
46384638

4639-
46404639
def sample_images_common(
46414640
pipe_class,
46424641
accelerator: Accelerator,
@@ -4654,6 +4653,7 @@ def sample_images_common(
46544653
"""
46554654
StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した
46564655
"""
4656+
46574657
if steps == 0:
46584658
if not args.sample_at_first:
46594659
return
@@ -4668,13 +4668,15 @@ def sample_images_common(
46684668
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
46694669
return
46704670

4671+
distributed_state = PartialState() #testing implementation of multi gpu distributed inference
4672+
46714673
print(f"\ngenerating sample images at step / サンプル画像生成 ステップ: {steps}")
46724674
if not os.path.isfile(args.sample_prompts):
46734675
print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
46744676
return
46754677

46764678
org_vae_device = vae.device # CPUにいるはず
4677-
vae.to(device)
4679+
vae.to(distributed_state.device)
46784680

46794681
# unwrap unet and text_encoder(s)
46804682
unet = accelerator.unwrap_model(unet)
@@ -4700,12 +4702,11 @@ def sample_images_common(
47004702
with open(args.sample_prompts, "r", encoding="utf-8") as f:
47014703
prompts = json.load(f)
47024704

4703-
schedulers: dict = {}
4705+
# schedulers: dict = {} cannot find where this is used
47044706
default_scheduler = get_my_scheduler(
47054707
sample_sampler=args.sample_sampler,
47064708
v_parameterization=args.v_parameterization,
47074709
)
4708-
schedulers[args.sample_sampler] = default_scheduler
47094710

47104711
pipeline = pipe_class(
47114712
text_encoder=text_encoder,
@@ -4718,114 +4719,135 @@ def sample_images_common(
47184719
requires_safety_checker=False,
47194720
clip_skip=args.clip_skip,
47204721
)
4721-
pipeline.to(device)
4722-
4722+
pipeline.to(distributed_state.device)
47234723
save_dir = args.output_dir + "/sample"
47244724
os.makedirs(save_dir, exist_ok=True)
4725+
4726+
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processess available (number of devices available)
4727+
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
4728+
per_process_prompts = generate_per_device_prompt_list(prompts, num_of_processes = distributed_state.num_processes, prompt_replacement = prompt_replacement)
47254729

47264730
rng_state = torch.get_rng_state()
47274731
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
4732+
# True random sample image generation
4733+
torch.seed()
4734+
torch.cuda.seed()
47284735

47294736
with torch.no_grad():
4730-
# with accelerator.autocast():
4731-
for i, prompt_dict in enumerate(prompts):
4732-
if not accelerator.is_main_process:
4733-
continue
4734-
4735-
if isinstance(prompt_dict, str):
4736-
prompt_dict = line_to_prompt_dict(prompt_dict)
4737-
4738-
assert isinstance(prompt_dict, dict)
4739-
negative_prompt = prompt_dict.get("negative_prompt")
4740-
sample_steps = prompt_dict.get("sample_steps", 30)
4741-
width = prompt_dict.get("width", 512)
4742-
height = prompt_dict.get("height", 512)
4743-
scale = prompt_dict.get("scale", 7.5)
4744-
seed = prompt_dict.get("seed")
4745-
controlnet_image = prompt_dict.get("controlnet_image")
4746-
prompt: str = prompt_dict.get("prompt", "")
4747-
sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
4748-
4749-
if seed is not None:
4750-
torch.manual_seed(seed)
4751-
torch.cuda.manual_seed(seed)
4752-
4753-
scheduler = schedulers.get(sampler_name)
4754-
if scheduler is None:
4755-
scheduler = get_my_scheduler(
4756-
sample_sampler=sampler_name,
4757-
v_parameterization=args.v_parameterization,
4758-
)
4759-
schedulers[sampler_name] = scheduler
4760-
pipeline.scheduler = scheduler
4761-
4762-
if prompt_replacement is not None:
4763-
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
4764-
if negative_prompt is not None:
4765-
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
4766-
4767-
if controlnet_image is not None:
4768-
controlnet_image = Image.open(controlnet_image).convert("RGB")
4769-
controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS)
4770-
4771-
height = max(64, height - height % 8) # round to divisible by 8
4772-
width = max(64, width - width % 8) # round to divisible by 8
4773-
print(f"prompt: {prompt}")
4774-
print(f"negative_prompt: {negative_prompt}")
4775-
print(f"height: {height}")
4776-
print(f"width: {width}")
4777-
print(f"sample_steps: {sample_steps}")
4778-
print(f"scale: {scale}")
4779-
print(f"sample_sampler: {sampler_name}")
4780-
if seed is not None:
4781-
print(f"seed: {seed}")
4782-
with accelerator.autocast():
4783-
latents = pipeline(
4784-
prompt=prompt,
4785-
height=height,
4786-
width=width,
4787-
num_inference_steps=sample_steps,
4788-
guidance_scale=scale,
4789-
negative_prompt=negative_prompt,
4790-
controlnet=controlnet,
4791-
controlnet_image=controlnet_image,
4792-
)
4737+
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
4738+
for prompt_dict in prompt_dict_lists[0]:
4739+
sample_image_inference(accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, controlnet=controlnet)
47934740

4794-
image = pipeline.latents_to_image(latents)[0]
4795-
4796-
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
4797-
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
4798-
seed_suffix = "" if seed is None else f"_{seed}"
4799-
img_filename = (
4800-
f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png"
4801-
)
4802-
4803-
image.save(os.path.join(save_dir, img_filename))
4804-
4805-
# wandb有効時のみログを送信
4806-
try:
4807-
wandb_tracker = accelerator.get_tracker("wandb")
4808-
try:
4809-
import wandb
4810-
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
4811-
raise ImportError("No wandb / wandb がインストールされていないようです")
4812-
4813-
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
4814-
except: # wandb 無効時
4815-
pass
48164741

48174742
# clear pipeline and cache to reduce vram usage
48184743
del pipeline
4819-
torch.cuda.empty_cache()
48204744

4745+
with torch.cuda.device(torch.cuda.current_device()):
4746+
torch.cuda.empty_cache()
4747+
48214748
torch.set_rng_state(rng_state)
48224749
if cuda_rng_state is not None:
48234750
torch.cuda.set_rng_state(cuda_rng_state)
48244751
vae.to(org_vae_device)
48254752

4753+
def generate_per_device_prompt_list(prompts, num_of_processes, prompt_replacement=None):
4754+
4755+
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processess available (number of devices available)
4756+
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
4757+
per_process_prompts = [[] for i in range(num_of_processes)]
4758+
for i, prompt in enumerate(prompts):
4759+
if isinstance(prompt, str):
4760+
prompt = line_to_prompt_dict(prompt)
4761+
assert isinstance(prompt, dict)
4762+
prompt.pop("subset", None) # Clean up subset key
4763+
prompt["enum"] = i
4764+
# Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
4765+
if prompt_replacement is not None:
4766+
prompt["prompt"] = prompt["prompt"].replace(prompt_replacement[0], prompt_replacement[1])
4767+
if prompt["negative_prompt"] is not None:
4768+
prompt["negative_prompt"] = prompt["negative_prompt"].replace(prompt_replacement[0], prompt_replacement[1])
4769+
# Refactor prompt replacement to here in order to simplify sample_image_inference function.
4770+
per_process_prompts[i % num_of_processes].append(prompt)
4771+
return per_process_prompts
4772+
4773+
def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, pipeline, save_dir, prompt_dict, epoch, steps, controlnet=None):
4774+
assert isinstance(prompt_dict, dict)
4775+
negative_prompt = prompt_dict.get("negative_prompt")
4776+
sample_steps = prompt_dict.get("sample_steps", 30)
4777+
width = prompt_dict.get("width", 512)
4778+
height = prompt_dict.get("height", 512)
4779+
scale = prompt_dict.get("scale", 7.5)
4780+
seed = prompt_dict.get("seed")
4781+
controlnet_image = prompt_dict.get("controlnet_image")
4782+
prompt: str = prompt_dict.get("prompt", "")
4783+
sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
4784+
4785+
if seed is not None:
4786+
torch.manual_seed(seed)
4787+
torch.cuda.manual_seed(seed)
4788+
4789+
scheduler = get_my_scheduler(
4790+
sample_sampler=sampler_name,
4791+
v_parameterization=args.v_parameterization,
4792+
)
4793+
pipeline.scheduler = scheduler
4794+
4795+
if controlnet_image is not None:
4796+
controlnet_image = Image.open(controlnet_image).convert("RGB")
4797+
controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS)
4798+
4799+
height = max(64, height - height % 8) # round to divisible by 8
4800+
width = max(64, width - width % 8) # round to divisible by 8
4801+
print(f"\nprompt: {prompt}")
4802+
print(f"negative_prompt: {negative_prompt}")
4803+
print(f"height: {height}")
4804+
print(f"width: {width}")
4805+
print(f"sample_steps: {sample_steps}")
4806+
print(f"scale: {scale}")
4807+
print(f"sample_sampler: {sampler_name}")
4808+
if seed is not None:
4809+
print(f"seed: {seed}")
4810+
with accelerator.autocast():
4811+
latents = pipeline(
4812+
prompt=prompt,
4813+
height=height,
4814+
width=width,
4815+
num_inference_steps=sample_steps,
4816+
guidance_scale=scale,
4817+
negative_prompt=negative_prompt,
4818+
controlnet=controlnet,
4819+
controlnet_image=controlnet_image,
4820+
)
4821+
image = pipeline.latents_to_image(latents)[0]
4822+
# adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
4823+
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
4824+
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
4825+
seed_suffix = "" if seed is None else f"_{seed}"
4826+
i: int = prompt_dict["enum"]
4827+
img_filename = (
4828+
f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
4829+
)
4830+
4831+
image.save(os.path.join(save_dir, img_filename))
4832+
if seed is not None:
4833+
torch.seed()
4834+
torch.cuda.seed()
4835+
# wandb有効時のみログを送信
4836+
try:
4837+
wandb_tracker = accelerator.get_tracker("wandb")
4838+
try:
4839+
import wandb
4840+
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
4841+
raise ImportError("No wandb / wandb がインストールされていないようです")
48264842

4843+
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
4844+
except: # wandb 無効時
4845+
pass
48274846
# endregion
48284847

4848+
4849+
4850+
48294851
# region 前処理用
48304852

48314853

0 commit comments

Comments
 (0)