面试 · 2025年11月27日 0

ONNX 介绍与使用

最近项目遇到了,某些模型架构不一样,需要转换的需求。我分配到进行模型转换(别人负责提到)

ONNX 介绍

什么是 ONNX?

ONNX(Open Neural Network Exchange) 是一个开放的机器学习模型格式标准,用于在不同框架和工具之间交换模型。

核心特点

  1. 跨框架互操作性
  • 支持 PyTorch、TensorFlow、Keras、scikit-learn 等框架
  • 模型可在不同框架间转换和运行
  1. 跨平台部署
  • 支持 CPU、GPU、移动设备、边缘设备
  • 可在 Windows、Linux、macOS、iOS、Android 等平台运行
  1. 优化推理
  • 提供 ONNX Runtime 等运行时
  • 支持多种硬件加速(CUDA、TensorRT、OpenVINO 等)

主要用途

  • 模型转换:将训练好的模型转换为 ONNX 格式
  • 模型部署:在生产环境中使用 ONNX Runtime 进行推理
  • 模型优化:通过 ONNX 工具链进行模型优化和量化
  • 跨平台部署:一次训练,多平台部署

工作流程

训练模型 (PyTorch/TensorFlow) 
    ↓
转换为 ONNX 格式
    ↓
使用 ONNX Runtime 进行推理

常用工具

  • ONNX Runtime:高性能推理引擎
  • onnx-tf:ONNX 与 TensorFlow 互转
  • torch.onnx:PyTorch 导出 ONNX
  • onnx-simplifier:简化 ONNX 模型

优势

  • ✅ 标准化格式,减少厂商锁定
  • ✅ 高性能推理
  • ✅ 广泛的硬件支持
  • ✅ 活跃的社区和工具生态

介绍 ONNX 模型的主流使用方法,涵盖常见场景和最佳实践。

[2 tools called]

ONNX 模型的主流使用方法详解

一、ONNX 模型的主要使用场景

1. 模型转换与导出

PyTorch → ONNX

import torch
import torch.onnx

# 加载训练好的模型
model = YourModel()
model.load_state_dict(torch.load('model.pth'))
model.eval()

# 创建示例输入
dummy_input = torch.randn(1, 3, 224, 224)

# 导出为 ONNX
torch.onnx.export(
    model,                          # 模型
    dummy_input,                    # 示例输入
    "model.onnx",                  # 输出路径
    export_params=True,            # 保存训练参数
    opset_version=11,              # ONNX 操作集版本
    do_constant_folding=True,      # 常量折叠优化
    input_names=['input'],         # 输入名称
    output_names=['output'],       # 输出名称
    dynamic_axes={
        'input': {0: 'batch_size'},
        'output': {0: 'batch_size'}
    }  # 支持动态batch size
)

TensorFlow → ONNX

import tf2onnx
import onnx

# 使用 tf2onnx 转换
spec = (tf.TensorSpec((None, 224, 224, 3), tf.float32, name="input"),)
output_path = "model.onnx"
model_proto, _ = tf2onnx.convert.from_keras(model, input_signature=spec, opset=13, output_path=output_path)

2. 模型推理(Python)

使用 ONNX Runtime(推荐)

import onnxruntime as ort
import numpy as np

# 创建推理会话
session = ort.InferenceSession(
    "model.onnx",
    providers=['CUDAExecutionProvider', 'CPUExecutionProvider']  # GPU优先,CPU备用
)

# 获取输入输出信息
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name

# 准备输入数据
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)

# 执行推理
outputs = session.run([output_name], {input_name: input_data})
result = outputs[0]

性能优化配置

# 使用优化选项
session_options = ort.SessionOptions()
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
session_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
session_options.enable_mem_pattern = True
session_options.enable_cpu_mem_arena = True

session = ort.InferenceSession("model.onnx", session_options, providers=['CPUExecutionProvider'])

3. Node.js 环境使用(适合你的项目)

安装依赖

npm install onnxruntime-node

基本使用

const ort = require('onnxruntime-node');
const fs = require('fs');

async function runInference() {
    // 创建推理会话
    const session = await ort.InferenceSession.create('./model.onnx', {
        executionProviders: ['cpu'],  // 或 ['cuda', 'cpu'] 如果有GPU
        graphOptimizationLevel: 'all'
    });

    // 准备输入数据
    const inputTensor = new ort.Tensor('float32', 
        new Float32Array([1, 2, 3, 4, 5, 6, 7, 8, 9]), 
        [1, 3, 3]  // shape
    );

    // 执行推理
    const feeds = { [session.inputNames[0]]: inputTensor };
    const results = await session.run(feeds);

    console.log(results[session.outputNames[0]].data);
}

