Skip to content

Commit ff96c51

Browse files
committed
Complete refactoring of logging into debug, info and error
1 parent 649f5f1 commit ff96c51

File tree

9 files changed

+32
-38
lines changed

9 files changed

+32
-38
lines changed

code/model/agent.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
from utils.constants import *
2020
from utils.strings import *
21-
from utils.util import print_and_log_message, print_and_log_message_list
2221

2322
class Agent(BaseAgent):
2423
'''Deep Trading Agent based on Deep Q Learning'''
@@ -98,7 +97,7 @@ def train(self):
9897

9998
message = 'avg_r: %.4f, avg_l: %.6f, avg_q: %3.6f, avg_ep_r: %.4f, max_ep_r: %.4f, min_ep_r: %.4f, # game: %d' \
10099
% (avg_reward, avg_loss, avg_q, avg_ep_reward, max_ep_reward, min_ep_reward, num_episodes)
101-
print_and_log_message(message, self.logger)
100+
self.logger.info(message)
102101

103102
if max_avg_ep_reward * 0.9 <= avg_ep_reward:
104103
self.sess.run(

code/model/baseagent.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
from utils.constants import *
88
from utils.strings import *
9-
from utils.util import print_and_log_message, print_and_log_message_list
109

1110
class BaseAgent(object):
1211
'''Base class containing all the parameters for reinforcement learning'''
@@ -57,22 +56,22 @@ def saver(self):
5756

5857
def save_model(self, step=None):
5958
message = "Saving checkpoint to {}".format(self.checkpoint_dir)
60-
print_and_log_message(message, self.logger)
59+
self.logger.info(message)
6160
self.saver.save(self.sess, self.checkpoint_dir, global_step=step)
6261

6362
def load_model(self):
6463
message = "Loading checkpoint from {}".format(self.checkpoint_dir)
65-
print_and_log_message(message, self.logger)
64+
self.logger.info(message)
6665

6766
ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir)
6867
if ckpt and ckpt.model_checkpoint_path:
6968
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
7069
fname = os.path.join(self.checkpoint_dir, ckpt_name)
7170
self.saver.restore(self.sess, fname)
7271
message = "Checkpoint successfully loaded from {}".format(fname)
73-
print_and_log_message(message, self.logger)
72+
self.logger.info(message)
7473
return True
7574
else:
7675
message = "Checkpoint could not be loaded from {}".format(self.checkpoint_dir)
77-
print_and_log_message(message, self.logger)
76+
self.logger.info(message)
7877
return False

code/model/deepsense.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from os.path import join
33
import tensorflow as tf
44

5-
from utils.util import print_and_log_message, print_and_log_message_list
65
from utils.constants import *
76
from utils.strings import *
87

code/model/environment.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
from utils.constants import *
55
from utils.strings import *
6-
from utils.util import print_and_log_message_list
76

87
class Environment:
98
'''Exchange Simulator for Bitcoin based upon per minute historical prices'''
@@ -56,7 +55,7 @@ def new_random_episode(self, history, replay_memory):
5655
history.add(state)
5756
replay_memory.add(state, 0.0, 0, False, 0.0)
5857

59-
print_and_log_message_list(message_list, self.logger)
58+
map(self.logger.debug, message_list)
6059

6160
return 1.0
6261

@@ -73,6 +72,9 @@ def act(self, action):
7372
self.short = self.short + 1
7473

7574
reward = (self.long - self.short) * self.unit * self.diffs[self.current]
75+
message = "Reward for timestep {} of episode number {} is {}".format(
76+
self.timesteps, self.episode_number, reward
77+
)
7678
self.timesteps = self.timesteps + 1
7779
if self.timesteps is not self.horizon:
7880
self.current = self.current + 1

code/model/history.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
from utils.constants import *
66
from utils.strings import *
7-
from utils.util import print_and_log_message
87

98
class History:
109
'''Experiance buffer of the behaniour policy of the agent'''
@@ -21,7 +20,8 @@ def __init__(self, logger, config):
2120

2221
def add(self, screen):
2322
if screen.shape != self.dims:
24-
print_and_log_message(INVALID_TIMESTEP, self.logger)
23+
self.logger.error(INVALID_TIMESTEP)
24+
2525
self._history[:-1] = self._history[1:]
2626
self._history[-1] = screen
2727

code/model/replay_memory.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from utils.constants import *
1010
from utils.strings import *
11-
from utils.util import print_and_log_message, print_and_log_message_list
1211

1312
class ReplayMemory:
1413
'''Memory buffer for experiance replay'''
@@ -43,7 +42,8 @@ def __init__(self, logger, config):
4342

4443
def add(self, screen, reward, action, terminal, trade_rem):
4544
if screen.shape != self.dims:
46-
print_and_log_message(INVALID_TIMESTEP, self.logger)
45+
self.logger.error(INVALID_TIMESTEP)
46+
4747
else:
4848
self.actions[self.current] = action
4949
self.rewards[self.current] = reward
@@ -55,7 +55,8 @@ def add(self, screen, reward, action, terminal, trade_rem):
5555

5656
def getState(self, index):
5757
if self.count == 0:
58-
print_and_log_message(REPLAY_MEMORY_ZERO, self.logger)
58+
self.logger.error(REPLAY_MEMORY_ZERO)
59+
5960
else:
6061
index = index % self.count
6162
if index >= self.history_length - 1:
@@ -68,26 +69,26 @@ def getState(self, index):
6869

