Skip to content

Commit

Permalink
Merge branch 'main' into feat-compression
Browse files Browse the repository at this point in the history
  • Loading branch information
mergify[bot] authored Nov 5, 2024
2 parents 6d52996 + 8eb6b23 commit 50f8ef5
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 3 deletions.
1 change: 1 addition & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ py_pkg_cc_deps(
http_archive(
name = "nnlib_hifi4",
build_file = "@tflite_micro//third_party/xtensa/nnlib_hifi4:nnlib_hifi4.BUILD",
integrity = "sha256-ulZ+uY4dRsbDUMZbZtD972eghclWQrqYRb0Y4Znfyyc=",
strip_prefix = "nnlib-hifi4-34f5f995f28d298ae2b6e2ba6e76c32a5cb34989",
urls = ["https://github.com/foss-xtensa/nnlib-hifi4/archive/34f5f995f28d298ae2b6e2ba6e76c32a5cb34989.zip"],
)
11 changes: 8 additions & 3 deletions tensorflow/lite/micro/kernels/cmsis_nn/fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,22 +148,27 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
} else if (input->type == kTfLiteInt8) {
buf_size = arm_fully_connected_s8_get_buffer_size(&filter_dims);

int8_t* filter_data = GetTensorData<int8_t>(filter);
data->kernel_sums = nullptr;

#if defined(KERNELS_OPTIMIZED_FOR_SPEED)
const int8_t* filter_data = GetTensorData<const int8_t>(filter);

if (buf_size > 0 && filter_data != nullptr) {
const int32_t input_offset = -data->reference_op_data.input_zero_point;
const int32_t filter_offset =
-data->reference_op_data.filter_zero_point;

data->kernel_sums = static_cast<int32_t*>(
context->AllocatePersistentBuffer(context, buf_size));

int32_t input_offset = -data->reference_op_data.input_zero_point;
int32_t filter_offset = -data->reference_op_data.filter_zero_point;
arm_vector_sum_s8(data->kernel_sums, filter_dims.n, data->output_depth,
filter_data, input_offset, filter_offset,
tflite::GetTensorData<int32_t>(bias));

// Do not request a scratch buffer since using persistent memory
buf_size = 0;
}
#endif
}
}

Expand Down
29 changes: 29 additions & 0 deletions tensorflow/lite/micro/kernels/cmsis_nn/svdf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ struct CmsisNnOpDataSvdf {
int effective_scale_1_b;
int effective_scale_2_b;
int scratch_tensor_index;
#if defined(KERNELS_OPTIMIZED_FOR_SIZE)
int scratch_weight_tensor_index;
#endif
int scratch_output_tensor_index;

// Cached tensor zero point values for quantized operations.
Expand Down Expand Up @@ -189,13 +192,25 @@ TfLiteStatus CmsisNnPrepareSvdf(TfLiteContext* context, TfLiteNode* node) {
const int32_t buf_size = arm_svdf_s8_get_buffer_size(&weights_feature_dims);

if (buf_size > 0) {
#if defined(KERNELS_OPTIMIZED_FOR_SPEED)
data->kernel_sums = static_cast<int32_t*>(
context->AllocatePersistentBuffer(context, buf_size));

arm_vector_sum_s8(data->kernel_sums, input_size, num_filters,
GetTensorData<int8_t>(weights_feature),
-data->input_zero_point,
-data->activation_state_zero_point, nullptr);
#elif defined(KERNELS_OPTIMIZED_FOR_SIZE)
const TfLiteStatus scratch_kernel_status =
context->RequestScratchBufferInArena(
context, buf_size, &(data->scratch_weight_tensor_index));
TF_LITE_ENSURE_OK(context, scratch_kernel_status);
#else
MicroPrintf(
"Either KERNELS_OPTIMIZED_FOR_SIZE or KERNELS_OPTIMIZED_FOR_SPEED "
"must be defined");
return kTfLiteError;
#endif
}

} else {
Expand Down Expand Up @@ -291,7 +306,21 @@ TfLiteStatus EvalIntegerSVDF(TfLiteContext* context, TfLiteNode* node,
switch (weights_time_tensor->type) {
case kTfLiteInt8: {
cmsis_nn_context ctx;

#if defined(KERNELS_OPTIMIZED_FOR_SPEED)
ctx.buf = data.kernel_sums;
#elif defined(KERNELS_OPTIMIZED_FOR_SIZE)
ctx.buf = static_cast<int32_t*>(
context->GetScratchBuffer(context, data.scratch_weight_tensor_index));

const int input_size = input_tensor->dims->data[1];
const int num_filters = weights_feature_tensor->dims->data[0];

arm_vector_sum_s8(
static_cast<int32_t*>(ctx.buf), input_size, num_filters,
tflite::micro::GetTensorData<int8_t>(weights_feature_tensor),
-data.input_zero_point, -data.activation_state_zero_point, nullptr);
#endif

arm_svdf_s8(
&ctx, &scratch_ctx, &scratch_output_ctx, &svdf_params,
Expand Down

0 comments on commit 50f8ef5

Please sign in to comment.