Skip to content

Commit

Permalink
move calculation of stats to a seperate function and update val and c…
Browse files Browse the repository at this point in the history
…al functions to use it.
  • Loading branch information
ChrisJones687 committed Feb 12, 2024
1 parent cf037ea commit 2f6c595
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 197 deletions.
127 changes: 7 additions & 120 deletions R/calibrate.R
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ calibrate <- function(infected_years_file,
config$county_level_infection_data <- county_level_infection_data
config$pest_host_table <- pest_host_table
config$competency_table <- competency_table
config$point_file <- ""

# call configuration function to perform data checks and transform data into
# format used in pops c++
Expand Down Expand Up @@ -548,46 +549,8 @@ calibrate <- function(infected_years_file,

# calculate comparison metrics for simulation data for each time step in
# the simulation
all_disagreement <-
foreach::foreach(
q = seq_len(length(data$host_pools[[1]]$infected)),
.combine = rbind,
.packages = c("terra", "PoPS"),
.final = colSums
) %do% {
comparison <- terra::rast(config$host_file_list[[1]])[[1]]
reference <- comparison
mask <- comparison
terra::values(comparison) <- 0
infections <- comparison
for (p in seq_len(length(data$host_pools))) {
terra::values(infections) <- data$host_pools[[p]]$infected[[q]]
comparison <- comparison + infections
}
terra::values(mask) <- config$mask_matrix

if (config$county_level_infection_data) {
reference <- terra::vect(config$infected_years_file[[1]])
compare_vect <- reference[, c(1, (q + 1))]
names(compare_vect) <- c("FIPS", "reference")
compare_vect$comparison <- terra::extract(comparison, reference, fun = "sum")[, 2]
ad <- calculated_stats_county_level(compare_vect)
ad$quantity_disagreement <- 0
ad$allocation_disagreement <- 0
ad$allocation_disagreement <- 0
ad$configuration_disagreement <- 0
ad$distance_difference <- 0

} else {
terra::values(reference) <- config$infection_years2[[q]]
ad <- quantity_allocation_disagreement(reference,
comparison,
use_configuration = config$use_configuration,
mask = mask,
use_distance = config$use_distance)
}
ad
}
all_disagreement <- calculate_all_stats(config, data)
all_disagreement <- colSums(all_disagreement)

all_disagreement <- as.data.frame(t(all_disagreement))
all_disagreement <- all_disagreement / length(data$host_pools[[1]]$infected)
Expand Down Expand Up @@ -907,46 +870,8 @@ calibrate <- function(infected_years_file,
proposed_network_max_distance
)

all_disagreement <-
foreach::foreach(
q = seq_len(length(data$host_pools[[1]]$infected)),
.combine = rbind,
.packages = c("terra", "PoPS"),
.final = colSums
) %do% {
comparison <- terra::rast(config$host_file_list[[1]])[[1]]
reference <- comparison
mask <- comparison
terra::values(comparison) <- 0
infections <- comparison
for (p in seq_len(length(data$host_pools))) {
terra::values(infections) <- data$host_pools[[p]]$infected[[q]]
comparison <- comparison + infections
}
terra::values(mask) <- config$mask_matrix

if (config$county_level_infection_data) {
reference <- terra::vect(config$infected_file[[1]])
compare_vect <- reference[, c(1, (q + 1))]
names(compare_vect) <- c("FIPS", "reference")
compare_vect$comparison <- terra::extract(comparison, reference, fun = "sum")[, 2]
ad <- calculated_stats_county_level(compare_vect)
ad$quantity_disagreement <- 0
ad$allocation_disagreement <- 0
ad$allocation_disagreement <- 0
ad$configuration_disagreement <- 0
ad$distance_difference <- 0

} else {
terra::values(reference) <- config$infection_years2[[q]]
ad <- quantity_allocation_disagreement(reference,
comparison,
use_configuration = config$use_configuration,
mask = mask,
use_distance = config$use_distance)
}
ad
}
all_disagreement <- calculate_all_stats(config, data)
all_disagreement <- colSums(all_disagreement)

all_disagreement <- as.data.frame(t(all_disagreement))
all_disagreement <- all_disagreement / length(data$host_pools[[1]]$infected)
Expand Down Expand Up @@ -1102,46 +1027,8 @@ calibrate <- function(infected_years_file,
)

# set up comparison
all_disagreement <-
foreach::foreach(
q = seq_len(length(data$host_pools[[1]]$infected)),
.combine = rbind,
.packages = c("terra", "PoPS"),
.final = colSums
) %do% {
comparison <- terra::rast(config$host_file_list[[1]])[[1]]
reference <- comparison
mask <- comparison
terra::values(comparison) <- 0
infections <- comparison
for (p in seq_len(length(data$host_pools))) {
terra::values(infections) <- data$host_pools[[p]]$infected[[q]]
comparison <- comparison + infections
}
terra::values(mask) <- config$mask_matrix

if (config$county_level_infection_data) {
reference <- terra::vect(config$infected_file[[1]])
compare_vect <- reference[, c(1, (q + 1))]
names(compare_vect) <- c("FIPS", "reference")
compare_vect$comparison <- terra::extract(comparison, reference, fun = "sum")[, 2]
ad <- calculated_stats_county_level(compare_vect)
ad$quantity_disagreement <- 0
ad$allocation_disagreement <- 0
ad$allocation_disagreement <- 0
ad$configuration_disagreement <- 0
ad$distance_difference <- 0

} else {
terra::values(reference) <- config$infection_years2[[q]]
ad <- quantity_allocation_disagreement(reference,
comparison,
use_configuration = config$use_configuration,
mask = mask,
use_distance = config$use_distance)
}
ad
}
all_disagreement <- calculate_all_stats(config, data)
all_disagreement <- colSums(all_disagreement)

all_disagreement <- as.data.frame(t(all_disagreement))
all_disagreement <- all_disagreement / length(data$host_pools[[1]]$infected)
Expand Down
79 changes: 79 additions & 0 deletions R/helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -471,3 +471,82 @@ calculated_stats_county_level <- function(compare_vect) {

return(output)
}


calculate_all_stats <- function(config, data) {
all_disagreement <-
foreach::foreach(
q = seq_len(length(data$host_pools[[1]]$infected)), .combine = rbind,
.packages = c("terra", "PoPS")
) %do% {
# need to assign reference, comparison, and mask in inner loop since
# terra objects are pointers

comparison <- terra::rast(config$host_file_list[[1]])[[1]]
terra::values(comparison) <- 0
reference <- comparison
mask <- comparison
infections <- comparison
for (p in seq_len(length(data$host_pools))) {
terra::values(infections) <- data$host_pools[[p]]$infected[[q]]
comparison <- comparison + infections
}
terra::values(mask) <- config$mask_matrix
if (config$county_level_infection_data) {
reference <- terra::vect(config$infected_years_file[[1]])
compare_vect <- reference[, c(1, (q + 1))]
names(compare_vect) <- c("FIPS", "reference")
compare_vect$comparison <- terra::extract(comparison, reference, fun = "sum")[, 2]
ad <- calculated_stats_county_level(compare_vect)
ad <- calculated_stats_county_level(compare_vect)
ad$quantity_disagreement <- 0
ad$allocation_disagreement <- 0
ad$allocation_disagreement <- 0
ad$configuration_disagreement <- 0
ad$distance_difference <- 0

} else {
terra::values(reference) <- config$infection_years2[[q]]
ad <-
quantity_allocation_disagreement(reference,
comparison,
use_configuration = config$use_configuration,
mask = mask,
use_distance = config$use_distance)
if (file.exists(config$point_file)) {
obs_data <- terra::vect(config$point_file)
obs_data <- terra::project(obs_data, comparison)
s <- extract(comparison, obs_data)
names(s) <- c("ID", paste("sim_value_output_", q, sep = ""))
s <- s[2]
obs_data <- cbind(obs_data, s)
## calculate true positive, true negatives, false positives, false
## negatives, and other statistics and add them to the data frame
## for export
ad$points_true_positive <-
nrow(obs_data[obs_data$positive > 0 & obs_data$sim_value_output_1 > 0, ])
ad$points_false_negative <-
nrow(obs_data[obs_data$positive > 0 & obs_data$sim_value_output_1 == 0, ])
ad$points_false_positive <-
nrow(obs_data[obs_data$positive == 0 & obs_data$sim_value_output_1 > 0, ])
ad$points_true_negative <-
nrow(obs_data[obs_data$positive == 0 & obs_data$sim_value_output_1 == 0, ])
ad$points_total_obs <-
ad$points_true_negative + ad$points_true_positive +
ad$points_false_negative + ad$points_false_positive
ad$points_accuracy <-
(ad$points_true_negative + ad$points_true_positive) / ad$points_total_obs
ad$points_precision <-
ad$points_true_positive / (ad$points_true_positive + ad$points_false_positive)
ad$points_recall <-
ad$points_true_positive / (ad$points_true_positive + ad$points_false_negative)
ad$points_specificiity <-
ad$points_true_negative / (ad$points_true_negative + ad$points_false_positive)
}
}
ad$output <- q
ad
}
all_disagreement <- data.frame(all_disagreement)
return(all_disagreement)
}
72 changes: 2 additions & 70 deletions R/validate.R
Original file line number Diff line number Diff line change
Expand Up @@ -329,76 +329,8 @@ validate <- function(infected_years_file,
dispersers_to_soils_percentage = config$dispersers_to_soils_percentage,
use_soils = config$use_soils)


all_disagreement <-
foreach(
q = seq_len(length(data$host_pools[[1]]$infected)), .combine = rbind,
.packages = c("terra", "PoPS")
) %do% {
# need to assign reference, comparison, and mask in inner loop since
# terra objects are pointers

comparison <- terra::rast(config$host_file_list[[1]])[[1]]
terra::values(comparison) <- 0
reference <- comparison
mask <- comparison
infections <- comparison
for (p in seq_len(length(data$host_pools))) {
terra::values(infections) <- data$host_pools[[p]]$infected[[q]]
comparison <- comparison + infections
}
terra::values(mask) <- config$mask_matrix
if (config$county_level_infection_data) {
reference <- terra::vect(config$infected_years_file[[1]])
compare_vect <- reference[, c(1, (q + 1))]
names(compare_vect) <- c("FIPS", "reference")
compare_vect$comparison <- terra::extract(comparison, reference, fun = "sum")[, 2]
ad <- calculated_stats_county_level(compare_vect)

} else {
terra::values(reference) <- config$infection_years2[[q]]
ad <-
quantity_allocation_disagreement(reference,
comparison,
use_configuration = config$use_configuration,
mask = mask,
use_distance = config$use_distance)
if (file.exists(config$point_file)) {
obs_data <- terra::vect(config$point_file)
obs_data <- terra::project(obs_data, comparison)
s <- extract(comparison, obs_data)
names(s) <- c("ID", paste("sim_value_output_", q, sep = ""))
s <- s[2]
obs_data <- cbind(obs_data, s)
## calculate true positive, true negatives, false positives, false
## negatives, and other statistics and add them to the data frame
## for export
ad$points_true_positive <-
nrow(obs_data[obs_data$positive > 0 & obs_data$sim_value_output_1 > 0, ])
ad$points_false_negative <-
nrow(obs_data[obs_data$positive > 0 & obs_data$sim_value_output_1 == 0, ])
ad$points_false_positive <-
nrow(obs_data[obs_data$positive == 0 & obs_data$sim_value_output_1 > 0, ])
ad$points_true_negative <-
nrow(obs_data[obs_data$positive == 0 & obs_data$sim_value_output_1 == 0, ])
ad$points_total_obs <-
ad$points_true_negative + ad$points_true_positive +
ad$points_false_negative + ad$points_false_positive
ad$points_accuracy <-
(ad$points_true_negative + ad$points_true_positive) / ad$points_total_obs
ad$points_precision <-
ad$points_true_positive / (ad$points_true_positive + ad$points_false_positive)
ad$points_recall <-
ad$points_true_positive / (ad$points_true_positive + ad$points_false_negative)
ad$points_specificiity <-
ad$points_true_negative / (ad$points_true_negative + ad$points_false_positive)
}
}
ad$output <- q
ad
}

data.frame(all_disagreement)
all_disagreement <- calculate_all_stats(config, data)
all_disagreement
}

parallel::stopCluster(cl)
Expand Down
14 changes: 8 additions & 6 deletions tests/testthat/test-calibrate.R
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ test_that("ABC calibration has correctly formatted returns with multiple output
use_soils <- FALSE
soil_starting_pest_file <- ""
start_with_soil_populations <- FALSE
county_level_infection_data <- FALSE

data <- calibrate(infected_years_file,
number_of_observations,
Expand Down Expand Up @@ -243,7 +244,8 @@ test_that("ABC calibration has correctly formatted returns with multiple output
file_random_seeds,
use_soils,
soil_starting_pest_file,
start_with_soil_populations)
start_with_soil_populations,
county_level_infection_data = county_level_infection_data)

expect_length(data$posterior_means, 8)
expect_vector(data$posterior_means, ptype = double(), size = 8)
Expand Down Expand Up @@ -1043,19 +1045,19 @@ test_that("ABC calibration has correctly formatted returns/runs with county leve
precipitation_coefficient_file <- ""
model_type <- "SI"
latency_period <- 0
time_step <- "month"
time_step <- "week"
season_month_start <- 1
season_month_end <- 12
start_date <- "2003-01-01"
end_date <- "2003-02-11"
end_date <- "2003-01-14"
use_lethal_temperature <- FALSE
temperature_file <- ""
lethal_temperature <- -30
lethal_temperature_month <- 1
mortality_frequency <- "Year"
mortality_frequency_n <- 1
management <- FALSE
treatment_dates <- c("2003-01-24")
treatment_dates <- c("2003-01-07")
treatments_file <- ""
treatment_method <- "ratio"
natural_kernel_type <- "exponential"
Expand All @@ -1067,7 +1069,7 @@ test_that("ABC calibration has correctly formatted returns/runs with county leve
pesticide_duration <- c(0)
pesticide_efficacy <- 1.0
mask <- NULL
output_frequency <- "year"
output_frequency <- "week"
output_frequency_n <- 1
movements_file <- ""
use_movements <- FALSE
Expand Down Expand Up @@ -1096,7 +1098,7 @@ test_that("ABC calibration has correctly formatted returns/runs with county leve
verbose <- TRUE
write_outputs <- "None"
output_folder_path <- ""
success_metric <- "quantity, allocation, and configuration"
success_metric <- "accuracy"
network_filename <-
system.file("extdata", "simple20x20", "segments.csv", package = "PoPS")
use_survival_rates <- FALSE
Expand Down
3 changes: 2 additions & 1 deletion tests/testthat/test-validate.R
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ test_that(
use_soils <- FALSE
soil_starting_pest_file <- ""
start_with_soil_populations <- FALSE
county_level_infection_data <- FALSE

outputs <- validate(
infected_years_file,
Expand Down Expand Up @@ -1011,7 +1012,7 @@ test_that(
expect_type(outputs, "list")
expect_length(outputs, 2)
data <- outputs[[1]]
expect_length(data, 24)
expect_length(data, 28)
expect_vector(data$false_negatives, size = number_of_iterations)
expect_vector(data$false_positives, size = number_of_iterations)
expect_vector(data$true_positives, size = number_of_iterations)
Expand Down

0 comments on commit 2f6c595

Please sign in to comment.