Skip to content

Commit 1e5eaca

Browse files
stable unclip integration tests turn on memory saving (huggingface#2408)
* stable unclip integration tests turn on memory saving * add note on turning on memory saving
1 parent 55de509 commit 1e5eaca

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

tests/pipelines/stable_unclip/test_stable_unclip.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,10 @@ def test_stable_unclip(self):
189189
pipe = StableUnCLIPPipeline.from_pretrained("fusing/stable-unclip-2-1-l", torch_dtype=torch.float16)
190190
pipe.to(torch_device)
191191
pipe.set_progress_bar_config(disable=None)
192+
# stable unclip will oom when integration tests are run on a V100,
193+
# so turn on memory savings
194+
pipe.enable_attention_slicing()
195+
pipe.enable_sequential_cpu_offload()
192196

193197
generator = torch.Generator(device="cpu").manual_seed(0)
194198
output = pipe("anime turle", generator=generator, output_type="np")

tests/pipelines/stable_unclip/test_stable_unclip_img2img.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,10 @@ def test_stable_unclip_l_img2img(self):
185185
)
186186
pipe.to(torch_device)
187187
pipe.set_progress_bar_config(disable=None)
188+
# stable unclip will oom when integration tests are run on a V100,
189+
# so turn on memory savings
190+
pipe.enable_attention_slicing()
191+
pipe.enable_sequential_cpu_offload()
188192

189193
generator = torch.Generator(device="cpu").manual_seed(0)
190194
output = pipe("anime turle", image=input_image, generator=generator, output_type="np")
@@ -209,6 +213,10 @@ def test_stable_unclip_h_img2img(self):
209213
)
210214
pipe.to(torch_device)
211215
pipe.set_progress_bar_config(disable=None)
216+
# stable unclip will oom when integration tests are run on a V100,
217+
# so turn on memory savings
218+
pipe.enable_attention_slicing()
219+
pipe.enable_sequential_cpu_offload()
212220

213221
generator = torch.Generator(device="cpu").manual_seed(0)
214222
output = pipe("anime turle", image=input_image, generator=generator, output_type="np")

0 commit comments

Comments
 (0)