diff --git a/blazingai/tabular/model_selection.py b/blazingai/tabular/model_selection.py index 3c190af..78c1d92 100644 --- a/blazingai/tabular/model_selection.py +++ b/blazingai/tabular/model_selection.py @@ -1,4 +1,4 @@ -from typing import Dict, List +from typing import Any, Dict, Generator, List, Optional import numpy as np import pandas as pd @@ -53,29 +53,41 @@ def validate_method(method): else: return method - def split(self, X: pd.DataFrame, y: np.ndarray = None, groups: str = None): + def split( + self, + X: pd.DataFrame, + y: Optional[np.ndarray] = None, + groups: Optional[str] = None, + ) -> Generator[ + tuple[ + np.ndarray, + np.ndarray, + ], + None, + None, + ]: """ Args: X (pandas.DataFrame): Input data. method (str): Either `sliding` or `expanding`. - y: to make the function compatible with sklearn cross validation. - groups: to make the function compatible with sklearn cross validation. + y: Unused, but required for compatibility with sklearn cross validation API. + groups: Unused, but required for compatibility with sklearn cross validation API. Returns: (generator) Indexes for the training and Validation set from the data passed """ self._is_valid_df(X) - ar = self._df_to_array(X) - self._has_enough_days(ar) - first = ar.min() + arr = self._df_to_array(X) + self._has_enough_days(arr) + first = arr.min() # yield indexes for train and valid sets for step in range(self.n_splits): train_start, train_end, valid_start, valid_end = self._get_dates( first=first, step=step ) - train_mask = (ar >= train_start) & (ar <= train_end) - valid_mask = (ar >= valid_start) & (ar <= valid_end) + train_mask = (arr >= train_start) & (arr <= train_end) + valid_mask = (arr >= valid_start) & (arr <= valid_end) self.train_date_ranges["period_" + str(step)] = [ train_start.date().isoformat(),