@@ -34,7 +34,7 @@ class SchedulerType(Enum):
3434 POLYNOMIAL = "polynomial"
3535 CONSTANT = "constant"
3636 CONSTANT_WITH_WARMUP = "constant_with_warmup"
37- CONSTANT_WITH_RULES = "constant_with_rules "
37+ PIECEWISE_CONSTANT = "piecewise_constant "
3838
3939
4040def get_constant_schedule (optimizer : Optimizer , last_epoch : int = - 1 ):
@@ -78,17 +78,17 @@ def lr_lambda(current_step: int):
7878 return LambdaLR (optimizer , lr_lambda , last_epoch = last_epoch )
7979
8080
81- def get_constant_schedule_with_rules (optimizer : Optimizer , rules : str , last_epoch : int = - 1 ):
81+ def get_piecewise_constant_schedule (optimizer : Optimizer , step_rules : str , last_epoch : int = - 1 ):
8282 """
83- Create a schedule with a constant learning rate with rule for the learning rate.
83+ Create a schedule with a constant learning rate, using the learning rate set in optimizer .
8484
8585 Args:
8686 optimizer ([`~torch.optim.Optimizer`]):
8787 The optimizer for which to schedule the learning rate.
88- rule (`string`):
89- The rules for the learning rate. ex: rules ="1:10,0.1:20,0.01:30,0.005" it means that the learning rate is
90- multiple 1 for the first 10 steps, mutiple 0.1 for the next 20 steps, multiple 0.01 for the next 30 steps
91- and multiple 0.005 for the other steps.
88+ step_rules (`string`):
89+ The rules for the learning rate. ex: rule_steps ="1:10,0.1:20,0.01:30,0.005" it means that the learning rate
90+ if multiple 1 for the first 10 steps, mutiple 0.1 for the next 20 steps, multiple 0.01 for the next 30
91+ steps and multiple 0.005 for the other steps.
9292 last_epoch (`int`, *optional*, defaults to -1):
9393 The index of the last epoch when resuming training.
9494
@@ -97,27 +97,27 @@ def get_constant_schedule_with_rules(optimizer: Optimizer, rules: str, last_epoc
9797 """
9898
9999 rules_dict = {}
100- rule_list = rules .split ("," )
100+ rule_list = step_rules .split ("," )
101101 for rule_str in rule_list [:- 1 ]:
102102 value_str , steps_str = rule_str .split (":" )
103103 steps = int (steps_str )
104104 value = float (value_str )
105105 rules_dict [steps ] = value
106- last_lr = float (rule_list [- 1 ])
106+ last_lr_multiple = float (rule_list [- 1 ])
107107
108- def create_rules_function (rules_dict , last_lr ):
108+ def create_rules_function (rules_dict , last_lr_multiple ):
109109 def rule_func (steps : int ) -> float :
110110 sorted_steps = sorted (rules_dict .keys ())
111111 for i , sorted_step in enumerate (sorted_steps ):
112112 if steps < sorted_step :
113113 return rules_dict [sorted_steps [i ]]
114- return last_lr
114+ return last_lr_multiple
115115
116116 return rule_func
117117
118- rules_f = create_rules_function (rules_dict , last_lr )
118+ rules_func = create_rules_function (rules_dict , last_lr_multiple )
119119
120- return LambdaLR (optimizer , rules_f , last_epoch = last_epoch )
120+ return LambdaLR (optimizer , rules_func , last_epoch = last_epoch )
121121
122122
123123def get_linear_schedule_with_warmup (optimizer , num_warmup_steps , num_training_steps , last_epoch = - 1 ):
@@ -275,14 +275,14 @@ def lr_lambda(current_step: int):
275275 SchedulerType .POLYNOMIAL : get_polynomial_decay_schedule_with_warmup ,
276276 SchedulerType .CONSTANT : get_constant_schedule ,
277277 SchedulerType .CONSTANT_WITH_WARMUP : get_constant_schedule_with_warmup ,
278- SchedulerType .CONSTANT_WITH_RULES : get_constant_schedule_with_rules ,
278+ SchedulerType .PIECEWISE_CONSTANT : get_piecewise_constant_schedule ,
279279}
280280
281281
282282def get_scheduler (
283283 name : Union [str , SchedulerType ],
284284 optimizer : Optimizer ,
285- rules : Optional [str ] = None ,
285+ step_rules : Optional [str ] = None ,
286286 num_warmup_steps : Optional [int ] = None ,
287287 num_training_steps : Optional [int ] = None ,
288288 num_cycles : int = 1 ,
@@ -297,6 +297,8 @@ def get_scheduler(
297297 The name of the scheduler to use.
298298 optimizer (`torch.optim.Optimizer`):
299299 The optimizer that will be used during training.
300+ step_rules (`str`, *optional*):
301+ A string representing the step rules to use. This is only used by the `PIECEWISE_CONSTANT` scheduler.
300302 num_warmup_steps (`int`, *optional*):
301303 The number of warmup steps to do. This is not required by all schedulers (hence the argument being
302304 optional), the function will raise an error if it's unset and the scheduler type requires it.
@@ -315,8 +317,8 @@ def get_scheduler(
315317 if name == SchedulerType .CONSTANT :
316318 return schedule_func (optimizer , last_epoch = last_epoch )
317319
318- if name == SchedulerType .CONSTANT_WITH_RULES :
319- return schedule_func (optimizer , rules = rules , last_epoch = last_epoch )
320+ if name == SchedulerType .PIECEWISE_CONSTANT :
321+ return schedule_func (optimizer , rules = step_rules , last_epoch = last_epoch )
320322
321323 # All other schedulers require `num_warmup_steps`
322324 if num_warmup_steps is None :
0 commit comments