1616
1717import math
1818from enum import Enum
19- from typing import Optional , Union
19+ from typing import Callable , Optional , Union
2020
2121from torch .optim import Optimizer
2222from torch .optim .lr_scheduler import LambdaLR
@@ -34,6 +34,7 @@ class SchedulerType(Enum):
3434 POLYNOMIAL = "polynomial"
3535 CONSTANT = "constant"
3636 CONSTANT_WITH_WARMUP = "constant_with_warmup"
37+ CONSTANT_WITH_RULES = "constant_with_rules"
3738
3839
3940def get_constant_schedule (optimizer : Optimizer , last_epoch : int = - 1 ):
@@ -77,6 +78,49 @@ def lr_lambda(current_step: int):
7778 return LambdaLR (optimizer , lr_lambda , last_epoch = last_epoch )
7879
7980
81+ def get_constant_schedule_with_rules (optimizer : Optimizer , rules : str , last_epoch : int = - 1 ):
82+ """
83+ Create a schedule with a constant learning rate with rule for the learning rate.
84+
85+ Args:
86+ optimizer ([`~torch.optim.Optimizer`]):
87+ The optimizer for which to schedule the learning rate.
88+ rule (`string`):
89+ The rules for the learning rate.
90+ ex: rules="1:10,0.1:20,0.01:30,0.005"
91+ it means that the learning rate is multiple 1 for the first 10 steps, mutiple 0.1 for the
92+ next 20 steps, multiple 0.01 for the next 30 steps and multiple 0.005 for the other steps.
93+ last_epoch (`int`, *optional*, defaults to -1):
94+ The index of the last epoch when resuming training.
95+
96+ Return:
97+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
98+ """
99+
100+ rules_dict = {}
101+ rule_list = rules .split ("," )
102+ for rule_str in rule_list [:- 1 ]:
103+ value_str , steps_str = rule_str .split (":" )
104+ steps = int (steps_str )
105+ value = float (value_str )
106+ rules_dict [steps ] = value
107+ last_lr = float (rule_list [- 1 ])
108+
109+ def create_rules_function ():
110+ def rule_func (steps : int ) -> float :
111+ sorted_steps = sorted (rules_dict .keys ())
112+ for i , sorted_step in enumerate (sorted_steps ):
113+ if steps < sorted_step :
114+ return rules_dict [sorted_steps [i ]]
115+ return last_lr
116+
117+ return rule_func
118+
119+ rules_f = create_rules_function ()
120+
121+ return LambdaLR (optimizer , rules_f , last_epoch = last_epoch )
122+
123+
80124def get_linear_schedule_with_warmup (optimizer , num_warmup_steps , num_training_steps , last_epoch = - 1 ):
81125 """
82126 Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
0 commit comments