@@ -275,3 +275,185 @@ def pag_attn_processors(self):
275275 if proc .__class__ in (PAGCFGIdentitySelfAttnProcessor2_0 , PAGIdentitySelfAttnProcessor2_0 ):
276276 processors [name ] = proc
277277 return processors
278+
279+
280+ class PixArtPAGMixin :
281+ @staticmethod
282+ def _check_input_pag_applied_layer (layer ):
283+ r"""
284+ Check if each layer input in `applied_pag_layers` is valid. It should be the block index: {block_index}.
285+ """
286+
287+ # Check if the layer index is valid (should be int or str of int)
288+ if isinstance (layer , int ):
289+ return # Valid layer index
290+
291+ if isinstance (layer , str ):
292+ if layer .isdigit ():
293+ return # Valid layer index
294+
295+ # If it is not a valid layer index, raise a ValueError
296+ raise ValueError (f"Pag layer should only contain block index. Accept number string like '3', got { layer } ." )
297+
298+ def _set_pag_attn_processor (self , pag_applied_layers , do_classifier_free_guidance ):
299+ r"""
300+ Set the attention processor for the PAG layers.
301+ """
302+ if do_classifier_free_guidance :
303+ pag_attn_proc = PAGCFGIdentitySelfAttnProcessor2_0 ()
304+ else :
305+ pag_attn_proc = PAGIdentitySelfAttnProcessor2_0 ()
306+
307+ def is_self_attn (module_name ):
308+ r"""
309+ Check if the module is self-attention module based on its name.
310+ """
311+ return (
312+ "attn1" in module_name and len (module_name .split ("." )) == 3
313+ ) # include transformer_blocks.1.attn1, exclude transformer_blocks.18.attn1.to_q, transformer_blocks.1.attn1.add_q_proj, ...
314+
315+ def get_block_index (module_name ):
316+ r"""
317+ Get the block index from the module name. can be "block_0", "block_1", ... If there is only one block (e.g.
318+ mid_block) and index is ommited from the name, it will be "block_0".
319+ """
320+ # transformer_blocks.23.attn -> "23"
321+ return module_name .split ("." )[1 ]
322+
323+ for pag_layer_input in pag_applied_layers :
324+ # for each PAG layer input, we find corresponding self-attention layers in the transformer model
325+ target_modules = []
326+
327+ block_index = str (pag_layer_input )
328+
329+ for name , module in self .transformer .named_modules ():
330+ if is_self_attn (name ) and get_block_index (name ) == block_index :
331+ target_modules .append (module )
332+
333+ if len (target_modules ) == 0 :
334+ raise ValueError (f"Cannot find pag layer to set attention processor for: { pag_layer_input } " )
335+
336+ for module in target_modules :
337+ module .processor = pag_attn_proc
338+
339+ # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.set_pag_applied_layers
340+ def set_pag_applied_layers (self , pag_applied_layers ):
341+ r"""
342+ set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid.
343+ """
344+
345+ if not isinstance (pag_applied_layers , list ):
346+ pag_applied_layers = [pag_applied_layers ]
347+
348+ for pag_layer in pag_applied_layers :
349+ self ._check_input_pag_applied_layer (pag_layer )
350+
351+ self .pag_applied_layers = pag_applied_layers
352+
353+ # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin._get_pag_scale
354+ def _get_pag_scale (self , t ):
355+ r"""
356+ Get the scale factor for the perturbed attention guidance at timestep `t`.
357+ """
358+
359+ if self .do_pag_adaptive_scaling :
360+ signal_scale = self .pag_scale - self .pag_adaptive_scale * (1000 - t )
361+ if signal_scale < 0 :
362+ signal_scale = 0
363+ return signal_scale
364+ else :
365+ return self .pag_scale
366+
367+ # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin._apply_perturbed_attention_guidance
368+ def _apply_perturbed_attention_guidance (self , noise_pred , do_classifier_free_guidance , guidance_scale , t ):
369+ r"""
370+ Apply perturbed attention guidance to the noise prediction.
371+
372+ Args:
373+ noise_pred (torch.Tensor): The noise prediction tensor.
374+ do_classifier_free_guidance (bool): Whether to apply classifier-free guidance.
375+ guidance_scale (float): The scale factor for the guidance term.
376+ t (int): The current time step.
377+
378+ Returns:
379+ torch.Tensor: The updated noise prediction tensor after applying perturbed attention guidance.
380+ """
381+ pag_scale = self ._get_pag_scale (t )
382+ if do_classifier_free_guidance :
383+ noise_pred_uncond , noise_pred_text , noise_pred_perturb = noise_pred .chunk (3 )
384+ noise_pred = (
385+ noise_pred_uncond
386+ + guidance_scale * (noise_pred_text - noise_pred_uncond )
387+ + pag_scale * (noise_pred_text - noise_pred_perturb )
388+ )
389+ else :
390+ noise_pred_text , noise_pred_perturb = noise_pred .chunk (2 )
391+ noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb )
392+ return noise_pred
393+
394+ # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin._prepare_perturbed_attention_guidance
395+ def _prepare_perturbed_attention_guidance (self , cond , uncond , do_classifier_free_guidance ):
396+ """
397+ Prepares the perturbed attention guidance for the PAG model.
398+
399+ Args:
400+ cond (torch.Tensor): The conditional input tensor.
401+ uncond (torch.Tensor): The unconditional input tensor.
402+ do_classifier_free_guidance (bool): Flag indicating whether to perform classifier-free guidance.
403+
404+ Returns:
405+ torch.Tensor: The prepared perturbed attention guidance tensor.
406+ """
407+
408+ cond = torch .cat ([cond ] * 2 , dim = 0 )
409+
410+ if do_classifier_free_guidance :
411+ cond = torch .cat ([uncond , cond ], dim = 0 )
412+ return cond
413+
414+ @property
415+ # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.pag_scale
416+ def pag_scale (self ):
417+ """
418+ Get the scale factor for the perturbed attention guidance.
419+ """
420+ return self ._pag_scale
421+
422+ @property
423+ # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.pag_adaptive_scale
424+ def pag_adaptive_scale (self ):
425+ """
426+ Get the adaptive scale factor for the perturbed attention guidance.
427+ """
428+ return self ._pag_adaptive_scale
429+
430+ @property
431+ # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.do_pag_adaptive_scaling
432+ def do_pag_adaptive_scaling (self ):
433+ """
434+ Check if the adaptive scaling is enabled for the perturbed attention guidance.
435+ """
436+ return self ._pag_adaptive_scale > 0 and self ._pag_scale > 0 and len (self .pag_applied_layers ) > 0
437+
438+ @property
439+ # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.do_perturbed_attention_guidance
440+ def do_perturbed_attention_guidance (self ):
441+ """
442+ Check if the perturbed attention guidance is enabled.
443+ """
444+ return self ._pag_scale > 0 and len (self .pag_applied_layers ) > 0
445+
446+ @property
447+ # Copied from diffusers.pipelines.pag.pag_utils.PAGMixin.pag_attn_processors with unet->transformer
448+ def pag_attn_processors (self ):
449+ r"""
450+ Returns:
451+ `dict` of PAG attention processors: A dictionary contains all PAG attention processors used in the model
452+ with the key as the name of the layer.
453+ """
454+
455+ processors = {}
456+ for name , proc in self .transformer .attn_processors .items ():
457+ if proc .__class__ in (PAGCFGIdentitySelfAttnProcessor2_0 , PAGIdentitySelfAttnProcessor2_0 ):
458+ processors [name ] = proc
459+ return processors
0 commit comments