Skip to content

Commit ac49412

Browse files
author
l2k2
committed
cleanup
1 parent e5d6730 commit ac49412

File tree

1 file changed

+26
-52
lines changed

1 file changed

+26
-52
lines changed

keras-lstm/lstm-train.py

Lines changed: 26 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,3 @@
1-
2-
'''Example script to generate text from Nietzsche's writings.
3-
At least 20 epochs are required before the generated text
4-
starts sounding coherent.
5-
It is recommended to run this script on GPU, as recurrent
6-
networks are quite computationally intensive.
7-
If you try this script on new data, make sure your corpus
8-
has at least ~100k characters. ~1M is better.
9-
'''
10-
111
import keras
122
from keras.models import Sequential
133
from keras.layers import Dense, Activation
@@ -19,62 +9,50 @@
199
import sys
2010
import io
2111
import wandb
22-
from wandb.wandb_keras import WandbKerasCallback
23-
from keras.callbacks import ModelCheckpoint
24-
25-
12+
from wandb.keras import WandbCallback
2613
import argparse
2714

28-
2915
parser = argparse.ArgumentParser()
30-
parser.add_argument("text", type=str,
31-
help="the text file to learn from")
16+
parser.add_argument("text", type=str)
3217

3318
args = parser.parse_args()
3419

3520
run = wandb.init()
3621
config = run.config
3722
config.hidden_nodes = 128
23+
config.batch_size = 256
3824
config.file = args.text
25+
config.maxlen = 200
26+
config.step = 3
3927

40-
41-
path = args.text
42-
text = io.open(path, encoding='utf-8').read().lower()
43-
print('corpus length:', len(text))
44-
28+
text = io.open(config.file, encoding='utf-8').read()
4529
chars = sorted(list(set(text)))
46-
print('total chars:', len(chars))
30+
4731
char_indices = dict((c, i) for i, c in enumerate(chars))
4832
indices_char = dict((i, c) for i, c in enumerate(chars))
4933

50-
# cut the text in semi-redundant sequences of maxlen characters
51-
maxlen = 40
52-
step = 3
34+
# build a sequence for every <config.step>-th character in the text
35+
5336
sentences = []
5437
next_chars = []
55-
for i in range(0, len(text) - maxlen, step):
56-
sentences.append(text[i: i + maxlen])
57-
next_chars.append(text[i + maxlen])
58-
print('nb sequences:', len(sentences))
38+
for i in range(0, len(text) - config.maxlen, config.step):
39+
sentences.append(text[i: i + config.maxlen])
40+
next_chars.append(text[i + config.maxlen])
41+
42+
# build up one-hot encoded input x and output y where x is a character
43+
# in the text y is the next character in the text
5944

60-
print('Vectorization...')
61-
x = np.zeros((len(sentences), maxlen, len(chars)), dtype=np.bool)
45+
x = np.zeros((len(sentences), config.maxlen, len(chars)), dtype=np.bool)
6246
y = np.zeros((len(sentences), len(chars)), dtype=np.bool)
6347
for i, sentence in enumerate(sentences):
6448
for t, char in enumerate(sentence):
6549
x[i, t, char_indices[char]] = 1
6650
y[i, char_indices[next_chars[i]]] = 1
6751

68-
69-
# build the model: a single LSTM
70-
print('Build model...')
7152
model = Sequential()
72-
model.add(LSTM(128, input_shape=(maxlen, len(chars))))
53+
model.add(LSTM(128, input_shape=(config.maxlen, len(chars))))
7354
model.add(Dense(len(chars), activation='softmax'))
74-
75-
76-
optimizer = RMSprop(lr=0.01)
77-
model.compile(loss='categorical_crossentropy', optimizer=optimizer)
55+
model.compile(loss='categorical_crossentropy', optimizer="rmsprop")
7856

7957

8058
def sample(preds, temperature=1.0):
@@ -88,20 +66,20 @@ def sample(preds, temperature=1.0):
8866

8967
class SampleText(keras.callbacks.Callback):
9068
def on_epoch_end(self, batch, logs={}):
91-
start_index = random.randint(0, len(text) - maxlen - 1)
69+
start_index = random.randint(0, len(text) - config.maxlen - 1)
9270

93-
for diversity in [0.2, 0.5, 1.0, 1.2]:
71+
for diversity in [0.5, 1.2]:
9472
print()
9573
print('----- diversity:', diversity)
9674

9775
generated = ''
98-
sentence = text[start_index: start_index + maxlen]
76+
sentence = text[start_index: start_index + config.maxlen]
9977
generated += sentence
10078
print('----- Generating with seed: "' + sentence + '"')
10179
sys.stdout.write(generated)
10280

103-
for i in range(50):
104-
x_pred = np.zeros((1, maxlen, len(chars)))
81+
for i in range(200):
82+
x_pred = np.zeros((1, config.maxlen, len(chars)))
10583
for t, char in enumerate(sentence):
10684
x_pred[0, t, char_indices[char]] = 1.
10785

@@ -115,10 +93,6 @@ def on_epoch_end(self, batch, logs={}):
11593
sys.stdout.write(next_char)
11694
sys.stdout.flush()
11795
print()
118-
# train the model, output generated text after each iteration
119-
filepath=str(run.dir)+"/model-{epoch:02d}-{loss:.4f}.hdf5"
120-
121-
122-
model.fit(x, y,
123-
batch_size=config.hidden_nodes,
124-
epochs=1000, callbacks=[SampleText(), WandbKerasCallback()])
96+
97+
model.fit(x, y, batch_size=config.batch_size,
98+
epochs=1000, callbacks=[SampleText(), WandbCallback()])

0 commit comments

Comments
 (0)