@@ -34,6 +34,7 @@ class SchedulerType(Enum):
3434 POLYNOMIAL = "polynomial"
3535 CONSTANT = "constant"
3636 CONSTANT_WITH_WARMUP = "constant_with_warmup"
37+ PIECEWISE_CONSTANT = "piecewise_constant"
3738
3839
3940def get_constant_schedule (optimizer : Optimizer , last_epoch : int = - 1 ):
@@ -77,6 +78,48 @@ def lr_lambda(current_step: int):
7778 return LambdaLR (optimizer , lr_lambda , last_epoch = last_epoch )
7879
7980
81+ def get_piecewise_constant_schedule (optimizer : Optimizer , step_rules : str , last_epoch : int = - 1 ):
82+ """
83+ Create a schedule with a constant learning rate, using the learning rate set in optimizer.
84+
85+ Args:
86+ optimizer ([`~torch.optim.Optimizer`]):
87+ The optimizer for which to schedule the learning rate.
88+ step_rules (`string`):
89+ The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate
90+ if multiple 1 for the first 10 steps, mutiple 0.1 for the next 20 steps, multiple 0.01 for the next 30
91+ steps and multiple 0.005 for the other steps.
92+ last_epoch (`int`, *optional*, defaults to -1):
93+ The index of the last epoch when resuming training.
94+
95+ Return:
96+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
97+ """
98+
99+ rules_dict = {}
100+ rule_list = step_rules .split ("," )
101+ for rule_str in rule_list [:- 1 ]:
102+ value_str , steps_str = rule_str .split (":" )
103+ steps = int (steps_str )
104+ value = float (value_str )
105+ rules_dict [steps ] = value
106+ last_lr_multiple = float (rule_list [- 1 ])
107+
108+ def create_rules_function (rules_dict , last_lr_multiple ):
109+ def rule_func (steps : int ) -> float :
110+ sorted_steps = sorted (rules_dict .keys ())
111+ for i , sorted_step in enumerate (sorted_steps ):
112+ if steps < sorted_step :
113+ return rules_dict [sorted_steps [i ]]
114+ return last_lr_multiple
115+
116+ return rule_func
117+
118+ rules_func = create_rules_function (rules_dict , last_lr_multiple )
119+
120+ return LambdaLR (optimizer , rules_func , last_epoch = last_epoch )
121+
122+
80123def get_linear_schedule_with_warmup (optimizer , num_warmup_steps , num_training_steps , last_epoch = - 1 ):
81124 """
82125 Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
@@ -232,12 +275,14 @@ def lr_lambda(current_step: int):
232275 SchedulerType .POLYNOMIAL : get_polynomial_decay_schedule_with_warmup ,
233276 SchedulerType .CONSTANT : get_constant_schedule ,
234277 SchedulerType .CONSTANT_WITH_WARMUP : get_constant_schedule_with_warmup ,
278+ SchedulerType .PIECEWISE_CONSTANT : get_piecewise_constant_schedule ,
235279}
236280
237281
238282def get_scheduler (
239283 name : Union [str , SchedulerType ],
240284 optimizer : Optimizer ,
285+ step_rules : Optional [str ] = None ,
241286 num_warmup_steps : Optional [int ] = None ,
242287 num_training_steps : Optional [int ] = None ,
243288 num_cycles : int = 1 ,
@@ -252,6 +297,8 @@ def get_scheduler(
252297 The name of the scheduler to use.
253298 optimizer (`torch.optim.Optimizer`):
254299 The optimizer that will be used during training.
300+ step_rules (`str`, *optional*):
301+ A string representing the step rules to use. This is only used by the `PIECEWISE_CONSTANT` scheduler.
255302 num_warmup_steps (`int`, *optional*):
256303 The number of warmup steps to do. This is not required by all schedulers (hence the argument being
257304 optional), the function will raise an error if it's unset and the scheduler type requires it.
@@ -270,6 +317,9 @@ def get_scheduler(
270317 if name == SchedulerType .CONSTANT :
271318 return schedule_func (optimizer , last_epoch = last_epoch )
272319
320+ if name == SchedulerType .PIECEWISE_CONSTANT :
321+ return schedule_func (optimizer , rules = step_rules , last_epoch = last_epoch )
322+
273323 # All other schedulers require `num_warmup_steps`
274324 if num_warmup_steps is None :
275325 raise ValueError (f"{ name } requires `num_warmup_steps`, please provide that argument." )
0 commit comments