Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ml transparent #29

Merged
merged 19 commits into from
Oct 22, 2023
2 changes: 1 addition & 1 deletion .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ jobs:
use-public-rspm: true

- name: Install dependencies
run: Rscript -e "install.packages('rcmdcheck')" -e "install.packages(c('MASS', 'assertthat', 'caret', 'knitr', 'rmarkdown', 'pacman', 'testthat' , 'magrittr', 'dplyr', 'webshot', 'htmlwidgets', 'forcats', 'stargazer', 'tidyr'))"
run: Rscript -e "install.packages('rcmdcheck')" -e "install.packages(c('MASS', 'assertthat', 'caret', 'knitr', 'rmarkdown', 'testthat' , 'magrittr', 'dplyr', 'forcats', 'tidyr'))"

- name: Install randomForest old
if: "${{ matrix.config.r == '3.6'}}"
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ R/preproc_data.R
/doc/
/Meta/
.DS_Store
Rplots
10 changes: 3 additions & 7 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: cat2cat
Title: Handling an Inconsistently Coded Categorical Variable in a Longitudinal Dataset
Version: 0.4.6.9007
Version: 0.4.6.9008
Authors@R: person("Maciej", "Nasinski", email = "nasinski.maciej@gmail.com", role = c("aut", "cre"))
Maintainer: Maciej Nasinski <nasinski.maciej@gmail.com>
Description:
Expand All @@ -19,17 +19,13 @@ Suggests:
caret,
dplyr,
forcats,
htmlwidgets,
knitr,
magrittr,
pacman,
randomForest,
rmarkdown,
stargazer,
testthat (>= 3.0.0),
tidyr,
webshot
tidyr
LazyData: true
VignetteBuilder: knitr
RoxygenNote: 7.2.2
RoxygenNote: 7.2.3
Config/testthat/edition: 3
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Generated by roxygen2: do not edit by hand

S3method(print,cat2cat_ml_run)
export(cat2cat)
export(cat2cat_agg)
export(cat2cat_ml_run)
export(cat_apply_freq)
export(cross_c2c)
export(dummy_c2c)
Expand Down
5 changes: 3 additions & 2 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# cat2cat 0.4.6.9007
# cat2cat 0.4.6.9008

* New `cat2cat_ml_run` function to check the ml models performance before `cat2cat` with ml option is run. Now, the ml models are more transparent.
* Add tests for cat2cat related journal (softwarex) paper.
* Add CITATION file, .
* Add CITATION file.
* Internal changes to make the code base more clear.

