Skip to content

Commit

Permalink
add pseudocount option
Browse files Browse the repository at this point in the history
  • Loading branch information
helske committed Nov 9, 2024
1 parent c0c18b6 commit 27dd4f6
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 45 deletions.
9 changes: 6 additions & 3 deletions R/fit_mnhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
#'
#' @noRd
fit_mnhmm <- function(model, inits, init_sd, restarts, lambda, method,
save_all_solutions = FALSE,
control_restart = list(), control_mstep = list(), ...) {
save_all_solutions = FALSE, control_restart = list(),
control_mstep = list(), ...) {
stopifnot_(
checkmate::test_int(x = restarts, lower = 0L),
"Argument {.arg restarts} must be a single integer."
Expand All @@ -21,7 +21,10 @@ fit_mnhmm <- function(model, inits, init_sd, restarts, lambda, method,
list(...)
)
control_restart <- utils::modifyList(control, control_restart)
control_mstep <- utils::modifyList(control, control_mstep)
control_mstep <- utils::modifyList(
c(control, list(pseudocount = 0)),
control_mstep
)

M <- model$n_symbols
S <- model$n_states
Expand Down
26 changes: 16 additions & 10 deletions R/fit_nhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
#'
#' @noRd
fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method,
save_all_solutions = FALSE,
control_restart = list(), control_mstep = list(), ...) {
save_all_solutions = FALSE, control_restart = list(),
control_mstep = list(), ...) {

stopifnot_(
checkmate::test_int(x = restarts, lower = 0L),
Expand All @@ -22,7 +22,10 @@ fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method,
list(...)
)
control_restart <- utils::modifyList(control, control_restart)
control_mstep <- utils::modifyList(control, control_mstep)
control_mstep <- utils::modifyList(
c(control, list(pseudocount = 0)),
control_mstep
)

M <- model$n_symbols
S <- model$n_states
Expand Down Expand Up @@ -149,7 +152,7 @@ fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method,
}
}
}

start_time <- proc.time()
if (restarts > 0L) {

Expand Down Expand Up @@ -212,7 +215,10 @@ fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method,
all_solutions = all_solutions,
time = end_time - start_time
)
} else {
model$estimation_results$lambda <- lambda
return(model)
}
if (method == "EM") {
start_time <- proc.time()
if (restarts > 0L) {
out <- future.apply::future_lapply(seq_len(restarts), function(i) {
Expand Down Expand Up @@ -268,7 +274,7 @@ fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method,
control$print_level, control_mstep$maxeval,
control_mstep$ftol_abs, control_mstep$ftol_rel,
control_mstep$xtol_abs, control_mstep$xtol_rel,
control_mstep$print_level, lambda)
control_mstep$print_level, lambda, control_mstep$pseudocount)
} else {
out <- EM_LBFGS_nhmm_multichannel(
init$pi, model$X_pi, init$A, model$X_A, init$B, model$X_B, obs,
Expand All @@ -279,7 +285,7 @@ fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method,
control$print_level, control_mstep$maxeval,
control_mstep$ftol_abs, control_mstep$ftol_rel,
control_mstep$xtol_abs, control_mstep$xtol_rel,
control_mstep$print_level, lambda)
control_mstep$print_level, lambda, control_mstep$pseudocount)
}
end_time <- proc.time()
# if (out$status < 0) {
Expand Down Expand Up @@ -309,8 +315,8 @@ fit_nhmm <- function(model, inits, init_sd, restarts, lambda, method,
f_abs_change = out$absolute_f_change,
x_rel_change = out$relative_x_change,
x_abs_change = out$absolute_x_change
)
)
model$estimation_results$lambda <- lambda
return(model)
}
model$estimation_results$lambda <- lambda
model
}
12 changes: 8 additions & 4 deletions R/simulate_mnhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,10 @@ simulate_mnhmm <- function(
sequence_lengths <- rep(sequence_lengths, length.out = n_sequences)
n_channels <- length(n_symbols)
symbol_names <- lapply(seq_len(n_channels), function(i) seq_len(n_symbols[i]))
T_ <- max(sequence_lengths)
obs <- lapply(seq_len(n_channels), function(i) {
suppressWarnings(suppressMessages(
seqdef(matrix(symbol_names[[i]][1], n_sequences, max(sequence_lengths)),
seqdef(matrix(symbol_names[[i]][1], n_sequences, T_),
alphabet = symbol_names[[i]]
)))
})
Expand Down Expand Up @@ -139,20 +140,23 @@ simulate_mnhmm <- function(
t(out$states),
n_sequences, max(sequence_lengths)
),
alphabet = state_names
alphabet = state_names, cnames = seq_len(T_)
)
))

