Skip to content

Commit a4af9ce

Browse files
committed
Merge branch 'master' of https://github.com/pybrain/pybrain
2 parents e065926 + e188bce commit a4af9ce

File tree

2 files changed

+38
-13
lines changed

2 files changed

+38
-13
lines changed

pybrain/structure/modules/module.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
__author__ = 'Daan Wierstra and Tom Schaul'
22

3-
from scipy import zeros
3+
from scipy import append, zeros
44

55
from pybrain.utilities import abstractMethod, Named
66

@@ -86,6 +86,21 @@ def reset(self):
8686
buf = getattr(self, buffername)
8787
buf[:] = zeros(l)
8888

89+
def shift(self, items):
90+
"""Shift all buffers up or down a defined number of items on offset axis.
91+
Negative values indicate backward shift."""
92+
if items == 0:
93+
return
94+
self.offset += items
95+
for buffername, l in self.bufferlist:
96+
buf = getattr(self, buffername)
97+
assert abs(items) <= len(buf), "Cannot shift further than length of buffer."
98+
fill = zeros((abs(items), len(buf[0])))
99+
if items < 0:
100+
buf[:] = append(buf[-items:], fill, 0)
101+
else:
102+
buf[:] = append(fill ,buf[0:-items] , 0)
103+
89104
def activateOnDataset(self, dataset):
90105
"""Run the module's forward pass on the given dataset unconditionally
91106
and return the output."""

pybrain/structure/networks/recurrent.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,10 @@ class RecurrentNetworkComponent(object):
1515

1616
sequential = True
1717

18-
def __init__(self, forget, name=None, *args, **kwargs):
18+
def __init__(self, forget=None, name=None, *args, **kwargs):
1919
self.recurrentConns = []
2020
self.maxoffset = 0
21-
if forget:
22-
self.increment = 0
23-
else:
24-
self.increment = 1
21+
self.forget = forget
2522

2623
def __str__(self):
2724
s = super(RecurrentNetworkComponent, self).__str__()
@@ -51,26 +48,29 @@ def activate(self, inpt):
5148
"""Do one transformation of an input and return the result."""
5249
self.inputbuffer[self.offset] = inpt
5350
self.forward()
54-
return self.outputbuffer[self.offset - self.increment].copy()
51+
if self.forget:
52+
return self.outputbuffer[self.offset].copy()
53+
else:
54+
return self.outputbuffer[self.offset - 1].copy()
5555

5656
def backActivate(self, outerr):
5757
"""Do one transformation of an output error outerr backward and return
5858
the error on the input."""
59-
self.outputerror[self.offset - self.increment] = outerr
59+
self.outputerror[self.offset - 1] = outerr
6060
self.backward()
6161
return self.inputerror[self.offset].copy()
6262

6363
def forward(self):
6464
"""Produce the output from the input."""
65-
if not (self.offset + self.increment < self.inputbuffer.shape[0]):
65+
if not (self.offset + 1 < self.inputbuffer.shape[0]):
6666
self._growBuffers()
6767
super(RecurrentNetworkComponent, self).forward()
68-
self.offset += self.increment
68+
self.offset += 1
6969
self.maxoffset = max(self.offset, self.maxoffset)
7070

7171
def backward(self):
7272
"""Produce the input error from the output error."""
73-
self.offset -= self.increment
73+
self.offset -= 1
7474
super(RecurrentNetworkComponent, self).backward()
7575

7676
def _isLastTimestep(self):
@@ -79,6 +79,9 @@ def _isLastTimestep(self):
7979
def _forwardImplementation(self, inbuf, outbuf):
8080
assert self.sorted, ".sortModules() has not been called"
8181

82+
if self.forget:
83+
self.offset += 1
84+
8285
index = 0
8386
offset = self.offset
8487
for m in self.inmodules:
@@ -87,19 +90,26 @@ def _forwardImplementation(self, inbuf, outbuf):
8790

8891
if offset > 0:
8992
for c in self.recurrentConns:
90-
c.forward(offset - self.increment, offset)
93+
c.forward(offset - 1, offset)
9194

9295
for m in self.modulesSorted:
9396
m.forward()
9497
for c in self.connections[m]:
9598
c.forward(offset, offset)
9699

100+
if self.forget:
101+
for m in self.modules:
102+
m.shift(-1)
103+
offset -= 1
104+
self.offset -= 2
105+
97106
index = 0
98107
for m in self.outmodules:
99108
outbuf[index:index + m.outdim] = m.outputbuffer[offset]
100109
index += m.outdim
101110

102111
def _backwardImplementation(self, outerr, inerr, outbuf, inbuf):
112+
assert not self.forget, "Cannot back propagate a forgetful network"
103113
assert self.sorted, ".sortModules() has not been called"
104114
index = 0
105115
offset = self.offset
@@ -109,7 +119,7 @@ def _backwardImplementation(self, outerr, inerr, outbuf, inbuf):
109119

110120
if not self._isLastTimestep():
111121
for c in self.recurrentConns:
112-
c.backward(offset, offset + self.increment)
122+
c.backward(offset, offset + 1)
113123

114124
for m in reversed(self.modulesSorted):
115125
for c in self.connections[m]:

0 commit comments

Comments
 (0)