|
50 | 50 | class IPAdapterNightlyTestsMixin(unittest.TestCase): |
51 | 51 | dtype = torch.float16 |
52 | 52 |
|
| 53 | + def setUp(self): |
| 54 | + # clean up the VRAM before each test |
| 55 | + super().setUp() |
| 56 | + gc.collect() |
| 57 | + torch.cuda.empty_cache() |
| 58 | + |
53 | 59 | def tearDown(self): |
| 60 | + # clean up the VRAM after each test |
54 | 61 | super().tearDown() |
55 | 62 | gc.collect() |
56 | 63 | torch.cuda.empty_cache() |
@@ -313,7 +320,7 @@ def test_text_to_image_sdxl(self): |
313 | 320 | feature_extractor=feature_extractor, |
314 | 321 | torch_dtype=self.dtype, |
315 | 322 | ) |
316 | | - pipeline.to(torch_device) |
| 323 | + pipeline.enable_model_cpu_offload() |
317 | 324 | pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin") |
318 | 325 |
|
319 | 326 | inputs = self.get_dummy_inputs() |
@@ -373,7 +380,7 @@ def test_image_to_image_sdxl(self): |
373 | 380 | feature_extractor=feature_extractor, |
374 | 381 | torch_dtype=self.dtype, |
375 | 382 | ) |
376 | | - pipeline.to(torch_device) |
| 383 | + pipeline.enable_model_cpu_offload() |
377 | 384 | pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin") |
378 | 385 |
|
379 | 386 | inputs = self.get_dummy_inputs(for_image_to_image=True) |
@@ -442,7 +449,7 @@ def test_inpainting_sdxl(self): |
442 | 449 | feature_extractor=feature_extractor, |
443 | 450 | torch_dtype=self.dtype, |
444 | 451 | ) |
445 | | - pipeline.to(torch_device) |
| 452 | + pipeline.enable_model_cpu_offload() |
446 | 453 | pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin") |
447 | 454 |
|
448 | 455 | inputs = self.get_dummy_inputs(for_inpainting=True) |
@@ -490,7 +497,7 @@ def test_ip_adapter_single_mask(self): |
490 | 497 | image_encoder=image_encoder, |
491 | 498 | torch_dtype=self.dtype, |
492 | 499 | ) |
493 | | - pipeline.to(torch_device) |
| 500 | + pipeline.enable_model_cpu_offload() |
494 | 501 | pipeline.load_ip_adapter( |
495 | 502 | "h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter-plus-face_sdxl_vit-h.safetensors" |
496 | 503 | ) |
@@ -518,7 +525,7 @@ def test_ip_adapter_multiple_masks(self): |
518 | 525 | image_encoder=image_encoder, |
519 | 526 | torch_dtype=self.dtype, |
520 | 527 | ) |
521 | | - pipeline.to(torch_device) |
| 528 | + pipeline.enable_model_cpu_offload() |
522 | 529 | pipeline.load_ip_adapter( |
523 | 530 | "h94/IP-Adapter", subfolder="sdxl_models", weight_name=["ip-adapter-plus-face_sdxl_vit-h.safetensors"] * 2 |
524 | 531 | ) |
|
0 commit comments