Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added group_by parameter to util_corr_fit() #69

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
126 changes: 91 additions & 35 deletions R/util_corr_fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,28 @@
#'
#' @param postsynth A postsynth object from tidysynthesis or a tibble
#' @param data an original (observed) data set.
#' @param group_by The unquoted name of a (or multiple) grouping variable(s)
#'
#' @return A `list` of fit metrics:
#' - `correlation_original`: correlation matrix of the original data.
#' - `correlation_synthetic`: correlation matrix of the synthetic data.
#' - `correlation_difference`: difference between `correlation_synthetic` and
#' `correlation_original`.
#' - `correlation_data`: A `tibble` of the correlations among the
#' numeric variables for the actual and synthetic data
#' - `correlation_fit`: square root of the sum of squared differences between
#' `correlation_synthetic` and `correlation_original`, divided by the number of
#' the synthetic and original data, divided by the number of
#' cells in the correlation matrix.
#' - `correlation_difference_mae`: the mean of the absolute correlation
#' differences between the actual and synthetic data
#' - `correlation_difference_rmse`: the root mean of the squared correlation
#' differences between the actual and synthetic data
#'
#' @family utility functions
#'
#' @export

util_corr_fit <- function(postsynth, data) {
util_corr_fit <- function(postsynth,
data,
group_by = NULL) {



if (is_postsynth(postsynth)) {

Expand All @@ -25,15 +32,47 @@ util_corr_fit <- function(postsynth, data) {
} else {

synthetic_data <- postsynth

}

synthetic_data <- dplyr::select_if(synthetic_data, is.numeric)
data <- dplyr::select_if(data, is.numeric)

# reorder data names
# reorder data names (this appears to check if the variables are the same)
# issue when the groups in the synthetic data do not match the groups in the og data, and vice versa
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"og data" may be a little casual for our roxygen headers...

# thinking about filling in all of groupings for each dataset first then running everything else
data <- dplyr::select(data, names(synthetic_data))

synthetic_data <- dplyr::select(synthetic_data, dplyr::where(is.numeric), {{ group_by }}) |>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're still using %>% instead of |> now to make sure the code is backwards compatible with R < 4.0.0.

dplyr::arrange(dplyr::across({{ group_by }})) |>
dplyr::group_split(dplyr::across({{ group_by }}))

data <- dplyr::select(data, dplyr::where(is.numeric), {{ group_by }}) |>
dplyr::arrange(dplyr::across({{ group_by }})) |>
dplyr::group_split(dplyr::across({{ group_by }}))

groups <- lapply(data, function(x) dplyr::select(x, {{ group_by }}) |>
slice(1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dplyr::slice() instead of just slice().

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you replace this with count(data, groups)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is to add the group by variables to the final datasets. I can add additional code to add the Ns to the metric data. I need to think more about how to add it to the corr_data dataset.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

count(data, {{ group_by }}) will return a data frame with the groups and the frequency of the groups that you can plug into bind_cols() below.


results <- purrr::pmap(
.l = list(synthetic_data, data, groups),
.f = get_correlations
)

metrics <- dplyr::bind_cols(
correlation_fit = map_dbl(results, "correlation_fit"),
correlation_difference_mae = map_dbl(results, "correlation_difference_mae"),
correlation_difference_rmse = map_dbl(results, "correlation_difference_rmse"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

purrr::map_dbl()

bind_rows(groups)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dplyr::bind_rows()

)

corr_data <- dplyr::bind_rows(map_dfr(results, "correlation_data"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

purrr::map_dfr()


return(list(
corr_data,
metrics
))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return(
  list(
    corr_data,
    metrics
  )
)

}

get_correlations <- function(synthetic_data,
data,
groups) {
# helper function to find a correlation matrix with the upper tri set to zeros
lower_triangle <- function(x) {

Expand All @@ -43,58 +82,75 @@ util_corr_fit <- function(postsynth, data) {
dplyr::select_if(is.numeric) %>%
stats::cor()

# set the values in the upper triangle to zero to avoid double counting
# set NA values in the lower triangle to "", set the values in the upper triangle to zero to avoid double counting
correlation_matrix[is.na(correlation_matrix[lower.tri(correlation_matrix, diag = FALSE)])] <- ""
correlation_matrix[upper.tri(correlation_matrix, diag = TRUE)] <- NA

return(correlation_matrix)
}

# find the lower triangle of the original data linear correlation matrix
original_lt <- lower_triangle(data)

# find the lower triangle of the synthetic data linear correlation matrix
synthetic_lt <- lower_triangle(synthetic_data)
# find the lower triangle of the linear correlation matrices and add a var column
original_lt <- data.frame(lower_triangle(data))
original_lt$var2 <- colnames(original_lt)

# compare names
if (any(rownames(original_lt) != rownames(synthetic_lt))) {
stop("ERROR: rownames are not identical")
}
synthetic_lt <- data.frame(lower_triangle(synthetic_data))
synthetic_lt$var2 <- colnames(synthetic_lt)

if (any(colnames(original_lt) != colnames(synthetic_lt))) {
stop("ERROR: colnames are not identical")
}
# restructure the correlation matrix so the cols are var1, var2, original/synthetic
original_lt <- original_lt %>%
tidyr::pivot_longer(cols = !var2, names_to = "var1", values_to = "original") %>%
dplyr::filter(!is.na(original)) %>%
dplyr::arrange(var1) %>%
dplyr::select(var1, var2, original) %>%
dplyr::mutate(original = dplyr::case_when(.data$original == "" ~ NA,
.default = .data$original))

# find the difference between the matrices
difference_lt <- synthetic_lt - original_lt
synthetic_lt <- synthetic_lt %>%
tidyr::pivot_longer(cols = !var2, names_to = "var1", values_to = "synthetic") %>%
dplyr::filter(!is.na(synthetic)) %>%
dplyr::arrange(var1) %>%
dplyr::select(var1, var2, synthetic) %>%
dplyr::mutate(synthetic = dplyr::case_when(.data$synthetic == "" ~ NA,
.default = .data$synthetic))

# find the length of the nonzero values in the matrices
n <- choose(ncol(difference_lt), 2)
# combine the data and find the difference between the original and synthetic correlations
correlation_data <- original_lt %>%
dplyr::left_join(synthetic_lt, by = c("var1","var2")) %>%
dplyr::mutate(original = as.numeric(.data$original),
synthetic = as.numeric(.data$synthetic),
difference = .data$original - .data$synthetic,
proportion_difference = .data$difference / .data$original)

correlation_data <- bind_cols(correlation_data, groups)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dplyr::binds_cols()


# find the number of values in the lower triangle
n <- nrow(dplyr::filter(correlation_data, !is.na(difference)))

# calculate the correlation fit and divide by n
correlation_fit <- sqrt(sum(difference_lt ^ 2, na.rm = TRUE)) / n
correlation_fit <- sqrt(sum(correlation_data$difference ^ 2, na.rm = TRUE)) / n

difference_vec <- as.numeric(difference_lt)[!is.na(difference_lt)]
difference_vec <- as.numeric(correlation_data$difference)

# mean absolute error
correlation_difference_mae <- difference_vec %>%
abs() %>%
mean()

# root mean square error
correlation_difference_rmse <-
difference_vec ^ 2%>%
correlation_difference_rmse <- difference_vec ^ 2 %>%
mean() %>%
sqrt()


return(
list(
correlation_original = original_lt,
correlation_synthetic = synthetic_lt,
correlation_difference = difference_lt,
correlation_data = correlation_data,
correlation_fit = correlation_fit,
correlation_difference_mae = correlation_difference_mae,
correlation_difference_rmse = correlation_difference_rmse
)
)

}
}

16 changes: 10 additions & 6 deletions man/util_corr_fit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading