Skip to content

Commit

Permalink
push most current multi state version
Browse files Browse the repository at this point in the history
  • Loading branch information
johannespiller committed Aug 8, 2024
1 parent 30099bd commit d01c17c
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 113 deletions.
4 changes: 2 additions & 2 deletions R/add-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -904,10 +904,10 @@ add_trans_prob <- function(
if (ci) {
newdata <- newdata |>
add_trans_ci(object) |>
add_cumu_hazard(object)
add_cumu_hazard(object, overwrite = T)
} else {
newdata <- newdata |>
add_cumu_hazard(object)
add_cumu_hazard(object, overwrite = T)
}


Expand Down
270 changes: 159 additions & 111 deletions vignettes/multi-state.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,24 @@ For illustration, we use the `prothr` data set from the **`mstate`** package.

```{r, echo = TRUE}
library(survival)
library(mgcv)
library(ggplot2)
library(pammtools)
library(purrr)
library(mstate)
data(prothr, package = "mstate")
prothr |> filter(id == 46) |> knitr::kable() # example patients
```

```{r, echo = FALSE}
data("prothr", package = "mstate")
prothr <- prothr %>%
rename(tstart = Tstart, tstop = Tstop) %>%
filter(tstart != tstop)
prothr <- prothr |>
mutate(transition = as.factor(paste0(from, "->", to))
, treat = as.factor(treat)) |>
rename(tstart = Tstart, tstop = Tstop) |>
filter(tstart != tstop) |>
select(-trans)
```

In general, one has to follow three steps to derive transition probabilities from multi-state survival data.
Expand All @@ -72,11 +78,14 @@ Transforming the survival data `prothr` into piecewise exponential data, we can
# source("C:/Users/ra63liw/Documents/98_git/pammtools-multi-state/pammtools/tmp/add_transition_probabilities.R")
library(checkmate)
library(tidyverse)
library(dplyr)
```
```{r, echo = TRUE, dependson=c("lib-ms-pammtools")}
my.prothr <- prothr |> filter(status == 1) |> add_counterfactual_transitions() # add possible transitions
# not necessary, prothr already contains all possible transitions
# my.prothr <- prothr |> add_counterfactual_transitions() # add possible transitions
ped <- as_ped_multistate(
data = my.prothr,
data = prothr,
formula = Surv(tstart, tstop, status)~ .,
transition = "transition",
id = "id",
Expand All @@ -91,6 +100,12 @@ where `add_counterfactual_transitions` is a helper function, which adds all poss
Estimating the log hazard structure using PAM objects, we can use
```{r echo = TRUE}
pam <- pamm(ped_status ~ s(tend, by=transition) + transition * treat, data = ped)
pam <- bam(ped_status ~ s(tend, by=transition) + transition * treat
, data = ped
, family = poisson()
, offset = offset
, method = "fREML"
, discrete = TRUE)
summary(pam)
```

Expand All @@ -102,9 +117,12 @@ Post-processing the data to include all relevant objects of interest in our data
ndf <- make_newdata(ped, tend = unique(tend), treat = unique(treat), transition = unique(transition))
ndf <- ndf |>
group_by(treat, transition) |> # important!
# arrange(treat, transition, tend) |>
# add_trans_ci(pam) |>
add_trans_prob(pam)
arrange(treat, transition, tend) |>
add_trans_prob(pam, ci=TRUE)
ndf <- ndf |>
group_by(treat, transition) |> # important!
add_cumu_hazard(pam, overwrite = T)
```
where `make_newdata` creates a data set containing all covariates and all their combinations from the PAM object. The convenience function `add_trans_prob` crates a new column `trans_prob`, which can be visualized.

Expand All @@ -114,18 +132,34 @@ ggplot(ndf, aes(x=tend)) +
geom_line(aes(y=trans_prob, col=treat)) +
# geom_ribbon(aes(ymin = trans_lower, ymax = trans_upper, fill=treat), alpha = .3) +
facet_wrap(~transition) +
xlim(c(0, 4000)) +
ylim(c(0,1))+
labs(y = "Transition Probability", x = "time", color = "Treatment", fill= "Treatment")
```



## Comparison of the results with Aalen-Johannsen estimator
```{r, echo = FALSE, warning = FALSE, dependson=c("lib-ms-pammtools")}

Comparing the ``pammtools`` resutls with the ``mastate`` results, we want to validate that the baselines are indeed correct. The following code shows the comparison between the ``mstate``.

First, we compare the cumulative hazards
```{r, echo = TRUE, warning = FALSE, dependson=c("lib-ms-pammtools")}
# pammtools
ndf <- ndf |>
group_by(treat, transition) |> # important!
add_cumu_hazard(pam, overwrite = T) |>
mutate(package = "pammtools")
```
```{r prothr-prep-mstate, eval=TRUE, echo = FALSE, warning = FALSE, dependson=c("lib-ms-pammtools")}
library(mstate)
library(msm)
library(mvna)
library(etm)
# code from mstate documentation
data(prothr)
data(prothr, package = "mstate")
tmat <- attr(prothr, "trans")
pr0 <- subset(prothr, treat=="Placebo")
attr(pr0, "trans") <- tmat
Expand All @@ -138,26 +172,6 @@ c1 <- coxph(Surv(Tstart, Tstop, status) ~ strata(trans), data=pr1)
msf0 <- msfit(c0, trans=tmat)
msf1 <- msfit(c1, trans=tmat)
# Comparison as in Figure 2 of Titman (2015)
# Aalen-Johansen
pt0 <- probtrans(msf0, predt=0)[[2]] # changed predt from 1000 to 0
pt1 <- probtrans(msf1, predt=0)[[2]] # changed predt from 1000 to 0
par(mfrow=c(1,2))
plot(pt0$time, pt0$pstate1, type="s", lwd=2, xlim=c(0,4000), ylim=c(0,1),
xlab="Time since randomisation (days)", ylab="Probability")
lines(pt1$time, pt1$pstate1, type="s", lwd=2, lty=3)
legend("topright", c("Placebo", "Prednisone"), lwd=2, lty=1:2, bty="n")
title(main="Aalen-Johansen")
# Landmark Aalen-Johansen
LMpt0 <- LMAJ(msdata=pr0, s=0, from=2) # changed predt from 1000 to 0
LMpt1 <- LMAJ(msdata=pr1, s=0, from=2) # changed predt from 1000 to 0
plot(LMpt0$time, LMpt0$pstate1, type="s", lwd=2, xlim=c(0,4000), ylim=c(0,1),
xlab="Time since randomisation (days)", ylab="Probability")
lines(LMpt1$time, LMpt1$pstate1, type="s", lwd=2, lty=3)
legend("topright", c("Placebo", "Prednisone"), lwd=2, lty=1:2, bty="n")
title(main="Landmark Aalen-Johansen")
par(mfrow=c(1,1))
# plot hazards
mstate_dat0 <- msf0$Haz %>% mutate(transition = case_when(
trans == 1 ~ "1->2",
Expand All @@ -178,98 +192,132 @@ mstate_dat1 <- msf1$Haz %>% mutate(transition = case_when(
mstate_dat <- rbind(mstate_dat0, mstate_dat1)
long_mstate <- mstate_dat %>%
rename(tend = time, cumu_hazard = Haz) %>%
mutate(package = "mstate") %>%
select(tend, treat, transition, cumu_hazard, package)
long_msm <- ndf %>%
mutate(package = "pammtools") %>%
select(tend, treat, transition, cumu_hazard, package)
long_haz_df <- rbind(long_mstate, long_msm) %>%
mutate(transition = case_when(transition == "1->2" ~ "0->1"
, transition == "1->3" ~ "0->2"
, transition == "2->1" ~ "1->0"
, transition == "2->3" ~ "1->2")
)
ggplot(data = mstate_dat, aes(x=time, y=Haz)) +
geom_line(aes(col=treat)) +
facet_wrap(~transition, ncol = 4, labeller = label_both)
# data analysis / preparation
head(prothr)
table(prothr$from, prothr$to)
# transitions 1->2, 1->3, 2->1, 2->3 possible.
# classical multi-state setup with recurrent events.
# try analysis with etm
# transform days in fractions of year
data <- prothr %>% mutate(Tstart = Tstart / 365.25,
Tstop = Tstop /365.25,
from = from - 1,
to = to - 1)
head(data)
dim(data)
# estimate transition probabilites
my.data <- prothr %>%
mutate(from = from -1, to = to-1) %>%
rename(entry = Tstart, exit = Tstop) %>%
arrange(id, entry, exit) %>%
filter(!(entry == exit)) %>%
distinct()
# build transition matrix for mvna
tra <- matrix(FALSE, ncol = 3, nrow = 3)
dimnames(tra) <- list(c("0", "1", "2"), c("0", "1", "2"))
tra[1, 2:3] <- TRUE
tra[2, c(1,3)] <- TRUE
# print transition matrix
tra
my.data0 <- my.data %>% filter(treat=="Placebo")
my.data1 <- my.data %>% filter(treat=="Prednisone")
etm.prothr0 <- etm(my.data0, c("0", "1", "2"), tra, "cens", s = 0)
etm.prothr1 <- etm(my.data1, c("0", "1", "2"), tra, "cens", s = 0)
mstate_trans_dat0 <- rbind(cbind(tend = etm.prothr0$time, trans_prob = etm.prothr0$est[1,2,], trans = 1)
, tend = cbind(etm.prothr0$time, trans_prob = etm.prothr0$est[1,3,], trans = 2)
, tend = cbind(etm.prothr0$time, trans_prob = etm.prothr0$est[2,1,], trans = 3)
, tend = cbind(etm.prothr0$time, trans_prob = etm.prothr0$est[2,3,], trans = 4))
mstate_trans_dat0 <- data.frame(mstate_trans_dat0) |> mutate(transition = case_when(
trans == 1 ~ "1->2",
trans == 2 ~ "1->3",
trans == 3 ~ "2->1",
trans == 4 ~ "2->3",
.default = "cens"
)
, treat = "Placebo")
my.data <- prothr %>%
mutate(from = from -1, to = to-1) %>%
rename(entry = Tstart, exit = Tstop) %>%
arrange(id, entry, exit) %>%
filter(!(entry == exit)) %>%
distinct()
dim(my.data)
head(my.data)
my.data %>% filter(tstart== tstop)
table(my.data$from, my.data$to)
# calculate cause specific hazards
my.nelaal <- mvna(my.data, c("0", "1", "2"), tra, "cens")
if (require(lattice)){
xyplot(my.nelaal
, strip=strip.custom(bg="white")
, ylab="Cumulative Hazard"
, lwd=2
# , xlim=c(0,5)
# , ylim=c(0,10)
)
}
# differentiate between placebo and not
my.data.placebo <- my.data %>% filter(treat=="Placebo")
table(my.data.placebo$from, my.data.placebo$to)
my.nelaal <- mvna(my.data.placebo, c("0", "1", "2"), tra, "cens")
if (require(lattice)){
xyplot(my.nelaal
, strip=strip.custom(bg="white")
, ylab="Cumulative Hazard"
, lwd=2
# , xlim=c(0,5)
# , ylim=c(0,10)
mstate_trans_dat1 <- rbind(cbind(tend = etm.prothr1$time, trans_prob = etm.prothr1$est[1,2,], trans = 1)
, tend = cbind(etm.prothr1$time, trans_prob = etm.prothr1$est[1,3,], trans = 2)
, tend = cbind(etm.prothr1$time, trans_prob = etm.prothr1$est[2,1,], trans = 3)
, tend = cbind(etm.prothr1$time, trans_prob = etm.prothr1$est[2,3,], trans = 4))
mstate_trans_dat1 <- data.frame(mstate_trans_dat1) |> mutate(transition = case_when(
trans == 1 ~ "1->2",
trans == 2 ~ "1->3",
trans == 3 ~ "2->1",
trans == 4 ~ "2->3",
.default = "cens"
)
}
, treat = "Prednisone")
long_mstate_trans <- rbind(mstate_trans_dat0, mstate_trans_dat1) |> mutate(package = "mstate") |> select(-trans)
# estimate transition probabilites
etm.prothr <- etm(my.data, c("0", "1", "2"), tra, "cens", s = 0)
par(mfrow=c(2,2))
plot(etm.prothr, tr.choice = "0 1", conf.int = TRUE,
lwd = 2, legend = FALSE, ylim = c(0, 1),
xlim = c(0, 4000), xlab = "Days",
ci.fun = "cloglog")
plot(etm.prothr, tr.choice = "0 2", conf.int = TRUE,
lwd = 2, legend = FALSE, ylim = c(0, 1),
xlim = c(0, 4000), xlab = "Days",
ci.fun = "cloglog")
plot(etm.prothr, tr.choice = "1 0", conf.int = TRUE,
lwd = 2, legend = FALSE, ylim = c(0, 1),
xlim = c(0, 4000), xlab = "Days",
ci.fun = "cloglog")
plot(etm.prothr, tr.choice = "1 2", conf.int = TRUE,
lwd = 2, legend = FALSE, ylim = c(0, 1),
xlim = c(0, 4000), xlab = "Days",
ci.fun = "cloglog")
par(mfrow=c(1,1))
# plot hazards
long_msm_trans <- ndf %>%
mutate(package = "pammtools") %>%
select(tend, treat, transition, trans_prob, package)
long_trans_df <- rbind(long_mstate_trans, long_msm_trans)
```

```{r, eval=FALSE, echo = FALSE, fig.width = 8, fig.height = 4, out.width = "600px", dependson=c("prothr-prep-mstate")}
# # plot transitions
# ggplot(test_msm, aes(x=tend, y=trans_prob)) +
# geom_line(aes(col=as.factor(treat))) +
# facet_wrap(~transition, ncol = 2, labeller = label_both) +
# # scale_color_manual(values = c("#1f78b4", "#1f78b4", "#33a02c", "#33a02c"))+
# # scale_linetype_manual(values = c("solid", "dashed", "solid", "dashed")) +
# ylim(c(0,0.8)) +
# xlim(c(0, 4000)) +
# ylab("Transition Probability") +
# xlab("time") +
# theme_bw()
comparison_aaljoh <- ggplot(long_trans_df, aes(x=tend, y=trans_prob, col=treat, linetype = package)) +
geom_line() +
facet_wrap(~transition, ncol = 4, labeller = label_both) +
scale_color_manual(values = c("firebrick2"
, "steelblue")
)+
# scale_linetype_manual(values = c("solid", "dashed", "solid", "dashed")) +18:
ylab("Transition Probabilities") +
xlab("time in days") +
ylim(c(0,1)) +
scale_linetype_manual(values=c("dotted", "solid")) +
theme_bw()
comparison_aaljoh
```

```{r, eval=TRUE, echo = TRUE, fig.width = 8, fig.height = 4, out.width = "600px", dependson=c("prothr-prep-mstate")}
# # plot transitions
# ggplot(test_msm, aes(x=tend, y=trans_prob)) +
# geom_line(aes(col=as.factor(treat))) +
# facet_wrap(~transition, ncol = 2, labeller = label_both) +
# # scale_color_manual(values = c("#1f78b4", "#1f78b4", "#33a02c", "#33a02c"))+
# # scale_linetype_manual(values = c("solid", "dashed", "solid", "dashed")) +
# ylim(c(0,0.8)) +
# xlim(c(0, 4000)) +
# ylab("Transition Probability") +
# xlab("time") +
# theme_bw()
comparison_nelaal <- ggplot(long_haz_df, aes(x=tend, y=cumu_hazard, col=treat, linetype = package)) +
geom_line() +
facet_wrap(~transition, ncol = 4, labeller = label_both) +
scale_color_manual(values = c("firebrick2"
, "steelblue")
)+
# scale_linetype_manual(values = c("solid", "dashed", "solid", "dashed")) +18:
ylab("Cumulative Hazards") +
xlab("time in days") +
ylim(c(0,6)) +
scale_linetype_manual(values=c("dotted", "solid")) +
theme_bw()
comparison_nelaal
```

0 comments on commit d01c17c

Please sign in to comment.