nyaggle/nyaggle/hyper_parameters/parameters.py

77 lines
2.5 KiB
Python

from more_itertools import first_true
from typing import Dict, List, Union
from nyaggle.hyper_parameters.catboost import parameters as params_cat
from nyaggle.hyper_parameters.lightgbm import parameters as params_lgb
from nyaggle.hyper_parameters.xgboost import parameters as params_xgb
def _get_hyperparam_byname(param_table: List[Dict], name: str, with_metadata: bool):
found = first_true(param_table, pred=lambda x: x['name'] == name)
if found is None:
raise RuntimeError('Hyperparameter {} not found.'.format(name))
if with_metadata:
return found
else:
return found['parameters']
def _return(parameter: Union[List[Dict], Dict], with_metadata: bool) -> Union[List[Dict], Dict]:
if with_metadata:
return parameter
if isinstance(parameter, list):
return [p['parameters'] for p in parameter]
else:
return parameter['parameters']
def _get_table(gbdt_type: str = 'lgbm'):
if gbdt_type == 'lgbm':
return params_lgb
elif gbdt_type == 'cat':
return params_cat
elif gbdt_type == 'xgb':
return params_xgb
raise ValueError('gbdt type should be one of (lgbm, cat, xgb)')
def list_hyperparams(gbdt_type: str = 'lgbm', with_metadata: bool = False) -> List[Dict]:
"""
List all hyperparameters
Args:
gbdt_type:
The type of gbdt library. ``lgbm``, ``cat``, ``xgb`` can be used.
with_metadata:
When set to True, parameters are wrapped by metadata dictionary which contains information about
source URL, competition name etc.
Returns:
A list of hyper-parameters used in Kaggle gold medal solutions
"""
return _return(_get_table(gbdt_type), with_metadata)
def get_hyperparam_byname(name: str, gbdt_type: str = 'lgbm', with_metadata: bool = False) -> Dict:
"""
Get a hyperparameter by parameter name
Args:
name:
The name of parameter (e.g. "ieee-2019-10th").
gbdt_type:
The type of gbdt library. ``lgbm``, ``cat``, ``xgb`` can be used.
with_metadata:
When set to True, parameters are wrapped by metadata dictionary which contains information about
source URL, competition name etc.
Returns:
A hyperparameter dictionary.
"""
param_table = _get_table(gbdt_type)
found = first_true(param_table, pred=lambda x: x['name'] == name)
if found is None:
raise RuntimeError('Hyperparameter {} not found.'.format(name))
return _return(found, with_metadata)