Skip to content

Commit

Permalink
tout en vrac
Browse files Browse the repository at this point in the history
  • Loading branch information
lrnv committed May 14, 2024
1 parent d26a29b commit 4f4babf
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 38 deletions.
2 changes: 1 addition & 1 deletion src/NPNSEstimator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ function _get_rate_predictors(rt,df)
return prd
end

function StatsBase.fit(::Type{E}, formula::FormulaTerm, df::DataFrame, rt::RateTables.AbstractRateTable) where {E<:NPNSEstimator}
function StatsBase.fit(::Type{E}, formula::FormulaTerm, df::DataFrame, rt::RateTables.AbstractRateTable) where {E<:Union{NPNSEstimator, Nessie}}
rate_predictors = _get_rate_predictors(rt,df)
formula_applied = apply_schema(formula,schema(df))

Expand Down
122 changes: 87 additions & 35 deletions src/Nessie.jl
Original file line number Diff line number Diff line change
@@ -1,41 +1,93 @@
function Nessie(formula::FormulaTerm, df::DataFrame, rt::RateTables.AbstractRateTable)
formula_applied = apply_schema(formula,schema(df))
rate_predictors = String.([RateTables.predictors(rt)...])

nms = StatsModels.termnames(formula_applied.rhs)
if isa(nms, String)
pred_names = [nms]
else
pred_names = nms
struct Nessie
expected_sample_size::Vector{Float64}
expected_life_time::Float64
grid::Vector{Float64}
function Nessie(T, Δ, age, year, rate_preds, ratetable)
grid = mk_grid([1,maximum(T)],1)
# grid = mk_grid(T,1)
expected_sample_size = zero(grid)
for i in eachindex(age)
# Tᵢ = searchsortedlast(grid, T[i])
Λₚ = 0.0
rtᵢ = ratetable[rate_preds[i,:]...]
for j in 1:(length(grid)-1)
λₚ = daily_hazard(rtᵢ, age[i] + grid[j], year[i] + grid[j])
∂Λₚ = λₚ * (grid[j+1]-grid[j]) # λₚ * ∂t
Λₚ += ∂Λₚ
Sₚ = exp(-Λₚ)
expected_sample_size[j] += Sₚ
end
end
expected_life_time = sum(expected_sample_size[1:(end-1)] .* diff(grid)) / length(age)

annual_indices = [searchsortedlast(grid, i) for i in (365.241 * (0:floor(maximum(T)/365.241))).+1]
return new(expected_sample_size[annual_indices], expected_life_time / 365.241, grid[annual_indices])
end
end

times = sort(unique(floor.(df.time ./ 365.241)))
times = unique([0.0; times])
"""
nessie(formula, data, ratetable)
times_d = times .* 365.241
bla bla
new_df = groupby(df, pred_names)
povp = zeros(nrow(unique(df[!,pred_names])))
sit = zeros(length(times))
num_pop = zeros(nrow(unique(df[!,pred_names])), length(times))
"""
function nessie(args...)
r = fit(Nessie,args...)
transform!(r, :estimator => ByRow(x-> (x.grid, x.expected_life_time, x.expected_sample_size)) => [:expected_sample_size,:expected_life_time, :grid])
select!(r, Not(:estimator))

for i in 1:nrow(unique(df[!,pred_names]))
for j in 1:nrow(new_df[i])
Tᵢ = searchsortedlast(times_d, new_df[i].time[j])
rate_preds = select(new_df[i],rate_predictors)
rtᵢ = rt[rate_preds[j,:]...]
Λₚ = 0.0
lt = deepcopy(r)
select!(lt, Not([:expected_sample_size, :grid]))

for m in 1:Tᵢ
λₚ = daily_hazard(rtᵢ, new_df[i].age[j] + times_d[m], new_df[i].year[j] + times_d[m])
∂Λₚ = λₚ
Λₚ += ∂Λₚ
Sₚ = exp(-Λₚ)
num_pop[i,m] += Sₚ
sit[m] += (1-Sₚ) / λₚ
end
end
povp[i] = mean(sit ./ 365.241)
end
return num_pop, povp
end
select!(r, Not(:expected_life_time))
return lt, r
end


expected_life_time(x::Nessie) = x.expected_life_time
expected_sample_size(x::Nessie) = x.expected_sample_size




# function old_Nessie(formula::FormulaTerm, df::DataFrame, rt::RateTables.AbstractRateTable)
# formula_applied = apply_schema(formula,schema(df))
# rate_predictors = String.([RateTables.predictors(rt)...])

# nms = StatsModels.termnames(formula_applied.rhs)
# if isa(nms, String)
# pred_names = [nms]
# else
# pred_names = nms
# end

# times = sort(unique(floor.(df.time ./ 365.241)))
# times = unique([0.0; times])

# times_d = times .* 365.241

# new_df = groupby(df, pred_names)
# povp = zeros(nrow(unique(df[!,pred_names])))
# sit = zeros(length(times))
# num_pop = zeros(nrow(unique(df[!,pred_names])), length(times))

# for i in 1:nrow(unique(df[!,pred_names]))
# for j in 1:nrow(new_df[i])
# Tᵢ = searchsortedlast(times_d, new_df[i].time[j])
# rate_preds = select(new_df[i],rate_predictors)
# rtᵢ = rt[rate_preds[j,:]...]
# Λₚ = 0.0

# for m in 1:Tᵢ
# λₚ = daily_hazard(rtᵢ, new_df[i].age[j] + times_d[m], new_df[i].year[j] + times_d[m])
# ∂Λₚ = λₚ * 365.241
# Λₚ += ∂Λₚ
# Sₚ = exp(-Λₚ)
# num_pop[i,m] += Sₚ
# sit[m] += (1-Sₚ) / λₚ
# end
# end
# povp[i] = mean(sit ./ 365.241)
# end
# return num_pop, povp
# end
6 changes: 4 additions & 2 deletions src/NetSurvival.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ using RateTables
include("fetch_datasets.jl")
include("Surv_and_Strata.jl")

include("Nessie.jl")

include("NPNSEstimator.jl")
include("PoharPerme.jl")
include("EdererI.jl")
Expand All @@ -22,13 +24,13 @@ include("Hakulinen.jl")

include("CrudeMortality.jl")

include("Nessie.jl")


include("GraffeoTest.jl")

export PoharPerme, EdererI, EdererII, Hakulinen
export CrudeMortality
export Nessie
export Nessie, nessie
export fit, confint
export GraffeoTest
export Surv, Strata
Expand Down

0 comments on commit 4f4babf

Please sign in to comment.