forked from ARM-software/CMSIS-NN
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Treat DW conv with 1 input ch as a regular conv op (ARM-software#156)
Weights are transposed if the optimal output channel threshold condition is met before calling the conv wrapper.
- Loading branch information
Showing
14 changed files
with
289 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
/* | ||
* SPDX-FileCopyrightText: Copyright 2010-2023 Arm Limited and/or its affiliates <[email protected]> | ||
* SPDX-FileCopyrightText: Copyright 2010-2024 Arm Limited and/or its affiliates <[email protected]> | ||
* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
|
@@ -22,14 +22,15 @@ | |
* Description: Wrapper API to select appropriate depthwise conv API based | ||
* on dimensions. | ||
* | ||
* $Date: 13 January 2023 | ||
* $Revision: V.2.1.0 | ||
* $Date: 04 November 2024 | ||
* $Revision: V.2.2.0 | ||
* | ||
* Target : Arm(R) M-Profile Architecture | ||
* | ||
* -------------------------------------------------------------------- */ | ||
|
||
#include "arm_nnfunctions.h" | ||
#include "arm_nnsupportfunctions.h" | ||
|
||
/** | ||
* @ingroup Public | ||
|
@@ -40,6 +41,51 @@ | |
* @{ | ||
*/ | ||
|
||
#if defined(ARM_MATH_MVEI) | ||
static arm_cmsis_nn_status arm_depthwise_conv_to_conv_s8(const cmsis_nn_context *ctx, | ||
const cmsis_nn_dw_conv_params *dw_conv_params, | ||
const cmsis_nn_per_channel_quant_params *quant_params, | ||
const cmsis_nn_dims *input_dims, | ||
const int8_t *input, | ||
const cmsis_nn_dims *filter_dims, | ||
const int8_t *filter, | ||
const cmsis_nn_dims *bias_dims, | ||
const int32_t *bias, | ||
const cmsis_nn_dims *output_dims, | ||
int8_t *output) | ||
{ | ||
const cmsis_nn_conv_params conv_params = {dw_conv_params->input_offset, | ||
dw_conv_params->output_offset, | ||
dw_conv_params->stride, | ||
dw_conv_params->padding, | ||
dw_conv_params->dilation, | ||
dw_conv_params->activation}; | ||
const cmsis_nn_dims filter_output_dims = {filter_dims->c, filter_dims->h, filter_dims->w, filter_dims->n}; | ||
int8_t *w_buf = | ||
ctx->buf + arm_convolve_wrapper_s8_get_buffer_size(&conv_params, input_dims, &filter_output_dims, output_dims); | ||
const uint32_t perm[4] = {3, 1, 2, 0}; | ||
const cmsis_nn_transpose_params transpose_params = {4, perm}; | ||
|
||
arm_cmsis_nn_status status = arm_transpose_s8(filter, w_buf, filter_dims, &filter_output_dims, &transpose_params); | ||
|
||
if (status == ARM_CMSIS_NN_SUCCESS) | ||
{ | ||
status = arm_convolve_wrapper_s8(ctx, | ||
&conv_params, | ||
quant_params, | ||
input_dims, | ||
input, | ||
&filter_output_dims, | ||
(const int8_t *)w_buf, | ||
bias_dims, | ||
bias, | ||
output_dims, | ||
output); | ||
} | ||
return status; | ||
} | ||
#endif | ||
|
||
/* | ||
* s8 Depthwise conv wrapper function | ||
* | ||
|
@@ -59,6 +105,24 @@ arm_cmsis_nn_status arm_depthwise_conv_wrapper_s8(const cmsis_nn_context *ctx, | |
int8_t *output) | ||
{ | ||
arm_cmsis_nn_status status = ARM_CMSIS_NN_SUCCESS; | ||
|
||
#if defined(ARM_MATH_MVEI) | ||
if (input_dims->c == 1 && output_dims->c > CONVERT_DW_CONV_WITH_ONE_INPUT_CH_AND_OUTPUT_CH_ABOVE_THRESHOLD) | ||
{ | ||
return arm_depthwise_conv_to_conv_s8(ctx, | ||
dw_conv_params, | ||
quant_params, | ||
input_dims, | ||
input, | ||
filter_dims, | ||
filter, | ||
bias_dims, | ||
bias, | ||
output_dims, | ||
output); | ||
} | ||
#endif | ||
|
||
if (1 == dw_conv_params->ch_mult && input_dims->n == 1 && dw_conv_params->dilation.w == 1 && | ||
dw_conv_params->dilation.h == 1) | ||
{ | ||
|
6 changes: 6 additions & 0 deletions
6
Tests/UnitTest/TestCases/TestData/in_ch_one_out_ch_larger_one/biases_data.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
// Generated by test_settings.py using tensorflow version 2.17.0 (Keras version 3.5.0). | ||
// Interpreter from tensorflow version 2.17.0 and revision v2.17.0-rc1-2-gad6d8cc177d. | ||
#pragma once | ||
#include <stdint.h> | ||
|
||
const int32_t in_ch_one_out_ch_larger_one_biases[1] = {-4565}; |
25 changes: 25 additions & 0 deletions
25
Tests/UnitTest/TestCases/TestData/in_ch_one_out_ch_larger_one/config_data.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
// Generated by test_settings.py using tensorflow version 2.17.0 (Keras version 3.5.0). | ||
// Interpreter from tensorflow version 2.17.0 and revision v2.17.0-rc1-2-gad6d8cc177d. | ||
#pragma once | ||
#define IN_CH_ONE_OUT_CH_LARGER_ONE_OUT_CH 1 | ||
#define IN_CH_ONE_OUT_CH_LARGER_ONE_IN_CH 1 | ||
#define IN_CH_ONE_OUT_CH_LARGER_ONE_INPUT_W 7 | ||
#define IN_CH_ONE_OUT_CH_LARGER_ONE_INPUT_H 7 | ||
#define IN_CH_ONE_OUT_CH_LARGER_ONE_DST_SIZE 16 | ||
#define IN_CH_ONE_OUT_CH_LARGER_ONE_INPUT_SIZE 49 | ||
#define IN_CH_ONE_OUT_CH_LARGER_ONE_OUT_ACTIVATION_MIN -128 | ||
#define IN_CH_ONE_OUT_CH_LARGER_ONE_OUT_ACTIVATION_MAX 127 | ||
#define IN_CH_ONE_OUT_CH_LARGER_ONE_INPUT_BATCHES 1 | ||
#define IN_CH_ONE_OUT_CH_LARGER_ONE_FILTER_X 3 | ||
#define IN_CH_ONE_OUT_CH_LARGER_ONE_FILTER_Y 3 | ||
#define IN_CH_ONE_OUT_CH_LARGER_ONE_STRIDE_X 2 | ||
#define IN_CH_ONE_OUT_CH_LARGER_ONE_STRIDE_Y 2 | ||
#define IN_CH_ONE_OUT_CH_LARGER_ONE_PAD_X 1 | ||
#define IN_CH_ONE_OUT_CH_LARGER_ONE_PAD_Y 1 | ||
#define IN_CH_ONE_OUT_CH_LARGER_ONE_OUTPUT_W 4 | ||
#define IN_CH_ONE_OUT_CH_LARGER_ONE_OUTPUT_H 4 | ||
#define IN_CH_ONE_OUT_CH_LARGER_ONE_CH_MULT 1 | ||
#define IN_CH_ONE_OUT_CH_LARGER_ONE_INPUT_OFFSET 128 | ||
#define IN_CH_ONE_OUT_CH_LARGER_ONE_OUTPUT_OFFSET 127 | ||
#define IN_CH_ONE_OUT_CH_LARGER_ONE_DILATION_X 1 | ||
#define IN_CH_ONE_OUT_CH_LARGER_ONE_DILATION_Y 1 |
9 changes: 9 additions & 0 deletions
9
Tests/UnitTest/TestCases/TestData/in_ch_one_out_ch_larger_one/input_data.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
// Generated by test_settings.py using tensorflow version 2.17.0 (Keras version 3.5.0). | ||
// Interpreter from tensorflow version 2.17.0 and revision v2.17.0-rc1-2-gad6d8cc177d. | ||
#pragma once | ||
#include <stdint.h> | ||
|
||
const int8_t in_ch_one_out_ch_larger_one_input[49] = { | ||
-65, 36, 56, -82, 109, 99, -113, -63, 47, -83, -100, 123, 46, 125, -52, 65, 12, | ||
-55, 11, -85, 123, 97, -55, 79, 33, 39, -39, 64, -1, 89, -8, 17, -16, -90, | ||
-66, 58, 126, 36, -52, 46, 66, -83, -125, -93, -52, -61, -14, -62, -76}; |
6 changes: 6 additions & 0 deletions
6
Tests/UnitTest/TestCases/TestData/in_ch_one_out_ch_larger_one/output_mult_data.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
// Generated by test_settings.py using tensorflow version 2.17.0 (Keras version 3.5.0). | ||
// Interpreter from tensorflow version 2.17.0 and revision v2.17.0-rc1-2-gad6d8cc177d. | ||
#pragma once | ||
#include <stdint.h> | ||
|
||
const int32_t in_ch_one_out_ch_larger_one_output_mult[1] = {2129586399}; |
7 changes: 7 additions & 0 deletions
7
Tests/UnitTest/TestCases/TestData/in_ch_one_out_ch_larger_one/output_ref_data.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
// Generated by test_settings.py using tensorflow version 2.17.0 (Keras version 3.5.0). | ||
// Interpreter from tensorflow version 2.17.0 and revision v2.17.0-rc1-2-gad6d8cc177d. | ||
#pragma once | ||
#include <stdint.h> | ||
|
||
const int8_t in_ch_one_out_ch_larger_one_output_ref[16] = | ||
{97, 22, 11, 36, 70, 24, 5, -68, 35, -27, 33, -2, 121, 38, 72, 72}; |
6 changes: 6 additions & 0 deletions
6
Tests/UnitTest/TestCases/TestData/in_ch_one_out_ch_larger_one/output_shift_data.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
// Generated by test_settings.py using tensorflow version 2.17.0 (Keras version 3.5.0). | ||
// Interpreter from tensorflow version 2.17.0 and revision v2.17.0-rc1-2-gad6d8cc177d. | ||
#pragma once | ||
#include <stdint.h> | ||
|
||
const int32_t in_ch_one_out_ch_larger_one_output_shift[1] = {-9}; |
9 changes: 9 additions & 0 deletions
9
Tests/UnitTest/TestCases/TestData/in_ch_one_out_ch_larger_one/test_data.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
// Generated by test_settings.py using tensorflow version 2.17.0 (Keras version 3.5.0). | ||
// Interpreter from tensorflow version 2.17.0 and revision v2.17.0-rc1-2-gad6d8cc177d. | ||
#include "biases_data.h" | ||
#include "config_data.h" | ||
#include "input_data.h" | ||
#include "output_mult_data.h" | ||
#include "output_ref_data.h" | ||
#include "output_shift_data.h" | ||
#include "weights_data.h" |
6 changes: 6 additions & 0 deletions
6
Tests/UnitTest/TestCases/TestData/in_ch_one_out_ch_larger_one/weights_data.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
// Generated by test_settings.py using tensorflow version 2.17.0 (Keras version 3.5.0). | ||
// Interpreter from tensorflow version 2.17.0 and revision v2.17.0-rc1-2-gad6d8cc177d. | ||
#pragma once | ||
#include <stdint.h> | ||
|
||
const int8_t in_ch_one_out_ch_larger_one_weights[9] = {-65, -108, 97, 1, -127, -72, -124, -76, 79}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
/* | ||
* SPDX-FileCopyrightText: Copyright 2010-2023 Arm Limited and/or its affiliates <[email protected]> | ||
* SPDX-FileCopyrightText: Copyright 2010-2024 Arm Limited and/or its affiliates <[email protected]> | ||
* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
|
@@ -55,3 +55,8 @@ void test_depthwise_dilation_arm_depthwise_conv_s8(void) { depthwise_dilation_ar | |
void test_buffer_size_mve_arm_depthwise_conv_s8(void) { buffer_size_mve_arm_depthwise_conv_s8(); } | ||
|
||
void test_buffer_size_dsp_arm_depthwise_conv_s8(void) { buffer_size_dsp_arm_depthwise_conv_s8(); } | ||
|
||
void test_in_ch_one_out_ch_larger_one_arm_depthwise_conv_s8(void) | ||
{ | ||
in_ch_one_out_ch_larger_one_arm_depthwise_conv_s8(); | ||
} |
Oops, something went wrong.