Skip to content

Commit fc5b51a

Browse files
committed
Refactored sampled softmax
1 parent e215091 commit fc5b51a

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

chatbot/model.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,18 +41,19 @@ def __init__(self, shape, scope=None, dtype=None):
4141

4242
# Projection on the keyboard
4343
with tf.variable_scope('weights_' + self.scope):
44-
self.W = tf.get_variable(
44+
self.W_t = tf.get_variable(
4545
'weights',
4646
shape,
4747
# initializer=tf.truncated_normal_initializer() # TODO: Tune value (fct of input size: 1/sqrt(input_dim))
4848
dtype=dtype
4949
)
5050
self.b = tf.get_variable(
5151
'bias',
52-
shape[1],
52+
shape[0],
5353
initializer=tf.constant_initializer(),
5454
dtype=dtype
5555
)
56+
self.W = tf.transpose(self.W_t)
5657

5758
def getWeights(self):
5859
""" Convenience method for some tf arguments
@@ -114,7 +115,7 @@ def buildNetwork(self):
114115
# Sampled softmax only makes sense if we sample less than vocabulary size.
115116
if 0 < self.args.softmaxSamples < self.textData.getVocabularySize():
116117
outputProjection = ProjectionOp(
117-
(self.args.hiddenSize, self.textData.getVocabularySize()),
118+
(self.textData.getVocabularySize(), self.args.hiddenSize),
118119
scope='softmax_projection',
119120
dtype=self.dtype
120121
)
@@ -124,7 +125,7 @@ def sampledSoftmax(labels, inputs):
124125

125126
# We need to compute the sampled_softmax_loss using 32bit floats to
126127
# avoid numerical instabilities.
127-
localWt = tf.cast(tf.transpose(outputProjection.W), tf.float32)
128+
localWt = tf.cast(outputProjection.W_t, tf.float32)
128129
localB = tf.cast(outputProjection.b, tf.float32)
129130
localInputs = tf.cast(inputs, tf.float32)
130131

0 commit comments

Comments
 (0)