Skip to content

Commit 4184eb1

Browse files
committed
add constant lr with rules
1 parent cfc99ad commit 4184eb1

File tree

1 file changed

+45
-1
lines changed

1 file changed

+45
-1
lines changed

src/diffusers/optimization.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import math
1818
from enum import Enum
19-
from typing import Optional, Union
19+
from typing import Callable, Optional, Union
2020

2121
from torch.optim import Optimizer
2222
from torch.optim.lr_scheduler import LambdaLR
@@ -34,6 +34,7 @@ class SchedulerType(Enum):
3434
POLYNOMIAL = "polynomial"
3535
CONSTANT = "constant"
3636
CONSTANT_WITH_WARMUP = "constant_with_warmup"
37+
CONSTANT_WITH_RULES = "constant_with_rules"
3738

3839

3940
def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
@@ -77,6 +78,49 @@ def lr_lambda(current_step: int):
7778
return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
7879

7980

81+
def get_constant_schedule_with_rules(optimizer: Optimizer, rules: str, last_epoch: int = -1):
82+
"""
83+
Create a schedule with a constant learning rate with rule for the learning rate.
84+
85+
Args:
86+
optimizer ([`~torch.optim.Optimizer`]):
87+
The optimizer for which to schedule the learning rate.
88+
rule (`string`):
89+
The rules for the learning rate.
90+
ex: rules="1:10,0.1:20,0.01:30,0.005"
91+
it means that the learning rate is multiple 1 for the first 10 steps, mutiple 0.1 for the
92+
next 20 steps, multiple 0.01 for the next 30 steps and multiple 0.005 for the other steps.
93+
last_epoch (`int`, *optional*, defaults to -1):
94+
The index of the last epoch when resuming training.
95+
96+
Return:
97+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
98+
"""
99+
100+
rules_dict = {}
101+
rule_list = rules.split(",")
102+
for rule_str in rule_list[:-1]:
103+
value_str, steps_str = rule_str.split(":")
104+
steps = int(steps_str)
105+
value = float(value_str)
106+
rules_dict[steps] = value
107+
last_lr = float(rule_list[-1])
108+
109+
def create_rules_function():
110+
def rule_func(steps: int) -> float:
111+
sorted_steps = sorted(rules_dict.keys())
112+
for i, sorted_step in enumerate(sorted_steps):
113+
if steps < sorted_step:
114+
return rules_dict[sorted_steps[i]]
115+
return last_lr
116+
117+
return rule_func
118+
119+
rules_f = create_rules_function()
120+
121+
return LambdaLR(optimizer, rules_f, last_epoch=last_epoch)
122+
123+
80124
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
81125
"""
82126
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after

0 commit comments

Comments
 (0)