@@ -12,23 +12,26 @@ class Sequence(nn.Module):
12
12
def __init__ (self ):
13
13
super (Sequence , self ).__init__ ()
14
14
self .lstm1 = nn .LSTMCell (1 , 51 )
15
- self .lstm2 = nn .LSTMCell (51 , 1 )
15
+ self .lstm2 = nn .LSTMCell (51 , 51 )
16
+ self .linear = nn .Linear (51 , 1 )
16
17
17
18
def forward (self , input , future = 0 ):
18
19
outputs = []
19
20
h_t = Variable (torch .zeros (input .size (0 ), 51 ).double (), requires_grad = False )
20
21
c_t = Variable (torch .zeros (input .size (0 ), 51 ).double (), requires_grad = False )
21
- h_t2 = Variable (torch .zeros (input .size (0 ), 1 ).double (), requires_grad = False )
22
- c_t2 = Variable (torch .zeros (input .size (0 ), 1 ).double (), requires_grad = False )
22
+ h_t2 = Variable (torch .zeros (input .size (0 ), 51 ).double (), requires_grad = False )
23
+ c_t2 = Variable (torch .zeros (input .size (0 ), 51 ).double (), requires_grad = False )
23
24
24
25
for i , input_t in enumerate (input .chunk (input .size (1 ), dim = 1 )):
25
26
h_t , c_t = self .lstm1 (input_t , (h_t , c_t ))
26
27
h_t2 , c_t2 = self .lstm2 (h_t , (h_t2 , c_t2 ))
27
- outputs += [h_t2 ]
28
+ output = self .linear (h_t2 )
29
+ outputs += [output ]
28
30
for i in range (future ):# if we should predict the future
29
- h_t , c_t = self .lstm1 (h_t2 , (h_t , c_t ))
31
+ h_t , c_t = self .lstm1 (output , (h_t , c_t ))
30
32
h_t2 , c_t2 = self .lstm2 (h_t , (h_t2 , c_t2 ))
31
- outputs += [h_t2 ]
33
+ output = self .linear (h_t2 )
34
+ outputs += [output ]
32
35
outputs = torch .stack (outputs , 1 ).squeeze (2 )
33
36
return outputs
34
37
0 commit comments