Merge pull request #89 from nyanp/fix/scikit-learn-1.0

Support scikit-learn 1.0
feature/enhance-averaging
nyanp 2021-10-30 15:31:25 +09:00 committed by GitHub
commit 4758343cac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 6 deletions

View File

@ -46,7 +46,7 @@ def check_cv(cv: Union[int, Iterable, BaseCrossValidator] = 5,
else:
return KFold(cv, shuffle=True, random_state=random_state)
return model_selection.check_cv(cv, y, stratified)
return model_selection.check_cv(cv, y, classifier=stratified)
class Take(BaseCrossValidator):
@ -380,8 +380,7 @@ class StratifiedGroupKFold(_BaseKFold):
def __init__(self, n_splits: int = 3, shuffle: bool = False,
random_state: Optional[Union[int, np.random.RandomState]] = None):
super(StratifiedGroupKFold, self).__init__(n_splits, shuffle,
random_state)
super().__init__(n_splits, shuffle=shuffle, random_state=random_state)
def _make_test_folds(self, X, y=None, groups=None):
"""

View File

@ -16,9 +16,7 @@ def _check_parameter_tunes(params, x, y):
def test_regression_problem_parameter_tunes():
dataset = datasets.load_boston()
x = pd.DataFrame(dataset.data, columns=dataset.feature_names)
y = pd.Series(dataset.target)
x, y = datasets.load_diabetes(return_X_y=True, as_frame=True)
params = {
'objective': 'regression',
'metric': 'rmse',