Compare commits
1 Commits
master
...
feature/ca
Author | SHA1 | Date |
---|---|---|
Taiga Noumi | a8e8f48145 |
|
@ -2,6 +2,7 @@ import json
|
|||
import numbers
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import uuid
|
||||
import warnings
|
||||
from logging import getLogger, FileHandler, DEBUG, Logger
|
||||
|
@ -62,6 +63,20 @@ def _try_to_get_existing_mlflow_run_id(logging_directory: str) -> Optional[str]:
|
|||
return None
|
||||
|
||||
|
||||
class LoggerWriter(object):
|
||||
def __init__(self, logger, stdout):
|
||||
self.logger = logger
|
||||
self.stdout = stdout
|
||||
|
||||
def write(self, message):
|
||||
self.stdout.write(message)
|
||||
if message != '\n':
|
||||
self.logger.debug(message)
|
||||
|
||||
def flush(self):
|
||||
pass
|
||||
|
||||
|
||||
class Experiment(object):
|
||||
"""Minimal experiment logger for Kaggle
|
||||
|
||||
|
@ -96,6 +111,10 @@ class Experiment(object):
|
|||
- replace: Delete logging directory before logging.
|
||||
- append: Append to exisitng experiment.
|
||||
- rename: Rename current directory by adding "_1", "_2"... prefix
|
||||
capture_stdout:
|
||||
If True, all message to stdout will be captured and recorded to log file.
|
||||
Since it has the global side effect by replacing sys.stdout with custom object while experiment,
|
||||
it is not recommended to use in threaded applications.
|
||||
Example:
|
||||
>>> import numpy as np
|
||||
>>> import pandas as pd
|
||||
|
@ -122,13 +141,16 @@ class Experiment(object):
|
|||
logging_directory: str,
|
||||
custom_logger: Optional[Logger] = None,
|
||||
with_mlflow: bool = False,
|
||||
if_exists: str = 'error'
|
||||
if_exists: str = 'error',
|
||||
capture_stdout: bool = False
|
||||
):
|
||||
logging_directory = _check_directory(logging_directory, if_exists)
|
||||
os.makedirs(logging_directory, exist_ok=True)
|
||||
|
||||
self.logging_directory = logging_directory
|
||||
self.with_mlflow = with_mlflow
|
||||
self.redirect_stdout = capture_stdout
|
||||
self.old_stream = None
|
||||
|
||||
if custom_logger is not None:
|
||||
self.logger = custom_logger
|
||||
|
@ -183,13 +205,16 @@ class Experiment(object):
|
|||
with open(os.path.join(self.logging_directory, 'mlflow.json'), 'w') as f:
|
||||
json.dump(mlflow_metadata, f, indent=4)
|
||||
|
||||
if self.redirect_stdout:
|
||||
self.old_stream = sys.stdout
|
||||
sys.stdout = LoggerWriter(self.logger, self.old_stream)
|
||||
|
||||
def _load_dict(self, filename: str) -> Dict:
|
||||
try:
|
||||
path = os.path.join(self.logging_directory, filename)
|
||||
with open(path, 'r') as f:
|
||||
return json.load(f)
|
||||
except IOError:
|
||||
self.logger.warning('failed to load file: {}'.format(filename))
|
||||
return {}
|
||||
|
||||
def _save_dict(self, obj: Dict, filename: str):
|
||||
|
@ -204,6 +229,9 @@ class Experiment(object):
|
|||
"""
|
||||
Stop current experiment.
|
||||
"""
|
||||
if self.redirect_stdout:
|
||||
sys.stdout = self.old_stream
|
||||
|
||||
self._save_dict(self.metrics, 'metrics.json')
|
||||
self._save_dict(self.params, 'params.json')
|
||||
|
||||
|
|
|
@ -112,8 +112,9 @@ def run_experiment(model_params: Dict[str, Any],
|
|||
Test data (Optional). If specified, prediction on the test data is performed using ensemble of models.
|
||||
logging_directory:
|
||||
Path to directory where output of experiment is stored.
|
||||
It will be ignored if ``inherit_experiment`` is used.
|
||||
if_exists:
|
||||
How to behave if the logging directory already exists.
|
||||
How to behave if the logging directory already exists. It will be ignored if ``inherit_experiment`` is used.
|
||||
|
||||
- error: Raise a ValueError.
|
||||
- replace: Delete logging directory before logging.
|
||||
|
@ -164,8 +165,7 @@ def run_experiment(model_params: Dict[str, Any],
|
|||
with_mlflow:
|
||||
If True, `mlflow tracking <https://www.mlflow.org/docs/latest/tracking.html>`_ is used.
|
||||
One instance of ``nyaggle.experiment.Experiment`` corresponds to one run in mlflow.
|
||||
Note that all output
|
||||
mlflow's directory (``mlruns`` by default).
|
||||
It will be ignored if ``inherit_experiment`` is used.
|
||||
:return:
|
||||
Namedtuple with following members
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import json
|
||||
import os
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
@ -264,3 +265,19 @@ def test_experiment_continue():
|
|||
with open(metric_file, 'r') as f:
|
||||
obj = json.load(f)
|
||||
assert 'Y' in obj
|
||||
|
||||
|
||||
def test_redirect_stdout():
|
||||
with get_temp_directory() as tmpdir:
|
||||
with Experiment(tmpdir, capture_stdout=True) as e:
|
||||
e.log('foo')
|
||||
print('bar')
|
||||
print('buzz', file=sys.stderr)
|
||||
|
||||
with open(os.path.join(tmpdir, 'log.txt'), 'r') as f:
|
||||
lines = f.readlines()
|
||||
lines = [l.strip() for l in lines]
|
||||
|
||||
assert 'foo' in lines
|
||||
assert 'bar' in lines
|
||||
assert 'buzz' not in lines # stderr is not captured
|
||||
|
|
Loading…
Reference in New Issue