diff --git a/Project.toml b/Project.toml index ad83249..9fdfe6e 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ RateTables = "d40fb65e-c2ee-4113-9e14-cb96ca0acb32" StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsModels = "3eaba693-59b7-5ba5-a881-562e759f1c8d" +SurvivalBase = "9cb6079b-e021-4662-8a75-9f65bfa286f2" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [compat] @@ -19,8 +20,8 @@ CSV = "0.10" DataFrames = "1" Distributions = "0.25" LinearAlgebra = "1.6" -RateTables = "0.1.1" RCall = "0.14" +RateTables = "0.1.1" StatsAPI = "1" StatsBase = "0.34" StatsModels = "0.7" diff --git a/src/NetSurvival.jl b/src/NetSurvival.jl index a0dc006..f14d4b7 100644 --- a/src/NetSurvival.jl +++ b/src/NetSurvival.jl @@ -5,14 +5,14 @@ using Distributions using LinearAlgebra using StatsAPI using StatsBase -using StatsModels using Tables using Base.Cartesian using CSV using RateTables +using StatsModels +using SurvivalBase: Surv, Strata include("fetch_datasets.jl") -include("Surv_and_Strata.jl") include("Nessie.jl") include("NPNSEstimator.jl") include("PoharPerme.jl") diff --git a/src/Surv_and_Strata.jl b/src/Surv_and_Strata.jl deleted file mode 100644 index 29f69ae..0000000 --- a/src/Surv_and_Strata.jl +++ /dev/null @@ -1,74 +0,0 @@ -# struct for behavior -# struct Surv{X, Y} -# T::X -# Δ::Y -# end -struct SurvTerm{X, Y} <: AbstractTerm - T::X - Δ::Y - function SurvTerm(T,Δ) - return new{typeof(T),typeof(Δ)}(T, Δ) - end -end -# Base.show(io::IO, t::Surv) = print(io, string(t.T, t.Δ == 1 ? "+" : "")) -Base.show(io::IO, t::SurvTerm) = print(io, "Surv($((t.T, t.Δ)))") - - -# Surv(T::Symbol, Δ::Symbol) = SurvTerm(term(T), term(Δ)) -Surv(T::Float64, Δ::Bool) = (T, Δ) - -Strata(x) = x -struct StrataTerm{X} <: AbstractTerm - Covariable::X -end -Base.show(io::IO, t::StrataTerm) = print(io, "Strata($((t.Covariable)))") - -Strata(Covariables::Vector) = StrataTerm(term(Covariables)) - -StatsModels.termvars(p::StrataTerm) = StatsModels.termvars(p.Covariable) - - -function StatsModels.apply_schema(t::FunctionTerm{typeof(Surv)}, - sch::StatsModels.Schema, - Mod::Type{<:Any}) - return apply_schema(SurvTerm(t.args...), sch, Mod) -end - -function StatsModels.apply_schema(t::FunctionTerm{typeof(Strata)}, - sch::StatsModels.Schema, - Mod::Type{<:Any}) - return apply_schema(StrataTerm(t.args...), sch, Mod) -end - -function StatsModels.apply_schema(t::SurvTerm{X,Y}, - sch::StatsModels.Schema, - Mod::Type{<:Any}) where {X,Y} - T = apply_schema(t.T, sch, Mod) - Δ = apply_schema(t.Δ, sch, Mod) - isa(T, ContinuousTerm) || throw(ArgumentError("Surv only works with continuous terms (got $T)")) - isa(Δ, ContinuousTerm) || throw(ArgumentError("Surv only works with discrete terms (got $Δ)")) - return SurvTerm(T, Δ) -end - -function StatsModels.apply_schema(t::StrataTerm, - sch::StatsModels.Schema, - Mod::Type{<:Any}) - X = apply_schema(t.Covariable, sch, Mod) - return StrataTerm(X) -end - -function StatsModels.modelcols(t::SurvTerm, d::NamedTuple) - T = modelcols(t.T, d) - Δ = modelcols(t.Δ, d) - return hcat(T,Δ) -end - -function StatsModels.modelcols(t::StrataTerm, d::NamedTuple) - return modelcols(t.Covariable, d) -end - - - - - -