Skip to content

Commit

Permalink
Added jackknife for coxph
Browse files Browse the repository at this point in the history
  • Loading branch information
alexpkeil1 committed Sep 13, 2023
1 parent 3603f6a commit ffd380a
Show file tree
Hide file tree
Showing 7 changed files with 200 additions and 131 deletions.
2 changes: 2 additions & 0 deletions src/LSurvival.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ export aic,
modelmatrix, #PValue
fitted,
isfitted,
jackknife,
loglikelihood,
logpartiallikelihood,
lrtest, # re-exported
Expand Down Expand Up @@ -140,6 +141,7 @@ include("residuals.jl")
include("npsurvival.jl")
include("data_generators.jl")
include("bootstrap.jl")
include("jackknife.jl")
include("deprecated.jl")


Expand Down
2 changes: 2 additions & 0 deletions src/coxmodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,8 @@ function StatsBase.vcov(m::M; type::Union{String,Nothing} = nothing) where {M<:A
mwarn(m)
if type == "robust"
res = robust_vcov(m)
elseif type == "robust"
res = jackknife_vcov(m)
else
res = -inv(m.P._hess)
end
Expand Down
37 changes: 27 additions & 10 deletions src/docstr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1105,16 +1105,6 @@ ft2w
"""

DOC_VCOV = """
Covariance matrix for Cox proportional hazards models
Keyword arguments
- `type` nothing or "robust": determines whether model based or robust (dfbeta based) variance is returned.
See ?residuals for info on `dfbeta` residuals
$DOC_ROBUST_VCOV
"""


DOC_RESIDUALS = """
Expand Down Expand Up @@ -1245,10 +1235,37 @@ vcov(ft, type="robust")
vcov(ft2, type="robust")
```
####################################################################
## jackknife residuals: influence of individual observations on each parameter
```@example
using LSurvival
dat1 = (
time = [1,1,6,6,8,9],
status = [1,0,1,1,0,1],
x = [1,1,1,0,0,0]
)
ft = coxph(@formula(Surv(time,status)~x),dat1, ties="breslow")
jackknife(ft)
residuals(ft, type="jackknife")
```
"""

DOC_VCOV = """
Covariance matrix for Cox proportional hazards models
Keyword arguments
- `type` nothing or "robust": determines whether model based or robust (dfbeta based) variance is returned.
See ?residuals for info on `dfbeta` residuals
$DOC_ROBUST_VCOV
"""


######## DEPRECATED FUNCTIONS ###############

"""
Expand Down
119 changes: 0 additions & 119 deletions src/indev.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,125 +9,6 @@ using LSurvival, Random, Optim, BenchmarkTools, RCall
######################################################################

#
"""
Remove the last element from an LSurvivalResp object
```julia
id, int, outt, data =
LSurvival.dgm(MersenneTwister(112), 100, 10; afun = LSurvival.int_0)
data[:, 1] = round.(data[:, 1], digits = 3)
d, X = data[:, 4], data[:, 1:3]
wt = rand(length(d))
wt ./= (sum(wt) / length(wt))
R = LSurvivalResp(int, outt, d, ID.(id)) # specification with ID only
Ri, Rj, idxi, idxj = pop(R);
```
"""
function pop(R::T) where {T<:LSurvivalResp}
uid = unique(R.id)[end]
idx = findall(getfield.(R.id, :value) .== uid.value)
nidx = setdiff(collect(eachindex(R.id)),idx)
Ri = LSurvivalResp(R.enter[idx], R.exit[idx], R.y[idx], R.wts[idx], R.id[idx]; origintime= R.origin)
Rj = LSurvivalResp(R.enter[nidx], R.exit[nidx], R.y[nidx], R.wts[nidx], R.id[nidx],origintime= R.origin)
Ri, Rj, idx, nidx
end

"""
id, int, outt, data =
LSurvival.dgm(MersenneTwister(112), 100, 10; afun = LSurvival.int_0)
data[:, 1] = round.(data[:, 1], digits = 3)
d, X = data[:, 4], data[:, 1:3]
wt = rand(length(d))
wt ./= (sum(wt) / length(wt))
P = PHParms(X)
R = LSurvivalResp(int, outt, d, ID.(id)) # specification with ID only
Ri, Rj, idxi, idxj = pop(R);
Pi = popat!(P, idxi, idxj)
"""
function popat!(P::T, idxi, idxj) where {T<:PHParms}
Pi = PHParms(P.X[idxi,:], P._B, P._r[idxi], P._LL, P._grad, P._hess, 1, P.p)
P.X = P.X[idxj,:]
P._r = P._r[idxj]
P.n -= 1
Pi
end



"""
Insert an observation into the front of an LSurvivalResp object
```julia
id, int, outt, data =
LSurvival.dgm(MersenneTwister(112), 100, 10; afun = LSurvival.int_0)
data[:, 1] = round.(data[:, 1], digits = 3)
d, X = data[:, 4], data[:, 1:3]
wt = rand(length(d))
wt ./= (sum(wt) / length(wt))
R = LSurvivalResp(int, outt, d, ID.(id)) # specification with ID only
Ri, Rj, idxi, idxj = pop(R);
R = push(Ri, Rj)
```
"""
function push(Ri::T, Rj::T) where {T<:LSurvivalResp}
Ri = LSurvivalResp(
vcat(Ri.enter, Rj.enter),
vcat(Ri.exit, Rj.exit),
vcat(Ri.y, Rj.y),
vcat(Ri.wts, Rj.wts),
vcat(Ri.id, Rj.id);
origintime= min(Ri.origin, Rj.origin))
end


"""
Insert an observation into the front of an PHParms object
"""
function push!(Pi::T, Pj::T) where {T<:PHParms}
Pj.X = vcat(Pi.X, Pj.X)
Pj._r = vcat(Pi._r, Pj._r)
Pj.n +=1
nothing
end


"""
id, int, outt, data =
LSurvival.dgm(MersenneTwister(112), 100, 10; afun = LSurvival.int_0)
data[:, 1] = round.(data[:, 1], digits = 3)
d, X = data[:, 4], data[:, 1:3]
wt = rand(length(d))
wt ./= (sum(wt) / length(wt))
m = coxph(X,int, outt,d, wts=wt)
R::Union{Nothing,G} # Survival response
P::L # parameters
formula::Union{FormulaTerm,Nothing}
ties::String
fit::Bool
bh::Matrix{Float64}
RL::Union{Nothing,Vector{Matrix{Float64}}} # residual matrix
"""
function jackknife(m::M) where {M<:PHmodel}
uid = unique(m.R.id)
coefs = zeros(length(uid), length(m.P._B))
for i in eachindex(uid)
Ri, Rj, idxi, idxj = pop(m.R);
Pi = popat!(m.P, idxi, idxj)
mi= PHModel(Rj, m.P, m.formula, m.ties, false, m.bh[1:length(Rj.eventtimes),:], nothing)
fit!(mi, getbasehaz=false)
coefs[i,:] = mi.P._B
m.R = push(Ri, Rj)
push!(Pi, m.P)
end
end
=#



Expand Down
156 changes: 156 additions & 0 deletions src/jackknife.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
"""
Remove the last element from an LSurvivalResp object
```julia
id, int, outt, data =
LSurvival.dgm(MersenneTwister(112), 100, 10; afun = LSurvival.int_0)
data[:, 1] = round.(data[:, 1], digits = 3)
d, X = data[:, 4], data[:, 1:3]
wt = rand(length(d))
wt ./= (sum(wt) / length(wt))
R = LSurvivalResp(int, outt, d, ID.(id)) # specification with ID only
Ri, Rj, idxi, idxj = pop(R);
```
"""
function pop(R::T) where {T<:LSurvivalResp}
uid = unique(R.id)[end]
idx = findall(getfield.(R.id, :value) .== uid.value)
nidx = setdiff(collect(eachindex(R.id)),idx)
Ri = LSurvivalResp(R.enter[idx], R.exit[idx], R.y[idx], R.wts[idx], R.id[idx]; origintime= R.origin)
Rj = LSurvivalResp(R.enter[nidx], R.exit[nidx], R.y[nidx], R.wts[nidx], R.id[nidx],origintime= R.origin)
Ri, Rj, idx, nidx
end

"""
id, int, outt, data =
LSurvival.dgm(MersenneTwister(112), 100, 10; afun = LSurvival.int_0)
data[:, 1] = round.(data[:, 1], digits = 3)
d, X = data[:, 4], data[:, 1:3]
wt = rand(length(d))
wt ./= (sum(wt) / length(wt))
P = PHParms(X)
R = LSurvivalResp(int, outt, d, ID.(id)) # specification with ID only
Ri, Rj, idxi, idxj = pop(R);
Pi = popat!(P, idxi, idxj)
"""
function popat!(P::T, idxi, idxj) where {T<:PHParms}
Pi = PHParms(P.X[idxi,:], P._B, P._r[idxi], P._LL, P._grad, P._hess, 1, P.p)
P.X = P.X[idxj,:]
P._r = P._r[idxj]
P.n -= length(idxi)
Pi
end



"""
Insert an observation into the front of an LSurvivalResp object
```julia
id, int, outt, data =
LSurvival.dgm(MersenneTwister(112), 100, 10; afun = LSurvival.int_0)
data[:, 1] = round.(data[:, 1], digits = 3)
d, X = data[:, 4], data[:, 1:3]
wt = rand(length(d))
wt ./= (sum(wt) / length(wt))
R = LSurvivalResp(int, outt, d, ID.(id)) # specification with ID only
Ri, Rj, idxi, idxj = pop(R);
R = push(Ri, Rj)
```
"""
function push(Ri::T, Rj::T) where {T<:LSurvivalResp}
Ri = LSurvivalResp(
vcat(Ri.enter, Rj.enter),
vcat(Ri.exit, Rj.exit),
vcat(Ri.y, Rj.y),
vcat(Ri.wts, Rj.wts),
vcat(Ri.id, Rj.id);
origintime= min(Ri.origin, Rj.origin))
end


"""
Insert an observation into the front of an PHParms object
"""
function push!(Pi::T, Pj::T) where {T<:PHParms}
Pj.X = vcat(Pi.X, Pj.X)
Pj._r = vcat(Pi._r, Pj._r)
Pj.n +=1
nothing
end


"""
id, int, outt, data =
LSurvival.dgm(MersenneTwister(112), 100, 10; afun = LSurvival.int_0)
data[:, 1] = round.(data[:, 1], digits = 3)
d, X = data[:, 4], data[:, 1:3]
wt = rand(length(d))
wt ./= (sum(wt) / length(wt))
m = coxph(X,int, outt,d, wts=wt, id=ID.(id))
jk = jackknife(m);
bs = bootstrap(MersenneTwister(12321), m, 1000);
N = nobs(m)
#comparing estimate with jackknife estimate with bootstrap mean
hcat(coef(m), mean(jk, dims=1)[1,:], mean(bs, dims=1)[1,:])
semb = stderror(m)
sebs = std(bs, dims=1)
sejk = std(jk, dims=1, corrected=false) .* sqrt(N-1)
sero = stderror(m, type="robust")
jackknife_vcov(m)
LSurvival.robust_vcov(m)
hcat(semb, sebs[1,:], sejk[1,:], sero)
dat1 = (time = [1, 1, 6, 6, 8, 9], status = [1, 0, 1, 1, 0, 1], x = [1, 1, 1, 0, 0, 0])
dat1clust = (
id = [1, 2, 3, 3, 4, 4, 5, 5, 6, 6],
enter = [0, 0, 0, 1, 0, 1, 0, 1, 0, 1],
exit = [1, 1, 1, 6, 1, 6, 1, 8, 1, 9],
status = [1, 0, 0, 1, 0, 1, 0, 0, 0, 1],
x = [1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
)
m = coxph(@formula(Surv(time, status)~x),dat1)
mc = coxph(@formula(Surv(enter, exit, status)~x),dat1clust, id=ID.(dat1clust.id))
jk = jackknife(m);
jkc = jackknife(mc);
bs = bootstrap(mc, 100);
std(bs[:,1])
std(jkc[:,1])
stderror(mc)
@assert jk == jkc
"""
function jackknife(m::M) where {M<:PHModel}
uid = unique(m.R.id)
coefs = zeros(length(uid), length(coef(m)))
R = deepcopy(m.R)
P = deepcopy(m.P)
for i in eachindex(uid)
Ri, Rj, idxi, idxj = pop(m.R);
Pi = popat!(m.P, idxi, idxj)
mi= PHModel(Rj, m.P, m.formula, m.ties, false, m.bh[1:length(Rj.eventtimes),:], nothing)
fit!(mi, getbasehaz=false)
coefs[i,:] = coef(mi)
m.R = push(Ri, Rj)
push!(Pi, m.P)
end
m.R, m.P = R, P
coefs
end


function jackknife_vcov(m::M) where {M<:PHModel}
N = nobs(m)
#comparing estimate with jackknife estimate with bootstrap mean
jk = jackknife(m);
covjk = cov(jk) .* (N-1)
covjk
end

9 changes: 8 additions & 1 deletion src/residuals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ $DOC_RESIDUALS
"""
function StatsBase.residuals(m::M; type = "martingale") where {M<:PHModel}
valid_methods =
["schoenfeld", "score", "martingale", "dfbeta", "dfbetas", "scaled_schoenfeld"]
["schoenfeld", "score", "martingale", "dfbeta", "dfbetas", "scaled_schoenfeld", "jackknife"]
whichmethod = findall(valid_methods .== lowercase(type))
thismethod = valid_methods[whichmethod][1]
if thismethod == "martingale"
Expand All @@ -22,6 +22,8 @@ function StatsBase.residuals(m::M; type = "martingale") where {M<:PHModel}
resid ./= stderror(m)'
elseif thismethod == "scaled_schoenfeld"
resid = resid_schoenfeld(m) * inv(m.P._hess)
elseif thismethod == "jackknife"
resid = resid_jackknife(m)
else
throw("Method $type not supported yet")
end
Expand Down Expand Up @@ -136,6 +138,11 @@ function resid_dfbeta(m::M) where {M<:PHModel}
return dfbeta .* m.R.wts
end

function resid_jackknife(m::M) where {M<:PHModel}
jk = jackknife(m);
permutedims(jk' .- coef(m))
end

"""
$DOC_ROBUST_VCOV
"""
Expand Down
Loading

0 comments on commit ffd380a

Please sign in to comment.