@@ -165,6 +165,7 @@ def _generate(
165165 guidance_scale : float = 7.5 ,
166166 latents : Optional [jnp .array ] = None ,
167167 debug : bool = False ,
168+ neg_prompt_ids : jnp .array = None ,
168169 ):
169170 if height % 8 != 0 or width % 8 != 0 :
170171 raise ValueError (f"`height` and `width` have to be divisible by 8 but are { height } and { width } ." )
@@ -177,10 +178,14 @@ def _generate(
177178 batch_size = prompt_ids .shape [0 ]
178179
179180 max_length = prompt_ids .shape [- 1 ]
180- uncond_input = self .tokenizer (
181- ["" ] * batch_size , padding = "max_length" , max_length = max_length , return_tensors = "np"
182- )
183- uncond_embeddings = self .text_encoder (uncond_input .input_ids , params = params ["text_encoder" ])[0 ]
181+
182+ if neg_prompt_ids is None :
183+ uncond_input = self .tokenizer (
184+ ["" ] * batch_size , padding = "max_length" , max_length = max_length , return_tensors = "np"
185+ ).input_ids
186+ else :
187+ uncond_input = neg_prompt_ids
188+ uncond_embeddings = self .text_encoder (uncond_input , params = params ["text_encoder" ])[0 ]
184189 context = jnp .concatenate ([uncond_embeddings , text_embeddings ])
185190
186191 latents_shape = (batch_size , self .unet .in_channels , height // 8 , width // 8 )
@@ -251,6 +256,7 @@ def __call__(
251256 return_dict : bool = True ,
252257 jit : bool = False ,
253258 debug : bool = False ,
259+ neg_prompt_ids : jnp .array = None ,
254260 ** kwargs ,
255261 ):
256262 r"""
@@ -298,11 +304,30 @@ def __call__(
298304 """
299305 if jit :
300306 images = _p_generate (
301- self , prompt_ids , params , prng_seed , num_inference_steps , height , width , guidance_scale , latents , debug
307+ self ,
308+ prompt_ids ,
309+ params ,
310+ prng_seed ,
311+ num_inference_steps ,
312+ height ,
313+ width ,
314+ guidance_scale ,
315+ latents ,
316+ debug ,
317+ neg_prompt_ids ,
302318 )
303319 else :
304320 images = self ._generate (
305- prompt_ids , params , prng_seed , num_inference_steps , height , width , guidance_scale , latents , debug
321+ prompt_ids ,
322+ params ,
323+ prng_seed ,
324+ num_inference_steps ,
325+ height ,
326+ width ,
327+ guidance_scale ,
328+ latents ,
329+ debug ,
330+ neg_prompt_ids ,
306331 )
307332
308333 if self .safety_checker is not None :
@@ -333,10 +358,29 @@ def __call__(
333358# TODO: maybe use a config dict instead of so many static argnums
334359@partial (jax .pmap , static_broadcasted_argnums = (0 , 4 , 5 , 6 , 7 , 9 ))
335360def _p_generate (
336- pipe , prompt_ids , params , prng_seed , num_inference_steps , height , width , guidance_scale , latents , debug
361+ pipe ,
362+ prompt_ids ,
363+ params ,
364+ prng_seed ,
365+ num_inference_steps ,
366+ height ,
367+ width ,
368+ guidance_scale ,
369+ latents ,
370+ debug ,
371+ neg_prompt_ids ,
337372):
338373 return pipe ._generate (
339- prompt_ids , params , prng_seed , num_inference_steps , height , width , guidance_scale , latents , debug
374+ prompt_ids ,
375+ params ,
376+ prng_seed ,
377+ num_inference_steps ,
378+ height ,
379+ width ,
380+ guidance_scale ,
381+ latents ,
382+ debug ,
383+ neg_prompt_ids ,
340384 )
341385
342386
0 commit comments