add option to capturing stdout

feature/capture-stdout
Taiga Noumi 2020-02-27 20:52:48 +09:00 committed by nyanp
parent 5b5505dc09
commit a8e8f48145
3 changed files with 50 additions and 5 deletions

View File

@ -2,6 +2,7 @@ import json
import numbers
import os
import shutil
import sys
import uuid
import warnings
from logging import getLogger, FileHandler, DEBUG, Logger
@ -62,6 +63,20 @@ def _try_to_get_existing_mlflow_run_id(logging_directory: str) -> Optional[str]:
return None
class LoggerWriter(object):
def __init__(self, logger, stdout):
self.logger = logger
self.stdout = stdout
def write(self, message):
self.stdout.write(message)
if message != '\n':
self.logger.debug(message)
def flush(self):
pass
class Experiment(object):
"""Minimal experiment logger for Kaggle
@ -96,6 +111,10 @@ class Experiment(object):
- replace: Delete logging directory before logging.
- append: Append to exisitng experiment.
- rename: Rename current directory by adding "_1", "_2"... prefix
capture_stdout:
If True, all message to stdout will be captured and recorded to log file.
Since it has the global side effect by replacing sys.stdout with custom object while experiment,
it is not recommended to use in threaded applications.
Example:
>>> import numpy as np
>>> import pandas as pd
@ -122,13 +141,16 @@ class Experiment(object):
logging_directory: str,
custom_logger: Optional[Logger] = None,
with_mlflow: bool = False,
if_exists: str = 'error'
if_exists: str = 'error',
capture_stdout: bool = False
):
logging_directory = _check_directory(logging_directory, if_exists)
os.makedirs(logging_directory, exist_ok=True)
self.logging_directory = logging_directory
self.with_mlflow = with_mlflow
self.redirect_stdout = capture_stdout
self.old_stream = None
if custom_logger is not None:
self.logger = custom_logger
@ -183,13 +205,16 @@ class Experiment(object):
with open(os.path.join(self.logging_directory, 'mlflow.json'), 'w') as f:
json.dump(mlflow_metadata, f, indent=4)
if self.redirect_stdout:
self.old_stream = sys.stdout
sys.stdout = LoggerWriter(self.logger, self.old_stream)
def _load_dict(self, filename: str) -> Dict:
try:
path = os.path.join(self.logging_directory, filename)
with open(path, 'r') as f:
return json.load(f)
except IOError:
self.logger.warning('failed to load file: {}'.format(filename))
return {}
def _save_dict(self, obj: Dict, filename: str):
@ -204,6 +229,9 @@ class Experiment(object):
"""
Stop current experiment.
"""
if self.redirect_stdout:
sys.stdout = self.old_stream
self._save_dict(self.metrics, 'metrics.json')
self._save_dict(self.params, 'params.json')

View File

@ -112,8 +112,9 @@ def run_experiment(model_params: Dict[str, Any],
Test data (Optional). If specified, prediction on the test data is performed using ensemble of models.
logging_directory:
Path to directory where output of experiment is stored.
It will be ignored if ``inherit_experiment`` is used.
if_exists:
How to behave if the logging directory already exists.
How to behave if the logging directory already exists. It will be ignored if ``inherit_experiment`` is used.
- error: Raise a ValueError.
- replace: Delete logging directory before logging.
@ -164,8 +165,7 @@ def run_experiment(model_params: Dict[str, Any],
with_mlflow:
If True, `mlflow tracking <https://www.mlflow.org/docs/latest/tracking.html>`_ is used.
One instance of ``nyaggle.experiment.Experiment`` corresponds to one run in mlflow.
Note that all output
mlflow's directory (``mlruns`` by default).
It will be ignored if ``inherit_experiment`` is used.
:return:
Namedtuple with following members

View File

@ -1,6 +1,7 @@
import json
import os
import pytest
import sys
import pandas as pd
import numpy as np
@ -264,3 +265,19 @@ def test_experiment_continue():
with open(metric_file, 'r') as f:
obj = json.load(f)
assert 'Y' in obj
def test_redirect_stdout():
with get_temp_directory() as tmpdir:
with Experiment(tmpdir, capture_stdout=True) as e:
e.log('foo')
print('bar')
print('buzz', file=sys.stderr)
with open(os.path.join(tmpdir, 'log.txt'), 'r') as f:
lines = f.readlines()
lines = [l.strip() for l in lines]
assert 'foo' in lines
assert 'bar' in lines
assert 'buzz' not in lines # stderr is not captured