Skip to content

Commit 9db7787

Browse files
dasguptarapaszke
authored andcommitted
Updating __getitem__ and __len__ for containers (pytorch#1544)
1 parent d1a4467 commit 9db7787

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

torch/nn/modules/container.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,18 @@ def __init__(self, *args):
5050
self.add_module(str(idx), module)
5151

5252
def __getitem__(self, idx):
53-
if idx < 0 or idx >= len(self._modules):
53+
if not (-len(self) <= idx < len(self)):
5454
raise IndexError('index {} is out of range'.format(idx))
55+
if idx < 0:
56+
idx += len(self)
5557
it = iter(self._modules.values())
5658
for i in range(idx):
5759
next(it)
5860
return next(it)
5961

62+
def __len__(self):
63+
return len(self._modules)
64+
6065
def forward(self, input):
6166
for module in self._modules.values():
6267
input = module(input)
@@ -92,6 +97,8 @@ def __init__(self, modules=None):
9297
self += modules
9398

9499
def __getitem__(self, idx):
100+
if not (-len(self) <= idx < len(self)):
101+
raise IndexError('index {} is out of range'.format(idx))
95102
if idx < 0:
96103
idx += len(self)
97104
return self._modules[str(idx)]
@@ -161,6 +168,8 @@ def __init__(self, parameters=None):
161168
self += parameters
162169

163170
def __getitem__(self, idx):
171+
if not (-len(self) <= idx < len(self)):
172+
raise IndexError('index {} is out of range'.format(idx))
164173
if idx < 0:
165174
idx += len(self)
166175
return self._parameters[str(idx)]

0 commit comments

Comments
 (0)