Skip to content
This repository has been archived by the owner on Apr 26, 2021. It is now read-only.

Commit

Permalink
updated to MLJ 0.14
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaodaigh committed Sep 24, 2020
1 parent 69e98e5 commit 926da7e
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 23 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "JLBoostMLJ"
uuid = "8b86df2c-1bc3-481d-95df-1c4d5a20ed95"
authors = ["Dai ZJ <[email protected]>"]
version = "0.1.8"
version = "0.1.9"

[deps]
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Expand All @@ -15,7 +15,7 @@ ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
DataFrames = "0.21"
JLBoost = "^0.1.8"
LossFunctions = "0.5, 0.6"
MLJ = "0.10, 0.11, 0.12, 0.13"
MLJ = "0.10, 0.11, 0.12, 0.13, 0.14"
MLJBase = "0.12, 0.13, 0.14, 0.15"
ScientificTypes = "0.7, 0.8, 1.0"
julia = "1"
Expand Down
34 changes: 17 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ model = JLBoostClassifier()

````
JLBoostClassifier(
loss = LogitLogLoss(),
loss = JLBoost.LogitLogLoss(),
nrounds = 1,
subsample = 1.0,
eta = 1.0,
max_depth = 6,
min_child_weight = 1.0,
lambda = 0.0,
gamma = 0.0,
colsample_bytree = 1) @810
colsample_bytree = 1) @087
````


Expand All @@ -47,11 +47,11 @@ mljmachine = machine(model, X, y)


````
Machine{JLBoostClassifier} @772 trained 0 times.
Machine{JLBoostClassifier} @730 trained 0 times.
args:
1: Source @097 ⏎ `ScientificTypes.Table{AbstractArray{ScientificTypes.C
1: Source @910 ⏎ `ScientificTypes.Table{AbstractArray{ScientificTypes.C
ontinuous,1}}`
2: Source @077 ⏎ `AbstractArray{ScientificTypes.Count,1}`
2: Source @954 ⏎ `AbstractArray{ScientificTypes.Count,1}`
````


Expand Down Expand Up @@ -81,11 +81,11 @@ Choosing a split on SepalLength
Choosing a split on SepalWidth
Choosing a split on PetalLength
Choosing a split on PetalWidth
Machine{JLBoostClassifier} @772 trained 1 time.
Machine{JLBoostClassifier} @730 trained 1 time.
args:
1: Source @097 ⏎ `ScientificTypes.Table{AbstractArray{ScientificTypes.C
1: Source @910 ⏎ `ScientificTypes.Table{AbstractArray{ScientificTypes.C
ontinuous,1}}`
2: Source @077 ⏎ `AbstractArray{ScientificTypes.Count,1}`
2: Source @954 ⏎ `AbstractArray{ScientificTypes.Count,1}`
````


Expand Down Expand Up @@ -216,11 +216,11 @@ m = machine(tm, X, y_cate)


````
Machine{ProbabilisticTunedModel{Grid,…}} @388 trained 0 times.
Machine{ProbabilisticTunedModel{Grid,…}} @109 trained 0 times.
args:
1: Source @578 ⏎ `ScientificTypes.Table{AbstractArray{ScientificTypes.C
1: Source @664 ⏎ `ScientificTypes.Table{AbstractArray{ScientificTypes.C
ontinuous,1}}`
2: Source @226 ⏎ `AbstractArray{ScientificTypes.Multiclass{2},1}`
2: Source @788 ⏎ `AbstractArray{ScientificTypes.Multiclass{2},1}`
````


Expand All @@ -235,11 +235,11 @@ fit!(m)


````
Machine{ProbabilisticTunedModel{Grid,…}} @388 trained 1 time.
Machine{ProbabilisticTunedModel{Grid,…}} @109 trained 1 time.
args:
1: Source @578 ⏎ `ScientificTypes.Table{AbstractArray{ScientificTypes.C
1: Source @664 ⏎ `ScientificTypes.Table{AbstractArray{ScientificTypes.C
ontinuous,1}}`
2: Source @226 ⏎ `AbstractArray{ScientificTypes.Multiclass{2},1}`
2: Source @788 ⏎ `AbstractArray{ScientificTypes.Multiclass{2},1}`
````


Expand Down Expand Up @@ -287,15 +287,15 @@ Choosing a split on SepalLength
Choosing a split on SepalWidth
Choosing a split on PetalLength
Choosing a split on PetalWidth
(fitresult = (treemodel = JLBoostTreeModel(AbstractJLBoostTree[eta = 1.0 (t
ree weight)
(fitresult = (treemodel = JLBoost.JLBoostTrees.JLBoostTreeModel(JLBoost.JLB
oostTrees.AbstractJLBoostTree[eta = 1.0 (tree weight)
-- PetalLength <= 1.9
---- weight = 2.0
-- PetalLength > 1.9
---- weight = -2.0
], LogitLogLoss(), :__y__),
], JLBoost.LogitLogLoss(), :__y__),
target_levels = Bool[0, 1],),
cache = nothing,
report = (AUC = 0.16666666666666669,
Expand Down
8 changes: 4 additions & 4 deletions src/mlj.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
export fit, predict, fitted_params, JLBoostMLJModel, JLBoostClassifier, JLBoostRegressor, JLBoostCount

#using MLJBase
import MLJBase
import MLJBase: Probabilistic, Deterministic, clean!, fit, predict, fitted_params, load_path, Table
import MLJBase: package_name, package_uuid, package_url, is_pure_julia, package_license
import MLJBase: input_scitype, target_scitype, docstring, UnivariateFinite

using ScientificTypes: Continuous, OrderedFactor, Count, Multiclass, Finite

using LossFunctions: PoissonLoss, L2DistLoss
using JLBoost: LogitLogLoss, jlboost, AUC, gini, feature_importance, predict
using JLBoost: LogitLogLoss, jlboost, AUC, gini, feature_importance

using DataFrames: DataFrame, nrow, levels, categorical

Expand Down Expand Up @@ -194,15 +194,15 @@ fitted_params(model::JLBoostMLJModel, fitresult) = (fitresult = fitresult.treemo


# seehttps://alan-turing-institute.github.io/MLJ.jl/stable/adding_models_for_general_use/#The-predict-method-1
predict(model::JLBoostClassifier, fitresult, Xnew) = begin
function MLJBase.predict(model::JLBoostClassifier, fitresult, Xnew)
res = JLBoost.predict(fitresult.treemodel, Xnew)
p = 1 ./ (1 .+ exp.(-res))
levels_cate = categorical(fitresult.target_levels)
[UnivariateFinite(levels_cate, [p, 1-p]) for p in p]
end


predict(model::JLBoostMLJModel, fitresult, Xnew) = begin
function MLJBase.predict(model::JLBoostMLJModel, fitresult, Xnew)
JLBoost.predict(fitresult, Xnew)
end

Expand Down

2 comments on commit 926da7e

@xiaodaigh
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/21896

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.9 -m "<description of version>" 926da7e6dae849932700206616610e7f9b0f7367
git push origin v0.1.9

Please sign in to comment.