Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixnans bool propagation #47

Open
wants to merge 10 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "e54bda2e-c571-11ec-9d64-0242ac120002"
license = "MIT"
desc = "Julia implementation of Modal Decision Trees and Random Forest algorithms"
authors = ["Giovanni PAGLIARINI"]
version = "0.5.0"
version = "0.5.1"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
32 changes: 19 additions & 13 deletions src/ModalCART.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ function generate_relevant_decisions(
idxs,
region,
grouped_featsaggrsnopss,
grouped_featsnaggrss,
grouped_featsnaggrss;
kwargs...
)
out = []
@inbounds for (i_modality,
Expand Down Expand Up @@ -295,6 +296,7 @@ function generate_relevant_decisions(
features_inds,
grouped_featsaggrsnopss[i_modality],
grouped_featsnaggrss[i_modality],
get(kwargs, :fixnans, false),
)
decision_instantiator = _threshold->begin
cond = ScalarCondition(metacondition, _threshold)
Expand Down Expand Up @@ -348,8 +350,8 @@ Base.@propagate_inbounds @inline function optimize_node!(
idxs :: AbstractVector{Int},
n_classes :: Int,
rng :: Random.AbstractRNG,
kwargs...
) where{P,L<:_Label,D<:AbstractDecision,U,NSubRelationsFunction<:Function,S<:MCARTState}

# Region of idxs to use to perform the split
region = node.region
_ninstances = length(region)
Expand Down Expand Up @@ -724,13 +726,15 @@ Base.@propagate_inbounds @inline function optimize_node!(
idxs,
region,
grouped_featsaggrsnopss,
grouped_featsnaggrss,
grouped_featsnaggrss;
kwargs...
)
if isa(_is_classification, Val{true})
thresh_domain, additional_info = limit_threshold_domain(aggr_thresholds, Yf, Wf, loss_function, test_op, min_samples_leaf, perform_domain_optimization; n_classes = n_classes, nc = nc, nt = nt)
else
thresh_domain, additional_info = limit_threshold_domain(aggr_thresholds, Yf, Wf, loss_function, test_op, min_samples_leaf, perform_domain_optimization)
end

# Look for the best threshold 'a', as in atoms like "feature >= a"
for (_threshold, threshold_info) in zip(thresh_domain, additional_info)
decision = decision_instantiator(_threshold)
Expand Down Expand Up @@ -952,6 +956,7 @@ Base.@propagate_inbounds @inline function optimize_node!(
idxs = deepcopy(idxs_copy),
n_classes = n_classes,
rng = copy(rng),
kwargs...
)
end
# TODO: evaluate the goodneess of the subtree?
Expand Down Expand Up @@ -1029,7 +1034,7 @@ end
_metaconditions = metaconditions(X)

_grouped_metaconditions = SoleData.grouped_metaconditions(_metaconditions, _features)

# _grouped_metaconditions::AbstractVector{<:AbstractVector{Tuple{<:ScalarMetaCondition}}}
# [[(i_metacond, aggregator, metacondition)...]...]

Expand Down Expand Up @@ -1061,7 +1066,7 @@ end
grouped_featsnaggrss = last.(permodality_groups)

# Process nodes recursively, using multi-threading
function process_node!(node, rng)
function process_node!(node, rng; kwargs...)
# Note: better to spawn rng's beforehand, to preserve reproducibility independently from optimize_node!
rng_l = spawn(rng)
rng_r = spawn(rng)
Expand All @@ -1084,17 +1089,17 @@ end
idxs = idxs,
rng = rng,
lookahead = lookahead,
kwargs...,
kwargs...
)
# !print_progress || ProgressMeter.update!(p, node.purity)
!print_progress || ProgressMeter.next!(p, spinner="⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏")
if !node.is_leaf
l = Threads.@spawn process_node!(node.l, rng_l)
r = Threads.@spawn process_node!(node.r, rng_r)
l = Threads.@spawn process_node!(node.l, rng_l; kwargs...)
r = Threads.@spawn process_node!(node.r, rng_r; kwargs...)
wait(l), wait(r)
end
end
@sync Threads.@spawn process_node!(root, rng)
@sync Threads.@spawn process_node!(root, rng; kwargs...)

!print_progress || ProgressMeter.finish!(p)

Expand Down Expand Up @@ -1192,9 +1197,10 @@ end
* " lookahead >= 0)")
end

if SoleData.hasnans(Xs)
error("This algorithm does not allow NaN values")
end
# fixnans = get(kwargs, :fixnans, false)
# if !fixnans && SoleData.hasnans(Xs)
# error("This algorithm does not allow NaN values")
# end

if nothing in Y
error("This algorithm does not allow nothing values in Y")
Expand Down Expand Up @@ -1236,7 +1242,7 @@ function fit_tree(
kwargs...,
) where {L<:Union{CLabel,RLabel}, U}
# Check validity of the input
check_input(Xs, Y, initconditions, W; profile = profile, lookahead = lookahead, kwargs...)
check_input(Xs, Y, initconditions, W; profile=profile, lookahead=lookahead, kwargs...)

# Classification-only: transform labels to categorical form (indexed by integers)
n_classes = begin
Expand Down
7 changes: 5 additions & 2 deletions src/build.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ function build_tree(
##############################################################################
rng :: Random.AbstractRNG = Random.GLOBAL_RNG,
print_progress :: Bool = true,
kwargs...
) where {L<:Label,U}

