Skip to content

Commit 39c1e29

Browse files
erictang000hiyouga
andauthored
[ray] allow for specifying ray.init kwargs (i.e. runtime_env) (hiyouga#7647)
* ray init kwargs * Update trainer_utils.py * fix ray args --------- Co-authored-by: hoshi-hiyouga <[email protected]>
1 parent ee840b4 commit 39c1e29

File tree

3 files changed

+16
-2
lines changed

3 files changed

+16
-2
lines changed

examples/train_lora/llama3_lora_sft_ray.yaml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,16 @@ report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow]
3131
### ray
3232
ray_run_name: llama3_8b_sft_lora
3333
ray_storage_path: ./saves
34-
ray_num_workers: 4 # number of GPUs to use
34+
ray_num_workers: 4 # Number of GPUs to use.
35+
placement_strategy: PACK
3536
resources_per_worker:
3637
GPU: 1
37-
placement_strategy: PACK
38+
# ray_init_kwargs:
39+
# runtime_env:
40+
# env_vars:
41+
# <YOUR-ENV-VAR-HERE>: "<YOUR-ENV-VAR-HERE>"
42+
# pip:
43+
# - emoji
3844

3945
### train
4046
per_device_train_batch_size: 1

src/llamafactory/hparams/training_args.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ class RayArguments:
4646
default="PACK",
4747
metadata={"help": "The placement strategy for Ray training. Default is PACK."},
4848
)
49+
ray_init_kwargs: Optional[dict] = field(
50+
default=None,
51+
metadata={"help": "The arguments to pass to ray.init for Ray training. Default is None."},
52+
)
4953

5054
def __post_init__(self):
5155
self.use_ray = use_ray()

src/llamafactory/train/trainer_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848

4949

5050
if is_ray_available():
51+
import ray
5152
from ray.train import RunConfig, ScalingConfig
5253
from ray.train.torch import TorchTrainer
5354

@@ -644,6 +645,9 @@ def get_ray_trainer(
644645
if not ray_args.use_ray:
645646
raise ValueError("Ray was not enabled. Please set `USE_RAY=1` to enable ray.")
646647

648+
if ray_args.ray_init_kwargs is not None:
649+
ray.init(**ray_args.ray_init_kwargs)
650+
647651
trainer = TorchTrainer(
648652
training_function,
649653
train_loop_config=train_loop_config,

0 commit comments

Comments
 (0)