Skip to content

Commit 221b032

Browse files
committed
add sample run and training script
1 parent 46c52f9 commit 221b032

File tree

4 files changed

+163
-0
lines changed

4 files changed

+163
-0
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
set -euv
2+
3+
# TODO: how to fine tune?
4+
# for training, need to install `diffusers` from local:
5+
# https://huggingface.co/docs/diffusers/installation#install-from-source
6+
#
7+
# pip install -e ".[torch]"
8+
9+
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
10+
# moved from ~/.cache/huggingface/.....
11+
# export MODEL_NAME='sd-compvis-model'
12+
13+
OUTPUT_DIR="dreambooth_model1"
14+
INSTANCE_DIR="data/instance_images"
15+
CAPTIONS_DIR="data/captions"
16+
17+
# if GPU, set --mixed_precision="fp16"
18+
#
19+
# training param to tune
20+
# --max_train_steps=15000 \
21+
# --learning_rate=1e-05 \
22+
# --use_8bit_adam \
23+
# --captions_dir="$CAPTIONS_DIR" \
24+
25+
# dump only textenc
26+
27+
accelerate launch --mixed_precision="no" examples/dreambooth/train_dreambooth.py \
28+
--pretrained_model_name_or_path=$MODEL_NAME \
29+
--instance_data_dir=$INSTANCE_DIR \
30+
--output_dir=$OUTPUT_DIR \
31+
--instance_prompt="a photo of sks dog" \
32+
--resolution=512 \
33+
--train_batch_size=1 \
34+
--gradient_accumulation_steps=1 \
35+
--learning_rate=5e-6 \
36+
--lr_scheduler="constant" \
37+
--lr_warmup_steps=0 \
38+
--max_train_steps=400
39+
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
"""
2+
usage:
3+
python sample_inference_txt2img.py
4+
5+
sample code from https://github.com/huggingface/diffusers
6+
7+
one time setup:
8+
9+
conda create -n sd2 pytorch==1.12.1 torchvision==0.13.1
10+
conda activate sd2
11+
# conda install -c conda-forge diffusers==0.12.1 <-- conda version
12+
# conda install -c conda-forge transformers==4.19.2 <-- one repo needs this version
13+
conda install -c conda-forge transformers==4.27.4
14+
conda install -c conda-forge accelerate==0.18.0
15+
conda install -c conda-forge datasets==2.11.0
16+
conda install -c conda-forge ftfy==6.1.1
17+
pip install invisible-watermark
18+
19+
20+
for training, need to install `diffusers` from local:
21+
https://huggingface.co/docs/diffusers/installation#install-from-source
22+
"""
23+
24+
import time
25+
import torch
26+
from diffusers import StableDiffusionPipeline
27+
28+
29+
load_from_local = False
30+
31+
if not load_from_local:
32+
# option-1: download from Hub
33+
# will download to ~/.cache/huggingface/...
34+
model_path = 'runwayml/stable-diffusion-v1-5'
35+
# model_path = '~/.cache/huggingface/diffusers/models--runwayml--stable-diffusion-v1-5/snapshots/39593d5650112b4cc580433f6b0435385882d819'
36+
# model_path = 'CompVis/stable-diffusion-v1-4'
37+
# model_path = '~/.cache/huggingface/hub/models--CompVis--stable-diffusion-v1-4/snapshots/249dd2d739844dea6a0bc7fc27b3c1d014720b28'
38+
# model_path = 'sd-compvis-model' # moved from ~/.cache/huggingface/...
39+
print(f"downloading {model_path}")
40+
41+
else:
42+
# option-2: load from local path
43+
model_path = 'sd-pokemon-model'
44+
print(f"loading from local path {model_path}")
45+
46+
start = time.time()
47+
pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float32, safety_checker=None, requires_safety_checker=False)
48+
pipe = pipe.to("cpu")
49+
# Recommended if your computer has < 64 GB of RAM
50+
pipe.enable_attention_slicing()
51+
52+
# Note: maximum sequence length for this model
53+
# prompt = "yoda"
54+
prompt = "This Elegant 14K Solid Two Tone Gold Mens Wedding Band is 6mm wide. Center of the Ring has a Satin Finished and edges are Shiny Finish. This Ring is comfort Fitted.\n\n Manufactured in New York, USA. Available in different Metals, Widths, Colors and Finishing."
55+
# prompt = "beautiful elven woman sitting in a white elven city, (full body), (blush), (sitting on stone staircase), pinup pose, (world of warcraft blood elf), (cosplay wig), (medium blonde hair:1.3), (light blue eyes:1.2), ((red, and gold elf minidress)), intricate elven dress"
56+
57+
print(f"=== prompt ===\n{prompt}\n===========\n")
58+
59+
# First-time "warmup" pass (see explanation above)
60+
_ = pipe(prompt, num_inference_steps=1)
61+
62+
# Results match those from the CPU device after the warmup pass.
63+
img_list = pipe(prompt, num_inference_steps=80).images
64+
65+
print(len(img_list))
66+
image = img_list[0]
67+
68+
output_fn = 'output1.png'
69+
print(f"after {(time.time() - start) / 60.0 :.2f} minutes, saving file into {output_fn}")
70+
image.save(output_fn)
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
set -euv
2+
3+
# TODO: how to fine tune?
4+
# for training, need to install `diffusers` from local:
5+
# https://huggingface.co/docs/diffusers/installation#install-from-source
6+
#
7+
# pip install -e ".[torch]"
8+
9+
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
10+
# moved from ~/.cache/huggingface/.....
11+
# export MODEL_NAME='sd-compvis-model'
12+
13+
export dataset_name="lambdalabs/pokemon-blip-captions"
14+
15+
16+
# if GPU, set --mixed_precision="fp16"
17+
#
18+
# training param to tune
19+
# --max_train_steps=15000 \
20+
# --learning_rate=1e-05 \
21+
# --use_8bit_adam \
22+
23+
accelerate launch --mixed_precision="no" examples/text_to_image/train_text_to_image.py \
24+
--pretrained_model_name_or_path=$MODEL_NAME \
25+
--dataset_name=$dataset_name \
26+
--use_ema \
27+
--resolution=512 --center_crop --random_flip \
28+
--train_batch_size=1 \
29+
--gradient_accumulation_steps=4 \
30+
--gradient_checkpointing \
31+
--max_train_steps=1000 \
32+
--learning_rate=1e-05 \
33+
--max_grad_norm=1 \
34+
--lr_scheduler="constant" --lr_warmup_steps=0 \
35+
--checkpointing_steps=5 \
36+
--checkpoints_total_limit=2 \
37+
--resume_from_checkpoint="latest" \
38+
--output_dir="sd-pokemon-model"

