From efc81fa994fa1be983dcf02fddc3317cdabbf9f1 Mon Sep 17 00:00:00 2001 From: Kozodoi Date: Wed, 14 Apr 2021 16:15:59 +0200 Subject: [PATCH] fix: Correct handling of factors - fixed errors for some factor outcomes - more detailed error messages for factor level mismatch --- R/acc_parity.R | 46 +++++++++++++++++++++++++++++--------------- R/dem_parity.R | 18 +++++++++++++++-- R/equal_odds.R | 18 +++++++++++++++-- R/fairness-package.R | 4 ++-- R/fnr_parity.R | 18 +++++++++++++++-- R/fpr_parity.R | 18 +++++++++++++++-- R/mcc_parity.R | 18 +++++++++++++++-- R/npv_parity.R | 18 +++++++++++++++-- R/pred_rate_parity.R | 18 +++++++++++++++-- R/prop_parity.R | 18 +++++++++++++++-- R/spec_parity.R | 18 +++++++++++++++-- 11 files changed, 176 insertions(+), 36 deletions(-) diff --git a/R/acc_parity.R b/R/acc_parity.R index b23c35b..05dffa0 100644 --- a/R/acc_parity.R +++ b/R/acc_parity.R @@ -54,7 +54,7 @@ acc_parity <- function(data, outcome, group, warning(paste0('Converting ', class(data)[1], ' to data.frame')) data <- as.data.frame(data) } - + # convert types, sync levels if (is.null(probs) & is.null(preds)) { stop({'Either probs or preds have to be supplied'}) @@ -68,7 +68,8 @@ acc_parity <- function(data, outcome, group, if (length(probs) == 1) { probs <- data[, probs] } - preds_status <- as.factor(as.numeric(probs > cutoff)) + preds_status <- as.factor(as.numeric(probs > cutoff)) + levels(preds_status) <- levels(as.factor(data[, outcome])) } # check group feature and cut if needed @@ -87,27 +88,40 @@ acc_parity <- function(data, outcome, group, group_status <- as.factor(data[, group]) outcome_status <- as.factor(data[, outcome]) + # check levels matching + if (!identical(levels(outcome_status), levels(preds_status))) { + warn_preds <- paste0(levels(preds_status), collapse = ', ') + warn_outcome <- paste0(levels(outcome_status), collapse = ', ') + stop({paste0(c('Levels of predictions and outcome do not match. ', + 'Please relevel predictions or outcome.\n', + 'Outcome levels: ', warn_preds, '\n', + 'Preds levels: ', warn_outcome))})} + # relevel preds & outcomes - if (is.null(outcome_base)) {outcome_base <- levels(outcome_status)[1]} + if (is.null(outcome_base)) { + outcome_base <- levels(outcome_status)[1] + }else{ + outcome_base <- as.character(outcome_base) + } outcome_status <- relevel(outcome_status, outcome_base) preds_status <- relevel(preds_status, outcome_base) outcome_positive <- levels(outcome_status)[2] - + # check lengths if ((length(outcome_status) != length(preds_status)) | (length(outcome_status) != - length(group_status))) { + length(group_status))) { stop('Outcomes, predictions/probabilities and group status must be of the same length') } - + # relevel group if (is.null(base)) {base <- levels(group_status)[1]} group_status <- relevel(group_status, base) - + # placeholders val <- rep(NA, length(levels(group_status))) names(val) <- levels(group_status) sample_size <- val - + # compute value for all groups for (i in levels(group_status)) { cm <- caret::confusionMatrix(preds_status[group_status == i], @@ -118,32 +132,32 @@ acc_parity <- function(data, outcome, group, val[i] <- metric_i sample_size[i] <- sum(cm$table) } - + # aggregate results res_table <- rbind(val, val/val[[1]], sample_size) rownames(res_table) <- c('Accuracy', 'Accuracy Parity', 'Group size') - + # conversion of metrics to df val_df <- as.data.frame(res_table[2, ]) colnames(val_df) <- c('val') val_df$groupst <- rownames(val_df) val_df$groupst <- as.factor(val_df$groupst) - + # relevel group if (is.null(base)) { val_df$groupst <- levels(val_df$groupst)[1] } val_df$groupst <- relevel(val_df$groupst, base) - + p <- ggplot(val_df, aes(x = groupst, weight = val, fill = groupst)) + geom_bar(alpha = 0.5) + coord_flip() + theme(legend.position = 'none') + labs(x = '', y = 'Accuracy Parity') - + # plotting if (!is.null(probs)) { q <- ggplot(data, aes(x = probs, fill = group_status)) + geom_density(alpha = 0.5) + labs(x = 'Predicted probabilities') + guides(fill = guide_legend(title = '')) + theme(plot.title = element_text(hjust = 0.5)) + xlim(0, 1) + geom_vline(xintercept = cutoff, - linetype = 'dashed') + linetype = 'dashed') } if (is.null(probs)) { @@ -151,5 +165,5 @@ acc_parity <- function(data, outcome, group, } else { list(Metric = res_table, Metric_plot = p, Probability_plot = q) } - -} + +} \ No newline at end of file diff --git a/R/dem_parity.R b/R/dem_parity.R index cae47e7..6671379 100644 --- a/R/dem_parity.R +++ b/R/dem_parity.R @@ -68,7 +68,8 @@ dem_parity <- function(data, outcome, group, if (length(probs) == 1) { probs <- data[, probs] } - preds_status <- as.factor(as.numeric(probs > cutoff)) + preds_status <- as.factor(as.numeric(probs > cutoff)) + levels(preds_status) <- levels(as.factor(data[, outcome])) } # check group feature and cut if needed @@ -87,8 +88,21 @@ dem_parity <- function(data, outcome, group, group_status <- as.factor(data[, group]) outcome_status <- as.factor(data[, outcome]) + # check levels matching + if (!identical(levels(outcome_status), levels(preds_status))) { + warn_preds <- paste0(levels(preds_status), collapse = ', ') + warn_outcome <- paste0(levels(outcome_status), collapse = ', ') + stop({paste0(c('Levels of predictions and outcome do not match. ', + 'Please relevel predictions or outcome.\n', + 'Outcome levels: ', warn_preds, '\n', + 'Preds levels: ', warn_outcome))})} + # relevel preds & outcomes - if (is.null(outcome_base)) {outcome_base <- levels(outcome_status)[1]} + if (is.null(outcome_base)) { + outcome_base <- levels(outcome_status)[1] + }else{ + outcome_base <- as.character(outcome_base) + } outcome_status <- relevel(outcome_status, outcome_base) preds_status <- relevel(preds_status, outcome_base) outcome_positive <- levels(outcome_status)[2] diff --git a/R/equal_odds.R b/R/equal_odds.R index 4d2b2fa..b500e05 100644 --- a/R/equal_odds.R +++ b/R/equal_odds.R @@ -69,7 +69,8 @@ equal_odds <- function(data, outcome, group, if (length(probs) == 1) { probs <- data[, probs] } - preds_status <- as.factor(as.numeric(probs > cutoff)) + preds_status <- as.factor(as.numeric(probs > cutoff)) + levels(preds_status) <- levels(as.factor(data[, outcome])) } # check group feature and cut if needed @@ -88,8 +89,21 @@ equal_odds <- function(data, outcome, group, group_status <- as.factor(data[, group]) outcome_status <- as.factor(data[, outcome]) + # check levels matching + if (!identical(levels(outcome_status), levels(preds_status))) { + warn_preds <- paste0(levels(preds_status), collapse = ', ') + warn_outcome <- paste0(levels(outcome_status), collapse = ', ') + stop({paste0(c('Levels of predictions and outcome do not match. ', + 'Please relevel predictions or outcome.\n', + 'Outcome levels: ', warn_preds, '\n', + 'Preds levels: ', warn_outcome))})} + # relevel preds & outcomes - if (is.null(outcome_base)) {outcome_base <- levels(outcome_status)[1]} + if (is.null(outcome_base)) { + outcome_base <- levels(outcome_status)[1] + }else{ + outcome_base <- as.character(outcome_base) + } outcome_status <- relevel(outcome_status, outcome_base) preds_status <- relevel(preds_status, outcome_base) outcome_positive <- levels(outcome_status)[2] diff --git a/R/fairness-package.R b/R/fairness-package.R index d7fbf28..e92ed50 100644 --- a/R/fairness-package.R +++ b/R/fairness-package.R @@ -7,8 +7,8 @@ #' Package: \tab fairness\cr #' Depends: \tab R (>= 3.5.0)\cr #' Type: \tab Package\cr -#' Version: \tab 1.2.1\cr -#' Date: \tab 2021-03-26\cr +#' Version: \tab 1.2.2\cr +#' Date: \tab 2021-04-14\cr #' License: \tab MIT\cr #' LazyLoad: \tab Yes #' } diff --git a/R/fnr_parity.R b/R/fnr_parity.R index 7207971..ae9286a 100644 --- a/R/fnr_parity.R +++ b/R/fnr_parity.R @@ -67,7 +67,8 @@ fnr_parity <- function(data, outcome, group, if (length(probs) == 1) { probs <- data[, probs] } - preds_status <- as.factor(as.numeric(probs > cutoff)) + preds_status <- as.factor(as.numeric(probs > cutoff)) + levels(preds_status) <- levels(as.factor(data[, outcome])) } # check group feature and cut if needed @@ -86,8 +87,21 @@ fnr_parity <- function(data, outcome, group, group_status <- as.factor(data[, group]) outcome_status <- as.factor(data[, outcome]) + # check levels matching + if (!identical(levels(outcome_status), levels(preds_status))) { + warn_preds <- paste0(levels(preds_status), collapse = ', ') + warn_outcome <- paste0(levels(outcome_status), collapse = ', ') + stop({paste0(c('Levels of predictions and outcome do not match. ', + 'Please relevel predictions or outcome.\n', + 'Outcome levels: ', warn_preds, '\n', + 'Preds levels: ', warn_outcome))})} + # relevel preds & outcomes - if (is.null(outcome_base)) {outcome_base <- levels(outcome_status)[1]} + if (is.null(outcome_base)) { + outcome_base <- levels(outcome_status)[1] + }else{ + outcome_base <- as.character(outcome_base) + } outcome_status <- relevel(outcome_status, outcome_base) preds_status <- relevel(preds_status, outcome_base) outcome_positive <- levels(outcome_status)[2] diff --git a/R/fpr_parity.R b/R/fpr_parity.R index 776654e..d6500d6 100644 --- a/R/fpr_parity.R +++ b/R/fpr_parity.R @@ -67,7 +67,8 @@ fpr_parity <- function(data, outcome, group, if (length(probs) == 1) { probs <- data[, probs] } - preds_status <- as.factor(as.numeric(probs > cutoff)) + preds_status <- as.factor(as.numeric(probs > cutoff)) + levels(preds_status) <- levels(as.factor(data[, outcome])) } # check group feature and cut if needed @@ -86,8 +87,21 @@ fpr_parity <- function(data, outcome, group, group_status <- as.factor(data[, group]) outcome_status <- as.factor(data[, outcome]) + # check levels matching + if (!identical(levels(outcome_status), levels(preds_status))) { + warn_preds <- paste0(levels(preds_status), collapse = ', ') + warn_outcome <- paste0(levels(outcome_status), collapse = ', ') + stop({paste0(c('Levels of predictions and outcome do not match. ', + 'Please relevel predictions or outcome.\n', + 'Outcome levels: ', warn_preds, '\n', + 'Preds levels: ', warn_outcome))})} + # relevel preds & outcomes - if (is.null(outcome_base)) {outcome_base <- levels(outcome_status)[1]} + if (is.null(outcome_base)) { + outcome_base <- levels(outcome_status)[1] + }else{ + outcome_base <- as.character(outcome_base) + } outcome_status <- relevel(outcome_status, outcome_base) preds_status <- relevel(preds_status, outcome_base) outcome_positive <- levels(outcome_status)[2] diff --git a/R/mcc_parity.R b/R/mcc_parity.R index 131d361..549ff9d 100644 --- a/R/mcc_parity.R +++ b/R/mcc_parity.R @@ -66,7 +66,8 @@ mcc_parity <- function(data, outcome, group, if (length(probs) == 1) { probs <- data[, probs] } - preds_status <- as.factor(as.numeric(probs > cutoff)) + preds_status <- as.factor(as.numeric(probs > cutoff)) + levels(preds_status) <- levels(as.factor(data[, outcome])) } # check group feature and cut if needed @@ -85,8 +86,21 @@ mcc_parity <- function(data, outcome, group, group_status <- as.factor(data[, group]) outcome_status <- as.factor(data[, outcome]) + # check levels matching + if (!identical(levels(outcome_status), levels(preds_status))) { + warn_preds <- paste0(levels(preds_status), collapse = ', ') + warn_outcome <- paste0(levels(outcome_status), collapse = ', ') + stop({paste0(c('Levels of predictions and outcome do not match. ', + 'Please relevel predictions or outcome.\n', + 'Outcome levels: ', warn_preds, '\n', + 'Preds levels: ', warn_outcome))})} + # relevel preds & outcomes - if (is.null(outcome_base)) {outcome_base <- levels(outcome_status)[1]} + if (is.null(outcome_base)) { + outcome_base <- levels(outcome_status)[1] + }else{ + outcome_base <- as.character(outcome_base) + } outcome_status <- relevel(outcome_status, outcome_base) preds_status <- relevel(preds_status, outcome_base) outcome_positive <- levels(outcome_status)[2] diff --git a/R/npv_parity.R b/R/npv_parity.R index 7255897..861a5d1 100644 --- a/R/npv_parity.R +++ b/R/npv_parity.R @@ -68,7 +68,8 @@ npv_parity <- function(data, outcome, group, if (length(probs) == 1) { probs <- data[, probs] } - preds_status <- as.factor(as.numeric(probs > cutoff)) + preds_status <- as.factor(as.numeric(probs > cutoff)) + levels(preds_status) <- levels(as.factor(data[, outcome])) } # check group feature and cut if needed @@ -87,8 +88,21 @@ npv_parity <- function(data, outcome, group, group_status <- as.factor(data[, group]) outcome_status <- as.factor(data[, outcome]) + # check levels matching + if (!identical(levels(outcome_status), levels(preds_status))) { + warn_preds <- paste0(levels(preds_status), collapse = ', ') + warn_outcome <- paste0(levels(outcome_status), collapse = ', ') + stop({paste0(c('Levels of predictions and outcome do not match. ', + 'Please relevel predictions or outcome.\n', + 'Outcome levels: ', warn_preds, '\n', + 'Preds levels: ', warn_outcome))})} + # relevel preds & outcomes - if (is.null(outcome_base)) {outcome_base <- levels(outcome_status)[1]} + if (is.null(outcome_base)) { + outcome_base <- levels(outcome_status)[1] + }else{ + outcome_base <- as.character(outcome_base) + } outcome_status <- relevel(outcome_status, outcome_base) preds_status <- relevel(preds_status, outcome_base) outcome_positive <- levels(outcome_status)[2] diff --git a/R/pred_rate_parity.R b/R/pred_rate_parity.R index 0d2aaf0..53ec57a 100644 --- a/R/pred_rate_parity.R +++ b/R/pred_rate_parity.R @@ -69,7 +69,8 @@ pred_rate_parity <- function(data, outcome, group, if (length(probs) == 1) { probs <- data[, probs] } - preds_status <- as.factor(as.numeric(probs > cutoff)) + preds_status <- as.factor(as.numeric(probs > cutoff)) + levels(preds_status) <- levels(as.factor(data[, outcome])) } # check group feature and cut if needed @@ -88,8 +89,21 @@ pred_rate_parity <- function(data, outcome, group, group_status <- as.factor(data[, group]) outcome_status <- as.factor(data[, outcome]) + # check levels matching + if (!identical(levels(outcome_status), levels(preds_status))) { + warn_preds <- paste0(levels(preds_status), collapse = ', ') + warn_outcome <- paste0(levels(outcome_status), collapse = ', ') + stop({paste0(c('Levels of predictions and outcome do not match. ', + 'Please relevel predictions or outcome.\n', + 'Outcome levels: ', warn_preds, '\n', + 'Preds levels: ', warn_outcome))})} + # relevel preds & outcomes - if (is.null(outcome_base)) {outcome_base <- levels(outcome_status)[1]} + if (is.null(outcome_base)) { + outcome_base <- levels(outcome_status)[1] + }else{ + outcome_base <- as.character(outcome_base) + } outcome_status <- relevel(outcome_status, outcome_base) preds_status <- relevel(preds_status, outcome_base) outcome_positive <- levels(outcome_status)[2] diff --git a/R/prop_parity.R b/R/prop_parity.R index 4e5c728..1ca6bdc 100644 --- a/R/prop_parity.R +++ b/R/prop_parity.R @@ -66,7 +66,8 @@ prop_parity <- function(data, outcome, group, if (length(probs) == 1) { probs <- data[, probs] } - preds_status <- as.factor(as.numeric(probs > cutoff)) + preds_status <- as.factor(as.numeric(probs > cutoff)) + levels(preds_status) <- levels(as.factor(data[, outcome])) } # check group feature and cut if needed @@ -85,8 +86,21 @@ prop_parity <- function(data, outcome, group, group_status <- as.factor(data[, group]) outcome_status <- as.factor(data[, outcome]) + # check levels matching + if (!identical(levels(outcome_status), levels(preds_status))) { + warn_preds <- paste0(levels(preds_status), collapse = ', ') + warn_outcome <- paste0(levels(outcome_status), collapse = ', ') + stop({paste0(c('Levels of predictions and outcome do not match. ', + 'Please relevel predictions or outcome.\n', + 'Outcome levels: ', warn_preds, '\n', + 'Preds levels: ', warn_outcome))})} + # relevel preds & outcomes - if (is.null(outcome_base)) {outcome_base <- levels(outcome_status)[1]} + if (is.null(outcome_base)) { + outcome_base <- levels(outcome_status)[1] + }else{ + outcome_base <- as.character(outcome_base) + } outcome_status <- relevel(outcome_status, outcome_base) preds_status <- relevel(preds_status, outcome_base) outcome_positive <- levels(outcome_status)[2] diff --git a/R/spec_parity.R b/R/spec_parity.R index 896d8ac..c3bf123 100644 --- a/R/spec_parity.R +++ b/R/spec_parity.R @@ -67,7 +67,8 @@ spec_parity <- function(data, outcome, group, if (length(probs) == 1) { probs <- data[, probs] } - preds_status <- as.factor(as.numeric(probs > cutoff)) + preds_status <- as.factor(as.numeric(probs > cutoff)) + levels(preds_status) <- levels(as.factor(data[, outcome])) } # check group feature and cut if needed @@ -86,8 +87,21 @@ spec_parity <- function(data, outcome, group, group_status <- as.factor(data[, group]) outcome_status <- as.factor(data[, outcome]) + # check levels matching + if (!identical(levels(outcome_status), levels(preds_status))) { + warn_preds <- paste0(levels(preds_status), collapse = ', ') + warn_outcome <- paste0(levels(outcome_status), collapse = ', ') + stop({paste0(c('Levels of predictions and outcome do not match. ', + 'Please relevel predictions or outcome.\n', + 'Outcome levels: ', warn_preds, '\n', + 'Preds levels: ', warn_outcome))})} + # relevel preds & outcomes - if (is.null(outcome_base)) {outcome_base <- levels(outcome_status)[1]} + if (is.null(outcome_base)) { + outcome_base <- levels(outcome_status)[1] + }else{ + outcome_base <- as.character(outcome_base) + } outcome_status <- relevel(outcome_status, outcome_base) preds_status <- relevel(preds_status, outcome_base) outcome_positive <- levels(outcome_status)[2]