Skip to content

Commit

Permalink
Save model in ubj as the default. (#9947)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Jan 5, 2024
1 parent c03a4d5 commit 38dd91f
Show file tree
Hide file tree
Showing 23 changed files with 600 additions and 552 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,6 @@ import org.apache.spark.ml.param.Params
import org.apache.spark.ml.util.DefaultParamsReader.Metadata

abstract class XGBoostWriter extends MLWriter {

/** Currently it's using the "deprecated" format as
* default, which will be changed into `ubj` in future releases. */
def getModelFormat(): String = {
optionMap.getOrElse("format", JBooster.DEFAULT_FORMAT)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -432,28 +432,29 @@ class XGBoostClassifierSuite extends AnyFunSuite with PerTest with TmpFolderPerS
val xgb = new XGBoostClassifier(paramMap)
val model = xgb.fit(trainingDF)

// test json
val modelPath = new File(tempDir.toFile, "xgbc").getPath
model.write.option("format", "json").save(modelPath)
val nativeJsonModelPath = new File(tempDir.toFile, "nativeModel.json").getPath
model.nativeBooster.saveModel(nativeJsonModelPath)
assert(compareTwoFiles(new File(modelPath, "data/XGBoostClassificationModel").getPath,
nativeJsonModelPath))

// test default "deprecated"
// test ubj
val modelUbjPath = new File(tempDir.toFile, "xgbcUbj").getPath
model.write.save(modelUbjPath)
val nativeDeprecatedModelPath = new File(tempDir.toFile, "nativeModel").getPath
model.nativeBooster.saveModel(nativeDeprecatedModelPath)
val nativeUbjModelPath = new File(tempDir.toFile, "nativeModel.ubj").getPath
model.nativeBooster.saveModel(nativeUbjModelPath)
assert(compareTwoFiles(new File(modelUbjPath, "data/XGBoostClassificationModel").getPath,
nativeDeprecatedModelPath))
nativeUbjModelPath))

// json file should be indifferent with ubj file
val modelJsonPath = new File(tempDir.toFile, "xgbcJson").getPath
model.write.option("format", "json").save(modelJsonPath)
val nativeUbjModelPath = new File(tempDir.toFile, "nativeModel1.ubj").getPath
model.nativeBooster.saveModel(nativeUbjModelPath)
val nativeUbjModelPath1 = new File(tempDir.toFile, "nativeModel1.ubj").getPath
model.nativeBooster.saveModel(nativeUbjModelPath1)
assert(!compareTwoFiles(new File(modelJsonPath, "data/XGBoostClassificationModel").getPath,
nativeUbjModelPath))
nativeUbjModelPath1))
}

