diff --git a/NAMESPACE b/NAMESPACE index 262aba1..42c62ba 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -4,6 +4,7 @@ S3method(predict,SL.hal) S3method(predict,hal) export(SL.hal) export(hal) +export(halplus) export(makeSparseMat) importFrom(Matrix,sparseMatrix) importFrom(bit,bit) diff --git a/R/SL.hal.R b/R/SL.hal.R index aaa9731..42058ca 100644 --- a/R/SL.hal.R +++ b/R/SL.hal.R @@ -27,7 +27,7 @@ SL.hal <- function(Y, ...) { halOut <- hal(Y = Y, X = X, newX = newX, verbose = verbose, obsWeights = obsWeights, nfolds = nfolds, - nlambda = nlambda, useMin = useMin, ...) + nlambda = nlambda, useMin = useMin, family = family, ...) out <- list(object = halOut, pred = halOut$pred) class(out) <- "SL.hal" @@ -51,7 +51,7 @@ SL.hal <- function(Y, #' @export predict.SL.hal <- function(object, newdata, bigDesign = FALSE, chunks = 5000, ...){ pred <- stats::predict(object$object, newdata = newdata, bigDesign = bigDesign, - chunks = chunks,...) + chunks = chunks, ...) return(pred) } diff --git a/R/doPred.R b/R/doPred.R index c19d8e3..224c534 100755 --- a/R/doPred.R +++ b/R/doPred.R @@ -11,7 +11,7 @@ #' #' @importFrom Matrix sparseMatrix -doPred <- function(object, newdata, verbose = FALSE, s) { +doPred <- function(object, newdata, verbose = FALSE, s, offset) { if (is.vector(newdata)) newdata <- matrix(newdata) @@ -99,7 +99,12 @@ doPred <- function(object, newdata, verbose = FALSE, s) { # call predict.glmnet to get predictions on new sparseMat with duplicate # columns removed. - pred <- stats::predict(object$object$glmnet.fit, newx = tmp, - s = s) + if (is.null(offset)) { + pred <- stats::predict(object$object$glmnet.fit, newx = tmp, + s = s, type = 'response') + } else { + pred <- stats::predict(object$object$glmnet.fit, newx = tmp, + s = s, type = 'response', newoffset = offset) + } return(pred) } diff --git a/R/hal.R b/R/hal.R index 40f1a44..008edb7 100644 --- a/R/hal.R +++ b/R/hal.R @@ -23,6 +23,7 @@ #' @param debug For benchmarking. Setting to \code{TRUE} will run garbage collection to #' improve the accuracy of memory monitoring #' @param parallel A boolean indicating whether to use a parallel backend, if possible +#' @param family binomial() or gaussian() #' @param ... Not currently used #' @importFrom glmnet cv.glmnet #' @importFrom bit bit @@ -44,6 +45,7 @@ hal <- function(Y, useMin = TRUE, debug = TRUE, parallel = FALSE, + family = gaussian(), ... # allow extra arguments with no death ) { @@ -187,7 +189,7 @@ hal <- function(Y, lambda.min.ratio = 0.001, type.measure = "deviance", nfolds = nfolds, - family = "gaussian", + family = family$family, alpha = 1, nlambda = nlambda, parallel = parallel @@ -202,7 +204,7 @@ hal <- function(Y, lambda.min.ratio = 0.001, type.measure = "deviance", nfolds = nfolds, - family = "gaussian", + family = family$family, alpha = 1, nlambda = nlambda, parallel = parallel @@ -245,7 +247,8 @@ hal <- function(Y, pred <- predict(fit, newdata = newX, bigDesign = FALSE, - chunks = 10000) + chunks = 10000 + ) } # wrap up the timing diff --git a/R/halplus.R b/R/halplus.R new file mode 100644 index 0000000..4b806a0 --- /dev/null +++ b/R/halplus.R @@ -0,0 +1,264 @@ +#' hal +#' +#' The highly adaptive lasso fitting function. This function takes a matrix of predictor values +#' (which can be binary or continuous) and converts it into a set of indicator basis functions +#' that perfectly fit the data. The function then uses cross-validated lasso (via the \code{glmnet} +#' package) to select basis functions. The resulting fit is called the highly adaptive lasso. +#' The process of creating the indicator basis functions can be extremely time and memory intensive +#' as it involves creating n(2^d - 1) basis functions, where n is the number of observations +#' and d the number of covariates. The function also must subsequently search over basis functions +#' for those that are duplicated and store the results. Future implementations will attempt to +#' streamline this process to the largest extent possible; however, for the time being implementing +#' with values of n and d such that n(2^d - 1) > 1e7 is not recommended. +#' +#' @param Y A \code{numeric} of outcomes +#' @param X A \code{data.frame} of predictors +#' @param newX Optional \code{data.frame} on which to return predicted values +#' @param verbose A \code{boolean} indicating whether to print output on functions progress +#' @param obsWeights Optional \code{vector} of observation weights to be passed to \code{cv.glmnet} +#' @param nfolds Number of CV folds passed to \code{cv.glmnet} +#' @param nlambda Number of lambda values to search across in \code{cv.glmnet} +#' @param useMin Option passed to \code{cv.glmnet}, use minimum risk lambda or 1se lambda (more +#' penalization) +#' @param debug For benchmarking. Setting to \code{TRUE} will run garbage collection to +#' improve the accuracy of memory monitoring +#' @param parallel A boolean indicating whether to use a parallel backend, if possible +#' @param family binomial() or gaussian() +#' @param ... Not currently used +#' @importFrom glmnet cv.glmnet +#' @importFrom bit bit +#' @importFrom stats gaussian predict +#' @importFrom utils combn +#' @importFrom data.table data.table set setkey +#' @importFrom plyr alply +#' @importFrom stringr str_c str_replace_na +#' +#' @export + +halplus <- function(Y, + X, + newX = NULL, + verbose = FALSE, + obsWeights = rep(1, length(Y)), + nfolds = ifelse(length(Y) <= 100, 20, 10), + nlambda = 100, + useMin = TRUE, + debug = TRUE, + parallel = FALSE, + family = gaussian(), + offset = NULL, + ... # allow extra arguments with no death + ) { + + + #--------------------------------------------------------- + # Preliminary operations + #--------------------------------------------------------- + d <- ncol(X) + n <- length(X[, 1]) + + if (is.vector(X)) + X <- matrix(X, ncol = 1) + + if (is.vector(newX)) + newX <- matrix(newX, ncol = 1) + + # Run garbage collection if we are in debug mode. + if (debug) gc() + + # Initialize prediction object to null in case newX = NULL. + pred <- NULL + times <- NULL + + #------------------------------------------------------------ + # Make initial design matrix (including duplicated columns) + #------------------------------------------------------------ + if (verbose) cat("Making sparse matrix \n") + time_sparse_start <- proc.time() + + # makeSparseMat to create sparseMatrix design matrix + X.init <- makeSparseMat(X = X, newX = X, verbose = verbose) + + time_sparse_end <- proc.time() + time_sparse_matrix <- time_sparse_end - time_sparse_start + + # Run garbage collection if we are in debug mode. + if (debug) gc() + + #------------------------------------------------------------ + # Removing duplicated columns + # TODO: Should this code be wrapped up in a function or would + # passing all those objects to another function waste memory? + #------------------------------------------------------------ + if (verbose) cat("Finding duplicate columns \n") + + # Number of columns will become the new number of observations in the data.table + nIndCols <- ncol(X.init) + + # Pre-allocate a data.table with one column, each row will store a single column from X.init + datDT <- + data.table(ID = 1:nIndCols, + bit_to_int_to_str = rep.int("0", nIndCols)) + # Each column in X.init will be represented by a unique vector of integers. + # Each indicator column in X.init will be converted to a row of integers or + # a string of cat'ed integers in data.table. The number of integers needed to + # represent a single column is determined automatically by package "bit" and + # it depends on nrow(X.init) + nbits <- nrow(X.init) # number of bits (0/1) used by each column in X.init + bitvals <- bit::bit(length = nbits) # initial allocation (all 0/FALSE) + nints_used <- length(unclass(bitvals)) # number of integers needed to represent each column + + # Track which results gave NA in one of the integers + ID_withNA <- NULL + + # For loop over columns of X.init + for (i in 1:nIndCols) { + bitvals <- bit::bit(length = nbits) # initial allocation (all 0/FALSE) + Fidx_base0 <- + (X.init@p[i]):(X.init@p[i + 1] - 1) # zero-base indices of indices of non-zero rows for column i=1 + nonzero_rows <- + X.init@i[Fidx_base0 + 1] + 1 # actual row numbers of non-zero elements in column i=1 + # print(i); print(nonzero_rows) + # X.init@i[i:X.init@p[i]]+1 # row numbers of non-zero elements in first col + bitvals[nonzero_rows] <- TRUE + # str(bitwhich(bitvals)) + intval <- + unclass(bitvals) # integer representation of the bit vector + # stringval <- str_c(intval, collapse = "") + if (any(is.na(intval))) + ID_withNA <- c(ID_withNA, i) + data.table::set(datDT, i, 2L, + value = stringr::str_c(stringr::str_replace_na(intval), + collapse = "")) + } + # create a hash-key on the string representation of the column, + # sorts it by bit_to_int_to_str using radix sort: + data.table::setkey(datDT, bit_to_int_to_str) + # add logical column indicating duplicates, + # following the first non-duplicate element + datDT[, duplicates := duplicated(datDT, by="bit_to_int_to_str")] + # just get the column IDs and duplicate indicators: + datDT[, .(ID, duplicates)] + + dupInds <- datDT[, ID][which(datDT[, duplicates])] + + #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + # OS: NEW FASTER APPROACH TO FIND DUPLICATE IDs + # get the number of duplicates in each group if its 1 the column is + # unique and we are note interested: + datDT[, Ngrp := .N, by = bit_to_int_to_str] + # collapse each duplicate group into a list of IDs, do that only + # among strings that have duplicates + collapsedDT <- datDT[Ngrp > 1, list(list(ID)), by = bit_to_int_to_str] + colDups <- collapsedDT[["V1"]] + # colDups[[2]] + #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + + #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + # OS: OLD APPROACH TO BE REMOVED AFTER VALIDATED + # uniqDup <- unique(datDT[duplicates == TRUE, bit_to_int_to_str]) + # colDups.old <- alply(uniqDup, 1, function(l) { + # datDT[, ID][which(datDT[, bit_to_int_to_str] == l)] + # }) + #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + + time_dup_end = proc.time() + + time_find_duplicates = time_dup_end - time_sparse_end + + # Run garbage collection if we are in debug mode. + if (debug) gc() + + #------------------------------------------------------------ + # Fit lasso + #------------------------------------------------------------ + + if (verbose) cat("Fitting lasso \n") + if (length(dupInds) > 0) { + notDupInds <- (1:ncol(X.init))[-unlist(colDups, use.names = FALSE)] + keepDupInds <- + unlist(lapply(colDups, function(x) { + x[[1]] + }), use.names = FALSE) + if (is.null(offset)) { + fitCV <- + glmnet::cv.glmnet( + x = X.init[, c(keepDupInds, notDupInds)], + y = Y, + weights = obsWeights, + lambda = NULL, + lambda.min.ratio = 0.001, + type.measure = "deviance", + nfolds = nfolds, + family = family$family, + alpha = 1, + nlambda = nlambda, + parallel = parallel + ) + } else { + fitCV <- + glmnet::cv.glmnet( + x = X.init[, c(keepDupInds, notDupInds)], + y = Y, + weights = obsWeights, + lambda = NULL, + lambda.min.ratio = 0.001, + type.measure = "deviance", + nfolds = nfolds, + family = family$family, + alpha = 1, + nlambda = nlambda, + parallel = parallel, + offset = offset + ) + } + } else { + if (is.null(offset)) { + fitCV <- + glmnet::cv.glmnet( + x = X.init[, c(keepDupInds, notDupInds)], + y = Y, + weights = obsWeights, + lambda = NULL, + lambda.min.ratio = 0.001, + type.measure = "deviance", + nfolds = nfolds, + family = family$family, + alpha = 1, + nlambda = nlambda, + parallel = parallel + ) + } else { + fitCV <- + glmnet::cv.glmnet( + x = X.init[, c(keepDupInds, notDupInds)], + y = Y, + weights = obsWeights, + lambda = NULL, + lambda.min.ratio = 0.001, + type.measure = "deviance", + nfolds = nfolds, + family = family$family, + alpha = 1, + nlambda = nlambda, + parallel = parallel, + offset = offset + ) + } + } + time_lasso_end <- proc.time() + time_lasso <- time_dup_end - time_lasso_end + #------------------------------------------------------------ + # Initial output object (pred and times added below) + #------------------------------------------------------------ + fit <- list(object = fitCV, + useMin = useMin, + X = X, + dupInds = dupInds, + colDups = colDups, + pred = NULL, + times = NULL + ) + class(fit) <- "hal" + return(fit) +} diff --git a/R/predict.hal.R b/R/predict.hal.R index 4b7631b..3c11791 100644 --- a/R/predict.hal.R +++ b/R/predict.hal.R @@ -8,6 +8,7 @@ predict.hal <- verbose = TRUE, chunks = 1000, s = ifelse(object$useMin, object$object$lambda.min, object$object$lambda.1se), + offset = NULL, ...) { if (!object$sparseMat) { @@ -49,7 +50,8 @@ predict.hal <- object$object$glmnet.fit, newx = designNewX, s = s, - type = "response" + type = "response", + newoffset = offset ) } else { @@ -92,7 +94,8 @@ predict.hal <- object$object$glmnet.fit, newx = matrix(designNewX, nrow = 1), s = s, - type = "response" + type = "response", + newoffset = offset ) thispred }) @@ -102,7 +105,8 @@ predict.hal <- if (bigDesign) { pred <- doPred(object = object, newdata = newdata, - verbose = verbose) + verbose = verbose, + offset = offset) } else { nNew <- length(newdata[, 1]) nChunks <- floor(nNew / chunks) + ifelse(nNew %% chunks == 0, 0, 1) @@ -115,7 +119,8 @@ predict.hal <- object = object, s = s, newdata = newdata[minC:maxC, ], - verbose = verbose + verbose = verbose, + offset = offset ) } } diff --git a/R/predict_hal.R b/R/predict_hal.R index 7e07c46..c6f4813 100644 --- a/R/predict_hal.R +++ b/R/predict_hal.R @@ -23,13 +23,15 @@ predict.hal <- verbose = TRUE, chunks = 5000, s = ifelse(object$useMin, object$object$lambda.min, object$object$lambda.1se), + offset = NULL, ...) { # all predictions at once if (bigDesign) { pred <- doPred(object = object, newdata = newdata, - verbose = verbose) + verbose = verbose, + offset = offset) } else { nNew <- length(newdata[, 1]) nChunks <- floor(nNew / chunks) + ifelse(nNew %% chunks == 0, 0, 1) @@ -42,7 +44,8 @@ predict.hal <- object = object, s = s, newdata = newdata[minC:maxC, ], - verbose = verbose + verbose = verbose, + offset = offset ) } } diff --git a/man/hal.Rd b/man/hal.Rd index fc7f061..b6535de 100644 --- a/man/hal.Rd +++ b/man/hal.Rd @@ -6,7 +6,7 @@ \usage{ hal(Y, X, newX = NULL, verbose = FALSE, obsWeights = rep(1, length(Y)), nfolds = ifelse(length(Y) <= 100, 20, 10), nlambda = 100, useMin = TRUE, - debug = TRUE, parallel = FALSE, ...) + debug = TRUE, parallel = FALSE, family, ...) } \arguments{ \item{Y}{A \code{numeric} of outcomes} @@ -26,22 +26,24 @@ hal(Y, X, newX = NULL, verbose = FALSE, obsWeights = rep(1, length(Y)), \item{useMin}{Option passed to \code{cv.glmnet}, use minimum risk lambda or 1se lambda (more penalization)} -\item{debug}{For benchmarking. Setting to \code{TRUE} will run garbage collection to +\item{debug}{For benchmarking. Setting to \code{TRUE} will run garbage collection to improve the accuracy of memory monitoring} \item{parallel}{A boolean indicating whether to use a parallel backend, if possible} +\item{family}{binomial() or gaussian()} + \item{...}{Not currently used} } \description{ The highly adaptive lasso fitting function. This function takes a matrix of predictor values (which can be binary or continuous) and converts it into a set of indicator basis functions -that perfectly fit the data. The function then uses cross-validated lasso (via the \code{glmnet} -package) to select basis functions. The resulting fit is called the highly adaptive lasso. -The process of creating the indicator basis functions can be extremely time and memory intensive -as it involves creating n(2^d - 1) basis functions, where n is the number of observations +that perfectly fit the data. The function then uses cross-validated lasso (via the \code{glmnet} +package) to select basis functions. The resulting fit is called the highly adaptive lasso. +The process of creating the indicator basis functions can be extremely time and memory intensive +as it involves creating n(2^d - 1) basis functions, where n is the number of observations and d the number of covariates. The function also must subsequently search over basis functions -for those that are duplicated and store the results. Future implementations will attempt to +for those that are duplicated and store the results. Future implementations will attempt to streamline this process to the largest extent possible; however, for the time being implementing with values of n and d such that n(2^d - 1) > 1e7 is not recommended. }