Skip to content

Commit

Permalink
fix: Correct handling of factors
Browse files Browse the repository at this point in the history
- fixed errors for some factor outcomes
- more detailed error messages for factor level mismatch
  • Loading branch information
kozodoi committed Apr 14, 2021
1 parent 700b9f7 commit efc81fa
Show file tree
Hide file tree
Showing 11 changed files with 176 additions and 36 deletions.
46 changes: 30 additions & 16 deletions R/acc_parity.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ acc_parity <- function(data, outcome, group,
warning(paste0('Converting ', class(data)[1], ' to data.frame'))
data <- as.data.frame(data)
}

# convert types, sync levels
if (is.null(probs) & is.null(preds)) {
stop({'Either probs or preds have to be supplied'})
Expand All @@ -68,7 +68,8 @@ acc_parity <- function(data, outcome, group,
if (length(probs) == 1) {
probs <- data[, probs]
}
preds_status <- as.factor(as.numeric(probs > cutoff))
preds_status <- as.factor(as.numeric(probs > cutoff))
levels(preds_status) <- levels(as.factor(data[, outcome]))
}

# check group feature and cut if needed
Expand All @@ -87,27 +88,40 @@ acc_parity <- function(data, outcome, group,
group_status <- as.factor(data[, group])
outcome_status <- as.factor(data[, outcome])

# check levels matching
if (!identical(levels(outcome_status), levels(preds_status))) {
warn_preds <- paste0(levels(preds_status), collapse = ', ')
warn_outcome <- paste0(levels(outcome_status), collapse = ', ')
stop({paste0(c('Levels of predictions and outcome do not match. ',
'Please relevel predictions or outcome.\n',
'Outcome levels: ', warn_preds, '\n',
'Preds levels: ', warn_outcome))})}

# relevel preds & outcomes
if (is.null(outcome_base)) {outcome_base <- levels(outcome_status)[1]}
if (is.null(outcome_base)) {
outcome_base <- levels(outcome_status)[1]
}else{
outcome_base <- as.character(outcome_base)
}
outcome_status <- relevel(outcome_status, outcome_base)
preds_status <- relevel(preds_status, outcome_base)
outcome_positive <- levels(outcome_status)[2]

# check lengths
if ((length(outcome_status) != length(preds_status)) | (length(outcome_status) !=
length(group_status))) {
length(group_status))) {
stop('Outcomes, predictions/probabilities and group status must be of the same length')
}

# relevel group
if (is.null(base)) {base <- levels(group_status)[1]}
group_status <- relevel(group_status, base)

# placeholders
val <- rep(NA, length(levels(group_status)))
names(val) <- levels(group_status)
sample_size <- val

# compute value for all groups
for (i in levels(group_status)) {
cm <- caret::confusionMatrix(preds_status[group_status == i],
Expand All @@ -118,38 +132,38 @@ acc_parity <- function(data, outcome, group,
val[i] <- metric_i
sample_size[i] <- sum(cm$table)
}

# aggregate results
res_table <- rbind(val, val/val[[1]], sample_size)
rownames(res_table) <- c('Accuracy', 'Accuracy Parity', 'Group size')

# conversion of metrics to df
val_df <- as.data.frame(res_table[2, ])
colnames(val_df) <- c('val')
val_df$groupst <- rownames(val_df)
val_df$groupst <- as.factor(val_df$groupst)

# relevel group
if (is.null(base)) {
val_df$groupst <- levels(val_df$groupst)[1]
}
val_df$groupst <- relevel(val_df$groupst, base)

p <- ggplot(val_df, aes(x = groupst, weight = val, fill = groupst)) + geom_bar(alpha = 0.5) +
coord_flip() + theme(legend.position = 'none') + labs(x = '', y = 'Accuracy Parity')

# plotting
if (!is.null(probs)) {
q <- ggplot(data, aes(x = probs, fill = group_status)) + geom_density(alpha = 0.5) +
labs(x = 'Predicted probabilities') + guides(fill = guide_legend(title = '')) +
theme(plot.title = element_text(hjust = 0.5)) + xlim(0, 1) + geom_vline(xintercept = cutoff,
linetype = 'dashed')
linetype = 'dashed')
}

if (is.null(probs)) {
list(Metric = res_table, Metric_plot = p)
} else {
list(Metric = res_table, Metric_plot = p, Probability_plot = q)
}

}
}
18 changes: 16 additions & 2 deletions R/dem_parity.R
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ dem_parity <- function(data, outcome, group,
if (length(probs) == 1) {
probs <- data[, probs]
}
preds_status <- as.factor(as.numeric(probs > cutoff))
preds_status <- as.factor(as.numeric(probs > cutoff))
levels(preds_status) <- levels(as.factor(data[, outcome]))
}

