Skip to content

update packing #3751

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update
  • Loading branch information
Jintao-Huang committed Apr 5, 2025
commit 986fbde40a9a2ecc31c8d937efb3a37ac835244d
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
" save_steps=50,\n",
" eval_strategy='steps',\n",
" eval_steps=50,\n",
" gradient_accumulation_steps=2,\n",
" gradient_accumulation_steps=16,\n",
" num_train_epochs=1,\n",
" metric_for_best_model='loss',\n",
" save_total_limit=2,\n",
Expand Down
27 changes: 0 additions & 27 deletions swift/trainers/rlhf_trainer/rlhf_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,38 +100,11 @@ def __init__(self,
self.is_vision_model = False
self.label_pad_token_id = -100
self.use_dpo_data_collator = True
if args.use_liger_kernel:
for m in [model, ref_model]:
if m is not None:
self._apply_liger(m)

if is_deepspeed_zero3_enabled() and ref_model is not None:
model = ModelWrapper(model, ref_model)
args.use_liger_kernel = False # compat zero3
super().__init__(model, *_args, **kwargs)
args.use_liger_kernel = True # recover
self.padding_value = self.tokenizer.pad_token_id

@staticmethod
def _apply_liger(model):
# copy from transformers trainer
from transformers.utils import is_liger_kernel_available
if is_liger_kernel_available():
from liger_kernel.transformers import _apply_liger_kernel_to_instance

if isinstance(model, PreTrainedModel):
# Patch the model with liger kernels. Use the default kernel configurations.
_apply_liger_kernel_to_instance(model=model)
elif hasattr(model, 'get_base_model') and isinstance(model.get_base_model(), PreTrainedModel):
# Patch the base model with liger kernels where model is a PeftModel.
# Use the default kernel configurations.
_apply_liger_kernel_to_instance(model=model.get_base_model())
else:
logger.warning('The model is not an instance of PreTrainedModel. No liger kernels will be applied.')
else:
raise ImportError('You have set `use_liger_kernel` to `True` but liger-kernel >= 0.3.0 is not available. '
'Please install it with `pip install liger-kernel`')

def _save_checkpoint(self, model, *args, **kwargs):
context = nullcontext()
if hasattr(model, '_save_load_context'):
Expand Down
Loading