if (n_channels == 1) {
dim(out$observations) <- dim(out$observations)[2:3]
out$observations[] <- symbol_names[c(out$observations) + 1]
colnames(out$observations) <- seq_len(T_)
model$observations <- suppressWarnings(suppressMessages(
seqdef(t(out$observations), alphabet = symbol_names)
seqdef(t(out$observations), alphabet = symbol_names, cnames = seq_len(T_))
))
} else {
model$observations <- lapply(seq_len(n_channels), function(i) {
out$observations[i, , ] <- symbol_names[[i]][c(out$observations[i, , ]) + 1]
suppressWarnings(suppressMessages(
seqdef(t(out$observations[i, , ]), alphabet = symbol_names[[i]])
seqdef(t(out$observations[i, , ]), alphabet = symbol_names[[i]],
cnames = seq_len(T_))
))
})
names(model$observations) <- model$channel_names
Expand Down
12 changes: 7 additions & 5 deletions R/simulate_nhmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ simulate_nhmm <- function(
sequence_lengths <- rep(sequence_lengths, length.out = n_sequences)
n_channels <- length(n_symbols)
symbol_names <- lapply(seq_len(n_channels), function(i) seq_len(n_symbols[i]))
T_ <- max(sequence_lengths)
obs <- lapply(seq_len(n_channels), function(i) {
suppressWarnings(suppressMessages(
seqdef(matrix(symbol_names[[i]][1], n_sequences, max(sequence_lengths)),
seqdef(matrix(symbol_names[[i]][1], n_sequences, T_),
alphabet = symbol_names[[i]]
)))
})
Expand Down Expand Up @@ -94,7 +95,6 @@ simulate_nhmm <- function(
model$n_symbols
)
}
T_ <- model$length_of_sequences
for (i in seq_len(model$n_sequences)) {
Ti <- sequence_lengths[i]
if (Ti < T_) {
Expand All @@ -111,20 +111,22 @@ simulate_nhmm <- function(
t(out$states),
n_sequences, max(sequence_lengths)
),
alphabet = state_names
alphabet = state_names, cnames = seq_len(T_)
)
))

if (n_channels == 1) {
dim(out$observations) <- dim(out$observations)[2:3]
out$observations[] <- symbol_names[c(out$observations) + 1]
model$observations <- suppressWarnings(suppressMessages(
seqdef(t(out$observations), alphabet = symbol_names)
seqdef(t(out$observations), alphabet = symbol_names, cnames = seq_len(T_))
))
} else {
model$observations <- lapply(seq_len(n_channels), function(i) {
out$observations[i, , ] <- symbol_names[[i]][c(out$observations[i, , ]) + 1]
suppressWarnings(suppressMessages(
seqdef(t(out$observations[i, , ]), alphabet = symbol_names[[i]])
seqdef(t(out$observations[i, , ]), alphabet = symbol_names[[i]],
cnames = seq_len(T_))
))
})
names(model$observations) <- model$channel_names
Expand Down
37 changes: 22 additions & 15 deletions src/nhmm_EM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ double nhmm_base::objective_pi(const arma::vec& x, arma::vec& grad) {
}
void nhmm_base::mstep_pi(const double xtol_abs, const double ftol_abs,
const double xtol_rel, const double ftol_rel,
const arma::uword maxeval, const arma::uword print_level) {
const arma::uword maxeval,
const arma::uword print_level) {

// Use closed form solution
if (icpt_only_pi && lambda < 1e-12) {
Expand Down Expand Up @@ -493,7 +494,7 @@ Rcpp::List EM_LBFGS_nhmm_singlechannel(
const double xtol_abs, const double xtol_rel, const arma::uword print_level,
const arma::uword maxeval_m, const double ftol_abs_m, const double ftol_rel_m,
const double xtol_abs_m, const double xtol_rel_m, const arma::uword print_level_m,
const double lambda) {
const double lambda, const double pseudocount) {

nhmm_sc model(
eta_A.n_slices, X_pi, X_A, X_B, Ti, icpt_only_pi, icpt_only_A,
Expand Down Expand Up @@ -546,9 +547,9 @@ Rcpp::List EM_LBFGS_nhmm_singlechannel(

double ll_i = logSumExp(log_alpha.col(model.Ti(i) - 1));
ll += ll_i;
model.estep_pi(i, log_alpha.col(0), log_beta.col(0), ll_i);
model.estep_A(i, log_alpha, log_beta, ll_i);
model.estep_B(i, log_alpha, log_beta, ll_i);
model.estep_pi(i, log_alpha.col(0), log_beta.col(0), ll_i, pseudocount);
model.estep_A(i, log_alpha, log_beta, ll_i, pseudocount);
model.estep_B(i, log_alpha, log_beta, ll_i, pseudocount);
}
double penalty_term = 0.5 * lambda * arma::dot(pars, pars);
ll -= penalty_term;
Expand Down Expand Up @@ -601,9 +602,9 @@ Rcpp::List EM_LBFGS_nhmm_singlechannel(
);
double ll_i = logSumExp(log_alpha.col(model.Ti(i) - 1));
ll_new += ll_i;
model.estep_pi(i, log_alpha.col(0), log_beta.col(0), ll_i);
model.estep_A(i, log_alpha, log_beta, ll_i);
model.estep_B(i, log_alpha, log_beta, ll_i);
model.estep_pi(i, log_alpha.col(0), log_beta.col(0), ll_i, pseudocount);
model.estep_A(i, log_alpha, log_beta, ll_i, pseudocount);
model.estep_B(i, log_alpha, log_beta, ll_i, pseudocount);
}

pars_new.cols(0, n_pi - 1) = arma::vectorise(model.eta_pi).t();
Expand Down Expand Up @@ -632,6 +633,9 @@ Rcpp::List EM_LBFGS_nhmm_singlechannel(
}
ll = ll_new;
pars = pars_new;
if (absolute_change < -1e6) {
Rcpp::warning("EM algorithm encountered decreasing log-likelihood.");
}
}

return Rcpp::List::create(
Expand Down Expand Up @@ -662,7 +666,7 @@ Rcpp::List EM_LBFGS_nhmm_multichannel(
const double xtol_abs, const double xtol_rel, const arma::uword print_level,
const arma::uword maxeval_m, const double ftol_abs_m, const double ftol_rel_m,
const double xtol_abs_m, const double xtol_rel_m, const arma::uword print_level_m,
const double lambda) {
const double lambda, const double pseudocount) {

nhmm_mc model(
eta_A.n_slices, X_pi, X_A, X_B, Ti, icpt_only_pi, icpt_only_A,
Expand Down Expand Up @@ -718,9 +722,9 @@ Rcpp::List EM_LBFGS_nhmm_multichannel(
);
double ll_i = logSumExp(log_alpha.col(model.Ti(i) - 1));
ll += ll_i;
model.estep_pi(i, log_alpha.col(0), log_beta.col(0), ll_i);
model.estep_A(i, log_alpha, log_beta, ll_i);
model.estep_B(i, log_alpha, log_beta, ll_i);
model.estep_pi(i, log_alpha.col(0), log_beta.col(0), ll_i, pseudocount);
model.estep_A(i, log_alpha, log_beta, ll_i, pseudocount);
model.estep_B(i, log_alpha, log_beta, ll_i, pseudocount);
}
double penalty_term = 0.5 * lambda * arma::dot(pars, pars);
ll -= penalty_term;
Expand Down Expand Up @@ -773,9 +777,9 @@ Rcpp::List EM_LBFGS_nhmm_multichannel(
);
double ll_i = logSumExp(log_alpha.col(model.Ti(i) - 1));
ll_new += ll_i;
model.estep_pi(i, log_alpha.col(0), log_beta.col(0), ll_i);
model.estep_A(i, log_alpha, log_beta, ll_i);
model.estep_B(i, log_alpha, log_beta, ll_i);
model.estep_pi(i, log_alpha.col(0), log_beta.col(0), ll_i, pseudocount);
model.estep_A(i, log_alpha, log_beta, ll_i, pseudocount);
model.estep_B(i, log_alpha, log_beta, ll_i, pseudocount);
}
pars_new.cols(0, n_pi - 1) = arma::vectorise(model.eta_pi).t();
pars_new.cols(n_pi, n_pi + n_A - 1) = arma::vectorise(model.eta_A).t();
Expand Down Expand Up @@ -807,6 +811,9 @@ Rcpp::List EM_LBFGS_nhmm_multichannel(
}
ll = ll_new;
pars = pars_new;
if (absolute_change < -1e6) {
Rcpp::warning("EM algorithm encountered decreasing log-likelihood.");
}
}

return Rcpp::List::create(
Expand Down
10 changes: 6 additions & 4 deletions src/nhmm_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,16 +134,18 @@ struct nhmm_base {
}

void estep_pi(const arma::uword i, const arma::vec& log_alpha,
const arma::vec& log_beta, const double ll) {
E_Pi.col(i) = arma::exp(log_alpha + log_beta - ll);
const arma::vec& log_beta, const double ll,
const double pseudocount = 0) {
E_Pi.col(i) = arma::exp(log_alpha + log_beta - ll) + pseudocount;
}
void estep_A(const arma::uword i, const arma::mat& log_alpha,
const arma::mat& log_beta, const double ll) {
const arma::mat& log_beta, const double ll,
const double pseudocount = 0) {
for (arma::uword k = 0; k < S; k++) { // from
for (arma::uword j = 0; j < S; j++) { // to
for (arma::uword t = 0; t < (Ti(i) - 1); t++) { // time
E_A(k)(j, i, t) = exp(log_alpha(k, t) + log_A(k, j, t) +
log_beta(j, t + 1) + log_py(j, t + 1) - ll);
log_beta(j, t + 1) + log_py(j, t + 1) - ll) + pseudocount;
}
}
}
Expand Down
5 changes: 3 additions & 2 deletions src/nhmm_mc.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,14 @@ struct nhmm_mc : public nhmm_base {
}
}
void estep_B(const arma::uword i, const arma::mat& log_alpha,
const arma::mat& log_beta, const double ll) {
const arma::mat& log_beta, const double ll,
const double pseudocount = 0) {
for (arma::uword k = 0; k < S; k++) { // state
for (arma::uword t = 0; t < Ti(i); t++) { // time
double pp = exp(log_alpha(k, t) + log_beta(k, t) - ll);
for (arma::uword c = 0; c < C; c++) { // channel
if (obs(c, t, i) < M(c)) {
E_B(c)(t, i, k) = pp;
E_B(c)(t, i, k) = pp + pseudocount;
} else {
E_B(c)(t, i, k) = 0.0;
}
Expand Down
5 changes: 3 additions & 2 deletions src/nhmm_sc.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,12 @@ struct nhmm_sc : public nhmm_base {
}
}
void estep_B(const arma::uword i, const arma::mat& log_alpha,
const arma::mat& log_beta, const double ll) {
const arma::mat& log_beta, const double ll,
const double pseudocount = 0) {
for (arma::uword k = 0; k < S; k++) { // state
for (arma::uword t = 0; t < Ti(i); t++) { // time
if (obs(t, i) < M) {
E_B(t, i, k) = exp(log_alpha(k, t) + log_beta(k, t) - ll);
E_B(t, i, k) = exp(log_alpha(k, t) + log_beta(k, t) - ll) + pseudocount;
} else {
E_B(t, i, k) = 0.0;
}
Expand Down

0 comments on commit 27dd4f6

Please sign in to comment.