Skip to content

Commit 4c660d1

Browse files
[Stable Diffusion] Fix padding / truncation (huggingface#1226)
* [Stable Diffusion] Fix padding / truncation * finish
1 parent 8171566 commit 4c660d1

9 files changed

+88
-25
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -248,17 +248,18 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr
248248
prompt,
249249
padding="max_length",
250250
max_length=self.tokenizer.model_max_length,
251+
truncation=True,
251252
return_tensors="pt",
252253
)
253254
text_input_ids = text_inputs.input_ids
255+
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
254256

255-
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
256-
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
257+
if not torch.equal(text_input_ids, untruncated_ids):
258+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
257259
logger.warning(
258260
"The following part of your input was truncated because CLIP can only handle sequences up to"
259261
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
260262
)
261-
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
262263
text_embeddings = self.text_encoder(text_input_ids.to(device))[0]
263264

264265
# duplicate text embeddings for each generation per prompt, using mps friendly method

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,17 +114,19 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida
114114
prompt,
115115
padding="max_length",
116116
max_length=self.tokenizer.model_max_length,
117+
truncation=True,
117118
return_tensors="np",
118119
)
119120
text_input_ids = text_inputs.input_ids
121+
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
120122

121-
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
122-
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
123+
if not np.array_equal(text_input_ids, untruncated_ids):
124+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
123125
logger.warning(
124126
"The following part of your input was truncated because CLIP can only handle sequences up to"
125127
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
126128
)
127-
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
129+
128130
text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
129131
text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0)
130132

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,17 +161,19 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida
161161
prompt,
162162
padding="max_length",
163163
max_length=self.tokenizer.model_max_length,
164+
truncation=True,
164165
return_tensors="np",
165166
)
166167
text_input_ids = text_inputs.input_ids
168+
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
167169

168-
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
169-
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
170+
if not np.array_equal(text_input_ids, untruncated_ids):
171+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
170172
logger.warning(
171173
"The following part of your input was truncated because CLIP can only handle sequences up to"
172174
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
173175
)
174-
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
176+
175177
text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
176178
text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0)
177179

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,17 +175,19 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida
175175
prompt,
176176
padding="max_length",
177177
max_length=self.tokenizer.model_max_length,
178+
truncation=True,
178179
return_tensors="np",
179180
)
180181
text_input_ids = text_inputs.input_ids
182+
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
181183

182-
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
183-
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
184+
if not np.array_equal(text_input_ids, untruncated_ids):
185+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
184186
logger.warning(
185187
"The following part of your input was truncated because CLIP can only handle sequences up to"
186188
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
187189
)
188-
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
190+
189191
text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
190192
text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0)
191193

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -236,17 +236,18 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr
236236
prompt,
237237
padding="max_length",
238238
max_length=self.tokenizer.model_max_length,
239+
truncation=True,
239240
return_tensors="pt",
240241
)
241242
text_input_ids = text_inputs.input_ids
243+
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
242244

243-
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
244-
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
245+
if not torch.equal(text_input_ids, untruncated_ids):
246+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
245247
logger.warning(
246248
"The following part of your input was truncated because CLIP can only handle sequences up to"
247249
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
248250
)
249-
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
250251
text_embeddings = self.text_encoder(text_input_ids.to(device))[0]
251252

252253
# duplicate text embeddings for each generation per prompt, using mps friendly method

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -244,17 +244,18 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr
244244
prompt,
245245
padding="max_length",
246246
max_length=self.tokenizer.model_max_length,
247+
truncation=True,
247248
return_tensors="pt",
248249
)
249250
text_input_ids = text_inputs.input_ids
251+
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
250252

251-
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
252-
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
253+
if not torch.equal(text_input_ids, untruncated_ids):
254+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
253255
logger.warning(
254256
"The following part of your input was truncated because CLIP can only handle sequences up to"
255257
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
256258
)
257-
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
258259
text_embeddings = self.text_encoder(text_input_ids.to(device))[0]
259260