test("native json model file should store feature_name and feature_type") {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright (c) 2014-2022 by Contributors
Copyright (c) 2014-2024 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -333,21 +333,24 @@ class XGBoostRegressorSuite extends AnyFunSuite with PerTest with TmpFolderPerSu
assert(compareTwoFiles(new File(modelPath, "data/XGBoostRegressionModel").getPath,
nativeJsonModelPath))

// test default "deprecated"
// test default "ubj"
val modelUbjPath = new File(tempDir.toFile, "xgbcUbj").getPath
model.write.save(modelUbjPath)
val nativeDeprecatedModelPath = new File(tempDir.toFile, "nativeModel").getPath
model.nativeBooster.saveModel(nativeDeprecatedModelPath)
assert(compareTwoFiles(new File(modelUbjPath, "data/XGBoostRegressionModel").getPath,
nativeDeprecatedModelPath))

// json file should be indifferent with ubj file
val modelJsonPath = new File(tempDir.toFile, "xgbcJson").getPath
model.write.option("format", "json").save(modelJsonPath)
val nativeUbjModelPath = new File(tempDir.toFile, "nativeModel1.ubj").getPath
val nativeUbjModelPath = new File(tempDir.toFile, "nativeModel.ubj").getPath
model.nativeBooster.saveModel(nativeUbjModelPath)
assert(!compareTwoFiles(new File(modelJsonPath, "data/XGBoostRegressionModel").getPath,

assert(compareTwoFiles(new File(modelUbjPath, "data/XGBoostRegressionModel").getPath,
nativeUbjModelPath))
}

// test the deprecated format
val modelDeprecatedPath = new File(tempDir.toFile, "modelDeprecated").getPath
model.write.option("format", "deprecated").save(modelDeprecatedPath)

val nativeDeprecatedModelPath = new File(tempDir.toFile, "nativeModel.deprecated").getPath
model.nativeBooster.saveModel(nativeDeprecatedModelPath)

assert(compareTwoFiles(new File(modelDeprecatedPath, "data/XGBoostRegressionModel").getPath,
nativeDeprecatedModelPath))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
* Booster for xgboost, this is a model API that support interactive build of a XGBoost Model
*/
public class Booster implements Serializable, KryoSerializable {
public static final String DEFAULT_FORMAT = "deprecated";
public static final String DEFAULT_FORMAT = "ubj";
private static final Log logger = LogFactory.getLog(Booster.class);
// handle to the booster.
private long handle = 0;
Expand Down Expand Up @@ -788,8 +788,7 @@ private Map<String, Double> getFeatureImportanceFromModel(
}

/**
* Save model into raw byte array. Currently it's using the deprecated format as
* default, which will be changed into `ubj` in future releases.
* Save model into raw byte array in the UBJSON ("ubj") format.
*
* @return the saved byte array
* @throws XGBoostError native error
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,8 +337,7 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
}

/**
* Save model into a raw byte array. Currently it's using the deprecated format as
* default, which will be changed into `ubj` in future releases.
* Save model into a raw byte array in the UBJSON ("ubj") format.
*/
@throws(classOf[XGBoostError])
def toByteArray: Array[Byte] = {
Expand Down
2 changes: 1 addition & 1 deletion python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2613,7 +2613,7 @@ def save_model(self, fname: Union[str, os.PathLike]) -> None:
else:
raise TypeError("fname must be a string or os PathLike")

def save_raw(self, raw_format: str = "deprecated") -> bytearray:
def save_raw(self, raw_format: str = "ubj") -> bytearray:
"""Save the model to a in memory buffer representation instead of file.
Parameters
Expand Down
5 changes: 2 additions & 3 deletions python-package/xgboost/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ def random_csc(t_id: int) -> sparse.csc_matrix:

def make_datasets_with_margin(
unweighted_strategy: strategies.SearchStrategy,
) -> Callable:
) -> Callable[[], strategies.SearchStrategy[TestDataset]]:
"""Factory function for creating strategies that generates datasets with weight and
base margin.
Expand Down Expand Up @@ -668,8 +668,7 @@ def weight_margin(draw: Callable) -> TestDataset:

# A strategy for drawing from a set of example datasets. May add random weights to the
# dataset
@memory.cache
def make_dataset_strategy() -> Callable:
def make_dataset_strategy() -> strategies.SearchStrategy[TestDataset]:
_unweighted_datasets_strategy = strategies.sampled_from(
[
TestDataset(
Expand Down
16 changes: 7 additions & 9 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1313,10 +1313,8 @@ XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char* fname) {

namespace {
void WarnOldModel() {
if (XGBOOST_VER_MAJOR >= 2) {
LOG(WARNING) << "Saving into deprecated binary model format, please consider using `json` or "
"`ubj`. Model format will default to JSON in XGBoost 2.2 if not specified.";
}
LOG(WARNING) << "Saving into deprecated binary model format, please consider using `json` or "
"`ubj`. Model format is default to UBJSON in XGBoost 2.1 if not specified.";
}
} // anonymous namespace

Expand All @@ -1339,14 +1337,14 @@ XGB_DLL int XGBoosterSaveModel(BoosterHandle handle, const char *fname) {
save_json(std::ios::out);
} else if (common::FileExtension(fname) == "ubj") {
save_json(std::ios::binary);
} else if (XGBOOST_VER_MAJOR == 2 && XGBOOST_VER_MINOR >= 2) {
LOG(WARNING) << "Saving model to JSON as default. You can use file extension `json`, `ubj` or "
"`deprecated` to choose between formats.";
save_json(std::ios::out);
} else {
} else if (common::FileExtension(fname) == "deprecated") {
WarnOldModel();
auto *bst = static_cast<Learner *>(handle);
bst->SaveModel(fo.get());
} else {
LOG(WARNING) << "Saving model in the UBJSON format as default. You can use file extension:"
" `json`, `ubj` or `deprecated` to choose between formats.";
save_json(std::ios::binary);
}
API_END();
}
Expand Down
2 changes: 2 additions & 0 deletions tests/ci_build/lint_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class LintersPaths:
"tests/python/test_quantile_dmatrix.py",
"tests/python/test_tree_regularization.py",
"tests/python/test_shap.py",
"tests/python/test_model_io.py",
"tests/python/test_with_pandas.py",
"tests/python-gpu/",
"tests/python-sycl/",
Expand Down Expand Up @@ -83,6 +84,7 @@ class LintersPaths:
"tests/python/test_multi_target.py",
"tests/python-gpu/test_gpu_data_iterator.py",
"tests/python-gpu/load_pickle.py",
"tests/python/test_model_io.py",
"tests/test_distributed/test_with_spark/test_data.py",
"tests/test_distributed/test_gpu_with_spark/test_data.py",
"tests/test_distributed/test_gpu_with_dask/test_gpu_with_dask.py",
Expand Down
Loading

0 comments on commit 38dd91f

Please sign in to comment.