diff --git a/python/benchmark/benchmark/bench_kmeans.py b/python/benchmark/benchmark/bench_kmeans.py index e1a0776f..5c2753fc 100644 --- a/python/benchmark/benchmark/bench_kmeans.py +++ b/python/benchmark/benchmark/bench_kmeans.py @@ -192,17 +192,6 @@ def gpu_cache_df(df: DataFrame) -> DataFrame: cluster_centers = gpu_model.cluster_centers_ - # temporary patch for DB with spark-rapids plugin - # this part is not timed so overhead is not critical, but should be reverted - # once https://github.com/NVIDIA/spark-rapids/issues/10770 is fixed - db_version = os.environ.get("DATABRICKS_RUNTIME_VERSION") - if db_version: - dim = len(cluster_centers[0]) - # inject unsupported expr (slice) that is essentially a noop - df_for_scoring = df_for_scoring.select( - F.slice(feature_col, 1, dim).alias(feature_col), output_col - ) - if num_cpus > 0: from pyspark.ml.clustering import KMeans as SparkKMeans diff --git a/python/benchmark/benchmark/bench_umap.py b/python/benchmark/benchmark/bench_umap.py index 94c33d36..246cb27a 100644 --- a/python/benchmark/benchmark/bench_umap.py +++ b/python/benchmark/benchmark/bench_umap.py @@ -19,10 +19,10 @@ import numpy as np from pandas import DataFrame as PandasDataFrame -from pyspark.ml.feature import VectorAssembler +from pyspark.ml.feature import StandardScaler, VectorAssembler from pyspark.ml.functions import array_to_vector, vector_to_array from pyspark.sql import DataFrame, SparkSession -from pyspark.sql.functions import col, sum +from pyspark.sql.functions import array, col, sum from benchmark.base import BenchmarkBase from benchmark.utils import inspect_default_params_from_func, with_benchmark @@ -105,7 +105,7 @@ def score( pdf: PandasDataFrame = transformed_df.toPandas() embedding = np.array(pdf[transformed_col].to_list()) - input = np.array(pdf[data_col].to_list()) + input = np.array(pdf[data_col].to_list()).astype(np.float32) score = trustworthiness(input, embedding, n_neighbors=15) return score @@ -162,39 +162,45 @@ def gpu_cache_df(df: DataFrame) -> DataFrame: else: gpu_estimator = gpu_estimator.setFeaturesCols(input_cols) - output_col = "embedding" - gpu_estimator = gpu_estimator.setOutputCol(output_col) - gpu_model, fit_time = with_benchmark( "gpu fit", lambda: gpu_estimator.fit(train_df) ) - def transform(model: UMAPModel, df: DataFrame) -> DataFrame: - transformed_df = model.transform(df) - transformed_df.count() - return transformed_df - - transformed_df, transform_time = with_benchmark( - "gpu transform", lambda: transform(gpu_model, train_df) + output_col = "embedding" + transformed_df = gpu_model.setOutputCol(output_col).transform(train_df) + _, transform_time = with_benchmark( + "gpu transform", lambda: transformed_df.foreach(lambda _: None) ) + total_time = round(time.time() - func_start_time, 2) print(f"gpu total took: {total_time} sec") - data_col = "features" + + df_for_scoring = transformed_df + feature_col = first_col + if not is_single_col: + feature_col = "features_array" + df_for_scoring = transformed_df.select( + array(*input_cols).alias("features_array"), output_col + ) + elif is_vector_col: + df_for_scoring = transformed_df.select( + vector_to_array(col(feature_col)).alias(feature_col), output_col + ) if num_cpus > 0: from pyspark.ml.feature import PCA as SparkPCA assert num_gpus <= 0 + if is_array_col: vector_df = train_df.select( array_to_vector(train_df[first_col]).alias(first_col) ) elif not is_vector_col: - vector_assembler = VectorAssembler(outputCol="features").setInputCols( + vector_assembler = VectorAssembler(outputCol=first_col).setInputCols( input_cols ) vector_df = vector_assembler.transform(train_df).drop(*input_cols) - first_col = "features" else: vector_df = train_df @@ -209,11 +215,10 @@ def cpu_cache_df(df: DataFrame) -> DataFrame: "prepare dataset", lambda: cpu_cache_df(vector_df) ) - output_col = "pca_features" - params = self.class_params print(f"Passing {params} to SparkPCA") + output_col = "pca_features" cpu_pca = SparkPCA(**params).setInputCol(first_col).setOutputCol(output_col) cpu_model, fit_time = with_benchmark( @@ -233,9 +238,27 @@ def cpu_transform(df: DataFrame) -> None: total_time = round(time.time() - func_start_time, 2) print(f"cpu total took: {total_time} sec") - data_col = first_col - score = self.score(transformed_df, data_col, output_col) + # spark ml does not remove the mean in the transformed features, so do that here + # needed for scoring + standard_scaler = ( + StandardScaler() + .setWithStd(False) + .setWithMean(True) + .setInputCol(output_col) + .setOutputCol(output_col + "_mean_removed") + ) + + scaler_model = standard_scaler.fit(transformed_df) + transformed_df = scaler_model.transform(transformed_df).drop(output_col) + + feature_col = first_col + output_col = output_col + "_mean_removed" + df_for_scoring = transformed_df.select( + vector_to_array(col(output_col)).alias(output_col), feature_col + ) + + score = self.score(df_for_scoring, feature_col, output_col) print(f"trustworthiness score: {score}") report_dict = { diff --git a/python/src/spark_rapids_ml/core.py b/python/src/spark_rapids_ml/core.py index 76d632fc..644c88a7 100644 --- a/python/src/spark_rapids_ml/core.py +++ b/python/src/spark_rapids_ml/core.py @@ -751,7 +751,9 @@ def _train_udf(pdf_iter: Iterator[pd.DataFrame]) -> pd.DataFrame: concated_nnz = sum(triplet[0].nnz for triplet in inputs) # type: ignore if concated_nnz > np.iinfo(np.int32).max: logger.warn( - "the number of non-zero values of a partition is larger than the int32 index dtype of cupyx csr_matrix" + f"The number of non-zero values of a partition exceeds the int32 index dtype. \ + cupyx csr_matrix currently does not promote the dtype to int64 when concatenated; \ + keeping as scipy csr_matrix to avoid overflow." ) else: inputs = [ diff --git a/python/src/spark_rapids_ml/umap.py b/python/src/spark_rapids_ml/umap.py index cdeb3a2f..2fc68498 100644 --- a/python/src/spark_rapids_ml/umap.py +++ b/python/src/spark_rapids_ml/umap.py @@ -34,6 +34,7 @@ import numpy as np import pandas as pd import pyspark +import scipy from pandas import DataFrame as PandasDataFrame from pyspark.ml.param.shared import ( HasFeaturesCol, @@ -50,6 +51,7 @@ ArrayType, DoubleType, FloatType, + IntegerType, Row, StructField, StructType, @@ -64,19 +66,30 @@ _CumlEstimatorSupervised, _CumlModel, _CumlModelReader, + _CumlModelWithColumns, _CumlModelWriter, _EvaluateFunc, + _read_csr_matrix_from_unwrapped_spark_vec, _TransformFunc, + _use_sparse_in_cuml, alias, param_alias, ) from .metrics import EvalMetricInfo -from .params import DictTypeConverters, HasFeaturesCols, P, _CumlClass, _CumlParams +from .params import ( + DictTypeConverters, + HasEnableSparseDataOptim, + HasFeaturesCols, + P, + _CumlClass, + _CumlParams, +) from .utils import ( _ArrayOrder, _concat_and_free, _get_spark_session, _is_local, + dtype_to_pyspark_type, get_logger, ) @@ -120,7 +133,12 @@ def _pyspark_class(self) -> Optional[ABCMeta]: class _UMAPCumlParams( - _CumlParams, HasFeaturesCol, HasFeaturesCols, HasLabelCol, HasOutputCol + _CumlParams, + HasFeaturesCol, + HasFeaturesCols, + HasLabelCol, + HasOutputCol, + HasEnableSparseDataOptim, ): def __init__(self) -> None: super().__init__() @@ -894,6 +912,9 @@ def __init__( labelCol: Optional[str] = None, outputCol: Optional[str] = None, num_workers: Optional[int] = None, + enable_sparse_data_optim: Optional[ + bool + ] = None, # will enable SparseVector inputs if first row is sparse (for any metric). **kwargs: Any, ) -> None: super().__init__() @@ -908,7 +929,6 @@ def __init__( ) assert max_records_per_batch_str is not None self.max_records_per_batch = int(max_records_per_batch_str) - self.BROADCAST_LIMIT = 8 << 30 def _create_pyspark_model(self, result: Row) -> _CumlModel: raise NotImplementedError("UMAP does not support model creation from Row") @@ -937,54 +957,36 @@ def _fit(self, dataset: DataFrame) -> "UMAPModel": pdf_output: PandasDataFrame = df_output.toPandas() - # Collect and concatenate row-by-row fit results - embeddings = np.array( - list( - pd.concat( - [pd.Series(x) for x in pdf_output["embedding_"]], ignore_index=True - ) - ), - dtype=np.float32, - ) - raw_data = np.array( - list( - pd.concat( - [pd.Series(x) for x in pdf_output["raw_data_"]], ignore_index=True - ) - ), - dtype=np.float32, - ) - del pdf_output - - def _chunk_arr( - arr: np.ndarray, BROADCAST_LIMIT: int = self.BROADCAST_LIMIT - ) -> List[np.ndarray]: - """Chunk an array, if oversized, into smaller arrays that can be broadcasted.""" - if arr.nbytes <= BROADCAST_LIMIT: - return [arr] - - rows_per_chunk = BROADCAST_LIMIT // (arr.nbytes // arr.shape[0]) - num_chunks = (arr.shape[0] + rows_per_chunk - 1) // rows_per_chunk - chunks = [ - arr[i * rows_per_chunk : (i + 1) * rows_per_chunk] - for i in range(num_chunks) - ] - - return chunks + if self._sparse_fit: + embeddings = np.array( + list( + pd.concat( + [pd.Series(x) for x in pdf_output["embedding_"]], + ignore_index=True, + ) + ), + dtype=np.float32, + ) + pdf_output["raw_data_"] = pdf_output.apply( + lambda row: scipy.sparse.csr_matrix( + (row["data"], row["indices"], row["indptr"]), + shape=row["shape"], + ).astype(np.float32), + axis=1, + ) + raw_data = scipy.sparse.vstack(pdf_output["raw_data_"], format="csr") + else: + embeddings = np.vstack(pdf_output["embedding_"]).astype(np.float32) + raw_data = np.vstack(pdf_output["raw_data_"]).astype(np.float32) # type: ignore - spark = _get_spark_session() - broadcast_embeddings = [ - spark.sparkContext.broadcast(chunk) for chunk in _chunk_arr(embeddings) - ] - broadcast_raw_data = [ - spark.sparkContext.broadcast(chunk) for chunk in _chunk_arr(raw_data) - ] + del pdf_output model = UMAPModel( - embedding_=broadcast_embeddings, - raw_data_=broadcast_raw_data, - n_cols=len(raw_data[0]), - dtype=type(raw_data[0][0]).__name__, + embedding_=embeddings, + raw_data_=raw_data, + sparse_fit=self._sparse_fit, + n_cols=self._n_cols, + dtype="float32", # UMAP only supports float ) model._num_workers = input_num_workers @@ -1065,7 +1067,8 @@ def _call_cuml_fit_func_dataframe( cls = self.__class__ - select_cols, multi_col_names, _, _ = self._pre_process_data(dataset) + select_cols, multi_col_names, dimension, _ = self._pre_process_data(dataset) + self._n_cols = dimension dataset = dataset.select(*select_cols) @@ -1091,6 +1094,11 @@ def _call_cuml_fit_func_dataframe( cuml_verbose = self.cuml_params.get("verbose", False) + use_sparse_array = _use_sparse_in_cuml(dataset) + self._sparse_fit = use_sparse_array # param stored internally by cuml model + if self.cuml_params.get("metric") == "jaccard" and not use_sparse_array: + raise ValueError("Metric 'jaccard' not supported for dense inputs.") + chunk_size = self.max_records_per_batch def _train_udf(pdf_iter: Iterable[pd.DataFrame]) -> Iterable[pd.DataFrame]: @@ -1100,6 +1108,7 @@ def _train_udf(pdf_iter: Iterable[pd.DataFrame]) -> Iterable[pd.DataFrame]: logger.info("Initializing cuml context") import cupy as cp + import cupyx if cuda_managed_mem_enabled: import rmm @@ -1118,17 +1127,20 @@ def _train_udf(pdf_iter: Iterable[pd.DataFrame]) -> Iterable[pd.DataFrame]: # handle the input # inputs = [(X, Optional(y)), (X, Optional(y))] logger.info("Loading data into python worker memory") - inputs = [] - sizes = [] + inputs: List[Any] = [] + sizes: List[int] = [] + for pdf in pdf_iter: sizes.append(pdf.shape[0]) if multi_col_names: features = np.array(pdf[multi_col_names], order=array_order) + elif use_sparse_array: + # sparse vector input + features = _read_csr_matrix_from_unwrapped_spark_vec(pdf) else: + # dense input features = np.array(list(pdf[alias.data]), order=array_order) - # experiments indicate it is faster to convert to numpy array and then to cupy array than directly - # invoking cupy array on the list - if cuda_managed_mem_enabled: + if cuda_managed_mem_enabled and not use_sparse_array: features = cp.array(features) label = pdf[alias.label] if alias.label in pdf.columns else None @@ -1137,10 +1149,25 @@ def _train_udf(pdf_iter: Iterable[pd.DataFrame]) -> Iterable[pd.DataFrame]: ) inputs.append((features, label, row_number)) + if cuda_managed_mem_enabled and use_sparse_array: + concated_nnz = sum(triplet[0].nnz for triplet in inputs) # type: ignore + if concated_nnz > np.iinfo(np.int32).max: + logger.warn( + f"The number of non-zero values of a partition exceeds the int32 index dtype. \ + cupyx csr_matrix currently does not promote the dtype to int64 when concatenated; \ + keeping as scipy csr_matrix to avoid overflow." + ) + else: + inputs = [ + (cupyx.scipy.sparse.csr_matrix(row[0]), row[1], row[2]) + for row in inputs + ] + # call the cuml fit function # *note*: cuml_fit_func may delete components of inputs to free # memory. do not rely on inputs after this call. embedding, raw_data = cuml_fit_func(inputs, params).values() + logger.info("Cuml fit complete") num_sections = (len(embedding) + chunk_size - 1) // chunk_size @@ -1148,15 +1175,29 @@ def _train_udf(pdf_iter: Iterable[pd.DataFrame]) -> Iterable[pd.DataFrame]: for i in range(num_sections): start = i * chunk_size end = min((i + 1) * chunk_size, len(embedding)) - - yield pd.DataFrame( - data=[ + if use_sparse_array: + csr_chunk = raw_data[start:end] + indices = csr_chunk.indices + indptr = csr_chunk.indptr + data = csr_chunk.data + yield pd.DataFrame( + data=[ + { + "embedding_": embedding[start:end].tolist(), + "indices": indices.tolist(), + "indptr": indptr.tolist(), + "data": data.tolist(), + "shape": [end - start, dimension], + } + ] + ) + else: + yield pd.DataFrame( { "embedding_": embedding[start:end].tolist(), "raw_data_": raw_data[start:end].tolist(), } - ] - ) + ) output_df = dataset.mapInPandas(_train_udf, schema=self._out_schema()) @@ -1166,20 +1207,27 @@ def _require_nccl_ucx(self) -> Tuple[bool, bool]: return (False, False) def _out_schema(self) -> Union[StructType, str]: - return StructType( - [ - StructField( - "embedding_", - ArrayType(ArrayType(FloatType(), False), False), - False, - ), - StructField( - "raw_data_", - ArrayType(ArrayType(FloatType(), False), False), - False, - ), - ] - ) + if self._sparse_fit: + return StructType( + [ + StructField( + "embedding_", + ArrayType(ArrayType(FloatType(), False), False), + False, + ), + StructField("indices", ArrayType(IntegerType(), False), False), + StructField("indptr", ArrayType(IntegerType(), False), False), + StructField("data", ArrayType(FloatType(), False), False), + StructField("shape", ArrayType(IntegerType(), False), False), + ] + ) + else: + return StructType( + [ + StructField("embedding_", ArrayType(FloatType()), False), + StructField("raw_data_", ArrayType(FloatType()), False), + ] + ) def _pre_process_data( self, dataset: DataFrame @@ -1201,36 +1249,47 @@ def _pre_process_data( return select_cols, multi_col_names, dimension, feature_type -class UMAPModel(_CumlModel, UMAPClass, _UMAPCumlParams): +class UMAPModel(_CumlModelWithColumns, UMAPClass, _UMAPCumlParams): def __init__( self, - embedding_: List[pyspark.broadcast.Broadcast], - raw_data_: List[pyspark.broadcast.Broadcast], + embedding_: np.ndarray, + raw_data_: Union[ + np.ndarray, + scipy.sparse.csr_matrix, + ], + sparse_fit: bool, n_cols: int, dtype: str, ) -> None: super(UMAPModel, self).__init__( embedding_=embedding_, raw_data_=raw_data_, + sparse_fit=sparse_fit, n_cols=n_cols, dtype=dtype, ) self.embedding_ = embedding_ self.raw_data_ = raw_data_ + self._sparse_fit = sparse_fit # If true, raw data is a sparse CSR matrix + self.BROADCAST_LIMIT = 8 << 30 # Spark broadcast limit: 8GiB @property - def embedding(self) -> List[List[float]]: - res = [] - for chunk in self.embedding_: - res.extend(chunk.value.tolist()) - return res + def embedding(self) -> np.ndarray: + """ + Returns the model embeddings. + """ + return ( + self.embedding_ + ) # TBD: return a more Spark-like object, e.g. DenseMatrix? @property - def raw_data(self) -> List[List[float]]: - res = [] - for chunk in self.raw_data_: - res.extend(chunk.value.tolist()) - return res + def rawData(self) -> Union[np.ndarray, scipy.sparse.csr_matrix]: + """ + Returns the raw data used to fit the model. If the input data was sparse, this will be a scipy csr matrix. + """ + return ( + self.raw_data_ + ) # TBD: return a more Spark-like object, e.g. DenseMatrix or SparseMatrix? def _get_cuml_transform_func( self, dataset: DataFrame, eval_metric_info: Optional[EvalMetricInfo] = None @@ -1240,9 +1299,53 @@ def _get_cuml_transform_func( Optional[_EvaluateFunc], ]: cuml_alg_params = self.cuml_params - driver_embedding = self.embedding_ - driver_raw_data = self.raw_data_ - outputCol = self.getOutputCol() + sparse_fit = self._sparse_fit + n_cols = self.n_cols + + def _chunk_and_broadcast( + sc: pyspark.SparkContext, + arr: np.ndarray, + BROADCAST_LIMIT: int, + ) -> List[pyspark.broadcast.Broadcast]: + """ + Broadcast the input array, chunking it into smaller arrays if it exceeds the broadcast limit. + """ + if arr.nbytes < BROADCAST_LIMIT: + return [sc.broadcast(arr)] + + rows_per_chunk = BROADCAST_LIMIT // (arr.nbytes // arr.shape[0]) + if rows_per_chunk == 0: + raise ValueError( + f"Array cannot be chunked into broadcastable pieces: \ + single row exceeds broadcast limit ({BROADCAST_LIMIT} bytes)" + ) + num_chunks = (arr.shape[0] + rows_per_chunk - 1) // rows_per_chunk + return [ + sc.broadcast(arr[i * rows_per_chunk : (i + 1) * rows_per_chunk]) + for i in range(num_chunks) + ] + + spark = _get_spark_session() + broadcast_embeddings = _chunk_and_broadcast( + spark.sparkContext, self.embedding_, self.BROADCAST_LIMIT + ) + + if isinstance(self.raw_data_, scipy.sparse.csr_matrix): + broadcast_raw_data = { + "indices": _chunk_and_broadcast( + spark.sparkContext, self.raw_data_.indices, self.BROADCAST_LIMIT + ), + "indptr": _chunk_and_broadcast( + spark.sparkContext, self.raw_data_.indptr, self.BROADCAST_LIMIT + ), + "data": _chunk_and_broadcast( + spark.sparkContext, self.raw_data_.data, self.BROADCAST_LIMIT + ), + } # NOTE: CSR chunks are not independently meaningful; do not use until recombined. + else: + broadcast_raw_data = _chunk_and_broadcast( + spark.sparkContext, self.raw_data_, self.BROADCAST_LIMIT + ) # type: ignore def _construct_umap() -> CumlT: import cupy as cp @@ -1252,28 +1355,52 @@ def _construct_umap() -> CumlT: from .utils import cudf_to_cuml_array - nonlocal driver_embedding, driver_raw_data + nonlocal broadcast_embeddings, broadcast_raw_data + assert isinstance(broadcast_embeddings, list) embedding = ( - driver_embedding[0].value - if len(driver_embedding) == 1 - else np.concatenate([chunk.value for chunk in driver_embedding]) - ) - raw_data = ( - driver_raw_data[0].value - if len(driver_raw_data) == 1 - else np.concatenate([chunk.value for chunk in driver_raw_data]) + broadcast_embeddings[0].value + if len(broadcast_embeddings) == 1 + else np.concatenate([chunk.value for chunk in broadcast_embeddings]) ) - del driver_embedding - del driver_raw_data + if sparse_fit: + if not isinstance(broadcast_raw_data, dict): + raise ValueError("Expected raw data as a CSR dict for sparse fit.") + indices = np.concatenate( + [chunk.value for chunk in broadcast_raw_data["indices"]] + ) + indptr = np.concatenate( + [chunk.value for chunk in broadcast_raw_data["indptr"]] + ) + data = np.concatenate( + [chunk.value for chunk in broadcast_raw_data["data"]] + ) + raw_data = scipy.sparse.csr_matrix( + (data, indices, indptr), shape=(len(indptr) - 1, n_cols) + ) + else: + if not isinstance(broadcast_raw_data, list): + raise ValueError( + "Expected raw data as list (of lists) for dense fit." + ) + raw_data = ( + broadcast_raw_data[0].value + if len(broadcast_raw_data) == 1 + else np.concatenate([chunk.value for chunk in broadcast_raw_data]) + ) + + del broadcast_embeddings + del broadcast_raw_data if embedding.dtype != np.float32: embedding = embedding.astype(np.float32) raw_data = raw_data.astype(np.float32) if is_sparse(raw_data): - raw_data_cuml = SparseCumlArray(raw_data, convert_format=False) + raw_data_cuml = SparseCumlArray( + raw_data, + ) else: raw_data_cuml = cudf_to_cuml_array( raw_data, @@ -1283,35 +1410,28 @@ def _construct_umap() -> CumlT: internal_model = CumlUMAP(**cuml_alg_params) internal_model.embedding_ = cp.array(embedding).data internal_model._raw_data = raw_data_cuml + internal_model.sparse_fit = sparse_fit return internal_model def _transform_internal( umap: CumlT, - df: Union[pd.DataFrame, np.ndarray], - ) -> pd.Series: - embedding = umap.transform(df) + df: Union[pd.DataFrame, np.ndarray, scipy.sparse._csr.csr_matrix], + ) -> pd.DataFrame: - is_df_np = isinstance(df, np.ndarray) - is_emb_np = isinstance(embedding, np.ndarray) + embedding = umap.transform(df) # Input is either numpy array or pandas dataframe - input_list = [ - df[i, :] if is_df_np else df.iloc[i, :] for i in range(df.shape[0]) # type: ignore - ] emb_list = [ - embedding[i, :] if is_emb_np else embedding.iloc[i, :] + ( + embedding[i, :] + if isinstance(embedding, np.ndarray) + else embedding.iloc[i, :] + ) for i in range(embedding.shape[0]) ] - result = pd.DataFrame( - { - "features": input_list, - outputCol: emb_list, - } - ) - - return result + return pd.Series(emb_list) return _construct_umap, _transform_internal, None @@ -1319,23 +1439,9 @@ def _require_nccl_ucx(self) -> Tuple[bool, bool]: return (False, False) def _out_schema(self, input_schema: StructType) -> Union[StructType, str]: - return StructType( - [ - StructField("features", ArrayType(FloatType(), False), False), - StructField(self.getOutputCol(), ArrayType(FloatType(), False), False), - ] - ) - - def _get_model_attributes(self) -> Optional[Dict[str, Any]]: - """ - Override parent method to bring broadcast variables to driver before JSON serialization. - """ - - self._model_attributes["embedding_"] = [ - chunk.value for chunk in self.embedding_ - ] - self._model_attributes["raw_data_"] = [chunk.value for chunk in self.raw_data_] - return self._model_attributes + assert self.dtype is not None + pyspark_type = dtype_to_pyspark_type(self.dtype) + return f"array<{pyspark_type}>" def write(self) -> MLWriter: return _CumlModelWriterNumpy(self) @@ -1367,14 +1473,16 @@ def saveImpl(self, path: str) -> None: if not os.path.exists(data_path): os.makedirs(data_path) assert model_attributes is not None - for key, value in model_attributes.items(): - if isinstance(value, list) and isinstance(value[0], np.ndarray): - paths = [] - for idx, chunk in enumerate(value): - array_path = os.path.join(data_path, f"{key}_{idx}.npy") - np.save(array_path, chunk) - paths.append(array_path) - model_attributes[key] = paths + + for key in ["embedding_", "raw_data_"]: + array = model_attributes[key] + if isinstance(array, scipy.sparse.csr_matrix): + npz_path = os.path.join(data_path, f"{key}csr_.npz") + scipy.sparse.save_npz(npz_path, array) + else: + npz_path = os.path.join(data_path, f"{key}.npz") + np.savez_compressed(npz_path, array) + model_attributes[key] = npz_path metadata_file_path = os.path.join(data_path, "metadata.json") model_attributes_str = json.dumps(model_attributes) @@ -1396,14 +1504,13 @@ def load(self, path: str) -> "_CumlEstimator": model_attr_str = self.sc.textFile(metadata_file_path).collect()[0] model_attr_dict = json.loads(model_attr_str) - for key, value in model_attr_dict.items(): - if isinstance(value, list) and value[0].endswith(".npy"): - arrays = [] - spark = _get_spark_session() - for array_path in value: - array = np.load(array_path) - arrays.append(spark.sparkContext.broadcast(array)) - model_attr_dict[key] = arrays + for key in ["embedding_", "raw_data_"]: + npz_path = model_attr_dict[key] + if npz_path.endswith("csr_.npz"): + model_attr_dict[key] = scipy.sparse.load_npz(npz_path) + else: + with np.load(npz_path) as data: + model_attr_dict[key] = data["arr_0"] instance = self.model_cls(**model_attr_dict) DefaultParamsReader.getAndSetParams(instance, metadata) diff --git a/python/tests/test_umap.py b/python/tests/test_umap.py index f31baea0..a4cd0f17 100644 --- a/python/tests/test_umap.py +++ b/python/tests/test_umap.py @@ -22,7 +22,9 @@ import pytest from _pytest.logging import LogCaptureFixture from cuml.metrics import trustworthiness +from pyspark.ml.linalg import SparseVector from pyspark.sql.functions import array +from scipy.sparse import csr_matrix from sklearn.datasets import load_digits, load_iris from spark_rapids_ml.umap import UMAP, UMAPModel @@ -37,6 +39,57 @@ ) +def _load_sparse_binary_data( + n_rows: int, n_cols: int, nnz: int +) -> Tuple[List[Tuple[SparseVector]], csr_matrix]: + # TODO: Replace this function by adding to SparseDataGen + # Generate binary sparse data compatible with Jaccard, with nnz non-zero values per row. + data = [] + for i in range(n_rows): + indices = [(i + j) % n_cols for j in range(nnz)] + values = [1] * nnz + sparse_vector = SparseVector(n_cols, dict(zip(indices, values))) + data.append((sparse_vector,)) + + csr_data: List[float] = [] + csr_indices: List[int] = [] + csr_indptr: List[int] = [0] + for row in data: + sparse_vector = row[0] + csr_data.extend(sparse_vector.values) + csr_indices.extend(sparse_vector.indices) + csr_indptr.append(csr_indptr[-1] + len(sparse_vector.indices)) + csr_mat = csr_matrix((csr_data, csr_indices, csr_indptr), shape=(n_rows, n_cols)) + + return data, csr_mat + + +def _assert_umap_model( + model: UMAPModel, input_raw_data: Union[np.ndarray, csr_matrix] +) -> None: + embedding = model.embedding + raw_data = model.rawData + assert embedding.shape == ( + input_raw_data.shape[0], + model.cuml_params["n_components"], + ) + assert raw_data.shape == input_raw_data.shape + if isinstance(input_raw_data, csr_matrix): + assert isinstance(raw_data, csr_matrix) + assert model._sparse_fit + assert (raw_data != input_raw_data).nnz == 0 + assert ( + np.all(raw_data.indices == input_raw_data.indices) + and np.all(raw_data.indptr == input_raw_data.indptr) + and np.allclose(raw_data.data, input_raw_data.data) + ) + else: + assert not model._sparse_fit + assert np.array_equal(raw_data, input_raw_data) + assert model.dtype == "float32" + assert model.n_cols == input_raw_data.shape[1] + + def _load_dataset(dataset: str, n_rows: int) -> Tuple[np.ndarray, np.ndarray]: if dataset == "digits": local_X, local_y = load_digits(return_X_y=True) @@ -57,18 +110,29 @@ def _load_dataset(dataset: str, n_rows: int) -> Tuple[np.ndarray, np.ndarray]: def _local_umap_trustworthiness( - local_X: np.ndarray, + local_X: Union[np.ndarray, csr_matrix], local_y: np.ndarray, n_neighbors: int, supervised: bool, + sparse: bool = False, ) -> float: from cuml.manifold import UMAP - local_model = UMAP(n_neighbors=n_neighbors, random_state=42, init="random") + if sparse: + local_model = UMAP( + n_neighbors=n_neighbors, random_state=42, init="random", metric="jaccard" + ) + else: + local_model = UMAP(n_neighbors=n_neighbors, random_state=42, init="random") + y_train = local_y if supervised else None local_model.fit(local_X, y=y_train) embedding = local_model.transform(local_X) + if sparse: + assert isinstance(local_X, csr_matrix) + local_X = local_X.toarray() + return trustworthiness(local_X, embedding, n_neighbors=n_neighbors, batch_size=5000) @@ -91,22 +155,32 @@ def _spark_umap_trustworthiness( with CleanSparkSession() as spark: if supervised: - data_df, features_col, label_col = create_pyspark_dataframe( + data_df, feature_cols, label_col = create_pyspark_dataframe( spark, feature_type, dtype, local_X, local_y ) assert label_col is not None umap_estimator.setLabelCol(label_col) else: - data_df, features_col, _ = create_pyspark_dataframe( + data_df, feature_cols, _ = create_pyspark_dataframe( spark, feature_type, dtype, local_X, None ) data_df = data_df.repartition(n_parts) - umap_estimator.setFeaturesCol(features_col) + if isinstance(feature_cols, list): + umap_estimator.setFeaturesCols(feature_cols) + else: + umap_estimator.setFeaturesCol(feature_cols) + umap_model = umap_estimator.fit(data_df) pdf = umap_model.transform(data_df).toPandas() + embedding = cp.asarray(pdf["embedding"].to_list()).astype(cp.float32) - input = cp.asarray(pdf["features"].to_list()).astype(cp.float32) + if isinstance(feature_cols, list): + input = pdf[feature_cols].to_numpy() + else: + input = pdf[feature_cols].to_list() + + input = cp.asarray(input).astype(cp.float32) return trustworthiness(input, embedding, n_neighbors=n_neighbors, batch_size=5000) @@ -295,89 +369,104 @@ def test_params(tmp_path: str, default_params: bool) -> None: _test_input_setter_getter(UMAP) -def test_umap_model_persistence(gpu_number: int, tmp_path: str) -> None: +@pytest.mark.parametrize("sparse_fit", [True, False]) +def test_umap_model_persistence( + sparse_fit: bool, gpu_number: int, tmp_path: str +) -> None: from cuml.datasets import make_blobs - X, _ = make_blobs( - 100, - 20, - centers=42, - cluster_std=0.1, - dtype=np.float32, - random_state=10, - ) - with CleanSparkSession() as spark: - pyspark_type = "float" - feature_cols = [f"c{i}" for i in range(X.shape[1])] - schema = [f"{c} {pyspark_type}" for c in feature_cols] - df = spark.createDataFrame(X.tolist(), ",".join(schema)) - df = df.withColumn("features", array(*feature_cols)).drop(*feature_cols) - umap = UMAP(num_workers=gpu_number).setFeaturesCol("features") + n_rows = 5000 + n_cols = 200 - def assert_umap_model(model: UMAPModel) -> None: - embedding = np.array(model.embedding) - raw_data = np.array(model.raw_data) - assert embedding.shape == (100, 2) - assert raw_data.shape == (100, 20) - assert np.array_equal(raw_data, X.get()) - assert model.dtype == "float32" - assert model.n_cols == X.shape[1] + if sparse_fit: + data, input_raw_data = _load_sparse_binary_data(n_rows, n_cols, 30) + df = spark.createDataFrame(data, ["features"]) + else: + X, _ = make_blobs( + n_rows, + n_cols, + centers=5, + cluster_std=0.1, + dtype=np.float32, + random_state=10, + ) + pyspark_type = "float" + feature_cols = [f"c{i}" for i in range(X.shape[1])] + schema = [f"{c} {pyspark_type}" for c in feature_cols] + df = spark.createDataFrame(X.tolist(), ",".join(schema)) + df = df.withColumn("features", array(*feature_cols)).drop(*feature_cols) + input_raw_data = X.get() + + umap = UMAP(num_workers=gpu_number).setFeaturesCol("features") umap_model = umap.fit(df) - assert_umap_model(model=umap_model) + _assert_umap_model(umap_model, input_raw_data) # Model persistence path = tmp_path + "/umap_tests" model_path = f"{path}/umap_model" umap_model.write().overwrite().save(model_path) umap_model_loaded = UMAPModel.load(model_path) - assert_umap_model(model=umap_model_loaded) + _assert_umap_model(umap_model_loaded, input_raw_data) -@pytest.mark.parametrize("BROADCAST_LIMIT", [8 << 20, 8 << 18]) -def test_umap_broadcast_chunks(gpu_number: int, BROADCAST_LIMIT: int) -> None: +@pytest.mark.parametrize("maxRecordsPerBatch", ["2000"]) +@pytest.mark.parametrize("BROADCAST_LIMIT", [8 << 15]) +@pytest.mark.parametrize("sparse_fit", [True, False]) +def test_umap_chunking( + gpu_number: int, maxRecordsPerBatch: str, BROADCAST_LIMIT: int, sparse_fit: bool +) -> None: from cuml.datasets import make_blobs - X, _ = make_blobs( - 5000, - 3000, - centers=42, - cluster_std=0.1, - dtype=np.float32, - random_state=10, - ) + n_rows = int(int(maxRecordsPerBatch) * 2.5) + n_cols = 3000 with CleanSparkSession() as spark: - pyspark_type = "float" - feature_cols = [f"c{i}" for i in range(X.shape[1])] - schema = [f"{c} {pyspark_type}" for c in feature_cols] - df = spark.createDataFrame(X.tolist(), ",".join(schema)) - df = df.withColumn("features", array(*feature_cols)).drop(*feature_cols) + spark.conf.set( + "spark.sql.execution.arrow.maxRecordsPerBatch", maxRecordsPerBatch + ) + + if sparse_fit: + data, input_raw_data = _load_sparse_binary_data(n_rows, n_cols, 30) + df = spark.createDataFrame(data, ["features"]) + nbytes = input_raw_data.data.nbytes + else: + X, _ = make_blobs( + n_rows, + n_cols, + centers=5, + cluster_std=0.1, + dtype=np.float32, + random_state=10, + ) + pyspark_type = "float" + feature_cols = [f"c{i}" for i in range(X.shape[1])] + schema = [f"{c} {pyspark_type}" for c in feature_cols] + df = spark.createDataFrame(X.tolist(), ",".join(schema)) + df = df.withColumn("features", array(*feature_cols)).drop(*feature_cols) + input_raw_data = X.get() + nbytes = input_raw_data.nbytes umap = UMAP(num_workers=gpu_number).setFeaturesCol("features") - umap.BROADCAST_LIMIT = BROADCAST_LIMIT - umap_model = umap.fit(df) + assert umap.max_records_per_batch == int(maxRecordsPerBatch) + assert nbytes > BROADCAST_LIMIT - def assert_umap_model(model: UMAPModel) -> None: - embedding = np.array(model.embedding) - raw_data = np.array(model.raw_data) - assert embedding.shape == (5000, 2) - assert raw_data.shape == (5000, 3000) - assert np.array_equal(raw_data, X.get()) - assert model.dtype == "float32" - assert model.n_cols == X.shape[1] + umap_model = umap.fit(df) + umap_model.BROADCAST_LIMIT = BROADCAST_LIMIT - assert_umap_model(model=umap_model) + _assert_umap_model(umap_model, input_raw_data) pdf = umap_model.transform(df).toPandas() - embedding = cp.asarray(pdf["embedding"].to_list()).astype(cp.float32) - input = cp.asarray(pdf["features"].to_list()).astype(cp.float32) + embedding = np.vstack(pdf["embedding"]).astype(np.float32) + input = np.vstack(pdf["features"]).astype(np.float32) dist_umap = trustworthiness(input, embedding, n_neighbors=15, batch_size=5000) - loc_umap = _local_umap_trustworthiness(X, np.zeros(0), 15, False) + loc_umap = _local_umap_trustworthiness( + input_raw_data, np.zeros(0), 15, False, sparse_fit + ) trust_diff = loc_umap - dist_umap assert trust_diff <= 0.15 @@ -393,7 +482,7 @@ def test_umap_sample_fraction(gpu_number: int) -> None: X, _ = make_blobs( n_rows, 10, - centers=42, + centers=5, cluster_std=0.1, dtype=np.float32, random_state=10, @@ -416,20 +505,14 @@ def test_umap_sample_fraction(gpu_number: int) -> None: umap_model = umap.fit(df) - def assert_umap_model(model: UMAPModel) -> None: - embedding = np.array(model.embedding) - raw_data = np.array(model.raw_data) + threshold = 2 * np.sqrt( + n_rows * sample_fraction * (1 - sample_fraction) + ) # 2 std devs - threshold = 2 * np.sqrt( - n_rows * sample_fraction * (1 - sample_fraction) - ) # 2 std devs - - assert np.abs(n_rows * sample_fraction - embedding.shape[0]) <= threshold - assert np.abs(n_rows * sample_fraction - raw_data.shape[0]) <= threshold - assert model.dtype == "float32" - assert model.n_cols == X.shape[1] - - assert_umap_model(model=umap_model) + embedding = umap_model.embedding + raw_data = umap_model.rawData + assert np.abs(n_rows * sample_fraction - embedding.shape[0]) <= threshold + assert np.abs(n_rows * sample_fraction - raw_data.shape[0]) <= threshold def test_umap_build_algo(gpu_number: int) -> None: @@ -473,16 +556,7 @@ def test_umap_build_algo(gpu_number: int) -> None: umap_model = umap.fit(df) - def assert_umap_model(model: UMAPModel) -> None: - embedding = np.array(model.embedding) - raw_data = np.array(model.raw_data) - assert embedding.shape == (10000, 2) - assert raw_data.shape == (10000, 10) - assert np.array_equal(raw_data, X.get()) - assert model.dtype == "float32" - assert model.n_cols == X.shape[1] - - assert_umap_model(model=umap_model) + _assert_umap_model(umap_model, X.get()) pdf = umap_model.transform(df).toPandas() embedding = cp.asarray(pdf["embedding"].to_list()).astype(cp.float32) @@ -493,3 +567,49 @@ def assert_umap_model(model: UMAPModel) -> None: trust_diff = loc_umap - dist_umap assert trust_diff <= 0.15 + + +@pytest.mark.parametrize("n_rows", [3000]) +@pytest.mark.parametrize("n_cols", [64]) +@pytest.mark.parametrize("nnz", [12]) +@pytest.mark.parametrize("metric", ["jaccard", "hamming", "correlation", "cosine"]) +def test_umap_sparse_vector( + n_rows: int, n_cols: int, nnz: int, metric: str, gpu_number: int, tmp_path: str +) -> None: + import pyspark + from cuml.manifold import UMAP as cumlUMAP + from packaging import version + + if version.parse(pyspark.__version__) < version.parse("3.4.0"): + import logging + + err_msg = "pyspark < 3.4 is detected. Cannot import pyspark `unwrap_udt` function for SparseVector. " + "The test case will be skipped. Please install pyspark>=3.4." + logging.info(err_msg) + return + + with CleanSparkSession() as spark: + data, input_raw_data = _load_sparse_binary_data(n_rows, n_cols, nnz) + df = spark.createDataFrame(data, ["features"]) + + umap_estimator = UMAP( + metric=metric, num_workers=gpu_number, random_state=42 + ).setFeaturesCol("features") + umap_model = umap_estimator.fit(df) + embedding = umap_model.embedding + + # Ensure internal and input CSR data match + _assert_umap_model(umap_model, input_raw_data) + + # Local vs dist trustworthiness check + output = umap_model.transform(df).toPandas() + embedding = cp.asarray(output["embedding"].to_list()) + dist_umap = trustworthiness(input_raw_data.toarray(), embedding, n_neighbors=15) + + local_model = cumlUMAP(n_neighbors=15, random_state=42, metric=metric) + local_model.fit(input_raw_data) + embedding = local_model.transform(input_raw_data) + loc_umap = trustworthiness(input_raw_data.toarray(), embedding, n_neighbors=15) + + trust_diff = loc_umap - dist_umap + assert trust_diff <= 0.15