You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The Python package's tests fail with the latest scikit-learn nightlies (v1.6.dev0).
================= 7 failed, 103 passed, 14 warnings in 19.91s ==================
All the failures appear to be from the estimator checks scikit-learn ships to help projects test compliance with scikit-learn API expectations. Stuff like this:
E AssertionError: XGBRegressor.predict() does not check for consistency between input number
E of features with XGBRegressor.fit(), via the n_features_in_ attribute.
full logs (click me)
============================= test session starts ==============================
platform darwin -- Python 3.11.9, pytest-8.2.2, pluggy-1.5.0
rootdir: /Users/jlamb/repos/xgboost/tests
configfile: pytest.ini
plugins: cov-5.0.0, hypothesis-6.115.2
collected 110 items
tests/python/test_with_sklearn.py ...................................... [ 34%]
.....................F.......F.F.F.............F...............FF....... [100%]
=================================== FAILURES ===================================
_ test_estimator_reg[XGBRegressor(base_score=None,booster=None,callbacks=None,colsample_bylevel=None,colsample_bynode=None,colsample_bytree=None,device=None,early_stopping_rounds=None,enable_categorical=False,eval_metric=None,feature_types=None,gamma=None,grow_policy=None,importance_type=None,interaction_constraints=None,learning_rate=None,max_bin=None,max_cat_threshold=None,max_cat_to_onehot=None,max_delta_step=None,max_depth=None,max_leaves=None,min_child_weight=None,missing=nan,monotone_constraints=None,multi_strategy=None,n_estimators=None,n_jobs=None,num_parallel_tree=None,random_state=None,...)-check_n_features_in_after_fitting] _
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/sklearn/utils/estimator_checks.py:3974: in check_n_features_in_after_fitting
callable_method(X_bad)
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/xgboost/core.py:775: in inner_f
return func(**kwargs)
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/xgboost/sklearn.py:1225: in predict
predts = self.get_booster().inplace_predict(
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/xgboost/core.py:775: in inner_f
return func(**kwargs)
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/xgboost/core.py:2642: in inplace_predict
raise ValueError(
E ValueError: Feature shape mismatch, expected: 4, got 1
The above exception was the direct cause of the following exception:
tests/python/test_with_sklearn.py:1349: in test_estimator_reg
check(estimator)
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/sklearn/utils/_testing.py:140: in wrapper
return fn(*args, **kwargs)
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/sklearn/utils/estimator_checks.py:3971: in check_n_features_in_after_fitting
with raises(
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/sklearn/utils/_testing.py:1076: in __exit__
raise AssertionError(err_msg) from exc_value
E AssertionError: `XGBRegressor.predict()` does not check for consistency between input number
E of features with XGBRegressor.fit(), via the `n_features_in_` attribute.
E You might want to use `sklearn.utils.validation.validate_data` instead
E of `check_array` in `XGBRegressor.fit()` and XGBRegressor.predict()`. This can be done
E like the following:
E from sklearn.utils.validation import validate_data
E ...
E class MyEstimator(BaseEstimator):
E ...
E def fit(self, X, y):
E X, y = validate_data(self, X, y, ...)
E ...
E return self
E ...
E def predict(self, X):
E X = validate_data(self, X, ..., reset=False)
E ...
E return X
_ test_estimator_reg[XGBRegressor(base_score=None,booster=None,callbacks=None,colsample_bylevel=None,colsample_bynode=None,colsample_bytree=None,device=None,early_stopping_rounds=None,enable_categorical=False,eval_metric=None,feature_types=None,gamma=None,grow_policy=None,importance_type=None,interaction_constraints=None,learning_rate=None,max_bin=None,max_cat_threshold=None,max_cat_to_onehot=None,max_delta_step=None,max_depth=None,max_leaves=None,min_child_weight=None,missing=nan,monotone_constraints=None,multi_strategy=None,n_estimators=None,n_jobs=None,num_parallel_tree=None,random_state=None,...)-check_complex_data] _
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/sklearn/utils/estimator_checks.py:1239: in check_complex_data
estimator.fit(X, y)
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/xgboost/core.py:775: in inner_f
return func(**kwargs)
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/xgboost/sklearn.py:1118: in fit
train_dmatrix, evals = _wrap_evaluation_matrices(
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/xgboost/sklearn.py:605: in _wrap_evaluation_matrices
train_dmatrix = create_dmatrix(
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/xgboost/sklearn.py:1040: in _create_dmatrix
return QuantileDMatrix(
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/xgboost/core.py:775: in inner_f
return func(**kwargs)
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/xgboost/core.py:1636: in __init__
self._init(
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/xgboost/core.py:1695: in _init
it.reraise()
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/xgboost/core.py:618: in reraise
raise exc # pylint: disable=raising-bad-type
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/xgboost/core.py:599: in _handle_exception
return fn()
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/xgboost/core.py:686: in <lambda>
return self._handle_exception(lambda: int(self.next(input_data)), 0)
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/xgboost/data.py:1479: in next
input_data(**self.kwargs)
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/xgboost/core.py:775: in inner_f
return func(**kwargs)
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/xgboost/core.py:674: in input_data
dispatch_proxy_set_data(self.proxy, new, cat_codes)
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/xgboost/data.py:1559: in dispatch_proxy_set_data
proxy._ref_data_from_array(data) # pylint: disable=W0212
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/xgboost/core.py:1509: in _ref_data_from_array
_check_call(
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/xgboost/core.py:297: in _check_call
raise XGBoostError(py_str(_LIB.XGBGetLastError()))
E xgboost.core.XGBoostError: [23:42:00] /Users/jlamb/repos/xgboost/src/c_api/../data/array_interface.h:499: Complex floating point-1 is not supported.
E Stack trace:
E [bt] (0) 1 libxgboost.dylib 0x0000000148b1e448 dmlc::LogMessageFatal::~LogMessageFatal() + 124
E [bt] (1) 2 libxgboost.dylib 0x0000000148b32684 xgboost::ArrayInterface<2, false>::AssignType(xgboost::StringView) + 1272
E [bt] (2) 3 libxgboost.dylib 0x0000000148b31e64 xgboost::ArrayInterface<2, false>::Initialize(std::__1::map<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>, xgboost::Json, std::__1::less<void>, std::__1::allocator<std::__1::pair<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>> const, xgboost::Json>>> const&) + 408
E [bt] (3) 4 libxgboost.dylib 0x0000000148ccea50 xgboost::data::ArrayAdapter::ArrayAdapter(xgboost::StringView) + 148
E [bt] (4) 5 libxgboost.dylib 0x0000000148cce664 xgboost::data::DMatrixProxy::SetArrayData(xgboost::StringView) + 72
E [bt] (5) 6 libxgboost.dylib 0x0000000148b2a8c0 XGProxyDMatrixSetDataDense + 136
E [bt] (6) 7 libffi.8.dylib 0x0000000105ebc04c ffi_call_SYSV + 76
E [bt] (7) 8 libffi.8.dylib 0x0000000105eb974c ffi_call_int + 1208
E [bt] (8) 9 _ctypes.cpython-311-darwin.so 0x0000000105e38988 _ctypes_callproc + 1208
The above exception was the direct cause of the following exception:
tests/python/test_with_sklearn.py:1349: in test_estimator_reg
check(estimator)
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/sklearn/utils/estimator_checks.py:1238: in check_complex_data
with raises(ValueError, match="Complex data not supported"):
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/sklearn/utils/_testing.py:1076: in __exit__
raise AssertionError(err_msg) from exc_value
E AssertionError: The error message should contain one of the following patterns:
E Complex data not supported
E Got [23:42:00] /Users/jlamb/repos/xgboost/src/c_api/../data/array_interface.h:499: Complex floating point-1 is not supported.
E Stack trace:
E [bt] (0) 1 libxgboost.dylib 0x0000000148b1e448 dmlc::LogMessageFatal::~LogMessageFatal() + 124
E [bt] (1) 2 libxgboost.dylib 0x0000000148b32684 xgboost::ArrayInterface<2, false>::AssignType(xgboost::StringView) + 1272
E [bt] (2) 3 libxgboost.dylib 0x0000000148b31e64 xgboost::ArrayInterface<2, false>::Initialize(std::__1::map<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>>, xgboost::Json, std::__1::less<void>, std::__1::allocator<std::__1::pair<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char>> const, xgboost::Json>>> const&) + 408
E [bt] (3) 4 libxgboost.dylib 0x0000000148ccea50 xgboost::data::ArrayAdapter::ArrayAdapter(xgboost::StringView) + 148
E [bt] (4) 5 libxgboost.dylib 0x0000000148cce664 xgboost::data::DMatrixProxy::SetArrayData(xgboost::StringView) + 72
E [bt] (5) 6 libxgboost.dylib 0x0000000148b2a8c0 XGProxyDMatrixSetDataDense + 136
E [bt] (6) 7 libffi.8.dylib 0x0000000105ebc04c ffi_call_SYSV + 76
E [bt] (7) 8 libffi.8.dylib 0x0000000105eb974c ffi_call_int + 1208
E [bt] (8) 9 _ctypes.cpython-311-darwin.so 0x0000000105e38988 _ctypes_callproc + 1208
_ test_estimator_reg[XGBRegressor(base_score=None,booster=None,callbacks=None,colsample_bylevel=None,colsample_bynode=None,colsample_bytree=None,device=None,early_stopping_rounds=None,enable_categorical=False,eval_metric=None,feature_types=None,gamma=None,grow_policy=None,importance_type=None,interaction_constraints=None,learning_rate=None,max_bin=None,max_cat_threshold=None,max_cat_to_onehot=None,max_delta_step=None,max_depth=None,max_leaves=None,min_child_weight=None,missing=nan,monotone_constraints=None,multi_strategy=None,n_estimators=None,n_jobs=None,num_parallel_tree=None,random_state=None,...)-check_estimators_empty_data_messages] _
tests/python/test_with_sklearn.py:1349: in test_estimator_reg
check(estimator)
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/sklearn/utils/_testing.py:140: in wrapper
return fn(*args, **kwargs)
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/sklearn/utils/estimator_checks.py:1830: in check_estimators_empty_data_messages
with raises(ValueError, err_msg=err_msg):
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/sklearn/utils/_testing.py:1059: in __exit__
raise AssertionError(err_msg)
E AssertionError: The estimator XGBRegressor does not raise a ValueError when an empty data is used to train. Perhaps use check_array in train.
_ test_estimator_reg[XGBRegressor(base_score=None,booster=None,callbacks=None,colsample_bylevel=None,colsample_bynode=None,colsample_bytree=None,device=None,early_stopping_rounds=None,enable_categorical=False,eval_metric=None,feature_types=None,gamma=None,grow_policy=None,importance_type=None,interaction_constraints=None,learning_rate=None,max_bin=None,max_cat_threshold=None,max_cat_to_onehot=None,max_delta_step=None,max_depth=None,max_leaves=None,min_child_weight=None,missing=nan,monotone_constraints=None,multi_strategy=None,n_estimators=None,n_jobs=None,num_parallel_tree=None,random_state=None,...)-check_estimators_nan_inf] _
tests/python/test_with_sklearn.py:1349: in test_estimator_reg
check(estimator)
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/sklearn/utils/_testing.py:140: in wrapper
return fn(*args, **kwargs)
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/sklearn/utils/estimator_checks.py:1867: in check_estimators_nan_inf
with raises(ValueError, match=["inf", "NaN"], err_msg=error_string_fit):
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/sklearn/utils/_testing.py:1059: in __exit__
raise AssertionError(err_msg)
E AssertionError: Estimator XGBRegressor doesn't check for NaN and inf in fit.
_ test_estimator_reg[XGBRegressor(base_score=None,booster=None,callbacks=None,colsample_bylevel=None,colsample_bynode=None,colsample_bytree=None,device=None,early_stopping_rounds=None,enable_categorical=False,eval_metric=None,feature_types=None,gamma=None,grow_policy=None,importance_type=None,interaction_constraints=None,learning_rate=None,max_bin=None,max_cat_threshold=None,max_cat_to_onehot=None,max_delta_step=None,max_depth=None,max_leaves=None,min_child_weight=None,missing=nan,monotone_constraints=None,multi_strategy=None,n_estimators=None,n_jobs=None,num_parallel_tree=None,random_state=None,...)-check_supervised_y_2d] _
tests/python/test_with_sklearn.py:1349: in test_estimator_reg
check(estimator)
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/sklearn/utils/_testing.py:140: in wrapper
return fn(*args, **kwargs)
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/sklearn/utils/estimator_checks.py:2799: in check_supervised_y_2d
assert len(w) > 0, msg
E AssertionError: expected 1 DataConversionWarning, got:
_ test_estimator_reg[XGBRegressor(base_score=None,booster=None,callbacks=None,colsample_bylevel=None,colsample_bynode=None,colsample_bytree=None,device=None,early_stopping_rounds=None,enable_categorical=False,eval_metric=None,feature_types=None,gamma=None,grow_policy=None,importance_type=None,interaction_constraints=None,learning_rate=None,max_bin=None,max_cat_threshold=None,max_cat_to_onehot=None,max_delta_step=None,max_depth=None,max_leaves=None,min_child_weight=None,missing=nan,monotone_constraints=None,multi_strategy=None,n_estimators=None,n_jobs=None,num_parallel_tree=None,random_state=None,...)-check_fit2d_predict1d] _
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/sklearn/utils/estimator_checks.py:1345: in check_fit2d_predict1d
getattr(estimator, method)(X[0])
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/xgboost/core.py:775: in inner_f
return func(**kwargs)
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/xgboost/sklearn.py:1225: in predict
predts = self.get_booster().inplace_predict(
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/xgboost/core.py:775: in inner_f
return func(**kwargs)
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/xgboost/core.py:2651: in inplace_predict
_check_call(
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/xgboost/core.py:297: in _check_call
raise XGBoostError(py_str(_LIB.XGBGetLastError()))
E xgboost.core.XGBoostError: [23:42:04] /Users/jlamb/repos/xgboost/src/predictor/cpu_predictor.cc:789: Check failed: m->NumColumns() == model.learner_model_param->num_feature (1 vs. 3) : Number of columns in data must equal to trained model.
E Stack trace:
E [bt] (0) 1 libxgboost.dylib 0x0000000148b1e448 dmlc::LogMessageFatal::~LogMessageFatal() + 124
E [bt] (1) 2 libxgboost.dylib 0x0000000148dd3e10 void xgboost::predictor::CPUPredictor::DispatchedInplacePredict<xgboost::data::ArrayAdapter, 64ul>(std::__1::any const&, std::__1::shared_ptr<xgboost::DMatrix>, xgboost::gbm::GBTreeModel const&, float, xgboost::PredictionCacheEntry*, unsigned int, unsigned int) const + 344
E [bt] (2) 3 libxgboost.dylib 0x0000000148dca8d0 xgboost::predictor::CPUPredictor::InplacePredict(std::__1::shared_ptr<xgboost::DMatrix>, xgboost::gbm::GBTreeModel const&, float, xgboost::PredictionCacheEntry*, unsigned int, unsigned int) const + 1572
E [bt] (3) 4 libxgboost.dylib 0x0000000148cf83d0 xgboost::gbm::GBTree::InplacePredict(std::__1::shared_ptr<xgboost::DMatrix>, float, xgboost::PredictionCacheEntry*, int, int) const + 740
E [bt] (4) 5 libxgboost.dylib 0x0000000148d16c20 xgboost::LearnerImpl::InplacePredict(std::__1::shared_ptr<xgboost::DMatrix>, xgboost::PredictionType, float, xgboost::HostDeviceVector<float>**, int, int) + 164
E [bt] (5) 6 libxgboost.dylib 0x0000000148ba0e58 InplacePredictImpl(std::__1::shared_ptr<xgboost::DMatrix>, char const*, xgboost::Learner*, unsigned long long const**, unsigned long long*, float const**) + 276
E [bt] (6) 7 libxgboost.dylib 0x0000000148ba14b4 XGBoosterPredictFromDense + 420
E [bt] (7) 8 libffi.8.dylib 0x0000000105ebc04c ffi_call_SYSV + 76
E [bt] (8) 9 libffi.8.dylib 0x0000000105eb974c ffi_call_int + 1208
The above exception was the direct cause of the following exception:
tests/python/test_with_sklearn.py:1349: in test_estimator_reg
check(estimator)
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/sklearn/utils/_testing.py:140: in wrapper
return fn(*args, **kwargs)
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/sklearn/utils/estimator_checks.py:1344: in check_fit2d_predict1d
with raises(ValueError, match="Reshape your data"):
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/sklearn/utils/_testing.py:1076: in __exit__
raise AssertionError(err_msg) from exc_value
E AssertionError: The error message should contain one of the following patterns:
E Reshape your data
E Got [23:42:04] /Users/jlamb/repos/xgboost/src/predictor/cpu_predictor.cc:789: Check failed: m->NumColumns() == model.learner_model_param->num_feature (1 vs. 3) : Number of columns in data must equal to trained model.
E Stack trace:
E [bt] (0) 1 libxgboost.dylib 0x0000000148b1e448 dmlc::LogMessageFatal::~LogMessageFatal() + 124
E [bt] (1) 2 libxgboost.dylib 0x0000000148dd3e10 void xgboost::predictor::CPUPredictor::DispatchedInplacePredict<xgboost::data::ArrayAdapter, 64ul>(std::__1::any const&, std::__1::shared_ptr<xgboost::DMatrix>, xgboost::gbm::GBTreeModel const&, float, xgboost::PredictionCacheEntry*, unsigned int, unsigned int) const + 344
E [bt] (2) 3 libxgboost.dylib 0x0000000148dca8d0 xgboost::predictor::CPUPredictor::InplacePredict(std::__1::shared_ptr<xgboost::DMatrix>, xgboost::gbm::GBTreeModel const&, float, xgboost::PredictionCacheEntry*, unsigned int, unsigned int) const + 1572
E [bt] (3) 4 libxgboost.dylib 0x0000000148cf83d0 xgboost::gbm::GBTree::InplacePredict(std::__1::shared_ptr<xgboost::DMatrix>, float, xgboost::PredictionCacheEntry*, int, int) const + 740
E [bt] (4) 5 libxgboost.dylib 0x0000000148d16c20 xgboost::LearnerImpl::InplacePredict(std::__1::shared_ptr<xgboost::DMatrix>, xgboost::PredictionType, float, xgboost::HostDeviceVector<float>**, int, int) + 164
E [bt] (5) 6 libxgboost.dylib 0x0000000148ba0e58 InplacePredictImpl(std::__1::shared_ptr<xgboost::DMatrix>, char const*, xgboost::Learner*, unsigned long long const**, unsigned long long*, float const**) + 276
E [bt] (6) 7 libxgboost.dylib 0x0000000148ba14b4 XGBoosterPredictFromDense + 420
E [bt] (7) 8 libffi.8.dylib 0x0000000105ebc04c ffi_call_SYSV + 76
E [bt] (8) 9 libffi.8.dylib 0x0000000105eb974c ffi_call_int + 1208
_ test_estimator_reg[XGBRegressor(base_score=None,booster=None,callbacks=None,colsample_bylevel=None,colsample_bynode=None,colsample_bytree=None,device=None,early_stopping_rounds=None,enable_categorical=False,eval_metric=None,feature_types=None,gamma=None,grow_policy=None,importance_type=None,interaction_constraints=None,learning_rate=None,max_bin=None,max_cat_threshold=None,max_cat_to_onehot=None,max_delta_step=None,max_depth=None,max_leaves=None,min_child_weight=None,missing=nan,monotone_constraints=None,multi_strategy=None,n_estimators=None,n_jobs=None,num_parallel_tree=None,random_state=None,...)-check_requires_y_none] _
tests/python/test_with_sklearn.py:1349: in test_estimator_reg
check(estimator)
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/sklearn/utils/estimator_checks.py:3889: in check_requires_y_none
raise ve
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/sklearn/utils/estimator_checks.py:3886: in check_requires_y_none
estimator.fit(X, None)
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/xgboost/core.py:775: in inner_f
return func(**kwargs)
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/xgboost/sklearn.py:1145: in fit
self._Booster = train(
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/xgboost/core.py:775: in inner_f
return func(**kwargs)
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/xgboost/training.py:181: in train
bst.update(dtrain, iteration=i, fobj=obj)
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/xgboost/core.py:2218: in update
_check_call(
../../miniforge3/envs/lgb-dev/lib/python3.11/site-packages/xgboost/core.py:297: in _check_call
raise XGBoostError(py_str(_LIB.XGBGetLastError()))
E xgboost.core.XGBoostError: [23:42:04] /Users/jlamb/repos/xgboost/src/objective/init_estimation.h:16: Check failed: info.labels.Shape(0) == info.num_row_ (0 vs. 100) : Invalid shape of labels.
E Stack trace:
E [bt] (0) 1 libxgboost.dylib 0x0000000148b1e448 dmlc::LogMessageFatal::~LogMessageFatal() + 124
E [bt] (1) 2 libxgboost.dylib 0x0000000148d7cc28 xgboost::obj::CheckInitInputs(xgboost::MetaInfo const&) + 200
E [bt] (2) 3 libxgboost.dylib 0x0000000148d7fa14 xgboost::obj::FitIntercept::InitEstimation(xgboost::MetaInfo const&, xgboost::linalg::Tensor<float, 1>*) const + 68
E [bt] (3) 4 libxgboost.dylib 0x0000000148d28f64 xgboost::LearnerConfiguration::InitBaseScore(xgboost::DMatrix const*) + 252
E [bt] (4) 5 libxgboost.dylib 0x0000000148d15a1c xgboost::LearnerImpl::UpdateOneIter(int, std::__1::shared_ptr<xgboost::DMatrix>) + 140
E [bt] (5) 6 libxgboost.dylib 0x0000000148b41b6c XGBoosterUpdateOneIter + 144
E [bt] (6) 7 libffi.8.dylib 0x0000000105ebc04c ffi_call_SYSV + 76
E [bt] (7) 8 libffi.8.dylib 0x0000000105eb974c ffi_call_int + 1208
E [bt] (8) 9 _ctypes.cpython-311-darwin.so 0x0000000105e38988 _ctypes_callproc + 1208
================= 7 failed, 103 passed, 14 warnings in 19.76s ==================
Reproducible Example
On an M2 Mac, in a Python 3.11.9 conda environment, built the Python package from source.
@trivialfis@hcho3 I'd be happy to try to help with this over the next week if you'd like. I'm familiar with some of the changes in scikit-learn from this related work we've been doing in lightgbm:
I'd be happy to try to help with this over the next week if you'd like
Yes, please let me know if there's anything I can help. I can handle the C++ changes if needed; some checks are done inside libxgboost, and somehow, the error message requirements from sklearn are changed.
Description
The Python package's tests fail with the latest
scikit-learn
nightlies (v1.6.dev0).All the failures appear to be from the estimator checks
scikit-learn
ships to help projects test compliance withscikit-learn
API expectations. Stuff like this:full logs (click me)
Reproducible Example
On an M2 Mac, in a Python 3.11.9 conda environment, built the Python package from source.
Installed the latest
scikit-learn
nightlies.pip uninstall --yes scikit-learn pytest \ --disable-warnings \ --tb=short \ -rs \ ./tests/python/test_with_sklearn.py \ | tee ./out.txt
Saw the failures reported above.
Repeated that same process but with the latest release of
scikit-learn
.pip uninstall --yes scikit-learn pip install --no-deps 'scikit-learn==1.5.2'
All tests passed.
Notes
I found this while testing against this in-progress
scikit-learn
branch: scikit-learn/scikit-learn#28901 (comment)@trivialfis @hcho3 I'd be happy to try to help with this over the next week if you'd like. I'm familiar with some of the changes in
scikit-learn
from this related work we've been doing inlightgbm
:scikit-learn>=0.24.2
, make scikit-learn estimators compatible withscikit-learn>=1.6.0dev
microsoft/LightGBM#6651The text was updated successfully, but these errors were encountered: