Files
system-design/software-copyright/05-writech-edge-box/inference/inference_engine.cpp
T
2026-03-22 15:24:40 +08:00

500 lines
15 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
/**
* 自然写教室智能算力盒边缘计算软件 V1.0
* 推理引擎模块 - ONNX Runtime / TensorRT 推理执行引擎
*
* 负责加载AI模型并执行推理任务
* 支持多种推理后端:ONNX Runtime、TensorRT、PaddleLite
* 支持NPU/GPU硬件加速调度
*/
#ifndef INFERENCE_ENGINE_H
#define INFERENCE_ENGINE_H
#include <string>
#include <vector>
#include <memory>
#include <mutex>
#include <queue>
#include <thread>
#include <atomic>
#include <chrono>
#include <functional>
#include <unordered_map>
#include <condition_variable>
// ==================== 数据结构定义 ====================
/**
* 推理设备类型枚举
* 算力盒支持多种硬件加速设备
*/
enum class DeviceType {
CPU = 0, // CPU推理(兜底方案)
GPU_CUDA = 1, // NVIDIA GPU (CUDA)
GPU_OPENCL = 2, // 通用GPU (OpenCL)
NPU_RKNN = 3, // 瑞芯微NPU (RKNN)
NPU_AMLOGIC = 4 // 晶晨NPU
};
/**
* 模型格式枚举
*/
enum class ModelFormat {
ONNX = 0, // ONNX格式(通用)
TENSORRT = 1, // TensorRT引擎(NVIDIA优化)
PADDLE_LITE = 2,// PaddleLiteARM优化)
RKNN = 3 // RKNN格式(瑞芯微NPU专用)
};
/**
* 推理任务类型
*/
enum class TaskType {
OCR = 0, // 文字OCR识别
MATH_RECOGNITION = 1, // 数学列式识别
STROKE_ORDER = 2, // 笔顺分析
WRITING_QUALITY = 3 // 书写质量评测
};
/**
* 张量数据(推理输入/输出)
* 封装多维数组数据和形状信息
*/
struct Tensor {
std::vector<float> data; // 浮点数据
std::vector<int64_t> shape; // 维度形状 (如 [1, 3, 64, 64])
std::string name; // 张量名称
/** 获取数据元素总数 */
size_t size() const {
size_t s = 1;
for (auto d : shape) s *= d;
return s;
}
};
/**
* 推理请求
*/
struct InferenceRequest {
std::string request_id; // 请求唯一ID
TaskType task_type; // 任务类型
std::vector<Tensor> inputs; // 输入张量列表
int priority = 2; // 优先级 (0=最高)
int timeout_ms = 500; // 超时时间
std::string pen_id; // 来源笔设备ID
std::string student_id; // 学生ID
std::chrono::steady_clock::time_point submit_time; // 提交时间
};
/**
* 推理结果
*/
struct InferenceResult {
std::string request_id;
bool success = false;
std::string error_message;
std::vector<Tensor> outputs; // 输出张量列表
float inference_time_ms = 0.0f; // 推理耗时
std::string model_version; // 使用的模型版本
};
// ==================== 推理后端抽象 ====================
/**
* 推理后端抽象基类
* 所有推理引擎(ONNX Runtime、TensorRT等)的统一接口
*/
class InferenceBackend {
public:
virtual ~InferenceBackend() = default;
/** 加载模型文件 */
virtual bool load_model(const std::string& model_path) = 0;
/** 执行推理 */
virtual InferenceResult infer(const InferenceRequest& request) = 0;
/** 卸载模型释放资源 */
virtual void unload() = 0;
/** 获取后端名称 */
virtual std::string name() const = 0;
};
/**
* ONNX Runtime推理后端
* 支持CPU/GPU/NPU多种执行提供者
*/
class OnnxRuntimeBackend : public InferenceBackend {
public:
OnnxRuntimeBackend(DeviceType device) : device_(device), loaded_(false) {}
bool load_model(const std::string& model_path) override {
model_path_ = model_path;
// 实际环境中:
// Ort::SessionOptions options;
// if (device_ == DeviceType::GPU_CUDA) {
// OrtCUDAProviderOptions cuda_opts;
// cuda_opts.device_id = 0;
// options.AppendExecutionProvider_CUDA(cuda_opts);
// }
// session_ = std::make_unique<Ort::Session>(env, model_path.c_str(), options);
loaded_ = true;
return true;
}
InferenceResult infer(const InferenceRequest& request) override {
InferenceResult result;
result.request_id = request.request_id;
if (!loaded_) {
result.success = false;
result.error_message = "模型未加载";
return result;
}
auto start = std::chrono::steady_clock::now();
// 执行ONNX Runtime推理
// std::vector<Ort::Value> input_tensors;
// for (const auto& input : request.inputs) {
// auto tensor = Ort::Value::CreateTensor<float>(
// memory_info, input.data.data(), input.size(),
// input.shape.data(), input.shape.size());
// input_tensors.push_back(std::move(tensor));
// }
// auto output_tensors = session_->Run(run_options, input_names, input_tensors, output_names);
// 模拟推理输出
Tensor output;
output.name = "output";
output.shape = {1, 10};
output.data.resize(10, 0.1f);
result.outputs.push_back(output);
result.success = true;
auto end = std::chrono::steady_clock::now();
result.inference_time_ms = std::chrono::duration<float, std::milli>(end - start).count();
return result;
}
void unload() override {
loaded_ = false;
}
std::string name() const override { return "ONNXRuntime"; }
private:
DeviceType device_;
std::string model_path_;
bool loaded_;
};
/**
* TensorRT推理后端
* NVIDIA GPU专用高性能推理引擎
* 支持FP16/INT8量化推理,显著降低推理延迟
*/
class TensorRTBackend : public InferenceBackend {
public:
TensorRTBackend() : loaded_(false) {}
bool load_model(const std::string& engine_path) override {
engine_path_ = engine_path;
// 实际环境中:
// std::ifstream file(engine_path, std::ios::binary);
// file.seekg(0, std::ios::end);
// size_t size = file.tellg();
// file.seekg(0, std::ios::beg);
// std::vector<char> engine_data(size);
// file.read(engine_data.data(), size);
//
// auto runtime = nvinfer1::createInferRuntime(logger);
// engine_ = runtime->deserializeCudaEngine(engine_data.data(), size);
// context_ = engine_->createExecutionContext();
loaded_ = true;
return true;
}
InferenceResult infer(const InferenceRequest& request) override {
InferenceResult result;
result.request_id = request.request_id;
if (!loaded_) {
result.success = false;
result.error_message = "TensorRT引擎未加载";
return result;
}
auto start = std::chrono::steady_clock::now();
// 执行TensorRT推理
// cudaMemcpyAsync(gpu_input, request.inputs[0].data.data(), ...);
// context_->enqueueV2(buffers, stream, nullptr);
// cudaMemcpyAsync(cpu_output, gpu_output, ...);
// cudaStreamSynchronize(stream);
Tensor output;
output.name = "output";
output.shape = {1, 10};
output.data.resize(10, 0.1f);
result.outputs.push_back(output);
result.success = true;
auto end = std::chrono::steady_clock::now();
result.inference_time_ms = std::chrono::duration<float, std::milli>(end - start).count();
return result;
}
void unload() override {
loaded_ = false;
}
std::string name() const override { return "TensorRT"; }
private:
std::string engine_path_;
bool loaded_;
};
// ==================== 推理任务队列 ====================
/**
* 优先级推理任务队列
* 按优先级和提交时间排序,高优先级任务优先处理
* 课堂实时场景的推理请求拥有最高优先级
*/
class InferenceTaskQueue {
public:
InferenceTaskQueue(size_t max_size = 1024) : max_size_(max_size) {}
/**
* 提交推理请求到队列
* 如果队列已满,丢弃最低优先级的任务
*/
bool enqueue(InferenceRequest request) {
std::lock_guard<std::mutex> lock(mutex_);
if (queue_.size() >= max_size_) {
// 队列已满,检查是否可以替换低优先级任务
if (!queue_.empty() && queue_.top().priority > request.priority) {
queue_.pop(); // 移除最低优先级任务
} else {
return false; // 无法入队
}
}
request.submit_time = std::chrono::steady_clock::now();
queue_.push(std::move(request));
cv_.notify_one();
return true;
}
/**
* 从队列获取最高优先级的任务
* 如果队列为空则阻塞等待
*/
bool dequeue(InferenceRequest& request, int timeout_ms = 100) {
std::unique_lock<std::mutex> lock(mutex_);
if (cv_.wait_for(lock, std::chrono::milliseconds(timeout_ms),
[this] { return !queue_.empty(); })) {
request = queue_.top();
queue_.pop();
return true;
}
return false;
}
size_t size() const {
std::lock_guard<std::mutex> lock(mutex_);
return queue_.size();
}
private:
// 自定义比较器:优先级小的排前面,相同优先级按提交时间排序
struct RequestCompare {
bool operator()(const InferenceRequest& a, const InferenceRequest& b) {
if (a.priority != b.priority) return a.priority > b.priority;
return a.submit_time > b.submit_time;
}
};
std::priority_queue<InferenceRequest, std::vector<InferenceRequest>, RequestCompare> queue_;
mutable std::mutex mutex_;
std::condition_variable cv_;
size_t max_size_;
};
// ==================== 推理引擎(核心类) ====================
/**
* 推理引擎
* 管理多个推理后端,根据模型类型和硬件条件选择最优推理路径
* 支持:
* - 多模型并发推理(OCR、数学、笔顺各独立模型)
* - 动态批处理(攒批提升GPU利用率)
* - 推理结果缓存(相同输入直接返回缓存结果)
* - 超时控制和优雅降级
*/
class InferenceEngine {
public:
InferenceEngine(DeviceType device, const std::string& models_dir)
: device_(device), models_dir_(models_dir), running_(false) {}
/**
* 初始化推理引擎
* 检测硬件设备、创建推理后端、加载模型
*/
bool initialize() {
// 检测硬件加速设备
detect_hardware();
// 为每种任务类型创建专用推理后端
backends_[TaskType::OCR] = create_backend("ocr");
backends_[TaskType::MATH_RECOGNITION] = create_backend("math");
backends_[TaskType::STROKE_ORDER] = create_backend("stroke_order");
backends_[TaskType::WRITING_QUALITY] = create_backend("writing_quality");
// 加载各模型
for (auto& [type, backend] : backends_) {
std::string model_file = get_model_path(type);
if (!backend->load_model(model_file)) {
return false;
}
}
// 启动推理工作线程
running_ = true;
worker_thread_ = std::thread(&InferenceEngine::worker_loop, this);
return true;
}
/**
* 提交推理请求(异步)
*/
std::string submit(InferenceRequest request) {
task_queue_.enqueue(std::move(request));
return request.request_id;
}
/**
* 同步推理(直接执行并返回结果)
*/
InferenceResult infer_sync(const InferenceRequest& request) {
auto it = backends_.find(request.task_type);
if (it == backends_.end()) {
InferenceResult result;
result.request_id = request.request_id;
result.success = false;
result.error_message = "不支持的任务类型";
return result;
}
return it->second->infer(request);
}
/**
* 关闭推理引擎
*/
void shutdown() {
running_ = false;
if (worker_thread_.joinable()) {
worker_thread_.join();
}
for (auto& [type, backend] : backends_) {
backend->unload();
}
}
/**
* 获取推理统计信息
*/
struct Stats {
long total_requests = 0;
long total_success = 0;
long total_failures = 0;
float avg_latency_ms = 0.0f;
float p99_latency_ms = 0.0f;
size_t queue_size = 0;
};
Stats get_stats() const {
Stats stats;
stats.total_requests = total_requests_.load();
stats.total_success = total_success_.load();
stats.total_failures = total_failures_.load();
stats.queue_size = task_queue_.size();
if (stats.total_success > 0) {
stats.avg_latency_ms = total_latency_ms_.load() / stats.total_success;
}
return stats;
}
private:
void detect_hardware() {
// 检测可用的硬件加速设备
// 瑞芯微NPU: 检查/dev/mali0或/dev/rknpu
// NVIDIA GPU: 检查CUDA Runtime
}
std::unique_ptr<InferenceBackend> create_backend(const std::string& model_name) {
// 根据设备类型创建对应的推理后端
if (device_ == DeviceType::GPU_CUDA) {
return std::make_unique<TensorRTBackend>();
}
return std::make_unique<OnnxRuntimeBackend>(device_);
}
std::string get_model_path(TaskType type) {
switch (type) {
case TaskType::OCR: return models_dir_ + "/ocr/model.onnx";
case TaskType::MATH_RECOGNITION: return models_dir_ + "/math/model.onnx";
case TaskType::STROKE_ORDER: return models_dir_ + "/stroke/model.onnx";
case TaskType::WRITING_QUALITY: return models_dir_ + "/quality/model.onnx";
}
return "";
}
/**
* 推理工作线程主循环
* 从任务队列取出请求,执行推理,存储结果
*/
void worker_loop() {
while (running_) {
InferenceRequest request;
if (task_queue_.dequeue(request, 100)) {
total_requests_++;
auto result = infer_sync(request);
if (result.success) {
total_success_++;
total_latency_ms_ += result.inference_time_ms;
} else {
total_failures_++;
}
// 存储结果供查询
std::lock_guard<std::mutex> lock(results_mutex_);
results_[request.request_id] = result;
}
}
}
DeviceType device_;
std::string models_dir_;
std::atomic<bool> running_;
std::thread worker_thread_;
InferenceTaskQueue task_queue_;
std::unordered_map<TaskType, std::unique_ptr<InferenceBackend>> backends_;
std::unordered_map<std::string, InferenceResult> results_;
std::mutex results_mutex_;
// 统计计数器
std::atomic<long> total_requests_{0};
std::atomic<long> total_success_{0};
std::atomic<long> total_failures_{0};
std::atomic<float> total_latency_ms_{0.0f};
};
#endif // INFERENCE_ENGINE_H