We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 947f915 commit 426bc17Copy full SHA for 426bc17
models/TimeMixer.py
@@ -11,15 +11,15 @@ class DFT_series_decomp(nn.Module):
11
Series decomposition block
12
"""
13
14
- def __init__(self, top_k=5):
+ def __init__(self, top_k: int = 5):
15
super(DFT_series_decomp, self).__init__()
16
self.top_k = top_k
17
18
def forward(self, x):
19
xf = torch.fft.rfft(x)
20
freq = abs(xf)
21
freq[0] = 0
22
- top_k_freq, top_list = torch.topk(freq, 5)
+ top_k_freq, top_list = torch.topk(freq, k=self.top_k)
23
xf[freq <= top_k_freq.min()] = 0
24
x_season = torch.fft.irfft(xf)
25
x_trend = x - x_season
0 commit comments