# cat2cat 0.4.6
Expand Down
193 changes: 15 additions & 178 deletions R/cat2cat.R
Original file line number Diff line number Diff line change
Expand Up @@ -137,19 +137,26 @@
#' mappings = list(trans = trans2, direction = "forward")
#' )
#'
#' # additional probabilities from knn
#' occup_ml <- cat2cat(
#' data = list(
#' old = occup_old, new = occup_new, cat_var = "code", time_var = "year"
#' ),
#' mappings = list(trans = trans, direction = "backward"),
#' ml = list(
#' mappings <- list(trans = trans, direction = "backward")
#'
#' ml_setup <- list(
#' data = occup_small[occup_small$year >= 2010, ],
#' cat_var = "code",
#' method = "knn",
#' features = c("age", "sex", "edu", "exp", "parttime", "salary"),
#' args = list(k = 10)
#' )
#' )
#'
#' # ml model performance check
#' print(cat2cat_ml_run(mappings, ml_setup))
#'
#' # additional probabilities from knn
#' occup_ml <- cat2cat(
#' data = list(
#' old = occup_old, new = occup_new, cat_var = "code", time_var = "year"
#' ),
#' mappings = mappings,
#' ml = ml_setup
#' )
#' }
#'
Expand Down Expand Up @@ -287,163 +294,6 @@ cat2cat <- function(data = list(
res
}

#' The internal function used in the cat2cat one
#' @description apply the ml models to the cat2cat data
#' @param ml `list` the same `ml` argument as provided to `cat2cat` function.
#' @param mapp `list` a mapping table
#' @param target_data `data.frame`
#' @param cat_var_target `character(1)` name of the categorical variable
#' in the target period.
#' @keywords internal
cat2cat_ml <- function(ml, mapp, target_data, cat_var_target) {

stopifnot(all(c("method", "features") %in% names(ml)))
stopifnot(all(ml$method %in% c("knn", "rf", "lda")))

if ("rf" %in% ml$method) {
delayed_package_load("randomForest", "rf")
}

if ("knn" %in% ml$method) {
delayed_package_load("caret", "knn")
}

stopifnot(ml$cat_var %in% colnames(ml$data))
stopifnot(all(ml$features %in% colnames(target_data)))
stopifnot(all(ml$features %in% colnames(ml$data)))
stopifnot(cat_var_target %in% colnames(target_data))

stopifnot(all(vapply(
target_data[, ml$features, drop = FALSE],
function(x) is.numeric(x) || is.logical(x), logical(1)
)))
stopifnot(all(vapply(
ml$data[, ml$features, drop = FALSE],
function(x) is.numeric(x) || is.logical(x), logical(1)
)))

features <- unique(ml$features)
methods <- unique(ml$method)
ml_names <- paste0("wei_", methods, "_c2c")

target_data[, ml_names] <- target_data["wei_freq_c2c"]

cat_ml_year_g <- split(
ml$data[, c(features, ml$cat_var), drop = FALSE],
factor(ml$data[[ml$cat_var]], exclude = NULL)
)
target_data_cats <- target_data[[cat_var_target]]
target_data_cat_c2c <- split(
target_data,
factor(target_data_cats, exclude = NULL)
)

for (cat in unique(names(target_data_cat_c2c))) {
try(
{
matched_cat <- match(cat, names(target_data_cat_c2c))
target_data_cat <- target_data_cat_c2c[[matched_cat]]
dis <- do.call(rbind, cat_ml_year_g[mapp[[match(cat, names(mapp))]]])
udc <- unique(dis[[ml$cat_var]])
if (length(udc) <= 1) {
target_data_cat_c2c[[matched_cat]][ml_names] <-
target_data_cat$wei_freq_c2c
next
}
if (
length(unique(target_data_cat$g_new_c2c)) > 1 &&
length(udc) >= 1 &&
nrow(target_data_cat) > 0 &&
any(unique(target_data_cat$g_new_c2c) %in% names(cat_ml_year_g))
) {
base_ml <-
target_data_cat[
!duplicated(target_data_cat[["index_c2c"]]),
c("index_c2c", features)
]
cc <- complete.cases(base_ml[, features])

for (m in methods) {

ml_name <- paste0("wei_", m, "_c2c")

if (m == "knn") {
group_prediction <- suppressWarnings(
caret::knn3(
x = dis[, features, drop = FALSE],
y = factor(dis[[ml$cat_var]]),
k = min(ml$args$k, ceiling(nrow(dis) / 4))
)
)
pp <- as.data.frame(
stats::predict(
group_prediction,
base_ml[cc, features, drop = FALSE],
type = "prob"
)
)
} else if (m == "rf") {
group_prediction <- suppressWarnings(
randomForest::randomForest(
y = factor(dis[[ml$cat_var]]),
x = dis[, features, drop = FALSE],
ntree = min(ml$args$ntree, 100)
)
)
pp <- as.data.frame(
stats::predict(
group_prediction,
base_ml[cc, features, drop = FALSE],
type = "prob"
)
)
} else if (m == "lda") {
group_prediction <- suppressWarnings(
MASS::lda(
grouping = factor(dis[[ml$cat_var]]),
x = as.matrix(dis[, features, drop = FALSE])
)
)
pp <- as.data.frame(
stats::predict(
group_prediction,
as.matrix(base_ml[cc, features, drop = FALSE])
)$posterior
)
}
ll <- setdiff(unique(target_data_cat$g_new_c2c), colnames(pp))
# imputing rest of the class to zero prob
if (length(ll)) {
pp[ll] <- 0
}
pp_stack <- utils::stack(pp)
pp[["index_c2c"]] <- base_ml[["index_c2c"]][cc]
res <- cbind(pp_stack, index_c2c = rep(pp$index_c2c, ncol(pp) - 1))
colnames(res) <- c("val", "g_new_c2c", "index_c2c")
ress <- merge(
target_data_cat[, c("index_c2c", "g_new_c2c")],
res,
by = c("index_c2c", "g_new_c2c"),
all.x = TRUE,
sort = FALSE
)
resso <- ress[order(ress$index_c2c), ]
target_data_cat_c2c[[
match(cat, names(target_data_cat_c2c))
]][[ml_name]] <- resso$val
}
}
},
silent = TRUE
)
}

target_data <- do.call(rbind, target_data_cat_c2c)
target_data <- target_data[order(target_data[["index_c2c"]]), ]

list(target_data = target_data)
}

#' Validate if the trans table contains all proper mappings
#' @param cats_target vector of unique target period categories
#' @param mapp transition (mapping) table process with `get_mappings`,
Expand Down Expand Up @@ -539,16 +389,3 @@ validate_mappings <- function(mappings) {
stopifnot(isTRUE(mappings$direction %in% c("forward", "backward")))
stopifnot(is.data.frame(mappings$trans) && ncol(mappings$trans) == 2)
}

#" Delayed load of a package
#' @keywords internal
delayed_package_load <- function(package, name) {
if (isFALSE(suppressPackageStartupMessages(requireNamespace(package, quietly = TRUE)))) {
stop(
sprintf(
"Please install %s package to use the %s model in the cat2cat function.",
package, name
)
)
}
}
Loading
Loading