-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Support dynamically loading/unloading loras with group offloading #11804
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
cc @zhangvia I think this should fix the issues you were facing. Could you test? Thanks 🤗 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Just some minor comments.
And maybe we could include disabling and enabling group offloading in the existing _func_optionally_*()
function. But not strong opinions.
Yep, done! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fantastic!
@@ -268,9 +288,12 @@ class GroupOffloadingHook(ModelHook): | |||
|
|||
_is_stateful = False | |||
|
|||
def __init__(self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None) -> None: | |||
def __init__( | |||
self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None, *, config: GroupOffloadingConfig |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❤️
else: | ||
raise ValueError(f"Unsupported offload_type: {offload_type}") | ||
assert False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be turned into a sensible value error like previous? 👁️
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The statement cannot be reached because we instantiate GroupOffloadingType with the value passed by user. Dataclass and enum will raise an error if the value is invalid
assert False is just a placeholder
Thanks for the quick fix! i've Confirmed it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice 👍🏽
A bit of a problem with the tests that weren't caught until now... Group offloading with streams is limited in what it can do. If the same layer is invoked twice in the same parent layer's forward, the prefetching logic becomes completely incorrect. This is the case with:
One option to make it work is to concatenate inputs across sequence dim and then split again. This will, however, incur some extra perf cost because concatenation/split is not free. Another option is creating a separate layer for each data stream and sharing the weights. I don't think this should incur memory overhead since the data reference will be same for both layers, but need to test to be sure. Any other ideas are welcome. If these don't sound good, I propose we skip the tests for now and wait for group offloading logic to become more mature/stable. Other than than PR looks good to merge to me |
Thanks for investigating these and also for proposing the potential solutions. On the surface, I would say we evaluate both approaches and then decide. However, the two models you mentioned probably have limited usage at least with group offloading for now. So, |
I tested the first approach as it's a super simple change. The performance penalty is not noticeable end-to-end but only shows up at a small microsecond scale. I don't think it really matters because, like you mentioned, they probably have very limited usage in the context of group offloading. For now, I'll specialize the tests for CogView4/CogVideoX by parameterizing with only non-stream tests instead of xfailing them, and add a note. Sounds good? |
I am good! |
Updated the tests. There are some gynamistics involved in skipping tests marked with parameterized because it seems like can't be overwritten or specialized in child classes |
@parameterized.expand([("block_level", True), ("leaf_level", False)]) | ||
@require_torch_accelerator | ||
def test_group_offloading_inference_denoiser(self, offload_type, use_stream): | ||
# TODO: We don't run the (leaf_level, True) test here that is enabled for other models. | ||
# The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338 | ||
super()._test_group_offloading_inference_denoiser(offload_type, use_stream) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can do this or we could detect if the test class is either of CogView4 or CogVideoX and use pytest.skip()
. Upto you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prefer the base test class not having any information about the model test classes that derive it. Current implementation will work for any model that overrides the test, so also much cleaner
…ggingface#11804) * update * add test * address review comments * update * fixes * change decorator order to fix tests * try fix * fight tests
Fixes #11791.
reproducer