Skip to content

Commit 18b018c

Browse files
[SDXL Refiner] Fix refiner forward pass for batched input (huggingface#4327)
* fix_batch_xl * Fix other pipelines as well * up * up * Update tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py * sort * up * Finish it all up Co-authored-by: Bagheera <[email protected]> * Co-authored-by: Bagheera [email protected] * Co-authored-by: Bagheera <[email protected]> * Finish it all up Co-authored-by: Bagheera <[email protected]>
1 parent 54fab2c commit 18b018c

File tree

6 files changed

+43
-13
lines changed

6 files changed

+43
-13
lines changed

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -906,15 +906,17 @@ def denoising_value_valid(dnv):
906906
negative_aesthetic_score,
907907
dtype=prompt_embeds.dtype,
908908
)
909+
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
909910

910911
if do_classifier_free_guidance:
911912
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
912913
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
914+
add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
913915
add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
914916

915917
prompt_embeds = prompt_embeds.to(device)
916918
add_text_embeds = add_text_embeds.to(device)
917-
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
919+
add_time_ids = add_time_ids.to(device)
918920

919921
# 9. Denoising loop
920922
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1168,15 +1168,17 @@ def denoising_value_valid(dnv):
11681168
negative_aesthetic_score,
11691169
dtype=prompt_embeds.dtype,
11701170
)
1171+
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
11711172

11721173
if do_classifier_free_guidance:
11731174
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
11741175
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
1176+
add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
11751177
add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
11761178

11771179
prompt_embeds = prompt_embeds.to(device)
11781180
add_text_embeds = add_text_embeds.to(device)
1179-
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
1181+
add_time_ids = add_time_ids.to(device)
11801182

11811183
# 11. Denoising loop
11821184
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -811,6 +811,7 @@ def __call__(
811811
negative_aesthetic_score,
812812
dtype=prompt_embeds.dtype,
813813
)
814+
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
814815

815816
original_prompt_embeds_len = len(prompt_embeds)
816817
original_add_text_embeds_len = len(add_text_embeds)
@@ -819,6 +820,7 @@ def __call__(
819820
if do_classifier_free_guidance:
820821
prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds], dim=0)
821822
add_text_embeds = torch.cat([add_text_embeds, negative_pooled_prompt_embeds], dim=0)
823+
add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
822824
add_time_ids = torch.cat([add_time_ids, add_neg_time_ids], dim=0)
823825

824826
# Make dimensions consistent
@@ -828,7 +830,7 @@ def __call__(
828830

829831
prompt_embeds = prompt_embeds.to(device).to(torch.float32)
830832
add_text_embeds = add_text_embeds.to(device).to(torch.float32)
831-
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
833+
add_time_ids = add_time_ids.to(device)
832834

833835
# 11. Denoising loop
834836
self.unet = self.unet.to(torch.float32)

tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def get_dummy_components(self, skip_first_text_encoder=False):
6464
addition_embed_type="text_time",
6565
addition_time_embed_dim=8,
6666
transformer_layers_per_block=(1, 2),
67-
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
67+
projection_class_embeddings_input_dim=72, # 5 * 8 + 32
6868
cross_attention_dim=64 if not skip_first_text_encoder else 32,
6969
)
7070
scheduler = EulerDiscreteScheduler(
@@ -113,9 +113,18 @@ def get_dummy_components(self, skip_first_text_encoder=False):
113113
"tokenizer": tokenizer if not skip_first_text_encoder else None,
114114
"text_encoder_2": text_encoder_2,
115115
"tokenizer_2": tokenizer_2,
116+
"requires_aesthetics_score": True,
116117
}
117118
return components
118119

120+
def test_components_function(self):
121+
init_components = self.get_dummy_components()
122+
init_components.pop("requires_aesthetics_score")
123+
pipe = self.pipeline_class(**init_components)
124+
125+
self.assertTrue(hasattr(pipe, "components"))
126+
self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
127+
119128
def get_dummy_inputs(self, device, seed=0):
120129
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
121130
image = image / 2 + 0.5
@@ -147,7 +156,7 @@ def test_stable_diffusion_xl_img2img_euler(self):
147156

148157
assert image.shape == (1, 32, 32, 3)
149158

150-
expected_slice = np.array([0.4656, 0.4840, 0.4439, 0.6698, 0.5574, 0.4524, 0.5799, 0.5943, 0.5165])
159+
expected_slice = np.array([0.4664, 0.4886, 0.4403, 0.6902, 0.5592, 0.4534, 0.5931, 0.5951, 0.5224])
151160

