Skip to content

Commit 74de7a8

Browse files
author
l2k2
committed
new video
1 parent 68086b7 commit 74de7a8

File tree

1 file changed

+138
-0
lines changed

1 file changed

+138
-0
lines changed

videos/seq2seq/train.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# adapted from https://blog.keras.io/a-ten-minute-introduction-to-sequence-to-sequence-learning-in-keras.html
2+
3+
from keras.models import Sequential
4+
from keras.layers import LSTM, TimeDistributed, RepeatVector, Dense
5+
import numpy as np
6+
import wandb
7+
from wandb.keras import WandbCallback
8+
9+
wandb.init()
10+
config = wandb.config
11+
12+
class CharacterTable(object):
13+
"""Given a set of characters:
14+
+ Encode them to a one hot integer representation
15+
+ Decode the one hot integer representation to their character output
16+
+ Decode a vector of probabilities to their character output
17+
"""
18+
def __init__(self, chars):
19+
"""Initialize character table.
20+
# Arguments
21+
chars: Characters that can appear in the input.
22+
"""
23+
self.chars = sorted(set(chars))
24+
self.char_indices = dict((c, i) for i, c in enumerate(self.chars))
25+
self.indices_char = dict((i, c) for i, c in enumerate(self.chars))
26+
27+
def encode(self, C, num_rows):
28+
"""One hot encode given string C.
29+
# Arguments
30+
num_rows: Number of rows in the returned one hot encoding. This is
31+
used to keep the # of rows for each data the same.
32+
"""
33+
x = np.zeros((num_rows, len(self.chars)))
34+
for i, c in enumerate(C):
35+
x[i, self.char_indices[c]] = 1
36+
return x
37+
38+
def decode(self, x, calc_argmax=True):
39+
if calc_argmax:
40+
x = x.argmax(axis=-1)
41+
return ''.join(self.indices_char[x] for x in x)
42+
43+
# Parameters for the model and dataset.
44+
config.training_size = 50000
45+
config.digits = 5
46+
config.hidden_size = 128
47+
config.batch_size = 128
48+
49+
# Maximum length of input is 'int + int' (e.g., '345+678'). Maximum length of
50+
# int is DIGITS.
51+
maxlen = config.digits + 1 + config.digits
52+
53+
# All the numbers, plus sign and space for padding.
54+
chars = '0123456789+- '
55+
ctable = CharacterTable(chars)
56+
57+
questions = []
58+
expected = []
59+
seen = set()
60+
print('Generating data...')
61+
while len(questions) < config.training_size:
62+
f = lambda: int(''.join(np.random.choice(list('0123456789'))
63+
for i in range(np.random.randint(1, config.digits + 1))))
64+
a, b = f(), f()
65+
# Skip any addition questions we've already seen
66+
# Also skip any such that x+Y == Y+x (hence the sorting).
67+
key = tuple(sorted((a, b)))
68+
if key in seen:
69+
continue
70+
seen.add(key)
71+
# Pad the data with spaces such that it is always MAXLEN.
72+
q = '{}-{}'.format(a, b)
73+
query = q + ' ' * (maxlen - len(q))
74+
ans = str(a - b)
75+
# Answers can be of maximum size DIGITS + 1.
76+
ans += ' ' * (config.digits + 1 - len(ans))
77+
78+
questions.append(query)
79+
expected.append(ans)
80+
81+
print('Total addition questions:', len(questions))
82+
83+
print('Vectorization...')
84+
x = np.zeros((len(questions), maxlen, len(chars)), dtype=np.bool)
85+
y = np.zeros((len(questions), config.digits + 1, len(chars)), dtype=np.bool)
86+
for i, sentence in enumerate(questions):
87+
x[i] = ctable.encode(sentence, maxlen)
88+
for i, sentence in enumerate(expected):
89+
y[i] = ctable.encode(sentence, config.digits + 1)
90+
91+
# Shuffle (x, y) in unison as the later parts of x will almost all be larger
92+
# digits.
93+
indices = np.arange(len(y))
94+
np.random.shuffle(indices)
95+
x = x[indices]
96+
y = y[indices]
97+
98+
# Explicitly set apart 10% for validation data that we never train over.
99+
split_at = len(x) - len(x) // 10
100+
(x_train, x_val) = x[:split_at], x[split_at:]
101+
(y_train, y_val) = y[:split_at], y[split_at:]
102+
103+
model = Sequential()
104+
model.add(LSTM(config.hidden_size, input_shape=(maxlen, len(chars))))
105+
model.add(RepeatVector(config.digits + 1))
106+
model.add(LSTM(config.hidden_size, return_sequences=True))
107+
model.add(TimeDistributed(Dense(len(chars), activation='softmax')))
108+
model.compile(loss='categorical_crossentropy',
109+
optimizer='adam',
110+
metrics=['accuracy'])
111+
model.summary()
112+
113+
# Train the model each generation and show predictions against the validation
114+
# dataset.
115+
for iteration in range(1, 200):
116+
print()
117+
print('-' * 50)
118+
print('Iteration', iteration)
119+
model.fit(x_train, y_train,
120+
batch_size=config.batch_size,
121+
epochs=1,
122+
validation_data=(x_val, y_val),callbacks=[WandbCallback()])
123+
# Select 10 samples from the validation set at random so we can visualize
124+
# errors.
125+
for i in range(10):
126+
ind = np.random.randint(0, len(x_val))
127+
rowx, rowy = x_val[np.array([ind])], y_val[np.array([ind])]
128+
preds = model.predict_classes(rowx, verbose=0)
129+
q = ctable.decode(rowx[0])
130+
correct = ctable.decode(rowy[0])
131+
guess = ctable.decode(preds[0], calc_argmax=False)
132+
print('Q', q[::-1] if config.reverse else q, end=' ')
133+
print('T', correct, end=' ')
134+
if correct == guess:
135+
print('☑', end=' ')
136+
else:
137+
print('☒', end=' ')
138+
print(guess)

0 commit comments

Comments
 (0)