Skip to content

Commit 2cc7fb4

Browse files
committed
Polished attention
1 parent a354adb commit 2cc7fb4

File tree

2 files changed

+32
-57
lines changed

2 files changed

+32
-57
lines changed

lstm/attention/train.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from keras.layers import CuDNNLSTM
1010
from keras.layers.wrappers import TimeDistributed, Bidirectional
1111
from attention_decoder import AttentionDecoder
12-
from nmt import simpleNMT
1312
from reader import Data, Vocabulary
1413
import numpy as np
1514
from keras import backend as K
@@ -31,13 +30,24 @@
3130

3231

3332
def run_example(model, input_vocabulary, output_vocabulary, text):
33+
"""Predict a single example"""
3434
encoded = input_vocabulary.string_to_int(text)
3535
prediction = model.predict(np.array([encoded]))
3636
prediction = np.argmax(prediction[0], axis=-1)
37-
return "".join([s for s in output_vocabulary.int_to_string(prediction) if s != "<unk>"])
37+
return output_vocabulary.int_to_string(prediction)
38+
39+
40+
def decode(chars, sanitize=False):
41+
"""Join a list of chars removing <unk> and invalid utf-8"""
42+
string = "".join([c for c in chars if c != "<unk>"])
43+
if sanitize:
44+
string = "".join(i for i in string if ord(i) < 2048)
45+
return bytes(string, 'utf-8').decode('utf-8', 'ignore')
3846

3947

4048
class Examples(Callback):
49+
"""Keras callback to log examples"""
50+
4151
def __init__(self, viz):
4252
self.visualizer = viz
4353

@@ -53,22 +63,20 @@ def on_epoch_end(self, epoch, logs):
5363
self.visualizer.proba_model.get_layer(
5464
"attention_decoder_prob").set_weights(weights)
5565
for i, o in zip(data_in, data_out):
56-
text = "".join(
57-
[s for s in input_vocab.int_to_string(i) if s != "<unk>"])
58-
truth = "".join([s for s in output_vocab.int_to_string(
59-
np.argmax(o, -1)) if s != "<unk>"])
60-
out = run_example(self.model, input_vocab, output_vocab, text)
61-
print(f"{text} -> {out} ({truth})")
62-
examples.append([bytes(text, 'utf-8').decode('utf-8', 'ignore'), bytes(
63-
out, 'utf-8').decode('utf-8', 'ignore'), bytes(truth, 'utf-8').decode('utf-8', 'ignore')])
66+
text = decode(input_vocab.int_to_string(i)).replace('<eot>', '')
67+
truth = decode(output_vocab.int_to_string(np.argmax(o, -1)), True)
68+
pred = run_example(self.model, input_vocab, output_vocab, text)
69+
out = decode(pred, True)
70+
print(f"{decode(text, True)} -> {out} ({truth})")
71+
examples.append([decode(text, True), out, truth])
6472
amap = self.visualizer.attention_map(text)
6573
if amap:
66-
viz.append(wandb.Image(amap, caption=text))
74+
viz.append(wandb.Image(amap,))
6775
amap.close()
6876
if len(viz) > 0:
6977
logs["attention_map"] = viz[:5]
70-
wandb.log(
71-
{"examples": wandb.Table(data=examples), **logs})
78+
logs["examples"] = wandb.Table(data=examples)
79+
wandb.log(logs)
7280

7381

7482
def all_acc(y_true, y_pred):
@@ -94,7 +102,7 @@ def all_acc(y_true, y_pred):
94102
input_vocab = Vocabulary('./human_vocab.json', padding=config.padding)
95103
output_vocab = Vocabulary('./machine_vocab.json', padding=config.padding)
96104

97-
print('Loading datasets.')
105+
print('Loading datasets...')
98106

99107
training = Data(training_data, input_vocab, output_vocab)
100108
validation = Data(validation_data, input_vocab, output_vocab)
@@ -125,7 +133,7 @@ def build_models(pad_length=config.padding, n_chars=input_vocab.size(), n_labels
125133
name='attention_decoder_prob',
126134
output_dim=n_labels,
127135
return_probabilities=True,
128-
trainable=trainable)(rnn_encoded)
136+
trainable=False)(rnn_encoded)
129137

130138
y_pred = AttentionDecoder(decoder_units,
131139
name='attention_decoder_1',
@@ -137,7 +145,7 @@ def build_models(pad_length=config.padding, n_chars=input_vocab.size(), n_labels
137145
model.summary()
138146
model.compile(optimizer='adam',
139147
loss='categorical_crossentropy',
140-
metrics=['accuracy', all_acc])
148+
metrics=['accuracy'])
141149
prob_model = Model(inputs=input_, outputs=y_prob)
142150
return model, prob_model
143151

lstm/attention/util.py

Lines changed: 8 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
1+
from reader import Vocabulary
2+
import matplotlib.patches as mpatches
3+
import matplotlib.pyplot as plt
4+
import numpy as np
5+
import os
6+
import argparse
17
import matplotlib # pylint: disable
28
matplotlib.use("Agg") # pylint: disable
3-
import argparse
4-
import os
5-
import numpy as np
6-
import matplotlib.pyplot as plt
7-
import matplotlib.patches as mpatches
8-
from reader import Vocabulary
9+
910

1011
def run_example(model, input_vocabulary, output_vocabulary, text):
1112
encoded = input_vocabulary.string_to_int(text)
1213
prediction = model.predict(np.array([encoded]))
1314
prediction = np.argmax(prediction[0], axis=-1)
1415
return output_vocabulary.int_to_string(prediction)
1516

17+
1618
class Visualizer(object):
1719

1820
def __init__(self, input_vocab, output_vocab):
@@ -82,38 +84,3 @@ def attention_map(self, text):
8284
# ax.legend(loc='best')
8385

8486
return plt
85-
86-
87-
def main(examples, args):
88-
print('Total Number of Examples:', len(examples))
89-
weights_file = os.path.expanduser(args.weights)
90-
print('Weights loading from:', weights_file)
91-
viz = Visualizer(padding=args.padding,
92-
input_vocab=args.human_vocab,
93-
output_vocab=args.machine_vocab)
94-
print('Loading models')
95-
pred_model = simpleNMT(trainable=False,
96-
pad_length=args.padding,
97-
n_chars=viz.input_vocab.size(),
98-
n_labels=viz.output_vocab.size())
99-
100-
pred_model.load_weights(weights_file, by_name=True)
101-
pred_model.compile(optimizer='adam', loss='categorical_crossentropy')
102-
103-
proba_model = simpleNMT(trainable=False,
104-
pad_length=args.padding,
105-
n_chars=viz.input_vocab.size(),
106-
n_labels=viz.output_vocab.size(),
107-
return_probabilities=True)
108-
109-
proba_model.load_weights(weights_file, by_name=True)
110-
proba_model.compile(optimizer='adam', loss='categorical_crossentropy')
111-
112-
viz.set_models(pred_model, proba_model)
113-
114-
print('Models loaded')
115-
116-
for example in examples:
117-
viz.attention_map(example)
118-
119-
print('Completed visualizations')

0 commit comments

Comments
 (0)