Skip to content

Commit 006d3a5

Browse files
committed
replace python for loop with tf.scan
1 parent 9fac589 commit 006d3a5

File tree

3 files changed

+38
-25
lines changed

3 files changed

+38
-25
lines changed

tensorflow_hmm/hmm.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -193,40 +193,48 @@ def forward_backward(self, y):
193193

194194
# set up
195195
N = tf.shape(y)[0]
196-
nT = self.length or y.shape[1]
197-
# nT = tf.shape(y)[1]
198196

199-
forward = []
197+
# y (batch, recurrent, features) -> (recurrent, batch, features)
198+
y = tf.transpose(y, (1, 0, 2))
200199

201200
# forward pass
202-
forward.append(tf.ones((N, self.K)) * (1.0 / self.K))
203-
for t in range(nT):
204-
tmp = tf.multiply(tf.matmul(forward[t], self.P), y[:, t])
205-
206-
forward.append(tmp / tf.expand_dims(tf.reduce_sum(tmp, axis=1), axis=1))
201+
def forward_function(last_forward, yi):
202+
tmp = tf.multiply(tf.matmul(last_forward, self.P), yi)
203+
return tmp / tf.reduce_sum(tmp, axis=1, keep_dims=True)
204+
205+
forward = tf.scan(
206+
forward_function,
207+
y,
208+
initializer=tf.ones((N, self.K)) * (1.0 / self.K),
209+
)
207210

208211
# backward pass
209-
backward = [None] * (nT + 1)
210-
backward[-1] = tf.ones((N, self.K)) * (1.0 / self.K)
211-
for t in range(nT, 0, -1):
212+
def backward_function(last_backward, yi):
212213
# combine transition matrix with observations
213214
combined = tf.multiply(
214-
tf.expand_dims(self.P, 0), tf.expand_dims(y[:, t - 1], 1)
215+
tf.expand_dims(self.P, 0), tf.expand_dims(yi, 1)
215216
)
216217
tmp = tf.reduce_sum(
217-
tf.multiply(combined, tf.expand_dims(backward[t], 1)), axis=2
218+
tf.multiply(combined, tf.expand_dims(last_backward, 1)), axis=2
218219
)
219-
backward[t - 1] = tmp / tf.expand_dims(tf.reduce_sum(tmp, axis=1), axis=1)
220+
return tmp / tf.reduce_sum(tmp, axis=1, keep_dims=True)
220221

221-
# remove initial/final probabilities
222-
forward = forward[1:]
223-
backward = backward[:-1]
222+
backward = tf.scan(
223+
backward_function,
224+
tf.reverse(y, [0]),
225+
initializer=tf.ones((N, self.K)) * (1.0 / self.K),
226+
)
227+
backward = tf.reverse(backward, [0])
224228

229+
# combine forward and backward into posterior probabilities
230+
# (recurrent, batch, features)
231+
posterior = forward * backward
232+
posterior = posterior / tf.reduce_sum(posterior, axis=2, keep_dims=True)
225233

226-
# combine and normalize
227-
posterior = [f * b for f, b in zip(forward, backward)]
228-
posterior = [p / tf.expand_dims(tf.reduce_sum(p, axis=1), axis=1) for p in posterior]
229-
posterior = tf.stack(posterior, axis=1)
234+
# (recurrent, batch, features) -> (batch, recurrent, features)
235+
posterior = tf.transpose(posterior, (1, 0, 2))
236+
forward = tf.transpose(forward, (1, 0, 2))
237+
backward = tf.transpose(backward, (1, 0, 2))
230238

231239
return posterior, forward, backward
232240

tensorflow_hmm/hmm_layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def __init__(self, states, length=None):
1414
for i in range(states):
1515
self.P[i, i] = 0.99
1616

17-
self.hmm = HMMTensorflow(self.P, length=length)
17+
self.hmm = HMMTensorflow(self.P)
1818

1919
super(HMMLayer, self).__init__()
2020

test/test_hmm.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,16 +101,21 @@ def test_hmm_tf_latch_forward_backward_multiple_batch(hmm_tf_latch, hmm_latch):
101101
y = lik(np.array([0, 0, 1, 1]))
102102
y = np.stack([y] * 3)
103103

104-
np_posterior, np_forward, _ = hmm_latch.forward_backward(y)
104+
np_posterior, np_forward, np_backward = hmm_latch.forward_backward(y)
105105
print('tf')
106-
g_posterior, g_forward, _ = hmm_tf_latch.forward_backward(y)
106+
g_posterior, g_forward, g_backward = hmm_tf_latch.forward_backward(y)
107107
tf_posterior = tf.Session().run(g_posterior)
108+
tf_forward = tf.Session().run(g_forward)
109+
tf_backward = tf.Session().run(g_backward)
108110

111+
assert np.isclose(np_forward, tf_forward).all()
112+
print('np_backward', np_backward)
113+
print('tf_backward', tf_backward)
114+
assert np.isclose(np_backward, tf_backward).all()
109115
print('np_posterior', np_posterior)
110116
print('tf_posterior', tf_posterior)
111117
assert np.isclose(np_posterior, tf_posterior).all()
112118

113-
114119
def test_lik():
115120
yin = np.array([0, 0.25, 0.5, 0.75, 1])
116121
y = lik(yin)

0 commit comments

Comments
 (0)