Skip to content

Commit 2fcd4a5

Browse files
committed
minor bugfix for forgetting recurrent networks
Signed-off-by: Tom Schaul <[email protected]>
1 parent a4af9ce commit 2fcd4a5

File tree

3 files changed

+10
-4
lines changed

3 files changed

+10
-4
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,6 @@
88
build
99
dist
1010
docs/sphinx/.build
11-
.DS_Store
11+
.DS_Store
12+
/pybrain/rl/environments/cartpole/fast_version/cartpolewrap.cpp
13+
/pybrain/rl/environments/cartpole/fast_version/cartpolewrap.pyd

pybrain/structure/networks/recurrent.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,10 @@ class RecurrentNetwork(RecurrentNetworkComponent, Network):
146146

147147
bufferlist = Network.bufferlist
148148

149-
def __init__(self, forget=False, *args, **kwargs):
149+
def __init__(self, *args, **kwargs):
150150
Network.__init__(self, *args, **kwargs)
151+
if 'forget' in kwargs:
152+
forget = kwargs['forget']
153+
else:
154+
forget = False
151155
RecurrentNetworkComponent.__init__(self, forget, *args, **kwargs)

pybrain/tests/unittests/test_peephole_mdlstm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@
4444
True
4545
4646
List all the states again, explicitly (buffer size is 8 by now).
47-
>>> fListToString(N['mdlstm'].outputbuffer[:,1], 3)
48-
'[0.4 , 0.4 , 0.814 , 0.407 , -0.152, -0.152, 0 , 0 ]'
47+
>>> fListToString(N['mdlstm'].outputbuffer[:,1], 2)
48+
'[0.4 , 0.4 , 0.81 , 0.41 , -0.15, -0.15, 0 , 0 ]'
4949
5050
"""
5151

0 commit comments

Comments
 (0)