如何在TensorFlow Lite Micro中添加自定义操作符(1)
相信大家在部署嵌入式AI应用时一定使用过TensorFlow Lite Micro,以下简称TFLm。 TFLm 是一款专为微控制器和嵌入式设备设计的轻量级机器学习推理框架。它通过模块化算子系统支持各种神经网络层的计算。也就是说,我们不仅可以使用内置的算子运算,还可以自己注册一个新的算子,更加灵活。本期将用两篇文章,以`reshape.cpp` 为例,详细讲解如何在TensorFlow Lite Micro 中添加新算子。
算子注册不仅是模型推理的基础,也是优化性能、减少内存占用的关键环节。掌握了这种机制,开发人员可以更灵活地定制算子,以满足特定的硬件和应用需求。
在TFLite Micro中,每个操作者都需要经历以下关键步骤:
1、内核实现:定义算子的具体计算逻辑
2.参数解析:从FlatBuffer格式解析算子参数
3.算子注册:将算子注册到解析器中,以便模型可以调用
4.内存管理:处理张量的内存分配和释放
算子实现的核心组件
1. 文件结构说明
添加新的算子需要修改以下关键文件,每个文件都有其特定的作用:
微/内核/reshape.cpp #
算子的核心计算逻辑实现
微/micro_mutable_op_resolver.h#
用于动态注册运算符的可变运算符解析器
核心/api/flatbuffer_conversions.h #
FlatBuffer参数解析函数声明
核心/api/flatbuffer_conversions.cpp #
FlatBuffer参数解析函数的具体实现
微/all_ops_resolver.cpp #
全局运算符解析器,包含所有支持的运算符
文件功能详细说明:
`micro/kernels/` 目录:
存储所有算子的具体实现,每个算子一个文件
`micro_mutable_op_resolver.h`:
提供灵活的操作员注册接口,允许用户选择性添加操作员
`flatbuffer_conversions.*`:
处理模型文件中的参数解析并将FlatBuffer格式转换为C++结构
`all_ops_resolver.cpp`:
预定义注册所有标准算子,适合需要完整算子支持的场景
2.核心实现文件分析
2.1 头文件介绍
文件位置:`micro/kernels/reshape.cpp`
#include #include 'tensorflow/lite/c/builtin_op_data.h' #include 'tensorflow/lite/c/common.h' #include 'tensorflow/lite/kernels/internal/tensor_ctypes.h' #include 'tensorflow/lite/kernels/kernel_util.h' h'#include'tensorflow/lite/kernels/op_macros.h'#include'tensorflow/lite/micro/kernels/kernel_util.h'#include'tensorflow/lite/micro/memory_helpers.h'#include'tensorflow/lite/micro/micro_utils.h'
头文件说明:
`builtin_op_data.h`:包含所有内置运算符的参数结构定义
`common.h`:TFLite的基本数据类型和状态码定义
`tensor_ctypes.h`:与张量数据类型相关的实用函数
`kernel_util.h`:运算符实现的常用实用函数
`op_macros.h`:运算符实现中常用的宏定义
`micro/kernels/kernel_util.h`:Micro 版本特定的内核实用程序函数
`memory_helpers.h`:内存管理相关的辅助函数
`micro_utils.h`:常用实用函数的微型版本
2.2 命名空间和常量定义
命名空间tflite {namespaceops {namespacemicro {namespacereshape {constexprintkInputTensor=0;constexprintkOutputTensor=0;
命名空间说明:
`tflite:reshape`:四层命名空间确保代码组织并避免命名冲突
常量定义:`kInputTensor`和`kOutputTensor`定义输入和输出张量的索引。 Reshape 操作只有一个输入和一个输出。
2.3 核心功能实现
ReshapeOutput函数-形状计算逻辑
TfLiteStatus ReshapeOutput(TfLiteContext* 上下文,TfLiteNode* 节点) { MicroContext* micro_context=GetMicroContext(context); //获取输入和输出张量- 使用临时分配以避免持久内存使用TfLiteTensor* input=micro_context-AllocateTempInputTensor(node, kInputTensor); TfLiteTensor* 输出=micro_context-AllocateTempOutputTensor(node, kOutputTensor); //计算输入元素总数——用于验证重塑操作的合法性intnum_input_elements=NumElements(input); TfLiteIntArray* 输出形状=输出尺寸; //处理特殊情况:自动计算-1维度//TensorFlow允许将某一维度设置为-1,表示根据其他维度自动推理intnum_output_elements=1;intstretch_dim=-1;for(inti=0; i output_shape-size; ++i) { intvalue=output_shape-data[i]; }如果(值==-1){ TF_LITE_ENSURE_EQ(上下文,stretch_dim,-1); //确保只有一个-1 维stretch_dim=i; }else{ num_output_elements *=值; } }//如果有-1维度,自动计算其大小if(stretch_dim !=-1) { TfLiteEvalTensor* output_eval=tflite:micro:GetEvalOutput(context, node, kOutputTensor); TF_LITE_ENSURE_STATUS(tflite:micro:CreateWritableTensorDimsWithCopy(上下文、输出、output_eval));输出形状=输出尺寸; //更新形状指针output_shape-data[stretch_dim]=num_input_elements/num_output_elements; num_output_elements *=output_shape-data[stretch_dim]; } //确保输入和输出元素的数量一致- Reshape 不会改变元素总数TF_LITE_ENSURE_EQ(context, num_input_elements, num_output_elements); TF_LITE_ENSURE_TYPES_EQ(上下文,输入类型,输出类型); //确保数据类型一致//释放临时张量- 避免内存泄漏micro_context-DeallocateTempTfLiteTensor(input); micro_context-DeallocateTempTfLiteTensor(输出);returnkTfLiteOk;}
功能详细解释:
临时张量分配:使用`AllocateTempInputTensor`和`AllocateTempOutputTensor`来获取张量信息。这些是临时分配,不占用持久内存。
形状验证:保证reshape操作的合法性,输入和输出元素总数必须相等
自动维度推断:处理-1维度的特殊情况,这是TensorFlow的标准功能
内存管理:及时释放临时张量,这在内存受限的微控制器环境中非常重要
准备函数-运算符准备阶段:
TfLiteStatusPrepare(TfLiteContext* context, TfLiteNode* node){//验证输入和输出的数量- Reshape 可以有1 或2 个输入(第二个输入是可选的形状参数) TF_LITE_ENSURE(context, NumInputs(node)==1|| NumInputs(node)==2); TF_LITE_ENSURE_EQ(上下文, NumOutputs(节点),1); //只有一个输出//执行输出重塑——在准备阶段确定最终输出形状TF_LITE_ENSURE_EQ(context, ReshapeOutput(context, node), kTfLiteOk);returnkTfLiteOk;}
准备功能说明:
输入验证:Reshape操作支持1-2个输入,第二个输入是可选的形状张量
形状计算:在准备阶段确定输出形状,避免执行阶段重复计算
错误检查:使用`TF_LITE_ENSURE`宏进行参数验证,失败则返回错误状态。
Eval 函数运算符执行阶段:
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {//获取输入输出张量——使用EvalTensor进行实际计算constTfLiteEvalTensor* input=tflite:GetEvalInput(context, node, kInputTensor); TfLiteEvalTensor* output=tflite:GetEvalOutput(context, node, kOutputTensor);//计算输入数据大小-要复制的字节数size_t input_bytes; TF_LITE_ENSURE_STATUS(TfLiteTypeSizeOf(输入类型,input_bytes)); input_bytes *=ElementCount(*input-dims);//执行数据复制(如果不是就地操作) //就地操作:输入和输出使用相同的内存,不需要复制if(input-data.raw !=output-data.raw) { memcpy(output-data.raw, input-data.raw, input_bytes); } }returnkTfLiteOk;}
评估函数说明:
EvalTensor用法:`TfLiteEvalTensor`用于执行阶段,其中包含实际的数据指针
原地操作优化:检查输入和输出是否共享内存,以避免不必要的数据复制
内存复制:Reshape操作本质上只是改变数据解释的方式,但不改变数据内容。
2.4 算子注册功能
TfLiteRegistration_V1 Register_RESHAPE() {returntflite:micro:RegisterOp(nullptr, reshape:Prepare, reshape:Eval);}} //命名空间重塑} //命名空间micro} //命名空间操作} //命名空间tflite
注册功能说明:
RegisterOp函数:创建操作符注册结构,包括初始化、准备和执行函数指针
nullptr参数:第一个参数是初始化函数。 Reshape不需要特殊的初始化,所以传入nullptr。
函数指针:传入Prepare和Eval函数指针,框架会在适当的时候调用这些函数。
3. 参数分析的实现
3.1 解析函数声明
文件位置:`core/api/flatbuffer_conversions.h`
TfLiteStatusParseReshape(constOperator* op,ErrorReporter* error_reporter,BuiltinDataAllocator* 分配器,voidbuiltin_data);
声明说明:
Operator* op:来自FlatBuffer的运算符定义,包括所有参数信息
ErrorReporter:用于报告解析过程中的错误
builtinDataAllocator:用于分配参数结构的专用内存分配器
builtin_data:输出参数,指向解析后的参数结构体
3.2 分析函数实现
文件位置:`core/api/flatbuffer_conversions.cpp`
TfLiteStatusParseReshape(constOperator * op,ErrorReporter * error_reporter,BuiltinDataAllocator *分配器,voidbuiltin_data);
解析函数详细解释:
参数验证:`CheckParsePointerParams` 确保所有指针参数都有效
安全分配器:`SafeBuiltinDataAllocator`提供异常安全的内存分配
FlatBuffer解析:从序列化模型文件中提取reshape参数
格式转换:将FlatBuffer格式转换为TFLite内部使用的C结构格式
所有权转移:使用release()将参数结构的所有权转移给框架
3.3 在解析开关中添加对应的case
文件位置:`core/api/flatbuffer_conversions.cpp`
在`ParseOpData`函数的switch语句中添加:
caseBuiltinOperator_RESHAPE: {returnParseReshape(op, error_reporter, 分配器,builtin_data);}
switch语句说明:
这个switch语句是TFLite参数解析的核心调度机制
根据运算符类型调用对应的解析函数
`BuiltinOperator_RESHAPE` 是FlatBuffer 模式中定义的枚举值
通过本指南,我们深入了解了TensorFlow Lite Micro的算子注册机制,包括其设计理念、实现方式以及在嵌入式场景中的重要性。
未来,随着边缘计算和微控制器AI的快速发展,理解和应用这些底层机制将成为构建高效、可扩展的AI系统的核心能力。建议读者在实践中尝试自定义算子注册,并结合实际项目进行优化,才能真正释放TensorFlow Lite Micro的潜力。