Source code for biopsykit.classification.model_selection.nested_cv

"""Module with functions for model selection using "nested" cross-validation."""
import warnings
from typing import Any, Dict, Optional

import numpy as np
from biopsykit.classification.utils import split_train_test
from sklearn.metrics import confusion_matrix, get_scorer
from sklearn.model_selection import BaseCrossValidator, GridSearchCV, RandomizedSearchCV
from sklearn.pipeline import Pipeline
from tqdm.auto import tqdm

__all__ = ["nested_cv_param_search"]





def _setup_scoring_dict(scoring, **kwargs):
    scoring_dict = {}
    if scoring is not None:
        if isinstance(scoring, str):
            kwargs["refit"] = scoring
            scoring = [scoring]

        for score in scoring:
            scoring_dict.setdefault(score, score)
    return kwargs, scoring_dict


def _fit_cv_obj_one_fold(cv_obj, x_train, y_train, groups_train):
    try:
        if groups_train is None:
            cv_obj.fit(x_train, y_train)
        else:
            cv_obj.fit(x_train, y_train, groups=groups_train)
    except ValueError as e:
        if "Classification metrics can't handle a mix of multiclass and continuous targets" in e.args[0]:
            raise ValueError(
                "Error when attempting to fit estimator. "
                "It seems that you are trying to fit a regression model, "
                "but specified metrics for classification. "
                "Please check your code and provide other evaluation metrics if necessary!"
            ) from e
        if "An empty dict was passed." in e.args[0]:
            raise ValueError("No scoring metric was specified for the estimator!") from e
        raise ValueError from e
    return cv_obj


def _get_param_search_cv_object(
    pipeline: Pipeline,
    param_dict: Dict[str, Any],
    inner_cv: BaseCrossValidator,
    scoring_dict: Dict[str, str],
    hyper_search_config: Dict[str, Any],
    **kwargs,
):
    random_state = kwargs.pop("random_state", None)
    if hyper_search_config["search_method"] == "random":
        return RandomizedSearchCV(
            pipeline,
            param_distributions=param_dict,
            cv=inner_cv,
            scoring=scoring_dict,
            n_iter=hyper_search_config["n_iter"],
            random_state=random_state,
            **kwargs,
        )
    if hyper_search_config["search_method"] == "grid":
        return GridSearchCV(pipeline, param_grid=param_dict, cv=inner_cv, scoring=scoring_dict, **kwargs)
    raise ValueError("Unknown search method {}".format(hyper_search_config["search_method"]))