42
42
43
43
logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
44
44
45
+ # Set to True to use python for loop instead of jax.fori_loop for easier debugging
46
+ DEBUG = False
47
+
45
48
46
49
class FlaxStableDiffusionPipeline (FlaxDiffusionPipeline ):
47
50
r"""
@@ -187,7 +190,6 @@ def _generate(
187
190
width : Optional [int ] = None ,
188
191
guidance_scale : float = 7.5 ,
189
192
latents : Optional [jnp .array ] = None ,
190
- debug : bool = False ,
191
193
neg_prompt_ids : jnp .array = None ,
192
194
):
193
195
# 0. Default height and width to unet
@@ -260,8 +262,7 @@ def loop_body(step, args):
260
262
261
263
# scale the initial noise by the standard deviation required by the scheduler
262
264
latents = latents * self .scheduler .init_noise_sigma
263
-
264
- if debug :
265
+ if DEBUG :
265
266
# run with python for loop
266
267
for i in range (num_inference_steps ):
267
268
latents , scheduler_state = loop_body (i , (latents , scheduler_state ))
@@ -283,11 +284,10 @@ def __call__(
283
284
num_inference_steps : int = 50 ,
284
285
height : Optional [int ] = None ,
285
286
width : Optional [int ] = None ,
286
- guidance_scale : float = 7.5 ,
287
+ guidance_scale : Union [ float , jnp . array ] = 7.5 ,
287
288
latents : jnp .array = None ,
288
289
return_dict : bool = True ,
289
290
jit : bool = False ,
290
- debug : bool = False ,
291
291
neg_prompt_ids : jnp .array = None ,
292
292
):
293
293
r"""
@@ -334,6 +334,14 @@ def __call__(
334
334
height = height or self .unet .config .sample_size * self .vae_scale_factor
335
335
width = width or self .unet .config .sample_size * self .vae_scale_factor
336
336
337
+ if isinstance (guidance_scale , float ):
338
+ # Convert to a tensor so each device gets a copy. Follow the prompt_ids for
339
+ # shape information, as they may be sharded (when `jit` is `True`), or not.
340
+ guidance_scale = jnp .array ([guidance_scale ] * prompt_ids .shape [0 ])
341
+ if len (prompt_ids .shape ) > 2 :
342
+ # Assume sharded
343
+ guidance_scale = guidance_scale .reshape (prompt_ids .shape [:2 ])
344
+
337
345
if jit :
338
346
images = _p_generate (
339
347
self ,
@@ -345,7 +353,6 @@ def __call__(
345
353
width ,
346
354
guidance_scale ,
347
355
latents ,
348
- debug ,
349
356
neg_prompt_ids ,
350
357
)
351
358
else :
@@ -358,7 +365,6 @@ def __call__(
358
365
width ,
359
366
guidance_scale ,
360
367
latents ,
361
- debug ,
362
368
neg_prompt_ids ,
363
369
)
364
370
@@ -388,8 +394,13 @@ def __call__(
388
394
return FlaxStableDiffusionPipelineOutput (images = images , nsfw_content_detected = has_nsfw_concept )
389
395
390
396
391
- # TODO: maybe use a config dict instead of so many static argnums
392
- @partial (jax .pmap , static_broadcasted_argnums = (0 , 4 , 5 , 6 , 7 , 9 ))
397
+ # Static argnums are pipe, num_inference_steps, height, width. A change would trigger recompilation.
398
+ # Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`).
399
+ @partial (
400
+ jax .pmap ,
401
+ in_axes = (None , 0 , 0 , 0 , None , None , None , 0 , 0 , 0 ),
402
+ static_broadcasted_argnums = (0 , 4 , 5 , 6 ),
403
+ )
393
404
def _p_generate (
394
405
pipe ,
395
406
prompt_ids ,
@@ -400,7 +411,6 @@ def _p_generate(
400
411
width ,
401
412
guidance_scale ,
402
413
latents ,
403
- debug ,
404
414
neg_prompt_ids ,
405
415
):
406
416
return pipe ._generate (
@@ -412,7 +422,6 @@ def _p_generate(
412
422
width ,
413
423
guidance_scale ,
414
424
latents ,
415
- debug ,
416
425
neg_prompt_ids ,
417
426
)
418
427
0 commit comments