Skip to content

Commit 594bec7

Browse files
committed
new seq2seq model
1 parent 75b6341 commit 594bec7

File tree

1 file changed

+199
-0
lines changed

1 file changed

+199
-0
lines changed

keras-seq2seq/train.py

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
2+
# -*- coding: utf-8 -*-
3+
'''An implementation of sequence to sequence learning for performing addition
4+
Input: "535+61"
5+
Output: "596"
6+
Padding is handled by using a repeated sentinel character (space)
7+
Input may optionally be reversed, shown to increase performance in many tasks in:
8+
"Learning to Execute"
9+
http://arxiv.org/abs/1410.4615
10+
and
11+
"Sequence to Sequence Learning with Neural Networks"
12+
http://papers.nips.cc/paper/5346-sequence-to-sequence-learning-with-neural-networks.pdf
13+
Theoretically it introduces shorter term dependencies between source and target.
14+
Two digits reversed:
15+
+ One layer LSTM (128 HN), 5k training examples = 99% train/test accuracy in 55 epochs
16+
Three digits reversed:
17+
+ One layer LSTM (128 HN), 50k training examples = 99% train/test accuracy in 100 epochs
18+
Four digits reversed:
19+
+ One layer LSTM (128 HN), 400k training examples = 99% train/test accuracy in 20 epochs
20+
Five digits reversed:
21+
+ One layer LSTM (128 HN), 550k training examples = 99% train/test accuracy in 30 epochs
22+
''' # noqa
23+
24+
from __future__ import print_function
25+
from keras.models import Sequential
26+
from keras import layers
27+
import numpy as np
28+
from six.moves import range
29+
import wandb
30+
from wandb.keras import WandbCallback
31+
32+
wandb.init()
33+
34+
35+
class CharacterTable(object):
36+
"""Given a set of characters:
37+
+ Encode them to a one hot integer representation
38+
+ Decode the one hot integer representation to their character output
39+
+ Decode a vector of probabilities to their character output
40+
"""
41+
def __init__(self, chars):
42+
"""Initialize character table.
43+
# Arguments
44+
chars: Characters that can appear in the input.
45+
"""
46+
self.chars = sorted(set(chars))
47+
self.char_indices = dict((c, i) for i, c in enumerate(self.chars))
48+
self.indices_char = dict((i, c) for i, c in enumerate(self.chars))
49+
50+
def encode(self, C, num_rows):
51+
"""One hot encode given string C.
52+
# Arguments
53+
num_rows: Number of rows in the returned one hot encoding. This is
54+
used to keep the # of rows for each data the same.
55+
"""
56+
x = np.zeros((num_rows, len(self.chars)))
57+
for i, c in enumerate(C):
58+
x[i, self.char_indices[c]] = 1
59+
return x
60+
61+
def decode(self, x, calc_argmax=True):
62+
if calc_argmax:
63+
x = x.argmax(axis=-1)
64+
return ''.join(self.indices_char[x] for x in x)
65+
66+
67+
class colors:
68+
ok = '\033[92m'
69+
fail = '\033[91m'
70+
close = '\033[0m'
71+
72+
# Parameters for the model and dataset.
73+
TRAINING_SIZE = 50000
74+
DIGITS = 3
75+
REVERSE = True
76+
77+
# Maximum length of input is 'int + int' (e.g., '345+678'). Maximum length of
78+
# int is DIGITS.
79+
MAXLEN = DIGITS + 1 + DIGITS
80+
81+
# All the numbers, plus sign and space for padding.
82+
chars = '0123456789+ '
83+
ctable = CharacterTable(chars)
84+
85+
questions = []
86+
expected = []
87+
seen = set()
88+
print('Generating data...')
89+
while len(questions) < TRAINING_SIZE:
90+
f = lambda: int(''.join(np.random.choice(list('0123456789'))
91+
for i in range(np.random.randint(1, DIGITS + 1))))
92+
a, b = f(), f()
93+
# Skip any addition questions we've already seen
94+
# Also skip any such that x+Y == Y+x (hence the sorting).
95+
key = tuple(sorted((a, b)))
96+
if key in seen:
97+
continue
98+
seen.add(key)
99+
# Pad the data with spaces such that it is always MAXLEN.
100+
q = '{}+{}'.format(a, b)
101+
query = q + ' ' * (MAXLEN - len(q))
102+
ans = str(a + b)
103+
# Answers can be of maximum size DIGITS + 1.
104+
ans += ' ' * (DIGITS + 1 - len(ans))
105+
if REVERSE:
106+
# Reverse the query, e.g., '12+345 ' becomes ' 543+21'. (Note the
107+
# space used for padding.)
108+
query = query[::-1]
109+
questions.append(query)
110+
expected.append(ans)
111+
print('Total addition questions:', len(questions))
112+
113+
print('Vectorization...')
114+
x = np.zeros((len(questions), MAXLEN, len(chars)), dtype=np.bool)
115+
y = np.zeros((len(questions), DIGITS + 1, len(chars)), dtype=np.bool)
116+
for i, sentence in enumerate(questions):
117+
x[i] = ctable.encode(sentence, MAXLEN)
118+
for i, sentence in enumerate(expected):
119+
y[i] = ctable.encode(sentence, DIGITS + 1)
120+
121+
# Shuffle (x, y) in unison as the later parts of x will almost all be larger
122+
# digits.
123+
indices = np.arange(len(y))
124+
np.random.shuffle(indices)
125+
x = x[indices]
126+
y = y[indices]
127+
128+
# Explicitly set apart 10% for validation data that we never train over.
129+
split_at = len(x) - len(x) // 10
130+
(x_train, x_val) = x[:split_at], x[split_at:]
131+
(y_train, y_val) = y[:split_at], y[split_at:]
132+
133+
print('Training Data:')
134+
print(x_train.shape)
135+
print(y_train.shape)
136+
137+
print('Validation Data:')
138+
print(x_val.shape)
139+
print(y_val.shape)
140+
141+
# Try replacing GRU, or SimpleRNN.
142+
RNN = layers.LSTM
143+
HIDDEN_SIZE = 128
144+
BATCH_SIZE = 128
145+
LAYERS = 1
146+
147+
print('Build model...')
148+
model = Sequential()
149+
# "Encode" the input sequence using an RNN, producing an output of HIDDEN_SIZE.
150+
# Note: In a situation where your input sequences have a variable length,
151+
# use input_shape=(None, num_feature).
152+
model.add(RNN(HIDDEN_SIZE, input_shape=(MAXLEN, len(chars))))
153+
# As the decoder RNN's input, repeatedly provide with the last hidden state of
154+
# RNN for each time step. Repeat 'DIGITS + 1' times as that's the maximum
155+
# length of output, e.g., when DIGITS=3, max output is 999+999=1998.
156+
model.add(layers.RepeatVector(DIGITS + 1))
157+
# The decoder RNN could be multiple layers stacked or a single layer.
158+
for _ in range(LAYERS):
159+
# By setting return_sequences to True, return not only the last output but
160+
# all the outputs so far in the form of (num_samples, timesteps,
161+
# output_dim). This is necessary as TimeDistributed in the below expects
162+
# the first dimension to be the timesteps.
163+
model.add(RNN(HIDDEN_SIZE, return_sequences=True))
164+
165+
# Apply a dense layer to the every temporal slice of an input. For each of step
166+
# of the output sequence, decide which character should be chosen.
167+
model.add(layers.TimeDistributed(layers.Dense(len(chars))))
168+
model.add(layers.Activation('softmax'))
169+
model.compile(loss='categorical_crossentropy',
170+
optimizer='adam',
171+
metrics=['accuracy'])
172+
model.summary()
173+
174+
# Train the model each generation and show predictions against the validation
175+
# dataset.
176+
for iteration in range(1, 200):
177+
print()
178+
print('-' * 50)
179+
print('Iteration', iteration)
180+
model.fit(x_train, y_train,
181+
batch_size=BATCH_SIZE,
182+
epochs=1,
183+
validation_data=(x_val, y_val),callbacks=[WandbCallback()])
184+
# Select 10 samples from the validation set at random so we can visualize
185+
# errors.
186+
for i in range(10):
187+
ind = np.random.randint(0, len(x_val))
188+
rowx, rowy = x_val[np.array([ind])], y_val[np.array([ind])]
189+
preds = model.predict_classes(rowx, verbose=0)
190+
q = ctable.decode(rowx[0])
191+
correct = ctable.decode(rowy[0])
192+
guess = ctable.decode(preds[0], calc_argmax=False)
193+
print('Q', q[::-1] if REVERSE else q, end=' ')
194+
print('T', correct, end=' ')
195+
if correct == guess:
196+
print(colors.ok + '☑' + colors.close, end=' ')
197+
else:
198+
print(colors.fail + '☒' + colors.close, end=' ')
199+
print(guess)

0 commit comments

Comments
 (0)