@@ -106,7 +106,7 @@ def get_constant_schedule_with_rules(optimizer: Optimizer, rules: str, last_epoc
106106 rules_dict [steps ] = value
107107 last_lr = float (rule_list [- 1 ])
108108
109- def create_rules_function ():
109+ def create_rules_function (rules_dict , last_lr ):
110110 def rule_func (steps : int ) -> float :
111111 sorted_steps = sorted (rules_dict .keys ())
112112 for i , sorted_step in enumerate (sorted_steps ):
@@ -116,7 +116,7 @@ def rule_func(steps: int) -> float:
116116
117117 return rule_func
118118
119- rules_f = create_rules_function ()
119+ rules_f = create_rules_function (rules_dict , last_lr )
120120
121121 return LambdaLR (optimizer , rules_f , last_epoch = last_epoch )
122122
@@ -283,6 +283,7 @@ def lr_lambda(current_step: int):
283283def get_scheduler (
284284 name : Union [str , SchedulerType ],
285285 optimizer : Optimizer ,
286+ rules : Optional [str ] = None ,
286287 num_warmup_steps : Optional [int ] = None ,
287288 num_training_steps : Optional [int ] = None ,
288289 num_cycles : int = 1 ,
@@ -315,6 +316,9 @@ def get_scheduler(
315316 if name == SchedulerType .CONSTANT :
316317 return schedule_func (optimizer , last_epoch = last_epoch )
317318
319+ if name == SchedulerType .CONSTANT_WITH_RULES :
320+ return schedule_func (optimizer , rules = rules , last_epoch = last_epoch )
321+
318322 # All other schedulers require `num_warmup_steps`
319323 if num_warmup_steps is None :
320324 raise ValueError (f"{ name } requires `num_warmup_steps`, please provide that argument." )
0 commit comments