@@ -41,18 +41,19 @@ def __init__(self, shape, scope=None, dtype=None):
4141
4242 # Projection on the keyboard
4343 with tf .variable_scope ('weights_' + self .scope ):
44- self .W = tf .get_variable (
44+ self .W_t = tf .get_variable (
4545 'weights' ,
4646 shape ,
4747 # initializer=tf.truncated_normal_initializer() # TODO: Tune value (fct of input size: 1/sqrt(input_dim))
4848 dtype = dtype
4949 )
5050 self .b = tf .get_variable (
5151 'bias' ,
52- shape [1 ],
52+ shape [0 ],
5353 initializer = tf .constant_initializer (),
5454 dtype = dtype
5555 )
56+ self .W = tf .transpose (self .W_t )
5657
5758 def getWeights (self ):
5859 """ Convenience method for some tf arguments
@@ -114,7 +115,7 @@ def buildNetwork(self):
114115 # Sampled softmax only makes sense if we sample less than vocabulary size.
115116 if 0 < self .args .softmaxSamples < self .textData .getVocabularySize ():
116117 outputProjection = ProjectionOp (
117- (self .args . hiddenSize , self .textData . getVocabularySize () ),
118+ (self .textData . getVocabularySize () , self .args . hiddenSize ),
118119 scope = 'softmax_projection' ,
119120 dtype = self .dtype
120121 )
@@ -124,7 +125,7 @@ def sampledSoftmax(labels, inputs):
124125
125126 # We need to compute the sampled_softmax_loss using 32bit floats to
126127 # avoid numerical instabilities.
127- localWt = tf .cast (tf . transpose ( outputProjection .W ), tf .float32 )
128+ localWt = tf .cast (outputProjection .W_t , tf .float32 )
128129 localB = tf .cast (outputProjection .b , tf .float32 )
129130 localInputs = tf .cast (inputs , tf .float32 )
130131
0 commit comments