Skip to content

Commit

Permalink
Update default value of AUC to 200 (#238)
Browse files Browse the repository at this point in the history
* update

* typo
  • Loading branch information
zhichao-li authored May 22, 2018
1 parent 8240713 commit a95b938
Show file tree
Hide file tree
Showing 7 changed files with 7 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def inference(image_path, model_path, sc):
print("Need parameters: <modelPath> <imagePath>")
exit(-1)

sc = get_nncontext()
sc = get_nncontext("image_inference")

model_path = sys.argv[1]
image_path = sys.argv[2]
Expand Down
2 changes: 1 addition & 1 deletion pyzoo/zoo/pipeline/api/keras/metrics/auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,5 @@ class AUC(JavaValue):
>>> meter = AUC(20)
creating: createAUC
"""
def __init__(self, threshold_num, bigdl_type="float"):
def __init__(self, threshold_num=200, bigdl_type="float"):
JavaValue.__init__(self, None, bigdl_type, threshold_num)
2 changes: 1 addition & 1 deletion pyzoo/zoo/pipeline/api/keras/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def to_bigdl_metric(metric):
elif metric == "mae":
return MAE()
elif metric == "auc":
return AUC(1000)
return AUC()
elif metric == "loss":
return Loss()
elif metric == "treennaccuracy":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ object Predict {

def main(args: Array[String]): Unit = {
parser.parse(args, TopNClassificationParam()).foreach { params =>
val sc = NNContext.getNNContext()
val sc = NNContext.getNNContext("Image Classification")
val model = ImageClassifier.loadModel[Float](params.model)
val data = ImageSet.read(params.imageFolder, sc, params.nPartition)
val output = model.predictImageSet(data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ object ImageInferenceExample {

val defaultParams = Utils.LocalParams()
Utils.parser.parse(args, defaultParams).foreach { params =>
val sc = NNContext.getNNContext()
val sc = NNContext.getNNContext("ImageInference")

val getImageName = udf { row: Row => row.getString(0)}
val imageDF = NNImageReader.readImages(params.folder, sc)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ object KerasUtils {
metric.toLowerCase() match {
case "accuracy" => new Top1Accuracy[T]()
case "mae" => new MAE[T]()
case "auc" => new AUC[T](1000)
case "auc" => new AUC[T]()
case "loss" => new Loss[T]()
case "treennaccuracy" => new TreeNNAccuracy[T]()
case _ => throw new IllegalArgumentException(s"Unsupported metric: ${metric}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ class AucScore(private val tp: Tensor[Float], private val fp: Tensor[Float],
* @param thresholdNum The number of thresholds. The quality of approximation
* may vary depending on thresholdNum.
*/
class AUC[T](thresholdNum: Int)(implicit ev: TensorNumeric[T])
class AUC[T](thresholdNum: Int = 200)(implicit ev: TensorNumeric[T])
extends ValidationMethod[T] {

override def apply(output: Activity, target: Activity):
Expand Down

0 comments on commit a95b938

Please sign in to comment.