Skip to content

Commit 98c9aac

Browse files
[SDXL] Fix all sequential offload (huggingface#4010)
* Fix all sequential offload * make style * make style
1 parent e3d71ad commit 98c9aac

File tree

3 files changed

+52
-69
lines changed

3 files changed

+52
-69
lines changed

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,6 @@ def disable_vae_tiling(self):
176176
"""
177177
self.vae.disable_tiling()
178178

179-
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.enable_sequential_cpu_offload
180179
def enable_sequential_cpu_offload(self, gpu_id=0):
181180
r"""
182181
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
@@ -196,10 +195,12 @@ def enable_sequential_cpu_offload(self, gpu_id=0):
196195
self.to("cpu", silence_dtype_warnings=True)
197196
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
198197

199-
for cpu_offloaded_model in [self.unet, self.text_encoder, self.text_encoder_2, self.vae]:
198+
for cpu_offloaded_model in [self.unet, self.text_encoder_2, self.vae]:
200199
cpu_offload(cpu_offloaded_model, device)
201200

202-
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.enable_model_cpu_offload
201+
if self.text_encoder is not None:
202+
cpu_offload(self.text_encoder, device)
203+
203204
def enable_model_cpu_offload(self, gpu_id=0):
204205
r"""
205206
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared

tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py

Lines changed: 24 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import gc
1716
import unittest
1817

1918
import numpy as np
@@ -22,12 +21,11 @@
2221

2322
from diffusers import (
2423
AutoencoderKL,
25-
DiffusionPipeline,
2624
EulerDiscreteScheduler,
2725
StableDiffusionXLPipeline,
2826
UNet2DConditionModel,
2927
)
30-
from diffusers.utils import slow, torch_device
28+
from diffusers.utils import torch_device
3129
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu
3230

3331
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
@@ -190,38 +188,31 @@ def test_attention_slicing_forward_pass(self):
190188
def test_inference_batch_single_identical(self):
191189
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
192190

191+
@require_torch_gpu
192+
def test_stable_diffusion_xl_offloads(self):
193+
pipes = []
194+
components = self.get_dummy_components()
195+
sd_pipe = StableDiffusionXLPipeline(**components).to(torch_device)
196+
pipes.append(sd_pipe)
193197

194-
@slow
195-
@require_torch_gpu
196-
class StableDiffusionXLPipelineSlowTests(unittest.TestCase):
197-
def tearDown(self):
198-
super().tearDown()
199-
gc.collect()
200-
torch.cuda.empty_cache()
198+
components = self.get_dummy_components()
199+
sd_pipe = StableDiffusionXLPipeline(**components)
200+
sd_pipe.enable_model_cpu_offload()
201+
pipes.append(sd_pipe)
201202

202-
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
203-
generator = torch.Generator(device=generator_device).manual_seed(seed)
204-
latents = np.random.RandomState(seed).standard_normal((1, 4, 64, 64))
205-
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
206-
inputs = {
207-
"prompt": "a photograph of an astronaut riding a horse",
208-
"latents": latents,
209-
"generator": generator,
210-
"num_inference_steps": 3,
211-
"guidance_scale": 7.5,
212-
"output_type": "numpy",
213-
}
214-
return inputs
203+
components = self.get_dummy_components()
204+
sd_pipe = StableDiffusionXLPipeline(**components)
205+
sd_pipe.enable_sequential_cpu_offload()
206+
pipes.append(sd_pipe)
207+
208+
image_slices = []
209+
for pipe in pipes:
210+
pipe.unet.set_default_attn_processor()
215211

216-
def test_stable_diffusion_default_euler(self):
217-
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-base")
218-
pipe.to(torch_device)
219-
pipe.set_progress_bar_config(disable=None)
212+
inputs = self.get_dummy_inputs(torch_device)
213+
image = pipe(**inputs).images
220214

221-
inputs = self.get_inputs(torch_device)
222-
image = pipe(**inputs).images
223-
image_slice = image[0, -3:, -3:, -1].flatten()
215+
image_slices.append(image[0, -3:, -3:, -1].flatten())
224216

225-
assert image.shape == (1, 512, 512, 3)
226-
expected_slice = np.array([0.49493, 0.47896, 0.40798, 0.54214, 0.53212, 0.48202, 0.47656, 0.46329, 0.48506])
227-
assert np.abs(image_slice - expected_slice).max() < 7e-3
217+
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
218+
assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3

tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py

Lines changed: 24 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import gc
1716
import random
1817
import unittest
1918

@@ -23,12 +22,11 @@
2322

2423
from diffusers import (
2524
AutoencoderKL,
26-
DiffusionPipeline,
2725
EulerDiscreteScheduler,
2826
StableDiffusionXLImg2ImgPipeline,
2927
UNet2DConditionModel,
3028
)
31-
from diffusers.utils import floats_tensor, slow, torch_device
29+
from diffusers.utils import floats_tensor, torch_device
3230
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu
3331

3432
from ..pipeline_params import (
@@ -205,38 +203,31 @@ def test_stable_diffusion_xl_img2img_negative_prompt_embeds(self):
205203
# make sure that it's equal
206204
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
207205

206+
@require_torch_gpu
207+
def test_stable_diffusion_xl_offloads(self):
208+
pipes = []
209+
components = self.get_dummy_components()
210+
sd_pipe = StableDiffusionXLImg2ImgPipeline(**components).to(torch_device)
211+
pipes.append(sd_pipe)
208212

209-
@slow
210-
@require_torch_gpu
211-
class StableDiffusionXLImg2ImgPipelineSlowTests(unittest.TestCase):
212-
def tearDown(self):
213-
super().tearDown()
214-
gc.collect()
215-
torch.cuda.empty_cache()
213+
components = self.get_dummy_components()
214+
sd_pipe = StableDiffusionXLImg2ImgPipeline(**components)
215+
sd_pipe.enable_model_cpu_offload()
216+
pipes.append(sd_pipe)
216217

217-
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
218-
generator = torch.Generator(device=generator_device).manual_seed(seed)
219-
latents = np.random.RandomState(seed).standard_normal((1, 4, 64, 64))
220-
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
221-
inputs = {
222-
"prompt": "a photograph of an astronaut riding a horse",
223-
"latents": latents,
224-
"generator": generator,
225-
"num_inference_steps": 3,
226-
"guidance_scale": 7.5,
227-
"output_type": "numpy",
228-
}
229-
return inputs
218+
components = self.get_dummy_components()
219+
sd_pipe = StableDiffusionXLImg2ImgPipeline(**components)
220+
sd_pipe.enable_sequential_cpu_offload()
221+
pipes.append(sd_pipe)
222+
223+
image_slices = []
224+
for pipe in pipes:
225+
pipe.unet.set_default_attn_processor()
230226

231-
def test_stable_diffusion_default_euler(self):
232-
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-base")
233-
pipe.to(torch_device)
234-
pipe.set_progress_bar_config(disable=None)
227+
inputs = self.get_dummy_inputs(torch_device)
228+
image = pipe(**inputs).images
235229

236-
inputs = self.get_inputs(torch_device)
237-
image = pipe(**inputs).images
238-
image_slice = image[0, -3:, -3:, -1].flatten()
230+
image_slices.append(image[0, -3:, -3:, -1].flatten())
239231

240-
assert image.shape == (1, 512, 512, 3)
241-
expected_slice = np.array([0.49493, 0.47896, 0.40798, 0.54214, 0.53212, 0.48202, 0.47656, 0.46329, 0.48506])
242-
assert np.abs(image_slice - expected_slice).max() < 7e-3
232+
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
233+
assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3

0 commit comments

Comments
 (0)