Skip to content

Commit

Permalink
Capture an assert failure from cuvs cagra and rephrase error message …
Browse files Browse the repository at this point in the history
…to include relevant params (#771)

* add test case for convertingn various dtypes to float32 for cagra

* handle cuvs assert on intermediate_graph_degree of ivf_pq cagra

---------

Signed-off-by: Jinfeng <jinfengl@nvidia.com>
  • Loading branch information
lijinf2 authored Nov 5, 2024
1 parent 453fd94 commit 148d8ec
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 3 deletions.
28 changes: 27 additions & 1 deletion python/src/spark_rapids_ml/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,6 +1099,12 @@ def __init__(
"ivfpq",
"cagra",
}, "currently only ivfflat, ivfpq, and cagra are supported"
if not self._input_kwargs.get("float32_inputs", True):
get_logger(self.__class__).warning(
"This estimator supports only float32 inputs on GPU and will convert all other data types to float32. Setting float32_inputs to False will be ignored."
)
self._input_kwargs.pop("float32_inputs")

self._set_params(**self._input_kwargs)

def _fit(self, item_df: DataFrame) -> "ApproximateNearestNeighborsModel": # type: ignore
Expand Down Expand Up @@ -1508,7 +1514,27 @@ def _transform_internal(
if isinstance(item, np.ndarray):
item = cp.array(item, dtype="float32")

index_obj = nn_object.build(build_params, item)
try:
index_obj = nn_object.build(build_params, item)
except Exception as e:
if "k must be less than topk::kMaxCapacity (256)" in str(e):
from cuvs.neighbors import cagra

assert nn_object == cagra
assert (
"build_algo" not in index_params
or index_params["build_algo"] == "ivf_pq"
)

intermediate_graph_degree = (
build_params.intermediate_graph_degree
)
assert intermediate_graph_degree >= 256

error_msg = f"cagra with ivf_pq build_algo expects intermediate_graph_degree ({intermediate_graph_degree}) to be smaller than 256"
raise ValueError(error_msg)
else:
raise e

logger.info(
f"partition {pid} indexing finished in {time.time() - start_time} seconds."
Expand Down
2 changes: 2 additions & 0 deletions python/src/spark_rapids_ml/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,8 @@ def dtype_to_pyspark_type(dtype: Union[np.dtype, str]) -> str:
return "int"
elif dtype == np.int16:
return "short"
elif dtype == np.int64:
return "long"
else:
raise RuntimeError("Unsupported dtype, found ", dtype)

Expand Down
82 changes: 80 additions & 2 deletions python/tests/test_approximate_nearest_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,11 @@ def ann_algorithm_test_func(
n_neighbors: int = 50,
) -> None:

assert data_type in {
np.float32,
np.float64,
}, "the test function applies to float dataset dtype only, as it scales the dataset by the average norm of rows"

algorithm = combo[0]
assert algorithm in {"ivfflat", "ivfpq", "cagra"}

Expand Down Expand Up @@ -758,6 +763,62 @@ def test_cagra(
)


@pytest.mark.parametrize(
"feature_type,data_type",
[
("vector", np.float64),
("multi_cols", np.float64),
("multi_cols", np.int16),
("array", np.int64),
],
)
@pytest.mark.slow
def test_cagra_dtype(
feature_type: str,
data_type: np.dtype,
) -> None:

algorithm = "cagra"
algo_params = {
"intermediate_graph_degree": 128,
"graph_degree": 64,
"build_algo": "ivf_pq",
}

gpu_number = 1
n_neighbors = 2
metric = "sqeuclidean"
X = np.array(
[
[10.0, 10.0],
[20.0, 20.0],
[40.0, 40.0],
[50.0, 50.0],
],
dtype="int32",
)
X = X.astype(data_type)
y = np.array(range(len(X)))
with CleanSparkSession() as spark:
data_df, features_col, label_col = create_pyspark_dataframe(
spark, feature_type, data_type, X, y
)

gpu_knn = ApproximateNearestNeighbors(
num_workers=gpu_number,
inputCol=features_col,
idCol=label_col,
k=n_neighbors,
metric=metric,
algorithm=algorithm,
algoParams=algo_params,
)

gpu_model = gpu_knn.fit(data_df)
(_, _, knn_df) = gpu_model.kneighbors(data_df)
knn_df.show()


@pytest.mark.parametrize(
"algorithm,feature_type,max_records_per_batch,algo_params,metric",
[
Expand All @@ -773,18 +834,18 @@ def test_cagra(
),
],
)
@pytest.mark.parametrize("data_shape", [(10000, 50)], ids=idfn)
@pytest.mark.parametrize("data_type", [np.float32])
def test_cagra_params(
algorithm: str,
feature_type: str,
max_records_per_batch: int,
algo_params: Dict[str, Any],
metric: str,
data_shape: Tuple[int, int],
data_type: np.dtype,
caplog: LogCaptureFixture,
) -> None:

data_shape = (1000, 20)
itopk_size = 64 if "itopk_size" not in algo_params else algo_params["itopk_size"]

internal_topk_size = math.ceil(itopk_size / 32) * 32
Expand All @@ -805,6 +866,23 @@ def test_cagra_params(
n_neighbors=n_neighbors,
)

# test intermediate_graph_degree restriction on ivf_pq
algo_params["itopk_size"] = 64
algo_params["intermediate_graph_degree"] = 257
error_msg = f"cagra with ivf_pq build_algo expects intermediate_graph_degree (257) to be smaller than 256."
with pytest.raises(Exception):
test_cagra(
algorithm,
feature_type,
max_records_per_batch,
algo_params,
metric,
data_shape,
data_type,
n_neighbors=n_neighbors,
)
assert error_msg in caplog.text


@pytest.mark.parametrize(
"combo",
Expand Down

0 comments on commit 148d8ec

Please sign in to comment.