152161
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
153162

@@ -165,7 +174,7 @@ def test_stable_diffusion_xl_refiner(self):
165174

166175
assert image.shape == (1, 32, 32, 3)
167176

168-
expected_slice = np.array([0.4676, 0.4865, 0.4335, 0.6715, 0.5578, 0.4497, 0.5847, 0.5967, 0.5198])
177+
expected_slice = np.array([0.4578, 0.4981, 0.4301, 0.6454, 0.5588, 0.4442, 0.5678, 0.5940, 0.5176])
169178

170179
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
171180

tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def get_dummy_components(self, skip_first_text_encoder=False):
6666
addition_embed_type="text_time",
6767
addition_time_embed_dim=8,
6868
transformer_layers_per_block=(1, 2),
69-
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
69+
projection_class_embeddings_input_dim=72, # 5 * 8 + 32
7070
cross_attention_dim=64 if not skip_first_text_encoder else 32,
7171
)
7272
scheduler = EulerDiscreteScheduler(
@@ -115,6 +115,7 @@ def get_dummy_components(self, skip_first_text_encoder=False):
115115
"tokenizer": tokenizer if not skip_first_text_encoder else None,
116116
"text_encoder_2": text_encoder_2,
117117
"tokenizer_2": tokenizer_2,
118+
"requires_aesthetics_score": True,
118119
}
119120
return components
120121

@@ -142,6 +143,14 @@ def get_dummy_inputs(self, device, seed=0):
142143
}
143144
return inputs
144145

146+
def test_components_function(self):
147+
init_components = self.get_dummy_components()
148+
init_components.pop("requires_aesthetics_score")
149+
pipe = self.pipeline_class(**init_components)
150+
151+
self.assertTrue(hasattr(pipe, "components"))
152+
self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
153+
145154
def test_stable_diffusion_xl_inpaint_euler(self):
146155
device = "cpu" # ensure determinism for the device-dependent torch.Generator
147156
components = self.get_dummy_components()
@@ -155,7 +164,7 @@ def test_stable_diffusion_xl_inpaint_euler(self):
155164

156165
assert image.shape == (1, 64, 64, 3)
157166

158-
expected_slice = np.array([0.6965, 0.5584, 0.5693, 0.5739, 0.6092, 0.6620, 0.5902, 0.5612, 0.5319])
167+
expected_slice = np.array([0.8029, 0.5523, 0.5825, 0.6003, 0.6702, 0.7018, 0.6369, 0.5955, 0.5123])
159168

160169
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
161170

@@ -250,10 +259,9 @@ def test_stable_diffusion_xl_refiner(self):
250259
image = sd_pipe(**inputs).images
251260
image_slice = image[0, -3:, -3:, -1]
252261

253-
print(torch.from_numpy(image_slice).flatten())
254262
assert image.shape == (1, 64, 64, 3)
255263

256-
expected_slice = np.array([0.9106, 0.6563, 0.6766, 0.6537, 0.6709, 0.7367, 0.6537, 0.5937, 0.5418])
264+
expected_slice = np.array([0.7045, 0.4838, 0.5454, 0.6270, 0.6168, 0.6717, 0.6484, 0.5681, 0.4922])
257265

258266
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
259267

tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_instruction_pix2pix.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def get_dummy_components(self):
6868
addition_embed_type="text_time",
6969
addition_time_embed_dim=8,
7070
transformer_layers_per_block=(1, 2),
71-
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
71+
projection_class_embeddings_input_dim=72, # 5 * 8 + 32
7272
cross_attention_dim=64,
7373
)
7474

@@ -118,8 +118,7 @@ def get_dummy_components(self):
118118
"tokenizer": tokenizer,
119119
"text_encoder_2": text_encoder_2,
120120
"tokenizer_2": tokenizer_2,
121-
# "safety_checker": None,
122-
# "feature_extractor": None,
121+
"requires_aesthetics_score": True,
123122
}
124123
return components
125124

@@ -141,6 +140,14 @@ def get_dummy_inputs(self, device, seed=0):
141140
}
142141
return inputs
143142

143+
def test_components_function(self):
144+
init_components = self.get_dummy_components()
145+
init_components.pop("requires_aesthetics_score")
146+
pipe = self.pipeline_class(**init_components)
147+
148+
self.assertTrue(hasattr(pipe, "components"))
149+
self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
150+
144151
def test_inference_batch_single_identical(self):
145152
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
146153

0 commit comments

Comments
 (0)