diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index 6a3cf77a7df7..a6a41d2524d3 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -467,17 +467,17 @@ def _resize_and_crop( def resize( self, - image: Union[PIL.Image.Image, np.ndarray, torch.Tensor], + image: PipelineImageInput, height: int, width: int, resize_mode: str = "default", # "default", "fill", "crop" - ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]: + ) -> PipelineImageInput: """ Resize image. Args: - image (`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`): - The image input, can be a PIL image, numpy array or pytorch tensor. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + The image batch input, can be a PIL image, numpy array or pytorch tensor, or a list of these elements. height (`int`): The height to resize to. width (`int`): @@ -492,7 +492,7 @@ def resize( supported for PIL image input. Returns: - `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`: + `torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`: The resized image. """ if resize_mode != "default" and not isinstance(image, PIL.Image.Image): @@ -523,6 +523,16 @@ def resize( size=(height, width), ) image = self.pt_to_numpy(image) + elif isinstance(image, list): + image = [ + self.resize( + img, + height=height, + width=width, + resize_mode=resize_mode, + ) + for img in image + ] return image def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image: diff --git a/tests/others/test_image_processor.py b/tests/others/test_image_processor.py index e9e5c0670676..f4f52b4fa4aa 100644 --- a/tests/others/test_image_processor.py +++ b/tests/others/test_image_processor.py @@ -308,3 +308,16 @@ def test_vae_image_processor_resize_np(self): assert out_np.shape == exp_np_shape, ( f"resized image output shape '{out_np.shape}' didn't match expected shape '{exp_np_shape}'." ) + + def test_vae_image_processor_resize_list_pt(self): + image_processor = VaeImageProcessor(do_resize=True, vae_scale_factor=1) + input_pt = self.dummy_sample + b, c, h, w = input_pt.shape + scale = 2 + input_pt_list = [input_pt] + out_pt_list = image_processor.resize(image=input_pt_list, height=h // scale, width=w // scale) + exp_pt_shape = (b, c, h // scale, w // scale) + for out_pt in out_pt_list: + assert out_pt.shape == exp_pt_shape, ( + f"resized image output shape '{out_pt.shape}' didn't match expected shape '{exp_pt_shape}'." + )