Skip to content

Commit 92f15f5

Browse files
authored
Model CPU offload fix for BLIPDiffusion (huggingface#5174)
cpu offload fix for blip diffusion
1 parent 22b19d5 commit 92f15f5

File tree

2 files changed

+29
-12
lines changed

2 files changed

+29
-12
lines changed

src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ class BlipDiffusionPipeline(DiffusionPipeline):
9898
Position of the context token in the text encoder.
9999
"""
100100

101+
model_cpu_offload_seq = "qformer->text_encoder->unet->vae"
102+
101103
def __init__(
102104
self,
103105
tokenizer: CLIPTokenizer,
@@ -155,7 +157,9 @@ def prepare_latents(self, batch_size, num_channels, height, width, dtype, device
155157
latents = latents * self.scheduler.init_noise_sigma
156158
return latents
157159

158-
def encode_prompt(self, query_embeds, prompt):
160+
def encode_prompt(self, query_embeds, prompt, device=None):
161+
device = device or self._execution_device
162+
159163
# embeddings for prompt, with query_embeds as context
160164
max_len = self.text_encoder.text_model.config.max_position_embeddings
161165
max_len -= self.qformer.config.num_query_tokens
@@ -166,7 +170,7 @@ def encode_prompt(self, query_embeds, prompt):
166170
truncation=True,
167171
max_length=max_len,
168172
return_tensors="pt",
169-
).to(self.device)
173+
).to(device)
170174

171175
batch_size = query_embeds.shape[0]
172176
ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size
@@ -249,11 +253,12 @@ def __call__(
249253
Returns:
250254
[`~pipelines.ImagePipelineOutput`] or `tuple`
251255
"""
256+
device = self._execution_device
252257

253258
reference_image = self.image_processor.preprocess(
254259
reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt"
255260
)["pixel_values"]
256-
reference_image = reference_image.to(self.device)
261+
reference_image = reference_image.to(device)
257262

258263
if isinstance(prompt, str):
259264
prompt = [prompt]
@@ -271,7 +276,7 @@ def __call__(
271276
prompt_reps=prompt_reps,
272277
)
273278
query_embeds = self.get_query_embeddings(reference_image, source_subject_category)
274-
text_embeddings = self.encode_prompt(query_embeds, prompt)
279+
text_embeddings = self.encode_prompt(query_embeds, prompt, device)
275280
do_classifier_free_guidance = guidance_scale > 1.0
276281
if do_classifier_free_guidance:
277282
max_length = self.text_encoder.text_model.config.max_position_embeddings
@@ -283,7 +288,7 @@ def __call__(
283288
return_tensors="pt",
284289
)
285290
uncond_embeddings = self.text_encoder(
286-
input_ids=uncond_input.input_ids.to(self.device),
291+
input_ids=uncond_input.input_ids.to(device),
287292
ctx_embeddings=None,
288293
)[0]
289294
# For classifier free guidance, we need to do two forward passes.
@@ -300,7 +305,7 @@ def __call__(
300305
generator=generator,
301306
latents=latents,
302307
dtype=self.unet.dtype,
303-
device=self.device,
308+
device=device,
304309
)
305310
# set timesteps
306311
extra_set_kwargs = {}
@@ -330,9 +335,13 @@ def __call__(
330335
t,
331336
latents,
332337
)["prev_sample"]
338+
333339
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
334340
image = self.image_processor.postprocess(image, output_type=output_type)
335341

342+
# Offload all models
343+
self.maybe_free_model_hooks()
344+
336345
if not return_dict:
337346
return (image,)
338347

src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
107107
Position of the context token in the text encoder.
108108
"""
109109

110+
model_cpu_offload_seq = "qformer->text_encoder->unet->vae"
111+
110112
def __init__(
111113
self,
112114
tokenizer: CLIPTokenizer,
@@ -166,7 +168,9 @@ def prepare_latents(self, batch_size, num_channels, height, width, dtype, device
166168
latents = latents * self.scheduler.init_noise_sigma
167169
return latents
168170

169-
def encode_prompt(self, query_embeds, prompt):
171+
def encode_prompt(self, query_embeds, prompt, device=None):
172+
device = device or self._execution_device
173+
170174
# embeddings for prompt, with query_embeds as context
171175
max_len = self.text_encoder.text_model.config.max_position_embeddings
172176
max_len -= self.qformer.config.num_query_tokens
@@ -177,7 +181,7 @@ def encode_prompt(self, query_embeds, prompt):
177181
truncation=True,
178182
max_length=max_len,
179183
return_tensors="pt",
180-
).to(self.device)
184+
).to(device)
181185

182186
batch_size = query_embeds.shape[0]
183187
ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size
@@ -297,11 +301,12 @@ def __call__(
297301
Returns:
298302
[`~pipelines.ImagePipelineOutput`] or `tuple`
299303
"""
304+
device = self._execution_device
300305

301306
reference_image = self.image_processor.preprocess(
302307
reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt"
303308
)["pixel_values"]
304-
reference_image = reference_image.to(self.device)
309+
reference_image = reference_image.to(device)
305310

306311
if isinstance(prompt, str):
307312
prompt = [prompt]
@@ -319,7 +324,7 @@ def __call__(
319324
prompt_reps=prompt_reps,
320325
)
321326
query_embeds = self.get_query_embeddings(reference_image, source_subject_category)
322-
text_embeddings = self.encode_prompt(query_embeds, prompt)
327+
text_embeddings = self.encode_prompt(query_embeds, prompt, device)
323328
# 3. unconditional embedding
324329
do_classifier_free_guidance = guidance_scale > 1.0
325330
if do_classifier_free_guidance:
@@ -332,7 +337,7 @@ def __call__(
332337
return_tensors="pt",
333338
)
334339
uncond_embeddings = self.text_encoder(
335-
input_ids=uncond_input.input_ids.to(self.device),
340+
input_ids=uncond_input.input_ids.to(device),
336341
ctx_embeddings=None,
337342
)[0]
338343
# For classifier free guidance, we need to do two forward passes.
@@ -348,7 +353,7 @@ def __call__(
348353
generator=generator,
349354
latents=latents,
350355
dtype=self.unet.dtype,
351-
device=self.device,
356+
device=device,
352357
)
353358
# set timesteps
354359
extra_set_kwargs = {}
@@ -399,6 +404,9 @@ def __call__(
399404
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
400405
image = self.image_processor.postprocess(image, output_type=output_type)
401406

407+
# Offload all models
408+
self.maybe_free_model_hooks()
409+
402410
if not return_dict:
403411
return (image,)
404412

0 commit comments

Comments
 (0)