@@ -83,7 +83,7 @@ def get_num_weights(handle, rnn_desc, x_desc, datatype):
8383 datatype
8484 ))
8585 elem_size = cudnn ._sizeofmap [datatype ]
86- assert ( weight_size .value % elem_size == 0 )
86+ assert weight_size .value % elem_size == 0
8787 return weight_size .value // elem_size
8888
8989
@@ -142,10 +142,11 @@ def get_parameters(fn, handle, weight_buf):
142142 ctypes .byref (nb_dims ),
143143 ctypes .c_void_p (filter_dim_a .data_ptr ())))
144144
145- filter_dim_a .resize_ (nb_dims .value )
145+ assert nb_dims .value <= min_dim
146+ filter_dim_a = filter_dim_a [:nb_dims .value ]
146147 elem_size = cudnn ._sizeofmap [fn .datatype ]
147148 offset_bytes = (matrix_pointer .value - weight_buf .data_ptr ())
148- assert ( offset_bytes % elem_size == 0 )
149+ assert offset_bytes % elem_size == 0
149150 offset = offset_bytes // elem_size
150151
151152 # for all the RNN types provided by CUDNN, all the ih weights
@@ -154,13 +155,13 @@ def get_parameters(fn, handle, weight_buf):
154155 # Since we're storing all the weights in a single tensor anyway,
155156 # might as well merge the CUDNN ones into a single tensor as well
156157 if linear_id == 0 or linear_id == num_linear_layers / 2 :
157- assert ( filter_dim_a .prod () == filter_dim_a [0 ])
158+ assert filter_dim_a .prod () == filter_dim_a [0 ]
158159 param = fn .weight_buf .new ().set_ (
159160 weight_buf .storage (), offset ,
160161 filter_dim_a [0 ] * num_linear_layers // 2 , filter_dim_a [2 ])
161162 layer_params .append (param )
162163 else :
163- assert ( cur_offset == offset )
164+ assert cur_offset == offset
164165
165166 cur_offset = offset + filter_dim_a [0 ]
166167
@@ -172,7 +173,7 @@ def get_parameters(fn, handle, weight_buf):
172173def _copyParams (params_from , params_to ):
173174 for layer_params_from , layer_params_to in zip (params_from , params_to ):
174175 for param_from , param_to in zip (layer_params_from , layer_params_to ):
175- assert ( param_from .type () == param_to .type () )
176+ assert param_from .type () == param_to .type ()
176177 param_to .copy_ (param_from )
177178
178179
@@ -206,9 +207,9 @@ def forward(fn, input, hx, weight, output, hy):
206207 output_size = _output_size (fn )
207208 x = input .contiguous ()
208209 output .resize_ (* output_size )
209- hy .resize_ (* hidden_size ). zero_ ()
210+ hy .resize_ (* hidden_size )
210211 if cy is not None :
211- cy .resize_ (* hidden_size ). zero_ ()
212+ cy .resize_ (* hidden_size )
212213 y = output
213214
214215 # init descriptors
0 commit comments