-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #790 from Angel-ML/Auto-ML
AutoML
- Loading branch information
Showing
192 changed files
with
48,593 additions
and
32 deletions.
There are no files selected for viewing
55 changes: 55 additions & 0 deletions
55
angel-ps/mllib/src/main/scala/com/tencent/angel/ml/auto/Example.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
/* | ||
* Tencent is pleased to support the open source community by making Angel available. | ||
* | ||
* Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in | ||
* compliance with the License. You may obtain a copy of the License at | ||
* | ||
* https://opensource.org/licenses/Apache-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software distributed under the License | ||
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express | ||
* or implied. See the License for the specific language governing permissions and limitations under | ||
* the License. | ||
* | ||
*/ | ||
|
||
|
||
package com.tencent.angel.ml.auto | ||
|
||
import com.tencent.angel.ml.auto.acquisition.optimizer.{AcqOptimizer, RandomSearch} | ||
import com.tencent.angel.ml.auto.acquisition.{Acquisition, EI} | ||
import com.tencent.angel.ml.auto.config.ConfigurationSpace | ||
import com.tencent.angel.ml.auto.parameter.{ContinuousSpace, DiscreteSpace, ParamSpace} | ||
import com.tencent.angel.ml.auto.setting.Setting | ||
import com.tencent.angel.ml.auto.solver.{Solver, SolverWithTrail} | ||
import com.tencent.angel.ml.auto.surrogate.{RFSurrogate, Surrogate} | ||
import com.tencent.angel.ml.auto.trail.{TestTrail, Trail} | ||
import com.tencent.angel.ml.math2.vector.IntFloatVector | ||
|
||
object Example extends App { | ||
|
||
override def main(args: Array[String]): Unit = { | ||
val param1: ParamSpace[Float] = new ContinuousSpace("param1", 0, 10, 11) | ||
val param2: ParamSpace[Float] = new ContinuousSpace("param2", -5, 5, 11) | ||
val param3: ParamSpace[Float] = new DiscreteSpace[Float]("param3", List(0.0f, 1.0f, 3.0f, 5.0f)) | ||
val param4: ParamSpace[Float] = new DiscreteSpace[Float]("param4", List(-5.0f, -3.0f, 0.0f, 3.0f, 5.0f)) | ||
val cs: ConfigurationSpace = new ConfigurationSpace("cs") | ||
cs.addParam(param1) | ||
cs.addParam(param2) | ||
cs.addParam(param3) | ||
cs.addParam(param4) | ||
Setting.setBatchSize(1) | ||
Setting.setSampleSize(100) | ||
val sur: Surrogate = new RFSurrogate(cs.paramNum, true) | ||
val acq: Acquisition = new EI(sur, 0.1f) | ||
val opt: AcqOptimizer = new RandomSearch(acq, cs) | ||
val solver: Solver = new Solver(cs, sur, acq, opt) | ||
val trail: Trail = new TestTrail() | ||
val runner: SolverWithTrail = new SolverWithTrail(solver, trail) | ||
val result: (IntFloatVector, Float) = runner.run(100) | ||
sur.stop() | ||
println(s"Best configuration ${result._1.getStorage.getValues.mkString(",")}, best performance: ${result._2}") | ||
} | ||
} |
37 changes: 37 additions & 0 deletions
37
angel-ps/mllib/src/main/scala/com/tencent/angel/ml/auto/acquisition/Acquisition.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
/* | ||
* Tencent is pleased to support the open source community by making Angel available. | ||
* | ||
* Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in | ||
* compliance with the License. You may obtain a copy of the License at | ||
* | ||
* https://opensource.org/licenses/Apache-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software distributed under the License | ||
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express | ||
* or implied. See the License for the specific language governing permissions and limitations under | ||
* the License. | ||
* | ||
*/ | ||
|
||
|
||
package com.tencent.angel.ml.auto.acquisition | ||
|
||
import com.tencent.angel.ml.auto.surrogate.Surrogate | ||
import com.tencent.angel.ml.math2.vector.IntFloatVector | ||
|
||
/** | ||
* Abstract base class for acquisition function | ||
*/ | ||
abstract class Acquisition(val surrogate: Surrogate) { | ||
|
||
/** | ||
* Computes the acquisition value for a given point X | ||
* | ||
* @param X : (1, D), the input points where the acquisition function should be evaluated. | ||
* @return (1, 1) Expected Improvement of X, (1, D) Derivative of Expected Improvement at X | ||
*/ | ||
def compute(X: IntFloatVector, derivative: Boolean = false): (Float, IntFloatVector) | ||
|
||
} |
59 changes: 59 additions & 0 deletions
59
angel-ps/mllib/src/main/scala/com/tencent/angel/ml/auto/acquisition/EI.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
/* | ||
* Tencent is pleased to support the open source community by making Angel available. | ||
* | ||
* Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in | ||
* compliance with the License. You may obtain a copy of the License at | ||
* | ||
* https://opensource.org/licenses/Apache-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software distributed under the License | ||
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express | ||
* or implied. See the License for the specific language governing permissions and limitations under | ||
* the License. | ||
* | ||
*/ | ||
|
||
|
||
package com.tencent.angel.ml.auto.acquisition | ||
|
||
import com.tencent.angel.ml.auto.surrogate.Surrogate | ||
import com.tencent.angel.ml.math2.storage.IntFloatDenseVectorStorage | ||
import com.tencent.angel.ml.math2.vector.IntFloatVector | ||
import org.apache.commons.logging.{Log, LogFactory} | ||
import org.apache.commons.math3.distribution.NormalDistribution | ||
|
||
/** | ||
* Expected improvement. | ||
* @param surrogate | ||
* @param par : Controls the balance between exploration and exploitation of the acquisition function, default=0.0 | ||
* | ||
*/ | ||
class EI(override val surrogate: Surrogate, val par: Float) extends Acquisition(surrogate) { | ||
val LOG: Log = LogFactory.getLog(classOf[Surrogate]) | ||
|
||
override def compute(X: IntFloatVector, derivative: Boolean = false): (Float, IntFloatVector) = { | ||
val pred = surrogate.predict(X) // (mean, variance) | ||
|
||
// Use the best seen observation as incumbent | ||
val eta: Float = surrogate.curBest._2 | ||
//println(s"best seen result: $eta") | ||
|
||
val s: Float = Math.sqrt(pred._2).toFloat | ||
|
||
if (s == 0) { | ||
// if std is zero, we have observed x on all instances | ||
// using a RF, std should be never exactly 0.0 | ||
(0.0f, new IntFloatVector(X.dim().toInt, new IntFloatDenseVectorStorage())) | ||
} else { | ||
val z = (eta - pred._1 - par) / s | ||
val norm: NormalDistribution = new NormalDistribution | ||
val cdf: Double = norm.cumulativeProbability(z) | ||
val pdf: Double = norm.density(z) | ||
val f = s * (z * cdf + pdf) | ||
println(s"cur best: $eta, z: $z, cdf: $cdf, pdf: $pdf, f: $f") | ||
(f.toFloat, new IntFloatVector(X.dim().toInt, new IntFloatDenseVectorStorage())) | ||
} | ||
} | ||
} |
40 changes: 40 additions & 0 deletions
40
...s/mllib/src/main/scala/com/tencent/angel/ml/auto/acquisition/optimizer/AcqOptimizer.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
/* | ||
* Tencent is pleased to support the open source community by making Angel available. | ||
* | ||
* Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in | ||
* compliance with the License. You may obtain a copy of the License at | ||
* | ||
* https://opensource.org/licenses/Apache-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software distributed under the License | ||
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express | ||
* or implied. See the License for the specific language governing permissions and limitations under | ||
* the License. | ||
* | ||
*/ | ||
|
||
|
||
package com.tencent.angel.ml.auto.acquisition.optimizer | ||
|
||
import com.tencent.angel.ml.auto.acquisition.Acquisition | ||
import com.tencent.angel.ml.auto.config.{Configuration,ConfigurationSpace} | ||
|
||
/** | ||
* Abstract base class for acquisition maximization. | ||
* @param acqFunc : The acquisition function which will be maximized | ||
* @param configSpace : Configuration space of parameters | ||
*/ | ||
abstract class AcqOptimizer(val acqFunc: Acquisition, val configSpace: ConfigurationSpace) { | ||
|
||
/** | ||
* Maximizes the given acquisition function. | ||
* | ||
* @param numPoints : Number of queried points. | ||
* @return A set of tuple(acquisition value, Configuration). | ||
*/ | ||
def maximize(numPoints: Int, sorted: Boolean = true): List[(Float, Configuration)] | ||
|
||
def maximize: (Float, Configuration) | ||
} |
46 changes: 46 additions & 0 deletions
46
...ps/mllib/src/main/scala/com/tencent/angel/ml/auto/acquisition/optimizer/LocalSearch.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
/* | ||
* Tencent is pleased to support the open source community by making Angel available. | ||
* | ||
* Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in | ||
* compliance with the License. You may obtain a copy of the License at | ||
* | ||
* https://opensource.org/licenses/Apache-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software distributed under the License | ||
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express | ||
* or implied. See the License for the specific language governing permissions and limitations under | ||
* the License. | ||
* | ||
*/ | ||
|
||
|
||
package com.tencent.angel.ml.auto.acquisition.optimizer | ||
|
||
import com.tencent.angel.ml.auto.acquisition.Acquisition | ||
import com.tencent.angel.ml.auto.config.{Configuration, ConfigurationSpace} | ||
|
||
/** | ||
* Implementation of local search. | ||
* | ||
* @param acqFunc : The acquisition function which will be maximized | ||
* @param configSpace : Configuration space of parameters | ||
* @param epsilon : In order to perform a local move one of the incumbent's neighbors needs at least an improvement higher than epsilon | ||
* @param numIters : Maximum number of iterations that the local search will perform | ||
*/ | ||
class LocalSearch(override val acqFunc: Acquisition, override val configSpace: ConfigurationSpace, | ||
epsilon: String, numIters: Int) | ||
extends AcqOptimizer(acqFunc, configSpace) { | ||
|
||
/** | ||
* Starts a local search from the given start point and quits if either the max number of steps is reached or | ||
* no neighbor with an higher improvement was found | ||
* | ||
* @param numPoints : Number of queried points. | ||
* @return A set of tuple(acquisition_value, Configuration). | ||
*/ | ||
override def maximize(numPoints: Int, sorted: Boolean = true): List[(Float, Configuration)] = ??? | ||
|
||
override def maximize: (Float, Configuration) = ??? | ||
} |
54 changes: 54 additions & 0 deletions
54
...s/mllib/src/main/scala/com/tencent/angel/ml/auto/acquisition/optimizer/RandomSearch.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
/* | ||
* Tencent is pleased to support the open source community by making Angel available. | ||
* | ||
* Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in | ||
* compliance with the License. You may obtain a copy of the License at | ||
* | ||
* https://opensource.org/licenses/Apache-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software distributed under the License | ||
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express | ||
* or implied. See the License for the specific language governing permissions and limitations under | ||
* the License. | ||
* | ||
*/ | ||
|
||
|
||
package com.tencent.angel.ml.auto.acquisition.optimizer | ||
|
||
import com.tencent.angel.ml.auto.acquisition.Acquisition | ||
import com.tencent.angel.ml.auto.config.{Configuration, ConfigurationSpace} | ||
import com.tencent.angel.ml.auto.setting.Setting | ||
import org.apache.commons.logging.{Log, LogFactory} | ||
|
||
import scala.util.Random | ||
|
||
/** | ||
* Get candidate solutions via random sampling of configurations. | ||
* | ||
* @param acqFunc : The acquisition function which will be maximized | ||
* @param configSpace : Configuration space of parameters | ||
* @param seed | ||
*/ | ||
class RandomSearch(override val acqFunc: Acquisition, override val configSpace: ConfigurationSpace, | ||
seed: Int = 100) extends AcqOptimizer(acqFunc, configSpace) { | ||
val LOG: Log = LogFactory.getLog(classOf[RandomSearch]) | ||
|
||
val rd = new Random(seed) | ||
|
||
override def maximize(numPoints: Int, sorted: Boolean = true): List[(Float, Configuration)] = { | ||
//println(s"maximize RandomSearch") | ||
val configs: List[Configuration] = configSpace.sampleConfig(Setting.sampleSize) | ||
configs.foreach( config => println(s"sample a configuration: ${config.getVector.getStorage.getValues.mkString(",")}")) | ||
if (sorted) | ||
configs.map{config => (acqFunc.compute(config.getVector)._1, config)}.sortWith(_._1 > _._1).take(numPoints) | ||
else | ||
rd.shuffle(configs.map{config => (0.0f, config)}).take(numPoints) | ||
} | ||
|
||
override def maximize: (Float, Configuration) = { | ||
maximize(1, true).head | ||
} | ||
} |
42 changes: 42 additions & 0 deletions
42
angel-ps/mllib/src/main/scala/com/tencent/angel/ml/auto/config/Configuration.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
/* | ||
* Tencent is pleased to support the open source community by making Angel available. | ||
* | ||
* Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in | ||
* compliance with the License. You may obtain a copy of the License at | ||
* | ||
* https://opensource.org/licenses/Apache-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software distributed under the License | ||
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express | ||
* or implied. See the License for the specific language governing permissions and limitations under | ||
* the License. | ||
* | ||
*/ | ||
|
||
|
||
package com.tencent.angel.ml.auto.config | ||
|
||
import com.tencent.angel.ml.math2.vector.IntFloatVector | ||
|
||
/** | ||
* A single configuration | ||
* | ||
* @param configSpace : The configuration space for this configuration | ||
* @param vector : A vector for efficient representation of configuration. | ||
*/ | ||
class Configuration(configSpace: ConfigurationSpace, vector: IntFloatVector) { | ||
|
||
def getVector: IntFloatVector = vector | ||
|
||
def getValues: List[Float] = vector.getStorage.getValues.toList | ||
|
||
def keys: List[String] = configSpace.param2Idx.keys.toList | ||
|
||
def get(name: String): Float = get(configSpace.param2Idx.getOrElse(name, -1)) | ||
|
||
def get(idx: Int): Float = vector.get(idx) | ||
|
||
def contains(name: String): Boolean = configSpace.param2Idx.contains(name) | ||
} |
Oops, something went wrong.