-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added group_by parameter to util_corr_fit() #69
base: main
Are you sure you want to change the base?
Changes from 11 commits
ac10a71
ecf5cd2
ea65fe4
39cf538
543d063
731cb64
e8a292f
6423b3a
b2cccd6
7a57123
8309fc4
e42d840
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,21 +2,28 @@ | |
#' | ||
#' @param postsynth A postsynth object from tidysynthesis or a tibble | ||
#' @param data an original (observed) data set. | ||
#' @param group_by The unquoted name of a (or multiple) grouping variable(s) | ||
#' | ||
#' @return A `list` of fit metrics: | ||
#' - `correlation_original`: correlation matrix of the original data. | ||
#' - `correlation_synthetic`: correlation matrix of the synthetic data. | ||
#' - `correlation_difference`: difference between `correlation_synthetic` and | ||
#' `correlation_original`. | ||
#' - `correlation_data`: A `tibble` of the correlations among the | ||
#' numeric variables for the actual and synthetic data | ||
#' - `correlation_fit`: square root of the sum of squared differences between | ||
#' `correlation_synthetic` and `correlation_original`, divided by the number of | ||
#' the synthetic and original data, divided by the number of | ||
#' cells in the correlation matrix. | ||
#' - `correlation_difference_mae`: the mean of the absolute correlation | ||
#' differences between the actual and synthetic data | ||
#' - `correlation_difference_rmse`: the root mean of the squared correlation | ||
#' differences between the actual and synthetic data | ||
#' | ||
#' @family utility functions | ||
#' | ||
#' @export | ||
|
||
util_corr_fit <- function(postsynth, data) { | ||
util_corr_fit <- function(postsynth, | ||
data, | ||
group_by = NULL) { | ||
|
||
|
||
|
||
if (is_postsynth(postsynth)) { | ||
|
||
|
@@ -25,15 +32,47 @@ util_corr_fit <- function(postsynth, data) { | |
} else { | ||
|
||
synthetic_data <- postsynth | ||
|
||
} | ||
|
||
synthetic_data <- dplyr::select_if(synthetic_data, is.numeric) | ||
data <- dplyr::select_if(data, is.numeric) | ||
|
||
# reorder data names | ||
# reorder data names (this appears to check if the variables are the same) | ||
# issue when the groups in the synthetic data do not match the groups in the og data, and vice versa | ||
# thinking about filling in all of groupings for each dataset first then running everything else | ||
data <- dplyr::select(data, names(synthetic_data)) | ||
|
||
synthetic_data <- dplyr::select(synthetic_data, dplyr::where(is.numeric), {{ group_by }}) |> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We're still using |
||
dplyr::arrange(dplyr::across({{ group_by }})) |> | ||
dplyr::group_split(dplyr::across({{ group_by }})) | ||
|
||
data <- dplyr::select(data, dplyr::where(is.numeric), {{ group_by }}) |> | ||
dplyr::arrange(dplyr::across({{ group_by }})) |> | ||
dplyr::group_split(dplyr::across({{ group_by }})) | ||
|
||
groups <- lapply(data, function(x) dplyr::select(x, {{ group_by }}) |> | ||
slice(1)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you replace this with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This code is to add the group by variables to the final datasets. I can add additional code to add the Ns to the metric data. I need to think more about how to add it to the corr_data dataset. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
results <- purrr::pmap( | ||
.l = list(synthetic_data, data, groups), | ||
.f = get_correlations | ||
) | ||
|
||
metrics <- dplyr::bind_cols( | ||
correlation_fit = map_dbl(results, "correlation_fit"), | ||
correlation_difference_mae = map_dbl(results, "correlation_difference_mae"), | ||
correlation_difference_rmse = map_dbl(results, "correlation_difference_rmse"), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
bind_rows(groups) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
) | ||
|
||
corr_data <- dplyr::bind_rows(map_dfr(results, "correlation_data")) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
return(list( | ||
corr_data, | ||
metrics | ||
)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. return(
list(
corr_data,
metrics
)
) |
||
} | ||
|
||
get_correlations <- function(synthetic_data, | ||
data, | ||
groups) { | ||
# helper function to find a correlation matrix with the upper tri set to zeros | ||
lower_triangle <- function(x) { | ||
|
||
|
@@ -43,58 +82,75 @@ util_corr_fit <- function(postsynth, data) { | |
dplyr::select_if(is.numeric) %>% | ||
stats::cor() | ||
|
||
# set the values in the upper triangle to zero to avoid double counting | ||
# set NA values in the lower triangle to "", set the values in the upper triangle to zero to avoid double counting | ||
correlation_matrix[is.na(correlation_matrix[lower.tri(correlation_matrix, diag = FALSE)])] <- "" | ||
correlation_matrix[upper.tri(correlation_matrix, diag = TRUE)] <- NA | ||
|
||
return(correlation_matrix) | ||
} | ||
|
||
# find the lower triangle of the original data linear correlation matrix | ||
original_lt <- lower_triangle(data) | ||
|
||
# find the lower triangle of the synthetic data linear correlation matrix | ||
synthetic_lt <- lower_triangle(synthetic_data) | ||
# find the lower triangle of the linear correlation matrices and add a var column | ||
original_lt <- data.frame(lower_triangle(data)) | ||
original_lt$var2 <- colnames(original_lt) | ||
|
||
# compare names | ||
if (any(rownames(original_lt) != rownames(synthetic_lt))) { | ||
stop("ERROR: rownames are not identical") | ||
} | ||
synthetic_lt <- data.frame(lower_triangle(synthetic_data)) | ||
synthetic_lt$var2 <- colnames(synthetic_lt) | ||
|
||
if (any(colnames(original_lt) != colnames(synthetic_lt))) { | ||
stop("ERROR: colnames are not identical") | ||
} | ||
# restructure the correlation matrix so the cols are var1, var2, original/synthetic | ||
original_lt <- original_lt %>% | ||
tidyr::pivot_longer(cols = !var2, names_to = "var1", values_to = "original") %>% | ||
dplyr::filter(!is.na(original)) %>% | ||
dplyr::arrange(var1) %>% | ||
dplyr::select(var1, var2, original) %>% | ||
dplyr::mutate(original = dplyr::case_when(.data$original == "" ~ NA, | ||
.default = .data$original)) | ||
|
||
# find the difference between the matrices | ||
difference_lt <- synthetic_lt - original_lt | ||
synthetic_lt <- synthetic_lt %>% | ||
tidyr::pivot_longer(cols = !var2, names_to = "var1", values_to = "synthetic") %>% | ||
dplyr::filter(!is.na(synthetic)) %>% | ||
dplyr::arrange(var1) %>% | ||
dplyr::select(var1, var2, synthetic) %>% | ||
dplyr::mutate(synthetic = dplyr::case_when(.data$synthetic == "" ~ NA, | ||
.default = .data$synthetic)) | ||
|
||
# find the length of the nonzero values in the matrices | ||
n <- choose(ncol(difference_lt), 2) | ||
# combine the data and find the difference between the original and synthetic correlations | ||
correlation_data <- original_lt %>% | ||
dplyr::left_join(synthetic_lt, by = c("var1","var2")) %>% | ||
dplyr::mutate(original = as.numeric(.data$original), | ||
synthetic = as.numeric(.data$synthetic), | ||
difference = .data$original - .data$synthetic, | ||
proportion_difference = .data$difference / .data$original) | ||
|
||
correlation_data <- bind_cols(correlation_data, groups) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
# find the number of values in the lower triangle | ||
n <- nrow(dplyr::filter(correlation_data, !is.na(difference))) | ||
|
||
# calculate the correlation fit and divide by n | ||
correlation_fit <- sqrt(sum(difference_lt ^ 2, na.rm = TRUE)) / n | ||
correlation_fit <- sqrt(sum(correlation_data$difference ^ 2, na.rm = TRUE)) / n | ||
|
||
difference_vec <- as.numeric(difference_lt)[!is.na(difference_lt)] | ||
difference_vec <- as.numeric(correlation_data$difference) | ||
|
||
# mean absolute error | ||
correlation_difference_mae <- difference_vec %>% | ||
abs() %>% | ||
mean() | ||
|
||
# root mean square error | ||
correlation_difference_rmse <- | ||
difference_vec ^ 2%>% | ||
correlation_difference_rmse <- difference_vec ^ 2 %>% | ||
mean() %>% | ||
sqrt() | ||
|
||
|
||
return( | ||
list( | ||
correlation_original = original_lt, | ||
correlation_synthetic = synthetic_lt, | ||
correlation_difference = difference_lt, | ||
correlation_data = correlation_data, | ||
correlation_fit = correlation_fit, | ||
correlation_difference_mae = correlation_difference_mae, | ||
correlation_difference_rmse = correlation_difference_rmse | ||
) | ||
) | ||
|
||
} | ||
} | ||
|
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"og data" may be a little casual for our roxygen headers...