Skip to content

Commit cfc99ad

Browse files
Add global pooling to controlnet (huggingface#3121)
1 parent 807f69b commit cfc99ad

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

src/diffusers/models/controlnet.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def __init__(
119119
projection_class_embeddings_input_dim: Optional[int] = None,
120120
controlnet_conditioning_channel_order: str = "rgb",
121121
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
122+
global_pool_conditions: bool = False,
122123
):
123124
super().__init__()
124125

@@ -566,6 +567,12 @@ def forward(
566567
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
567568
mid_block_res_sample *= conditioning_scale
568569

570+
if self.config.global_pool_conditions:
571+
down_block_res_samples = [
572+
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
573+
]
574+
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
575+
569576
if not return_dict:
570577
return (down_block_res_samples, mid_block_res_sample)
571578

0 commit comments

Comments
 (0)