diff --git a/src/NPNSEstimator.jl b/src/NPNSEstimator.jl index b61e96c..7206807 100644 --- a/src/NPNSEstimator.jl +++ b/src/NPNSEstimator.jl @@ -39,7 +39,7 @@ function _get_rate_predictors(rt,df) return prd end -function StatsBase.fit(::Type{E}, formula::FormulaTerm, df::DataFrame, rt::RateTables.AbstractRateTable) where {E<:NPNSEstimator} +function StatsBase.fit(::Type{E}, formula::FormulaTerm, df::DataFrame, rt::RateTables.AbstractRateTable) where {E<:Union{NPNSEstimator, Nessie}} rate_predictors = _get_rate_predictors(rt,df) formula_applied = apply_schema(formula,schema(df)) diff --git a/src/Nessie.jl b/src/Nessie.jl index 3ba5be2..86f917b 100644 --- a/src/Nessie.jl +++ b/src/Nessie.jl @@ -1,41 +1,93 @@ -function Nessie(formula::FormulaTerm, df::DataFrame, rt::RateTables.AbstractRateTable) - formula_applied = apply_schema(formula,schema(df)) - rate_predictors = String.([RateTables.predictors(rt)...]) - - nms = StatsModels.termnames(formula_applied.rhs) - if isa(nms, String) - pred_names = [nms] - else - pred_names = nms +struct Nessie + expected_sample_size::Vector{Float64} + expected_life_time::Float64 + grid::Vector{Float64} + function Nessie(T, Δ, age, year, rate_preds, ratetable) + grid = mk_grid([1,maximum(T)],1) + # grid = mk_grid(T,1) + expected_sample_size = zero(grid) + for i in eachindex(age) + # Tᵢ = searchsortedlast(grid, T[i]) + Λₚ = 0.0 + rtᵢ = ratetable[rate_preds[i,:]...] + for j in 1:(length(grid)-1) + λₚ = daily_hazard(rtᵢ, age[i] + grid[j], year[i] + grid[j]) + ∂Λₚ = λₚ * (grid[j+1]-grid[j]) # λₚ * ∂t + Λₚ += ∂Λₚ + Sₚ = exp(-Λₚ) + expected_sample_size[j] += Sₚ + end + end + expected_life_time = sum(expected_sample_size[1:(end-1)] .* diff(grid)) / length(age) + + annual_indices = [searchsortedlast(grid, i) for i in (365.241 * (0:floor(maximum(T)/365.241))).+1] + return new(expected_sample_size[annual_indices], expected_life_time / 365.241, grid[annual_indices]) end +end - times = sort(unique(floor.(df.time ./ 365.241))) - times = unique([0.0; times]) +""" + nessie(formula, data, ratetable) - times_d = times .* 365.241 +bla bla - new_df = groupby(df, pred_names) - povp = zeros(nrow(unique(df[!,pred_names]))) - sit = zeros(length(times)) - num_pop = zeros(nrow(unique(df[!,pred_names])), length(times)) +""" +function nessie(args...) + r = fit(Nessie,args...) + transform!(r, :estimator => ByRow(x-> (x.grid, x.expected_life_time, x.expected_sample_size)) => [:expected_sample_size,:expected_life_time, :grid]) + select!(r, Not(:estimator)) - for i in 1:nrow(unique(df[!,pred_names])) - for j in 1:nrow(new_df[i]) - Tᵢ = searchsortedlast(times_d, new_df[i].time[j]) - rate_preds = select(new_df[i],rate_predictors) - rtᵢ = rt[rate_preds[j,:]...] - Λₚ = 0.0 + lt = deepcopy(r) + select!(lt, Not([:expected_sample_size, :grid])) - for m in 1:Tᵢ - λₚ = daily_hazard(rtᵢ, new_df[i].age[j] + times_d[m], new_df[i].year[j] + times_d[m]) - ∂Λₚ = λₚ - Λₚ += ∂Λₚ - Sₚ = exp(-Λₚ) - num_pop[i,m] += Sₚ - sit[m] += (1-Sₚ) / λₚ - end - end - povp[i] = mean(sit ./ 365.241) - end - return num_pop, povp -end \ No newline at end of file + select!(r, Not(:expected_life_time)) + return lt, r +end + + +expected_life_time(x::Nessie) = x.expected_life_time +expected_sample_size(x::Nessie) = x.expected_sample_size + + + + +# function old_Nessie(formula::FormulaTerm, df::DataFrame, rt::RateTables.AbstractRateTable) +# formula_applied = apply_schema(formula,schema(df)) +# rate_predictors = String.([RateTables.predictors(rt)...]) + +# nms = StatsModels.termnames(formula_applied.rhs) +# if isa(nms, String) +# pred_names = [nms] +# else +# pred_names = nms +# end + +# times = sort(unique(floor.(df.time ./ 365.241))) +# times = unique([0.0; times]) + +# times_d = times .* 365.241 + +# new_df = groupby(df, pred_names) +# povp = zeros(nrow(unique(df[!,pred_names]))) +# sit = zeros(length(times)) +# num_pop = zeros(nrow(unique(df[!,pred_names])), length(times)) + +# for i in 1:nrow(unique(df[!,pred_names])) +# for j in 1:nrow(new_df[i]) +# Tᵢ = searchsortedlast(times_d, new_df[i].time[j]) +# rate_preds = select(new_df[i],rate_predictors) +# rtᵢ = rt[rate_preds[j,:]...] +# Λₚ = 0.0 + +# for m in 1:Tᵢ +# λₚ = daily_hazard(rtᵢ, new_df[i].age[j] + times_d[m], new_df[i].year[j] + times_d[m]) +# ∂Λₚ = λₚ * 365.241 +# Λₚ += ∂Λₚ +# Sₚ = exp(-Λₚ) +# num_pop[i,m] += Sₚ +# sit[m] += (1-Sₚ) / λₚ +# end +# end +# povp[i] = mean(sit ./ 365.241) +# end +# return num_pop, povp +# end \ No newline at end of file diff --git a/src/NetSurvival.jl b/src/NetSurvival.jl index 02fabda..5197429 100644 --- a/src/NetSurvival.jl +++ b/src/NetSurvival.jl @@ -14,6 +14,8 @@ using RateTables include("fetch_datasets.jl") include("Surv_and_Strata.jl") +include("Nessie.jl") + include("NPNSEstimator.jl") include("PoharPerme.jl") include("EdererI.jl") @@ -22,13 +24,13 @@ include("Hakulinen.jl") include("CrudeMortality.jl") -include("Nessie.jl") + include("GraffeoTest.jl") export PoharPerme, EdererI, EdererII, Hakulinen export CrudeMortality -export Nessie +export Nessie, nessie export fit, confint export GraffeoTest export Surv, Strata