diff --git a/NEWS.md b/NEWS.md index d261acf..2ce4d31 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,7 @@ # tidypredict 0.4.9 +- Fixes issue handling GLM Binomial earth models (#97) + - Adds capability to handle single simple Cubist models (#57) - Fixed parenthesis issue in the creation of the interval formula (#76) diff --git a/R/model-earth.R b/R/model-earth.R index 930f13a..b8ec292 100644 --- a/R/model-earth.R +++ b/R/model-earth.R @@ -14,10 +14,7 @@ parse_model.earth <- function(model) { } is_glm <- !is.null(model$glm.list) - - all_coefs <- model$coefficients - if (is_glm) all_coefs <- model$glm.coefficients - + pm <- list() pm$general$model <- "earth" pm$general$type <- "tree" @@ -29,12 +26,12 @@ parse_model.earth <- function(model) { pm$general$family <- fam$family pm$general$link <- fam$link } - pm$terms <- mars_terms(model) + pm$terms <- mars_terms(model, is_glm) as_parsed_model(pm) } -mars_terms <- function(mod) { +mars_terms <- function(mod, is_glm) { feature_types <- tibble::as_tibble(mod$dirs, rownames = "feature") %>% dplyr::mutate(feature_num = dplyr::row_number()) %>% @@ -48,9 +45,16 @@ mars_terms <- function(mod) { tidyr::pivot_longer(cols = c(-feature, -feature_num), values_to = "value", names_to = "term") + + if (is_glm) { + all_coefs <- mod$glm.coefficients + } else { + all_coefs <- mod$coefficients + } + feature_coefs <- # Note coef(mod) formats data differently for logistic regression - tibble::as_tibble(mod$coefficients, rownames = "feature") %>% + tibble::as_tibble(all_coefs, rownames = "feature") %>% setNames(c("feature", "coefficient")) term_to_column <- diff --git a/R/test-predictions.R b/R/test-predictions.R index 3fb21d6..8d198d5 100644 --- a/R/test-predictions.R +++ b/R/test-predictions.R @@ -68,12 +68,12 @@ tidypredict_test_default <- function(model, df = model$model, threshold = 0.0000 if (is.numeric(max_rows)) df <- head(df, max_rows) - base <- predict(model, df, interval = interval, type = "response") + preds <- predict(model, df, interval = interval, type = "response") if (!include_intervals) { - base <- data.frame(fit = base, row.names = NULL) + base <- data.frame(fit = as.vector(preds), row.names = NULL) } else { - base <- as.data.frame(base) + base <- as.data.frame(preds) } te <- tidypredict_to_column(