Skip to content

Commit

Permalink
Add MAML model
Browse files Browse the repository at this point in the history
* Add an external MAML implementation as a submodule
* Load omniglot
* Model rewriting in transform.py
  - Change inputs for KWS to eliminate unsupported operators from the
    final ONNX model.
  - Save models at each step of rewriting for debugging
  - Inference dynamic shapes to get rid of a complex graph from forward
    steps like `x = x.view(x.size(0), -1)`
  - Constant folding for Squeeze and Reshape nodes with known new shape
    and constant input
  - Reduce global variables so that the latest model is always used
    after model rewriting
* Other changes in transform.py
  - Move more ONNX helpers to utils.py
  - Make transformation of input samples more robust in terms of input
    shape
  - Don't use a default batch size; make that argument required.
* Implement BatchNormalization
* Use libc_nano instead of libc as the model omniglot-maml is too large

[1] onnx/optimizer#38
[2] microsoft/onnxruntime#5577
  • Loading branch information
Chih-Hsuan Yen committed Feb 7, 2022
1 parent b11ff62 commit 749fa97
Show file tree
Hide file tree
Showing 13 changed files with 188 additions and 25 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,6 @@
[submodule "ARM-CMSIS"]
path = ARM-CMSIS
url = ssh://[email protected]/EMCLab-Sinica/ARM-CMSIS.git
[submodule "data/MAML-Pytorch"]
path = data/MAML-Pytorch
url = https://github.com/yan12125/MAML-Pytorch
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ set(ARM_CMSIS_PATH ${CMAKE_CURRENT_SOURCE_DIR}/ARM-CMSIS/CMSIS)
add_library(arm_cmsis_dsp
${ARM_CMSIS_PATH}/DSP/Source/BasicMathFunctions/arm_add_q15.c
${ARM_CMSIS_PATH}/DSP/Source/BasicMathFunctions/arm_offset_q15.c
${ARM_CMSIS_PATH}/DSP/Source/BasicMathFunctions/arm_mult_q15.c
${ARM_CMSIS_PATH}/DSP/Source/BasicMathFunctions/arm_scale_q15.c
${ARM_CMSIS_PATH}/DSP/Source/BasicMathFunctions/arm_sub_q15.c
${ARM_CMSIS_PATH}/DSP/Source/FastMathFunctions/arm_sqrt_q15.c
${ARM_CMSIS_PATH}/DSP/Source/MatrixFunctions/arm_mat_init_q15.c
${ARM_CMSIS_PATH}/DSP/Source/MatrixFunctions/arm_mat_mult_fast_q15.c
${ARM_CMSIS_PATH}/DSP/Source/StatisticsFunctions/arm_max_q15.c
Expand All @@ -74,7 +77,9 @@ add_library(dsplib
${DSPLIB_PATH}/source/vector/msp_offset_q15.c
${DSPLIB_PATH}/source/vector/msp_max_q15.c
${DSPLIB_PATH}/source/vector/msp_min_q15.c
${DSPLIB_PATH}/source/vector/msp_mpy_q15.c
${DSPLIB_PATH}/source/vector/msp_scale_q15.c
${DSPLIB_PATH}/source/vector/msp_sub_q15.c
${DSPLIB_PATH}/source/utility/msp_deinterleave_q15.c
${DSPLIB_PATH}/source/utility/msp_interleave_q15.c
${DSPLIB_PATH}/source/utility/msp_fill_q15.c
Expand Down
2 changes: 2 additions & 0 deletions common/cnn_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ int64_t get_int64_param(const ParameterInfo *param, size_t i) {
MY_ASSERT(param->bitwidth == 64);
MY_ASSERT(param->slot == SLOT_PARAMETERS);
int64_t ret;
// detect mis-aligned memory access on ARM
MY_ASSERT(param->params_offset % 4 == 0);
my_memcpy_from_parameters(&ret, param, i * sizeof(int64_t), sizeof(int64_t));
return ret;

Expand Down
4 changes: 2 additions & 2 deletions common/intermittent-cnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ static void handle_node(Model *model, uint16_t node_idx) {
my_printf("op_type = %d" NEWLINE, cur_node->op_type);
#endif

int16_t input_id[3];
const ParameterInfo *input[3];
int16_t input_id[NUM_INPUTS];
const ParameterInfo *input[NUM_INPUTS];
for (uint16_t j = 0; j < cur_node->inputs_len; j++) {
input_id[j] = cur_node->inputs[j];
my_printf_debug("input_id[%d] = %d" NEWLINE, j, input_id[j]);
Expand Down
55 changes: 54 additions & 1 deletion common/my_dsplib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ void my_matrix_mpy_q15(uint16_t A_rows, uint16_t A_cols, uint16_t B_rows, uint16
// appears to still give correct results when srcARows is odd
// srcBCols should really be even, though
// http://e2e.ti.com/support/microcontrollers/msp430/f/166/t/716353?MSP430FR5992-MSP-DSPLib-msp-matrix-mpy-q15
MY_ASSERT((A_cols & 1) || (B_cols & 1) == 0);
// MY_ASSERT((A_cols & 1) || (B_cols & 1) == 0); // FIXME
MY_ASSERT(B_rows * B_cols <= ARM_PSTATE_LEN);
MY_ASSERT(A_cols == B_rows);
check_buffer_address(pSrcA, A_rows * A_cols);
Expand Down Expand Up @@ -200,6 +200,26 @@ void my_matrix_mpy_q15(uint16_t A_rows, uint16_t A_cols, uint16_t B_rows, uint16
#endif
}

void my_mpy_q15(const int16_t *pSrcA, const int16_t *pSrcB, int16_t *pDst, uint32_t blockSize) {
check_buffer_address(pSrcA, blockSize);
check_buffer_address(pSrcB, blockSize);
check_buffer_address(pDst, blockSize);
#if !USE_ARM_CMSIS
uint32_t blockSizeForLEA = blockSize / 2 * 2;
if (blockSizeForLEA) {
msp_mpy_q15_params mpy_params;
mpy_params.length = blockSizeForLEA;
msp_status status = msp_mpy_q15(&mpy_params, pSrcA, pSrcB, pDst);
my_checkStatus(status);
}
if (blockSize % 2) {
pDst[blockSize - 1] = pSrcA[blockSize - 1] + pSrcB[blockSize - 1];
}
#else
arm_mult_q15(pSrcA, pSrcB, pDst, blockSize);
#endif
}

void my_scale_q15(const int16_t *pSrc, int16_t scaleFract, uint8_t shift, int16_t *pDst, uint32_t blockSize) {
#if !USE_ARM_CMSIS
uint32_t blockSizeForLEA = blockSize / 2 * 2;
Expand All @@ -218,6 +238,39 @@ void my_scale_q15(const int16_t *pSrc, int16_t scaleFract, uint8_t shift, int16_
#endif
}

void my_sub_q15(const int16_t *pSrcA, const int16_t *pSrcB, int16_t *pDst, uint32_t blockSize) {
check_buffer_address(pSrcA, blockSize);
check_buffer_address(pSrcB, blockSize);
check_buffer_address(pDst, blockSize);
#if !USE_ARM_CMSIS
uint32_t blockSizeForLEA = blockSize / 2 * 2;
if (blockSizeForLEA) {
msp_sub_q15_params sub_params;
sub_params.length = blockSizeForLEA;
msp_status status = msp_sub_q15(&sub_params, pSrcA, pSrcB, pDst);
my_checkStatus(status);
}
if (blockSize % 2) {
pDst[blockSize - 1] = pSrcA[blockSize - 1] + pSrcB[blockSize - 1];
}
#else
arm_sub_q15(pSrcA, pSrcB, pDst, blockSize);
#endif
}

void my_vsqrt_q15(int16_t* pIn, int16_t* pOut, uint32_t blockSize) {
#if !USE_ARM_CMSIS
ERROR_OCCURRED();
#else
// somehow arm_vsqrt_q15 is defined in headers but there is no implementation
for (uint32_t idx = 0; idx < blockSize; idx++) {
arm_sqrt_q15(*pIn, pOut);
pIn++;
pOut++;
}
#endif
}

void my_interleave_q15(const int16_t *pSrc, uint16_t channel, uint16_t numChannels, int16_t *pDst, uint32_t blockSize) {
MY_ASSERT(channel < numChannels);
// XXX: not using LEA here as pSrc and/or pDst is often unaligned
Expand Down
3 changes: 3 additions & 0 deletions common/my_dsplib.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ void my_matrix_mpy_q15(uint16_t A_rows, uint16_t A_cols, uint16_t B_rows, uint16
uint16_t mask, int16_t n_keep_state_bits);
void my_max_q15(const int16_t *pSrc, uint32_t blockSize, int16_t *pResult, uint16_t *pIndex);
void my_min_q15(const int16_t *pSrc, uint32_t blockSize, int16_t *pResult, uint16_t *pIndex);
void my_mpy_q15(const int16_t *pSrcA, const int16_t *pSrcB, int16_t *pDst, uint32_t blockSize);
void my_scale_q15(const int16_t *pSrc, int16_t scaleFract, uint8_t shift, int16_t *pDst, uint32_t blockSize);
void my_sub_q15(const int16_t *pSrcA, const int16_t *pSrcB, int16_t *pDst, uint32_t blockSize);
void my_vsqrt_q15(int16_t* pIn, int16_t* pOut, uint32_t blockSize);
void my_interleave_q15(const int16_t *pSrc, uint16_t channel, uint16_t numChannels, int16_t *pDst, uint32_t blockSize);
void my_deinterleave_q15(const int16_t *pSrc, uint16_t channel, uint16_t numChannels, int16_t *pDst, uint32_t blockSize);
int16_t padding_for_lea(int16_t val);
Expand Down
76 changes: 76 additions & 0 deletions common/op_handlers.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <cstdint>
#include <cmath>
#include "cnn_common.h"
#include "data.h"
#include "op_utils.h"
Expand Down Expand Up @@ -265,3 +266,78 @@ void handle_add(Model *model, const ParameterInfo *input[], ParameterInfo *outpu
}
dump_params_nhwc_debug(model, output, node->output_name);
}

void alloc_batchnormalization(Model* model, const ParameterInfo* input[], ParameterInfo* output, const Node*) {
const ParameterInfo* X = input[0];
output->slot = get_next_slot(model, X);
}

void handle_batchnormalization(Model* model, const ParameterInfo* input[], ParameterInfo* output, const Node*) {
my_printf_debug("BatchNormalization!" NEWLINE);

const ParameterInfo *X = input[0], *scale = input[1], *B = input[2], *mean = input[3], *var = input[4];
const uint16_t CHANNEL = X->dims[1], H = X->dims[2], W= X->dims[3];
int16_t *buffer_x = lea_buffer,
*buffer_scale = buffer_x + CHANNEL,
*buffer_b = buffer_scale + CHANNEL,
*buffer_mean = buffer_b + CHANNEL,
*buffer_var = buffer_mean + CHANNEL;
MY_ASSERT(buffer_var + CHANNEL < lea_buffer + LEA_BUFFER_SIZE);

my_memcpy_from_param(model, buffer_scale, scale, 0, CHANNEL * sizeof(int16_t));
my_memcpy_from_param(model, buffer_b, B, 0, CHANNEL * sizeof(int16_t));
my_memcpy_from_param(model, buffer_mean, mean, 0, CHANNEL * sizeof(int16_t));
my_memcpy_from_param(model, buffer_var, var, 0, CHANNEL * sizeof(int16_t));

int16_t scaleFract;
uint8_t shift;
float_to_scale_params(&scaleFract, &shift, 1.0f * mean->scale / (X->scale * 2));
my_scale_q15(buffer_mean, scaleFract, shift, buffer_mean, CHANNEL);

int16_t var_scale_sqrt = static_cast<int16_t>(sqrtf(1.0f * var->scale));
MY_ASSERT(var_scale_sqrt * var_scale_sqrt == var->scale);

float_to_scale_params(&scaleFract, &shift, 1.0f * var_scale_sqrt / (X->scale * 2));
my_scale_q15(buffer_b, scaleFract, shift, buffer_b, CHANNEL);

output->scale = scale->scale * (X->scale * 2) / var_scale_sqrt;

// assume conventional epsilon
my_offset_q15(buffer_var, static_cast<int16_t>(0.00001 * 0x8000 / var->scale), buffer_var, CHANNEL);
my_vsqrt_q15(buffer_var, buffer_var, CHANNEL);

uint32_t offset = 0;
for (uint16_t idx = 0; idx < H * W; idx++) {
my_memcpy_from_param(model, buffer_x, X, offset, CHANNEL * sizeof(int16_t));

my_printf_debug("(h, w) = (%d, %d)" NEWLINE, idx / W, idx % W);

my_sub_q15(buffer_x, buffer_mean, buffer_x, CHANNEL);
my_printf_debug("x - mean" NEWLINE);
dump_matrix_debug(buffer_x, CHANNEL, ValueInfo(X->scale * 2));

// XXX: use LEA?
for (uint16_t channel = 0; channel < CHANNEL; channel++) {
// https://sestevenson.wordpress.com/2010/09/20/fixed-point-division-2/
int32_t tmp = (static_cast<int32_t>(buffer_x[channel]) << 15) / static_cast<int32_t>(buffer_var[channel]);
tmp = MIN_VAL(32767, MAX_VAL(-32768, tmp));
buffer_x[channel] = static_cast<int16_t>(tmp);
}
my_printf_debug("(x - mean)/sqrt(var+epsilon)" NEWLINE);
dump_matrix_debug(buffer_x, CHANNEL, ValueInfo((X->scale * 2) / var_scale_sqrt));

my_mpy_q15(buffer_x, buffer_scale, buffer_x, CHANNEL);
my_printf_debug("(x - mean)/sqrt(var+epsilon)*scale" NEWLINE);
dump_matrix_debug(buffer_x, CHANNEL, ValueInfo(output, model));

my_add_q15(buffer_x, buffer_b, buffer_x, CHANNEL);
my_printf_debug("(x - mean)/sqrt(var+epsilon)*scale+B" NEWLINE);
dump_matrix_debug(buffer_x, CHANNEL, ValueInfo(output, model));

my_memcpy_to_param(output, offset, buffer_x, CHANNEL * sizeof(int16_t), 0);
offset += CHANNEL;
}

my_printf_debug("handle_batchnormalization output" NEWLINE);
dump_params_nhwc_debug(model, output);
}
14 changes: 14 additions & 0 deletions configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
load_data_cifar10,
load_data_google_speech,
load_har,
load_data_omniglot,
)

# intermediate_values_size should < 65536, or TI's compiler gets confused
Expand Down Expand Up @@ -59,5 +60,18 @@
'first_sample_outputs': [ -6.194588, 2.2284777, -13.659239, -1.4972568, 13.473643, -10.446839 ],
'fp32_accuracy': 0.9121,
},
'omniglot': {
'onnx_model': 'data/maml.onnx',
'scale': 4,
'input_scale': 4,
'num_slots': 2,
'intermediate_values_size': 30000,
'data_loader': load_data_omniglot,
'n_all_samples': 5 * 20, # 5-way (classes), each with 20 samples
'sample_size': [1, 28, 28],
'op_filters': 4,
'first_sample_outputs': [ -0.230564, -0.879236, -0.910271, -0.212429, 0.534965 ],
'fp32_accuracy': 0,
},
}

1 change: 1 addition & 0 deletions data/MAML-Pytorch
Submodule MAML-Pytorch added at 6de70c
Binary file added data/maml.onnx
Binary file not shown.
3 changes: 2 additions & 1 deletion msp432/.cproject
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@
<option id="com.ti.ccstudio.buildDefinitions.MSP432_GNU_7.0.linkerID.OUTPUT_FILE.1474882655" name="Output file (-o)" superClass="com.ti.ccstudio.buildDefinitions.MSP432_GNU_7.0.linkerID.OUTPUT_FILE" value="${ProjName}.out" valueType="string"/>
<option id="com.ti.ccstudio.buildDefinitions.MSP432_GNU_7.0.linkerID.MAP_FILE.20271074" name="Write a map file (-Map)" superClass="com.ti.ccstudio.buildDefinitions.MSP432_GNU_7.0.linkerID.MAP_FILE" value="${ProjName}.map" valueType="string"/>
<option IS_BUILTIN_EMPTY="false" IS_VALUE_EMPTY="false" id="com.ti.ccstudio.buildDefinitions.MSP432_GNU_7.0.linkerID.LIBRARY.570316691" name="Libraries (-l, --library)" superClass="com.ti.ccstudio.buildDefinitions.MSP432_GNU_7.0.linkerID.LIBRARY" valueType="libs">
<listOptionValue builtIn="false" value="c"/>
<listOptionValue builtIn="false" value="c_nano"/>
<listOptionValue builtIn="false" value="m"/>
</option>
<inputType id="com.ti.ccstudio.buildDefinitions.MSP432_GNU_7.0.exeLinker.inputType__CMD_SRCS.1683061491" name="Linker Command Files" superClass="com.ti.ccstudio.buildDefinitions.MSP432_GNU_7.0.exeLinker.inputType__CMD_SRCS"/>
<inputType id="com.ti.ccstudio.buildDefinitions.MSP432_GNU_7.0.exeLinker.inputType__CMD2_SRCS.486498849" name="Linker Command Files" superClass="com.ti.ccstudio.buildDefinitions.MSP432_GNU_7.0.exeLinker.inputType__CMD2_SRCS"/>
Expand Down
32 changes: 11 additions & 21 deletions transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,8 +345,8 @@ def infer_auto_pad(node):
raise NotImplementedError

for idx, n in enumerate(nodes):
if n.op_type == 'Dropout':
output = n.output[:1] # we don't care the second output `mask`
if n.op_type in ('Dropout', 'BatchNormalization'):
output = n.output[:1] # we don't care outputs for training
else:
output = n.output
if n.op_type == 'Conv':
Expand Down Expand Up @@ -452,6 +452,7 @@ def get_memory_usage(output_tile_c, filter_len):

if not input_tile_too_large:
params_len = math.ceil(CHANNEL / node_flags.input_tile_c) * OUTPUT_CHANNEL * OUTPUT_H * OUTPUT_W * 2
logger.debug('Candidate params_len %d', params_len)
if params_len < config['intermediate_values_size']:
break
logger.debug(f'params_len={params_len}, too high!')
Expand Down Expand Up @@ -615,13 +616,6 @@ def write_str(buffer: io.BytesIO, data: str):

parameter_info_idx = 0

def decode_raw_data(params):
format_char = {
onnx.TensorProto.FLOAT: 'f',
onnx.TensorProto.INT64: 'q',
}[params.data_type]
return list(map(lambda t: t[0], struct.iter_unpack(format_char, params.raw_data)))

model_parameters_info = outputs['model_parameters_info']
for params in parameters:
if params is None: # input
Expand All @@ -643,10 +637,7 @@ def decode_raw_data(params):
param_scale = 0
assert len(params.dims) <= 4
if params.data_type == onnx.TensorProto.FLOAT:
if params.float_data:
float_data = params.float_data
else:
float_data = decode_raw_data(params)
float_data = extract_data(params).flatten()
data_len = len(float_data)
assert data_len > 0
slot = parameters_slot
Expand All @@ -658,20 +649,19 @@ def decode_raw_data(params):
param_scale = config['scale']
slot.target.write(to_bytes(_Q15(np.array(float_data) / param_scale, 'Parameter')))
slot.offset += 2 * len(float_data)
if slot.offset % 4:
slot.offset += 2
slot.target.write(to_bytes(0))
model_parameters_info.write(to_bytes(16, size=8)) # bitwidth
elif params.data_type == onnx.TensorProto.INT64:
if params.int64_data:
int64_data = params.int64_data
else:
int64_data = decode_raw_data(params)
data_len = len(int64_data)
int64_data = extract_data(params)
data_len = int(np.prod(np.shape(int64_data)))
assert data_len > 0
slot = parameters_slot
model_parameters_info.write(to_bytes(slot.offset, size=32)) # params_offset
model_parameters_info.write(to_bytes(data_len * 8, size=32))
for param in int64_data:
slot.target.write(to_bytes(param, size=64))
slot.offset += 8
slot.target.write(to_bytes(int64_data, size=64))
slot.offset += 8 * data_len
model_parameters_info.write(to_bytes(64, size=8)) # bitwidth
else:
assert False
Expand Down
15 changes: 15 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import onnxoptimizer
import onnxruntime
import onnxruntime.backend as backend
import onnx.helper

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -150,6 +151,13 @@ def load_data_google_speech(start: int, limit: int) -> ModelData:

return ModelData(labels=labels, images=np.array(mfccs, dtype=np.float32), data_layout=DataLayout.NEUTRAL)

def load_data_omniglot(start: int, limit: int) -> ModelData:
n_way = 5
data = np.load(os.path.expanduser('~/.cache/omniglot/omniglot.npy'))[:n_way,:,:,:,:]
labels = list(itertools.chain.from_iterable([[idx]*20 for idx in range(n_way)]))
images = np.concatenate(data)
return ModelData(labels=labels, images=images, data_layout=DataLayout.NEUTRAL)

def kws_dnn_model():
return download_file('https://github.com/ARM-software/ML-KWS-for-MCU/raw/master/Pretrained_models/DNN/DNN_S.pb', 'KWS-DNN_S.pb')

Expand Down Expand Up @@ -326,8 +334,15 @@ def change_batch_size(onnx_model: onnx.ModelProto):
for value_info in itertools.chain(g.value_info, g.input, g.output):
if value_info.name in initializer_names or value_info.name in constant_names:
continue
n = find_node_by_output(onnx_model.graph.node, value_info.name)
if not n and value_info.name not in [inp.name for inp in g.input]:
continue
if n and n.op_type == 'Shape':
continue
shape = value_info.type.tensor_type.shape
if shape.dim and shape.dim[0].dim_param:
if n and n.op_type == 'Concat' and len(shape.dim) == 1:
continue
shape.dim[0].dim_value = 1

# make sure above steps did not break the model
Expand Down

0 comments on commit 749fa97

Please sign in to comment.