diff --git a/Project.toml b/Project.toml index cde30189..a78cda25 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "EvoTrees" uuid = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5" authors = ["jeremiedb "] -version = "0.15.1" +version = "0.15.2" [deps] BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" diff --git a/README.md b/README.md index 785776bf..fcbe6a98 100644 --- a/README.md +++ b/README.md @@ -46,27 +46,27 @@ Code to reproduce is availabe in [`benchmarks/regressor.jl`](https://github.com/ - Julia: v1.9.1. - Algorithms - XGBoost: v2.3.0 (Using the `hist` algorithm). - - EvoTrees: v0.15.0. + - EvoTrees: v0.15.2. ### Training: | Dimensions / Algo | XGBoost CPU | EvoTrees CPU | XGBoost GPU | EvoTrees GPU | |---------------------|:-----------:|:------------:|:-----------:|:------------:| -| 100K x 100 | 2.33s | 1.09s | 0.90s | 2.72s | -| 500K x 100 | 10.7s | 2.96s | 1.84s | 3.65s | -| 1M x 100 | 20.9s | 6.48s | 3.10s | 4.45s | -| 5M x 100 | 108s | 35.8s | 12.9s | 12.7s | -| 10M x 100 | 216s | 71.6s | 25.5s | 23.0s | +| 100K x 100 | 2.34s | 1.01s | 0.90s | 2.61s | +| 500K x 100 | 10.7s | 3.95s | 1.84s | 3.41s | +| 1M x 100 | 21.1s | 6.57s | 3.10s | 4.47s | +| 5M x 100 | 108s | 36.1s | 12.9s | 12.5s | +| 10M x 100 | 218s | 72.6s | 25.5s | 23.0s | ### Inference: | Dimensions / Algo | XGBoost CPU | EvoTrees CPU | XGBoost GPU | EvoTrees GPU | |---------------------|:------------:|:------------:|:-----------:|:------------:| -| 100K x 100 | 0.151s | 0.053s | NA | 0.036s | -| 500K x 100 | 0.628s | 0.276s | NA | 0.169s | -| 1M x 100 | 1.26s | 0.558s | NA | 0.334s | +| 100K x 100 | 0.151s | 0.058s | NA | 0.045s | +| 500K x 100 | 0.647s | 0.248s | NA | 0.172s | +| 1M x 100 | 1.26s | 0.573s | NA | 0.327s | | 5M x 100 | 6.04s | 2.87s | NA | 1.66s | -| 10M x 100 | 12.4s | 5.71s | NA | 3.31s | +| 10M x 100 | 12.4s | 5.71s | NA | 3.40s | ## MLJ Integration diff --git a/benchmarks/regressor.jl b/benchmarks/regressor.jl index b9312f44..b78b38af 100644 --- a/benchmarks/regressor.jl +++ b/benchmarks/regressor.jl @@ -8,13 +8,21 @@ using BenchmarkTools using Random: seed! import CUDA +### v.0.15.1 +# desktop | 1e6 | depth 11 | cpu: 37.2s +# desktop | 10e6 | depth 11 | cpu + +### perf depth +# desktop | 1e6 | depth 11 | cpu: 28s gpu: 73 sec | xgboost: 26s +# desktop | 10e6 | depth 11 | cpu 205s gpu: 109 sec | xgboost 260s nobs = Int(1e6) num_feat = Int(100) nrounds = 200 +max_depth = 6 tree_type = "binary" T = Float64 nthread = Base.Threads.nthreads() -@info "testing with: $nobs observations | $num_feat features. nthread: $nthread | tree_type : $tree_type" +@info "testing with: $nobs observations | $num_feat features. nthread: $nthread | tree_type : $tree_type | max_depth : $max_depth" seed!(123) x_train = rand(T, nobs, num_feat) y_train = rand(T, size(x_train, 1)) @@ -37,7 +45,7 @@ end @info "train" params_xgb = Dict( :num_round => nrounds, - :max_depth => 5, + :max_depth => max_depth - 1, :eta => 0.05, :objective => loss_xgb, :print_every_n => 5, @@ -98,7 +106,7 @@ params_evo = EvoTreeRegressor(; lambda=0.0, gamma=0.0, eta=0.05, - max_depth=6, + max_depth=max_depth, min_weight=1.0, rowsample=0.5, colsample=0.5, @@ -117,14 +125,11 @@ device = "cpu" # @time m_evo = fit_evotree(params_evo; x_train, y_train, device, verbosity, print_every_n=100); @info "train - eval" @time m_evo = fit_evotree(params_evo; x_train, y_train, x_eval=x_train, y_eval=y_train, metric=metric_evo, device, verbosity, print_every_n=100); -# using Plots -# plot(m_evo, 2) - @time m_evo = fit_evotree(params_evo; x_train, y_train, x_eval=x_train, y_eval=y_train, metric=metric_evo, device, verbosity, print_every_n=100); @info "predict" @time pred_evo = m_evo(x_train); @time pred_evo = m_evo(x_train); -@btime m_evo($x_train); +# @btime m_evo($x_train); @info "EvoTrees GPU" device = "gpu" @@ -139,4 +144,4 @@ CUDA.@time m_evo = fit_evotree(params_evo; x_train, y_train, x_eval=x_train, y_e @info "predict" CUDA.@time pred_evo = m_evo(x_train; device); CUDA.@time pred_evo = m_evo(x_train; device); -@btime m_evo($x_train; device); +# @btime m_evo($x_train; device); diff --git a/docs/src/assets/regression-sinus-binary.png b/docs/src/assets/regression-sinus-binary.png index b442732c..c81dc366 100644 Binary files a/docs/src/assets/regression-sinus-binary.png and b/docs/src/assets/regression-sinus-binary.png differ diff --git a/experiments/depth-debug.jl b/experiments/depth-debug.jl new file mode 100644 index 00000000..4a8df1eb --- /dev/null +++ b/experiments/depth-debug.jl @@ -0,0 +1,129 @@ +using Statistics +using StatsBase:sample +using Base.Threads:@threads +using BenchmarkTools +using Revise +using EvoTrees +using Profile + +nobs = Int(1e6) +num_feat = Int(100) +nrounds = 200 +nthread = Base.Threads.nthreads() +x_train = rand(nobs, num_feat) +y_train = rand(size(x_train, 1)) + +config = EvoTreeRegressor(; + loss=:mse, + nrounds=200, + lambda=0.0, + gamma=0.0, + eta=0.05, + max_depth=10, + min_weight=1.0, + rowsample=0.5, + colsample=0.5, + nbins=64, + tree_type="binary", + rng=123 +) + +################################ +# high-level +################################ +_device = EvoTrees.GPU +@time EvoTrees.fit_evotree(config; x_train, y_train, device = "gpu") + +@time m, cache = EvoTrees.init(config, x_train, y_train); +@time EvoTrees.grow_evotree!(m, cache, config) +@btime EvoTrees.grow_evotree!(m, cache, config) + +Profile.clear() +# Profile.init() +Profile.init(n = 10^5, delay = 0.01) +# @profile m, cache = EvoTrees.init(config, x_train, y_train); +@profile EvoTrees.grow_evotree!(m, cache, config) +Profile.print() + +################################ +# mid-level +################################ +@time m, cache = EvoTrees.init(config, x_train, y_train); +@time EvoTrees.grow_evotree!(m, cache, config) +# compute gradients +@time m, cache = EvoTrees.init(config, x_train, y_train); +@time EvoTrees.update_grads!(cache.∇, cache.pred, cache.y, config) +# subsample rows +@time cache.nodes[1].is = EvoTrees.subsample(cache.is_in, cache.is_out, cache.mask, config.rowsample, config.rng) +# subsample cols +EvoTrees.sample!(config.rng, cache.js_, cache.js, replace=false, ordered=true) +L = EvoTrees._get_struct_loss(m) +# instantiate a tree then grow it +tree = EvoTrees.Tree{L,1}(config.max_depth) +grow! = config.tree_type == "oblivious" ? EvoTrees.grow_otree! : EvoTrees.grow_tree! +@time EvoTrees.grow_tree!( + tree, + cache.nodes, + config, + cache.∇, + cache.edges, + cache.js, + cache.out, + cache.left, + cache.right, + cache.x_bin, + cache.feattypes, + cache.monotone_constraints +) + +using ProfileView +ProfileView.@profview EvoTrees.grow_tree!( + tree, + cache.nodes, + config, + cache.∇, + cache.edges, + cache.js, + cache.out, + cache.left, + cache.right, + cache.x_bin, + cache.feattypes, + cache.monotone_constraints +) + +################################ +# end mid-level +################################ + + +@time m_evo = grow_tree!(params_evo; x_train, y_train, device, print_every_n=100); + +@info "train - no eval" +@time m_evo = fit_evotree(params_evo; x_train, y_train, device, print_every_n=100); + + +offset = 0 +feat = 15 +cond_bin = 32 +@time l, r = split_set_threads!(out, left, right, 𝑖, X_bin, feat, cond_bin, offset, 2^14); +@btime split_set_threads!($out, $left, $right, $𝑖, $X_bin, $feat, $cond_bin, $offset, 2^14); +@code_warntype split_set_1!(left, right, 𝑖, X_bin, feat, cond_bin, offset) + +offset = 0 +feat = 15 +cond_bin = 32 +lid1, rid1 = split_set_threads!(out, left, right, 𝑖, X_bin, feat, cond_bin, offset) +offset = 0 +feat = 14 +cond_bin = 12 +lid2, rid2 = split_set_threads!(out, left, right, lid1, X_bin, feat, cond_bin, offset) +offset = + length(lid1) +feat = 14 +cond_bin = 12 +lid3, rid3 = split_set_threads!(out, left, right, rid1, X_bin, feat, cond_bin, offset) + +lid1_ = deepcopy(lid1) + + + diff --git a/experiments/learnAPI.jl b/experiments/learnAPI.jl new file mode 100644 index 00000000..c1dfd037 --- /dev/null +++ b/experiments/learnAPI.jl @@ -0,0 +1,67 @@ +using EvoTrees + +module LearnAPI + +abstract type Config end +abstract type Learner end +abstract type Model end + +function fit(config::Config; kwargs...) + return nothing +end +function fit(config::Config, data; kwargs...) + return nothing +end +function init(config::Config, data; kwargs...) + return nothing +end +# function fit!(learner::Learner) +# return nothing +# end + +function predict(model::Model, x) + return x +end +function predict!(p, model::Model, x) + return nothing +end + +function isiterative(m) end + +end #module + +struct EvoLearner + params +end + +# 1 args fit: all needed supplemental info passed through kwargs: risk of having fragmentation of naming convention, hard to follow +m = LearnAPI.fit(config::Config; kwargs) +m = LearnAPI.fit(config::EvoTrees.EvoTypes; x_train=xt, y_train=yt) +m = LearnAPI.fit(config::EvoTrees.EvoTypes; x_train=xt, y_train=yt, x_eval=xe, y_eval=ye) + +# 2 args fit: forces the notion of input data on which training is performed. May facilitates dispatch/specialisation on various supported data typees +m = LearnAPI.fit(config::Config, data; kwargs) +m = LearnAPI.fit(config::EvoTrees.EvoTypes, (x_train, y_train)) +m = LearnAPI.fit(config::EvoTrees.EvoTypes, (x_train, y_train); x_eval=xe, y_eval=ye) +m = LearnAPI.fit(config::EvoTrees.EvoTypes, df::DataFrame) + +# Iterative models +import .LearnAPI: isiterative +LearnAPI.isiterative(m::EvoTree) = true + +# 2 args model initialization +# Here a EvoTreeLearner is returned: a comprehensive struct that includes the config, the model, and cache/state +m = LearnAPI.init(config::Config, data::DataFrame; kwargs) +m = LearnAPI.init(config::EvoTrees.EvoTypes, df::DataFrame; x_eval=xe, y_eval=ye) + +LearnAPI.fit!(m::EvoTree) +LearnAPI.fit!(m::EvoTree, data) + +# LearnAPI.fit!(m, config::EvoTrees.EvoTypes; kwargs) +LearnAPI.predict(m::EvoTrees.EvoTypes, x) + +config = EvoTreeRegressor() +# m, cache = LearnAPI.init() + +# should be possible to have model that specify feature treatment upfront at the Config level? +# Or rather have those passed at the fitted level? diff --git a/experiments/readme_plots_cpu.jl b/experiments/readme_plots_cpu.jl index e431f3d9..7c15e433 100644 --- a/experiments/readme_plots_cpu.jl +++ b/experiments/readme_plots_cpu.jl @@ -90,7 +90,7 @@ params1 = EvoTreeRegressor(; nbins=64, lambda=0.1, gamma=0.1, - eta=0.05, + eta=0.1, max_depth=6, min_weight=1.0, rowsample=0.5, @@ -132,7 +132,7 @@ params1 = EvoTreeRegressor(; nbins=64, lambda=0.1, gamma=0.1, - eta=0.05, + eta=0.1, max_depth=6, min_weight=1.0, rowsample=0.5, @@ -288,7 +288,7 @@ params1 = EvoTreeRegressor(; loss=:tweedie, nrounds=500, nbins=64, - lambda=0.5, + lambda=0.1, gamma=0.1, eta=0.1, max_depth=6, @@ -359,7 +359,7 @@ params1 = EvoTreeRegressor(; nbins=64, lambda=0.1, gamma=0.0, - eta=0.05, + eta=0.1, max_depth=6, min_weight=1.0, rowsample=0.5, @@ -389,7 +389,7 @@ params1 = EvoTreeRegressor(; nbins=64, lambda=0.1, gamma=0.0, - eta=0.05, + eta=0.1, max_depth=6, min_weight=1.0, rowsample=0.5, @@ -408,7 +408,7 @@ params1 = EvoTreeRegressor(; nbins=64, lambda=0.1, gamma=0.0, - eta=0.05, + eta=0.1, max_depth=6, min_weight=1.0, rowsample=0.5, @@ -466,10 +466,10 @@ params1 = EvoTreeMLE(; nbins=64, lambda=0.1, gamma=0.1, - eta=0.05, + eta=0.1, max_depth=6, - min_weight=10.0, - rowsample=1.0, + min_weight=10, + rowsample=0.5, colsample=1.0, rng=123, tree_type, @@ -549,12 +549,12 @@ params1 = EvoTrees.EvoTreeMLE(; loss=:logistic, nrounds=500, nbins=64, - lambda=1.0, + lambda=0.1, gamma=0.1, - eta=0.03, + eta=0.1, max_depth=6, - min_weight=1.0, - rowsample=1.0, + min_weight=10, + rowsample=0.5, colsample=1.0, tree_type, rng=123, diff --git a/experiments/readme_plots_gpu.jl b/experiments/readme_plots_gpu.jl index 7fdeb244..33b89c92 100644 --- a/experiments/readme_plots_gpu.jl +++ b/experiments/readme_plots_gpu.jl @@ -249,7 +249,7 @@ params1 = EvoTreeGaussian(; gamma=0.1, eta=0.1, max_depth=6, - min_weight=20, + min_weight=10, rowsample=0.5, colsample=1.0, rng=123, diff --git a/figures/regression-sinus-binary-gpu.png b/figures/regression-sinus-binary-gpu.png index e50b6ef9..1be5632a 100644 Binary files a/figures/regression-sinus-binary-gpu.png and b/figures/regression-sinus-binary-gpu.png differ diff --git a/figures/regression-sinus-binary.png b/figures/regression-sinus-binary.png index b442732c..c81dc366 100644 Binary files a/figures/regression-sinus-binary.png and b/figures/regression-sinus-binary.png differ diff --git a/figures/regression-sinus-oblivious-gpu.png b/figures/regression-sinus-oblivious-gpu.png index 0a67e8b4..44644279 100644 Binary files a/figures/regression-sinus-oblivious-gpu.png and b/figures/regression-sinus-oblivious-gpu.png differ diff --git a/figures/regression-sinus-oblivious.png b/figures/regression-sinus-oblivious.png index 635a35e2..358d6749 100644 Binary files a/figures/regression-sinus-oblivious.png and b/figures/regression-sinus-oblivious.png differ diff --git a/figures/regression-sinus2-binary.png b/figures/regression-sinus2-binary.png index 888638f2..964652e9 100644 Binary files a/figures/regression-sinus2-binary.png and b/figures/regression-sinus2-binary.png differ diff --git a/figures/regression-sinus2-oblivious.png b/figures/regression-sinus2-oblivious.png index b6c32f1c..2a1eb75f 100644 Binary files a/figures/regression-sinus2-oblivious.png and b/figures/regression-sinus2-oblivious.png differ diff --git a/src/fit-utils.jl b/src/fit-utils.jl index 0d3360a0..f49163a4 100644 --- a/src/fit-utils.jl +++ b/src/fit-utils.jl @@ -166,45 +166,51 @@ function split_set_threads!( offset, ) where {S} - chunk_size = cld(length(is), min(cld(length(is), 1024), Threads.nthreads())) + chunk_size = cld(length(is), min(cld(length(is), 16_000), Threads.nthreads())) nblocks = cld(length(is), chunk_size) lefts = zeros(Int, nblocks) rights = zeros(Int, nblocks) - @threads :static for bid = 1:nblocks - lefts[bid], rights[bid] = split_set_chunk!( - left, - right, - is, - bid, - nblocks, - x_bin, - feat, - cond_bin, - feattype, - offset, - chunk_size, - ) + @sync begin + for bid = 1:nblocks + @spawn begin + lefts[bid], rights[bid] = split_set_chunk!( + left, + right, + is, + bid, + nblocks, + x_bin, + feat, + cond_bin, + feattype, + offset, + chunk_size, + ) + end + end end sum_lefts = sum(lefts) cumsum_lefts = cumsum(lefts) cumsum_rights = cumsum(rights) - @threads :static for bid = 1:nblocks - split_views_kernel!( - out, - left, - right, - bid, - offset, - chunk_size, - lefts, - rights, - sum_lefts, - cumsum_lefts, - cumsum_rights, - ) + @sync begin + for bid = 1:nblocks + @spawn split_views_kernel!( + out, + left, + right, + bid, + offset, + chunk_size, + lefts, + rights, + sum_lefts, + cumsum_lefts, + cumsum_rights, + ) + end end return ( diff --git a/src/fit.jl b/src/fit.jl index bd1dd5e9..f6d2aac0 100644 --- a/src/fit.jl +++ b/src/fit.jl @@ -61,17 +61,19 @@ function grow_tree!( end end - # reset - n_next = [1] - n_current = copy(n_next) + # initialize + n_current = [1] depth = 1 # initialize summary stats nodes[1].∑ .= dropdims(sum(Float64, view(∇, :, nodes[1].is), dims=2), dims=2) nodes[1].gain = get_gain(params, nodes[1].∑) + # grow while there are remaining active nodes while length(n_current) > 0 && depth <= params.max_depth offset = 0 # identifies breakpoint for each node set within a depth + n_next = Int[] + if depth < params.max_depth for n_id in eachindex(n_current) n = n_current[n_id] @@ -89,14 +91,15 @@ function grow_tree!( update_hist!(L, nodes[n].h, ∇, x_bin, nodes[n].is, js) end end + @threads :static for n ∈ sort(n_current) + update_gains!(nodes[n], js, params, feattypes, monotone_constraints) + end end for n ∈ sort(n_current) if depth == params.max_depth || nodes[n].∑[end] <= params.min_weight pred_leaf_cpu!(tree.pred, n, nodes[n].∑, params, ∇, nodes[n].is) - popfirst!(n_next) else - update_gains!(nodes[n], js, params, feattypes, monotone_constraints) best = findmax(findmax.(nodes[n].gains)) best_gain = best[1][1] best_bin = best[1][2] @@ -106,12 +109,8 @@ function grow_tree!( tree.cond_bin[n] = best_bin tree.feat[n] = best_feat tree.cond_float[n] = edges[tree.feat[n]][tree.cond_bin[n]] - end - tree.split[n] = tree.cond_bin[n] != 0 - if !tree.split[n] - pred_leaf_cpu!(tree.pred, n, nodes[n].∑, params, ∇, nodes[n].is) - popfirst!(n_next) - else + tree.split[n] = best_bin != 0 + _left, _right = split_set_threads!( out, left, @@ -123,12 +122,14 @@ function grow_tree!( feattypes[best_feat], offset, ) + offset += length(nodes[n].is) nodes[n<<1].is, nodes[n<<1+1].is = _left, _right nodes[n<<1].∑ .= nodes[n].hL[best_feat][:, best_bin] nodes[n<<1+1].∑ .= nodes[n].hR[best_feat][:, best_bin] nodes[n<<1].gain = get_gain(params, nodes[n<<1].∑) nodes[n<<1+1].gain = get_gain(params, nodes[n<<1+1].∑) + if length(_right) >= length(_left) push!(n_next, n << 1) push!(n_next, n << 1 + 1) @@ -136,7 +137,8 @@ function grow_tree!( push!(n_next, n << 1 + 1) push!(n_next, n << 1) end - popfirst!(n_next) + else + pred_leaf_cpu!(tree.pred, n, nodes[n].∑, params, ∇, nodes[n].is) end end end @@ -173,17 +175,18 @@ function grow_otree!( end end - # reset - n_next = [1] - n_current = copy(n_next) + # initialize + n_current = [1] depth = 1 # initialize summary stats nodes[1].∑ .= dropdims(sum(Float64, view(∇, :, nodes[1].is), dims=2), dims=2) nodes[1].gain = get_gain(params, nodes[1].∑) + # grow while there are remaining active nodes while length(n_current) > 0 && depth <= params.max_depth offset = 0 # identifies breakpoint for each node set within a depth + n_next = Int[] min_weight_flag = false for n in n_current @@ -193,7 +196,6 @@ function grow_otree!( for n in n_current # @info "length(nodes[n].is)" length(nodes[n].is) depth n pred_leaf_cpu!(tree.pred, n, nodes[n].∑, params, ∇, nodes[n].is) - popfirst!(n_next) end else # update histograms @@ -213,6 +215,9 @@ function grow_otree!( update_hist!(L, nodes[n].h, ∇, x_bin, nodes[n].is, js) end end + @threads :static for n ∈ n_current + update_gains!(nodes[n], js, params, feattypes, monotone_constraints) + end # initialize gains for node 1 in which all gains of a given depth will be accumulated if depth > 1 @@ -223,7 +228,6 @@ function grow_otree!( gain = 0 # update gains based on the aggregation of all nodes of a given depth. One gains matrix per depth (vs one per node in binary trees). for n ∈ sort(n_current) - update_gains!(nodes[n], js, params, feattypes, monotone_constraints) if n > 1 # accumulate gains in node 1 for j in js nodes[1].gains[j] .+= nodes[n].gains[j] @@ -277,12 +281,10 @@ function grow_otree!( push!(n_next, n << 1 + 1) push!(n_next, n << 1) end - popfirst!(n_next) end else for n in n_current pred_leaf_cpu!(tree.pred, n, nodes[n].∑, params, ∇, nodes[n].is) - popfirst!(n_next) end end end diff --git a/src/gpu/fit.jl b/src/gpu/fit.jl index 30d9bd00..b6751599 100644 --- a/src/gpu/fit.jl +++ b/src/gpu/fit.jl @@ -61,16 +61,18 @@ function grow_tree!( end # initialize - n_next = [1] - n_current = copy(n_next) + n_current = [1] depth = 1 # initialize summary stats nodes[1].∑ .= Vector(vec(sum(∇[:, nodes[1].is], dims=2))) nodes[1].gain = get_gain(params, nodes[1].∑) # should use a GPU version? - # grow while there are remaining active nodes - TO DO histogram substraction hits issue on GPU + + # grow while there are remaining active nodes while length(n_current) > 0 && depth <= params.max_depth offset = 0 # identifies breakpoint for each node set within a depth + n_next = Int[] + if depth < params.max_depth for n_id in eachindex(n_current) n = n_current[n_id] @@ -88,34 +90,26 @@ function grow_tree!( update_hist_gpu!(nodes[n].h, h∇, ∇, x_bin, nodes[n].is, jsg, js) end end + @threads :static for n ∈ sort(n_current) + update_gains!(nodes[n], js, params, feattypes, monotone_constraints) + end end - # grow while there are remaining active nodes for n ∈ sort(n_current) if depth == params.max_depth || nodes[n].∑[end] <= params.min_weight pred_leaf_cpu!(tree.pred, n, nodes[n].∑, params, ∇, nodes[n].is) - popfirst!(n_next) else - # @info "gain & max" - update_gains!(nodes[n], js, params, feattypes, monotone_constraints) best = findmax(findmax.(nodes[n].gains)) best_gain = best[1][1] best_bin = best[1][2] best_feat = best[2] - # if best_gain > nodes[n].gain + params.gamma && best_gain > nodes[n].gains[best_feat][end] + params.gamma if best_gain > nodes[n].gain + params.gamma tree.gain[n] = best_gain - nodes[n].gain tree.cond_bin[n] = best_bin tree.feat[n] = best_feat tree.cond_float[n] = edges[tree.feat[n]][tree.cond_bin[n]] - end - tree.split[n] = tree.cond_bin[n] != 0 - if !tree.split[n] - pred_leaf_cpu!(tree.pred, n, nodes[n].∑, params, ∇, nodes[n].is) - popfirst!(n_next) - else - # @info "split" best_bin typeof(nodes[n].is) length(nodes[n].is) - # @info "split typeof" typeof(out) typeof(left) typeof(nodes[n].is) typeof(x_bin) + tree.split[n] = best_bin != 0 + _left, _right = split_set_threads_gpu!( out, left, @@ -127,12 +121,14 @@ function grow_tree!( feattypes[best_feat], offset, ) + offset += length(nodes[n].is) nodes[n<<1].is, nodes[n<<1+1].is = _left, _right nodes[n<<1].∑ .= nodes[n].hL[best_feat][:, best_bin] nodes[n<<1+1].∑ .= nodes[n].hR[best_feat][:, best_bin] nodes[n<<1].gain = get_gain(params, nodes[n<<1].∑) nodes[n<<1+1].gain = get_gain(params, nodes[n<<1+1].∑) + if length(_right) >= length(_left) push!(n_next, n << 1) push!(n_next, n << 1 + 1) @@ -140,8 +136,8 @@ function grow_tree!( push!(n_next, n << 1 + 1) push!(n_next, n << 1) end - # @info "split post" length(_left) length(_right) - popfirst!(n_next) + else + pred_leaf_cpu!(tree.pred, n, nodes[n].∑, params, ∇, nodes[n].is) end end end @@ -182,8 +178,7 @@ function grow_otree!( end # initialize - n_next = [1] - n_current = copy(n_next) + n_current = [1] depth = 1 # initialize summary stats @@ -193,6 +188,7 @@ function grow_otree!( # grow while there are remaining active nodes while length(n_current) > 0 && depth <= params.max_depth offset = 0 # identifies breakpoint for each node set within a depth + n_next = Int[] min_weight_flag = false for n in n_current @@ -202,7 +198,6 @@ function grow_otree!( for n in n_current # @info "length(nodes[n].is)" length(nodes[n].is) depth n pred_leaf_cpu!(tree.pred, n, nodes[n].∑, params, ∇, nodes[n].is) - popfirst!(n_next) end else # update histograms @@ -222,6 +217,9 @@ function grow_otree!( update_hist_gpu!(nodes[n].h, h∇, ∇, x_bin, nodes[n].is, jsg, js) end end + @threads :static for n ∈ n_current + update_gains!(nodes[n], js, params, feattypes, monotone_constraints) + end # initialize gains for node 1 in which all gains of a given depth will be accumulated if depth > 1 @@ -232,7 +230,6 @@ function grow_otree!( gain = 0 # update gains based on the aggregation of all nodes of a given depth. One gains matrix per depth (vs one per node in binary trees). for n ∈ sort(n_current) - update_gains!(nodes[n], js, params, feattypes, monotone_constraints) if n > 1 # accumulate gains in node 1 for j in js nodes[1].gains[j] .+= nodes[n].gains[j] @@ -286,12 +283,10 @@ function grow_otree!( push!(n_next, n << 1 + 1) push!(n_next, n << 1) end - popfirst!(n_next) end else for n in n_current pred_leaf_cpu!(tree.pred, n, nodes[n].∑, params, ∇, nodes[n].is) - popfirst!(n_next) end end end