Skip to content

Commit 2cae5a2

Browse files
author
l2k2
committed
back to simple rnn
1 parent f7b37f2 commit 2cae5a2

File tree

1 file changed

+98
-0
lines changed

1 file changed

+98
-0
lines changed

videos/text-gen/char-gen.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import keras
2+
from keras.models import Sequential
3+
from keras.layers import Dense, Activation
4+
from keras.layers import LSTM
5+
from keras.optimizers import RMSprop
6+
from keras.utils.data_utils import get_file
7+
import numpy as np
8+
import random
9+
import sys
10+
import io
11+
import wandb
12+
from wandb.keras import WandbCallback
13+
import argparse
14+
15+
parser = argparse.ArgumentParser()
16+
parser.add_argument("text", type=str)
17+
18+
args = parser.parse_args()
19+
20+
run = wandb.init()
21+
config = run.config
22+
config.hidden_nodes = 128
23+
config.batch_size = 256
24+
config.file = args.text
25+
config.maxlen = 200
26+
config.step = 3
27+
28+
text = io.open(config.file, encoding='utf-8').read()
29+
chars = sorted(list(set(text)))
30+
31+
char_indices = dict((c, i) for i, c in enumerate(chars))
32+
indices_char = dict((i, c) for i, c in enumerate(chars))
33+
34+
# build a sequence for every <config.step>-th character in the text
35+
36+
sentences = []
37+
next_chars = []
38+
for i in range(0, len(text) - config.maxlen, config.step):
39+
sentences.append(text[i: i + config.maxlen])
40+
next_chars.append(text[i + config.maxlen])
41+
42+
# build up one-hot encoded input x and output y where x is a character
43+
# in the text y is the next character in the text
44+
45+
x = np.zeros((len(sentences), config.maxlen, len(chars)), dtype=np.bool)
46+
y = np.zeros((len(sentences), len(chars)), dtype=np.bool)
47+
for i, sentence in enumerate(sentences):
48+
for t, char in enumerate(sentence):
49+
x[i, t, char_indices[char]] = 1
50+
y[i, char_indices[next_chars[i]]] = 1
51+
52+
model = Sequential()
53+
model.add(SimpleRNN(128, input_shape=(config.maxlen, len(chars))))
54+
model.add(Dense(len(chars), activation='softmax'))
55+
model.compile(loss='categorical_crossentropy', optimizer="rmsprop")
56+
57+
58+
def sample(preds, temperature=1.0):
59+
# helper function to sample an index from a probability array
60+
preds = np.asarray(preds).astype('float64')
61+
preds = np.log(preds) / temperature
62+
exp_preds = np.exp(preds)
63+
preds = exp_preds / np.sum(exp_preds)
64+
probas = np.random.multinomial(1, preds, 1)
65+
return np.argmax(probas)
66+
67+
class SampleText(keras.callbacks.Callback):
68+
def on_epoch_end(self, batch, logs={}):
69+
start_index = random.randint(0, len(text) - config.maxlen - 1)
70+
71+
for diversity in [0.5, 1.2]:
72+
print()
73+
print('----- diversity:', diversity)
74+
75+
generated = ''
76+
sentence = text[start_index: start_index + config.maxlen]
77+
generated += sentence
78+
print('----- Generating with seed: "' + sentence + '"')
79+
sys.stdout.write(generated)
80+
81+
for i in range(200):
82+
x_pred = np.zeros((1, config.maxlen, len(chars)))
83+
for t, char in enumerate(sentence):
84+
x_pred[0, t, char_indices[char]] = 1.
85+
86+
preds = model.predict(x_pred, verbose=0)[0]
87+
next_index = sample(preds, diversity)
88+
next_char = indices_char[next_index]
89+
90+
generated += next_char
91+
sentence = sentence[1:] + next_char
92+
93+
sys.stdout.write(next_char)
94+
sys.stdout.flush()
95+
print()
96+
97+
model.fit(x, y, batch_size=config.batch_size,
98+
epochs=100, callbacks=[SampleText(), WandbCallback()])

0 commit comments

Comments
 (0)