Skip to content

Commit 36c2bc6

Browse files
Update module.py
1 parent b09d7c8 commit 36c2bc6

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

torch/nn/modules/module.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,26 @@ def state_dict(self, destination=None, prefix=''):
365365
if module is not None:
366366
module.state_dict(destination, prefix + name + '.')
367367
return destination
368+
369+
def state_dict_without_buffers(self, destination=None, prefix=''):
370+
"""Returns a dictionary containing a whole state of the module.
371+
372+
Both parameters and persistent buffers (e.g. running averages) are
373+
included. Keys are corresponding parameter and buffer names.
374+
375+
Example:
376+
>>> module.state_dict().keys()
377+
['bias', 'weight']
378+
"""
379+
if destination is None:
380+
destination = OrderedDict()
381+
for name, param in self._parameters.items():
382+
if param is not None:
383+
destination[prefix + name] = param.data
384+
for name, module in self._modules.items():
385+
if module is not None:
386+
module.state_dict_without_buffers(destination, prefix + name + '.')
387+
return destination
368388

369389
def load_state_dict(self, state_dict):
370390
"""Copies parameters and buffers from :attr:`state_dict` into

0 commit comments

Comments
 (0)