|
31 | 31 | from torch.nn import functional as F
|
32 | 32 | from einops import rearrange
|
33 | 33 | from .utils import setup_logging
|
| 34 | + |
34 | 35 | setup_logging()
|
35 | 36 | import logging
|
| 37 | + |
36 | 38 | logger = logging.getLogger(__name__)
|
37 | 39 |
|
38 | 40 | IN_CHANNELS: int = 4
|
@@ -1074,7 +1076,7 @@ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
|
1074 | 1076 | timesteps = timesteps.expand(x.shape[0])
|
1075 | 1077 |
|
1076 | 1078 | hs = []
|
1077 |
| - t_emb = get_timestep_embedding(timesteps, self.model_channels) # , repeat_only=False) |
| 1079 | + t_emb = get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0) # , repeat_only=False) |
1078 | 1080 | t_emb = t_emb.to(x.dtype)
|
1079 | 1081 | emb = self.time_embed(t_emb)
|
1080 | 1082 |
|
@@ -1132,7 +1134,7 @@ def __init__(self, original_unet: SdxlUNet2DConditionModel, **kwargs):
|
1132 | 1134 | # call original model's methods
|
1133 | 1135 | def __getattr__(self, name):
|
1134 | 1136 | return getattr(self.delegate, name)
|
1135 |
| - |
| 1137 | + |
1136 | 1138 | def __call__(self, *args, **kwargs):
|
1137 | 1139 | return self.delegate(*args, **kwargs)
|
1138 | 1140 |
|
@@ -1164,7 +1166,7 @@ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
|
1164 | 1166 | timesteps = timesteps.expand(x.shape[0])
|
1165 | 1167 |
|
1166 | 1168 | hs = []
|
1167 |
| - t_emb = get_timestep_embedding(timesteps, _self.model_channels) # , repeat_only=False) |
| 1169 | + t_emb = get_timestep_embedding(timesteps, _self.model_channels, downscale_freq_shift=0) # , repeat_only=False) |
1168 | 1170 | t_emb = t_emb.to(x.dtype)
|
1169 | 1171 | emb = _self.time_embed(t_emb)
|
1170 | 1172 |
|
|
0 commit comments