1
1
2
2
from keras .models import Sequential
3
- from keras import layers
3
+ from keras . layers import LSTM , TimeDistributed , RepeatVector , Dense
4
4
import numpy as np
5
5
import wandb
6
6
from wandb .keras import WandbCallback
7
7
8
8
wandb .init ()
9
+ config = wandb .config
9
10
10
11
class CharacterTable (object ):
11
12
"""Given a set of characters:
@@ -38,17 +39,16 @@ def decode(self, x, calc_argmax=True):
38
39
x = x .argmax (axis = - 1 )
39
40
return '' .join (self .indices_char [x ] for x in x )
40
41
41
-
42
-
43
-
44
42
# Parameters for the model and dataset.
45
- TRAINING_SIZE = 50000
46
- DIGITS = 3
47
- REVERSE = True
43
+ config .training_size = 50000
44
+ config .digits = 3
45
+ config .reverse = True
46
+ config .hidden_size = 128
47
+ config .batch_size = 128
48
48
49
49
# Maximum length of input is 'int + int' (e.g., '345+678'). Maximum length of
50
50
# int is DIGITS.
51
- MAXLEN = DIGITS + 1 + DIGITS
51
+ maxlen = config . digits + 1 + config . digits
52
52
53
53
# All the numbers, plus sign and space for padding.
54
54
chars = '0123456789+ '
@@ -58,9 +58,9 @@ def decode(self, x, calc_argmax=True):
58
58
expected = []
59
59
seen = set ()
60
60
print ('Generating data...' )
61
- while len (questions ) < TRAINING_SIZE :
61
+ while len (questions ) < config . training_size :
62
62
f = lambda : int ('' .join (np .random .choice (list ('0123456789' ))
63
- for i in range (np .random .randint (1 , DIGITS + 1 ))))
63
+ for i in range (np .random .randint (1 , config . digits + 1 ))))
64
64
a , b = f (), f ()
65
65
# Skip any addition questions we've already seen
66
66
# Also skip any such that x+Y == Y+x (hence the sorting).
@@ -70,25 +70,26 @@ def decode(self, x, calc_argmax=True):
70
70
seen .add (key )
71
71
# Pad the data with spaces such that it is always MAXLEN.
72
72
q = '{}+{}' .format (a , b )
73
- query = q + ' ' * (MAXLEN - len (q ))
73
+ query = q + ' ' * (maxlen - len (q ))
74
74
ans = str (a + b )
75
75
# Answers can be of maximum size DIGITS + 1.
76
- ans += ' ' * (DIGITS + 1 - len (ans ))
77
- if REVERSE :
76
+ ans += ' ' * (config . digits + 1 - len (ans ))
77
+ if config . reverse :
78
78
# Reverse the query, e.g., '12+345 ' becomes ' 543+21'. (Note the
79
79
# space used for padding.)
80
80
query = query [::- 1 ]
81
81
questions .append (query )
82
82
expected .append (ans )
83
+
83
84
print ('Total addition questions:' , len (questions ))
84
85
85
86
print ('Vectorization...' )
86
- x = np .zeros ((len (questions ), MAXLEN , len (chars )), dtype = np .bool )
87
- y = np .zeros ((len (questions ), DIGITS + 1 , len (chars )), dtype = np .bool )
87
+ x = np .zeros ((len (questions ), maxlen , len (chars )), dtype = np .bool )
88
+ y = np .zeros ((len (questions ), config . digits + 1 , len (chars )), dtype = np .bool )
88
89
for i , sentence in enumerate (questions ):
89
- x [i ] = ctable .encode (sentence , MAXLEN )
90
+ x [i ] = ctable .encode (sentence , maxlen )
90
91
for i , sentence in enumerate (expected ):
91
- y [i ] = ctable .encode (sentence , DIGITS + 1 )
92
+ y [i ] = ctable .encode (sentence , config . digits + 1 )
92
93
93
94
# Shuffle (x, y) in unison as the later parts of x will almost all be larger
94
95
# digits.
@@ -110,34 +111,22 @@ def decode(self, x, calc_argmax=True):
110
111
print (x_val .shape )
111
112
print (y_val .shape )
112
113
113
- # Try replacing GRU, or SimpleRNN.
114
- RNN = layers .LSTM
115
- HIDDEN_SIZE = 128
116
- BATCH_SIZE = 128
117
- LAYERS = 1
118
114
119
- print ( 'Build model...' )
115
+
120
116
model = Sequential ()
121
117
# "Encode" the input sequence using an RNN, producing an output of HIDDEN_SIZE.
122
118
# Note: In a situation where your input sequences have a variable length,
123
119
# use input_shape=(None, num_feature).
124
- model .add (RNN ( HIDDEN_SIZE , input_shape = (MAXLEN , len (chars ))))
120
+ model .add (LSTM ( config . hidden_size , input_shape = (maxlen , len (chars ))))
125
121
# As the decoder RNN's input, repeatedly provide with the last hidden state of
126
122
# RNN for each time step. Repeat 'DIGITS + 1' times as that's the maximum
127
123
# length of output, e.g., when DIGITS=3, max output is 999+999=1998.
128
- model .add (layers .RepeatVector (DIGITS + 1 ))
129
- # The decoder RNN could be multiple layers stacked or a single layer.
130
- for _ in range (LAYERS ):
131
- # By setting return_sequences to True, return not only the last output but
132
- # all the outputs so far in the form of (num_samples, timesteps,
133
- # output_dim). This is necessary as TimeDistributed in the below expects
134
- # the first dimension to be the timesteps.
135
- model .add (RNN (HIDDEN_SIZE , return_sequences = True ))
124
+ model .add (RepeatVector (config .digits + 1 ))
125
+ model .add (LSTM (config .hidden_size , return_sequences = True ))
136
126
137
127
# Apply a dense layer to the every temporal slice of an input. For each of step
138
128
# of the output sequence, decide which character should be chosen.
139
- model .add (layers .TimeDistributed (layers .Dense (len (chars ))))
140
- model .add (layers .Activation ('softmax' ))
129
+ model .add (TimeDistributed (Dense (len (chars ), activation = 'softmax' )))
141
130
model .compile (loss = 'categorical_crossentropy' ,
142
131
optimizer = 'adam' ,
143
132
metrics = ['accuracy' ])
@@ -150,7 +139,7 @@ def decode(self, x, calc_argmax=True):
150
139
print ('-' * 50 )
151
140
print ('Iteration' , iteration )
152
141
model .fit (x_train , y_train ,
153
- batch_size = BATCH_SIZE ,
142
+ batch_size = config . batch_size ,
154
143
epochs = 1 ,
155
144
validation_data = (x_val , y_val ),callbacks = [WandbCallback ()])
156
145
# Select 10 samples from the validation set at random so we can visualize
@@ -162,7 +151,7 @@ def decode(self, x, calc_argmax=True):
162
151
q = ctable .decode (rowx [0 ])
163
152
correct = ctable .decode (rowy [0 ])
164
153
guess = ctable .decode (preds [0 ], calc_argmax = False )
165
- print ('Q' , q [::- 1 ] if REVERSE else q , end = ' ' )
154
+ print ('Q' , q [::- 1 ] if config . reverse else q , end = ' ' )
166
155
print ('T' , correct , end = ' ' )
167
156
if correct == guess :
168
157
print ('☑' , end = ' ' )
0 commit comments