Skip to content

Commit a638eef

Browse files
authored
Update pipeline_controlnet_inpaint.py
1 parent cec8e32 commit a638eef

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -995,6 +995,7 @@ def __call__(
995995
image: PipelineImageInput = None,
996996
mask_image: PipelineImageInput = None,
997997
control_image: PipelineImageInput = None,
998+
control_image_alt: PipelineImageInput = None, # 新增第二个控制条件
998999
height: Optional[int] = None,
9991000
width: Optional[int] = None,
10001001
padding_mask_crop: Optional[int] = None,
@@ -1053,6 +1054,9 @@ def __call__(
10531054
width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
10541055
images must be passed as a list such that each element of the list can be correctly batched for input
10551056
to a single ControlNet.
1057+
control_image_alt (`torch.Tensor`, `PIL.Image.Image`, `List[torch.Tensor]`, `List[PIL.Image.Image]`,
1058+
`List[List[torch.Tensor]]`, or `List[List[PIL.Image.Image]]`):
1059+
第二个ControlNet输入条件,与control_image提供的逻辑相同。
10561060
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
10571061
The height in pixels of the generated image.
10581062
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
@@ -1277,6 +1281,22 @@ def __call__(
12771281
do_classifier_free_guidance=self.do_classifier_free_guidance,
12781282
guess_mode=guess_mode,
12791283
)
1284+
1285+
# 准备第二个条件输入
1286+
if control_image_alt is not None:
1287+
control_image_alt = self.prepare_control_image(
1288+
image=control_image_alt,
1289+
width=width,
1290+
height=height,
1291+
batch_size=batch_size * num_images_per_prompt,
1292+
num_images_per_prompt=num_images_per_prompt,
1293+
device=device,
1294+
dtype=controlnet.dtype,
1295+
crops_coords=crops_coords,
1296+
resize_mode=resize_mode,
1297+
do_classifier_free_guidance=self.do_classifier_free_guidance,
1298+
guess_mode=guess_mode,
1299+
)
12801300
elif isinstance(controlnet, MultiControlNetModel):
12811301
control_images = []
12821302

@@ -1298,6 +1318,29 @@ def __call__(
12981318
control_images.append(control_image_)
12991319

13001320
control_image = control_images
1321+
1322+
# 为MultiControlNetModel准备第二个条件列表
1323+
if control_image_alt is not None:
1324+
control_images_alt = []
1325+
1326+
for control_image_alt_ in control_image_alt:
1327+
control_image_alt_ = self.prepare_control_image(
1328+
image=control_image_alt_,
1329+
width=width,
1330+
height=height,
1331+
batch_size=batch_size * num_images_per_prompt,
1332+
num_images_per_prompt=num_images_per_prompt,
1333+
device=device,
1334+
dtype=controlnet.dtype,
1335+
crops_coords=crops_coords,
1336+
resize_mode=resize_mode,
1337+
do_classifier_free_guidance=self.do_classifier_free_guidance,
1338+
guess_mode=guess_mode,
1339+
)
1340+
1341+
control_images_alt.append(control_image_alt_)
1342+
1343+
control_image_alt = control_images_alt
13011344
else:
13021345
assert False
13031346

@@ -1417,6 +1460,7 @@ def __call__(
14171460
t,
14181461
encoder_hidden_states=controlnet_prompt_embeds,
14191462
controlnet_cond=control_image,
1463+
controlnet_cond_alt=control_image_alt, # 添加第二个控制条件
14201464
conditioning_scale=cond_scale,
14211465
guess_mode=guess_mode,
14221466
return_dict=False,

0 commit comments

Comments
 (0)