@assert W isa AbstractVector || W in [nothing, :rebalance, :default]
Expand Down Expand Up @@ -108,8 +109,8 @@ function build_tree(
@assert isnothing(max_depth) || (max_depth >= 0)
@assert isnothing(max_modal_depth) || (max_modal_depth >= 0)

fit_tree(X, Y, initconditions, W
;###########################################################################
fit_tree(X, Y, initconditions, W;
###########################################################################
loss_function = loss_function,
lookahead = lookahead,
max_depth = max_depth,
Expand All @@ -127,6 +128,7 @@ function build_tree(
############################################################################
rng = rng,
print_progress = print_progress,
kwargs...
)
end

Expand Down Expand Up @@ -162,6 +164,7 @@ function build_forest(
rng :: Random.AbstractRNG = Random.GLOBAL_RNG,
print_progress :: Bool = true,
suppress_parity_warning :: Bool = false,
fixnans :: Bool = false,
) where {L<:Label,U}

@assert W isa AbstractVector || W in [nothing, :rebalance, :default]
Expand Down
10 changes: 5 additions & 5 deletions src/interfaces/MLJ.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,13 @@ depth(t::MDT.DTree) = height(t)
############################################################################################
############################################################################################

function MMI.fit(m::SymbolicModel, verbosity::Integer, X, y, var_grouping, classes_seen=nothing, w=nothing)
function MMI.fit(m::SymbolicModel, verbosity::Integer, X, y, var_grouping, classes_seen=nothing, w=nothing; kwargs...)
# @show get_kwargs(m, X)
model = begin
if m isa ModalDecisionTree
MDT.build_tree(X, y, w; get_kwargs(m, X)...)
MDT.build_tree(X, y, w; get_kwargs(m, X)..., kwargs...)
elseif m isa ModalRandomForest
MDT.build_forest(X, y, w; get_kwargs(m, X)...)
MDT.build_forest(X, y, w; get_kwargs(m, X)..., kwargs...)
else
error("Unexpected model type: $(typeof(m))")
end
Expand Down Expand Up @@ -171,8 +171,8 @@ end
# DATA FRONT END
############################################################################################

function MMI.reformat(m::SymbolicModel, X, y, w = nothing; passive_mode = false)
X, var_grouping = wrapdataset(X, m; passive_mode = passive_mode)
function MMI.reformat(m::SymbolicModel, X, y, w = nothing; passive_mode = false, kwargs...)
X, var_grouping = wrapdataset(X, m; passive_mode = passive_mode, kwargs...)
y, classes_seen = fix_y(y)
(X, y, var_grouping, classes_seen, w)
end
Expand Down
52 changes: 4 additions & 48 deletions src/interfaces/MLJ/clean.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,22 +88,6 @@ function MMI.clean!(m::SymbolicModel)
########################################################################################
########################################################################################
########################################################################################

if !(isnothing(m.relations) ||
m.relations isa Symbol && m.relations in keys(AVAILABLE_RELATIONS) ||
m.relations isa Vector{<:AbstractRelation} ||
m.relations isa Function
)
warning *= "relations should be in $(collect(keys(AVAILABLE_RELATIONS))) " *
"or a vector of SoleLogics.AbstractRelation's, " *
"but $(m.relations) " *
"was provided. Defaulting to $(mlj_default_relations_str).\n"
m.relations = nothing
end

isnothing(m.relations) && (m.relations = mlj_default_relations)
m.relations isa Vector{<:AbstractRelation} && (m.relations = m.relations)

# Patch name: features -> conditions
if !isnothing(m.features)
if !isnothing(m.conditions)
Expand All @@ -113,24 +97,10 @@ function MMI.clean!(m::SymbolicModel)
m.conditions = m.features
m.features = nothing
end

if !(isnothing(m.conditions) ||
m.conditions isa Vector{<:Union{SoleData.VarFeature,Base.Callable}} ||
m.conditions isa Vector{<:Tuple{Base.Callable,Integer}} ||
m.conditions isa Vector{<:Tuple{TestOperator,<:Union{SoleData.VarFeature,Base.Callable}}} ||
m.conditions isa Vector{<:SoleData.ScalarMetaCondition}
)
warning *= "conditions should be either:" *
"a) a vector of features (i.e., callables to be associated to all variables, or SoleData.VarFeature objects);\n" *
"b) a vector of tuples (callable,var_id);\n" *
"c) a vector of tuples (test_operator,features);\n" *
"d) a vector of SoleData.ScalarMetaCondition;\n" *
"but $(m.conditions) " *
"was provided. Defaulting to $(mlj_default_conditions_str).\n"
m.conditions = nothing
end

isnothing(m.conditions) && (m.conditions = mlj_default_conditions)

m.relations, _w = SoleData.autorelations(m.relations); warning *= _w
m.conditions, _w = SoleData.autoconditions(m.conditions); warning *= _w
m.downsize, _w = SoleData.autodownsize(m); warning *= _w

if !(isnothing(m.initconditions) ||
m.initconditions isa Symbol && m.initconditions in keys(AVAILABLE_INITCONDITIONS) ||
Expand All @@ -148,20 +118,6 @@ function MMI.clean!(m::SymbolicModel)
########################################################################################
########################################################################################

m.downsize = begin
if m.downsize == true
make_downsizing_function(m)
elseif m.downsize == false
identity
elseif m.downsize isa NTuple{N,Integer} where N
make_downsizing_function(m.downsize)
elseif m.downsize isa Function
m.downsize
else
error("Unexpected value for `downsize` encountered: $(m.downsize)")
end
end

if m.rng isa Integer
m.rng = Random.MersenneTwister(m.rng)
end
Expand Down
Loading
Loading