@@ -995,6 +995,7 @@ def __call__(
995995 image : PipelineImageInput = None ,
996996 mask_image : PipelineImageInput = None ,
997997 control_image : PipelineImageInput = None ,
998+ control_image_alt : PipelineImageInput = None , # 新增第二个控制条件
998999 height : Optional [int ] = None ,
9991000 width : Optional [int ] = None ,
10001001 padding_mask_crop : Optional [int ] = None ,
@@ -1053,6 +1054,9 @@ def __call__(
10531054 width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
10541055 images must be passed as a list such that each element of the list can be correctly batched for input
10551056 to a single ControlNet.
1057+ control_image_alt (`torch.Tensor`, `PIL.Image.Image`, `List[torch.Tensor]`, `List[PIL.Image.Image]`,
1058+ `List[List[torch.Tensor]]`, or `List[List[PIL.Image.Image]]`):
1059+ 第二个ControlNet输入条件,与control_image提供的逻辑相同。
10561060 height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
10571061 The height in pixels of the generated image.
10581062 width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
@@ -1277,6 +1281,22 @@ def __call__(
12771281 do_classifier_free_guidance = self .do_classifier_free_guidance ,
12781282 guess_mode = guess_mode ,
12791283 )
1284+
1285+ # 准备第二个条件输入
1286+ if control_image_alt is not None :
1287+ control_image_alt = self .prepare_control_image (
1288+ image = control_image_alt ,
1289+ width = width ,
1290+ height = height ,
1291+ batch_size = batch_size * num_images_per_prompt ,
1292+ num_images_per_prompt = num_images_per_prompt ,
1293+ device = device ,
1294+ dtype = controlnet .dtype ,
1295+ crops_coords = crops_coords ,
1296+ resize_mode = resize_mode ,
1297+ do_classifier_free_guidance = self .do_classifier_free_guidance ,
1298+ guess_mode = guess_mode ,
1299+ )
12801300 elif isinstance (controlnet , MultiControlNetModel ):
12811301 control_images = []
12821302
@@ -1298,6 +1318,29 @@ def __call__(
12981318 control_images .append (control_image_ )
12991319
13001320 control_image = control_images
1321+
1322+ # 为MultiControlNetModel准备第二个条件列表
1323+ if control_image_alt is not None :
1324+ control_images_alt = []
1325+
1326+ for control_image_alt_ in control_image_alt :
1327+ control_image_alt_ = self .prepare_control_image (
1328+ image = control_image_alt_ ,
1329+ width = width ,
1330+ height = height ,
1331+ batch_size = batch_size * num_images_per_prompt ,
1332+ num_images_per_prompt = num_images_per_prompt ,
1333+ device = device ,
1334+ dtype = controlnet .dtype ,
1335+ crops_coords = crops_coords ,
1336+ resize_mode = resize_mode ,
1337+ do_classifier_free_guidance = self .do_classifier_free_guidance ,
1338+ guess_mode = guess_mode ,
1339+ )
1340+
1341+ control_images_alt .append (control_image_alt_ )
1342+
1343+ control_image_alt = control_images_alt
13011344 else :
13021345 assert False
13031346
@@ -1417,6 +1460,7 @@ def __call__(
14171460 t ,
14181461 encoder_hidden_states = controlnet_prompt_embeds ,
14191462 controlnet_cond = control_image ,
1463+ controlnet_cond_alt = control_image_alt , # 添加第二个控制条件
14201464 conditioning_scale = cond_scale ,
14211465 guess_mode = guess_mode ,
14221466 return_dict = False ,
0 commit comments