@@ -193,40 +193,48 @@ def forward_backward(self, y):
193193
194194 # set up
195195 N = tf .shape (y )[0 ]
196- nT = self .length or y .shape [1 ]
197- # nT = tf.shape(y)[1]
198196
199- forward = []
197+ # y (batch, recurrent, features) -> (recurrent, batch, features)
198+ y = tf .transpose (y , (1 , 0 , 2 ))
200199
201200 # forward pass
202- forward .append (tf .ones ((N , self .K )) * (1.0 / self .K ))
203- for t in range (nT ):
204- tmp = tf .multiply (tf .matmul (forward [t ], self .P ), y [:, t ])
205-
206- forward .append (tmp / tf .expand_dims (tf .reduce_sum (tmp , axis = 1 ), axis = 1 ))
201+ def forward_function (last_forward , yi ):
202+ tmp = tf .multiply (tf .matmul (last_forward , self .P ), yi )
203+ return tmp / tf .reduce_sum (tmp , axis = 1 , keep_dims = True )
204+
205+ forward = tf .scan (
206+ forward_function ,
207+ y ,
208+ initializer = tf .ones ((N , self .K )) * (1.0 / self .K ),
209+ )
207210
208211 # backward pass
209- backward = [None ] * (nT + 1 )
210- backward [- 1 ] = tf .ones ((N , self .K )) * (1.0 / self .K )
211- for t in range (nT , 0 , - 1 ):
212+ def backward_function (last_backward , yi ):
212213 # combine transition matrix with observations
213214 combined = tf .multiply (
214- tf .expand_dims (self .P , 0 ), tf .expand_dims (y [:, t - 1 ] , 1 )
215+ tf .expand_dims (self .P , 0 ), tf .expand_dims (yi , 1 )
215216 )
216217 tmp = tf .reduce_sum (
217- tf .multiply (combined , tf .expand_dims (backward [ t ] , 1 )), axis = 2
218+ tf .multiply (combined , tf .expand_dims (last_backward , 1 )), axis = 2
218219 )
219- backward [ t - 1 ] = tmp / tf .expand_dims ( tf . reduce_sum (tmp , axis = 1 ), axis = 1 )
220+ return tmp / tf .reduce_sum (tmp , axis = 1 , keep_dims = True )
220221
221- # remove initial/final probabilities
222- forward = forward [1 :]
223- backward = backward [:- 1 ]
222+ backward = tf .scan (
223+ backward_function ,
224+ tf .reverse (y , [0 ]),
225+ initializer = tf .ones ((N , self .K )) * (1.0 / self .K ),
226+ )
227+ backward = tf .reverse (backward , [0 ])
224228
229+ # combine forward and backward into posterior probabilities
230+ # (recurrent, batch, features)
231+ posterior = forward * backward
232+ posterior = posterior / tf .reduce_sum (posterior , axis = 2 , keep_dims = True )
225233
226- # combine and normalize
227- posterior = [ f * b for f , b in zip ( forward , backward )]
228- posterior = [ p / tf .expand_dims ( tf . reduce_sum ( p , axis = 1 ), axis = 1 ) for p in posterior ]
229- posterior = tf .stack ( posterior , axis = 1 )
234+ # (recurrent, batch, features) -> (batch, recurrent, features)
235+ posterior = tf . transpose ( posterior , ( 1 , 0 , 2 ))
236+ forward = tf .transpose ( forward , ( 1 , 0 , 2 ))
237+ backward = tf .transpose ( backward , ( 1 , 0 , 2 ) )
230238
231239 return posterior , forward , backward
232240
0 commit comments