diff --git a/R/ame_obs.R b/R/ame_obs.R index 1e8cf64..282198b 100644 --- a/R/ame_obs.R +++ b/R/ame_obs.R @@ -102,7 +102,7 @@ ame_obs.nhmm <- function( X2 <- update(model, newdata)[c("X_pi", "X_A", "X_B")] C <- model$n_channels if (C == 1L) { - times <- colnames(model$observations) + times <- as.numeric(colnames(model$observations)) symbol_names <- list(model$symbol_names) obs <- create_obsArray(model)[1L, , ] out <- ame_obs_nhmm_singlechannel( @@ -121,14 +121,14 @@ ame_obs.nhmm <- function( ) d <- data.frame( observation = model$symbol_names, - time = rep(colnames(model$observations), each = model$n_symbols), + time = rep(as.numeric(colnames(model$observations)), each = model$n_symbols), estimate = c(out$point_estimate) ) for(i in seq_along(probs)) { d[paste0("q", 100 * probs[i])] <- c(out$quantiles[, , i]) } } else { - times <- colnames(model$observations[[1]]) + times <- as.numeric(colnames(model$observations[[1]])) symbol_names <- model$symbol_names obs <- create_obsArray(model) out <- ame_obs_nhmm_multichannel( @@ -147,7 +147,7 @@ ame_obs.nhmm <- function( ) d <- data.frame( observation = model$symbol_names, - time = rep(colnames(model$observations), each = model$n_symbols), + time = rep(as.numeric(colnames(model$observations)), each = model$n_symbols), estimate = c(out$point_estimate) ) for(i in seq_along(probs)) { diff --git a/R/ame_param.R b/R/ame_param.R index 5c409d6..4b214a3 100644 --- a/R/ame_param.R +++ b/R/ame_param.R @@ -86,10 +86,10 @@ ame_param.nhmm <- function( model2 <- update(model, newdata) C <- model$n_channels if (C == 1L) { - times <- colnames(model$observations) + times <- as.numeric(colnames(model$observations)) symbol_names <- list(model$symbol_names) } else { - times <- colnames(model$observations[[1]]) + times <- as.numeric(colnames(model$observations[[1]])) symbol_names <- model$symbol_names } if (attr(model$X_pi, "icpt_only")) { diff --git a/R/get_probs.R b/R/get_probs.R index d0d547f..ac04fed 100644 --- a/R/get_probs.R +++ b/R/get_probs.R @@ -97,10 +97,10 @@ get_transition_probs.nhmm <- function(model, probs, remove_voids = TRUE, ...) { model$X_A[attr(model$X_A, "missing")] <- NA if (model$n_channels == 1L) { ids <- rownames(model$observations) - times <- colnames(model$observations) + times <- as.numeric(colnames(model$observations)) } else { ids <- rownames(model$observations[[1]]) - times <- colnames(model$observations[[1]]) + times <- as.numeric(colnames(model$observations[[1]])) } if (!attr(model$X_A, "iv")) { X <- model$X_A[, , 1L, drop = FALSE] @@ -182,12 +182,12 @@ get_emission_probs.nhmm <- function(model, probs, remove_voids = TRUE, ...) { model$X_B[attr(model$X_B, "missing")] <- NA if (C == 1L) { ids <- rownames(model$observations) - times <- colnames(model$observations) + times <- as.numeric(colnames(model$observations)) symbol_names <- list(model$symbol_names) model$gammas$B <- list(model$gammas$B) } else { ids <- rownames(model$observations[[1]]) - times <- colnames(model$observations[[1]]) + times <- as.numeric(colnames(model$observations[[1]])) symbol_names <- model$symbol_names } if (!attr(model$X_B, "iv")) { diff --git a/R/simulate_mnhmm.R b/R/simulate_mnhmm.R index ce88f7b..777cbb4 100644 --- a/R/simulate_mnhmm.R +++ b/R/simulate_mnhmm.R @@ -53,7 +53,11 @@ simulate_mnhmm <- function( ) sequence_lengths <- rep(sequence_lengths, length.out = n_sequences) n_channels <- length(n_symbols) - symbol_names <- lapply(seq_len(n_channels), function(i) seq_len(n_symbols[i])) + symbol_names <- lapply( + seq_len(n_channels), function(i) { + as.character(seq_len(n_symbols[i])) + } + ) T_ <- max(sequence_lengths) obs <- lapply(seq_len(n_channels), function(i) { suppressWarnings(suppressMessages( @@ -98,7 +102,7 @@ simulate_mnhmm <- function( gamma_B <- c(eta_to_gamma_cube_field(unlist(model$etas$B, recursive = FALSE))) model$gammas$B <- split(gamma_B, rep(seq_along(l), l)) } - + model$gammas$omega <- eta_to_gamma_mat( model$etas$omega ) @@ -120,7 +124,7 @@ simulate_mnhmm <- function( model$n_symbols ) } - + for (i in seq_len(model$n_sequences)) { Ti <- sequence_lengths[i] if (Ti < T_) { @@ -143,7 +147,7 @@ simulate_mnhmm <- function( alphabet = state_names, cnames = seq_len(T_) ) )) - + if (n_channels == 1) { dim(out$observations) <- dim(out$observations)[2:3] out$observations[] <- symbol_names[c(out$observations) + 1] diff --git a/R/simulate_nhmm.R b/R/simulate_nhmm.R index 17cd72b..0523218 100644 --- a/R/simulate_nhmm.R +++ b/R/simulate_nhmm.R @@ -47,7 +47,11 @@ simulate_nhmm <- function( ) sequence_lengths <- rep(sequence_lengths, length.out = n_sequences) n_channels <- length(n_symbols) - symbol_names <- lapply(seq_len(n_channels), function(i) seq_len(n_symbols[i])) + symbol_names <- lapply( + seq_len(n_channels), function(i) { + as.character(seq_len(n_symbols[i])) + } + ) T_ <- max(sequence_lengths) obs <- lapply(seq_len(n_channels), function(i) { suppressWarnings(suppressMessages(