Skip to content

Commit 6a7f1f0

Browse files
authored
Flax: avoid recompilation when params change (huggingface#1096)
* Do not recompile when guidance_scale changes. * Remove debug for simplicity. * make style * Make guidance_scale an array. * Make DEBUG a constant to avoid passing it down. * Add comments for clarification.
1 parent 170ebd2 commit 6a7f1f0

File tree

1 file changed

+20
-11
lines changed

1 file changed

+20
-11
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py

+20-11
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@
4242

4343
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
4444

45+
# Set to True to use python for loop instead of jax.fori_loop for easier debugging
46+
DEBUG = False
47+
4548

4649
class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
4750
r"""
@@ -187,7 +190,6 @@ def _generate(
187190
width: Optional[int] = None,
188191
guidance_scale: float = 7.5,
189192
latents: Optional[jnp.array] = None,
190-
debug: bool = False,
191193
neg_prompt_ids: jnp.array = None,
192194
):
193195
# 0. Default height and width to unet
@@ -260,8 +262,7 @@ def loop_body(step, args):
260262

261263
# scale the initial noise by the standard deviation required by the scheduler
262264
latents = latents * self.scheduler.init_noise_sigma
263-
264-
if debug:
265+
if DEBUG:
265266
# run with python for loop
266267
for i in range(num_inference_steps):
267268
latents, scheduler_state = loop_body(i, (latents, scheduler_state))
@@ -283,11 +284,10 @@ def __call__(
283284
num_inference_steps: int = 50,
284285
height: Optional[int] = None,
285286
width: Optional[int] = None,
286-
guidance_scale: float = 7.5,
287+
guidance_scale: Union[float, jnp.array] = 7.5,
287288
latents: jnp.array = None,
288289
return_dict: bool = True,
289290
jit: bool = False,
290-
debug: bool = False,
291291
neg_prompt_ids: jnp.array = None,
292292
):
293293
r"""
@@ -334,6 +334,14 @@ def __call__(
334334
height = height or self.unet.config.sample_size * self.vae_scale_factor
335335
width = width or self.unet.config.sample_size * self.vae_scale_factor
336336

337+
if isinstance(guidance_scale, float):
338+
# Convert to a tensor so each device gets a copy. Follow the prompt_ids for
339+
# shape information, as they may be sharded (when `jit` is `True`), or not.
340+
guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0])
341+
if len(prompt_ids.shape) > 2:
342+
# Assume sharded
343+
guidance_scale = guidance_scale.reshape(prompt_ids.shape[:2])
344+
337345
if jit:
338346
images = _p_generate(
339347
self,
@@ -345,7 +353,6 @@ def __call__(
345353
width,
346354
guidance_scale,
347355
latents,
348-
debug,
349356
neg_prompt_ids,
350357
)
351358
else:
@@ -358,7 +365,6 @@ def __call__(
358365
width,
359366
guidance_scale,
360367
latents,
361-
debug,
362368
neg_prompt_ids,
363369
)
364370

@@ -388,8 +394,13 @@ def __call__(
388394
return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)
389395

390396

391-
# TODO: maybe use a config dict instead of so many static argnums
392-
@partial(jax.pmap, static_broadcasted_argnums=(0, 4, 5, 6, 7, 9))
397+
# Static argnums are pipe, num_inference_steps, height, width. A change would trigger recompilation.
398+
# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`).
399+
@partial(
400+
jax.pmap,
401+
in_axes=(None, 0, 0, 0, None, None, None, 0, 0, 0),
402+
static_broadcasted_argnums=(0, 4, 5, 6),
403+
)
393404
def _p_generate(
394405
pipe,
395406
prompt_ids,
@@ -400,7 +411,6 @@ def _p_generate(
400411
width,
401412
guidance_scale,
402413
latents,
403-
debug,
404414
neg_prompt_ids,
405415
):
406416
return pipe._generate(
@@ -412,7 +422,6 @@ def _p_generate(
412422
width,
413423
guidance_scale,
414424
latents,
415-
debug,
416425
neg_prompt_ids,
417426
)
418427

0 commit comments

Comments
 (0)