Skip to content

Commit

Permalink
improve code (#30)
Browse files Browse the repository at this point in the history
* spelling
* improve ml solution
* fix test
* simplify README file
* latex related update
  • Loading branch information
Polkas authored Jan 22, 2024
1 parent 10c4207 commit 8754a86
Show file tree
Hide file tree
Showing 10 changed files with 37 additions and 51 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: cat2cat
Title: Handling an Inconsistently Coded Categorical Variable in a Longitudinal Dataset
Version: 0.4.6.9008
Version: 0.4.6.9009
Authors@R: person("Maciej", "Nasinski", email = "nasinski.maciej@gmail.com", role = c("aut", "cre"))
Maintainer: Maciej Nasinski <nasinski.maciej@gmail.com>
Description:
Expand Down
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# cat2cat 0.4.6.9008
# cat2cat 0.4.6.9009

* New `cat2cat_ml_run` function to check the ml models performance before `cat2cat` with ml option is run. Now, the ml models are more transparent.
* Add tests for cat2cat related journal (softwarex) paper.
Expand Down
10 changes: 5 additions & 5 deletions R/cat2cat.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
#' `data`, `cat_var`, `method`, `features` and optional `args`.
#' @details
#' data args
#' \itemize{
#' \describe{
#' \item{"old"}{ data.frame older time point in a panel}
#' \item{"new"} { data.frame more recent time point in a panel}
#' \item{"new"}{ data.frame more recent time point in a panel}
#' \item{"time_var"}{ character(1) name of the time variable.}
#' \item{"cat_var"}{ character(1) name of the categorical variable.}
#' \item{"cat_var_old"}{
Expand All @@ -51,7 +51,7 @@
#' }
#' }
#' mappings args
#' \itemize{
#' \describe{
#' \item{"trans"}{ data.frame with 2 columns - mapping (transition) table -
#' all categories for cat_var in old and new datasets have to be included.
#' First column contains an old encoding and second a new one.
Expand All @@ -70,7 +70,7 @@
#' }
#' }
#' Optional ml args
#' \itemize{
#' \describe{
#' \item{"data"}{ data.frame - dataset with features and the `cat_var`.}
#' \item{"cat_var"}{ character(1) - the dependent variable name.}
#' \item{"method"}{
Expand Down Expand Up @@ -295,7 +295,7 @@ cat2cat <- function(data = list(
}

#' Validate if the trans table contains all proper mappings
#' @param cats_target vector of unique target period categories
#' @param u_cats_target vector of unique target period categories
#' @param mapp transition (mapping) table process with `get_mappings`,
#' the "to_base" direction is taken.
#' @keywords internal
Expand Down
38 changes: 17 additions & 21 deletions R/cat2cat_ml.R
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,9 @@ cat2cat_ml_run <- function(mappings, ml, ...) {
target_name <- "old"
}

mapps <- get_mappings(mappings$trans)
mapp <- mapps[[paste0("to_", base_name)]]

cat_var <- ml$data[[ml$cat_var]]
cat_var_vals <- unlist(mappings$trans[, base_name])
if (sum(cat_var %in% cat_var_vals) / length(cat_var) < elargs$min_match) {
Expand All @@ -246,10 +249,6 @@ cat2cat_ml_run <- function(mappings, ml, ...) {
)
}

mapps <- get_mappings(mappings$trans)
mapp <- mapps[[paste0("to_", base_name)]]

nobs <- nrow(ml$data)
features <- unique(ml$features)
methods <- unique(ml$method)

Expand All @@ -259,14 +258,13 @@ cat2cat_ml_run <- function(mappings, ml, ...) {
)

res <- list()
for (cat in unique(names(mapp))) {
for (cat in names(mapp)) {
try(
{
matched_cat <- mapp[[match(cat, names(mapp))]]
g_name <- paste(matched_cat, collapse = "&")

res[[g_name]] <- list(ncat = length(matched_cat), naive = 1 / length(matched_cat),
acc = stats::setNames(rep(NA_real_, length(methods)), methods), freq = NA_real_)
cat_nam <- if (cat == "") " " else cat
res[[cat_nam]] <- list(naive = NA_real_,
acc = stats::setNames(rep(NA_real_, length(methods)), methods), freq = NA_real_)

data_small_g <- do.call(rbind, train_g[matched_cat])

Expand All @@ -275,26 +273,24 @@ cat2cat_ml_run <- function(mappings, ml, ...) {
next
}

res[[cat_nam]][["naive"]] <- 1 / length(matched_cat)

index_tt <- sample(c(0, 1),
nrow(data_small_g),
prob = c(1 - elargs$test_prop, elargs$test_prop), replace = TRUE)
data_test_small <- data_small_g[index_tt == 1, ]
data_train_small <- data_small_g[index_tt == 0, ]
cc <- complete.cases(data_test_small[, features])

gcounts <- table(data_train_small[[ml$cat_var]])
gfreq <- names(gcounts)[which.max(gcounts)]

res[[g_name]][["freq"]] <- mean(gfreq == data_test_small[[ml$cat_var]])

if (isTRUE(nrow(data_test_small) == 0 || nrow(data_train_small) < 5)) {
if (isTRUE(nrow(data_test_small[cc, ]) == 0 || nrow(data_train_small) < 5)) {
next
}

cc <- complete.cases(data_test_small[, features])
gcounts <- table(data_train_small[[ml$cat_var]])
gfreq <- names(gcounts)[which.max(gcounts)]
res[[cat_nam]][["freq"]] <- mean(gfreq == data_test_small[[ml$cat_var]])

for (m in methods) {
ml_name <- paste0("wei_", m, "_c2c")

if (m == "knn") {
group_prediction <- suppressWarnings(
caret::knn3(
Expand Down Expand Up @@ -332,7 +328,7 @@ cat2cat_ml_run <- function(mappings, ml, ...) {
as.matrix(data_test_small[cc, features, drop = FALSE])
)$class
}
res[[g_name]][["acc"]][m] <- mean(pred == data_test_small[[ml$cat_var]])
res[[cat_nam]][["acc"]][m] <- mean(pred == data_test_small[[ml$cat_var]])
}
},
silent = TRUE
Expand Down Expand Up @@ -361,7 +357,7 @@ print.cat2cat_ml_run <- function(x, ...) {
acc <- mean(vapply(x, function(i) i$acc[m], numeric(1)), na.rm = T)
ml_message <- c(
ml_message,
sprintf("Average (groups) accurecy for %s ml models: %f", m, acc)
sprintf("Average (groups) accuracy for %s ml models: %f", m, acc)
)
howaccn <- mean(vapply(x, function(i) i$naive < mean(i$acc[m], na.rm = TRUE), numeric(1)), na.rm = T)
how_ml_message_n <- c(
Expand Down Expand Up @@ -391,7 +387,7 @@ print.cat2cat_ml_run <- function(x, ...) {
"Selected prediction stats:",
"",
sprintf("Average naive (equal probabilities) guess: %f", acc_naive),
sprintf("Average (groups) accurecy for most frequent category solution: %f", acc_freq),
sprintf("Average (groups) accuracy for most frequent category solution: %f", acc_freq),
ml_message,
"",
na_message,
Expand Down
6 changes: 3 additions & 3 deletions R/cat2cat_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
#' number of rows
#' @details
#' method - specify a method to reduce number of replications
#' \itemize{
#' \describe{
#' \item{"nonzero"}{ remove nonzero probabilities}
#' \item{"highest"} {
#' \item{"highest"}{
#' leave only highest probabilities for each subject- accepting ties
#' }
#' \item{"highest1"} {
#' \item{"highest1"}{
#' leave only highest probabilities for each subject -
#' not accepting ties so always one is returned
#' }
Expand Down
10 changes: 0 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,3 @@ all.equal(nrow(ff), sum(final_data_back$wei_freq_c2c))
```

**More complex examples are presented in the "Get Started" vignette.**

## Graph

The graphs present how the `cat2cat::cat2cat` function works, in this case under a panel dataset without the unique identifiers and only two periods.

![Backward Mapping](./man/figures/back_nom.png)

![Forward Mapping](./man/figures/for_nom.png)


8 changes: 4 additions & 4 deletions man/cat2cat.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions man/prune_c2c.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/validate_cover_cats.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tests/testthat/test-cat2cat_ml.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ testthat::test_that("cat2cat_ml_run", {
testthat::expect_equal(res, res2)
testthat::expect_s3_class(res, c("cat2cat_ml_run", "list"))
testthat::expect_output(print(res), "Selected prediction stats:")
testthat::expect_output(print(res), "Percent of failed knn ml models: 32")
testthat::expect_output(print(res), "Percent of failed knn ml models:")
})

testthat::test_that("cat2cat_ml_run wrong direction", {
Expand Down

0 comments on commit 8754a86

Please sign in to comment.