260261
# duplicate text embeddings for each generation per prompt, using mps friendly method

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -244,17 +244,18 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr
244244
prompt,
245245
padding="max_length",
246246
max_length=self.tokenizer.model_max_length,
247+
truncation=True,
247248
return_tensors="pt",
248249
)
249250
text_input_ids = text_inputs.input_ids
251+
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
250252

251-
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
252-
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
253+
if not torch.equal(text_input_ids, untruncated_ids):
254+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
253255
logger.warning(
254256
"The following part of your input was truncated because CLIP can only handle sequences up to"
255257
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
256258
)
257-
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
258259
text_embeddings = self.text_encoder(text_input_ids.to(device))[0]
259260

260261
# duplicate text embeddings for each generation per prompt, using mps friendly method

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -213,17 +213,18 @@ def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_fr
213213
prompt,
214214
padding="max_length",
215215
max_length=self.tokenizer.model_max_length,
216+
truncation=True,
216217
return_tensors="pt",
217218
)
218219
text_input_ids = text_inputs.input_ids
220+
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
219221

220-
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
221-
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
222+
if not torch.equal(text_input_ids, untruncated_ids):
223+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
222224
logger.warning(
223225
"The following part of your input was truncated because CLIP can only handle sequences up to"
224226
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
225227
)
226-
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
227228
text_embeddings = self.text_encoder(text_input_ids.to(device))[0]
228229

229230
# duplicate text embeddings for each generation per prompt, using mps friendly method

tests/pipelines/stable_diffusion/test_stable_diffusion.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,10 @@
3333
UNet2DConditionModel,
3434
UNet2DModel,
3535
VQModel,
36+
logging,
3637
)
3738
from diffusers.utils import floats_tensor, load_numpy, slow, torch_device
38-
from diffusers.utils.testing_utils import require_torch_gpu
39+
from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu
3940
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
4041

4142
from ...test_pipelines_common import PipelineTesterMixin
@@ -619,6 +620,57 @@ def test_stable_diffusion_fp16(self):
619620

620621
assert image.shape == (1, 128, 128, 3)
621622

623+
def test_stable_diffusion_long_prompt(self):
624+
unet = self.dummy_cond_unet
625+
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
626+
vae = self.dummy_vae
627+
bert = self.dummy_text_encoder
628+
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
629+
630+
# make sure here that pndm scheduler skips prk
631+
sd_pipe = StableDiffusionPipeline(
632+
unet=unet,
633+
scheduler=scheduler,
634+
vae=vae,
635+
text_encoder=bert,
636+
tokenizer=tokenizer,
637+
safety_checker=None,
638+
feature_extractor=self.dummy_extractor,
639+
)
640+
sd_pipe = sd_pipe.to(torch_device)
641+
sd_pipe.set_progress_bar_config(disable=None)
642+
643+
do_classifier_free_guidance = True
644+
negative_prompt = None
645+
num_images_per_prompt = 1
646+
logger = logging.get_logger("diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion")
647+
648+
prompt = 25 * "@"
649+
with CaptureLogger(logger) as cap_logger_3:
650+
text_embeddings_3 = sd_pipe._encode_prompt(
651+
prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
652+
)
653+
654+
prompt = 100 * "@"
655+
with CaptureLogger(logger) as cap_logger:
656+
text_embeddings = sd_pipe._encode_prompt(
657+
prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
658+
)
659+
660+
negative_prompt = "Hello"
661+
with CaptureLogger(logger) as cap_logger_2:
662+
text_embeddings_2 = sd_pipe._encode_prompt(
663+
prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
664+
)
665+
666+
assert text_embeddings_3.shape == text_embeddings_2.shape == text_embeddings.shape
667+
assert text_embeddings.shape[1] == 77
668+
669+
assert cap_logger.out == cap_logger_2.out
670+
# 100 - 77 + 1 (BOS token) + 1 (EOS token) = 25
671+
assert cap_logger.out.count("@") == 25
672+
assert cap_logger_3.out == ""
673+
622674

623675
@slow
624676
@require_torch_gpu

0 commit comments

Comments
 (0)