Skip to content

Commit 6b1520a

Browse files
authored
Merge pull request kohya-ss#1187 from kohya-ss/fix-timeemb
fix sdxl timestep embedding
2 parents 2d73891 + f811b11 commit 6b1520a

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,16 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum
249249

250250
## Change History
251251

252+
### Mar 15, 2024 / 2024/3/15: v0.8.5
253+
254+
- Fixed a bug that the value of timestep embedding during SDXL training was incorrect.
255+
- The inference with the generation script is also fixed.
256+
- The impact is unknown, but please update for SDXL training.
257+
258+
- SDXL 学習時の timestep embedding の値が誤っていたのを修正しました。
259+
- 生成スクリプトでの推論時についてもあわせて修正しました。
260+
- 影響の度合いは不明ですが、SDXL の学習時にはアップデートをお願いいたします。
261+
252262
### Feb 24, 2024 / 2024/2/24: v0.8.4
253263

254264
- The log output has been improved. PR [#905](https://github.com/kohya-ss/sd-scripts/pull/905) Thanks to shirayu!

library/sdxl_original_unet.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@
3131
from torch.nn import functional as F
3232
from einops import rearrange
3333
from .utils import setup_logging
34+
3435
setup_logging()
3536
import logging
37+
3638
logger = logging.getLogger(__name__)
3739

3840
IN_CHANNELS: int = 4
@@ -1074,7 +1076,7 @@ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
10741076
timesteps = timesteps.expand(x.shape[0])
10751077

10761078
hs = []
1077-
t_emb = get_timestep_embedding(timesteps, self.model_channels) # , repeat_only=False)
1079+
t_emb = get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0) # , repeat_only=False)
10781080
t_emb = t_emb.to(x.dtype)
10791081
emb = self.time_embed(t_emb)
10801082

@@ -1132,7 +1134,7 @@ def __init__(self, original_unet: SdxlUNet2DConditionModel, **kwargs):
11321134
# call original model's methods
11331135
def __getattr__(self, name):
11341136
return getattr(self.delegate, name)
1135-
1137+
11361138
def __call__(self, *args, **kwargs):
11371139
return self.delegate(*args, **kwargs)
11381140

@@ -1164,7 +1166,7 @@ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
11641166
timesteps = timesteps.expand(x.shape[0])
11651167

11661168
hs = []
1167-
t_emb = get_timestep_embedding(timesteps, _self.model_channels) # , repeat_only=False)
1169+
t_emb = get_timestep_embedding(timesteps, _self.model_channels, downscale_freq_shift=0) # , repeat_only=False)
11681170
t_emb = t_emb.to(x.dtype)
11691171
emb = _self.time_embed(t_emb)
11701172

0 commit comments

Comments
 (0)