@@ -41,18 +41,19 @@ def __init__(self, shape, scope=None, dtype=None):
4141
4242        # Projection on the keyboard 
4343        with  tf .variable_scope ('weights_'  +  self .scope ):
44-             self .W  =  tf .get_variable (
44+             self .W_t  =  tf .get_variable (
4545                'weights' ,
4646                shape ,
4747                # initializer=tf.truncated_normal_initializer()  # TODO: Tune value (fct of input size: 1/sqrt(input_dim)) 
4848                dtype = dtype 
4949            )
5050            self .b  =  tf .get_variable (
5151                'bias' ,
52-                 shape [1 ],
52+                 shape [0 ],
5353                initializer = tf .constant_initializer (),
5454                dtype = dtype 
5555            )
56+             self .W  =  tf .transpose (self .W_t )
5657
5758    def  getWeights (self ):
5859        """ Convenience method for some tf arguments 
@@ -114,7 +115,7 @@ def buildNetwork(self):
114115        # Sampled softmax only makes sense if we sample less than vocabulary size. 
115116        if  0  <  self .args .softmaxSamples  <  self .textData .getVocabularySize ():
116117            outputProjection  =  ProjectionOp (
117-                 (self .args . hiddenSize , self .textData . getVocabularySize () ),
118+                 (self .textData . getVocabularySize () , self .args . hiddenSize ),
118119                scope = 'softmax_projection' ,
119120                dtype = self .dtype 
120121            )
@@ -124,7 +125,7 @@ def sampledSoftmax(labels, inputs):
124125
125126                # We need to compute the sampled_softmax_loss using 32bit floats to 
126127                # avoid numerical instabilities. 
127-                 localWt      =  tf .cast (tf . transpose ( outputProjection .W ),  tf .float32 )
128+                 localWt      =  tf .cast (outputProjection .W_t ,              tf .float32 )
128129                localB       =  tf .cast (outputProjection .b ,               tf .float32 )
129130                localInputs  =  tf .cast (inputs ,                           tf .float32 )
130131
0 commit comments