4242
4343logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
4444
45+ # Set to True to use python for loop instead of jax.fori_loop for easier debugging
46+ DEBUG = False
47+
4548
4649class FlaxStableDiffusionPipeline (FlaxDiffusionPipeline ):
4750 r"""
@@ -187,7 +190,6 @@ def _generate(
187190 width : Optional [int ] = None ,
188191 guidance_scale : float = 7.5 ,
189192 latents : Optional [jnp .array ] = None ,
190- debug : bool = False ,
191193 neg_prompt_ids : jnp .array = None ,
192194 ):
193195 # 0. Default height and width to unet
@@ -260,8 +262,7 @@ def loop_body(step, args):
260262
261263 # scale the initial noise by the standard deviation required by the scheduler
262264 latents = latents * self .scheduler .init_noise_sigma
263-
264- if debug :
265+ if DEBUG :
265266 # run with python for loop
266267 for i in range (num_inference_steps ):
267268 latents , scheduler_state = loop_body (i , (latents , scheduler_state ))
@@ -283,11 +284,10 @@ def __call__(
283284 num_inference_steps : int = 50 ,
284285 height : Optional [int ] = None ,
285286 width : Optional [int ] = None ,
286- guidance_scale : float = 7.5 ,
287+ guidance_scale : Union [ float , jnp . array ] = 7.5 ,
287288 latents : jnp .array = None ,
288289 return_dict : bool = True ,
289290 jit : bool = False ,
290- debug : bool = False ,
291291 neg_prompt_ids : jnp .array = None ,
292292 ):
293293 r"""
@@ -334,6 +334,14 @@ def __call__(
334334 height = height or self .unet .config .sample_size * self .vae_scale_factor
335335 width = width or self .unet .config .sample_size * self .vae_scale_factor
336336
337+ if isinstance (guidance_scale , float ):
338+ # Convert to a tensor so each device gets a copy. Follow the prompt_ids for
339+ # shape information, as they may be sharded (when `jit` is `True`), or not.
340+ guidance_scale = jnp .array ([guidance_scale ] * prompt_ids .shape [0 ])
341+ if len (prompt_ids .shape ) > 2 :
342+ # Assume sharded
343+ guidance_scale = guidance_scale .reshape (prompt_ids .shape [:2 ])
344+
337345 if jit :
338346 images = _p_generate (
339347 self ,
@@ -345,7 +353,6 @@ def __call__(
345353 width ,
346354 guidance_scale ,
347355 latents ,
348- debug ,
349356 neg_prompt_ids ,
350357 )
351358 else :
@@ -358,7 +365,6 @@ def __call__(
358365 width ,
359366 guidance_scale ,
360367 latents ,
361- debug ,
362368 neg_prompt_ids ,
363369 )
364370
@@ -388,8 +394,13 @@ def __call__(
388394 return FlaxStableDiffusionPipelineOutput (images = images , nsfw_content_detected = has_nsfw_concept )
389395
390396
391- # TODO: maybe use a config dict instead of so many static argnums
392- @partial (jax .pmap , static_broadcasted_argnums = (0 , 4 , 5 , 6 , 7 , 9 ))
397+ # Static argnums are pipe, num_inference_steps, height, width. A change would trigger recompilation.
398+ # Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`).
399+ @partial (
400+ jax .pmap ,
401+ in_axes = (None , 0 , 0 , 0 , None , None , None , 0 , 0 , 0 ),
402+ static_broadcasted_argnums = (0 , 4 , 5 , 6 ),
403+ )
393404def _p_generate (
394405 pipe ,
395406 prompt_ids ,
@@ -400,7 +411,6 @@ def _p_generate(
400411 width ,
401412 guidance_scale ,
402413 latents ,
403- debug ,
404414 neg_prompt_ids ,
405415):
406416 return pipe ._generate (
@@ -412,7 +422,6 @@ def _p_generate(
412422 width ,
413423 guidance_scale ,
414424 latents ,
415- debug ,
416425 neg_prompt_ids ,
417426 )
418427
0 commit comments