Skip to content

Commit d06e069

Browse files
andsteingsayakpaul
andauthored
Adds profiling flags, computes train metrics average. (huggingface#3053)
* WIP controlnet training - bugfix --streaming - bugfix running report_to!='wandb' - adds memory profile before validation * Adds final logging statement. * Sets train epochs to 11. Looking at a longer ~16ep run, we see only good validation images after ~11ep: https://wandb.ai/andsteing/controlnet_fill50k/runs/3j2hx6n8 * Removes --logging_dir (it's not used). * Adds --profile flags. * Updates --output_dir=runs/fill-circle-{timestamp}. * Compute mean of `train_metrics`. Previously `train_metrics[-1]` was logged, resulting in very bumpy train metrics. * Improves logging a bit. - adds l2_grads gradient norm logging - adds steps_per_sec - sets walltime as x coordinate of train/step - logs controlnet_params config * Adds --ccache (doesn't really help though). * minor fix in controlnet flax example (huggingface#2986) * fix the error when push_to_hub but not log validation * contronet_from_pt & controlnet_revision * add intermediate checkpointing to the guide * Bugfix --profile_steps * Sets `RACKER_PROJECT_NAME='controlnet_fill50k'`. * Logs fractional epoch. * Adds relative `walltime` metric. * Adds `StepTraceAnnotation` and uses `global_step` insetad of `step`. * Applied `black`. * Streamlines commands in README a bit. * Removes `--ccache`. This makes only a very small difference (~1 min) with this model size, so removing the option introduced in cdb3cc. * Re-ran `black`. * Update examples/controlnet/README.md Co-authored-by: Sayak Paul <[email protected]> * Converts spaces to tab. * Removes repeated args. * Skips first step (compilation) in profiling * Updates README with profiling instructions. * Unifies tabs/spaces in README. * Re-ran style & quality. --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 0a73b4d commit d06e069

File tree

2 files changed

+119
-45
lines changed

2 files changed

+119
-45
lines changed

examples/controlnet/README.md

Lines changed: 49 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -284,9 +284,9 @@ TPU_TYPE=v4-8
284284
VM_NAME=hg_flax
285285
286286
gcloud alpha compute tpus tpu-vm create $VM_NAME \
287-
--zone $ZONE \
288-
--accelerator-type $TPU_TYPE \
289-
--version tpu-vm-v4-base
287+
--zone $ZONE \
288+
--accelerator-type $TPU_TYPE \
289+
--version tpu-vm-v4-base
290290
291291
gcloud alpha compute tpus tpu-vm ssh $VM_NAME --zone $ZONE -- \
292292
```
@@ -326,6 +326,7 @@ If you want to use Weights and Biases logging, you should also install `wandb` n
326326
pip install wandb
327327
```
328328

329+
329330
Now let's downloading two conditioning images that we will use to run validation during the training in order to track our progress
330331

331332
```
@@ -343,8 +344,8 @@ Make sure you have the `MODEL_DIR`,`OUTPUT_DIR` and `HUB_MODEL_ID` environment v
343344
344345
```bash
345346
export MODEL_DIR="runwayml/stable-diffusion-v1-5"
346-
export OUTPUT_DIR="control_out"
347-
export HUB_MODEL_ID="fill-circle-controlnet"
347+
export OUTPUT_DIR="runs/fill-circle-{timestamp}"
348+
export HUB_MODEL_ID="controlnet-fill-circle"
348349
```
349350

350351
And finally start the training
@@ -363,32 +364,36 @@ python3 train_controlnet_flax.py \
363364
--revision="non-ema" \
364365
--from_pt \
365366
--report_to="wandb" \
366-
--max_train_steps=10000 \
367+
--tracker_project_name=$HUB_MODEL_ID \
368+
--num_train_epochs=11 \
367369
--push_to_hub \
368370
--hub_model_id=$HUB_MODEL_ID
369371
```
370372

371373
Since we passed the `--push_to_hub` flag, it will automatically create a model repo under your huggingface account based on `$HUB_MODEL_ID`. By the end of training, the final checkpoint will be automatically stored on the hub. You can find an example model repo [here](https://huggingface.co/YiYiXu/fill-circle-controlnet).
372374

373-
Our training script also provides limited support for streaming large datasets from the Hugging Face Hub. In order to enable streaming, one must also set `--max_train_samples`. Here is an example command:
375+
Our training script also provides limited support for streaming large datasets from the Hugging Face Hub. In order to enable streaming, one must also set `--max_train_samples`. Here is an example command (from [this blog article](https://huggingface.co/blog/train-your-controlnet)):
374376

375377
```bash
378+
export MODEL_DIR="runwayml/stable-diffusion-v1-5"
379+
export OUTPUT_DIR="runs/uncanny-faces-{timestamp}"
380+
export HUB_MODEL_ID="controlnet-uncanny-faces"
381+
376382
python3 train_controlnet_flax.py \
377-
--pretrained_model_name_or_path=$MODEL_DIR \
378-
--output_dir=$OUTPUT_DIR \
379-
--dataset_name=multimodalart/facesyntheticsspigacaptioned \
380-
--streaming \
381-
--conditioning_image_column=spiga_seg \
382-
--image_column=image \
383-
--caption_column=image_caption \
384-
--resolution=512 \
385-
--max_train_samples 50 \
386-
--max_train_steps 5 \
387-
--learning_rate=1e-5 \
388-
--validation_steps=2 \
389-
--train_batch_size=1 \
390-
--revision="flax" \
391-
--report_to="wandb"
383+
--pretrained_model_name_or_path=$MODEL_DIR \
384+
--output_dir=$OUTPUT_DIR \
385+
--dataset_name=multimodalart/facesyntheticsspigacaptioned \
386+
--streaming \
387+
--conditioning_image_column=spiga_seg \
388+
--image_column=image \
389+
--caption_column=image_caption \
390+
--resolution=512 \
391+
--max_train_samples 100000 \
392+
--learning_rate=1e-5 \
393+
--train_batch_size=1 \
394+
--revision="flax" \
395+
--report_to="wandb" \
396+
--tracker_project_name=$HUB_MODEL_ID
392397
```
393398

394399
Note, however, that the performance of the TPUs might get bottlenecked as streaming with `datasets` is not optimized for images. For ensuring maximum throughput, we encourage you to explore the following options:
@@ -400,16 +405,35 @@ Note, however, that the performance of the TPUs might get bottlenecked as stream
400405
When work with a larger dataset, you may need to run training process for a long time and it’s useful to save regular checkpoints during the process. You can use the following argument to enable intermediate checkpointing:
401406

402407
```bash
403-
--checkpointing_steps=500
408+
--checkpointing_steps=500
404409
```
405410
This will save the trained model in subfolders of your output_dir. Subfolder names is the number of steps performed so far; for example: a checkpoint saved after 500 training steps would be saved in a subfolder named 500
406411

407412
You can then start your training from this saved checkpoint with
408413

409414
```bash
410-
--controlnet_model_name_or_path="./control_out/500"
415+
--controlnet_model_name_or_path="./control_out/500"
411416
```
412417

413418
We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps to achieve faster convergence by rebalancing the loss. To use it, one needs to set the `--snr_gamma` argument. The recommended value when using it is `5.0`.
414419

415-
We also support gradient accumulation - it is a technique that lets you use a bigger batch size than your machine would normally be able to fit into memory. You can use `gradient_accumulation_steps` argument to set gradient accumulation steps. The ControlNet author recommends using gradient accumulation to achieve better convergence. Read more [here](https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md#more-consideration-sudden-converge-phenomenon-and-gradient-accumulation).
420+
We also support gradient accumulation - it is a technique that lets you use a bigger batch size than your machine would normally be able to fit into memory. You can use `gradient_accumulation_steps` argument to set gradient accumulation steps. The ControlNet author recommends using gradient accumulation to achieve better convergence. Read more [here](https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md#more-consideration-sudden-converge-phenomenon-and-gradient-accumulation).
421+
422+
You can **profile your code** with:
423+
424+
```bash
425+
--profile_steps==5
426+
```
427+
428+
Refer to the [JAX documentation on profiling](https://jax.readthedocs.io/en/latest/profiling.html). To inspect the profile trace, you'll have to install and start Tensorboard with the profile plugin:
429+
430+
```bash
431+
pip install tensorflow tensorboard-plugin-profile
432+
tensorboard --logdir runs/fill-circle-100steps-20230411_165612/
433+
```
434+
435+
The profile can then be inspected at http://localhost:6006/#profile
436+
437+
Sometimes you'll get version conflicts (error messages like `Duplicate plugins for name projector`), which means that you have to uninstall and reinstall all versions of Tensorflow/Tensorboard (e.g. with `pip uninstall tensorflow tf-nightly tensorboard tb-nightly tensorboard-plugin-profile && pip install tf-nightly tbp-nightly tensorboard-plugin-profile`).
438+
439+
Note that the debugging functionality of the Tensorboard `profile` plugin is still under active development. Not all views are fully functional, and for example the `trace_viewer` cuts off events after 1M (which can result in all your device traces getting lost if you for example profile the compilation step by accident).

examples/controlnet/train_controlnet_flax.py

Lines changed: 70 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import math
1919
import os
2020
import random
21+
import time
2122
from pathlib import Path
2223

2324
import jax
@@ -220,6 +221,28 @@ def parse_args():
220221
default=None,
221222
help="Revision of controlnet model identifier from huggingface.co/models.",
222223
)
224+
parser.add_argument(
225+
"--profile_steps",
226+
type=int,
227+
default=0,
228+
help="How many training steps to profile in the beginning.",
229+
)
230+
parser.add_argument(
231+
"--profile_validation",
232+
action="store_true",
233+
help="Whether to profile the (last) validation.",
234+
)
235+
parser.add_argument(
236+
"--profile_memory",
237+
action="store_true",
238+
help="Whether to dump an initial (before training loop) and a final (at program end) memory profile.",
239+
)
240+
parser.add_argument(
241+
"--ccache",
242+
type=str,
243+
default=None,
244+
help="Enables compilation cache.",
245+
)
223246
parser.add_argument(
224247
"--controlnet_from_pt",
225248
action="store_true",
@@ -234,8 +257,9 @@ def parse_args():
234257
parser.add_argument(
235258
"--output_dir",
236259
type=str,
237-
default="controlnet-model",
238-
help="The output directory where the model predictions and checkpoints will be written.",
260+
default="runs/{timestamp}",
261+
help="The output directory where the model predictions and checkpoints will be written. "
262+
"Can contain placeholders: {timestamp}.",
239263
)
240264
parser.add_argument(
241265
"--cache_dir",
@@ -317,15 +341,6 @@ def parse_args():
317341
default=None,
318342
help="The name of the repository to keep in sync with the local `output_dir`.",
319343
)
320-
parser.add_argument(
321-
"--logging_dir",
322-
type=str,
323-
default="logs",
324-
help=(
325-
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
326-
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
327-
),
328-
)
329344
parser.add_argument(
330345
"--logging_steps",
331346
type=int,
@@ -459,6 +474,8 @@ def parse_args():
459474
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
460475

461476
args = parser.parse_args()
477+
args.output_dir = args.output_dir.replace("{timestamp}", time.strftime("%Y%m%d_%H%M%S"))
478+
462479
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
463480
if env_local_rank != -1 and env_local_rank != args.local_rank:
464481
args.local_rank = env_local_rank
@@ -952,6 +969,11 @@ def cumul_grad_step(grad_idx, loss_grad_rng):
952969
metrics = {"loss": loss}
953970
metrics = jax.lax.pmean(metrics, axis_name="batch")
954971

972+
def l2(xs):
973+
return jnp.sqrt(sum([jnp.vdot(x, x) for x in jax.tree_util.tree_leaves(xs)]))
974+
975+
metrics["l2_grads"] = l2(jax.tree_util.tree_leaves(grad))
976+
955977
return new_state, metrics, new_train_rng
956978

957979
# Create parallel version of the train step
@@ -983,32 +1005,38 @@ def cumul_grad_step(grad_idx, loss_grad_rng):
9831005
logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}")
9841006
logger.info(f" Total optimization steps = {args.num_train_epochs * num_update_steps_per_epoch}")
9851007

986-
if jax.process_index() == 0:
1008+
if jax.process_index() == 0 and args.report_to == "wandb":
9871009
wandb.define_metric("*", step_metric="train/step")
1010+
wandb.define_metric("train/step", step_metric="walltime")
9881011
wandb.config.update(
9891012
{
9901013
"num_train_examples": args.max_train_samples if args.streaming else len(train_dataset),
9911014
"total_train_batch_size": total_train_batch_size,
9921015
"total_optimization_step": args.num_train_epochs * num_update_steps_per_epoch,
9931016
"num_devices": jax.device_count(),
1017+
"controlnet_params": sum(np.prod(x.shape) for x in jax.tree_util.tree_leaves(state.params)),
9941018
}
9951019
)
9961020

997-
global_step = 0
1021+
global_step = step0 = 0
9981022
epochs = tqdm(
9991023
range(args.num_train_epochs),
10001024
desc="Epoch ... ",
10011025
position=0,
10021026
disable=jax.process_index() > 0,
10031027
)
1028+
if args.profile_memory:
1029+
jax.profiler.save_device_memory_profile(os.path.join(args.output_dir, "memory_initial.prof"))
1030+
t00 = t0 = time.monotonic()
10041031
for epoch in epochs:
10051032
# ======================== Training ================================
10061033

10071034
train_metrics = []
1035+
train_metric = None
10081036

10091037
steps_per_epoch = (
10101038
args.max_train_samples // total_train_batch_size
1011-
if args.streaming
1039+
if args.streaming or args.max_train_samples
10121040
else len(train_dataset) // total_train_batch_size
10131041
)
10141042
train_step_progress_bar = tqdm(
@@ -1020,10 +1048,18 @@ def cumul_grad_step(grad_idx, loss_grad_rng):
10201048
)
10211049
# train
10221050
for batch in train_dataloader:
1051+
if args.profile_steps and global_step == 1:
1052+
train_metric["loss"].block_until_ready()
1053+
jax.profiler.start_trace(args.output_dir)
1054+
if args.profile_steps and global_step == 1 + args.profile_steps:
1055+
train_metric["loss"].block_until_ready()
1056+
jax.profiler.stop_trace()
1057+
10231058
batch = shard(batch)
1024-
state, train_metric, train_rngs = p_train_step(
1025-
state, unet_params, text_encoder_params, vae_params, batch, train_rngs
1026-
)
1059+
with jax.profiler.StepTraceAnnotation("train", step_num=global_step):
1060+
state, train_metric, train_rngs = p_train_step(
1061+
state, unet_params, text_encoder_params, vae_params, batch, train_rngs
1062+
)
10271063
train_metrics.append(train_metric)
10281064

10291065
train_step_progress_bar.update(1)
@@ -1041,13 +1077,19 @@ def cumul_grad_step(grad_idx, loss_grad_rng):
10411077

10421078
if global_step % args.logging_steps == 0 and jax.process_index() == 0:
10431079
if args.report_to == "wandb":
1080+
train_metrics = jax_utils.unreplicate(train_metrics)
1081+
train_metrics = jax.tree_util.tree_map(lambda *m: jnp.array(m).mean(), *train_metrics)
10441082
wandb.log(
10451083
{
1084+
"walltime": time.monotonic() - t00,
10461085
"train/step": global_step,
1047-
"train/epoch": epoch,
1048-
"train/loss": jax_utils.unreplicate(train_metric)["loss"],
1086+
"train/epoch": global_step / dataset_length,
1087+
"train/steps_per_sec": (global_step - step0) / (time.monotonic() - t0),
1088+
**{f"train/{k}": v for k, v in train_metrics.items()},
10491089
}
10501090
)
1091+
t0, step0 = time.monotonic(), global_step
1092+
train_metrics = []
10511093
if global_step % args.checkpointing_steps == 0 and jax.process_index() == 0:
10521094
controlnet.save_pretrained(
10531095
f"{args.output_dir}/{global_step}",
@@ -1058,10 +1100,14 @@ def cumul_grad_step(grad_idx, loss_grad_rng):
10581100
train_step_progress_bar.close()
10591101
epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})")
10601102

1061-
# Create the pipeline using using the trained modules and save it.
1103+
# Final validation & store model.
10621104
if jax.process_index() == 0:
10631105
if args.validation_prompt is not None:
1106+
if args.profile_validation:
1107+
jax.profiler.start_trace(args.output_dir)
10641108
image_logs = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype)
1109+
if args.profile_validation:
1110+
jax.profiler.stop_trace()
10651111
else:
10661112
image_logs = None
10671113

@@ -1084,6 +1130,10 @@ def cumul_grad_step(grad_idx, loss_grad_rng):
10841130
ignore_patterns=["step_*", "epoch_*"],
10851131
)
10861132

1133+
if args.profile_memory:
1134+
jax.profiler.save_device_memory_profile(os.path.join(args.output_dir, "memory_final.prof"))
1135+
logger.info("Finished training.")
1136+
10871137

10881138
if __name__ == "__main__":
10891139
main()

0 commit comments

Comments
 (0)