Skip to content

Commit 0180e63

Browse files
committed
Remove unnecessary zero_() calls in cuDNN RNN
1 parent 95c6ae0 commit 0180e63

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

torch/backends/cudnn/rnn.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def get_num_weights(handle, rnn_desc, x_desc, datatype):
8383
datatype
8484
))
8585
elem_size = cudnn._sizeofmap[datatype]
86-
assert(weight_size.value % elem_size == 0)
86+
assert weight_size.value % elem_size == 0
8787
return weight_size.value // elem_size
8888

8989

@@ -142,10 +142,11 @@ def get_parameters(fn, handle, weight_buf):
142142
ctypes.byref(nb_dims),
143143
ctypes.c_void_p(filter_dim_a.data_ptr())))
144144

145-
filter_dim_a.resize_(nb_dims.value)
145+
assert nb_dims.value <= min_dim
146+
filter_dim_a = filter_dim_a[:nb_dims.value]
146147
elem_size = cudnn._sizeofmap[fn.datatype]
147148
offset_bytes = (matrix_pointer.value - weight_buf.data_ptr())
148-
assert(offset_bytes % elem_size == 0)
149+
assert offset_bytes % elem_size == 0
149150
offset = offset_bytes // elem_size
150151

151152
# for all the RNN types provided by CUDNN, all the ih weights
@@ -154,13 +155,13 @@ def get_parameters(fn, handle, weight_buf):
154155
# Since we're storing all the weights in a single tensor anyway,
155156
# might as well merge the CUDNN ones into a single tensor as well
156157
if linear_id == 0 or linear_id == num_linear_layers / 2:
157-
assert(filter_dim_a.prod() == filter_dim_a[0])
158+
assert filter_dim_a.prod() == filter_dim_a[0]
158159
param = fn.weight_buf.new().set_(
159160
weight_buf.storage(), offset,
160161
filter_dim_a[0] * num_linear_layers // 2, filter_dim_a[2])
161162
layer_params.append(param)
162163
else:
163-
assert(cur_offset == offset)
164+
assert cur_offset == offset
164165

165166
cur_offset = offset + filter_dim_a[0]
166167

@@ -172,7 +173,7 @@ def get_parameters(fn, handle, weight_buf):
172173
def _copyParams(params_from, params_to):
173174
for layer_params_from, layer_params_to in zip(params_from, params_to):
174175
for param_from, param_to in zip(layer_params_from, layer_params_to):
175-
assert(param_from.type() == param_to.type())
176+
assert param_from.type() == param_to.type()
176177
param_to.copy_(param_from)
177178

178179

@@ -206,9 +207,9 @@ def forward(fn, input, hx, weight, output, hy):
206207
output_size = _output_size(fn)
207208
x = input.contiguous()
208209
output.resize_(*output_size)
209-
hy.resize_(*hidden_size).zero_()
210+
hy.resize_(*hidden_size)
210211
if cy is not None:
211-
cy.resize_(*hidden_size).zero_()
212+
cy.resize_(*hidden_size)
212213
y = output
213214

214215
# init descriptors

0 commit comments

Comments
 (0)