1414
1515import inspect
1616import math
17- from typing import Any , Callable , Dict , List , Optional , Union
17+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union
1818
1919import numpy as np
2020import torch
@@ -76,7 +76,7 @@ def get_empty_store():
7676
7777 def __call__ (self , attn , is_cross : bool , place_in_unet : str ):
7878 if self .cur_att_layer >= 0 and is_cross :
79- if attn .shape [1 ] == self .attn_res ** 2 :
79+ if attn .shape [1 ] == np . prod ( self .attn_res ) :
8080 self .step_store [place_in_unet ].append (attn )
8181
8282 self .cur_att_layer += 1
@@ -98,7 +98,7 @@ def aggregate_attention(self, from_where: List[str]) -> torch.Tensor:
9898 attention_maps = self .get_average_attention ()
9999 for location in from_where :
100100 for item in attention_maps [location ]:
101- cross_maps = item .reshape (- 1 , self .attn_res , self .attn_res , item .shape [- 1 ])
101+ cross_maps = item .reshape (- 1 , self .attn_res [ 0 ] , self .attn_res [ 1 ] , item .shape [- 1 ])
102102 out .append (cross_maps )
103103 out = torch .cat (out , dim = 0 )
104104 out = out .sum (0 ) / out .shape [0 ]
@@ -109,7 +109,7 @@ def reset(self):
109109 self .step_store = self .get_empty_store ()
110110 self .attention_store = {}
111111
112- def __init__ (self , attn_res = 16 ):
112+ def __init__ (self , attn_res ):
113113 """
114114 Initialize an empty AttentionStore :param step_index: used to visualize only a specific step in the diffusion
115115 process
@@ -724,7 +724,7 @@ def __call__(
724724 max_iter_to_alter : int = 25 ,
725725 thresholds : dict = {0 : 0.05 , 10 : 0.5 , 20 : 0.8 },
726726 scale_factor : int = 20 ,
727- attn_res : int = 16 ,
727+ attn_res : Optional [ Tuple [ int ]] = None ,
728728 ):
729729 r"""
730730 Function invoked when calling the pipeline for generation.
@@ -796,8 +796,8 @@ def __call__(
796796 Dictionary defining the iterations and desired thresholds to apply iterative latent refinement in.
797797 scale_factor (`int`, *optional*, default to 20):
798798 Scale factor that controls the step size of each Attend and Excite update.
799- attn_res (`int `, *optional*, default to 16 ):
800- The resolution of most semantic attention map.
799+ attn_res (`tuple `, *optional*, default computed from width and height ):
800+ The 2D resolution of the semantic attention map.
801801
802802 Examples:
803803
@@ -870,7 +870,9 @@ def __call__(
870870 # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
871871 extra_step_kwargs = self .prepare_extra_step_kwargs (generator , eta )
872872
873- self .attention_store = AttentionStore (attn_res = attn_res )
873+ if attn_res is None :
874+ attn_res = int (np .ceil (width / 32 )), int (np .ceil (height / 32 ))
875+ self .attention_store = AttentionStore (attn_res )
874876 self .register_attention_control ()
875877
876878 # default config for step size from original repo
0 commit comments