Skip to content

Commit

Permalink
Merge pull request #102 from edgararuiz/fix-test
Browse files Browse the repository at this point in the history
Fixes earth GLM models and prediction test routine
  • Loading branch information
topepo authored May 25, 2022
2 parents 6a0a4e7 + eeeddd6 commit de7dbe3
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 10 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# tidypredict (development version)

- Fixes issue handling GLM Binomial earth models (#97)

- Adds capability to handle single simple Cubist models (#57)

- Fixed parenthesis issue in the creation of the interval formula (#76)
Expand Down
18 changes: 11 additions & 7 deletions R/model-earth.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@ parse_model.earth <- function(model) {
}

is_glm <- !is.null(model$glm.list)

all_coefs <- model$coefficients
if (is_glm) all_coefs <- model$glm.coefficients


pm <- list()
pm$general$model <- "earth"
pm$general$type <- "tree"
Expand All @@ -29,12 +26,12 @@ parse_model.earth <- function(model) {
pm$general$family <- fam$family
pm$general$link <- fam$link
}
pm$terms <- mars_terms(model)
pm$terms <- mars_terms(model, is_glm)
as_parsed_model(pm)
}


mars_terms <- function(mod) {
mars_terms <- function(mod, is_glm) {
feature_types <-
tibble::as_tibble(mod$dirs, rownames = "feature") %>%
dplyr::mutate(feature_num = dplyr::row_number()) %>%
Expand All @@ -48,9 +45,16 @@ mars_terms <- function(mod) {
tidyr::pivot_longer(cols = c(-feature, -feature_num),
values_to = "value",
names_to = "term")

if (is_glm) {
all_coefs <- mod$glm.coefficients
} else {
all_coefs <- mod$coefficients
}

feature_coefs <-
# Note coef(mod) formats data differently for logistic regression
tibble::as_tibble(mod$coefficients, rownames = "feature") %>%
tibble::as_tibble(all_coefs, rownames = "feature") %>%
setNames(c("feature", "coefficient"))

term_to_column <-
Expand Down
6 changes: 3 additions & 3 deletions R/test-predictions.R
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,12 @@ tidypredict_test_default <- function(model, df = model$model, threshold = 0.0000

if (is.numeric(max_rows)) df <- head(df, max_rows)

base <- predict(model, df, interval = interval, type = "response")
preds <- predict(model, df, interval = interval, type = "response")

if (!include_intervals) {
base <- data.frame(fit = base, row.names = NULL)
base <- data.frame(fit = as.vector(preds), row.names = NULL)
} else {
base <- as.data.frame(base)
base <- as.data.frame(preds)
}

te <- tidypredict_to_column(
Expand Down

0 comments on commit de7dbe3

Please sign in to comment.