@@ -640,6 +640,79 @@ def forward(self, hidden_states):
640640 return hidden_states
641641
642642
643+ class AttnDownEncoderBlock2D (nn .Module ):
644+ def __init__ (
645+ self ,
646+ in_channels : int ,
647+ out_channels : int ,
648+ dropout : float = 0.0 ,
649+ num_layers : int = 1 ,
650+ resnet_eps : float = 1e-6 ,
651+ resnet_time_scale_shift : str = "default" ,
652+ resnet_act_fn : str = "swish" ,
653+ resnet_groups : int = 32 ,
654+ resnet_pre_norm : bool = True ,
655+ attn_num_head_channels = 1 ,
656+ output_scale_factor = 1.0 ,
657+ add_downsample = True ,
658+ downsample_padding = 1 ,
659+ ):
660+ super ().__init__ ()
661+ resnets = []
662+ attentions = []
663+
664+ for i in range (num_layers ):
665+ in_channels = in_channels if i == 0 else out_channels
666+ resnets .append (
667+ ResnetBlock (
668+ in_channels = in_channels ,
669+ out_channels = out_channels ,
670+ temb_channels = None ,
671+ eps = resnet_eps ,
672+ groups = resnet_groups ,
673+ dropout = dropout ,
674+ time_embedding_norm = resnet_time_scale_shift ,
675+ non_linearity = resnet_act_fn ,
676+ output_scale_factor = output_scale_factor ,
677+ pre_norm = resnet_pre_norm ,
678+ )
679+ )
680+ attentions .append (
681+ AttentionBlockNew (
682+ out_channels ,
683+ num_head_channels = attn_num_head_channels ,
684+ rescale_output_factor = output_scale_factor ,
685+ eps = resnet_eps ,
686+ num_groups = resnet_groups ,
687+ )
688+ )
689+
690+ self .attentions = nn .ModuleList (attentions )
691+ self .resnets = nn .ModuleList (resnets )
692+
693+ if add_downsample :
694+ self .downsamplers = nn .ModuleList (
695+ [
696+ Downsample2D (
697+ in_channels , use_conv = True , out_channels = out_channels , padding = downsample_padding , name = "op"
698+ )
699+ ]
700+ )
701+ else :
702+ self .downsamplers = None
703+
704+ def forward (self , hidden_states ):
705+ for resnet , attn in zip (self .resnets , self .attentions ):
706+ hidden_states = resnet (hidden_states , temb = None )
707+ hidden_states = attn (hidden_states )
708+
709+ if self .downsamplers is not None :
710+ for downsampler in self .downsamplers :
711+ hidden_states = downsampler (hidden_states )
712+
713+ return hidden_states
714+
715+
643716class AttnSkipDownBlock2D (nn .Module ):
644717 def __init__ (
645718 self ,
@@ -1087,6 +1160,73 @@ def forward(self, hidden_states):
10871160 return hidden_states
10881161
10891162
1163+ class AttnUpDecoderBlock2D (nn .Module ):
1164+ def __init__ (
1165+ self ,
1166+ in_channels : int ,
1167+ out_channels : int ,
1168+ dropout : float = 0.0 ,
1169+ num_layers : int = 1 ,
1170+ resnet_eps : float = 1e-6 ,
1171+ resnet_time_scale_shift : str = "default" ,
1172+ resnet_act_fn : str = "swish" ,
1173+ resnet_groups : int = 32 ,
1174+ resnet_pre_norm : bool = True ,
1175+ attn_num_head_channels = 1 ,
1176+ output_scale_factor = 1.0 ,
1177+ add_upsample = True ,
1178+ ):
1179+ super ().__init__ ()
1180+ resnets = []
1181+ attentions = []
1182+
1183+ for i in range (num_layers ):
1184+ input_channels = in_channels if i == 0 else out_channels
1185+
1186+ resnets .append (
1187+ ResnetBlock (
1188+ in_channels = input_channels ,
1189+ out_channels = out_channels ,
1190+ temb_channels = None ,
1191+ eps = resnet_eps ,
1192+ groups = resnet_groups ,
1193+ dropout = dropout ,
1194+ time_embedding_norm = resnet_time_scale_shift ,
1195+ non_linearity = resnet_act_fn ,
1196+ output_scale_factor = output_scale_factor ,
1197+ pre_norm = resnet_pre_norm ,
1198+ )
1199+ )
1200+ attentions .append (
1201+ AttentionBlockNew (
1202+ out_channels ,
1203+ num_head_channels = attn_num_head_channels ,
1204+ rescale_output_factor = output_scale_factor ,
1205+ eps = resnet_eps ,
1206+ num_groups = resnet_groups ,
1207+ )
1208+ )
1209+
1210+ self .attentions = nn .ModuleList (attentions )
1211+ self .resnets = nn .ModuleList (resnets )
1212+
1213+ if add_upsample :
1214+ self .upsamplers = nn .ModuleList ([Upsample2D (out_channels , use_conv = True , out_channels = out_channels )])
1215+ else :
1216+ self .upsamplers = None
1217+
1218+ def forward (self , hidden_states ):
1219+ for resnet , attn in zip (self .resnets , self .attentions ):
1220+ hidden_states = resnet (hidden_states , temb = None )
1221+ hidden_states = attn (hidden_states )
1222+
1223+ if self .upsamplers is not None :
1224+ for upsampler in self .upsamplers :
1225+ hidden_states = upsampler (hidden_states )
1226+
1227+ return hidden_states
1228+
1229+
10901230class AttnSkipUpBlock2D (nn .Module ):
10911231 def __init__ (
10921232 self ,
0 commit comments