Skip to content

Commit 426bc17

Browse files
authored
Fix DFT_series_decomp: use the top_k parameter
1 parent 947f915 commit 426bc17

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

models/TimeMixer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@ class DFT_series_decomp(nn.Module):
1111
Series decomposition block
1212
"""
1313

14-
def __init__(self, top_k=5):
14+
def __init__(self, top_k: int = 5):
1515
super(DFT_series_decomp, self).__init__()
1616
self.top_k = top_k
1717

1818
def forward(self, x):
1919
xf = torch.fft.rfft(x)
2020
freq = abs(xf)
2121
freq[0] = 0
22-
top_k_freq, top_list = torch.topk(freq, 5)
22+
top_k_freq, top_list = torch.topk(freq, k=self.top_k)
2323
xf[freq <= top_k_freq.min()] = 0
2424
x_season = torch.fft.irfft(xf)
2525
x_trend = x - x_season

0 commit comments

Comments
 (0)