Skip to content

Commit fbd3d48

Browse files
committed
push-to-hub fix
1 parent 06dca18 commit fbd3d48

File tree

2 files changed

+11
-11
lines changed

2 files changed

+11
-11
lines changed

run_clm_streaming.sh

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#! /bin/bash
22
./run_clm_streaming_flax_v2.py \
3-
--output_dir $HOME/gpt-neo-125M-code-clippy \
3+
--output_dir $HOME/gpt-neo-125M-test \
44
--model_name_or_path="EleutherAI/gpt-neo-125M" \
55
--dataset_name $HOME/gpt-code-clippy/code_clippy.py \
66
--data_dir /home/shared/code-clippy-dataset/merged-data \
@@ -11,27 +11,27 @@
1111
--per_device_eval_batch_size="16" \
1212
--preprocessing_num_workers="8" \
1313
--learning_rate="6e-4" \
14-
--adafactor \
15-
--max_steps 10000 \
16-
--warmup_steps 3000 \
17-
--decay_steps 5000 \
14+
--max_steps 500 \
15+
--warmup_steps 150 \
16+
--decay_steps 250 \
1817
--adam_beta1="0.9" \
1918
--adam_beta2="0.95" \
2019
--weight_decay="0.01" \
2120
--overwrite_output_dir \
22-
--logging_steps="100" \
23-
--eval_steps="100" \
24-
--push_to_hub="False" \
21+
--logging_steps="10" \
22+
--eval_steps="50" \
23+
--push_to_hub="True" \
2524
--report_to="all" \
2625
--dtype="bfloat16" \
2726
--skip_memory_metrics="False" \
28-
--save_steps="100" \
27+
--save_steps="50" \
2928
--save_total_limit 2 \
3029
--gradient_accumulation_steps 8 \
3130
--report_to="wandb" \
32-
--run_name="testing" \
31+
--run_name="testing-mini" \
3332
--max_eval_samples 100 \
3433
--save_optimizer true \
34+
# --adafactor \
3535
# --resume_from_checkpoint $HOME/gpt-neo-125M-code-clippy/ \
3636
# --max_train_samples="10000" \
3737

run_clm_streaming_flax_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -807,7 +807,7 @@ def eval_step(params, batch):
807807
push_to_hub=training_args.push_to_hub)
808808
if model_args.save_optimizer:
809809
# this saves full state including optimizer
810-
save_checkpoint(training_args.output_dir, jax_utils.unreplicate(state), cur_step, keep=training_args.save_total_limit, overwrite=False)
810+
save_checkpoint(training_args.output_dir, jax_utils.unreplicate(state), cur_step, keep=training_args.save_total_limit, overwrite=True)
811811
if training_args.save_total_limit is not None:
812812
rotate_checkpoints(training_args.output_dir, training_args.save_total_limit)
813813

0 commit comments

Comments
 (0)