Skip to content

Commit 9ab0104

Browse files
committed
new attention example
1 parent 0d14fc5 commit 9ab0104

File tree

1 file changed

+169
-0
lines changed

1 file changed

+169
-0
lines changed
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
2+
from keras.preprocessing import sequence
3+
from keras.models import Sequential
4+
from keras.layers import Dense, Dropout, Activation
5+
from keras.layers import Embedding, CuDNNLSTM
6+
from keras.layers import Conv1D, Flatten, Layer
7+
from keras import initializers, regularizers, constraints
8+
9+
from keras.datasets import imdb
10+
import wandb
11+
from wandb.keras import WandbCallback
12+
import imdb
13+
import numpy as np
14+
from keras.preprocessing import text
15+
import keras.backend as K
16+
17+
# from https://gist.github.com/cbaziotis/7ef97ccf71cbc14366835198c09809d2
18+
19+
def dot_product(x, kernel):
20+
"""
21+
Wrapper for dot product operation, in order to be compatible with both
22+
Theano and Tensorflow
23+
Args:
24+
x (): input
25+
kernel (): weights
26+
Returns:
27+
"""
28+
if K.backend() == 'tensorflow':
29+
return K.squeeze(K.dot(x, K.expand_dims(kernel)), axis=-1)
30+
else:
31+
return K.dot(x, kernel)
32+
33+
34+
class AttentionWithContext(Layer):
35+
"""
36+
Attention operation, with a context/query vector, for temporal data.
37+
Supports Masking.
38+
Follows the work of Yang et al. [https://www.cs.cmu.edu/~diyiy/docs/naacl16.pdf]
39+
"Hierarchical Attention Networks for Document Classification"
40+
by using a context vector to assist the attention
41+
# Input shape
42+
3D tensor with shape: `(samples, steps, features)`.
43+
# Output shape
44+
2D tensor with shape: `(samples, features)`.
45+
How to use:
46+
Just put it on top of an RNN Layer (GRU/LSTM/SimpleRNN) with return_sequences=True.
47+
The dimensions are inferred based on the output shape of the RNN.
48+
Note: The layer has been tested with Keras 2.0.6
49+
Example:
50+
model.add(LSTM(64, return_sequences=True))
51+
model.add(AttentionWithContext())
52+
# next add a Dense layer (for classification/regression) or whatever...
53+
"""
54+
55+
def __init__(self,
56+
W_regularizer=None, u_regularizer=None, b_regularizer=None,
57+
W_constraint=None, u_constraint=None, b_constraint=None,
58+
bias=True, **kwargs):
59+
60+
self.supports_masking = True
61+
self.init = initializers.get('glorot_uniform')
62+
63+
self.W_regularizer = regularizers.get(W_regularizer)
64+
self.u_regularizer = regularizers.get(u_regularizer)
65+
self.b_regularizer = regularizers.get(b_regularizer)
66+
67+
self.W_constraint = constraints.get(W_constraint)
68+
self.u_constraint = constraints.get(u_constraint)
69+
self.b_constraint = constraints.get(b_constraint)
70+
71+
self.bias = bias
72+
super(AttentionWithContext, self).__init__(**kwargs)
73+
74+
def build(self, input_shape):
75+
assert len(input_shape) == 3
76+
77+
self.W = self.add_weight((input_shape[-1], input_shape[-1],),
78+
initializer=self.init,
79+
name='{}_W'.format(self.name),
80+
regularizer=self.W_regularizer,
81+
constraint=self.W_constraint)
82+
if self.bias:
83+
self.b = self.add_weight((input_shape[-1],),
84+
initializer='zero',
85+
name='{}_b'.format(self.name),
86+
regularizer=self.b_regularizer,
87+
constraint=self.b_constraint)
88+
89+
self.u = self.add_weight((input_shape[-1],),
90+
initializer=self.init,
91+
name='{}_u'.format(self.name),
92+
regularizer=self.u_regularizer,
93+
constraint=self.u_constraint)
94+
95+
super(AttentionWithContext, self).build(input_shape)
96+
97+
def compute_mask(self, input, input_mask=None):
98+
# do not pass the mask to the next layers
99+
return None
100+
101+
def call(self, x, mask=None):
102+
uit = dot_product(x, self.W)
103+
104+
if self.bias:
105+
uit += self.b
106+
107+
uit = K.tanh(uit)
108+
ait = dot_product(uit, self.u)
109+
110+
a = K.exp(ait)
111+
112+
# apply mask after the exp. will be re-normalized next
113+
if mask is not None:
114+
# Cast the mask to floatX to avoid float64 upcasting in theano
115+
a *= K.cast(mask, K.floatx())
116+
117+
# in some cases especially in the early stages of training the sum may be almost zero
118+
# and this results in NaN's. A workaround is to add a very small positive number ε to the sum.
119+
# a /= K.cast(K.sum(a, axis=1, keepdims=True), K.floatx())
120+
a /= K.cast(K.sum(a, axis=1, keepdims=True) + K.epsilon(), K.floatx())
121+
122+
a = K.expand_dims(a)
123+
weighted_input = x * a
124+
return K.sum(weighted_input, axis=1)
125+
126+
def compute_output_shape(self, input_shape):
127+
return input_shape[0], input_shape[-1]
128+
129+
130+
131+
wandb.init()
132+
config = wandb.config
133+
134+
# set parameters:
135+
config.vocab_size = 1000
136+
config.maxlen = 300
137+
config.batch_size = 32
138+
config.embedding_dims = 50
139+
config.filters = 250
140+
config.kernel_size = 3
141+
config.hidden_dims = 100
142+
config.epochs = 10
143+
144+
(X_train, y_train), (X_test, y_test) = imdb.load_imdb()
145+
146+
tokenizer = text.Tokenizer(num_words=config.vocab_size)
147+
tokenizer.fit_on_texts(X_train)
148+
X_train = tokenizer.texts_to_sequences(X_train)
149+
X_test = tokenizer.texts_to_sequences(X_test)
150+
151+
X_train = sequence.pad_sequences(X_train, maxlen=config.maxlen)
152+
X_test = sequence.pad_sequences(X_test, maxlen=config.maxlen)
153+
154+
model = Sequential()
155+
model.add(Embedding(config.vocab_size,
156+
config.embedding_dims,
157+
input_length=config.maxlen))
158+
model.add(CuDNNLSTM(config.hidden_dims, return_sequences=True))
159+
model.add(AttentionWithContext())
160+
model.add(Dense(1, activation='sigmoid'))
161+
model.compile(loss='binary_crossentropy',
162+
optimizer='rmsprop',
163+
metrics=['accuracy'])
164+
model.summary()
165+
166+
model.fit(X_train, y_train,
167+
batch_size=config.batch_size,
168+
epochs=config.epochs,
169+
validation_data=(X_test, y_test), callbacks=[WandbCallback()])

0 commit comments

Comments
 (0)