Skip to content

Commit 431f574

Browse files
fix duplicate log (microsoft#1661)
* fix duplicate log * fix unit test * fix log * fix_duplicate_log * fix_duplicate_log * add comments --------- Co-authored-by: Linlang <[email protected]>
1 parent b604fe5 commit 431f574

File tree

2 files changed

+23
-14
lines changed

2 files changed

+23
-14
lines changed

qlib/config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,11 @@ def register_from_C(config, skip_register=True):
173173
"filters": ["field_not_found"],
174174
}
175175
},
176-
"loggers": {"qlib": {"level": logging.DEBUG, "handlers": ["console"]}},
176+
# Normally this should be set to `False` to avoid duplicated logging [1].
177+
# However, due to bug in pytest, it requires log message to propagate to root logger to be captured by `caplog` [2].
178+
# [1] https://github.com/microsoft/qlib/pull/1661
179+
# [2] https://github.com/pytest-dev/pytest/issues/3697
180+
"loggers": {"qlib": {"level": logging.DEBUG, "handlers": ["console"], "propagate": False}},
177181
# To let qlib work with other packages, we shouldn't disable existing loggers.
178182
# Note that this param is default to True according to the documentation of logging.
179183
"disable_existing_loggers": False,

tests/rl/test_logger.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
3-
43
from random import randint, choice
54
from pathlib import Path
5+
import logging
66

77
import re
88
from typing import Any, Tuple
@@ -69,6 +69,10 @@ def learn(self, batch):
6969

7070
def test_simple_env_logger(caplog):
7171
set_log_with_config(C.logging_config)
72+
# In order for caplog to capture log messages, we configure it here:
73+
# allow logs from the qlib logger to be passed to the parent logger.
74+
C.logging_config["loggers"]["qlib"]["propagate"] = True
75+
logging.config.dictConfig(C.logging_config)
7276
for venv_cls_name in ["dummy", "shmem", "subproc"]:
7377
writer = ConsoleWriter()
7478
csv_writer = CsvWriter(Path(__file__).parent / ".output")
@@ -80,13 +84,12 @@ def test_simple_env_logger(caplog):
8084
output_file = pd.read_csv(Path(__file__).parent / ".output" / "result.csv")
8185
assert output_file.columns.tolist() == ["reward", "a", "c"]
8286
assert len(output_file) >= 30
83-
8487
line_counter = 0
8588
for line in caplog.text.splitlines():
8689
line = line.strip()
8790
if line:
8891
line_counter += 1
89-
assert re.match(r".*reward .* a .* \((4|5|6)\.\d+\) c .* \((14|15|16)\.\d+\)", line)
92+
assert re.match(r".*reward .* {2}a .* \(([456])\.\d+\) {2}c .* \((14|15|16)\.\d+\)", line)
9093
assert line_counter >= 3
9194

9295

@@ -137,15 +140,17 @@ def learn(self, batch):
137140

138141
def test_logger_with_env_wrapper():
139142
with DataQueue(list(range(20)), shuffle=False) as data_iterator:
140-
env_wrapper_factory = lambda: EnvWrapper(
141-
SimpleSimulator,
142-
DummyStateInterpreter(),
143-
DummyActionInterpreter(),
144-
data_iterator,
145-
logger=LogCollector(LogLevel.DEBUG),
146-
)
147-
148-
# loglevel can be debug here because metrics can all dump into csv
143+
144+
def env_wrapper_factory():
145+
return EnvWrapper(
146+
SimpleSimulator,
147+
DummyStateInterpreter(),
148+
DummyActionInterpreter(),
149+
data_iterator,
150+
logger=LogCollector(LogLevel.DEBUG),
151+
)
152+
153+
# loglevel can be debugged here because metrics can all dump into csv
149154
# otherwise, csv writer might crash
150155
csv_writer = CsvWriter(Path(__file__).parent / ".output", loglevel=LogLevel.DEBUG)
151156
venv = vectorize_env(env_wrapper_factory, "shmem", 4, csv_writer)
@@ -155,7 +160,7 @@ def test_logger_with_env_wrapper():
155160

156161
output_df = pd.read_csv(Path(__file__).parent / ".output" / "result.csv")
157162
assert len(output_df) == 20
158-
# obs has a increasing trend
163+
# obs has an increasing trend
159164
assert output_df["obs"].to_numpy()[:10].sum() < output_df["obs"].to_numpy()[10:].sum()
160165
assert (output_df["test_a"] == 233).all()
161166
assert (output_df["test_b"] == 200).all()

0 commit comments

Comments
 (0)