Skip to content

Commit 76ec3d1

Browse files
authored
Support dynamically loading/unloading loras with group offloading (#11804)
* update * add test * address review comments * update * fixes * change decorator order to fix tests * try fix * fight tests
1 parent cdaf84a commit 76ec3d1

File tree

7 files changed

+290
-176
lines changed

7 files changed

+290
-176
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 140 additions & 153 deletions
Large diffs are not rendered by default.

src/diffusers/loaders/lora_base.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from huggingface_hub import model_info
2626
from huggingface_hub.constants import HF_HUB_OFFLINE
2727

28+
from ..hooks.group_offloading import _is_group_offload_enabled, _maybe_remove_and_reapply_group_offloading
2829
from ..models.modeling_utils import ModelMixin, load_state_dict
2930
from ..utils import (
3031
USE_PEFT_BACKEND,
@@ -391,7 +392,9 @@ def _load_lora_into_text_encoder(
391392
adapter_name = get_adapter_name(text_encoder)
392393

393394
# <Unsafe code
394-
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
395+
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = _func_optionally_disable_offloading(
396+
_pipeline
397+
)
395398
# inject LoRA layers and load the state dict
396399
# in transformers we automatically check whether the adapter name is already in use or not
397400
text_encoder.load_adapter(
@@ -410,6 +413,10 @@ def _load_lora_into_text_encoder(
410413
_pipeline.enable_model_cpu_offload()
411414
elif is_sequential_cpu_offload:
412415
_pipeline.enable_sequential_cpu_offload()
416+
elif is_group_offload:
417+
for component in _pipeline.components.values():
418+
if isinstance(component, torch.nn.Module):
419+
_maybe_remove_and_reapply_group_offloading(component)
413420
# Unsafe code />
414421

415422
if prefix is not None and not state_dict:
@@ -433,30 +440,36 @@ def _func_optionally_disable_offloading(_pipeline):
433440
434441
Returns:
435442
tuple:
436-
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
443+
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` or `is_group_offload` is True.
437444
"""
438445
is_model_cpu_offload = False
439446
is_sequential_cpu_offload = False
447+
is_group_offload = False
440448

441449
if _pipeline is not None and _pipeline.hf_device_map is None:
442450
for _, component in _pipeline.components.items():
443-
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
444-
if not is_model_cpu_offload:
445-
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
446-
if not is_sequential_cpu_offload:
447-
is_sequential_cpu_offload = (
448-
isinstance(component._hf_hook, AlignDevicesHook)
449-
or hasattr(component._hf_hook, "hooks")
450-
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
451-
)
451+
if not isinstance(component, nn.Module):
452+
continue
453+
is_group_offload = is_group_offload or _is_group_offload_enabled(component)
454+
if not hasattr(component, "_hf_hook"):
455+
continue
456+
is_model_cpu_offload = is_model_cpu_offload or isinstance(component._hf_hook, CpuOffload)
457+
is_sequential_cpu_offload = is_sequential_cpu_offload or (
458+
isinstance(component._hf_hook, AlignDevicesHook)
459+
or hasattr(component._hf_hook, "hooks")
460+
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
461+
)
452462

453-
logger.info(
454-
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
455-
)
456-
if is_sequential_cpu_offload or is_model_cpu_offload:
457-
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
463+
if is_sequential_cpu_offload or is_model_cpu_offload:
464+
logger.info(
465+
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
466+
)
467+
for _, component in _pipeline.components.items():
468+
if not isinstance(component, nn.Module) or not hasattr(component, "_hf_hook"):
469+
continue
470+
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
458471

459-
return (is_model_cpu_offload, is_sequential_cpu_offload)
472+
return (is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload)
460473

461474

462475
class LoraBaseMixin:

src/diffusers/loaders/peft.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import safetensors
2323
import torch
2424

25+
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
2526
from ..utils import (
2627
MIN_PEFT_VERSION,
2728
USE_PEFT_BACKEND,
@@ -256,7 +257,9 @@ def load_lora_adapter(
256257

257258
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
258259
# otherwise loading LoRA weights will lead to an error.
259-
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
260+
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
261+
_pipeline
262+
)
260263
peft_kwargs = {}
261264
if is_peft_version(">=", "0.13.1"):
262265
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
@@ -347,6 +350,10 @@ def map_state_dict_for_hotswap(sd):
347350
_pipeline.enable_model_cpu_offload()
348351
elif is_sequential_cpu_offload:
349352
_pipeline.enable_sequential_cpu_offload()
353+
elif is_group_offload:
354+
for component in _pipeline.components.values():
355+
if isinstance(component, torch.nn.Module):
356+
_maybe_remove_and_reapply_group_offloading(component)
350357
# Unsafe code />
351358

352359
if prefix is not None and not state_dict:
@@ -687,6 +694,8 @@ def unload_lora(self):
687694
if hasattr(self, "peft_config"):
688695
del self.peft_config
689696

697+
_maybe_remove_and_reapply_group_offloading(self)
698+
690699
def disable_lora(self):
691700
"""
692701
Disables the active LoRA layers of the underlying model.

src/diffusers/loaders/unet.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import torch.nn.functional as F
2323
from huggingface_hub.utils import validate_hf_hub_args
2424

25+
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
2526
from ..models.embeddings import (
2627
ImageProjection,
2728
IPAdapterFaceIDImageProjection,
@@ -203,6 +204,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
203204
is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys())
204205
is_model_cpu_offload = False
205206
is_sequential_cpu_offload = False
207+
is_group_offload = False
206208

207209
if is_lora:
208210
deprecation_message = "Using the `load_attn_procs()` method has been deprecated and will be removed in a future version. Please use `load_lora_adapter()`."
@@ -211,7 +213,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
211213
if is_custom_diffusion:
212214
attn_processors = self._process_custom_diffusion(state_dict=state_dict)
213215
elif is_lora:
214-
is_model_cpu_offload, is_sequential_cpu_offload = self._process_lora(
216+
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._process_lora(
215217
state_dict=state_dict,
216218
unet_identifier_key=self.unet_name,
217219
network_alphas=network_alphas,
@@ -230,7 +232,9 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
230232

231233
# For LoRA, the UNet is already offloaded at this stage as it is handled inside `_process_lora`.
232234
if is_custom_diffusion and _pipeline is not None:
233-
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline=_pipeline)
235+
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
236+
_pipeline=_pipeline
237+
)
234238

235239
# only custom diffusion needs to set attn processors
236240
self.set_attn_processor(attn_processors)
@@ -241,6 +245,10 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
241245
_pipeline.enable_model_cpu_offload()
242246
elif is_sequential_cpu_offload:
243247
_pipeline.enable_sequential_cpu_offload()
248+
elif is_group_offload:
249+
for component in _pipeline.components.values():
250+
if isinstance(component, torch.nn.Module):
251+
_maybe_remove_and_reapply_group_offloading(component)
244252
# Unsafe code />
245253

246254
def _process_custom_diffusion(self, state_dict):
@@ -307,6 +315,7 @@ def _process_lora(
307315

308316
is_model_cpu_offload = False
309317
is_sequential_cpu_offload = False
318+
is_group_offload = False
310319
state_dict_to_be_used = unet_state_dict if len(unet_state_dict) > 0 else state_dict
311320

312321
if len(state_dict_to_be_used) > 0:
@@ -356,7 +365,9 @@ def _process_lora(
356365

357366
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
358367
# otherwise loading LoRA weights will lead to an error
359-
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
368+
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
369+
_pipeline
370+
)
360371
peft_kwargs = {}
361372
if is_peft_version(">=", "0.13.1"):
362373
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
@@ -389,7 +400,7 @@ def _process_lora(
389400
if warn_msg:
390401
logger.warning(warn_msg)
391402

392-
return is_model_cpu_offload, is_sequential_cpu_offload
403+
return is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload
393404

394405
@classmethod
395406
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading

tests/lora/test_lora_layers_cogvideox.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import unittest
1717

1818
import torch
19+
from parameterized import parameterized
1920
from transformers import AutoTokenizer, T5EncoderModel
2021

2122
from diffusers import (
@@ -28,6 +29,7 @@
2829
from diffusers.utils.testing_utils import (
2930
floats_tensor,
3031
require_peft_backend,
32+
require_torch_accelerator,
3133
)
3234

3335

@@ -127,6 +129,13 @@ def test_simple_inference_with_text_denoiser_lora_unfused(self):
127129
def test_lora_scale_kwargs_match_fusion(self):
128130
super().test_lora_scale_kwargs_match_fusion(expected_atol=9e-3, expected_rtol=9e-3)
129131

132+
@parameterized.expand([("block_level", True), ("leaf_level", False)])
133+
@require_torch_accelerator
134+
def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
135+
# TODO: We don't run the (leaf_level, True) test here that is enabled for other models.
136+
# The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338
137+
super()._test_group_offloading_inference_denoiser(offload_type, use_stream)
138+
130139
@unittest.skip("Not supported in CogVideoX.")
131140
def test_simple_inference_with_text_denoiser_block_scale(self):
132141
pass

tests/lora/test_lora_layers_cogview4.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,17 @@
1818

1919
import numpy as np
2020
import torch
21+
from parameterized import parameterized
2122
from transformers import AutoTokenizer, GlmModel
2223

2324
from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
24-
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps, torch_device
25+
from diffusers.utils.testing_utils import (
26+
floats_tensor,
27+
require_peft_backend,
28+
require_torch_accelerator,
29+
skip_mps,
30+
torch_device,
31+
)
2532

2633

2734
sys.path.append(".")
@@ -141,6 +148,13 @@ def test_simple_inference_save_pretrained(self):
141148
"Loading from saved checkpoints should give same results.",
142149
)
143150

151+
@parameterized.expand([("block_level", True), ("leaf_level", False)])
152+
@require_torch_accelerator
153+
def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
154+
# TODO: We don't run the (leaf_level, True) test here that is enabled for other models.
155+
# The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338
156+
super()._test_group_offloading_inference_denoiser(offload_type, use_stream)
157+
144158
@unittest.skip("Not supported in CogView4.")
145159
def test_simple_inference_with_text_denoiser_block_scale(self):
146160
pass

tests/lora/utils.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
is_torch_version,
4040
require_peft_backend,
4141
require_peft_version_greater,
42+
require_torch_accelerator,
4243
require_transformers_version_greater,
4344
skip_mps,
4445
torch_device,
@@ -2355,3 +2356,73 @@ def test_inference_load_delete_load_adapters(self):
23552356
pipe.load_lora_weights(tmpdirname)
23562357
output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
23572358
self.assertTrue(np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3))
2359+
2360+
def _test_group_offloading_inference_denoiser(self, offload_type, use_stream):
2361+
from diffusers.hooks.group_offloading import _get_top_level_group_offload_hook
2362+
2363+
onload_device = torch_device
2364+
offload_device = torch.device("cpu")
2365+
2366+
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
2367+
pipe = self.pipeline_class(**components)
2368+
pipe = pipe.to(torch_device)
2369+
pipe.set_progress_bar_config(disable=None)
2370+
2371+
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
2372+
denoiser.add_adapter(denoiser_lora_config)
2373+
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
2374+
2375+
with tempfile.TemporaryDirectory() as tmpdirname:
2376+
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
2377+
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
2378+
self.pipeline_class.save_lora_weights(
2379+
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
2380+
)
2381+
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
2382+
2383+
components, _, _ = self.get_dummy_components(self.scheduler_classes[0])
2384+
pipe = self.pipeline_class(**components)
2385+
pipe = pipe.to(torch_device)
2386+
pipe.set_progress_bar_config(disable=None)
2387+
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
2388+
2389+
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
2390+
check_if_lora_correctly_set(denoiser)
2391+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
2392+
2393+
# Test group offloading with load_lora_weights
2394+
denoiser.enable_group_offload(
2395+
onload_device=onload_device,
2396+
offload_device=offload_device,
2397+
offload_type=offload_type,
2398+
num_blocks_per_group=1,
2399+
use_stream=use_stream,
2400+
)
2401+
group_offload_hook_1 = _get_top_level_group_offload_hook(denoiser)
2402+
self.assertTrue(group_offload_hook_1 is not None)
2403+
output_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
2404+
2405+
# Test group offloading after removing the lora
2406+
pipe.unload_lora_weights()
2407+
group_offload_hook_2 = _get_top_level_group_offload_hook(denoiser)
2408+
self.assertTrue(group_offload_hook_2 is not None)
2409+
output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] # noqa: F841
2410+
2411+
# Add the lora again and check if group offloading works
2412+
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
2413+
check_if_lora_correctly_set(denoiser)
2414+
group_offload_hook_3 = _get_top_level_group_offload_hook(denoiser)
2415+
self.assertTrue(group_offload_hook_3 is not None)
2416+
output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0]
2417+
2418+
self.assertTrue(np.allclose(output_1, output_3, atol=1e-3, rtol=1e-3))
2419+
2420+
@parameterized.expand([("block_level", True), ("leaf_level", False), ("leaf_level", True)])
2421+
@require_torch_accelerator
2422+
def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
2423+
for cls in inspect.getmro(self.__class__):
2424+
if "test_group_offloading_inference_denoiser" in cls.__dict__ and cls is not PeftLoraLoaderMixinTests:
2425+
# Skip this test if it is overwritten by child class. We need to do this because parameterized
2426+
# materializes the test methods on invocation which cannot be overridden.
2427+
return
2428+
self._test_group_offloading_inference_denoiser(offload_type, use_stream)

0 commit comments

Comments
 (0)