From 2d814c43e2052acd4ec3fc0affb4a87d840fa4e0 Mon Sep 17 00:00:00 2001 From: Jinfeng Li Date: Tue, 5 Nov 2024 22:22:48 -0800 Subject: [PATCH] Revise to cuvs ivf_pq, add cosine for ivf_pq and support long_max indices for all ann algorithms. (#757) * squashed and rebased support derived class and cuvs ivf_pq add testing cosine for ivf_pq replace cuml ivfpq with cuvs ivf_pq fix less than k items probed and support long label dtype in create spark dataframe normalize dataset to unit norms for inner_product distances to avoid mg failure increase ivf_pq quantization to make its recall more stable remove normalization as it transform the dataset that leads to lower recall add case when less than k items are probed * rebased and second squash: improve test case for fewer k items probed fix bug relates to CPUNN revise per comments fix create_pyspark_dataframe to get it works for cp arrays as input fix bug on label of create_pyspark_dataframe fix bug tested in CPUNearestNeighbors model add refine to the knn.py for ivfpq in progress for checkout add debug info get ivf_pq cosine passed by increasing dataset std to make it separable get ivf_pq working after using refine remove unnecessary test for refine get refine work for less than k itmes probed replace df.withColumn with df.select to fix slowdown for df that was initialized with wide pd.DataFrame revise comment to make it more clear * ensure spark returns are consistent with cuvs when handling less than k items probed listening for future updates to consolidate behaviors of ivfflat, ivfpq and refine --- python/src/spark_rapids_ml/knn.py | 110 +++++++--- python/src/spark_rapids_ml/utils.py | 2 + .../test_approximate_nearest_neighbors.py | 190 ++++++++++++++++-- python/tests/utils.py | 33 ++- 4 files changed, 278 insertions(+), 57 deletions(-) diff --git a/python/src/spark_rapids_ml/knn.py b/python/src/spark_rapids_ml/knn.py index 96b3c85b..81d3dcaf 100644 --- a/python/src/spark_rapids_ml/knn.py +++ b/python/src/spark_rapids_ml/knn.py @@ -925,6 +925,12 @@ class ApproximateNearestNeighbors( k: int (default = 5) the default number of approximate nearest neighbors to retrieve for each query. + If fewer than k neighbors are found for a query (for example, due to a small nprobe value): + (1)In ivfflat and ivfpq: + (a) If no item vector is probed, the indices are filled with long_max (9,223,372,036,854,775,807) and distances are set to infinity. + (b) If at least one item vector is probed, the indices are filled with the top-1 neighbor's ID, and distances are filled with infinity. + (2) cagra does not have this problem, as at least itopk_size (where itopk_size ≥ k) items are always probed. + algorithm: str (default = 'ivfflat') the algorithm parameter to be passed into cuML. It currently must be 'ivfflat', 'ivfpq' or 'cagra'. Other algorithms are expected to be supported later. @@ -1329,6 +1335,30 @@ def _cal_cuvs_ivf_flat_params_and_check( return (ivfflat_index_params, ivfflat_search_params) + @classmethod + def _cal_cuvs_ivf_pq_params_and_check( + cls, algoParams: Optional[Dict[str, Any]], metric: str, topk: int + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + pq_index_params: Dict[str, Any] = {"metric": metric} + pq_search_params: Dict[str, Any] = {} + + if algoParams is not None: + for p in algoParams: + if p in {"n_probes", "nprobe"}: + pq_search_params["n_probes"] = algoParams[p] + elif p in {"lut_dtype", "internal_distance_dtype"}: + pq_search_params[p] = algoParams[p] + elif p in {"n_lists", "nlist"}: + pq_index_params["n_lists"] = algoParams[p] + elif p in {"M", "pq_dim"}: + pq_index_params["pq_dim"] = algoParams[p] + elif p in {"n_bits", "pq_bits"}: + pq_index_params["pq_bits"] = algoParams[p] + else: + pq_index_params[p] = algoParams[p] + + return (pq_index_params, pq_search_params) + def kneighbors( self, query_df: DataFrame, sort_knn_df_by_query_id: bool = True ) -> Tuple[DataFrame, DataFrame, DataFrame]: @@ -1416,12 +1446,17 @@ def _get_cuml_transform_func( "cosine", } - if cuml_alg_params["algorithm"] != "ivfpq": - check_fn = ( - self._cal_cagra_params_and_check - if cuml_alg_params["algorithm"] == "cagra" - else self._cal_cuvs_ivf_flat_params_and_check - ) + if ( + cuml_alg_params["algorithm"] != "brute" + ): # brute links to CPUNearestNeighborsModel of benchmark.bench_nearest_neighbors + if cuml_alg_params["algorithm"] == "cagra": + check_fn = self._cal_cagra_params_and_check + elif cuml_alg_params["algorithm"] in {"ivf_flat", "ivfflat"}: + check_fn = self._cal_cuvs_ivf_flat_params_and_check + else: + assert cuml_alg_params["algorithm"] in {"ivf_pq", "ivfpq"} + check_fn = self._cal_cuvs_ivf_pq_params_and_check + index_params, search_params = check_fn( algoParams=self.cuml_params["algo_params"], metric=self.cuml_params["metric"], @@ -1431,19 +1466,9 @@ def _get_cuml_transform_func( def _construct_sgnn() -> CumlT: if cuml_alg_params["algorithm"] in {"ivf_pq", "ivfpq"}: - from cuml.neighbors import NearestNeighbors as SGNN + from cuvs.neighbors import ivf_pq - # Currently 'usePrecomputedTables' is required by cuml cython API, though the value is ignored in C++. - if ( - cuml_alg_params["algorithm"] == "ivfpq" - and cuml_alg_params["algo_params"] - ): - if "usePrecomputedTables" not in cuml_alg_params["algo_params"]: - cuml_alg_params["algo_params"]["usePrecomputedTables"] = False - - nn_object = SGNN(output_type="cupy", **cuml_alg_params) - - return nn_object + return ivf_pq elif cuml_alg_params["algorithm"] in {"ivfflat" or "ivf_flat"}: from cuvs.neighbors import ivf_flat @@ -1470,7 +1495,7 @@ def _transform_internal( nn_object: CumlT, df: Union[pd.DataFrame, np.ndarray] ) -> pd.DataFrame: - item_row_number = df[row_number_col].to_numpy() + item_row_number = df[row_number_col].to_numpy(dtype=np.int64) item = df.drop(row_number_col, axis=1) # type: ignore if input_col is not None: assert len(item.columns) == 1 @@ -1498,12 +1523,9 @@ def _transform_internal( start_time = time.time() - from cuml.neighbors import NearestNeighbors as cumlSGNN - from cuvs.neighbors import cagra, ivf_flat - if not inspect.ismodule( nn_object - ): # ivfpq and derived class (e.g. benchmark.bench_nearest_neighbors.CPUNearestNeighborsModel) + ): # derived class (e.g. benchmark.bench_nearest_neighbors.CPUNearestNeighborsModel) nn_object.fit(item) else: # cuvs ivf_flat or cagra build_params = nn_object.IndexParams(**index_params) @@ -1544,35 +1566,57 @@ def _transform_internal( if not inspect.ismodule( nn_object - ): # ivfpq and derived class (e.g. benchmark.bench_nearest_neighbors.CPUNearestNeighborsModel) + ): # derived class (e.g. benchmark.bench_nearest_neighbors.CPUNearestNeighborsModel) distances, indices = nn_object.kneighbors(bcast_qfeatures.value) - else: # cuvs ivf_flat cagra + else: # cuvs ivf_flat cagra ivf_pq gpu_qfeatures = cp.array( bcast_qfeatures.value, order="C", dtype="float32" ) + assert cuml_alg_params["n_neighbors"] <= len( + item + ), "k is larger than the number of item vectors on a GPU. Please increase the dataset size or use less GPUs" + distances, indices = nn_object.search( nn_object.SearchParams(**search_params), index_obj, gpu_qfeatures, cuml_alg_params["n_neighbors"], ) + + if cuml_alg_params["algorithm"] in {"ivf_pq", "ivfpq"}: + from cuvs.neighbors import refine + + distances, indices = refine( + dataset=item, + queries=gpu_qfeatures, + candidates=indices, + k=cuml_alg_params["n_neighbors"], + metric=cuml_alg_params["metric"], + ) + distances = cp.asarray(distances) indices = cp.asarray(indices) - # Note cuML kneighbors applys an extra square root on the l2 distances. - # Here applies square to obtain the actual l2 distances. - if isinstance(nn_object, cumlSGNN): - if ( - cuml_alg_params["metric"] == "euclidean" - or cuml_alg_params["metric"] == "l2" - ): - distances = distances * distances + # in case refine API reset inf distances to 0. + if cuml_alg_params["algorithm"] in {"ivf_pq", "ivfpq"}: + distances[indices >= len(item)] = float("inf") + + # for the case top-1 nn got filled into indices + top1_ind = indices[:, 0] + rest_indices = indices[:, 1:] + rest_distances = distances[:, 1:] + rest_distances[rest_indices == top1_ind[:, cp.newaxis]] = float( + "inf" + ) if isinstance(distances, cp.ndarray): distances = distances.get() + # in case a query did not probe any items, indices are filled with int64 max and distances are filled with inf + item_row_number = np.append(item_row_number, np.iinfo("int64").max) if isinstance(indices, cp.ndarray): + indices[indices >= len(item)] = len(item) indices = indices.get() indices_global = item_row_number[indices] diff --git a/python/src/spark_rapids_ml/utils.py b/python/src/spark_rapids_ml/utils.py index acb1a8da..0edcf7fa 100644 --- a/python/src/spark_rapids_ml/utils.py +++ b/python/src/spark_rapids_ml/utils.py @@ -271,6 +271,8 @@ def dtype_to_pyspark_type(dtype: Union[np.dtype, str]) -> str: return "double" elif dtype == np.int32: return "int" + elif dtype == np.int64: + return "long" elif dtype == np.int16: return "short" elif dtype == np.int64: diff --git a/python/tests/test_approximate_nearest_neighbors.py b/python/tests/test_approximate_nearest_neighbors.py index 2d7de8ce..387cf4f8 100644 --- a/python/tests/test_approximate_nearest_neighbors.py +++ b/python/tests/test_approximate_nearest_neighbors.py @@ -1,5 +1,5 @@ import math -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import numpy as np import pandas as pd @@ -266,27 +266,22 @@ def compare_with_cuml_or_cuvs_sg( tolerance: float, ) -> None: # compare with cuml sg ANN on avg_recall and avg_dist_gap - if algorithm in {"ivfpq"}: - cumlsg_distances, cumlsg_indices = self.get_cuml_sg_results( - algorithm, algoParams - ) - else: - cumlsg_distances, cumlsg_indices = self.get_cuvs_sg_results( - algorithm=algorithm, algoParams=algoParams - ) + cuvssg_distances, cuvssg_indices = self.get_cuvs_sg_results( + algorithm=algorithm, algoParams=algoParams + ) # compare cuml sg with given results - avg_recall_cumlann = self.cal_avg_recall(cumlsg_indices) + avg_recall_cumlann = self.cal_avg_recall(cuvssg_indices) avg_recall = self.cal_avg_recall(given_indices) assert (avg_recall > avg_recall_cumlann) or abs( avg_recall - avg_recall_cumlann - ) < tolerance + ) <= tolerance - avg_dist_gap_cumlann = self.cal_avg_dist_gap(cumlsg_distances) + avg_dist_gap_cumlann = self.cal_avg_dist_gap(cuvssg_distances) avg_dist_gap = self.cal_avg_dist_gap(given_distances) - assert (avg_dist_gap < avg_dist_gap_cumlann) or abs( + assert (avg_dist_gap <= avg_dist_gap_cumlann) or abs( avg_dist_gap - avg_dist_gap_cumlann - ) < tolerance + ) <= tolerance def get_cuml_sg_results( self, @@ -335,6 +330,13 @@ def get_cuvs_sg_results( ) ) from cuvs.neighbors import ivf_flat as cuvs_algo + elif algorithm in {"ivf_pq", "ivfpq"}: + index_params, search_params = ( + ApproximateNearestNeighborsModel._cal_cuvs_ivf_pq_params_and_check( + algoParams=algoParams, metric=self.metric, topk=self.n_neighbors + ) + ) + from cuvs.neighbors import ivf_pq as cuvs_algo else: assert False, f"unrecognized algorithm {algorithm}" @@ -347,6 +349,13 @@ def get_cuvs_sg_results( cuvs_algo.SearchParams(**search_params), index, gpu_X, self.n_neighbors ) + if algorithm in {"ivf_pq", "ivfpq"}: + from cuvs.neighbors import refine + + sg_distances, sg_indices = refine( + gpu_X, gpu_X, sg_indices, self.n_neighbors, metric=self.metric + ) + # convert results to cp array then to np array sg_distances = cp.array(sg_distances).get() sg_indices = cp.array(sg_indices).get() @@ -363,6 +372,7 @@ def ann_algorithm_test_func( distances_are_exact: bool = True, tolerance: float = 1e-4, n_neighbors: int = 50, + cluster_std: float = 1.0, ) -> None: assert data_type in { @@ -399,6 +409,7 @@ def ann_algorithm_test_func( n_features=data_shape[1], centers=n_clusters, random_state=0, + cluster_std=cluster_std, ) # make_blobs creates a random dataset of isotropic gaussian blobs. # set average norm sq to be 1 to allow comparisons with default error thresholds @@ -452,7 +463,7 @@ def ann_algorithm_test_func( # test kneighbors: compare top-1 nn indices(self) and distances(self) - if metric != "inner_product" and distances_are_exact: + if metric != "inner_product": self_index = [knn[0] for knn in indices] assert np.all(self_index == y) @@ -608,7 +619,6 @@ def test_ivfflat( "nprobe": 20, "M": 20, "n_bits": 4, - "usePrecomputedTables": False, }, "euclidean", ), @@ -621,7 +631,6 @@ def test_ivfflat( "nprobe": 20, "M": 40, "n_bits": 4, - "usePrecomputedTables": True, }, "sqeuclidean", ), @@ -634,7 +643,6 @@ def test_ivfflat( "nprobe": 20, "M": 10, "n_bits": 8, - "usePrecomputedTables": False, }, "l2", ), @@ -650,6 +658,18 @@ def test_ivfflat( }, "inner_product", ), + ( + "ivfpq", + "array", + 3000, + { + "nlist": 100, + "nprobe": 20, + "M": 20, + "n_bits": 4, + }, + "cosine", + ), ], ) @pytest.mark.parametrize("data_shape", [(10000, 50)], ids=idfn) @@ -665,20 +685,31 @@ def test_ivfpq( ) -> None: """ (1) Currently the usePrecomputedTables is not used in cuml C++. - (2) ivfpq has become unstable in 24.10. It does not get passed with algoParam {"nlist" : 10, "nprobe" : 2, "M": 2, "n_bits": 4} in ci where test_ivfflat is run beforehand. avg_recall shows large variance, depending on the quantization accuracy. This can be fixed by increasing nlist, nprobe, M, and n_bits. + + (2) ivfpq has become unstable in 24.10. It does not get passed with algoParam {"nlist" : 10, "nprobe" : 2, "M": 2, "n_bits": 4} in ci where test_ivfflat is run beforehand. avg_recall shows large variance, depending on the quantization accuracy. This can be fixed by increasing nlist, nprobe, M, and n_bits. Note ivf_pq is non-deterministic, and it seems due to kmeans initialization leveraging runtime values of GPU memory. + + (3) In ivfpq, when the dataset itself is used as queries, it is observed sometimes that the top-1 indice may not be self, and top-1 distance may not be zero. + This is because ivfpq internally uses approximated distance, i.e. the distance of the query vector to the center of quantized item. """ combo = (algorithm, feature_type, max_records_per_batch, algo_params, metric) expected_avg_recall = 0.4 - distances_are_exact = False + distances_are_exact = True + expected_avg_dist_gap = 0.05 tolerance = 0.05 # tolerance increased to be more stable due to quantization and randomness in ivfpq, especially when expected_recall is low. + cluster_std = ( + 1.0 if metric != "cosine" else 10.0 + ) # Increasing cluster_std for cosine to make dataset more randomized and separable. + ann_algorithm_test_func( combo=combo, data_shape=data_shape, data_type=data_type, expected_avg_recall=expected_avg_recall, + expected_avg_dist_gap=expected_avg_dist_gap, distances_are_exact=distances_are_exact, tolerance=tolerance, + cluster_std=cluster_std, ) @@ -900,9 +931,126 @@ def test_ivfflat_wide_matrix( data_shape: Tuple[int, int], data_type: np.dtype, ) -> None: + """ + It seems adding a column with df.withColumn can be very slow, if df already has many columns (e.g. 3000). + One strategy is to avoid df.withColumn on wide df and use df.select instead. + """ import time start = time.time() ann_algorithm_test_func(combo=combo, data_shape=data_shape, data_type=data_type) duration_sec = time.time() - start - assert duration_sec < 10 * 60 + assert duration_sec < 3 * 60 + + +@pytest.mark.parametrize( + "algorithm,feature_type", + [ + ( + "ivfpq", + "array", + ), + ( + "ivfflat", + "vector", + ), + ], +) +@pytest.mark.parametrize("data_type", [np.float32]) +def test_return_fewer_k( + algorithm: str, + feature_type: str, + data_type: np.dtype, +) -> None: + """ + This tests the corner case where there are less than k neighbors found due to nprobe too small. + More details can be found at the docstring of class ApproximateNearestNeighbors. + """ + assert algorithm in {"ivfpq", "ivfflat"} + metric = "euclidean" + gpu_number = 1 + k = 4 + algo_params = { + "nlist": k, + "nprobe": 1, + } + + if algorithm == "ivfpq": + algo_params.update({"M": 2, "n_bits": 4}) + + X = np.array( + [ + ( + 0.0, + 0.0, + ), + ( + 0.0, + 0.0, + ), + ( + 2.0, + 2.0, + ), + ( + 2.0, + 2.0, + ), + ] + ) + y = np.arange(len(X)) # use label column as id column + + with CleanSparkSession() as spark: + df, features_col, label_col = create_pyspark_dataframe( + spark, feature_type, data_type, X, y, label_dtype=np.dtype(np.int64) + ) + + est = ApproximateNearestNeighbors( + num_workers=gpu_number, + algorithm=algorithm, + algoParams=algo_params, + metric=metric, + k=k, + inputCol="features", + idCol=label_col, + ) + model = est.fit(df) + _, _, knn_df = model.kneighbors(df) + knn_df_collect = knn_df.collect() + + int64_max = np.iinfo("int64").max + float_inf = float("inf") + + # ensure consistency with cuvs for ivfflat, and ivfpq > 24.10 + import cuvs + from packaging import version + + if algorithm == "ivfflat" or version.parse(cuvs.__version__) > version.parse( + "24.10.00" + ): + ann_evaluator = ANNEvaluator(X, k, metric) + spark_indices = np.array([row["indices"] for row in knn_df_collect]) + spark_distances = np.array([row["distances"] for row in knn_df_collect]) + ann_evaluator.compare_with_cuml_or_cuvs_sg( + algorithm, algo_params, spark_indices, spark_distances, tolerance=0.0 + ) + + # check result details + indices_none_probed = [int64_max, int64_max, int64_max, int64_max] + distances_none_probed = [float_inf, float_inf, float_inf, float_inf] + + def check_row_results( + i: int, indices_if_probed: List[int], distances_if_probed: List[float] + ) -> None: + assert i == 0 or i == 2 + j = i + 1 + assert knn_df_collect[i]["indices"] == knn_df_collect[j]["indices"] + assert knn_df_collect[i]["distances"] == knn_df_collect[j]["distances"] + if knn_df_collect[i]["indices"] == indices_none_probed: + assert knn_df_collect[i]["distances"] == distances_none_probed + else: + assert knn_df_collect[i]["indices"] == indices_if_probed + assert knn_df_collect[i]["distances"] == distances_if_probed + + check_row_results(0, [0, 1, 0, 0], [0.0, 0.0, float_inf, float_inf]) + check_row_results(2, [2, 3, 2, 2], [0.0, 0.0, float_inf, float_inf]) diff --git a/python/tests/utils.py b/python/tests/utils.py index 379e0627..c2c61281 100644 --- a/python/tests/utils.py +++ b/python/tests/utils.py @@ -19,6 +19,7 @@ from typing import Any, Dict, Iterator, List, Optional, Tuple, TypeVar, Union import numpy as np +import pandas as pd import pyspark from pyspark.ml.feature import VectorAssembler from pyspark.sql import SparkSession @@ -80,10 +81,17 @@ def create_pyspark_dataframe( dtype: np.dtype, data: np.ndarray, label: Optional[np.ndarray] = None, + label_dtype: Optional[np.dtype] = None, # type: ignore ) -> Tuple[pyspark.sql.DataFrame, Union[str, List[str]], Optional[str]]: """Construct a dataframe based on features and label data.""" assert feature_type in pyspark_supported_feature_types + # in case cp.ndarray get passed in + if not isinstance(data, np.ndarray): + data = data.get() + if label is not None and not isinstance(label, np.ndarray): + label = label.get() + m, n = data.shape pyspark_type = dtype_to_pyspark_type(dtype) @@ -92,17 +100,31 @@ def create_pyspark_dataframe( label_col = None if label is not None: + label_dtype = dtype if label_dtype is None else label_dtype + label = label.astype(label_dtype) + label_pyspark_type = dtype_to_pyspark_type(label_dtype) + label_col = "label_col" - schema.append(f"{label_col} {pyspark_type}") + schema.append(f"{label_col} {label_pyspark_type}") + + pdf = pd.DataFrame(data, dtype=dtype, columns=feature_cols) + pdf[label_col] = label.astype(label_dtype) df = spark.createDataFrame( - np.concatenate((data, label.reshape(m, 1)), axis=1).tolist(), + pdf, ",".join(schema), ) else: df = spark.createDataFrame(data.tolist(), ",".join(schema)) if feature_type == feature_types.array: - df = df.withColumn("features", array(*feature_cols)).drop(*feature_cols) + # avoid calling df.withColumn here because runtime slowdown is observed when df has many columns (e.g. 3000). + from pyspark.sql.functions import col + + selected_col = [array(*feature_cols).alias("features")] + if label_col: + selected_col.append(col(label_col).alias(label_col)) + df = df.select(selected_col) + feature_cols = "features" elif feature_type == feature_types.vector: df = ( @@ -113,6 +135,11 @@ def create_pyspark_dataframe( .drop(*feature_cols) ) feature_cols = "features" + else: + # When df has many columns (e.g. 3000), and was created by calling spark.createDataFrame on a pandas DataFrame, + # calling df.withColumn can lead to noticeable runtime slowdown. + # Using select here can significantly reduce the runtime and improve the performance. + df = df.select("*") return df, feature_cols, label_col