# check group feature and cut if needed
Expand All @@ -87,8 +88,21 @@ dem_parity <- function(data, outcome, group,
group_status <- as.factor(data[, group])
outcome_status <- as.factor(data[, outcome])

# check levels matching
if (!identical(levels(outcome_status), levels(preds_status))) {
warn_preds <- paste0(levels(preds_status), collapse = ', ')
warn_outcome <- paste0(levels(outcome_status), collapse = ', ')
stop({paste0(c('Levels of predictions and outcome do not match. ',
'Please relevel predictions or outcome.\n',
'Outcome levels: ', warn_preds, '\n',
'Preds levels: ', warn_outcome))})}

# relevel preds & outcomes
if (is.null(outcome_base)) {outcome_base <- levels(outcome_status)[1]}
if (is.null(outcome_base)) {
outcome_base <- levels(outcome_status)[1]
}else{
outcome_base <- as.character(outcome_base)
}
outcome_status <- relevel(outcome_status, outcome_base)
preds_status <- relevel(preds_status, outcome_base)
outcome_positive <- levels(outcome_status)[2]
Expand Down
18 changes: 16 additions & 2 deletions R/equal_odds.R
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ equal_odds <- function(data, outcome, group,
if (length(probs) == 1) {
probs <- data[, probs]
}
preds_status <- as.factor(as.numeric(probs > cutoff))
preds_status <- as.factor(as.numeric(probs > cutoff))
levels(preds_status) <- levels(as.factor(data[, outcome]))
}

# check group feature and cut if needed
Expand All @@ -88,8 +89,21 @@ equal_odds <- function(data, outcome, group,
group_status <- as.factor(data[, group])
outcome_status <- as.factor(data[, outcome])

# check levels matching
if (!identical(levels(outcome_status), levels(preds_status))) {
warn_preds <- paste0(levels(preds_status), collapse = ', ')
warn_outcome <- paste0(levels(outcome_status), collapse = ', ')
stop({paste0(c('Levels of predictions and outcome do not match. ',
'Please relevel predictions or outcome.\n',
'Outcome levels: ', warn_preds, '\n',
'Preds levels: ', warn_outcome))})}

# relevel preds & outcomes
if (is.null(outcome_base)) {outcome_base <- levels(outcome_status)[1]}
if (is.null(outcome_base)) {
outcome_base <- levels(outcome_status)[1]
}else{
outcome_base <- as.character(outcome_base)
}
outcome_status <- relevel(outcome_status, outcome_base)
preds_status <- relevel(preds_status, outcome_base)
outcome_positive <- levels(outcome_status)[2]
Expand Down
4 changes: 2 additions & 2 deletions R/fairness-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
#' Package: \tab fairness\cr
#' Depends: \tab R (>= 3.5.0)\cr
#' Type: \tab Package\cr
#' Version: \tab 1.2.1\cr
#' Date: \tab 2021-03-26\cr
#' Version: \tab 1.2.2\cr
#' Date: \tab 2021-04-14\cr
#' License: \tab MIT\cr
#' LazyLoad: \tab Yes
#' }
Expand Down
18 changes: 16 additions & 2 deletions R/fnr_parity.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ fnr_parity <- function(data, outcome, group,
if (length(probs) == 1) {
probs <- data[, probs]
}
preds_status <- as.factor(as.numeric(probs > cutoff))
preds_status <- as.factor(as.numeric(probs > cutoff))
levels(preds_status) <- levels(as.factor(data[, outcome]))
}

# check group feature and cut if needed
Expand All @@ -86,8 +87,21 @@ fnr_parity <- function(data, outcome, group,
group_status <- as.factor(data[, group])
outcome_status <- as.factor(data[, outcome])

# check levels matching
if (!identical(levels(outcome_status), levels(preds_status))) {
warn_preds <- paste0(levels(preds_status), collapse = ', ')
warn_outcome <- paste0(levels(outcome_status), collapse = ', ')
stop({paste0(c('Levels of predictions and outcome do not match. ',
'Please relevel predictions or outcome.\n',
'Outcome levels: ', warn_preds, '\n',
'Preds levels: ', warn_outcome))})}

# relevel preds & outcomes
if (is.null(outcome_base)) {outcome_base <- levels(outcome_status)[1]}
if (is.null(outcome_base)) {
outcome_base <- levels(outcome_status)[1]
}else{
outcome_base <- as.character(outcome_base)
}
outcome_status <- relevel(outcome_status, outcome_base)
preds_status <- relevel(preds_status, outcome_base)
outcome_positive <- levels(outcome_status)[2]
Expand Down
18 changes: 16 additions & 2 deletions R/fpr_parity.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ fpr_parity <- function(data, outcome, group,
if (length(probs) == 1) {
probs <- data[, probs]
}
preds_status <- as.factor(as.numeric(probs > cutoff))
preds_status <- as.factor(as.numeric(probs > cutoff))
levels(preds_status) <- levels(as.factor(data[, outcome]))
}

# check group feature and cut if needed
Expand All @@ -86,8 +87,21 @@ fpr_parity <- function(data, outcome, group,
group_status <- as.factor(data[, group])
outcome_status <- as.factor(data[, outcome])

# check levels matching
if (!identical(levels(outcome_status), levels(preds_status))) {
warn_preds <- paste0(levels(preds_status), collapse = ', ')
warn_outcome <- paste0(levels(outcome_status), collapse = ', ')
stop({paste0(c('Levels of predictions and outcome do not match. ',
'Please relevel predictions or outcome.\n',
'Outcome levels: ', warn_preds, '\n',
'Preds levels: ', warn_outcome))})}

# relevel preds & outcomes
if (is.null(outcome_base)) {outcome_base <- levels(outcome_status)[1]}
if (is.null(outcome_base)) {
outcome_base <- levels(outcome_status)[1]
}else{
outcome_base <- as.character(outcome_base)
}
outcome_status <- relevel(outcome_status, outcome_base)
preds_status <- relevel(preds_status, outcome_base)
outcome_positive <- levels(outcome_status)[2]
Expand Down
18 changes: 16 additions & 2 deletions R/mcc_parity.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ mcc_parity <- function(data, outcome, group,
if (length(probs) == 1) {
probs <- data[, probs]
}
preds_status <- as.factor(as.numeric(probs > cutoff))
preds_status <- as.factor(as.numeric(probs > cutoff))
levels(preds_status) <- levels(as.factor(data[, outcome]))
}

# check group feature and cut if needed
Expand All @@ -85,8 +86,21 @@ mcc_parity <- function(data, outcome, group,
group_status <- as.factor(data[, group])
outcome_status <- as.factor(data[, outcome])

# check levels matching
if (!identical(levels(outcome_status), levels(preds_status))) {
warn_preds <- paste0(levels(preds_status), collapse = ', ')
warn_outcome <- paste0(levels(outcome_status), collapse = ', ')
stop({paste0(c('Levels of predictions and outcome do not match. ',
'Please relevel predictions or outcome.\n',
'Outcome levels: ', warn_preds, '\n',
'Preds levels: ', warn_outcome))})}

# relevel preds & outcomes
if (is.null(outcome_base)) {outcome_base <- levels(outcome_status)[1]}
if (is.null(outcome_base)) {
outcome_base <- levels(outcome_status)[1]
}else{
outcome_base <- as.character(outcome_base)
}
outcome_status <- relevel(outcome_status, outcome_base)
preds_status <- relevel(preds_status, outcome_base)
outcome_positive <- levels(outcome_status)[2]
Expand Down
18 changes: 16 additions & 2 deletions R/npv_parity.R
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ npv_parity <- function(data, outcome, group,
if (length(probs) == 1) {
probs <- data[, probs]
}
preds_status <- as.factor(as.numeric(probs > cutoff))
preds_status <- as.factor(as.numeric(probs > cutoff))
levels(preds_status) <- levels(as.factor(data[, outcome]))
}

# check group feature and cut if needed
Expand All @@ -87,8 +88,21 @@ npv_parity <- function(data, outcome, group,
group_status <- as.factor(data[, group])
outcome_status <- as.factor(data[, outcome])

# check levels matching
if (!identical(levels(outcome_status), levels(preds_status))) {
warn_preds <- paste0(levels(preds_status), collapse = ', ')
warn_outcome <- paste0(levels(outcome_status), collapse = ', ')
stop({paste0(c('Levels of predictions and outcome do not match. ',
'Please relevel predictions or outcome.\n',
'Outcome levels: ', warn_preds, '\n',
'Preds levels: ', warn_outcome))})}

# relevel preds & outcomes
if (is.null(outcome_base)) {outcome_base <- levels(outcome_status)[1]}
if (is.null(outcome_base)) {
outcome_base <- levels(outcome_status)[1]
}else{
outcome_base <- as.character(outcome_base)
}
outcome_status <- relevel(outcome_status, outcome_base)
preds_status <- relevel(preds_status, outcome_base)
outcome_positive <- levels(outcome_status)[2]
Expand Down
18 changes: 16 additions & 2 deletions R/pred_rate_parity.R
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ pred_rate_parity <- function(data, outcome, group,
if (length(probs) == 1) {
probs <- data[, probs]
}
preds_status <- as.factor(as.numeric(probs > cutoff))
preds_status <- as.factor(as.numeric(probs > cutoff))
levels(preds_status) <- levels(as.factor(data[, outcome]))
}

# check group feature and cut if needed
Expand All @@ -88,8 +89,21 @@ pred_rate_parity <- function(data, outcome, group,
group_status <- as.factor(data[, group])
outcome_status <- as.factor(data[, outcome])

# check levels matching
if (!identical(levels(outcome_status), levels(preds_status))) {
warn_preds <- paste0(levels(preds_status), collapse = ', ')
warn_outcome <- paste0(levels(outcome_status), collapse = ', ')
stop({paste0(c('Levels of predictions and outcome do not match. ',
'Please relevel predictions or outcome.\n',
'Outcome levels: ', warn_preds, '\n',
'Preds levels: ', warn_outcome))})}

# relevel preds & outcomes
if (is.null(outcome_base)) {outcome_base <- levels(outcome_status)[1]}
if (is.null(outcome_base)) {
outcome_base <- levels(outcome_status)[1]
}else{
outcome_base <- as.character(outcome_base)
}
outcome_status <- relevel(outcome_status, outcome_base)
preds_status <- relevel(preds_status, outcome_base)
outcome_positive <- levels(outcome_status)[2]
Expand Down
18 changes: 16 additions & 2 deletions R/prop_parity.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ prop_parity <- function(data, outcome, group,
if (length(probs) == 1) {
probs <- data[, probs]
}
preds_status <- as.factor(as.numeric(probs > cutoff))
preds_status <- as.factor(as.numeric(probs > cutoff))
levels(preds_status) <- levels(as.factor(data[, outcome]))
}

# check group feature and cut if needed
Expand All @@ -85,8 +86,21 @@ prop_parity <- function(data, outcome, group,
group_status <- as.factor(data[, group])
outcome_status <- as.factor(data[, outcome])

# check levels matching
if (!identical(levels(outcome_status), levels(preds_status))) {
warn_preds <- paste0(levels(preds_status), collapse = ', ')
warn_outcome <- paste0(levels(outcome_status), collapse = ', ')
stop({paste0(c('Levels of predictions and outcome do not match. ',
'Please relevel predictions or outcome.\n',
'Outcome levels: ', warn_preds, '\n',
'Preds levels: ', warn_outcome))})}

# relevel preds & outcomes
if (is.null(outcome_base)) {outcome_base <- levels(outcome_status)[1]}
if (is.null(outcome_base)) {
outcome_base <- levels(outcome_status)[1]
}else{
outcome_base <- as.character(outcome_base)
}
outcome_status <- relevel(outcome_status, outcome_base)
preds_status <- relevel(preds_status, outcome_base)
outcome_positive <- levels(outcome_status)[2]
Expand Down
Loading

0 comments on commit efc81fa

Please sign in to comment.