在 Express 服务器中集成

// 在你的 server.js 中添加
const ort = require('onnxruntime-node');
let modelSession = null;

// 初始化模型(启动时加载一次)
async function initModel() {
    try {
        modelSession = await ort.InferenceSession.create('./models/your_model.onnx', {
            executionProviders: ['cpu'],
            graphOptimizationLevel: 'all'
        });
        console.log('✅ ONNX 模型加载成功');
    } catch (error) {
        console.error('❌ 模型加载失败:', error);
    }
}

// API 端点示例
app.post('/api/predict', async (req, res) => {
    try {
        const { inputData } = req.body;

        // 转换为 Float32Array
        const inputTensor = new ort.Tensor(
            'float32',
            new Float32Array(inputData),
            [1, inputData.length]
        );

        // 推理
        const feeds = { [modelSession.inputNames[0]]: inputTensor };
        const results = await modelSession.run(feeds);

        res.json({
            success: true,
            prediction: Array.from(results[modelSession.outputNames[0]].data)
        });
    } catch (error) {
        res.status(500).json({ success: false, error: error.message });
    }
});

// 启动时初始化
initModel();

4. 模型优化

模型量化(减小体积,加速推理)

from onnxruntime.quantization import quantize_dynamic, QuantType

# 动态量化(推荐,简单快速)
quantize_dynamic(
    model_input="model.onnx",
    model_output="model_quantized.onnx",
    weight_type=QuantType.QUInt8  # 或 QuantType.QInt8
)

# 量化后模型通常减小 75%,推理速度提升 2-4 倍

模型简化

import onnx
from onnxsim import simplify

# 加载模型
model = onnx.load("model.onnx")

# 简化模型
simplified_model, check = simplify(model)

# 保存简化后的模型
onnx.save(simplified_model, "model_simplified.onnx")

5. Web 浏览器中使用

使用 ONNX.js

<script src="https://cdn.jsdelivr.net/npm/onnxjs/dist/onnx.min.js"></script>
<script>
async function runInference() {
    const session = new onnx.InferenceSession();
    await session.loadModel("./model.onnx");

    const inputTensor = new onnx.Tensor(
        new Float32Array([1, 2, 3, 4]),
        [1, 4]
    );

    const outputMap = await session.run([inputTensor]);
    const output = outputMap.values().next().value;
    console.log(output.data);
}
</script>

6. 移动端部署

iOS (Core ML)

# 使用 coremltools 转换
python -m coremltools.converters.onnx.convert model.onnx --output model.mlmodel

Android (TensorFlow Lite)

# ONNX → TensorFlow → TFLite
import onnx
from onnx_tf.backend import prepare
import tensorflow as tf

onnx_model = onnx.load("model.onnx")
tf_rep = prepare(onnx_model)
tf_rep.export_graph("model_tf")

# 转换为 TFLite
converter = tf.lite.TFLiteConverter.from_saved_model("model_tf")
tflite_model = converter.convert()

二、主流使用流程

完整工作流

1. 训练模型 (PyTorch/TensorFlow)
   ↓
2. 导出为 ONNX 格式
   ↓
3. 模型优化(量化、简化)
   ↓
4. 部署到目标平台
   - Python 服务:ONNX Runtime
   - Node.js 服务:onnxruntime-node
   - Web 前端:ONNX.js
   - 移动端:Core ML / TFLite
   - 边缘设备:ONNX Runtime + 硬件加速

三、性能优化最佳实践

  1. 选择合适的执行提供者
  • GPU:CUDAExecutionProvider
  • CPU:CPUExecutionProvider
  • 移动端:CoreMLExecutionProvider (iOS)
  1. 批量处理
   # 批量推理比单次推理效率高
   batch_input = np.random.randn(32, 3, 224, 224).astype(np.float32)
   outputs = session.run([output_name], {input_name: batch_input})
  1. 使用 IO Binding(减少数据拷贝)
   io_binding = session.io_binding()
   io_binding.bind_input('input', 'cuda', 0, np.float32, [1, 3, 224, 224], input_data)
   io_binding.bind_output('output', 'cuda')
   session.run_with_iobinding(io_binding)

四、常见应用场景

  • 图像分类/目标检测:YOLO、ResNet、MobileNet
  • 自然语言处理:BERT、GPT 等 Transformer
  • 语音识别:Whisper、Wav2Vec
  • 推荐系统:深度推荐模型
  • 时间序列预测:LSTM、GRU