Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve code #30

Merged
merged 6 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: cat2cat
Title: Handling an Inconsistently Coded Categorical Variable in a Longitudinal Dataset
Version: 0.4.6.9008
Version: 0.4.6.9009
Authors@R: person("Maciej", "Nasinski", email = "nasinski.maciej@gmail.com", role = c("aut", "cre"))
Maintainer: Maciej Nasinski <nasinski.maciej@gmail.com>
Description:
Expand Down
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# cat2cat 0.4.6.9008
# cat2cat 0.4.6.9009

* New `cat2cat_ml_run` function to check the ml models performance before `cat2cat` with ml option is run. Now, the ml models are more transparent.
* Add tests for cat2cat related journal (softwarex) paper.
Expand Down
10 changes: 5 additions & 5 deletions R/cat2cat.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
#' `data`, `cat_var`, `method`, `features` and optional `args`.
#' @details
#' data args
#' \itemize{
#' \describe{
#' \item{"old"}{ data.frame older time point in a panel}
#' \item{"new"} { data.frame more recent time point in a panel}
#' \item{"new"}{ data.frame more recent time point in a panel}
#' \item{"time_var"}{ character(1) name of the time variable.}
#' \item{"cat_var"}{ character(1) name of the categorical variable.}
#' \item{"cat_var_old"}{
Expand All @@ -51,7 +51,7 @@
#' }
#' }
#' mappings args
#' \itemize{
#' \describe{
#' \item{"trans"}{ data.frame with 2 columns - mapping (transition) table -
#' all categories for cat_var in old and new datasets have to be included.
#' First column contains an old encoding and second a new one.
Expand All @@ -70,7 +70,7 @@
#' }
#' }
#' Optional ml args
#' \itemize{
#' \describe{
#' \item{"data"}{ data.frame - dataset with features and the `cat_var`.}
#' \item{"cat_var"}{ character(1) - the dependent variable name.}
#' \item{"method"}{
Expand Down Expand Up @@ -295,7 +295,7 @@ cat2cat <- function(data = list(
}

#' Validate if the trans table contains all proper mappings
#' @param cats_target vector of unique target period categories
#' @param u_cats_target vector of unique target period categories
#' @param mapp transition (mapping) table process with `get_mappings`,
#' the "to_base" direction is taken.
#' @keywords internal
Expand Down
38 changes: 17 additions & 21 deletions R/cat2cat_ml.R
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,9 @@ cat2cat_ml_run <- function(mappings, ml, ...) {
target_name <- "old"
}

mapps <- get_mappings(mappings$trans)
mapp <- mapps[[paste0("to_", base_name)]]

cat_var <- ml$data[[ml$cat_var]]
cat_var_vals <- unlist(mappings$trans[, base_name])
if (sum(cat_var %in% cat_var_vals) / length(cat_var) < elargs$min_match) {
Expand All @@ -246,10 +249,6 @@ cat2cat_ml_run <- function(mappings, ml, ...) {
)
}

mapps <- get_mappings(mappings$trans)
mapp <- mapps[[paste0("to_", base_name)]]

nobs <- nrow(ml$data)
features <- unique(ml$features)
methods <- unique(ml$method)

Expand All @@ -259,14 +258,13 @@ cat2cat_ml_run <- function(mappings, ml, ...) {
)

res <- list()
for (cat in unique(names(mapp))) {
for (cat in names(mapp)) {
try(
{
matched_cat <- mapp[[match(cat, names(mapp))]]
g_name <- paste(matched_cat, collapse = "&")

res[[g_name]] <- list(ncat = length(matched_cat), naive = 1 / length(matched_cat),
acc = stats::setNames(rep(NA_real_, length(methods)), methods), freq = NA_real_)
cat_nam <- if (cat == "") " " else cat
res[[cat_nam]] <- list(naive = NA_real_,
acc = stats::setNames(rep(NA_real_, length(methods)), methods), freq = NA_real_)

data_small_g <- do.call(rbind, train_g[matched_cat])

Expand All @@ -275,26 +273,24 @@ cat2cat_ml_run <- function(mappings, ml, ...) {
next
}

res[[cat_nam]][["naive"]] <- 1 / length(matched_cat)

index_tt <- sample(c(0, 1),
nrow(data_small_g),
prob = c(1 - elargs$test_prop, elargs$test_prop), replace = TRUE)
data_test_small <- data_small_g[index_tt == 1, ]
data_train_small <- data_small_g[index_tt == 0, ]
cc <- complete.cases(data_test_small[, features])

gcounts <- table(data_train_small[[ml$cat_var]])
gfreq <- names(gcounts)[which.max(gcounts)]

res[[g_name]][["freq"]] <- mean(gfreq == data_test_small[[ml$cat_var]])

if (isTRUE(nrow(data_test_small) == 0 || nrow(data_train_small) < 5)) {
if (isTRUE(nrow(data_test_small[cc, ]) == 0 || nrow(data_train_small) < 5)) {
next
}

cc <- complete.cases(data_test_small[, features])
gcounts <- table(data_train_small[[ml$cat_var]])
gfreq <- names(gcounts)[which.max(gcounts)]
res[[cat_nam]][["freq"]] <- mean(gfreq == data_test_small[[ml$cat_var]])

for (m in methods) {
ml_name <- paste0("wei_", m, "_c2c")

if (m == "knn") {
group_prediction <- suppressWarnings(
caret::knn3(
Expand Down Expand Up @@ -332,7 +328,7 @@ cat2cat_ml_run <- function(mappings, ml, ...) {
as.matrix(data_test_small[cc, features, drop = FALSE])
)$class
}
res[[g_name]][["acc"]][m] <- mean(pred == data_test_small[[ml$cat_var]])
res[[cat_nam]][["acc"]][m] <- mean(pred == data_test_small[[ml$cat_var]])
}
},
silent = TRUE
Expand Down Expand Up @@ -361,7 +357,7 @@ print.cat2cat_ml_run <- function(x, ...) {
acc <- mean(vapply(x, function(i) i$acc[m], numeric(1)), na.rm = T)
ml_message <- c(
ml_message,
sprintf("Average (groups) accurecy for %s ml models: %f", m, acc)
sprintf("Average (groups) accuracy for %s ml models: %f", m, acc)
)
howaccn <- mean(vapply(x, function(i) i$naive < mean(i$acc[m], na.rm = TRUE), numeric(1)), na.rm = T)
how_ml_message_n <- c(
Expand Down Expand Up @@ -391,7 +387,7 @@ print.cat2cat_ml_run <- function(x, ...) {
"Selected prediction stats:",
"",
sprintf("Average naive (equal probabilities) guess: %f", acc_naive),
sprintf("Average (groups) accurecy for most frequent category solution: %f", acc_freq),
sprintf("Average (groups) accuracy for most frequent category solution: %f", acc_freq),
ml_message,
"",
na_message,
Expand Down
6 changes: 3 additions & 3 deletions R/cat2cat_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
#' number of rows
#' @details
#' method - specify a method to reduce number of replications
#' \itemize{
#' \describe{
#' \item{"nonzero"}{ remove nonzero probabilities}
#' \item{"highest"} {
#' \item{"highest"}{
#' leave only highest probabilities for each subject- accepting ties
#' }
#' \item{"highest1"} {
#' \item{"highest1"}{
#' leave only highest probabilities for each subject -
#' not accepting ties so always one is returned
#' }
Expand Down
10 changes: 0 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,3 @@ all.equal(nrow(ff), sum(final_data_back$wei_freq_c2c))
```

**More complex examples are presented in the "Get Started" vignette.**

## Graph

The graphs present how the `cat2cat::cat2cat` function works, in this case under a panel dataset without the unique identifiers and only two periods.

![Backward Mapping](./man/figures/back_nom.png)

![Forward Mapping](./man/figures/for_nom.png)


8 changes: 4 additions & 4 deletions man/cat2cat.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions man/prune_c2c.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/validate_cover_cats.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tests/testthat/test-cat2cat_ml.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ testthat::test_that("cat2cat_ml_run", {
testthat::expect_equal(res, res2)
testthat::expect_s3_class(res, c("cat2cat_ml_run", "list"))
testthat::expect_output(print(res), "Selected prediction stats:")
testthat::expect_output(print(res), "Percent of failed knn ml models: 32")
testthat::expect_output(print(res), "Percent of failed knn ml models:")
})

testthat::test_that("cat2cat_ml_run wrong direction", {
Expand Down
Loading