Skip to content

Commit d82cad3

Browse files
chenyuntcsoumith
authored andcommitted
implement nn.Module.__dir__ (pytorch#1142)
1 parent 9504246 commit d82cad3

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

test/test_nn.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,19 @@ def test_children(self):
467467
s = nn.Sequential(l1, l2, l1, l2, subnet)
468468
self.assertEqual(list(s.children()), [l1, l2, subnet])
469469

470+
def test_dir(self):
471+
linear = nn.Linear(2, 2)
472+
linear._test_submodule = nn.Linear(2, 2)
473+
linear._test_parameter = Parameter(torch.Tensor(2, 2))
474+
linear.register_buffer('_test_buffer', torch.Tensor(2, 2))
475+
keys = linear.__dir__()
476+
self.assertIn('_test_submodule', keys)
477+
self.assertIn('_test_parameter', keys)
478+
self.assertIn('_test_buffer', keys)
479+
480+
for key in keys:
481+
self.assertTrue(hasattr(linear, key))
482+
470483
def test_named_children(self):
471484
l1 = nn.Linear(2, 2)
472485
l2 = nn.Linear(2, 2)

torch/nn/modules/module.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,3 +462,12 @@ def __repr__(self):
462462
tmpstr = tmpstr + ' (' + key + '): ' + modstr + '\n'
463463
tmpstr = tmpstr + ')'
464464
return tmpstr
465+
466+
def __dir__(self):
467+
module_attrs = dir(self.__class__)
468+
attrs = list(self.__dict__.keys())
469+
parameters = list(self._parameters.keys())
470+
modules = list(self._modules.keys())
471+
buffers = list(self._buffers.keys())
472+
keys = module_attrs + attrs + parameters + modules + buffers
473+
return sorted(keys)

0 commit comments

Comments
 (0)