Skip to content

Commit ab7b231

Browse files
committed
init
1 parent 29177d2 commit ab7b231

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

library/train_util.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2994,7 +2994,7 @@ def int_or_float(value):
29942994
"--optimizer_type",
29952995
type=str,
29962996
default="",
2997-
help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor",
2997+
help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, AdEMAMix8bit, PagedAdEMAMix8bit, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor",
29982998
)
29992999

30003000
# backward compatibility
@@ -4032,7 +4032,7 @@ def task():
40324032

40334033

40344034
def get_optimizer(args, trainable_params):
4035-
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor"
4035+
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, AdEMAMix8bit, PagedAdEMAMix8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor"
40364036

40374037
optimizer_type = args.optimizer_type
40384038
if args.use_8bit_adam:
@@ -4141,7 +4141,22 @@ def get_optimizer(args, trainable_params):
41414141
raise AttributeError(
41424142
"No PagedLion8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedLion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください"
41434143
)
4144-
4144+
elif optimizer_type == "Ademamix8bit".lower():
4145+
logger.info(f"use 8-bit Ademamix optimizer | {optimizer_kwargs}")
4146+
try:
4147+
optimizer_class = bnb.optim.AdEMAMix8bit
4148+
except AttributeError:
4149+
raise AttributeError(
4150+
"No Ademamix8bit. The version of bitsandbytes installed seems to be old. Please install 0.44.0 or later. / Ademamix8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください"
4151+
)
4152+
elif optimizer_type == "PagedAdemamix8bit".lower():
4153+
logger.info(f"use 8-bit PagedAdemamix optimizer | {optimizer_kwargs}")
4154+
try:
4155+
optimizer_class = bnb.optim.PagedAdEMAMix8bit
4156+
except AttributeError:
4157+
raise AttributeError(
4158+
"No PagedAdemamix8bit. The version of bitsandbytes installed seems to be old. Please install 0.44.0 or later. / PagedAdemamix8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください"
4159+
)
41454160
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
41464161

41474162
elif optimizer_type == "PagedAdamW".lower():

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ ftfy==6.1.1
66
opencv-python==4.8.1.78
77
einops==0.7.0
88
pytorch-lightning==1.9.0
9-
bitsandbytes==0.43.0
9+
bitsandbytes==0.44.0
1010
prodigyopt==1.0
1111
lion-pytorch==0.0.6
1212
tensorboard

0 commit comments

Comments
 (0)