Skip to content

Add offload optimizer function #10607

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions paddlenlp/trainer/utils/offload_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from paddle import _C_ops
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.hybrid_parallel_optimizer import (
HybridParallelOptimizer,
)
from paddle.optimizer import Optimizer

from .sharding_io import to_device


def offload(tensor):
if paddle.is_compiled_with_cuda():
place = paddle.CUDAPinnedPlace()
else:
place = paddle.CPUPlace()

new_tensor = to_device(tensor, place)
assert new_tensor is tensor, "to_device must be inplace operation"


def reload(tensor):
new_tensor = to_device(tensor)
assert new_tensor is tensor, "to_device must be inplace operation"


def hack_offload_optimizer():
# Step 1: mock _add_accumulator
origin_add_accumulator = getattr(Optimizer, "_add_accumulator")

def new_add_accumulator(self, *args, **kwargs):
x = origin_add_accumulator(self, *args, **kwargs)
return offload(x)

setattr(Optimizer, "_add_accumulator", new_add_accumulator)

# Step 2: mock _C_ops.adamw_ and _C_ops.adamw
for name in ["adam_", "adamw_"]:
origin_op = getattr(_C_ops, name)

def new_opt_op(*args):
for arg in args:
if isinstance(arg, paddle.Tensor):
reload(arg)

ret = origin_op(*args)

for i, arg in enumerate(args):
if i >= 2 and isinstance(arg, paddle.Tensor): # do not offload parameter and gradient
offload(arg)
return ret

setattr(_C_ops, name, new_opt_op)

# Step 3: mock _insert_sync
opt_type = HybridParallelOptimizer
origin_insert_sync = getattr(opt_type, "_insert_sync")

def new_insert_sync(self, sync_var, *args, **kwargs):
origin_place = sync_var.place
reload(sync_var)
ret = origin_insert_sync(self, sync_var, *args, **kwargs)
new_sync_var = to_device(sync_var, origin_place)
assert new_sync_var is sync_var, "to_device must be inplace operation"
return ret

setattr(opt_type, "_insert_sync", new_insert_sync)