1
-
2
- '''Example script to generate text from Nietzsche's writings.
3
- At least 20 epochs are required before the generated text
4
- starts sounding coherent.
5
- It is recommended to run this script on GPU, as recurrent
6
- networks are quite computationally intensive.
7
- If you try this script on new data, make sure your corpus
8
- has at least ~100k characters. ~1M is better.
9
- '''
10
-
11
1
import keras
12
2
from keras .models import Sequential
13
3
from keras .layers import Dense , Activation
19
9
import sys
20
10
import io
21
11
import wandb
22
- from wandb .wandb_keras import WandbKerasCallback
23
- from keras .callbacks import ModelCheckpoint
24
-
25
-
12
+ from wandb .keras import WandbCallback
26
13
import argparse
27
14
28
-
29
15
parser = argparse .ArgumentParser ()
30
- parser .add_argument ("text" , type = str ,
31
- help = "the text file to learn from" )
16
+ parser .add_argument ("text" , type = str )
32
17
33
18
args = parser .parse_args ()
34
19
35
20
run = wandb .init ()
36
21
config = run .config
37
22
config .hidden_nodes = 128
23
+ config .batch_size = 256
38
24
config .file = args .text
25
+ config .maxlen = 200
26
+ config .step = 3
39
27
40
-
41
- path = args .text
42
- text = io .open (path , encoding = 'utf-8' ).read ().lower ()
43
- print ('corpus length:' , len (text ))
44
-
28
+ text = io .open (config .file , encoding = 'utf-8' ).read ()
45
29
chars = sorted (list (set (text )))
46
- print ( 'total chars:' , len ( chars ))
30
+
47
31
char_indices = dict ((c , i ) for i , c in enumerate (chars ))
48
32
indices_char = dict ((i , c ) for i , c in enumerate (chars ))
49
33
50
- # cut the text in semi-redundant sequences of maxlen characters
51
- maxlen = 40
52
- step = 3
34
+ # build a sequence for every <config.step>-th character in the text
35
+
53
36
sentences = []
54
37
next_chars = []
55
- for i in range (0 , len (text ) - maxlen , step ):
56
- sentences .append (text [i : i + maxlen ])
57
- next_chars .append (text [i + maxlen ])
58
- print ('nb sequences:' , len (sentences ))
38
+ for i in range (0 , len (text ) - config .maxlen , config .step ):
39
+ sentences .append (text [i : i + config .maxlen ])
40
+ next_chars .append (text [i + config .maxlen ])
41
+
42
+ # build up one-hot encoded input x and output y where x is a character
43
+ # in the text y is the next character in the text
59
44
60
- print ('Vectorization...' )
61
- x = np .zeros ((len (sentences ), maxlen , len (chars )), dtype = np .bool )
45
+ x = np .zeros ((len (sentences ), config .maxlen , len (chars )), dtype = np .bool )
62
46
y = np .zeros ((len (sentences ), len (chars )), dtype = np .bool )
63
47
for i , sentence in enumerate (sentences ):
64
48
for t , char in enumerate (sentence ):
65
49
x [i , t , char_indices [char ]] = 1
66
50
y [i , char_indices [next_chars [i ]]] = 1
67
51
68
-
69
- # build the model: a single LSTM
70
- print ('Build model...' )
71
52
model = Sequential ()
72
- model .add (LSTM (128 , input_shape = (maxlen , len (chars ))))
53
+ model .add (LSTM (128 , input_shape = (config . maxlen , len (chars ))))
73
54
model .add (Dense (len (chars ), activation = 'softmax' ))
74
-
75
-
76
- optimizer = RMSprop (lr = 0.01 )
77
- model .compile (loss = 'categorical_crossentropy' , optimizer = optimizer )
55
+ model .compile (loss = 'categorical_crossentropy' , optimizer = "rmsprop" )
78
56
79
57
80
58
def sample (preds , temperature = 1.0 ):
@@ -88,20 +66,20 @@ def sample(preds, temperature=1.0):
88
66
89
67
class SampleText (keras .callbacks .Callback ):
90
68
def on_epoch_end (self , batch , logs = {}):
91
- start_index = random .randint (0 , len (text ) - maxlen - 1 )
69
+ start_index = random .randint (0 , len (text ) - config . maxlen - 1 )
92
70
93
- for diversity in [0.2 , 0.5 , 1.0 , 1.2 ]:
71
+ for diversity in [0.5 , 1.2 ]:
94
72
print ()
95
73
print ('----- diversity:' , diversity )
96
74
97
75
generated = ''
98
- sentence = text [start_index : start_index + maxlen ]
76
+ sentence = text [start_index : start_index + config . maxlen ]
99
77
generated += sentence
100
78
print ('----- Generating with seed: "' + sentence + '"' )
101
79
sys .stdout .write (generated )
102
80
103
- for i in range (50 ):
104
- x_pred = np .zeros ((1 , maxlen , len (chars )))
81
+ for i in range (200 ):
82
+ x_pred = np .zeros ((1 , config . maxlen , len (chars )))
105
83
for t , char in enumerate (sentence ):
106
84
x_pred [0 , t , char_indices [char ]] = 1.
107
85
@@ -115,10 +93,6 @@ def on_epoch_end(self, batch, logs={}):
115
93
sys .stdout .write (next_char )
116
94
sys .stdout .flush ()
117
95
print ()
118
- # train the model, output generated text after each iteration
119
- filepath = str (run .dir )+ "/model-{epoch:02d}-{loss:.4f}.hdf5"
120
-
121
-
122
- model .fit (x , y ,
123
- batch_size = config .hidden_nodes ,
124
- epochs = 1000 , callbacks = [SampleText (), WandbKerasCallback ()])
96
+
97
+ model .fit (x , y , batch_size = config .batch_size ,
98
+ epochs = 1000 , callbacks = [SampleText (), WandbCallback ()])
0 commit comments