Skip to content

Commit 089f0f4

Browse files
authored
update to latest colossalai (huggingface#1951)
1 parent aba2a65 commit 089f0f4

File tree

1 file changed

+18
-20
lines changed

1 file changed

+18
-20
lines changed

examples/research_projects/colossalai/train_dreambooth_colossalai.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515
from colossalai.core import global_context as gpc
1616
from colossalai.logging import disable_existing_loggers, get_dist_logger
1717
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
18-
from colossalai.nn.parallel.utils import convert_to_torch_module
19-
from colossalai.tensor import ProcessGroup
18+
from colossalai.nn.parallel.utils import get_static_torch_model
2019
from colossalai.utils import get_current_device
2120
from colossalai.utils.model.colo_init_context import ColoInitContext
2221
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
@@ -356,26 +355,17 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
356355

357356

358357
# Gemini + ZeRO DDP
359-
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
358+
def gemini_zero_dpp(model: torch.nn.Module, placememt_policy: str = "auto"):
360359
from colossalai.nn.parallel import GeminiDDP
361360

362361
model = GeminiDDP(
363-
model, device=get_current_device(), placement_policy=placememt_policy, pin_memory=True, search_range_mb=32
362+
model, device=get_current_device(), placement_policy=placememt_policy, pin_memory=True, search_range_mb=64
364363
)
365364
return model
366365

367366

368367
def main(args):
369-
# config for colossalai
370-
371-
config = {
372-
"BATCH": args.train_batch_size,
373-
"gradient_accumulation_steps": args.gradient_accumulation_steps,
374-
"clip_grad_norm": args.max_grad_norm,
375-
}
376-
377-
colossalai.launch_from_torch(config=config)
378-
pg = ProcessGroup()
368+
colossalai.launch_from_torch(config={})
379369

380370
if args.seed is not None:
381371
gpc.set_seed(args.seed)
@@ -472,7 +462,7 @@ def main(args):
472462
)
473463

474464
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
475-
with ColoInitContext():
465+
with ColoInitContext(device=get_current_device()):
476466
unet = UNet2DConditionModel.from_pretrained(
477467
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False
478468
)
@@ -484,12 +474,19 @@ def main(args):
484474
unet.enable_gradient_checkpointing()
485475

486476
if args.scale_lr:
487-
args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * 2
477+
args.learning_rate = (
478+
args.learning_rate
479+
* args.gradient_accumulation_steps
480+
* args.train_batch_size
481+
* gpc.get_world_size(ParallelMode.DATA)
482+
)
488483

489-
unet = gemini_zero_dpp(unet, pg, args.placement)
484+
unet = gemini_zero_dpp(unet, args.placement)
490485

491486
# config optimizer for colossalai zero
492-
optimizer = GeminiAdamOptimizer(unet, lr=args.learning_rate, initial_scale=2**5)
487+
optimizer = GeminiAdamOptimizer(
488+
unet, lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm
489+
)
493490

494491
# load noise_scheduler
495492
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
@@ -657,10 +654,11 @@ def collate_fn(examples):
657654

658655
if global_step % args.save_steps == 0:
659656
torch.cuda.synchronize()
657+
torch_unet = get_static_torch_model(unet)
660658
if gpc.get_local_rank(ParallelMode.DATA) == 0:
661659
pipeline = DiffusionPipeline.from_pretrained(
662660
args.pretrained_model_name_or_path,
663-
unet=convert_to_torch_module(unet),
661+
unet=torch_unet,
664662
revision=args.revision,
665663
)
666664
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
@@ -670,7 +668,7 @@ def collate_fn(examples):
670668
break
671669

672670
torch.cuda.synchronize()
673-
unet = convert_to_torch_module(unet)
671+
unet = get_static_torch_model(unet)
674672

675673
if gpc.get_local_rank(ParallelMode.DATA) == 0:
676674
pipeline = DiffusionPipeline.from_pretrained(

0 commit comments

Comments
 (0)