Skip to content

Commit

Permalink
Check Android SDK in runtime instead of compile time
Browse files Browse the repository at this point in the history
  • Loading branch information
daquexian committed Aug 11, 2019
1 parent 156eab7 commit ffe6d9a
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 81 deletions.
95 changes: 57 additions & 38 deletions dnnlibrary/ModelBuilderImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@ namespace dnn {
using namespace android::nn::wrapper;

// ModelBuilder auto generated methods start
#if __ANDROID_API__ >= 27
ModelBuilder::Index ModelBuilder::AddConv(
const std::string &input, const std::string &weight,
const dnn::optional<std::string> &bias, int32_t padding_left,
int32_t padding_right, int32_t padding_top, int32_t padding_bottom,
int32_t stride_x, int32_t stride_y, int32_t fuse_code,
const std::string &output,
const dnn::optional<QuantInfo> &output_quant_info) {
if (nnapi_->android_sdk_version < 27) {
throw std::runtime_error("Conv requires API 27");
}
IndexSeq input_indexes;
imm_blob_inputs_.insert(input);
const auto input_idx = operand_indexes_.at(input);
Expand Down Expand Up @@ -64,14 +66,15 @@ ModelBuilder::Index ModelBuilder::AddConv(
imm_blob_outputs_.insert(output);
return output_idx;
}
#endif // __ANDROID_API__ >= 27
#if __ANDROID_API__ >= 27
ModelBuilder::Index ModelBuilder::AddAvePool(
const std::string &input, int32_t padding_left, int32_t padding_right,
int32_t padding_top, int32_t padding_bottom, int32_t stride_x,
int32_t stride_y, int32_t kernel_width, int32_t kernel_height,
int32_t fuse_code, const std::string &output,
const dnn::optional<QuantInfo> &output_quant_info) {
if (nnapi_->android_sdk_version < 27) {
throw std::runtime_error("AvePool requires API 27");
}
IndexSeq input_indexes;
imm_blob_inputs_.insert(input);
const auto input_idx = operand_indexes_.at(input);
Expand All @@ -90,14 +93,15 @@ ModelBuilder::Index ModelBuilder::AddAvePool(
imm_blob_outputs_.insert(output);
return output_idx;
}
#endif // __ANDROID_API__ >= 27
#if __ANDROID_API__ >= 27
ModelBuilder::Index ModelBuilder::AddMaxPool(
const std::string &input, int32_t padding_left, int32_t padding_right,
int32_t padding_top, int32_t padding_bottom, int32_t stride_x,
int32_t stride_y, int32_t kernel_width, int32_t kernel_height,
int32_t fuse_code, const std::string &output,
const dnn::optional<QuantInfo> &output_quant_info) {
if (nnapi_->android_sdk_version < 27) {
throw std::runtime_error("MaxPool requires API 27");
}
IndexSeq input_indexes;
imm_blob_inputs_.insert(input);
const auto input_idx = operand_indexes_.at(input);
Expand All @@ -116,10 +120,11 @@ ModelBuilder::Index ModelBuilder::AddMaxPool(
imm_blob_outputs_.insert(output);
return output_idx;
}
#endif // __ANDROID_API__ >= 27
#if __ANDROID_API__ >= 27
ModelBuilder::Index ModelBuilder::AddReLU(const std::string &input,
const std::string &output) {
if (nnapi_->android_sdk_version < 27) {
throw std::runtime_error("ReLU requires API 27");
}
IndexSeq input_indexes;
imm_blob_inputs_.insert(input);
const auto input_idx = operand_indexes_.at(input);
Expand All @@ -133,11 +138,12 @@ ModelBuilder::Index ModelBuilder::AddReLU(const std::string &input,
imm_blob_outputs_.insert(output);
return output_idx;
}
#endif // __ANDROID_API__ >= 27
#if __ANDROID_API__ >= 27
ModelBuilder::Index ModelBuilder::AddSoftmax(const std::string &input,
float beta,
const std::string &output) {
if (nnapi_->android_sdk_version < 27) {
throw std::runtime_error("Softmax requires API 27");
}
IndexSeq input_indexes;
imm_blob_inputs_.insert(input);
const auto input_idx = operand_indexes_.at(input);
Expand All @@ -152,13 +158,14 @@ ModelBuilder::Index ModelBuilder::AddSoftmax(const std::string &input,
imm_blob_outputs_.insert(output);
return output_idx;
}
#endif // __ANDROID_API__ >= 27
#if __ANDROID_API__ >= 27
ModelBuilder::Index ModelBuilder::AddFC(
const std::string &input, const std::string &weight,
const dnn::optional<std::string> &bias, int32_t fuse_code,
const std::string &output,
const dnn::optional<QuantInfo> &output_quant_info) {
if (nnapi_->android_sdk_version < 27) {
throw std::runtime_error("FC requires API 27");
}
IndexSeq input_indexes;
imm_blob_inputs_.insert(input);
const auto input_idx = operand_indexes_.at(input);
Expand Down Expand Up @@ -201,12 +208,13 @@ ModelBuilder::Index ModelBuilder::AddFC(
imm_blob_outputs_.insert(output);
return output_idx;
}
#endif // __ANDROID_API__ >= 27
#if __ANDROID_API__ >= 27
ModelBuilder::Index ModelBuilder::AddAdd(
const std::string &input1, const std::string &input2, int32_t fuse_code,
const std::string &output,
const dnn::optional<QuantInfo> &output_quant_info) {
if (nnapi_->android_sdk_version < 27) {
throw std::runtime_error("Add requires API 27");
}
IndexSeq input_indexes;
imm_blob_inputs_.insert(input1);
const auto input1_idx = operand_indexes_.at(input1);
Expand All @@ -224,11 +232,12 @@ ModelBuilder::Index ModelBuilder::AddAdd(
imm_blob_outputs_.insert(output);
return output_idx;
}
#endif // __ANDROID_API__ >= 27
#if __ANDROID_API__ >= 27
ModelBuilder::Index ModelBuilder::AddConcat(
const std::vector<std::string> &inputs, int32_t axis,
const std::string &output) {
if (nnapi_->android_sdk_version < 27) {
throw std::runtime_error("Concat requires API 27");
}
IndexSeq input_indexes;
for (const auto &x : inputs) {
imm_blob_inputs_.insert(x);
Expand All @@ -244,15 +253,16 @@ ModelBuilder::Index ModelBuilder::AddConcat(
imm_blob_outputs_.insert(output);
return output_idx;
}
#endif // __ANDROID_API__ >= 27
#if __ANDROID_API__ >= 27
ModelBuilder::Index ModelBuilder::AddDepthwiseConv(
const std::string &input, const std::string &weight,
const dnn::optional<std::string> &bias, int32_t padding_left,
int32_t padding_right, int32_t padding_top, int32_t padding_bottom,
int32_t stride_x, int32_t stride_y, int32_t depth_multiplier,
int32_t fuse_code, const std::string &output,
const dnn::optional<QuantInfo> &output_quant_info) {
if (nnapi_->android_sdk_version < 27) {
throw std::runtime_error("DepthwiseConv requires API 27");
}
IndexSeq input_indexes;
imm_blob_inputs_.insert(input);
const auto input_idx = operand_indexes_.at(input);
Expand Down Expand Up @@ -299,11 +309,12 @@ ModelBuilder::Index ModelBuilder::AddDepthwiseConv(
imm_blob_outputs_.insert(output);
return output_idx;
}
#endif // __ANDROID_API__ >= 27
#if __ANDROID_API__ >= 28
ModelBuilder::Index ModelBuilder::AddBatchToSpaceND(
const std::string &input, const std::vector<int32_t> &block_sizes,
const std::string &output) {
if (nnapi_->android_sdk_version < 28) {
throw std::runtime_error("BatchToSpaceND requires API 28");
}
IndexSeq input_indexes;
imm_blob_inputs_.insert(input);
const auto input_idx = operand_indexes_.at(input);
Expand All @@ -321,11 +332,12 @@ ModelBuilder::Index ModelBuilder::AddBatchToSpaceND(
imm_blob_outputs_.insert(output);
return output_idx;
}
#endif // __ANDROID_API__ >= 28
#if __ANDROID_API__ >= 28
ModelBuilder::Index ModelBuilder::AddSpaceToBatchND(
const std::string &input, const std::vector<int32_t> &block_sizes,
const std::vector<int32_t> &pads, const std::string &output) {
if (nnapi_->android_sdk_version < 28) {
throw std::runtime_error("SpaceToBatchND requires API 28");
}
IndexSeq input_indexes;
imm_blob_inputs_.insert(input);
const auto input_idx = operand_indexes_.at(input);
Expand All @@ -347,13 +359,14 @@ ModelBuilder::Index ModelBuilder::AddSpaceToBatchND(
imm_blob_outputs_.insert(output);
return output_idx;
}
#endif // __ANDROID_API__ >= 28
#if __ANDROID_API__ >= 28
ModelBuilder::Index ModelBuilder::AddStridedSlice(
const std::string &input, const std::vector<int32_t> &starts,
const std::vector<int32_t> &ends, const std::vector<int32_t> &strides,
int32_t begin_mask, int32_t end_mask, int32_t shrink_axis_mask,
const std::string &output) {
if (nnapi_->android_sdk_version < 28) {
throw std::runtime_error("StridedSlice requires API 28");
}
IndexSeq input_indexes;
imm_blob_inputs_.insert(input);
const auto input_idx = operand_indexes_.at(input);
Expand Down Expand Up @@ -381,12 +394,13 @@ ModelBuilder::Index ModelBuilder::AddStridedSlice(
imm_blob_outputs_.insert(output);
return output_idx;
}
#endif // __ANDROID_API__ >= 28
#if __ANDROID_API__ >= 27
ModelBuilder::Index ModelBuilder::AddMul(
const std::string &input1, const std::string &input2, int32_t fuse_code,
const std::string &output,
const dnn::optional<QuantInfo> &output_quant_info) {
if (nnapi_->android_sdk_version < 27) {
throw std::runtime_error("Mul requires API 27");
}
IndexSeq input_indexes;
imm_blob_inputs_.insert(input1);
const auto input1_idx = operand_indexes_.at(input1);
Expand All @@ -404,11 +418,12 @@ ModelBuilder::Index ModelBuilder::AddMul(
imm_blob_outputs_.insert(output);
return output_idx;
}
#endif // __ANDROID_API__ >= 27
#if __ANDROID_API__ >= 27
ModelBuilder::Index ModelBuilder::AddAdd(const std::string &input, float scalar,
int32_t fuse_code,
const std::string &output) {
if (nnapi_->android_sdk_version < 27) {
throw std::runtime_error("Add requires API 27");
}
IndexSeq input_indexes;
imm_blob_inputs_.insert(input);
const auto input_idx = operand_indexes_.at(input);
Expand All @@ -426,11 +441,12 @@ ModelBuilder::Index ModelBuilder::AddAdd(const std::string &input, float scalar,
imm_blob_outputs_.insert(output);
return output_idx;
}
#endif // __ANDROID_API__ >= 27
#if __ANDROID_API__ >= 27
ModelBuilder::Index ModelBuilder::AddMul(const std::string &input, float scalar,
int32_t fuse_code,
const std::string &output) {
if (nnapi_->android_sdk_version < 27) {
throw std::runtime_error("Mul requires API 27");
}
IndexSeq input_indexes;
imm_blob_inputs_.insert(input);
const auto input_idx = operand_indexes_.at(input);
Expand All @@ -448,10 +464,11 @@ ModelBuilder::Index ModelBuilder::AddMul(const std::string &input, float scalar,
imm_blob_outputs_.insert(output);
return output_idx;
}
#endif // __ANDROID_API__ >= 27
#if __ANDROID_API__ >= 27
ModelBuilder::Index ModelBuilder::AddDequantize(const std::string &input,
const std::string &output) {
if (nnapi_->android_sdk_version < 27) {
throw std::runtime_error("Dequantize requires API 27");
}
IndexSeq input_indexes;
imm_blob_inputs_.insert(input);
const auto input_idx = operand_indexes_.at(input);
Expand All @@ -465,12 +482,13 @@ ModelBuilder::Index ModelBuilder::AddDequantize(const std::string &input,
imm_blob_outputs_.insert(output);
return output_idx;
}
#endif // __ANDROID_API__ >= 27
#if __ANDROID_API__ >= 27
ModelBuilder::Index ModelBuilder::AddLRN(const std::string &input,
int32_t radius, float bias,
float alpha, float beta,
const std::string &output) {
if (nnapi_->android_sdk_version < 27) {
throw std::runtime_error("LRN requires API 27");
}
IndexSeq input_indexes;
imm_blob_inputs_.insert(input);
const auto input_idx = operand_indexes_.at(input);
Expand All @@ -486,10 +504,11 @@ ModelBuilder::Index ModelBuilder::AddLRN(const std::string &input,
imm_blob_outputs_.insert(output);
return output_idx;
}
#endif // __ANDROID_API__ >= 27
#if __ANDROID_API__ >= 27
ModelBuilder::Index ModelBuilder::AddTanh(const std::string &input,
const std::string &output) {
if (nnapi_->android_sdk_version < 27) {
throw std::runtime_error("Tanh requires API 27");
}
IndexSeq input_indexes;
imm_blob_inputs_.insert(input);
const auto input_idx = operand_indexes_.at(input);
Expand All @@ -503,10 +522,11 @@ ModelBuilder::Index ModelBuilder::AddTanh(const std::string &input,
imm_blob_outputs_.insert(output);
return output_idx;
}
#endif // __ANDROID_API__ >= 27
#if __ANDROID_API__ >= 27
ModelBuilder::Index ModelBuilder::AddFloor(const std::string &input,
const std::string &output) {
if (nnapi_->android_sdk_version < 27) {
throw std::runtime_error("Floor requires API 27");
}
IndexSeq input_indexes;
imm_blob_inputs_.insert(input);
const auto input_idx = operand_indexes_.at(input);
Expand All @@ -520,7 +540,6 @@ ModelBuilder::Index ModelBuilder::AddFloor(const std::string &input,
imm_blob_outputs_.insert(output);
return output_idx;
}
#endif // __ANDROID_API__ >= 27
// ModelBuilder auto generated methods end

// Methods for backward compatibility
Expand Down
7 changes: 3 additions & 4 deletions generate_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,13 +278,15 @@ def generate_model_builder():
for i, op in enumerate(cfg):
if len(op['input']) == 0:
continue
cogoutl('#if __ANDROID_API__ >= {}'.format(op['api']))
ipt_opt = op['input'] + op['output']
params = list(map(get_param, ipt_opt))
if op['support_quant_asymm']:
params.append(('const dnn::optional<QuantInfo> &', 'output_quant_info'))
params_str = ', '.join(map(lambda param: "{} {}".format(*param), params))
cogoutl("ModelBuilder::Index ModelBuilder::Add{}({}) {{".format(op['name'], params_str))
cogoutl(f'if (nnapi_->android_sdk_version < {op["api"]}) {{'
f'throw std::runtime_error("{op["name"]} requires API {op["api"]}");'
f'}}')
tensor_input = list(filter(lambda x: x['nnapi_type'] == 'tensor', op['input']))
scalar_input = list(filter(lambda x: x['nnapi_type'] == 'scalar', op['input']))

Expand Down Expand Up @@ -317,19 +319,16 @@ def generate_model_builder():
}
'''
)
cogoutl('#endif // __ANDROID_API__ >= {}'.format(op['api']))
update_code('dnnlibrary/ModelBuilderImpl.cpp', 'ModelBuilder auto generated methods')
for i, op in enumerate(cfg):
if len(op['input']) == 0:
continue
cogoutl('#if __ANDROID_API__ >= {}'.format(op['api']))
ipt_opt = op['input'] + op['output']
params = list(map(get_param, ipt_opt))
if op['support_quant_asymm']:
params.append(('const dnn::optional<QuantInfo> &', 'output_quant_info'))
params_str = ', '.join(map(lambda param: "{} {}".format(*param), params))
cogoutl("ModelBuilder::Index Add{}({});".format(op['name'], params_str))
cogoutl('#endif // __ANDROID_API__ >= {}'.format(op['api']))
update_code('include/dnnlibrary/ModelBuilder.h', 'ModelBuilder auto generated methods')


Expand Down
Loading

0 comments on commit ffe6d9a

Please sign in to comment.