6970
def save(self):
7071
message = "Saving replay memory to {}".format(self._model_dir)
71-
print_and_log_message(message, self.logger)
72+
self.logger.info(message)
7273
for idx, (name, array) in enumerate(
7374
zip([ACTIONS, REWARDS, SCREENS, TERMINALS, TRADES_REM, PRESTATES, POSTSTATES],
7475
[self.actions, self.rewards, self.screens, self.terminals, self.trades_rem, self.prestates, self.poststates])):
7576
save_npy(array, join(self._model_dir, name))
7677

7778
message = "Replay memory successfully saved to {}".format(self._model_dir)
78-
print_and_log_message(message, self.logger)
79+
self.logger.info(message)
7980

8081
def load(self):
8182
message = "Loading replay memory from {}".format(self._model_dir)
82-
print_and_log_message(message, self.logger)
83+
self.logger.info(message)
8384

8485
for idx, (name, array) in enumerate(
8586
zip([ACTIONS, REWARDS, SCREENS, TERMINALS, TRADES_REM, PRESTATES, POSTSTATES],
8687
[self.actions, self.rewards, self.screens, self.terminals, self.trades_rem, self.prestates, self.poststates])):
8788
array = load_npy(join(self._model_dir, name))
8889

8990
message = "Replay memory successfully loaded from {}".format(self._model_dir)
90-
print_and_log_message(message, self.logger)
91+
self.logger.info(message)
9192

9293
@property
9394
def model_dir(self):
@@ -96,7 +97,7 @@ def model_dir(self):
9697
@property
9798
def sample(self):
9899
if self.count <= self.history_length:
99-
print_and_log_message(REPLAY_MEMORY_INSUFFICIENT, self.logger)
100+
self.logger.error(REPLAY_MEMORY_INSUFFICIENT)
100101

101102
else:
102103
indexes = []

code/model/util.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import tensorflow as tf
99

10-
from utils.util import print_and_log_message_list, print_and_log_message
1110
from utils.strings import *
1211

1312
def clipped_error(x):
@@ -20,10 +19,10 @@ def clipped_error(x):
2019
def save_npy(obj, path, logger):
2120
np.save(path, obj)
2221
message = " [*] saved at {}".format(path)
23-
print_and_log_message(message, logger)
22+
logger.info(message)
2423

2524
def load_npy(path, logger):
2625
obj = np.load(path)
2726
message = " [*] loaded from {}".format(path)
28-
print_and_log_message(message, logger)
27+
logger.info(message)
2928
return obj

code/process/processor.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
from utils.constants import *
66
from utils.strings import *
7-
from utils.util import print_and_log_message, print_and_log_message_list
87

98
class Processor:
109
'''Preprocessor for Bitcoin prices dataset as obtained by following the procedure
@@ -34,15 +33,15 @@ def timestamp_blocks(self):
3433
def preprocess(self):
3534
data = pd.read_csv(self.dataset_path)
3635
message = 'Columns found in the dataset {}'.format(data.columns)
37-
print_and_log_message(message, self.logger)
36+
self.logger.info(message)
3837
data = data.dropna()
3938
start_time_stamp = data['Timestamp'][0]
4039
timestamps = data['Timestamp'].apply(lambda x: (x - start_time_stamp) / 60)
4140
timestamps = timestamps - range(timestamps.shape[0])
4241
data.insert(0, 'blocks', timestamps)
4342
blocks = data.groupby('blocks')
4443
message = 'Number of blocks of continuous prices found are {}'.format(len(blocks))
45-
print_and_log_message(message, self.logger)
44+
self.logger.info(message)
4645

4746
self._data_blocks = []
4847
distinct_episodes = 0
@@ -60,7 +59,7 @@ def preprocess(self):
6059
data = None
6160
message_list = ['Number of usable blocks obtained from the dataset are {}'.format(len(self._data_blocks))]
6261
message_list.append('Number of distinct episodes for the current configuration are {}'.format(distinct_episodes))
63-
print_and_log_message_list(message_list, self.logger)
62+
map(self.logger.info, message_list)
6463

6564
def generate_attributes(self):
6665
self._diff_blocks = []

code/utils/util.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from ConfigParser import ConfigParser
2+
import sys
23
import logging
34

45
from constants import *
@@ -10,25 +11,20 @@ def get_config_parser(filename):
1011
return config
1112

1213
def get_logger(config):
13-
logging.basicConfig(level=logging.DEBUG)
1414
formatter = \
1515
logging.Formatter('%(asctime)s - %(pathname)s - Line No %(lineno)s - Level %(levelname)s - %(message)s')
1616
info_handler = logging.FileHandler(config[LOG_FILE])
1717
info_handler.setLevel(logging.INFO)
1818
info_handler.setFormatter(formatter)
1919

20+
out_handler = logging.StreamHandler(sys.stdout)
21+
out_handler.setLevel(logging.DEBUG)
22+
out_handler.setFormatter(formatter)
23+
2024
logger = logging.getLogger(name=DEEP_TRADING_AGENT)
21-
logger.setLevel(logging.DEBUG)
25+
logger.setLevel(logging.INFO)
2226
logger.addHandler(info_handler)
23-
logger.propagate = False
27+
2428
return logger
2529

26-
def print_and_log_message(message, logger):
27-
logging.info(message)
28-
logger.info(message)
29-
30-
def print_and_log_message_list(message_list, logger):
31-
for message in message_list:
32-
logging.info(message)
33-
logger.info(message)
3430

0 commit comments

Comments
 (0)