@@ -15,13 +15,10 @@ class RecurrentNetworkComponent(object):
1515
1616 sequential = True
1717
18- def __init__ (self , forget , name = None , * args , ** kwargs ):
18+ def __init__ (self , forget = None , name = None , * args , ** kwargs ):
1919 self .recurrentConns = []
2020 self .maxoffset = 0
21- if forget :
22- self .increment = 0
23- else :
24- self .increment = 1
21+ self .forget = forget
2522
2623 def __str__ (self ):
2724 s = super (RecurrentNetworkComponent , self ).__str__ ()
@@ -51,26 +48,29 @@ def activate(self, inpt):
5148 """Do one transformation of an input and return the result."""
5249 self .inputbuffer [self .offset ] = inpt
5350 self .forward ()
54- return self .outputbuffer [self .offset - self .increment ].copy ()
51+ if self .forget :
52+ return self .outputbuffer [self .offset ].copy ()
53+ else :
54+ return self .outputbuffer [self .offset - 1 ].copy ()
5555
5656 def backActivate (self , outerr ):
5757 """Do one transformation of an output error outerr backward and return
5858 the error on the input."""
59- self .outputerror [self .offset - self . increment ] = outerr
59+ self .outputerror [self .offset - 1 ] = outerr
6060 self .backward ()
6161 return self .inputerror [self .offset ].copy ()
6262
6363 def forward (self ):
6464 """Produce the output from the input."""
65- if not (self .offset + self . increment < self .inputbuffer .shape [0 ]):
65+ if not (self .offset + 1 < self .inputbuffer .shape [0 ]):
6666 self ._growBuffers ()
6767 super (RecurrentNetworkComponent , self ).forward ()
68- self .offset += self . increment
68+ self .offset += 1
6969 self .maxoffset = max (self .offset , self .maxoffset )
7070
7171 def backward (self ):
7272 """Produce the input error from the output error."""
73- self .offset -= self . increment
73+ self .offset -= 1
7474 super (RecurrentNetworkComponent , self ).backward ()
7575
7676 def _isLastTimestep (self ):
@@ -79,6 +79,9 @@ def _isLastTimestep(self):
7979 def _forwardImplementation (self , inbuf , outbuf ):
8080 assert self .sorted , ".sortModules() has not been called"
8181
82+ if self .forget :
83+ self .offset += 1
84+
8285 index = 0
8386 offset = self .offset
8487 for m in self .inmodules :
@@ -87,19 +90,26 @@ def _forwardImplementation(self, inbuf, outbuf):
8790
8891 if offset > 0 :
8992 for c in self .recurrentConns :
90- c .forward (offset - self . increment , offset )
93+ c .forward (offset - 1 , offset )
9194
9295 for m in self .modulesSorted :
9396 m .forward ()
9497 for c in self .connections [m ]:
9598 c .forward (offset , offset )
9699
100+ if self .forget :
101+ for m in self .modules :
102+ m .shift (- 1 )
103+ offset -= 1
104+ self .offset -= 2
105+
97106 index = 0
98107 for m in self .outmodules :
99108 outbuf [index :index + m .outdim ] = m .outputbuffer [offset ]
100109 index += m .outdim
101110
102111 def _backwardImplementation (self , outerr , inerr , outbuf , inbuf ):
112+ assert not self .forget , "Cannot back propagate a forgetful network"
103113 assert self .sorted , ".sortModules() has not been called"
104114 index = 0
105115 offset = self .offset
@@ -109,7 +119,7 @@ def _backwardImplementation(self, outerr, inerr, outbuf, inbuf):
109119
110120 if not self ._isLastTimestep ():
111121 for c in self .recurrentConns :
112- c .backward (offset , offset + self . increment )
122+ c .backward (offset , offset + 1 )
113123
114124 for m in reversed (self .modulesSorted ):
115125 for c in self .connections [m ]:
0 commit comments