Skip to content

Commit

Permalink
time index as numeric, symbol names as character
Browse files Browse the repository at this point in the history
  • Loading branch information
helske committed Nov 21, 2024
1 parent cd2e943 commit 60b5af1
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 15 deletions.
8 changes: 4 additions & 4 deletions R/ame_obs.R
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ ame_obs.nhmm <- function(
X2 <- update(model, newdata)[c("X_pi", "X_A", "X_B")]
C <- model$n_channels
if (C == 1L) {
times <- colnames(model$observations)
times <- as.numeric(colnames(model$observations))
symbol_names <- list(model$symbol_names)
obs <- create_obsArray(model)[1L, , ]
out <- ame_obs_nhmm_singlechannel(
Expand All @@ -121,14 +121,14 @@ ame_obs.nhmm <- function(
)
d <- data.frame(
observation = model$symbol_names,
time = rep(colnames(model$observations), each = model$n_symbols),
time = rep(as.numeric(colnames(model$observations)), each = model$n_symbols),
estimate = c(out$point_estimate)
)
for(i in seq_along(probs)) {
d[paste0("q", 100 * probs[i])] <- c(out$quantiles[, , i])
}
} else {
times <- colnames(model$observations[[1]])
times <- as.numeric(colnames(model$observations[[1]]))
symbol_names <- model$symbol_names
obs <- create_obsArray(model)
out <- ame_obs_nhmm_multichannel(
Expand All @@ -147,7 +147,7 @@ ame_obs.nhmm <- function(
)
d <- data.frame(
observation = model$symbol_names,
time = rep(colnames(model$observations), each = model$n_symbols),
time = rep(as.numeric(colnames(model$observations)), each = model$n_symbols),
estimate = c(out$point_estimate)
)
for(i in seq_along(probs)) {
Expand Down
4 changes: 2 additions & 2 deletions R/ame_param.R
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ ame_param.nhmm <- function(
model2 <- update(model, newdata)
C <- model$n_channels
if (C == 1L) {
times <- colnames(model$observations)
times <- as.numeric(colnames(model$observations))
symbol_names <- list(model$symbol_names)
} else {
times <- colnames(model$observations[[1]])
times <- as.numeric(colnames(model$observations[[1]]))
symbol_names <- model$symbol_names
}
if (attr(model$X_pi, "icpt_only")) {
Expand Down
8 changes: 4 additions & 4 deletions R/get_probs.R
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ get_transition_probs.nhmm <- function(model, probs, remove_voids = TRUE, ...) {
model$X_A[attr(model$X_A, "missing")] <- NA
if (model$n_channels == 1L) {
ids <- rownames(model$observations)
times <- colnames(model$observations)
times <- as.numeric(colnames(model$observations))
} else {
ids <- rownames(model$observations[[1]])
times <- colnames(model$observations[[1]])
times <- as.numeric(colnames(model$observations[[1]]))
}
if (!attr(model$X_A, "iv")) {
X <- model$X_A[, , 1L, drop = FALSE]
Expand Down Expand Up @@ -182,12 +182,12 @@ get_emission_probs.nhmm <- function(model, probs, remove_voids = TRUE, ...) {
model$X_B[attr(model$X_B, "missing")] <- NA
if (C == 1L) {
ids <- rownames(model$observations)
times <- colnames(model$observations)
times <- as.numeric(colnames(model$observations))
symbol_names <- list(model$symbol_names)
model$gammas$B <- list(model$gammas$B)
} else {
ids <- rownames(model$observations[[1]])
times <- colnames(model$observations[[1]])
times <- as.numeric(colnames(model$observations[[1]]))
symbol_names <- model$symbol_names
}
if (!attr(model$X_B, "iv")) {
Expand Down
12 changes: 8 additions & 4 deletions R/simulate_mnhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ simulate_mnhmm <- function(
)
sequence_lengths <- rep(sequence_lengths, length.out = n_sequences)
n_channels <- length(n_symbols)
symbol_names <- lapply(seq_len(n_channels), function(i) seq_len(n_symbols[i]))
symbol_names <- lapply(
seq_len(n_channels), function(i) {
as.character(seq_len(n_symbols[i]))
}
)
T_ <- max(sequence_lengths)
obs <- lapply(seq_len(n_channels), function(i) {
suppressWarnings(suppressMessages(
Expand Down Expand Up @@ -98,7 +102,7 @@ simulate_mnhmm <- function(
gamma_B <- c(eta_to_gamma_cube_field(unlist(model$etas$B, recursive = FALSE)))
model$gammas$B <- split(gamma_B, rep(seq_along(l), l))
}

model$gammas$omega <- eta_to_gamma_mat(
model$etas$omega
)
Expand All @@ -120,7 +124,7 @@ simulate_mnhmm <- function(
model$n_symbols
)
}

for (i in seq_len(model$n_sequences)) {
Ti <- sequence_lengths[i]
if (Ti < T_) {
Expand All @@ -143,7 +147,7 @@ simulate_mnhmm <- function(
alphabet = state_names, cnames = seq_len(T_)
)
))

if (n_channels == 1) {
dim(out$observations) <- dim(out$observations)[2:3]
out$observations[] <- symbol_names[c(out$observations) + 1]
Expand Down
6 changes: 5 additions & 1 deletion R/simulate_nhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ simulate_nhmm <- function(
)
sequence_lengths <- rep(sequence_lengths, length.out = n_sequences)
n_channels <- length(n_symbols)
symbol_names <- lapply(seq_len(n_channels), function(i) seq_len(n_symbols[i]))
symbol_names <- lapply(
seq_len(n_channels), function(i) {
as.character(seq_len(n_symbols[i]))
}
)
T_ <- max(sequence_lengths)
obs <- lapply(seq_len(n_channels), function(i) {
suppressWarnings(suppressMessages(
Expand Down

0 comments on commit 60b5af1

Please sign in to comment.