Skip to content

[Integration] add swanlab logger #10594

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions paddlenlp/trainer/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ def is_wandb_available():
return importlib.util.find_spec("wandb") is not None


def is_swanlab_available():
return importlib.util.find_spec("swanlab") is not None


def is_ray_available():
return importlib.util.find_spec("ray.air") is not None

Expand All @@ -55,6 +59,8 @@ def get_available_reporting_integrations():
integrations.append("wandb")
if is_tensorboardX_available():
integrations.append("tensorboard")
if is_swanlab_available():
integrations.append("swanlab")

return integrations

Expand Down Expand Up @@ -395,6 +401,85 @@ def on_save(self, args, state, control, **kwargs):
self._wandb.log_artifact(artifact, aliases=[f"checkpoint-{state.global_step}"])


class SwanLabCallback(TrainerCallback):
"""
A [`TrainerCallback`] that logs metrics, media to [Swanlab](https://swanlab.cn/).
"""

def __init__(self):
has_swanlab = is_swanlab_available()
if not has_swanlab:
raise RuntimeError("SwanlabCallback requires swanlab to be installed. Run `pip install swanlab`.")
if has_swanlab:
import swanlab

self._swanlab = swanlab

self._initialized = False

def setup(self, args, state, model, **kwargs):
"""
Setup the optional Swanlab integration.

One can subclass and override this method to customize the setup if needed.
variables:
Environment:
- **SWANLAB_MODE** (`str`, *optional*, defaults to `"cloud"`):
Whether to use swanlab cloud, local or disabled. Set `SWANLAB_MODE="local"` to use local. Set `SWANLAB_MODE="disabled"` to disable.
- **SWANLAB_PROJECT** (`str`, *optional*, defaults to `"PaddleNLP"`):
Set this to a custom string to store results in a different project.
"""

if self._swanlab is None:
return

self._initialized = True

if state.is_world_process_zero:
logger.info('Automatic Swanlab logging enabled, to disable set os.environ["SWANLAB_MODE"] = "disabled"')

combined_dict = {**args.to_dict()}

if hasattr(model, "config") and model.config is not None:
model_config = model.config.to_dict()
combined_dict = {**model_config, **combined_dict}

trial_name = state.trial_name
init_args = {}
if trial_name is not None:
init_args["name"] = trial_name
init_args["group"] = args.run_name
else:
if not (args.run_name is None or args.run_name == args.output_dir):
init_args["name"] = args.run_name
init_args["dir"] = args.logging_dir
if self._swanlab.get_run() is None:
self._swanlab.init(
project=os.getenv("SWANLAB_PROJECT", "PaddleNLP"),
**init_args,
)
self._swanlab.config.update(combined_dict, allow_val_change=True)

def on_train_begin(self, args, state, control, model=None, **kwargs):
if self._swanlab is None:
return
if not self._initialized:
self.setup(args, state, model, **kwargs)

def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
if self._swanlab is None:
return

def on_log(self, args, state, control, model=None, logs=None, **kwargs):
if self._swanlab is None:
return
if not self._initialized:
self.setup(args, state, model)
if state.is_world_process_zero:
logs = rewrite_logs(logs)
self._swanlab.log({**logs, "train/global_step": state.global_step}, step=state.global_step)


class AutoNLPCallback(TrainerCallback):
"""
A [`TrainerCallback`] that sends the logs to [`Ray Tune`] for [`AutoNLP`]
Expand Down Expand Up @@ -423,6 +508,7 @@ def on_evaluate(self, args, state, control, **kwargs):
"autonlp": AutoNLPCallback,
"wandb": WandbCallback,
"tensorboard": TensorBoardCallback,
"swanlab": SwanLabCallback,
}


Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ class TrainingArguments:
instance of `Dataset`.
report_to (`str` or `List[str]`, *optional*, defaults to `"visualdl"`):
The list of integrations to report the results and logs to.
Supported platforms are `"visualdl"`/`"wandb"`/`"tensorboard"`.
Supported platforms are `"visualdl"`/`"wandb"`/`"tensorboard"`/`"swanlab"`.
`"none"` for no integrations.
ddp_find_unused_parameters (`bool`, *optional*):
When using distributed training, the value of the flag `find_unused_parameters` passed to
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ rouge
tiktoken
visualdl
wandb
swanlab
tensorboard
tensorboardX
modelscope
Expand Down
31 changes: 31 additions & 0 deletions tests/trainer/test_trainer_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from paddlenlp.trainer import TrainerControl, TrainerState, TrainingArguments
from paddlenlp.trainer.integrations import (
SwanLabCallback,
TensorBoardCallback,
VisualDLCallback,
WandbCallback,
Expand Down Expand Up @@ -66,6 +67,36 @@ def test_wandbcallback(self):
shutil.rmtree(output_dir)


class TestSwanlabCallback(unittest.TestCase):
def test_swanlabcallback(self):
output_dir = tempfile.mkdtemp()
args = TrainingArguments(
output_dir=output_dir,
max_steps=200,
logging_steps=20,
run_name="test_swanlabcallback",
logging_dir=output_dir,
)
state = TrainerState(trial_name="PaddleNLP")
control = TrainerControl()
config = RegressionModelConfig(a=1, b=1)
model = RegressionPretrainedModel(config)
os.environ["SWANLAB_MODE"] = "disabled"
swanlabcallback = SwanLabCallback()
self.assertFalse(swanlabcallback._initialized)
swanlabcallback.on_train_begin(args, state, control)
self.assertTrue(swanlabcallback._initialized)
for global_step in range(args.max_steps):
state.global_step = global_step
if global_step % args.logging_steps == 0:
log = {"loss": 100 - 0.4 * global_step, "learning_rate": 0.1, "global_step": global_step}
swanlabcallback.on_log(args, state, control, logs=log)
swanlabcallback.on_train_end(args, state, control, model=model)
swanlabcallback._swanlab.finish()
os.environ.pop("SWANLAB_MODE", None)
shutil.rmtree(output_dir)


class TestTensorboardCallback(unittest.TestCase):
def test_tbcallback(self):
output_dir = tempfile.mkdtemp()
Expand Down
Loading