Skip to content

Commit

Permalink
fix cpu repro
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremiedb committed Nov 21, 2022
1 parent 6e0d677 commit 02c8333
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 4 deletions.
3 changes: 1 addition & 2 deletions src/find_split.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
#############################################
# Get the braking points
#############################################
function get_edges(X::AbstractMatrix{T}, nbins) where {T}
rng = Random.MersenneTwister(123)
function get_edges(X::AbstractMatrix{T}, nbins, rng = Random.MersenneTwister()) where {T}
nobs = min(size(X, 1), 1000 * nbins)
obs = rand(rng, 1:size(X, 1), nobs)
edges = Vector{Vector{T}}(undef, size(X, 2))
Expand Down
2 changes: 1 addition & 1 deletion src/fit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ function init_evotree(
∇[end, :] .= w

# binarize data into quantiles
edges = get_edges(x, params.nbins)
edges = get_edges(x, params.nbins, params.rng)
x_bin = binarize(x, edges)

is_in = zeros(UInt32, x_size[1])
Expand Down
2 changes: 1 addition & 1 deletion src/gpu/fit_gpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ function init_evotree_gpu(
∇[end, :] .= w

# binarize data into quantiles
edges = get_edges(x, params.nbins)
edges = get_edges(x, params.nbins, params.rng)
x_bin = CuArray(binarize(x, edges))

is_in = CUDA.zeros(UInt32, x_size[1])
Expand Down

0 comments on commit 02c8333

Please sign in to comment.