diff --git a/python/dllib/src/bigdl/dllib/utils/file_utils.py b/python/dllib/src/bigdl/dllib/utils/file_utils.py index 0e1fa58c3d0..f0f8b484e63 100644 --- a/python/dllib/src/bigdl/dllib/utils/file_utils.py +++ b/python/dllib/src/bigdl/dllib/utils/file_utils.py @@ -13,12 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from bigdl.util.common import Sample as BSample, JTensor as BJTensor, callBigDlFunc +from bigdl.util.common import Sample as BSample, JTensor as BJTensor,\ + JavaCreator, _get_gateway, _java2py, _py2java import numpy as np def to_list_of_numpy(elements): - if isinstance(elements, np.ndarray): return [elements] elif np.isscalar(elements): @@ -39,7 +39,28 @@ def to_list_of_numpy(elements): def set_core_number(num): - callBigDlFunc("float", "setCoreNumber", num) + callZooFunc("float", "setCoreNumber", num) + + +def callZooFunc(bigdl_type, name, *args): + """ Call API in PythonBigDL """ + gateway = _get_gateway() + args = [_py2java(gateway, a) for a in args] + error = Exception("Cannot find function: %s" % name) + for jinvoker in JavaCreator.instance(bigdl_type, gateway).value: + # hasattr(jinvoker, name) always return true here, + # so you need to invoke the method to check if it exist or not + try: + api = getattr(jinvoker, name) + java_result = api(*args) + result = _java2py(gateway, java_result) + except Exception as e: + error = e + if "does not exist" not in str(e): + raise e + else: + return result + raise error class JTensor(BJTensor):