Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[FFI] npx.softmax, npx.activation, npx.batch_norm, npx.fully_connected (
Browse files Browse the repository at this point in the history
  • Loading branch information
barry-jin authored Mar 26, 2021
1 parent 03e7cc2 commit 9645e63
Show file tree
Hide file tree
Showing 12 changed files with 1,223 additions and 3 deletions.
5 changes: 4 additions & 1 deletion python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,9 @@ def write_all_str(module_file, module_all_list):

_NP_EXT_OP_PREFIX = '_npx_'
_NP_EXT_OP_SUBMODULE_LIST = ['_image_', '_random_']
_NP_EXT_OP_IMPLEMENTED_SET = {'_npx_softmax', '_npx_log_softmax', '_npx_masked_softmax',
'_npx_masked_log_softmax', '_npx_activation',
'_npx_batch_norm', '_npx_fully_connected'}

_NP_INTERNAL_OP_PREFIX = '_npi_'

Expand Down Expand Up @@ -855,7 +858,7 @@ def _init_np_op_module(root_module_name, np_module_name, mx_module_name, make_op
elif np_module_name == 'numpy_extension':
op_name_prefix = _NP_EXT_OP_PREFIX
submodule_name_list = _NP_EXT_OP_SUBMODULE_LIST
op_implemented_set = set()
op_implemented_set = _NP_EXT_OP_IMPLEMENTED_SET
elif np_module_name == 'numpy._internal':
op_name_prefix = _NP_INTERNAL_OP_PREFIX
submodule_name_list = []
Expand Down
24 changes: 24 additions & 0 deletions python/mxnet/ndarray/numpy_extension/_api_internal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Namespace for numpy_extension api."""

from ..._ffi.function import _init_api

__all__ = []

_init_api("_npx", "mxnet.ndarray.numpy_extension._api_internal")
Loading

0 comments on commit 9645e63

Please sign in to comment.