Skip to content

Commit

Permalink
add callZooFunc and change all callBigDlFunc to callZooFunc (intel-an…
Browse files Browse the repository at this point in the history
  • Loading branch information
qiuxin2012 authored Nov 26, 2019
1 parent 88ba569 commit ac66722
Showing 1 changed file with 17 additions and 15 deletions.
32 changes: 17 additions & 15 deletions python/dllib/src/bigdl/dllib/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
# limitations under the License.
#

from bigdl.util.common import JavaValue, callBigDlFunc
from bigdl.util.common import JavaValue
from zoo.common.utils import callZooFunc


class Estimator(JavaValue):
Expand All @@ -24,9 +25,10 @@ class Estimator(JavaValue):
Estimator wraps a model, and provide an uniform training, evaluation or prediction operation on
both local host and distributed spark environment.
"""

def __init__(self, model, optim_methods=None, model_dir=None, jvalue=None, bigdl_type="float"):
self.bigdl_type = bigdl_type
self.value = jvalue if jvalue else callBigDlFunc(
self.value = jvalue if jvalue else callZooFunc(
bigdl_type, self.jvm_class_constructor(), model, optim_methods, model_dir)

def clear_gradient_clipping(self):
Expand All @@ -35,7 +37,7 @@ def clear_gradient_clipping(self):
In order to take effect, it needs to be called before fit.
:return:
"""
callBigDlFunc(self.bigdl_type, "clearGradientClipping")
callZooFunc(self.bigdl_type, "clearGradientClipping")

def set_constant_gradient_clipping(self, min, max):
"""
Expand All @@ -45,7 +47,7 @@ def set_constant_gradient_clipping(self, min, max):
:param max: The maximum value to clip by.
:return:
"""
callBigDlFunc(self.bigdl_type, "setConstantGradientClipping", self.value, min, max)
callZooFunc(self.bigdl_type, "setConstantGradientClipping", self.value, min, max)

def set_l2_norm_gradient_clipping(self, clip_norm):
"""
Expand All @@ -54,7 +56,7 @@ def set_l2_norm_gradient_clipping(self, clip_norm):
:param clip_norm: Gradient L2-Norm threshold.
:return:
"""
callBigDlFunc(self.bigdl_type, "setGradientClippingByL2Norm", self.value, clip_norm)
callZooFunc(self.bigdl_type, "setGradientClippingByL2Norm", self.value, clip_norm)

def train(self, train_set, criterion, end_trigger=None, checkpoint_trigger=None,
validation_set=None, validation_method=None, batch_size=32):
Expand All @@ -73,9 +75,9 @@ def train(self, train_set, criterion, end_trigger=None, checkpoint_trigger=None,
:param batch_size:
:return: Estimator
"""
callBigDlFunc(self.bigdl_type, "estimatorTrain", self.value, train_set,
criterion, end_trigger, checkpoint_trigger, validation_set,
validation_method, batch_size)
callZooFunc(self.bigdl_type, "estimatorTrain", self.value, train_set,
criterion, end_trigger, checkpoint_trigger, validation_set,
validation_method, batch_size)

def train_imagefeature(self, train_set, criterion, end_trigger=None, checkpoint_trigger=None,
validation_set=None, validation_method=None, batch_size=32):
Expand All @@ -94,9 +96,9 @@ def train_imagefeature(self, train_set, criterion, end_trigger=None, checkpoint_
:param batch_size: Batch size
:return:
"""
callBigDlFunc(self.bigdl_type, "estimatorTrainImageFeature", self.value, train_set,
criterion, end_trigger, checkpoint_trigger, validation_set,
validation_method, batch_size)
callZooFunc(self.bigdl_type, "estimatorTrainImageFeature", self.value, train_set,
criterion, end_trigger, checkpoint_trigger, validation_set,
validation_method, batch_size)

def evaluate(self, validation_set, validation_method, batch_size=32):
"""
Expand All @@ -106,8 +108,8 @@ def evaluate(self, validation_set, validation_method, batch_size=32):
:param batch_size: batch size
:return: validation results
"""
callBigDlFunc(self.bigdl_type, "estimatorEvaluate", self.value,
validation_set, validation_method, batch_size)
callZooFunc(self.bigdl_type, "estimatorEvaluate", self.value,
validation_set, validation_method, batch_size)

def evaluate_imagefeature(self, validation_set, validation_method, batch_size=32):
"""
Expand All @@ -117,5 +119,5 @@ def evaluate_imagefeature(self, validation_set, validation_method, batch_size=32
:param batch_size: batch size
:return: validation results
"""
callBigDlFunc(self.bigdl_type, "estimatorEvaluateImageFeature", self.value,
validation_set, validation_method, batch_size)
callZooFunc(self.bigdl_type, "estimatorEvaluateImageFeature", self.value,
validation_set, validation_method, batch_size)

0 comments on commit ac66722

Please sign in to comment.