11# Copyright (c) Microsoft Corporation.
22# Licensed under the MIT License.
3-
43from random import randint , choice
54from pathlib import Path
5+ import logging
66
77import re
88from typing import Any , Tuple
@@ -69,6 +69,10 @@ def learn(self, batch):
6969
7070def test_simple_env_logger (caplog ):
7171 set_log_with_config (C .logging_config )
72+ # In order for caplog to capture log messages, we configure it here:
73+ # allow logs from the qlib logger to be passed to the parent logger.
74+ C .logging_config ["loggers" ]["qlib" ]["propagate" ] = True
75+ logging .config .dictConfig (C .logging_config )
7276 for venv_cls_name in ["dummy" , "shmem" , "subproc" ]:
7377 writer = ConsoleWriter ()
7478 csv_writer = CsvWriter (Path (__file__ ).parent / ".output" )
@@ -80,13 +84,12 @@ def test_simple_env_logger(caplog):
8084 output_file = pd .read_csv (Path (__file__ ).parent / ".output" / "result.csv" )
8185 assert output_file .columns .tolist () == ["reward" , "a" , "c" ]
8286 assert len (output_file ) >= 30
83-
8487 line_counter = 0
8588 for line in caplog .text .splitlines ():
8689 line = line .strip ()
8790 if line :
8891 line_counter += 1
89- assert re .match (r".*reward .* a .* \((4|5|6 )\.\d+\) c .* \((14|15|16)\.\d+\)" , line )
92+ assert re .match (r".*reward .* {2} a .* \(([456] )\.\d+\) {2} c .* \((14|15|16)\.\d+\)" , line )
9093 assert line_counter >= 3
9194
9295
@@ -137,15 +140,17 @@ def learn(self, batch):
137140
138141def test_logger_with_env_wrapper ():
139142 with DataQueue (list (range (20 )), shuffle = False ) as data_iterator :
140- env_wrapper_factory = lambda : EnvWrapper (
141- SimpleSimulator ,
142- DummyStateInterpreter (),
143- DummyActionInterpreter (),
144- data_iterator ,
145- logger = LogCollector (LogLevel .DEBUG ),
146- )
147-
148- # loglevel can be debug here because metrics can all dump into csv
143+
144+ def env_wrapper_factory ():
145+ return EnvWrapper (
146+ SimpleSimulator ,
147+ DummyStateInterpreter (),
148+ DummyActionInterpreter (),
149+ data_iterator ,
150+ logger = LogCollector (LogLevel .DEBUG ),
151+ )
152+
153+ # loglevel can be debugged here because metrics can all dump into csv
149154 # otherwise, csv writer might crash
150155 csv_writer = CsvWriter (Path (__file__ ).parent / ".output" , loglevel = LogLevel .DEBUG )
151156 venv = vectorize_env (env_wrapper_factory , "shmem" , 4 , csv_writer )
@@ -155,7 +160,7 @@ def test_logger_with_env_wrapper():
155160
156161 output_df = pd .read_csv (Path (__file__ ).parent / ".output" / "result.csv" )
157162 assert len (output_df ) == 20
158- # obs has a increasing trend
163+ # obs has an increasing trend
159164 assert output_df ["obs" ].to_numpy ()[:10 ].sum () < output_df ["obs" ].to_numpy ()[10 :].sum ()
160165 assert (output_df ["test_a" ] == 233 ).all ()
161166 assert (output_df ["test_b" ] == 200 ).all ()
0 commit comments