88
99from utils .constants import *
1010from utils .strings import *
11- from utils .util import print_and_log_message , print_and_log_message_list
1211
1312class ReplayMemory :
1413 '''Memory buffer for experiance replay'''
@@ -43,7 +42,8 @@ def __init__(self, logger, config):
4342
4443 def add (self , screen , reward , action , terminal , trade_rem ):
4544 if screen .shape != self .dims :
46- print_and_log_message (INVALID_TIMESTEP , self .logger )
45+ self .logger .error (INVALID_TIMESTEP )
46+
4747 else :
4848 self .actions [self .current ] = action
4949 self .rewards [self .current ] = reward
@@ -55,7 +55,8 @@ def add(self, screen, reward, action, terminal, trade_rem):
5555
5656 def getState (self , index ):
5757 if self .count == 0 :
58- print_and_log_message (REPLAY_MEMORY_ZERO , self .logger )
58+ self .logger .error (REPLAY_MEMORY_ZERO )
59+
5960 else :
6061 index = index % self .count
6162 if index >= self .history_length - 1 :
@@ -68,26 +69,26 @@ def getState(self, index):
6869
6970 def save (self ):
7071 message = "Saving replay memory to {}" .format (self ._model_dir )
71- print_and_log_message ( message , self .logger )
72+ self .logger . info ( message )
7273 for idx , (name , array ) in enumerate (
7374 zip ([ACTIONS , REWARDS , SCREENS , TERMINALS , TRADES_REM , PRESTATES , POSTSTATES ],
7475 [self .actions , self .rewards , self .screens , self .terminals , self .trades_rem , self .prestates , self .poststates ])):
7576 save_npy (array , join (self ._model_dir , name ))
7677
7778 message = "Replay memory successfully saved to {}" .format (self ._model_dir )
78- print_and_log_message ( message , self .logger )
79+ self .logger . info ( message )
7980
8081 def load (self ):
8182 message = "Loading replay memory from {}" .format (self ._model_dir )
82- print_and_log_message ( message , self .logger )
83+ self .logger . info ( message )
8384
8485 for idx , (name , array ) in enumerate (
8586 zip ([ACTIONS , REWARDS , SCREENS , TERMINALS , TRADES_REM , PRESTATES , POSTSTATES ],
8687 [self .actions , self .rewards , self .screens , self .terminals , self .trades_rem , self .prestates , self .poststates ])):
8788 array = load_npy (join (self ._model_dir , name ))
8889
8990 message = "Replay memory successfully loaded from {}" .format (self ._model_dir )
90- print_and_log_message ( message , self .logger )
91+ self .logger . info ( message )
9192
9293 @property
9394 def model_dir (self ):
@@ -96,7 +97,7 @@ def model_dir(self):
9697 @property
9798 def sample (self ):
9899 if self .count <= self .history_length :
99- print_and_log_message ( REPLAY_MEMORY_INSUFFICIENT , self .logger )
100+ self .logger . error ( REPLAY_MEMORY_INSUFFICIENT )
100101
101102 else :
102103 indexes = []
0 commit comments