9
9
from keras .layers import CuDNNLSTM
10
10
from keras .layers .wrappers import TimeDistributed , Bidirectional
11
11
from attention_decoder import AttentionDecoder
12
- from nmt import simpleNMT
13
12
from reader import Data , Vocabulary
14
13
import numpy as np
15
14
from keras import backend as K
31
30
32
31
33
32
def run_example (model , input_vocabulary , output_vocabulary , text ):
33
+ """Predict a single example"""
34
34
encoded = input_vocabulary .string_to_int (text )
35
35
prediction = model .predict (np .array ([encoded ]))
36
36
prediction = np .argmax (prediction [0 ], axis = - 1 )
37
- return "" .join ([s for s in output_vocabulary .int_to_string (prediction ) if s != "<unk>" ])
37
+ return output_vocabulary .int_to_string (prediction )
38
+
39
+
40
+ def decode (chars , sanitize = False ):
41
+ """Join a list of chars removing <unk> and invalid utf-8"""
42
+ string = "" .join ([c for c in chars if c != "<unk>" ])
43
+ if sanitize :
44
+ string = "" .join (i for i in string if ord (i ) < 2048 )
45
+ return bytes (string , 'utf-8' ).decode ('utf-8' , 'ignore' )
38
46
39
47
40
48
class Examples (Callback ):
49
+ """Keras callback to log examples"""
50
+
41
51
def __init__ (self , viz ):
42
52
self .visualizer = viz
43
53
@@ -53,22 +63,20 @@ def on_epoch_end(self, epoch, logs):
53
63
self .visualizer .proba_model .get_layer (
54
64
"attention_decoder_prob" ).set_weights (weights )
55
65
for i , o in zip (data_in , data_out ):
56
- text = "" .join (
57
- [s for s in input_vocab .int_to_string (i ) if s != "<unk>" ])
58
- truth = "" .join ([s for s in output_vocab .int_to_string (
59
- np .argmax (o , - 1 )) if s != "<unk>" ])
60
- out = run_example (self .model , input_vocab , output_vocab , text )
61
- print (f"{ text } -> { out } ({ truth } )" )
62
- examples .append ([bytes (text , 'utf-8' ).decode ('utf-8' , 'ignore' ), bytes (
63
- out , 'utf-8' ).decode ('utf-8' , 'ignore' ), bytes (truth , 'utf-8' ).decode ('utf-8' , 'ignore' )])
66
+ text = decode (input_vocab .int_to_string (i )).replace ('<eot>' , '' )
67
+ truth = decode (output_vocab .int_to_string (np .argmax (o , - 1 )), True )
68
+ pred = run_example (self .model , input_vocab , output_vocab , text )
69
+ out = decode (pred , True )
70
+ print (f"{ decode (text , True )} -> { out } ({ truth } )" )
71
+ examples .append ([decode (text , True ), out , truth ])
64
72
amap = self .visualizer .attention_map (text )
65
73
if amap :
66
- viz .append (wandb .Image (amap , caption = text ))
74
+ viz .append (wandb .Image (amap ,))
67
75
amap .close ()
68
76
if len (viz ) > 0 :
69
77
logs ["attention_map" ] = viz [:5 ]
70
- wandb .log (
71
- { "examples" : wandb .Table ( data = examples ), ** logs } )
78
+ logs [ "examples" ] = wandb .Table ( data = examples )
79
+ wandb .log ( logs )
72
80
73
81
74
82
def all_acc (y_true , y_pred ):
@@ -94,7 +102,7 @@ def all_acc(y_true, y_pred):
94
102
input_vocab = Vocabulary ('./human_vocab.json' , padding = config .padding )
95
103
output_vocab = Vocabulary ('./machine_vocab.json' , padding = config .padding )
96
104
97
- print ('Loading datasets.' )
105
+ print ('Loading datasets... ' )
98
106
99
107
training = Data (training_data , input_vocab , output_vocab )
100
108
validation = Data (validation_data , input_vocab , output_vocab )
@@ -125,7 +133,7 @@ def build_models(pad_length=config.padding, n_chars=input_vocab.size(), n_labels
125
133
name = 'attention_decoder_prob' ,
126
134
output_dim = n_labels ,
127
135
return_probabilities = True ,
128
- trainable = trainable )(rnn_encoded )
136
+ trainable = False )(rnn_encoded )
129
137
130
138
y_pred = AttentionDecoder (decoder_units ,
131
139
name = 'attention_decoder_1' ,
@@ -137,7 +145,7 @@ def build_models(pad_length=config.padding, n_chars=input_vocab.size(), n_labels
137
145
model .summary ()
138
146
model .compile (optimizer = 'adam' ,
139
147
loss = 'categorical_crossentropy' ,
140
- metrics = ['accuracy' , all_acc ])
148
+ metrics = ['accuracy' ])
141
149
prob_model = Model (inputs = input_ , outputs = y_prob )
142
150
return model , prob_model
143
151
0 commit comments