Skip to content

Commit 8ecdd3e

Browse files
authored
Optimize log_validation in train_controlnet_flax (huggingface#3110)
extract pipeline from log_validation
1 parent cd8b750 commit 8ecdd3e

File tree

1 file changed

+19
-16
lines changed

1 file changed

+19
-16
lines changed

examples/controlnet/train_controlnet_flax.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -76,20 +76,11 @@ def image_grid(imgs, rows, cols):
7676
return grid
7777

7878

79-
def log_validation(controlnet, controlnet_params, tokenizer, args, rng, weight_dtype):
80-
logger.info("Running validation... ")
79+
def log_validation(pipeline, pipeline_params, controlnet_params, tokenizer, args, rng, weight_dtype):
80+
logger.info("Running validation...")
8181

82-
pipeline, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
83-
args.pretrained_model_name_or_path,
84-
tokenizer=tokenizer,
85-
controlnet=controlnet,
86-
safety_checker=None,
87-
dtype=weight_dtype,
88-
revision=args.revision,
89-
from_pt=args.from_pt,
90-
)
91-
params = jax_utils.replicate(params)
92-
params["controlnet"] = controlnet_params
82+
pipeline_params = pipeline_params.copy()
83+
pipeline_params["controlnet"] = controlnet_params
9384

9485
num_samples = jax.device_count()
9586
prng_seed = jax.random.split(rng, jax.device_count())
@@ -121,7 +112,7 @@ def log_validation(controlnet, controlnet_params, tokenizer, args, rng, weight_d
121112
images = pipeline(
122113
prompt_ids=prompt_ids,
123114
image=processed_image,
124-
params=params,
115+
params=pipeline_params,
125116
prng_seed=prng_seed,
126117
num_inference_steps=50,
127118
jit=True,
@@ -176,6 +167,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N
176167
- text-to-image
177168
- diffusers
178169
- controlnet
170+
- jax-diffusers-event
179171
inference: true
180172
---
181173
"""
@@ -800,6 +792,17 @@ def main():
800792
]:
801793
controlnet_params[key] = unet_params[key]
802794

795+
pipeline, pipeline_params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
796+
args.pretrained_model_name_or_path,
797+
tokenizer=tokenizer,
798+
controlnet=controlnet,
799+
safety_checker=None,
800+
dtype=weight_dtype,
801+
revision=args.revision,
802+
from_pt=args.from_pt,
803+
)
804+
pipeline_params = jax_utils.replicate(pipeline_params)
805+
803806
# Optimization
804807
if args.scale_lr:
805808
args.learning_rate = args.learning_rate * total_train_batch_size
@@ -1073,7 +1076,7 @@ def l2(xs):
10731076
and global_step % args.validation_steps == 0
10741077
and jax.process_index() == 0
10751078
):
1076-
_ = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype)
1079+
_ = log_validation(pipeline, pipeline_params, state.params, tokenizer, args, validation_rng, weight_dtype)
10771080

10781081
if global_step % args.logging_steps == 0 and jax.process_index() == 0:
10791082
if args.report_to == "wandb":
@@ -1105,7 +1108,7 @@ def l2(xs):
11051108
if args.validation_prompt is not None:
11061109
if args.profile_validation:
11071110
jax.profiler.start_trace(args.output_dir)
1108-
image_logs = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype)
1111+
image_logs = log_validation(pipeline, pipeline_params, state.params, tokenizer, args, validation_rng, weight_dtype)
11091112
if args.profile_validation:
11101113
jax.profiler.stop_trace()
11111114
else:

0 commit comments

Comments
 (0)