examples/text_to_image/train_text_to_image.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -780,9 +780,14 @@ def collate_fn(examples):
780780
progress_bar.set_description("Steps")
781781

782782
for epoch in range(first_epoch, args.num_train_epochs):
783+
logger.info(f"epoch = {epoch}: start")
783784
unet.train()
785+
logger.info(f"epoch = {epoch}: finish train()")
786+
784787
train_loss = 0.0
785788
for step, batch in enumerate(train_dataloader):
789+
logger.info(f"step = {step}: start")
790+
786791
# Skip steps until we reach the resumed step
787792
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
788793
if step % args.gradient_accumulation_steps == 0:
@@ -802,6 +807,7 @@ def collate_fn(examples):
802807
(latents.shape[0], latents.shape[1], 1, 1), device=latents.device
803808
)
804809

810+
print(f"latents.shape = {latents.shape}")
805811
bsz = latents.shape[0]
806812
# Sample a random timestep for each image
807813
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
@@ -823,6 +829,7 @@ def collate_fn(examples):
823829
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
824830

825831
# Predict the noise residual and compute loss
832+
logger.info(f"step = {step}: Predict the noise residual and compute loss")
826833
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
827834

828835
if args.snr_gamma is None:
@@ -842,17 +849,23 @@ def collate_fn(examples):
842849
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
843850
loss = loss.mean()
844851

852+
logger.info(f"step = {step}: after accelerator loss={loss:.3f} ==")
853+
845854
# Gather the losses across all processes for logging (if we use distributed training).
846855
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
847856
train_loss += avg_loss.item() / args.gradient_accumulation_steps
848857

849858
# Backpropagate
859+
logger.info(f"step = {step}: Backpropagate with {loss:.3f}")
850860
accelerator.backward(loss)
851861
if accelerator.sync_gradients:
852862
accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
853863
optimizer.step()
864+
logger.info(f"step = {step}: optimizer.step() done")
854865
lr_scheduler.step()
866+
logger.info(f"step = {step}: lr_scheduler done")
855867
optimizer.zero_grad()
868+
logger.info(f"step = {step}: optimizer.zero_grad() done")
856869

857870
# Checks if the accelerator has performed an optimization step behind the scenes
858871
if accelerator.sync_gradients:
@@ -872,10 +885,13 @@ def collate_fn(examples):
872885
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
873886
progress_bar.set_postfix(**logs)
874887

888+
logger.info(f"step_loss: {loss.detach().item()}, lr: {lr_scheduler.get_last_lr()[0]}")
889+
875890
if global_step >= args.max_train_steps:
876891
break
877892

878893
if accelerator.is_main_process:
894+
logger.info("== accelerator.is_main_process ==")
879895
if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
880896
if args.use_ema:
881897
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.

0 commit comments

Comments
 (0)