最近项目遇到了,某些模型架构不一样,需要转换的需求。我分配到进行模型转换(别人负责提到)
ONNX 介绍
什么是 ONNX?
ONNX(Open Neural Network Exchange) 是一个开放的机器学习模型格式标准,用于在不同框架和工具之间交换模型。
核心特点
- 跨框架互操作性
- 支持 PyTorch、TensorFlow、Keras、scikit-learn 等框架
- 模型可在不同框架间转换和运行
- 跨平台部署
- 支持 CPU、GPU、移动设备、边缘设备
- 可在 Windows、Linux、macOS、iOS、Android 等平台运行
- 优化推理
- 提供 ONNX Runtime 等运行时
- 支持多种硬件加速(CUDA、TensorRT、OpenVINO 等)
主要用途
- 模型转换:将训练好的模型转换为 ONNX 格式
- 模型部署:在生产环境中使用 ONNX Runtime 进行推理
- 模型优化:通过 ONNX 工具链进行模型优化和量化
- 跨平台部署:一次训练,多平台部署
工作流程
训练模型 (PyTorch/TensorFlow)
↓
转换为 ONNX 格式
↓
使用 ONNX Runtime 进行推理
常用工具
- ONNX Runtime:高性能推理引擎
- onnx-tf:ONNX 与 TensorFlow 互转
- torch.onnx:PyTorch 导出 ONNX
- onnx-simplifier:简化 ONNX 模型
优势
- ✅ 标准化格式,减少厂商锁定
- ✅ 高性能推理
- ✅ 广泛的硬件支持
- ✅ 活跃的社区和工具生态
介绍 ONNX 模型的主流使用方法,涵盖常见场景和最佳实践。
[2 tools called]
ONNX 模型的主流使用方法详解
一、ONNX 模型的主要使用场景
1. 模型转换与导出
PyTorch → ONNX
import torch
import torch.onnx
# 加载训练好的模型
model = YourModel()
model.load_state_dict(torch.load('model.pth'))
model.eval()
# 创建示例输入
dummy_input = torch.randn(1, 3, 224, 224)
# 导出为 ONNX
torch.onnx.export(
model, # 模型
dummy_input, # 示例输入
"model.onnx", # 输出路径
export_params=True, # 保存训练参数
opset_version=11, # ONNX 操作集版本
do_constant_folding=True, # 常量折叠优化
input_names=['input'], # 输入名称
output_names=['output'], # 输出名称
dynamic_axes={
'input': {0: 'batch_size'},
'output': {0: 'batch_size'}
} # 支持动态batch size
)
TensorFlow → ONNX
import tf2onnx
import onnx
# 使用 tf2onnx 转换
spec = (tf.TensorSpec((None, 224, 224, 3), tf.float32, name="input"),)
output_path = "model.onnx"
model_proto, _ = tf2onnx.convert.from_keras(model, input_signature=spec, opset=13, output_path=output_path)
2. 模型推理(Python)
使用 ONNX Runtime(推荐)
import onnxruntime as ort
import numpy as np
# 创建推理会话
session = ort.InferenceSession(
"model.onnx",
providers=['CUDAExecutionProvider', 'CPUExecutionProvider'] # GPU优先,CPU备用
)
# 获取输入输出信息
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
# 准备输入数据
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)
# 执行推理
outputs = session.run([output_name], {input_name: input_data})
result = outputs[0]
性能优化配置
# 使用优化选项
session_options = ort.SessionOptions()
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
session_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
session_options.enable_mem_pattern = True
session_options.enable_cpu_mem_arena = True
session = ort.InferenceSession("model.onnx", session_options, providers=['CPUExecutionProvider'])
3. Node.js 环境使用(适合你的项目)
安装依赖
npm install onnxruntime-node
基本使用
const ort = require('onnxruntime-node');
const fs = require('fs');
async function runInference() {
// 创建推理会话
const session = await ort.InferenceSession.create('./model.onnx', {
executionProviders: ['cpu'], // 或 ['cuda', 'cpu'] 如果有GPU
graphOptimizationLevel: 'all'
});
// 准备输入数据
const inputTensor = new ort.Tensor('float32',
new Float32Array([1, 2, 3, 4, 5, 6, 7, 8, 9]),
[1, 3, 3] // shape
);
// 执行推理
const feeds = { [session.inputNames[0]]: inputTensor };
const results = await session.run(feeds);
console.log(results[session.outputNames[0]].data);
}
在 Express 服务器中集成
// 在你的 server.js 中添加
const ort = require('onnxruntime-node');
let modelSession = null;
// 初始化模型(启动时加载一次)
async function initModel() {
try {
modelSession = await ort.InferenceSession.create('./models/your_model.onnx', {
executionProviders: ['cpu'],
graphOptimizationLevel: 'all'
});
console.log('✅ ONNX 模型加载成功');
} catch (error) {
console.error('❌ 模型加载失败:', error);
}
}
// API 端点示例
app.post('/api/predict', async (req, res) => {
try {
const { inputData } = req.body;
// 转换为 Float32Array
const inputTensor = new ort.Tensor(
'float32',
new Float32Array(inputData),
[1, inputData.length]
);
// 推理
const feeds = { [modelSession.inputNames[0]]: inputTensor };
const results = await modelSession.run(feeds);
res.json({
success: true,
prediction: Array.from(results[modelSession.outputNames[0]].data)
});
} catch (error) {
res.status(500).json({ success: false, error: error.message });
}
});
// 启动时初始化
initModel();
4. 模型优化
模型量化(减小体积,加速推理)
from onnxruntime.quantization import quantize_dynamic, QuantType
# 动态量化(推荐,简单快速)
quantize_dynamic(
model_input="model.onnx",
model_output="model_quantized.onnx",
weight_type=QuantType.QUInt8 # 或 QuantType.QInt8
)
# 量化后模型通常减小 75%,推理速度提升 2-4 倍
模型简化
import onnx
from onnxsim import simplify
# 加载模型
model = onnx.load("model.onnx")
# 简化模型
simplified_model, check = simplify(model)
# 保存简化后的模型
onnx.save(simplified_model, "model_simplified.onnx")
5. Web 浏览器中使用
使用 ONNX.js
<script src="https://cdn.jsdelivr.net/npm/onnxjs/dist/onnx.min.js"></script>
<script>
async function runInference() {
const session = new onnx.InferenceSession();
await session.loadModel("./model.onnx");
const inputTensor = new onnx.Tensor(
new Float32Array([1, 2, 3, 4]),
[1, 4]
);
const outputMap = await session.run([inputTensor]);
const output = outputMap.values().next().value;
console.log(output.data);
}
</script>
6. 移动端部署
iOS (Core ML)
# 使用 coremltools 转换
python -m coremltools.converters.onnx.convert model.onnx --output model.mlmodel
Android (TensorFlow Lite)
# ONNX → TensorFlow → TFLite
import onnx
from onnx_tf.backend import prepare
import tensorflow as tf
onnx_model = onnx.load("model.onnx")
tf_rep = prepare(onnx_model)
tf_rep.export_graph("model_tf")
# 转换为 TFLite
converter = tf.lite.TFLiteConverter.from_saved_model("model_tf")
tflite_model = converter.convert()
二、主流使用流程
完整工作流
1. 训练模型 (PyTorch/TensorFlow)
↓
2. 导出为 ONNX 格式
↓
3. 模型优化(量化、简化)
↓
4. 部署到目标平台
- Python 服务:ONNX Runtime
- Node.js 服务:onnxruntime-node
- Web 前端:ONNX.js
- 移动端:Core ML / TFLite
- 边缘设备:ONNX Runtime + 硬件加速
三、性能优化最佳实践
- 选择合适的执行提供者
- GPU:
CUDAExecutionProvider - CPU:
CPUExecutionProvider - 移动端:
CoreMLExecutionProvider(iOS)
- 批量处理
# 批量推理比单次推理效率高
batch_input = np.random.randn(32, 3, 224, 224).astype(np.float32)
outputs = session.run([output_name], {input_name: batch_input})
- 使用 IO Binding(减少数据拷贝)
io_binding = session.io_binding()
io_binding.bind_input('input', 'cuda', 0, np.float32, [1, 3, 224, 224], input_data)
io_binding.bind_output('output', 'cuda')
session.run_with_iobinding(io_binding)
四、常见应用场景
- 图像分类/目标检测:YOLO、ResNet、MobileNet
- 自然语言处理:BERT、GPT 等 Transformer
- 语音识别:Whisper、Wav2Vec
- 推荐系统:深度推荐模型
- 时间序列预测:LSTM、GRU