Skip to content

Commit b344c95

Browse files
authored
add attention up/down blocks for VAE (huggingface#161)
1 parent dd10da7 commit b344c95

File tree

1 file changed

+140
-0
lines changed

1 file changed

+140
-0
lines changed

src/diffusers/models/unet_blocks.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,79 @@ def forward(self, hidden_states):
640640
return hidden_states
641641

642642

643+
class AttnDownEncoderBlock2D(nn.Module):
644+
def __init__(
645+
self,
646+
in_channels: int,
647+
out_channels: int,
648+
dropout: float = 0.0,
649+
num_layers: int = 1,
650+
resnet_eps: float = 1e-6,
651+
resnet_time_scale_shift: str = "default",
652+
resnet_act_fn: str = "swish",
653+
resnet_groups: int = 32,
654+
resnet_pre_norm: bool = True,
655+
attn_num_head_channels=1,
656+
output_scale_factor=1.0,
657+
add_downsample=True,
658+
downsample_padding=1,
659+
):
660+
super().__init__()
661+
resnets = []
662+
attentions = []
663+
664+
for i in range(num_layers):
665+
in_channels = in_channels if i == 0 else out_channels
666+
resnets.append(
667+
ResnetBlock(
668+
in_channels=in_channels,
669+
out_channels=out_channels,
670+
temb_channels=None,
671+
eps=resnet_eps,
672+
groups=resnet_groups,
673+
dropout=dropout,
674+
time_embedding_norm=resnet_time_scale_shift,
675+
non_linearity=resnet_act_fn,
676+
output_scale_factor=output_scale_factor,
677+
pre_norm=resnet_pre_norm,
678+
)
679+
)
680+
attentions.append(
681+
AttentionBlockNew(
682+
out_channels,
683+
num_head_channels=attn_num_head_channels,
684+
rescale_output_factor=output_scale_factor,
685+
eps=resnet_eps,
686+
num_groups=resnet_groups,
687+
)
688+
)
689+
690+
self.attentions = nn.ModuleList(attentions)
691+
self.resnets = nn.ModuleList(resnets)
692+
693+
if add_downsample:
694+
self.downsamplers = nn.ModuleList(
695+
[
696+
Downsample2D(
697+
in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
698+
)
699+
]
700+
)
701+
else:
702+
self.downsamplers = None
703+
704+
def forward(self, hidden_states):
705+
for resnet, attn in zip(self.resnets, self.attentions):
706+
hidden_states = resnet(hidden_states, temb=None)
707+
hidden_states = attn(hidden_states)
708+
709+
if self.downsamplers is not None:
710+
for downsampler in self.downsamplers:
711+
hidden_states = downsampler(hidden_states)
712+
713+
return hidden_states
714+
715+
643716
class AttnSkipDownBlock2D(nn.Module):
644717
def __init__(
645718
self,
@@ -1087,6 +1160,73 @@ def forward(self, hidden_states):
10871160
return hidden_states
10881161

10891162

1163+
class AttnUpDecoderBlock2D(nn.Module):
1164+
def __init__(
1165+
self,
1166+
in_channels: int,
1167+
out_channels: int,
1168+
dropout: float = 0.0,
1169+
num_layers: int = 1,
1170+
resnet_eps: float = 1e-6,
1171+
resnet_time_scale_shift: str = "default",
1172+
resnet_act_fn: str = "swish",
1173+
resnet_groups: int = 32,
1174+
resnet_pre_norm: bool = True,
1175+
attn_num_head_channels=1,
1176+
output_scale_factor=1.0,
1177+
add_upsample=True,
1178+
):
1179+
super().__init__()
1180+
resnets = []
1181+
attentions = []
1182+
1183+
for i in range(num_layers):
1184+
input_channels = in_channels if i == 0 else out_channels
1185+
1186+
resnets.append(
1187+
ResnetBlock(
1188+
in_channels=input_channels,
1189+
out_channels=out_channels,
1190+
temb_channels=None,
1191+
eps=resnet_eps,
1192+
groups=resnet_groups,
1193+
dropout=dropout,
1194+
time_embedding_norm=resnet_time_scale_shift,
1195+
non_linearity=resnet_act_fn,
1196+
output_scale_factor=output_scale_factor,
1197+
pre_norm=resnet_pre_norm,
1198+
)
1199+
)
1200+
attentions.append(
1201+
AttentionBlockNew(
1202+
out_channels,
1203+
num_head_channels=attn_num_head_channels,
1204+
rescale_output_factor=output_scale_factor,
1205+
eps=resnet_eps,
1206+
num_groups=resnet_groups,
1207+
)
1208+
)
1209+
1210+
self.attentions = nn.ModuleList(attentions)
1211+
self.resnets = nn.ModuleList(resnets)
1212+
1213+
if add_upsample:
1214+
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1215+
else:
1216+
self.upsamplers = None
1217+
1218+
def forward(self, hidden_states):
1219+
for resnet, attn in zip(self.resnets, self.attentions):
1220+
hidden_states = resnet(hidden_states, temb=None)
1221+
hidden_states = attn(hidden_states)
1222+
1223+
if self.upsamplers is not None:
1224+
for upsampler in self.upsamplers:
1225+
hidden_states = upsampler(hidden_states)
1226+
1227+
return hidden_states
1228+
1229+
10901230
class AttnSkipUpBlock2D(nn.Module):
10911231
def __init__(
10921232
self,

0 commit comments

Comments
 (0)