Skip to content

Commit 70a4342

Browse files
JeanKossaififacebook-github-bot
authored andcommitted
Simplify init._calculate_fan_in_and_fan_out (pytorch#53522)
Summary: This uses the shape of the tensor instead of directly indexing it. This is useful when extending PyTorch's tensor class, e.g. for lazy access. Since the `init` sub-module doesn't check for `torch_function`, it is not possibly to override its functions. Explicitly indexing the tensor will force a call to tensor() and reconstruct the full tensor/explicitly access the elements. Simply using the shape allows to avoid that. Fixes pytorch#53540 Pull Request resolved: pytorch#53522 Reviewed By: anjali411 Differential Revision: D26947794 Pulled By: jbschlosser fbshipit-source-id: 80cd65efed16383f21363cee2eb404c9bc05971c
1 parent a76b473 commit 70a4342

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

torch/nn/init.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,10 @@ def _calculate_fan_in_and_fan_out(tensor):
274274
num_output_fmaps = tensor.size(0)
275275
receptive_field_size = 1
276276
if tensor.dim() > 2:
277-
receptive_field_size = tensor[0][0].numel()
277+
# math.prod is not always available, accumulate the product manually
278+
# we could use functools.reduce but that is not supported by TorchScript
279+
for s in tensor.shape[2:]:
280+
receptive_field_size *= s
278281
fan_in = num_input_fmaps * receptive_field_size
279282
fan_out = num_output_fmaps * receptive_field_size
280283

0 commit comments

Comments
 (0)