@@ -35,42 +35,43 @@ class HMMNumpy(HMM):
3535
3636 def forward_backward (self , y ):
3737 # set up
38- nT = y .shape [0 ]
39- posterior = np .zeros ((nT , self .K ))
40- forward = np .zeros ((nT + 1 , self .K ))
41- backward = np .zeros ((nT + 1 , self .K ))
38+ if y .ndim == 2 :
39+ y = y [np .newaxis , ...]
40+ nT = y .shape [1 ]
41+ nB = y .shape [0 ]
42+ posterior = np .zeros ((nB , nT , self .K ))
43+ forward = np .zeros ((nB , nT + 1 , self .K ))
44+ backward = np .zeros ((nB , nT + 1 , self .K ))
4245
4346 # forward pass
44- forward [0 , :] = 1.0 / self .K
47+ forward [:, 0 , :] = 1.0 / self .K
4548 for t in range (nT ):
4649 tmp = np .multiply (
47- np .matmul (forward [t , :], self .P ),
48- y [t ]
50+ np .matmul (forward [:, t , :], self .P ),
51+ y [:, t ]
4952 )
50-
51- forward [t + 1 , :] = tmp / np .sum (tmp )
53+ # normalize
54+ forward [:, t + 1 , :] = tmp / np .sum (tmp , axis = 1 )[:, np . newaxis ]
5255
5356 # backward pass
54- backward [- 1 , :] = 1.0 / self .K
57+ backward [:, - 1 , :] = 1.0 / self .K
5558 for t in range (nT , 0 , - 1 ):
56- tmp = np .matmul (
57- np .matmul (
58- self .P , np .diag (y [t - 1 ])
59- ),
60- backward [t , :].transpose ()
61- ).transpose ()
62-
63- backward [t - 1 , :] = tmp / np .sum (tmp )
59+ # TODO[marcel]: double check whether y[:,t-1] should be y[:,t]
60+ tmp = np .matmul (self .P , (y [:, t - 1 ] * backward [:, t , :]).T ).T
61+ # normalize
62+ backward [:, t - 1 , :] = tmp / np .sum (tmp , axis = 1 )[:, np .newaxis ]
6463
65- # remove initial/final probabilities
66- forward = forward [1 :, :]
67- backward = backward [:- 1 , :]
64+ # remove initial/final probabilities and squeeze for non-batched tests
65+ forward = np . squeeze ( forward [:, 1 :, :])
66+ backward = np . squeeze ( backward [:, : - 1 , :])
6867
68+ # TODO[marcel]: posterior missing initial probabilities
6969 # combine and normalize
7070 posterior = np .array (forward ) * np .array (backward )
7171 # [:,None] expands sum to be correct size
72- posterior = posterior / np .sum (posterior , 1 )[:, None ]
72+ posterior = posterior / np .sum (posterior , axis = - 1 )[..., np . newaxis ]
7373
74+ # squeeze for non-batched tests
7475 return posterior , forward , backward
7576
7677 def _viterbi_partial_forward (self , scores ):
0 commit comments