Skip to content

Commit 1ad3520

Browse files
committed
add constant lr rate with rule
1 parent 8dde211 commit 1ad3520

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

src/diffusers/optimization.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def get_constant_schedule_with_rules(optimizer: Optimizer, rules: str, last_epoc
106106
rules_dict[steps] = value
107107
last_lr = float(rule_list[-1])
108108

109-
def create_rules_function():
109+
def create_rules_function(rules_dict, last_lr):
110110
def rule_func(steps: int) -> float:
111111
sorted_steps = sorted(rules_dict.keys())
112112
for i, sorted_step in enumerate(sorted_steps):
@@ -116,7 +116,7 @@ def rule_func(steps: int) -> float:
116116

117117
return rule_func
118118

119-
rules_f = create_rules_function()
119+
rules_f = create_rules_function(rules_dict, last_lr)
120120

121121
return LambdaLR(optimizer, rules_f, last_epoch=last_epoch)
122122

@@ -283,6 +283,7 @@ def lr_lambda(current_step: int):
283283
def get_scheduler(
284284
name: Union[str, SchedulerType],
285285
optimizer: Optimizer,
286+
rules: Optional[str] = None,
286287
num_warmup_steps: Optional[int] = None,
287288
num_training_steps: Optional[int] = None,
288289
num_cycles: int = 1,
@@ -315,6 +316,9 @@ def get_scheduler(
315316
if name == SchedulerType.CONSTANT:
316317
return schedule_func(optimizer, last_epoch=last_epoch)
317318

319+
if name == SchedulerType.CONSTANT_WITH_RULES:
320+
return schedule_func(optimizer, rules=rules, last_epoch=last_epoch)
321+
318322
# All other schedulers require `num_warmup_steps`
319323
if num_warmup_steps is None:
320324
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")

0 commit comments

Comments
 (0)