Skip to content

Commit 348fdbd

Browse files
committed
change name constant_with_rules to piecewise constant
1 parent 489bb56 commit 348fdbd

File tree

1 file changed

+19
-17
lines changed

1 file changed

+19
-17
lines changed

src/diffusers/optimization.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class SchedulerType(Enum):
3434
POLYNOMIAL = "polynomial"
3535
CONSTANT = "constant"
3636
CONSTANT_WITH_WARMUP = "constant_with_warmup"
37-
CONSTANT_WITH_RULES = "constant_with_rules"
37+
PIECEWISE_CONSTANT = "piecewise_constant"
3838

3939

4040
def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
@@ -78,17 +78,17 @@ def lr_lambda(current_step: int):
7878
return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
7979

8080

81-
def get_constant_schedule_with_rules(optimizer: Optimizer, rules: str, last_epoch: int = -1):
81+
def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_epoch: int = -1):
8282
"""
83-
Create a schedule with a constant learning rate with rule for the learning rate.
83+
Create a schedule with a constant learning rate, using the learning rate set in optimizer.
8484
8585
Args:
8686
optimizer ([`~torch.optim.Optimizer`]):
8787
The optimizer for which to schedule the learning rate.
88-
rule (`string`):
89-
The rules for the learning rate. ex: rules="1:10,0.1:20,0.01:30,0.005" it means that the learning rate is
90-
multiple 1 for the first 10 steps, mutiple 0.1 for the next 20 steps, multiple 0.01 for the next 30 steps
91-
and multiple 0.005 for the other steps.
88+
step_rules (`string`):
89+
The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate
90+
if multiple 1 for the first 10 steps, mutiple 0.1 for the next 20 steps, multiple 0.01 for the next 30
91+
steps and multiple 0.005 for the other steps.
9292
last_epoch (`int`, *optional*, defaults to -1):
9393
The index of the last epoch when resuming training.
9494
@@ -97,27 +97,27 @@ def get_constant_schedule_with_rules(optimizer: Optimizer, rules: str, last_epoc
9797
"""
9898

9999
rules_dict = {}
100-
rule_list = rules.split(",")
100+
rule_list = step_rules.split(",")
101101
for rule_str in rule_list[:-1]:
102102
value_str, steps_str = rule_str.split(":")
103103
steps = int(steps_str)
104104
value = float(value_str)
105105
rules_dict[steps] = value
106-
last_lr = float(rule_list[-1])
106+
last_lr_multiple = float(rule_list[-1])
107107

108-
def create_rules_function(rules_dict, last_lr):
108+
def create_rules_function(rules_dict, last_lr_multiple):
109109
def rule_func(steps: int) -> float:
110110
sorted_steps = sorted(rules_dict.keys())
111111
for i, sorted_step in enumerate(sorted_steps):
112112
if steps < sorted_step:
113113
return rules_dict[sorted_steps[i]]
114-
return last_lr
114+
return last_lr_multiple
115115

116116
return rule_func
117117

118-
rules_f = create_rules_function(rules_dict, last_lr)
118+
rules_func = create_rules_function(rules_dict, last_lr_multiple)
119119

120-
return LambdaLR(optimizer, rules_f, last_epoch=last_epoch)
120+
return LambdaLR(optimizer, rules_func, last_epoch=last_epoch)
121121

122122

123123
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
@@ -275,14 +275,14 @@ def lr_lambda(current_step: int):
275275
SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
276276
SchedulerType.CONSTANT: get_constant_schedule,
277277
SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
278-
SchedulerType.CONSTANT_WITH_RULES: get_constant_schedule_with_rules,
278+
SchedulerType.PIECEWISE_CONSTANT: get_piecewise_constant_schedule,
279279
}
280280

281281

282282
def get_scheduler(
283283
name: Union[str, SchedulerType],
284284
optimizer: Optimizer,
285-
rules: Optional[str] = None,
285+
step_rules: Optional[str] = None,
286286
num_warmup_steps: Optional[int] = None,
287287
num_training_steps: Optional[int] = None,
288288
num_cycles: int = 1,
@@ -297,6 +297,8 @@ def get_scheduler(
297297
The name of the scheduler to use.
298298
optimizer (`torch.optim.Optimizer`):
299299
The optimizer that will be used during training.
300+
step_rules (`str`, *optional*):
301+
A string representing the step rules to use. This is only used by the `PIECEWISE_CONSTANT` scheduler.
300302
num_warmup_steps (`int`, *optional*):
301303
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
302304
optional), the function will raise an error if it's unset and the scheduler type requires it.
@@ -315,8 +317,8 @@ def get_scheduler(
315317
if name == SchedulerType.CONSTANT:
316318
return schedule_func(optimizer, last_epoch=last_epoch)
317319

318-
if name == SchedulerType.CONSTANT_WITH_RULES:
319-
return schedule_func(optimizer, rules=rules, last_epoch=last_epoch)
320+
if name == SchedulerType.PIECEWISE_CONSTANT:
321+
return schedule_func(optimizer, rules=step_rules, last_epoch=last_epoch)
320322

321323
# All other schedulers require `num_warmup_steps`
322324
if num_warmup_steps is None:

0 commit comments

Comments
 (0)