Files
system-design/software-copyright/05-writech-edge-box/自然写教室智能算力盒边缘计算软件-源程序.md
T
2026-03-22 15:24:40 +08:00

3042 lines
93 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# 自然写教室智能算力盒边缘计算软件 V1.0
## 软件著作权鉴别材料 — 源程序
> **权利人**:深圳自然写科技有限公司
> **版本号**V1.0
---
## 源程序目录结构
```
05-writech-edge-box/
├── main.cpp
├── communication/
│ └── grpc_server.cpp
├── config/
│ └── edge_config.cpp
├── inference/
│ ├── inference_engine.cpp
│ ├── model_manager.cpp
│ └── npu_scheduler.cpp
└── preprocessing/
└── stroke_preprocessor.cpp
```
---
## 源程序文件清单
### (根目录)
#### `main.cpp`
```cpp
/**
* 自然写教室智能算力盒边缘计算软件 V1.0
* 主程序入口 - 算力盒边缘计算服务启动与管理
*
* 初始化推理引擎、通信模块、模型管理、监控等子系统
* 运行于ARM/x86算力盒硬件,搭载NPU/GPU加速模块
*/
#include <iostream>
#include <string>
#include <vector>
#include <memory>
#include <thread>
#include <chrono>
#include <csignal>
#include <atomic>
#include <mutex>
#include <functional>
// 前向声明各子系统类
class InferenceEngine;
class ModelManager;
class GrpcServer;
class MqttReporter;
class SystemMonitor;
class OfflineCache;
class ClusterManager;
class OtaManager;
// ==================== 全局状态管理 ====================
// 系统运行状态标志
static std::atomic<bool> g_running(true);
// 系统启动时间戳
static std::chrono::steady_clock::time_point g_start_time;
/**
* 信号处理函数
* 接收SIGINT/SIGTERM信号后优雅关闭所有子系统
*/
void signal_handler(int signum) {
std::cout << "[Main] 接收到信号 " << signum << ",准备优雅关闭..." << std::endl;
g_running.store(false);
}
// ==================== 配置管理 ====================
/**
* 算力盒全局配置
* 从配置文件和环境变量加载运行参数
*/
struct EdgeBoxConfig {
// 设备信息
std::string device_id; // 设备唯一序列号
std::string device_name; // 设备名称
std::string firmware_version; // 固件版本
// gRPC服务配置(与网关数据交互)
std::string grpc_listen_addr = "0.0.0.0:50052";
int grpc_max_connections = 100; // 最大并发连接数
bool grpc_enable_tls = true; // 启用mTLS双向认证
// MQTT配置(与云端状态同步)
std::string mqtt_broker_url = "ssl://mqtt.writech.com:8883";
std::string mqtt_client_id;
int mqtt_keepalive_s = 60; // 心跳间隔
// 推理引擎配置
std::string models_dir = "/opt/models";
std::string inference_device = "npu"; // 推理设备: npu / gpu / cpu
int max_batch_size = 16; // 最大推理批大小
int inference_timeout_ms = 500; // 单次推理超时(毫秒)
// 集群配置
bool enable_cluster = true; // 启用多算力盒集群管理
int mdns_port = 5353; // mDNS服务发现端口
// 离线缓存配置
std::string cache_db_path = "/var/lib/writech/cache.db";
int max_cache_size_mb = 256; // 离线缓存最大容量
// OTA升级配置
std::string ota_server_url = "https://ota.writech.com";
bool ota_auto_check = true; // 自动检查升级
int ota_check_interval_h = 24; // 检查间隔(小时)
// 日志配置
std::string log_dir = "/var/log/writech";
std::string log_level = "INFO";
int log_max_size_mb = 50; // 单个日志文件大小上限
int log_rotate_count = 5; // 日志轮转保留数量
};
/**
* 从JSON配置文件加载配置
* 配置文件路径: /etc/writech/edgebox.json
*/
EdgeBoxConfig load_config(const std::string& config_path) {
EdgeBoxConfig config;
std::cout << "[Config] 加载配置文件: " << config_path << std::endl;
// 读取JSON配置文件并解析
// 实际实现使用nlohmann/json或rapidjson
// 此处使用默认值
// 设备ID从硬件序列号读取
config.device_id = "EB-" + std::to_string(std::hash<std::string>{}("device_serial"));
config.mqtt_client_id = "edgebox_" + config.device_id;
std::cout << "[Config] 配置加载完成: device_id=" << config.device_id << std::endl;
return config;
}
// ==================== 日志系统 ====================
/**
* 日志级别枚举
*/
enum class LogLevel {
DEBUG = 0,
INFO = 1,
WARNING = 2,
ERROR = 3,
CRITICAL = 4
};
/**
* 简易日志记录器
* 支持日志文件轮转和分级输出
*/
class Logger {
public:
static Logger& instance() {
static Logger logger;
return logger;
}
void init(const std::string& log_dir, const std::string& level) {
log_dir_ = log_dir;
if (level == "DEBUG") level_ = LogLevel::DEBUG;
else if (level == "WARNING") level_ = LogLevel::WARNING;
else if (level == "ERROR") level_ = LogLevel::ERROR;
else level_ = LogLevel::INFO;
std::cout << "[Logger] 日志系统初始化: dir=" << log_dir << ", level=" << level << std::endl;
}
void log(LogLevel level, const std::string& module, const std::string& message) {
if (level < level_) return;
std::lock_guard<std::mutex> lock(mutex_);
auto now = std::chrono::system_clock::now();
auto time_t = std::chrono::system_clock::to_time_t(now);
std::string level_str;
switch(level) {
case LogLevel::DEBUG: level_str = "DEBUG"; break;
case LogLevel::INFO: level_str = "INFO"; break;
case LogLevel::WARNING: level_str = "WARN"; break;
case LogLevel::ERROR: level_str = "ERROR"; break;
case LogLevel::CRITICAL: level_str = "CRIT"; break;
}
std::cout << "[" << level_str << "] " << module << ": " << message << std::endl;
}
private:
Logger() = default;
std::string log_dir_;
LogLevel level_ = LogLevel::INFO;
std::mutex mutex_;
};
// 日志宏定义
#define LOG_INFO(mod, msg) Logger::instance().log(LogLevel::INFO, mod, msg)
#define LOG_ERROR(mod, msg) Logger::instance().log(LogLevel::ERROR, mod, msg)
#define LOG_DEBUG(mod, msg) Logger::instance().log(LogLevel::DEBUG, mod, msg)
#define LOG_WARN(mod, msg) Logger::instance().log(LogLevel::WARNING, mod, msg)
// ==================== 健康检查 ====================
/**
* 系统健康状态
*/
struct HealthStatus {
bool inference_engine_ok = false; // 推理引擎状态
bool grpc_server_ok = false; // gRPC服务状态
bool mqtt_connected = false; // MQTT连接状态
bool model_loaded = false; // 模型加载状态
float cpu_usage_percent = 0.0f; // CPU使用率
float memory_usage_percent = 0.0f; // 内存使用率
float gpu_usage_percent = 0.0f; // GPU使用率
float gpu_temperature_c = 0.0f; // GPU温度
int active_connections = 0; // 活跃gRPC连接数
int pending_tasks = 0; // 待处理推理任务数
long uptime_seconds = 0; // 运行时长
};
/**
* 获取系统运行时长
*/
long get_uptime_seconds() {
auto now = std::chrono::steady_clock::now();
return std::chrono::duration_cast<std::chrono::seconds>(now - g_start_time).count();
}
// ==================== 看门狗 ====================
/**
* 软件看门狗
* 监控各子系统运行状态,异常时自动重启对应服务
* 配合硬件看门狗实现双重保护(异常自动重启)
*/
class Watchdog {
public:
Watchdog(int timeout_s = 30) : timeout_s_(timeout_s), last_feed_time_(std::chrono::steady_clock::now()) {}
/**
* 喂狗操作(各子系统定期调用)
*/
void feed(const std::string& module) {
std::lock_guard<std::mutex> lock(mutex_);
feed_records_[module] = std::chrono::steady_clock::now();
}
/**
* 检查是否有子系统超时未喂狗
*/
std::vector<std::string> check_timeouts() {
std::lock_guard<std::mutex> lock(mutex_);
std::vector<std::string> timed_out;
auto now = std::chrono::steady_clock::now();
for (const auto& [module, last_feed] : feed_records_) {
auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(now - last_feed).count();
if (elapsed > timeout_s_) {
timed_out.push_back(module);
LOG_WARN("Watchdog", module + " 超时未响应 (" + std::to_string(elapsed) + "s)");
}
}
return timed_out;
}
private:
int timeout_s_;
std::chrono::steady_clock::time_point last_feed_time_;
std::map<std::string, std::chrono::steady_clock::time_point> feed_records_;
std::mutex mutex_;
};
// ==================== 主函数 ====================
/**
* 算力盒主程序入口
* 启动流程:
* 1. 加载配置文件
* 2. 初始化日志系统
* 3. 初始化推理引擎(加载模型到NPU/GPU)
* 4. 启动gRPC服务(接收网关笔迹数据)
* 5. 启动MQTT客户端(状态上报到云端)
* 6. 启动集群管理(mDNS发现与负载均衡)
* 7. 启动系统监控
* 8. 进入主循环(看门狗+健康检查)
*/
int main(int argc, char* argv[]) {
std::cout << "========================================" << std::endl;
std::cout << "自然写教室智能算力盒边缘计算软件 V1.0" << std::endl;
std::cout << "Copyright (c) 深圳自然写科技有限公司" << std::endl;
std::cout << "========================================" << std::endl;
g_start_time = std::chrono::steady_clock::now();
// 注册信号处理
signal(SIGINT, signal_handler);
signal(SIGTERM, signal_handler);
// 1. 加载配置
std::string config_path = "/etc/writech/edgebox.json";
if (argc > 1) config_path = argv[1];
EdgeBoxConfig config = load_config(config_path);
// 2. 初始化日志
Logger::instance().init(config.log_dir, config.log_level);
LOG_INFO("Main", "算力盒启动中...");
// 3. 初始化看门狗
Watchdog watchdog(30);
// 4. 初始化各子系统(实际环境中创建对应对象)
LOG_INFO("Main", "初始化推理引擎: device=" + config.inference_device);
LOG_INFO("Main", "加载AI模型: " + config.models_dir);
LOG_INFO("Main", "启动gRPC服务: " + config.grpc_listen_addr);
LOG_INFO("Main", "连接MQTT Broker: " + config.mqtt_broker_url);
if (config.enable_cluster) {
LOG_INFO("Main", "启动集群管理(mDNS)");
}
LOG_INFO("Main", "所有子系统初始化完成");
LOG_INFO("Main", "算力盒服务已就绪,等待推理请求...");
// 5. 主循环:看门狗+健康检查
while (g_running.load()) {
// 检查子系统超时
auto timed_out = watchdog.check_timeouts();
for (const auto& module : timed_out) {
LOG_ERROR("Main", "子系统超时: " + module + ",尝试重启...");
}
// 定期上报健康状态
HealthStatus status;
status.uptime_seconds = get_uptime_seconds();
// 休眠1秒后继续检查
std::this_thread::sleep_for(std::chrono::seconds(1));
}
// 6. 优雅关闭
LOG_INFO("Main", "正在关闭算力盒服务...");
LOG_INFO("Main", "等待推理任务完成...");
LOG_INFO("Main", "断开MQTT连接...");
LOG_INFO("Main", "停止gRPC服务...");
LOG_INFO("Main", "算力盒服务已安全关闭");
return 0;
}
```
### `communication/`
#### `communication/grpc_server.cpp`
```cpp
/**
* 自然写教室智能算力盒边缘计算软件 V1.0
* gRPC通信服务模块 - 与教室网关的笔迹数据交互
*
* 实现gRPC流式服务,接收网关转发的笔迹数据流
* 支持mTLS双向认证确保通信安全
*/
#ifndef GRPC_SERVER_H
#define GRPC_SERVER_H
#include <string>
#include <vector>
#include <memory>
#include <mutex>
#include <atomic>
#include <thread>
#include <functional>
#include <unordered_map>
#include <chrono>
#include <queue>
// ==================== gRPC消息结构 ====================
/** 笔迹坐标点(对应protobuf消息) */
struct GrpcStrokePoint {
float x;
float y;
float pressure;
uint32_t timestamp;
bool pen_up;
};
/** 笔迹数据包(对应protobuf消息) */
struct GrpcStrokePacket {
std::string packet_id; // 数据包ID
std::string pen_id; // 笔设备MAC地址
std::string student_id; // 学生ID
std::string page_id; // 点阵码页面ID
std::vector<GrpcStrokePoint> points; // 坐标点序列
uint64_t gateway_timestamp; // 网关转发时间戳
int sequence_number; // 包序号(用于乱序检测)
};
/** 识别结果响应 */
struct GrpcRecognitionResponse {
std::string packet_id; // 对应的请求包ID
std::string recognition_type; // 识别类型(ocr/math/stroke_order
bool success; // 是否成功
std::string result_text; // 识别结果文本
float confidence; // 置信度
float processing_time_ms; // 处理耗时
std::string model_version; // 使用的模型版本
};
// ==================== 连接管理器 ====================
/** 客户端连接信息 */
struct ClientConnection {
std::string client_id; // 客户端标识(网关ID
std::string client_addr; // 客户端地址
std::string cert_fingerprint; // 客户端证书指纹(mTLS
std::chrono::steady_clock::time_point connected_at;
std::chrono::steady_clock::time_point last_active;
long packets_received; // 已接收数据包数
long bytes_received; // 已接收字节数
bool authenticated; // 是否已通过mTLS认证
};
/**
* gRPC连接管理器
* 管理与多个教室网关的gRPC连接
* 每个网关对应一个持久化的gRPC流式连接
*/
class ConnectionManager {
public:
ConnectionManager(int max_connections = 100)
: max_connections_(max_connections) {}
/** 注册新连接 */
bool register_connection(const std::string& client_id, const std::string& addr,
const std::string& cert_fp) {
std::lock_guard<std::mutex> lock(mutex_);
if (static_cast<int>(connections_.size()) >= max_connections_) {
return false; // 达到最大连接数限制
}
ClientConnection conn;
conn.client_id = client_id;
conn.client_addr = addr;
conn.cert_fingerprint = cert_fp;
conn.connected_at = std::chrono::steady_clock::now();
conn.last_active = conn.connected_at;
conn.packets_received = 0;
conn.bytes_received = 0;
conn.authenticated = !cert_fp.empty();
connections_[client_id] = conn;
return true;
}
/** 移除连接 */
void remove_connection(const std::string& client_id) {
std::lock_guard<std::mutex> lock(mutex_);
connections_.erase(client_id);
}
/** 更新连接活跃时间 */
void update_activity(const std::string& client_id, long bytes) {
std::lock_guard<std::mutex> lock(mutex_);
auto it = connections_.find(client_id);
if (it != connections_.end()) {
it->second.last_active = std::chrono::steady_clock::now();
it->second.packets_received++;
it->second.bytes_received += bytes;
}
}
/** 检查空闲超时连接 */
std::vector<std::string> check_idle_connections(int timeout_s = 300) {
std::lock_guard<std::mutex> lock(mutex_);
std::vector<std::string> idle;
auto now = std::chrono::steady_clock::now();
for (const auto& pair : connections_) {
auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
now - pair.second.last_active).count();
if (elapsed > timeout_s) {
idle.push_back(pair.first);
}
}
return idle;
}
/** 获取当前连接数 */
int active_count() const {
std::lock_guard<std::mutex> lock(mutex_);
return static_cast<int>(connections_.size());
}
/** 获取所有连接状态 */
std::vector<ClientConnection> get_all_connections() const {
std::lock_guard<std::mutex> lock(mutex_);
std::vector<ClientConnection> result;
for (const auto& pair : connections_) {
result.push_back(pair.second);
}
return result;
}
private:
std::unordered_map<std::string, ClientConnection> connections_;
mutable std::mutex mutex_;
int max_connections_;
};
// ==================== 数据包排序器 ====================
/**
* 数据包排序器
* 网络传输可能导致数据包乱序到达
* 使用滑动窗口机制对数据包进行重排序
*/
class PacketReorderer {
public:
PacketReorderer(int window_size = 16) : window_size_(window_size), expected_seq_(0) {}
/**
* 提交数据包到排序窗口
* 如果是期望的下一个序号则直接输出
* 否则缓存等待前序包到达
*/
std::vector<GrpcStrokePacket> submit(const GrpcStrokePacket& packet) {
std::vector<GrpcStrokePacket> output;
if (packet.sequence_number == expected_seq_) {
// 正好是期望的下一个包
output.push_back(packet);
expected_seq_++;
// 检查缓存中是否有后续连续的包
while (buffer_.count(expected_seq_) > 0) {
output.push_back(buffer_[expected_seq_]);
buffer_.erase(expected_seq_);
expected_seq_++;
}
} else if (packet.sequence_number > expected_seq_) {
// 后序包先到达,缓存等待
buffer_[packet.sequence_number] = packet;
// 缓存过大时强制输出最旧的包
if (static_cast<int>(buffer_.size()) > window_size_) {
auto it = buffer_.begin();
output.push_back(it->second);
expected_seq_ = it->first + 1;
buffer_.erase(it);
}
}
// 过期的旧包直接丢弃
return output;
}
void reset() {
buffer_.clear();
expected_seq_ = 0;
}
private:
std::map<int, GrpcStrokePacket> buffer_;
int window_size_;
int expected_seq_;
};
// ==================== gRPC服务实现 ====================
/**
* gRPC笔迹接收服务
* 实现InferenceService.ProcessStroke流式RPC
* 接收网关推送的笔迹数据流,送入推理引擎处理
*
* 安全设计:
* - gRPC启用mTLS双向认证
* - 请求大小限制防恶意攻击
* - 连接数限制防DoS
*/
class GrpcStrokeServer {
public:
using StrokeCallback = std::function<void(const GrpcStrokePacket&)>;
GrpcStrokeServer(const std::string& listen_addr = "0.0.0.0:50052",
bool enable_tls = true)
: listen_addr_(listen_addr), enable_tls_(enable_tls),
running_(false), conn_manager_(100) {}
/**
* 设置笔迹数据接收回调
* 当收到网关的笔迹数据时调用此回调
*/
void set_stroke_callback(StrokeCallback callback) {
stroke_callback_ = std::move(callback);
}
/**
* 启动gRPC服务器
* 加载TLS证书,绑定端口,开始监听
*/
bool start() {
if (enable_tls_) {
// 加载mTLS证书(安全设计:gRPC启用mTLS双向认证)
// grpc::SslServerCredentialsOptions ssl_opts;
// ssl_opts.pem_root_certs = load_file("/etc/ssl/ca.crt");
// ssl_opts.pem_key_cert_pairs.push_back({
// load_file("/etc/ssl/server.key"),
// load_file("/etc/ssl/server.crt")
// });
// ssl_opts.client_certificate_request = GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY;
}
// 构建并启动gRPC服务器
// grpc::ServerBuilder builder;
// builder.AddListeningPort(listen_addr_, credentials);
// builder.RegisterService(this);
// builder.SetMaxReceiveMessageSize(10 * 1024 * 1024); // 10MB最大消息
// server_ = builder.BuildAndStart();
running_ = true;
return true;
}
/**
* ProcessStroke RPC实现
* 接收网关的流式笔迹数据,处理后返回识别结果流
*/
void ProcessStroke(const GrpcStrokePacket& packet) {
// 更新连接活跃状态
conn_manager_.update_activity(packet.pen_id, packet.points.size() * 16);
// 数据包排序
auto ordered = reorderer_.submit(packet);
// 处理排序后的数据包
for (const auto& p : ordered) {
total_packets_++;
total_points_ += static_cast<long>(p.points.size());
// 调用回调函数送入推理引擎
if (stroke_callback_) {
stroke_callback_(p);
}
}
}
/** 停止服务器 */
void stop() {
running_ = false;
// if (server_) server_->Shutdown();
}
/** 获取服务器统计信息 */
struct ServerStats {
int active_connections;
long total_packets;
long total_points;
bool is_running;
};
ServerStats get_stats() const {
ServerStats stats;
stats.active_connections = conn_manager_.active_count();
stats.total_packets = total_packets_.load();
stats.total_points = total_points_.load();
stats.is_running = running_.load();
return stats;
}
private:
std::string listen_addr_;
bool enable_tls_;
std::atomic<bool> running_;
ConnectionManager conn_manager_;
PacketReorderer reorderer_;
StrokeCallback stroke_callback_;
std::atomic<long> total_packets_{0};
std::atomic<long> total_points_{0};
};
// ==================== MQTT状态上报客户端 ====================
/**
* MQTT状态上报客户端
* 定期向云平台上报算力盒运行状态
* Topic: edgebox/{id}/status
* 安全设计:MQTT over TLS加密传输
*/
class MqttReporter {
public:
MqttReporter(const std::string& broker_url, const std::string& device_id)
: broker_url_(broker_url), device_id_(device_id), connected_(false) {}
/** 连接MQTT BrokerTLS加密) */
bool connect() {
// 实际环境使用Eclipse Paho MQTT C++ Client
// mqtt::async_client client(broker_url_, device_id_);
// mqtt::ssl_options ssl_opts;
// ssl_opts.set_trust_store("/etc/ssl/ca.crt");
// ssl_opts.set_key_store("/etc/ssl/client.crt");
// ssl_opts.set_private_key("/etc/ssl/client.key");
connected_ = true;
return true;
}
/** 上报设备状态 */
void report_status(float gpu_usage, float temperature, float inference_qps,
int queue_depth, long uptime_s) {
if (!connected_) return;
std::string topic = "edgebox/" + device_id_ + "/status";
// 构造JSON状态消息
// {"gpu_usage": 45.2, "temperature": 62.5, "qps": 120.3, "queue": 5, "uptime": 3600}
}
/** 接收远程指令 */
void subscribe_commands() {
std::string topic = "edgebox/" + device_id_ + "/command";
// 订阅远程管理指令:重启、模型切换、OTA升级等
}
/** 断开连接 */
void disconnect() {
connected_ = false;
}
private:
std::string broker_url_;
std::string device_id_;
bool connected_;
};
// ==================== 离线结果缓存 ====================
/**
* 离线结果缓存
* 断网期间推理结果暂存到本地SQLite数据库
* 网络恢复后自动批量上传至云端
* 安全设计:通信安全保障数据完整性
*/
class OfflineResultCache {
public:
OfflineResultCache(const std::string& db_path, int max_size_mb = 256)
: db_path_(db_path), max_size_mb_(max_size_mb), cached_count_(0) {}
/** 初始化SQLite数据库 */
bool initialize() {
// CREATE TABLE IF NOT EXISTS offline_results (
// id INTEGER PRIMARY KEY AUTOINCREMENT,
// packet_id TEXT NOT NULL,
// result_type TEXT NOT NULL,
// result_json TEXT NOT NULL,
// created_at INTEGER NOT NULL,
// uploaded INTEGER DEFAULT 0
// );
return true;
}
/** 缓存推理结果 */
bool cache_result(const std::string& packet_id, const std::string& type,
const std::string& result_json) {
// INSERT INTO offline_results (packet_id, result_type, result_json, created_at)
// VALUES (?, ?, ?, strftime('%s', 'now'));
cached_count_++;
return true;
}
/** 获取待上传的缓存结果 */
std::vector<std::string> get_pending_results(int limit = 100) {
// SELECT * FROM offline_results WHERE uploaded = 0 ORDER BY created_at LIMIT ?
return {};
}
/** 标记结果已上传 */
void mark_uploaded(const std::vector<int>& ids) {
// UPDATE offline_results SET uploaded = 1 WHERE id IN (...)
}
/** 清理已上传的旧数据 */
void cleanup(int retention_days = 7) {
// DELETE FROM offline_results WHERE uploaded = 1 AND created_at < ?
}
int cached_count() const { return cached_count_; }
private:
std::string db_path_;
int max_size_mb_;
int cached_count_;
};
// ==================== 集群管理器 ====================
/**
* 多算力盒集群管理器
* 通过mDNS服务发现同一校园网内的其他算力盒
* 实现负载均衡调度:当本机推理队列过长时,分发至空闲节点
*/
class ClusterManager {
public:
struct ClusterNode {
std::string node_id; // 节点ID
std::string address; // gRPC地址
float load_factor; // 负载因子(0-1)
bool is_self; // 是否为本机
std::chrono::steady_clock::time_point last_seen;
};
ClusterManager(const std::string& self_id) : self_id_(self_id) {}
/** 启动mDNS服务注册和发现 */
bool start_discovery() {
// 注册本机mDNS服务
// _writech-edgebox._tcp.local.
// 定期扫描同网段其他算力盒
return true;
}
/** 选择最优节点处理推理任务 */
std::string select_best_node() {
std::lock_guard<std::mutex> lock(mutex_);
std::string best_id = self_id_;
float min_load = 1.0f;
for (const auto& pair : nodes_) {
if (pair.second.load_factor < min_load) {
min_load = pair.second.load_factor;
best_id = pair.first;
}
}
return best_id;
}
/** 更新本机负载因子 */
void update_self_load(float load) {
std::lock_guard<std::mutex> lock(mutex_);
if (nodes_.count(self_id_)) {
nodes_[self_id_].load_factor = load;
}
}
int cluster_size() const {
std::lock_guard<std::mutex> lock(mutex_);
return static_cast<int>(nodes_.size());
}
private:
std::string self_id_;
std::unordered_map<std::string, ClusterNode> nodes_;
mutable std::mutex mutex_;
};
#endif // GRPC_SERVER_H
```
### `config/`
#### `config/edge_config.cpp`
```cpp
/**
* 自然写教室智能算力盒边缘计算软件 V1.0
* 配置管理与安全模块 - 全局配置、安全认证、审计日志
*
* 管理算力盒的所有运行配置参数
* 提供安全认证、审计日志记录等安全功能
* 安全设计:
* - 模型加密:模型文件AES-256加密存储
* - 通信安全:gRPC启用mTLS双向认证,MQTT over TLS
* - OTA安全:升级包RSA签名+SHA-256校验
* - 运行隔离:推理进程与管理进程独立沙箱
* - 物理安全:设备唯一序列号绑定
*/
#ifndef EDGE_CONFIG_H
#define EDGE_CONFIG_H
#include <string>
#include <vector>
#include <memory>
#include <mutex>
#include <fstream>
#include <unordered_map>
#include <chrono>
#include <ctime>
// ==================== 配置文件解析器 ====================
/**
* JSON配置文件解析器
* 从/etc/writech/edgebox.json加载配置
* 支持嵌套配置项和数组
*/
class ConfigParser {
public:
/**
* 从文件加载配置
*/
bool load_from_file(const std::string& path) {
config_path_ = path;
// 使用rapidjson或nlohmann/json解析
// 此处使用简单的键值对模拟
return load_defaults();
}
/**
* 获取字符串配置项
*/
std::string get_string(const std::string& key, const std::string& default_val = "") {
auto it = string_values_.find(key);
return (it != string_values_.end()) ? it->second : default_val;
}
/**
* 获取整数配置项
*/
int get_int(const std::string& key, int default_val = 0) {
auto it = int_values_.find(key);
return (it != int_values_.end()) ? it->second : default_val;
}
/**
* 获取浮点配置项
*/
float get_float(const std::string& key, float default_val = 0.0f) {
auto it = float_values_.find(key);
return (it != float_values_.end()) ? it->second : default_val;
}
/**
* 获取布尔配置项
*/
bool get_bool(const std::string& key, bool default_val = false) {
auto it = bool_values_.find(key);
return (it != bool_values_.end()) ? it->second : default_val;
}
/**
* 设置配置项(运行时修改)
*/
void set_string(const std::string& key, const std::string& value) {
string_values_[key] = value;
}
/**
* 保存配置到文件
*/
bool save_to_file(const std::string& path = "") {
std::string save_path = path.empty() ? config_path_ : path;
// 序列化为JSON并写入文件
return true;
}
private:
/**
* 加载默认配置
*/
bool load_defaults() {
// gRPC服务配置
string_values_["grpc.listen_addr"] = "0.0.0.0:50052";
int_values_["grpc.max_connections"] = 100;
bool_values_["grpc.enable_tls"] = true;
// MQTT配置
string_values_["mqtt.broker_url"] = "ssl://mqtt.writech.com:8883";
int_values_["mqtt.keepalive_s"] = 60;
bool_values_["mqtt.enable_tls"] = true;
// 推理引擎配置
string_values_["inference.device"] = "npu";
string_values_["inference.models_dir"] = "/opt/models";
int_values_["inference.max_batch_size"] = 16;
int_values_["inference.timeout_ms"] = 500;
bool_values_["inference.enable_fp16"] = true;
// GPU/NPU配置
float_values_["gpu.memory_fraction"] = 0.8f;
float_values_["gpu.thermal_throttle_temp"] = 80.0f;
// 集群配置
bool_values_["cluster.enable"] = true;
int_values_["cluster.mdns_port"] = 5353;
// 离线缓存配置
string_values_["cache.db_path"] = "/var/lib/writech/cache.db";
int_values_["cache.max_size_mb"] = 256;
// OTA配置
string_values_["ota.server_url"] = "https://ota.writech.com";
bool_values_["ota.auto_check"] = true;
int_values_["ota.check_interval_h"] = 24;
// 安全配置
string_values_["security.cert_dir"] = "/etc/ssl";
bool_values_["security.model_encryption"] = true;
bool_values_["security.enable_audit_log"] = true;
// 日志配置
string_values_["log.dir"] = "/var/log/writech";
string_values_["log.level"] = "INFO";
int_values_["log.max_size_mb"] = 50;
int_values_["log.rotate_count"] = 5;
return true;
}
std::string config_path_;
std::unordered_map<std::string, std::string> string_values_;
std::unordered_map<std::string, int> int_values_;
std::unordered_map<std::string, float> float_values_;
std::unordered_map<std::string, bool> bool_values_;
};
// ==================== 设备证书管理 ====================
/**
* 设备证书管理器
* 管理算力盒的X.509设备证书
* 用于mTLS双向认证和设备身份验证
* 安全设计:物理安全 - 设备唯一序列号绑定
*/
class DeviceCertManager {
public:
DeviceCertManager(const std::string& cert_dir = "/etc/ssl")
: cert_dir_(cert_dir) {}
/** 加载设备证书和密钥 */
bool load_certificates() {
server_cert_path_ = cert_dir_ + "/server.crt";
server_key_path_ = cert_dir_ + "/server.key";
ca_cert_path_ = cert_dir_ + "/ca.crt";
client_cert_path_ = cert_dir_ + "/client.crt";
client_key_path_ = cert_dir_ + "/client.key";
// 验证证书文件是否存在且有效
// X509_STORE *store = X509_STORE_new();
// X509_STORE_CTX *ctx = X509_STORE_CTX_new();
// 验证证书链完整性
return true;
}
/** 获取设备唯一序列号 */
std::string get_device_serial() {
// 从设备证书的Subject CN字段提取序列号
// 或从硬件安全芯片读取
return "EB-202501-001";
}
/** 验证对端证书指纹 */
bool verify_peer_cert(const std::string& peer_fingerprint) {
// 与信任列表比对
return trusted_fingerprints_.count(peer_fingerprint) > 0;
}
/** 注册信任的对端证书 */
void add_trusted_fingerprint(const std::string& name, const std::string& fingerprint) {
trusted_fingerprints_[fingerprint] = name;
}
std::string get_server_cert_path() const { return server_cert_path_; }
std::string get_server_key_path() const { return server_key_path_; }
std::string get_ca_cert_path() const { return ca_cert_path_; }
private:
std::string cert_dir_;
std::string server_cert_path_;
std::string server_key_path_;
std::string ca_cert_path_;
std::string client_cert_path_;
std::string client_key_path_;
std::unordered_map<std::string, std::string> trusted_fingerprints_;
};
// ==================== 审计日志记录器 ====================
/**
* 审计日志记录器
* 记录所有安全相关事件:
* - 推理请求(调用方、时间、模型版本)
* - 设备连接/断开
* - 模型加载/切换
* - OTA升级操作
* - 异常和错误事件
*/
class AuditLogger {
public:
enum class EventType {
INFERENCE_REQUEST, // 推理请求
DEVICE_CONNECT, // 设备连接
DEVICE_DISCONNECT, // 设备断开
MODEL_LOAD, // 模型加载
MODEL_SWITCH, // 模型切换
OTA_START, // OTA升级开始
OTA_COMPLETE, // OTA升级完成
OTA_FAILED, // OTA升级失败
AUTH_SUCCESS, // 认证成功
AUTH_FAILED, // 认证失败
CONFIG_CHANGE, // 配置变更
SYSTEM_ERROR // 系统错误
};
struct AuditEvent {
EventType type;
std::string timestamp;
std::string source; // 事件来源(客户端ID/模块名)
std::string action; // 操作描述
std::string details; // 详细信息
std::string result; // 结果(success/failure
std::string client_ip; // 客户端IP
};
AuditLogger(const std::string& log_dir = "/var/log/writech")
: log_dir_(log_dir), event_count_(0) {}
/**
* 记录审计事件
* 安全设计:所有识别请求记录调用方、时间、模型版本
*/
void log_event(const AuditEvent& event) {
std::lock_guard<std::mutex> lock(mutex_);
// 格式化时间戳
auto now = std::chrono::system_clock::now();
auto time = std::chrono::system_clock::to_time_t(now);
// 写入审计日志文件
// 格式:[时间] [事件类型] [来源] [操作] [结果] [详情]
// 审计日志独立于运行日志,不可被篡改
event_count_++;
// 检查日志文件大小,超限则轮转
check_rotation();
}
/** 快捷方法:记录推理请求 */
void log_inference(const std::string& client_id, const std::string& task_type,
const std::string& model_version, float latency_ms, bool success) {
AuditEvent event;
event.type = EventType::INFERENCE_REQUEST;
event.source = client_id;
event.action = "inference:" + task_type;
event.details = "model=" + model_version + ",latency=" + std::to_string(latency_ms) + "ms";
event.result = success ? "success" : "failure";
log_event(event);
}
/** 快捷方法:记录认证事件 */
void log_auth(const std::string& client_ip, const std::string& cert_cn, bool success) {
AuditEvent event;
event.type = success ? EventType::AUTH_SUCCESS : EventType::AUTH_FAILED;
event.source = cert_cn;
event.client_ip = client_ip;
event.action = "mTLS authentication";
event.result = success ? "success" : "failure";
log_event(event);
}
/** 快捷方法:记录OTA事件 */
void log_ota(const std::string& action, const std::string& version, bool success) {
AuditEvent event;
event.type = success ? EventType::OTA_COMPLETE : EventType::OTA_FAILED;
event.source = "ota_manager";
event.action = action;
event.details = "version=" + version;
event.result = success ? "success" : "failure";
log_event(event);
}
long get_event_count() const { return event_count_; }
private:
void check_rotation() {
// 审计日志文件轮转
// 当文件大小超过限制时创建新文件
// 保留最近90天的审计日志(安全合规要求)
}
std::string log_dir_;
long event_count_;
std::mutex mutex_;
};
// ==================== 进程沙箱隔离 ====================
/**
* 进程沙箱管理器
* 安全设计:推理进程与管理进程独立沙箱,异常不互相影响
* 使用Linux namespaces和cgroups实现进程隔离
*/
class ProcessSandbox {
public:
/** 创建沙箱化子进程 */
bool create_sandbox(const std::string& name, const std::string& exec_path) {
// Linux: clone(CLONE_NEWNS | CLONE_NEWPID | CLONE_NEWNET)
// cgroup限制:内存、CPU、GPU资源配额
// seccomp: 限制可用的系统调用
return true;
}
/** 设置资源限制 */
void set_resource_limits(const std::string& name, size_t memory_limit_mb,
float cpu_quota, int gpu_device_id) {
// 通过cgroups v2设置资源限制
// memory.max = memory_limit_mb * 1024 * 1024
// cpu.max = cpu_quota * period
// 通过NVIDIA Container Runtime限制GPU访问
}
/** 检查沙箱进程健康状态 */
bool is_healthy(const std::string& name) {
// 检查进程是否存活
// 检查资源使用是否超限
return true;
}
/** 重启异常的沙箱进程 */
bool restart_sandbox(const std::string& name) {
// 发送SIGTERM等待优雅退出
// 超时后发送SIGKILL强制终止
// 重新创建沙箱进程
return true;
}
};
#endif // EDGE_CONFIG_H
```
### `inference/`
#### `inference/inference_engine.cpp`
```cpp
/**
* 自然写教室智能算力盒边缘计算软件 V1.0
* 推理引擎模块 - ONNX Runtime / TensorRT 推理执行引擎
*
* 负责加载AI模型并执行推理任务
* 支持多种推理后端:ONNX Runtime、TensorRT、PaddleLite
* 支持NPU/GPU硬件加速调度
*/
#ifndef INFERENCE_ENGINE_H
#define INFERENCE_ENGINE_H
#include <string>
#include <vector>
#include <memory>
#include <mutex>
#include <queue>
#include <thread>
#include <atomic>
#include <chrono>
#include <functional>
#include <unordered_map>
#include <condition_variable>
// ==================== 数据结构定义 ====================
/**
* 推理设备类型枚举
* 算力盒支持多种硬件加速设备
*/
enum class DeviceType {
CPU = 0, // CPU推理(兜底方案)
GPU_CUDA = 1, // NVIDIA GPU (CUDA)
GPU_OPENCL = 2, // 通用GPU (OpenCL)
NPU_RKNN = 3, // 瑞芯微NPU (RKNN)
NPU_AMLOGIC = 4 // 晶晨NPU
};
/**
* 模型格式枚举
*/
enum class ModelFormat {
ONNX = 0, // ONNX格式(通用)
TENSORRT = 1, // TensorRT引擎(NVIDIA优化)
PADDLE_LITE = 2,// PaddleLiteARM优化)
RKNN = 3 // RKNN格式(瑞芯微NPU专用)
};
/**
* 推理任务类型
*/
enum class TaskType {
OCR = 0, // 文字OCR识别
MATH_RECOGNITION = 1, // 数学列式识别
STROKE_ORDER = 2, // 笔顺分析
WRITING_QUALITY = 3 // 书写质量评测
};
/**
* 张量数据(推理输入/输出)
* 封装多维数组数据和形状信息
*/
struct Tensor {
std::vector<float> data; // 浮点数据
std::vector<int64_t> shape; // 维度形状 (如 [1, 3, 64, 64])
std::string name; // 张量名称
/** 获取数据元素总数 */
size_t size() const {
size_t s = 1;
for (auto d : shape) s *= d;
return s;
}
};
/**
* 推理请求
*/
struct InferenceRequest {
std::string request_id; // 请求唯一ID
TaskType task_type; // 任务类型
std::vector<Tensor> inputs; // 输入张量列表
int priority = 2; // 优先级 (0=最高)
int timeout_ms = 500; // 超时时间
std::string pen_id; // 来源笔设备ID
std::string student_id; // 学生ID
std::chrono::steady_clock::time_point submit_time; // 提交时间
};
/**
* 推理结果
*/
struct InferenceResult {
std::string request_id;
bool success = false;
std::string error_message;
std::vector<Tensor> outputs; // 输出张量列表
float inference_time_ms = 0.0f; // 推理耗时
std::string model_version; // 使用的模型版本
};
// ==================== 推理后端抽象 ====================
/**
* 推理后端抽象基类
* 所有推理引擎(ONNX Runtime、TensorRT等)的统一接口
*/
class InferenceBackend {
public:
virtual ~InferenceBackend() = default;
/** 加载模型文件 */
virtual bool load_model(const std::string& model_path) = 0;
/** 执行推理 */
virtual InferenceResult infer(const InferenceRequest& request) = 0;
/** 卸载模型释放资源 */
virtual void unload() = 0;
/** 获取后端名称 */
virtual std::string name() const = 0;
};
/**
* ONNX Runtime推理后端
* 支持CPU/GPU/NPU多种执行提供者
*/
class OnnxRuntimeBackend : public InferenceBackend {
public:
OnnxRuntimeBackend(DeviceType device) : device_(device), loaded_(false) {}
bool load_model(const std::string& model_path) override {
model_path_ = model_path;
// 实际环境中:
// Ort::SessionOptions options;
// if (device_ == DeviceType::GPU_CUDA) {
// OrtCUDAProviderOptions cuda_opts;
// cuda_opts.device_id = 0;
// options.AppendExecutionProvider_CUDA(cuda_opts);
// }
// session_ = std::make_unique<Ort::Session>(env, model_path.c_str(), options);
loaded_ = true;
return true;
}
InferenceResult infer(const InferenceRequest& request) override {
InferenceResult result;
result.request_id = request.request_id;
if (!loaded_) {
result.success = false;
result.error_message = "模型未加载";
return result;
}
auto start = std::chrono::steady_clock::now();
// 执行ONNX Runtime推理
// std::vector<Ort::Value> input_tensors;
// for (const auto& input : request.inputs) {
// auto tensor = Ort::Value::CreateTensor<float>(
// memory_info, input.data.data(), input.size(),
// input.shape.data(), input.shape.size());
// input_tensors.push_back(std::move(tensor));
// }
// auto output_tensors = session_->Run(run_options, input_names, input_tensors, output_names);
// 模拟推理输出
Tensor output;
output.name = "output";
output.shape = {1, 10};
output.data.resize(10, 0.1f);
result.outputs.push_back(output);
result.success = true;
auto end = std::chrono::steady_clock::now();
result.inference_time_ms = std::chrono::duration<float, std::milli>(end - start).count();
return result;
}
void unload() override {
loaded_ = false;
}
std::string name() const override { return "ONNXRuntime"; }
private:
DeviceType device_;
std::string model_path_;
bool loaded_;
};
/**
* TensorRT推理后端
* NVIDIA GPU专用高性能推理引擎
* 支持FP16/INT8量化推理,显著降低推理延迟
*/
class TensorRTBackend : public InferenceBackend {
public:
TensorRTBackend() : loaded_(false) {}
bool load_model(const std::string& engine_path) override {
engine_path_ = engine_path;
// 实际环境中:
// std::ifstream file(engine_path, std::ios::binary);
// file.seekg(0, std::ios::end);
// size_t size = file.tellg();
// file.seekg(0, std::ios::beg);
// std::vector<char> engine_data(size);
// file.read(engine_data.data(), size);
//
// auto runtime = nvinfer1::createInferRuntime(logger);
// engine_ = runtime->deserializeCudaEngine(engine_data.data(), size);
// context_ = engine_->createExecutionContext();
loaded_ = true;
return true;
}
InferenceResult infer(const InferenceRequest& request) override {
InferenceResult result;
result.request_id = request.request_id;
if (!loaded_) {
result.success = false;
result.error_message = "TensorRT引擎未加载";
return result;
}
auto start = std::chrono::steady_clock::now();
// 执行TensorRT推理
// cudaMemcpyAsync(gpu_input, request.inputs[0].data.data(), ...);
// context_->enqueueV2(buffers, stream, nullptr);
// cudaMemcpyAsync(cpu_output, gpu_output, ...);
// cudaStreamSynchronize(stream);
Tensor output;
output.name = "output";
output.shape = {1, 10};
output.data.resize(10, 0.1f);
result.outputs.push_back(output);
result.success = true;
auto end = std::chrono::steady_clock::now();
result.inference_time_ms = std::chrono::duration<float, std::milli>(end - start).count();
return result;
}
void unload() override {
loaded_ = false;
}
std::string name() const override { return "TensorRT"; }
private:
std::string engine_path_;
bool loaded_;
};
// ==================== 推理任务队列 ====================
/**
* 优先级推理任务队列
* 按优先级和提交时间排序,高优先级任务优先处理
* 课堂实时场景的推理请求拥有最高优先级
*/
class InferenceTaskQueue {
public:
InferenceTaskQueue(size_t max_size = 1024) : max_size_(max_size) {}
/**
* 提交推理请求到队列
* 如果队列已满,丢弃最低优先级的任务
*/
bool enqueue(InferenceRequest request) {
std::lock_guard<std::mutex> lock(mutex_);
if (queue_.size() >= max_size_) {
// 队列已满,检查是否可以替换低优先级任务
if (!queue_.empty() && queue_.top().priority > request.priority) {
queue_.pop(); // 移除最低优先级任务
} else {
return false; // 无法入队
}
}
request.submit_time = std::chrono::steady_clock::now();
queue_.push(std::move(request));
cv_.notify_one();
return true;
}
/**
* 从队列获取最高优先级的任务
* 如果队列为空则阻塞等待
*/
bool dequeue(InferenceRequest& request, int timeout_ms = 100) {
std::unique_lock<std::mutex> lock(mutex_);
if (cv_.wait_for(lock, std::chrono::milliseconds(timeout_ms),
[this] { return !queue_.empty(); })) {
request = queue_.top();
queue_.pop();
return true;
}
return false;
}
size_t size() const {
std::lock_guard<std::mutex> lock(mutex_);
return queue_.size();
}
private:
// 自定义比较器:优先级小的排前面,相同优先级按提交时间排序
struct RequestCompare {
bool operator()(const InferenceRequest& a, const InferenceRequest& b) {
if (a.priority != b.priority) return a.priority > b.priority;
return a.submit_time > b.submit_time;
}
};
std::priority_queue<InferenceRequest, std::vector<InferenceRequest>, RequestCompare> queue_;
mutable std::mutex mutex_;
std::condition_variable cv_;
size_t max_size_;
};
// ==================== 推理引擎(核心类) ====================
/**
* 推理引擎
* 管理多个推理后端,根据模型类型和硬件条件选择最优推理路径
* 支持:
* - 多模型并发推理(OCR、数学、笔顺各独立模型)
* - 动态批处理(攒批提升GPU利用率)
* - 推理结果缓存(相同输入直接返回缓存结果)
* - 超时控制和优雅降级
*/
class InferenceEngine {
public:
InferenceEngine(DeviceType device, const std::string& models_dir)
: device_(device), models_dir_(models_dir), running_(false) {}
/**
* 初始化推理引擎
* 检测硬件设备、创建推理后端、加载模型
*/
bool initialize() {
// 检测硬件加速设备
detect_hardware();
// 为每种任务类型创建专用推理后端
backends_[TaskType::OCR] = create_backend("ocr");
backends_[TaskType::MATH_RECOGNITION] = create_backend("math");
backends_[TaskType::STROKE_ORDER] = create_backend("stroke_order");
backends_[TaskType::WRITING_QUALITY] = create_backend("writing_quality");
// 加载各模型
for (auto& [type, backend] : backends_) {
std::string model_file = get_model_path(type);
if (!backend->load_model(model_file)) {
return false;
}
}
// 启动推理工作线程
running_ = true;
worker_thread_ = std::thread(&InferenceEngine::worker_loop, this);
return true;
}
/**
* 提交推理请求(异步)
*/
std::string submit(InferenceRequest request) {
task_queue_.enqueue(std::move(request));
return request.request_id;
}
/**
* 同步推理(直接执行并返回结果)
*/
InferenceResult infer_sync(const InferenceRequest& request) {
auto it = backends_.find(request.task_type);
if (it == backends_.end()) {
InferenceResult result;
result.request_id = request.request_id;
result.success = false;
result.error_message = "不支持的任务类型";
return result;
}
return it->second->infer(request);
}
/**
* 关闭推理引擎
*/
void shutdown() {
running_ = false;
if (worker_thread_.joinable()) {
worker_thread_.join();
}
for (auto& [type, backend] : backends_) {
backend->unload();
}
}
/**
* 获取推理统计信息
*/
struct Stats {
long total_requests = 0;
long total_success = 0;
long total_failures = 0;
float avg_latency_ms = 0.0f;
float p99_latency_ms = 0.0f;
size_t queue_size = 0;
};
Stats get_stats() const {
Stats stats;
stats.total_requests = total_requests_.load();
stats.total_success = total_success_.load();
stats.total_failures = total_failures_.load();
stats.queue_size = task_queue_.size();
if (stats.total_success > 0) {
stats.avg_latency_ms = total_latency_ms_.load() / stats.total_success;
}
return stats;
}
private:
void detect_hardware() {
// 检测可用的硬件加速设备
// 瑞芯微NPU: 检查/dev/mali0或/dev/rknpu
// NVIDIA GPU: 检查CUDA Runtime
}
std::unique_ptr<InferenceBackend> create_backend(const std::string& model_name) {
// 根据设备类型创建对应的推理后端
if (device_ == DeviceType::GPU_CUDA) {
return std::make_unique<TensorRTBackend>();
}
return std::make_unique<OnnxRuntimeBackend>(device_);
}
std::string get_model_path(TaskType type) {
switch (type) {
case TaskType::OCR: return models_dir_ + "/ocr/model.onnx";
case TaskType::MATH_RECOGNITION: return models_dir_ + "/math/model.onnx";
case TaskType::STROKE_ORDER: return models_dir_ + "/stroke/model.onnx";
case TaskType::WRITING_QUALITY: return models_dir_ + "/quality/model.onnx";
}
return "";
}
/**
* 推理工作线程主循环
* 从任务队列取出请求,执行推理,存储结果
*/
void worker_loop() {
while (running_) {
InferenceRequest request;
if (task_queue_.dequeue(request, 100)) {
total_requests_++;
auto result = infer_sync(request);
if (result.success) {
total_success_++;
total_latency_ms_ += result.inference_time_ms;
} else {
total_failures_++;
}
// 存储结果供查询
std::lock_guard<std::mutex> lock(results_mutex_);
results_[request.request_id] = result;
}
}
}
DeviceType device_;
std::string models_dir_;
std::atomic<bool> running_;
std::thread worker_thread_;
InferenceTaskQueue task_queue_;
std::unordered_map<TaskType, std::unique_ptr<InferenceBackend>> backends_;
std::unordered_map<std::string, InferenceResult> results_;
std::mutex results_mutex_;
// 统计计数器
std::atomic<long> total_requests_{0};
std::atomic<long> total_success_{0};
std::atomic<long> total_failures_{0};
std::atomic<float> total_latency_ms_{0.0f};
};
#endif // INFERENCE_ENGINE_H
```
#### `inference/model_manager.cpp`
```cpp
/**
* 自然写教室智能算力盒边缘计算软件 V1.0
* 模型管理模块 - 模型加载、版本管理、量化压缩、云端同步
*
* 管理算力盒上部署的所有AI推理模型的生命周期
* 支持模型热更新、A/B切换、云端版本同步
* 模型文件AES-256加密存储,推理时内存解密加载
*/
#ifndef MODEL_MANAGER_H
#define MODEL_MANAGER_H
#include <string>
#include <vector>
#include <memory>
#include <mutex>
#include <atomic>
#include <unordered_map>
#include <chrono>
#include <functional>
// ==================== 模型元信息 ====================
/** 模型状态枚举 */
enum class ModelState {
NOT_FOUND = 0, // 未发现
DOWNLOADING = 1, // 下载中
DECRYPTING = 2, // 解密中
LOADING = 3, // 加载到设备中
READY = 4, // 就绪可用
ACTIVE = 5, // 当前使用中
DEPRECATED = 6, // 已弃用
ERROR = 7 // 错误状态
};
/** 模型量化精度 */
enum class QuantizationType {
FP32 = 0, // 全精度浮点
FP16 = 1, // 半精度浮点
INT8 = 2, // 8位整型量化
INT4 = 3 // 4位整型量化(极致压缩)
};
/** 模型元信息 */
struct ModelInfo {
std::string name; // 模型名称
std::string version; // 版本号(语义化版本)
std::string format; // 格式(onnx/trt/rknn
std::string file_path; // 本地文件路径
size_t file_size_bytes; // 文件大小
std::string sha256; // 文件SHA-256校验和
QuantizationType quantization; // 量化类型
float accuracy; // 测试集准确率
float latency_ms; // 平均推理延迟
ModelState state; // 当前状态
std::string deployed_at; // 部署时间
std::string description; // 模型描述
};
// ==================== 模型加密管理 ====================
/**
* 模型文件加密/解密管理器
* 安全设计:模型文件AES-256加密存储,推理时内存解密加载
* 加密密钥通过安全芯片(TPM)或环境变量注入
*/
class ModelCryptoManager {
public:
ModelCryptoManager() : key_loaded_(false) {}
/**
* 加载加密密钥
* 优先从安全芯片读取,其次从环境变量
*/
bool load_encryption_key() {
// 尝试从TPM安全芯片读取密钥
// if (tpm_available()) { key_ = tpm_read_key("model_key"); }
// 后备方案:从环境变量读取
const char* env_key = std::getenv("WRITECH_MODEL_KEY");
if (env_key) {
key_ = std::string(env_key);
key_loaded_ = true;
return true;
}
return false;
}
/**
* 解密模型文件到内存
* 不在磁盘上生成明文文件,仅在内存中解密
*/
std::vector<uint8_t> decrypt_model(const std::string& encrypted_path) {
std::vector<uint8_t> decrypted_data;
if (!key_loaded_) return decrypted_data;
// 读取加密文件
// AES-256-CBC解密
// openssl EVP_DecryptInit_ex(ctx, EVP_aes_256_cbc(), NULL, key, iv);
// EVP_DecryptUpdate(ctx, output, &out_len, input, in_len);
// EVP_DecryptFinal_ex(ctx, output + out_len, &final_len);
return decrypted_data;
}
/**
* 加密模型文件
* 新下载的模型文件加密后存储到本地Flash
*/
bool encrypt_model(const std::vector<uint8_t>& data, const std::string& output_path) {
if (!key_loaded_) return false;
// AES-256-CBC加密并写入文件
return true;
}
/**
* 验证模型文件完整性
* 计算SHA-256校验和并与元数据中的值比对
*/
bool verify_integrity(const std::string& file_path, const std::string& expected_sha256) {
// 计算文件SHA-256
// SHA256_CTX sha256;
// SHA256_Init(&sha256);
// while (read chunk) SHA256_Update(&sha256, chunk, len);
// SHA256_Final(hash, &sha256);
return true;
}
private:
std::string key_;
bool key_loaded_;
};
// ==================== 模型版本管理器 ====================
/**
* 模型版本管理器
* 管理算力盒上所有AI模型的版本、加载、切换
* 支持A/B分区切换实现热更新
*/
class ModelVersionManager {
public:
ModelVersionManager(const std::string& models_dir)
: models_dir_(models_dir) {}
/**
* 注册模型
* 扫描模型目录,加载所有可用模型的元信息
*/
bool register_model(const ModelInfo& info) {
std::lock_guard<std::mutex> lock(mutex_);
std::string key = info.name + "@" + info.version;
models_[key] = info;
return true;
}
/**
* 激活指定版本的模型
* 将旧版本标记为deprecated,新版本标记为active
*/
bool activate_version(const std::string& name, const std::string& version) {
std::lock_guard<std::mutex> lock(mutex_);
// 将当前活跃版本设为deprecated
for (auto& pair : models_) {
if (pair.second.name == name && pair.second.state == ModelState::ACTIVE) {
pair.second.state = ModelState::DEPRECATED;
}
}
// 激活新版本
std::string key = name + "@" + version;
auto it = models_.find(key);
if (it != models_.end()) {
it->second.state = ModelState::ACTIVE;
return true;
}
return false;
}
/**
* 获取当前活跃版本的模型信息
*/
ModelInfo get_active_model(const std::string& name) {
std::lock_guard<std::mutex> lock(mutex_);
for (const auto& pair : models_) {
if (pair.second.name == name && pair.second.state == ModelState::ACTIVE) {
return pair.second;
}
}
return ModelInfo{};
}
/**
* 获取所有模型状态列表
*/
std::vector<ModelInfo> get_all_models() {
std::lock_guard<std::mutex> lock(mutex_);
std::vector<ModelInfo> result;
for (const auto& pair : models_) {
result.push_back(pair.second);
}
return result;
}
/**
* 清理已废弃的旧版本模型文件
* 保留最近2个版本,删除更早的版本释放存储空间
*/
void cleanup_old_versions(const std::string& name, int keep_count = 2) {
std::lock_guard<std::mutex> lock(mutex_);
std::vector<std::string> deprecated_keys;
for (const auto& pair : models_) {
if (pair.second.name == name && pair.second.state == ModelState::DEPRECATED) {
deprecated_keys.push_back(pair.first);
}
}
// 按版本排序,保留最新的keep_count个
if (static_cast<int>(deprecated_keys.size()) > keep_count) {
for (int i = 0; i < static_cast<int>(deprecated_keys.size()) - keep_count; i++) {
// 删除模型文件并从注册表移除
models_.erase(deprecated_keys[i]);
}
}
}
private:
std::string models_dir_;
std::unordered_map<std::string, ModelInfo> models_;
std::mutex mutex_;
};
// ==================== 云端模型同步器 ====================
/**
* 云端模型同步器
* 定期检查云端是否有新版本模型,自动下载并部署
* 通过HTTPS加密通道下载,下载后RSA签名校验
*/
class CloudModelSyncer {
public:
CloudModelSyncer(const std::string& server_url, const std::string& device_id)
: server_url_(server_url), device_id_(device_id) {}
/**
* 检查云端是否有模型更新
* GET /api/v1/model/check-update?device_id=xxx&models=ocr@1.0,math@1.0
*/
struct UpdateInfo {
std::string model_name;
std::string new_version;
std::string download_url;
size_t file_size;
std::string sha256;
};
std::vector<UpdateInfo> check_updates(const std::vector<ModelInfo>& current_models) {
std::vector<UpdateInfo> updates;
// 向云端API发送当前模型版本列表,获取可更新版本
// HTTPS请求:GET server_url_/api/v1/model/check-update
return updates;
}
/**
* 下载模型文件
* HTTPS下载,支持断点续传
* 下载完成后进行SHA-256校验和RSA签名验证
*/
bool download_model(const UpdateInfo& info, const std::string& save_path) {
// HTTPS下载
// 进度回调上报
// SHA-256校验
// RSA签名验证(OTA安全:升级包RSA签名+SHA-256校验,防篡改)
return true;
}
/**
* 上报模型部署状态
* POST /api/v1/model/deploy-status
*/
void report_deploy_status(const std::string& model_name, const std::string& version,
bool success, const std::string& error = "") {
// 向云端上报模型部署结果
}
private:
std::string server_url_;
std::string device_id_;
};
// ==================== OTA固件升级管理器 ====================
/**
* OTA固件升级管理器
* 管理算力盒固件的远程升级
* 采用A/B双分区方案,升级失败自动回滚
* 安全设计:升级包RSA签名+SHA-256校验,防篡改
*/
class OtaUpgradeManager {
public:
enum class OtaState {
IDLE, // 空闲
CHECKING, // 检查更新中
DOWNLOADING, // 下载中
VERIFYING, // 校验中
INSTALLING, // 安装中
REBOOTING, // 重启中
FAILED // 失败
};
OtaUpgradeManager(const std::string& ota_url, const std::string& device_id)
: ota_url_(ota_url), device_id_(device_id), state_(OtaState::IDLE),
current_partition_("A"), download_progress_(0) {}
/** 检查固件更新 */
bool check_update() {
state_ = OtaState::CHECKING;
// GET ota_url_/api/v1/ota/check?device_id=xxx&version=xxx
return false; // 返回是否有新版本
}
/** 下载固件升级包 */
bool download_firmware(const std::string& download_url) {
state_ = OtaState::DOWNLOADING;
// HTTPS分块下载到非活跃分区
// 支持断点续传
return true;
}
/** 验证固件包完整性和签名 */
bool verify_firmware(const std::string& firmware_path) {
state_ = OtaState::VERIFYING;
// SHA-256校验
// RSA-2048签名验证
return true;
}
/** 安装固件(写入非活跃分区) */
bool install_firmware() {
state_ = OtaState::INSTALLING;
// 写入B分区(如当前运行A分区)
// 设置下次启动从B分区引导
return true;
}
/** 回滚到上一版本 */
bool rollback() {
// 切换回上一个分区
std::string target = (current_partition_ == "A") ? "B" : "A";
// 设置引导分区为target
return true;
}
/** 获取当前OTA状态 */
OtaState get_state() const { return state_; }
int get_progress() const { return download_progress_; }
std::string get_current_partition() const { return current_partition_; }
private:
std::string ota_url_;
std::string device_id_;
OtaState state_;
std::string current_partition_;
int download_progress_;
};
// ==================== 系统监控模块 ====================
/**
* 系统运行状态监控
* 采集CPU、内存、GPU/NPU利用率、温度等硬件指标
* 为云端监控告警和集群调度提供数据支撑
*/
class SystemMonitor {
public:
struct SystemMetrics {
float cpu_usage_percent; // CPU使用率
float memory_usage_percent; // 内存使用率
long memory_total_mb; // 总内存
long memory_used_mb; // 已用内存
float gpu_usage_percent; // GPU/NPU利用率
float gpu_memory_usage_mb; // GPU显存使用
float gpu_temperature_c; // GPU温度
float disk_usage_percent; // 磁盘使用率
float network_rx_mbps; // 网络接收速率
float network_tx_mbps; // 网络发送速率
long uptime_seconds; // 系统运行时长
};
SystemMonitor() : running_(false) {}
/** 启动监控采集线程 */
void start(int interval_ms = 5000) {
running_ = true;
// 定时采集系统指标
}
/** 获取最新系统指标 */
SystemMetrics get_metrics() {
SystemMetrics metrics;
metrics.cpu_usage_percent = read_cpu_usage();
metrics.memory_usage_percent = read_memory_usage();
metrics.gpu_usage_percent = read_gpu_usage();
metrics.gpu_temperature_c = read_gpu_temperature();
metrics.disk_usage_percent = read_disk_usage();
return metrics;
}
void stop() { running_ = false; }
private:
float read_cpu_usage() {
// 读取 /proc/stat 计算CPU使用率
return 0.0f;
}
float read_memory_usage() {
// 读取 /proc/meminfo
return 0.0f;
}
float read_gpu_usage() {
// NVIDIA: nvidia-smi / NVML
// 瑞芯微: /sys/class/devfreq/xxx
return 0.0f;
}
float read_gpu_temperature() {
// 读取GPU温度传感器
return 0.0f;
}
float read_disk_usage() {
// statfs("/")
return 0.0f;
}
std::atomic<bool> running_;
};
#endif // MODEL_MANAGER_H
```
#### `inference/npu_scheduler.cpp`
```cpp
/**
* 自然写教室智能算力盒边缘计算软件 V1.0
* NPU/GPU硬件调度模块 - 硬件加速资源管理与任务分配
*
* 管理算力盒上的NPU/GPU计算资源
* 支持多种硬件平台:NVIDIA GPU(CUDA)、瑞芯微NPU(RKNN)、通用GPU(OpenCL)
* 根据任务类型和硬件负载动态选择最优推理路径
*/
#ifndef NPU_SCHEDULER_H
#define NPU_SCHEDULER_H
#include <string>
#include <vector>
#include <memory>
#include <mutex>
#include <atomic>
#include <chrono>
#include <queue>
#include <functional>
#include <unordered_map>
#include <thread>
#include <condition_variable>
#include <cstring>
// ==================== 硬件设备抽象 ====================
/** 硬件加速器类型 */
enum class AcceleratorType {
CPU_ONLY = 0, // 仅CPU(无加速器可用时的兜底方案)
NVIDIA_GPU = 1, // NVIDIA GPU (CUDA/TensorRT)
ROCKCHIP_NPU = 2, // 瑞芯微NPU (RKNN)
AMLOGIC_NPU = 3, // 晶晨NPU
GENERIC_OPENCL = 4 // 通用OpenCL GPU
};
/** 硬件设备信息 */
struct AcceleratorDevice {
AcceleratorType type; // 加速器类型
int device_id; // 设备编号
std::string name; // 设备名称
std::string driver_version; // 驱动版本
size_t total_memory_mb; // 总显存/内存(MB)
size_t free_memory_mb; // 可用显存/内存(MB)
float compute_capability; // 算力指标
float current_utilization; // 当前利用率(0-1)
float temperature_celsius; // 当前温度
float max_temperature; // 最高安全温度
bool is_available; // 是否可用
};
/** 推理任务资源需求 */
struct TaskResourceRequirement {
size_t memory_mb; // 需要的显存(MB)
float estimated_time_ms; // 预估推理时间
bool requires_fp16; // 是否需要FP16支持
bool requires_int8; // 是否需要INT8支持
int preferred_device; // 偏好设备ID-1表示无偏好)
};
// ==================== 硬件检测器 ====================
/**
* 硬件加速器检测器
* 启动时扫描系统中可用的NPU/GPU设备
* 自动匹配设备驱动和推理后端
*/
class HardwareDetector {
public:
/**
* 扫描系统中所有可用的加速器设备
* 检测顺序:NVIDIA GPU → 瑞芯微NPU → 通用OpenCL → CPU
*/
std::vector<AcceleratorDevice> detect_devices() {
std::vector<AcceleratorDevice> devices;
// 检测NVIDIA GPU
if (detect_nvidia_gpu(devices)) {
// 通过NVML库获取GPU信息
}
// 检测瑞芯微NPU
if (detect_rockchip_npu(devices)) {
// 通过sysfs获取NPU信息
}
// 如果没有加速器,添加CPU作为兜底
if (devices.empty()) {
AcceleratorDevice cpu_dev;
cpu_dev.type = AcceleratorType::CPU_ONLY;
cpu_dev.device_id = 0;
cpu_dev.name = "CPU";
cpu_dev.total_memory_mb = get_system_memory_mb();
cpu_dev.free_memory_mb = get_free_memory_mb();
cpu_dev.is_available = true;
devices.push_back(cpu_dev);
}
return devices;
}
private:
bool detect_nvidia_gpu(std::vector<AcceleratorDevice>& devices) {
// 检查 /dev/nvidia0 是否存在
// 使用NVML API获取设备信息
// nvmlInit();
// nvmlDeviceGetCount(&count);
// for (int i = 0; i < count; i++) {
// nvmlDeviceGetHandleByIndex(i, &device);
// nvmlDeviceGetName(device, name, sizeof(name));
// nvmlDeviceGetMemoryInfo(device, &mem);
// nvmlDeviceGetUtilizationRates(device, &util);
// nvmlDeviceGetTemperature(device, NVML_TEMPERATURE_GPU, &temp);
// }
return false;
}
bool detect_rockchip_npu(std::vector<AcceleratorDevice>& devices) {
// 检查 /dev/rknpu 或 /sys/class/misc/rknpu 是否存在
// 读取NPU硬件信息
// cat /sys/kernel/debug/rknpu/load // NPU负载
return false;
}
size_t get_system_memory_mb() {
// 读取 /proc/meminfo
return 4096; // 默认4GB
}
size_t get_free_memory_mb() {
return 2048;
}
};
// ==================== 设备负载监控 ====================
/**
* 硬件设备负载实时监控
* 定期采集GPU/NPU利用率、温度、显存使用等指标
* 为调度策略提供实时数据支撑
*/
class DeviceLoadMonitor {
public:
struct DeviceMetrics {
int device_id;
float utilization; // 利用率 (0-1)
float memory_usage; // 显存使用率 (0-1)
float temperature; // 温度(摄氏度)
float power_watts; // 功耗(瓦)
int inference_qps; // 当前推理QPS
std::chrono::steady_clock::time_point timestamp;
};
DeviceLoadMonitor() : running_(false) {}
/** 启动监控(后台线程定期采集) */
void start(int interval_ms = 1000) {
running_ = true;
monitor_thread_ = std::thread([this, interval_ms]() {
while (running_) {
collect_metrics();
std::this_thread::sleep_for(std::chrono::milliseconds(interval_ms));
}
});
}
/** 获取指定设备的最新指标 */
DeviceMetrics get_metrics(int device_id) {
std::lock_guard<std::mutex> lock(mutex_);
auto it = latest_metrics_.find(device_id);
if (it != latest_metrics_.end()) {
return it->second;
}
return DeviceMetrics{};
}
/** 获取所有设备指标 */
std::vector<DeviceMetrics> get_all_metrics() {
std::lock_guard<std::mutex> lock(mutex_);
std::vector<DeviceMetrics> result;
for (const auto& pair : latest_metrics_) {
result.push_back(pair.second);
}
return result;
}
void stop() {
running_ = false;
if (monitor_thread_.joinable()) {
monitor_thread_.join();
}
}
private:
void collect_metrics() {
std::lock_guard<std::mutex> lock(mutex_);
// NVIDIA GPU: nvmlDeviceGetUtilizationRates + nvmlDeviceGetTemperature
// 瑞芯微NPU: 读取 /sys/kernel/debug/rknpu/load
// CPU: 读取 /proc/stat
}
std::unordered_map<int, DeviceMetrics> latest_metrics_;
std::mutex mutex_;
std::atomic<bool> running_;
std::thread monitor_thread_;
};
// ==================== 调度策略 ====================
/**
* 推理任务调度策略
* 根据任务特征和设备负载选择最优的推理设备
*/
class SchedulingPolicy {
public:
virtual ~SchedulingPolicy() = default;
/** 选择最优设备执行推理任务 */
virtual int select_device(const TaskResourceRequirement& requirement,
const std::vector<AcceleratorDevice>& devices,
const std::vector<DeviceLoadMonitor::DeviceMetrics>& metrics) = 0;
};
/**
* 最小负载调度策略
* 优先选择当前利用率最低的设备
*/
class MinLoadPolicy : public SchedulingPolicy {
public:
int select_device(const TaskResourceRequirement& requirement,
const std::vector<AcceleratorDevice>& devices,
const std::vector<DeviceLoadMonitor::DeviceMetrics>& metrics) override {
int best_device = 0;
float min_load = 1.0f;
for (size_t i = 0; i < devices.size(); i++) {
if (!devices[i].is_available) continue;
if (devices[i].free_memory_mb < requirement.memory_mb) continue;
float load = (i < metrics.size()) ? metrics[i].utilization : 0.0f;
if (load < min_load) {
min_load = load;
best_device = static_cast<int>(i);
}
}
return best_device;
}
};
/**
* 温度感知调度策略
* 除了负载外还考虑设备温度,防止过热降频
*/
class ThermalAwarePolicy : public SchedulingPolicy {
public:
ThermalAwarePolicy(float temp_threshold = 80.0f) : temp_threshold_(temp_threshold) {}
int select_device(const TaskResourceRequirement& requirement,
const std::vector<AcceleratorDevice>& devices,
const std::vector<DeviceLoadMonitor::DeviceMetrics>& metrics) override {
int best_device = 0;
float best_score = -1.0f;
for (size_t i = 0; i < devices.size(); i++) {
if (!devices[i].is_available) continue;
if (devices[i].free_memory_mb < requirement.memory_mb) continue;
float load = (i < metrics.size()) ? metrics[i].utilization : 0.0f;
float temp = (i < metrics.size()) ? metrics[i].temperature : 0.0f;
// 综合评分:负载权重0.6 + 温度权重0.4
float load_score = 1.0f - load;
float temp_score = (temp < temp_threshold_) ? 1.0f : (1.0f - (temp - temp_threshold_) / 20.0f);
float score = load_score * 0.6f + temp_score * 0.4f;
if (score > best_score) {
best_score = score;
best_device = static_cast<int>(i);
}
}
return best_device;
}
private:
float temp_threshold_;
};
// ==================== NPU调度器(核心) ====================
/**
* NPU/GPU硬件调度器
* 管理推理任务到硬件设备的分配调度
* 核心功能:
* 1. 硬件资源池化管理
* 2. 基于负载和温度的智能调度
* 3. 设备故障自动切换
* 4. 推理性能指标采集
*/
class NpuScheduler {
public:
NpuScheduler() : initialized_(false) {}
/**
* 初始化调度器
* 检测硬件设备,启动负载监控,设置调度策略
*/
bool initialize() {
// 检测可用硬件加速器
HardwareDetector detector;
devices_ = detector.detect_devices();
if (devices_.empty()) {
return false;
}
// 启动设备负载监控
load_monitor_.start(1000);
// 设置调度策略(默认温度感知策略)
policy_ = std::make_unique<ThermalAwarePolicy>(80.0f);
initialized_ = true;
return true;
}
/**
* 为推理任务分配最优设备
*/
int schedule_task(const TaskResourceRequirement& requirement) {
if (!initialized_) return 0;
auto metrics = load_monitor_.get_all_metrics();
return policy_->select_device(requirement, devices_, metrics);
}
/**
* 获取所有设备状态
*/
std::vector<AcceleratorDevice> get_device_status() {
// 更新设备实时状态
auto metrics = load_monitor_.get_all_metrics();
for (auto& dev : devices_) {
for (const auto& m : metrics) {
if (m.device_id == dev.device_id) {
dev.current_utilization = m.utilization;
dev.temperature_celsius = m.temperature;
}
}
}
return devices_;
}
/** 获取调度统计信息 */
struct SchedulerStats {
long total_tasks_scheduled;
long total_tasks_completed;
long total_tasks_failed;
float avg_inference_ms;
float gpu_avg_utilization;
float gpu_temperature;
int active_devices;
};
SchedulerStats get_stats() {
SchedulerStats stats;
stats.total_tasks_scheduled = tasks_scheduled_.load();
stats.total_tasks_completed = tasks_completed_.load();
stats.total_tasks_failed = tasks_failed_.load();
stats.active_devices = static_cast<int>(devices_.size());
auto metrics = load_monitor_.get_all_metrics();
if (!metrics.empty()) {
float total_util = 0;
for (const auto& m : metrics) total_util += m.utilization;
stats.gpu_avg_utilization = total_util / metrics.size();
stats.gpu_temperature = metrics[0].temperature;
}
return stats;
}
void shutdown() {
load_monitor_.stop();
initialized_ = false;
}
private:
std::vector<AcceleratorDevice> devices_;
DeviceLoadMonitor load_monitor_;
std::unique_ptr<SchedulingPolicy> policy_;
bool initialized_;
std::atomic<long> tasks_scheduled_{0};
std::atomic<long> tasks_completed_{0};
std::atomic<long> tasks_failed_{0};
};
// ==================== 配置管理 ====================
/**
* 算力盒配置管理(边缘设备专用)
* 从JSON配置文件和环境变量加载配置
* 支持运行时配置热更新(通过MQTT远程指令)
*/
struct EdgeBoxConfiguration {
// 推理配置
int max_concurrent_inferences = 4; // 最大并发推理数
int inference_queue_size = 256; // 推理队列大小
int default_timeout_ms = 500; // 默认推理超时
// NPU/GPU配置
float gpu_memory_fraction = 0.8f; // GPU显存使用比例上限
float thermal_throttle_temp = 80.0f; // 温度降频阈值
bool enable_fp16 = true; // 启用FP16推理
bool enable_int8 = false; // 启用INT8量化
// 网络配置
std::string grpc_listen = "0.0.0.0:50052";
std::string mqtt_broker = "ssl://mqtt.writech.com:8883";
bool enable_mtls = true;
// 存储配置
std::string models_dir = "/opt/models";
std::string cache_dir = "/var/lib/writech/cache";
int offline_cache_max_mb = 256;
// 集群配置
bool enable_cluster = true;
std::string cluster_discovery = "mdns";
};
#endif // NPU_SCHEDULER_H
```
### `preprocessing/`
#### `preprocessing/stroke_preprocessor.cpp`
```cpp
/**
* 自然写教室智能算力盒边缘计算软件 V1.0
* 笔迹预处理模块 - 笔迹坐标数据预处理管道
*
* 对网关转发的原始笔迹坐标进行预处理:
* 去噪滤波、坐标归一化、笔画分割、特征提取
* 预处理结果作为NPU/GPU推理的标准化输入
*/
#ifndef STROKE_PREPROCESSOR_H
#define STROKE_PREPROCESSOR_H
#include <vector>
#include <cmath>
#include <algorithm>
#include <numeric>
#include <cstring>
// ==================== 基础数据结构 ====================
/** 原始笔迹坐标点(来自网关gRPC数据流) */
struct RawPoint {
float x; // X坐标(点阵单位,约300DPI)
float y; // Y坐标
float pressure; // 压力值 (0.0-1.0)
uint32_t timestamp; // 采集时间戳(毫秒)
bool pen_up; // 抬笔标记
};
/** 归一化后的坐标点 */
struct NormalizedPoint {
float x; // 归一化X (0.0-1.0)
float y; // 归一化Y (0.0-1.0)
float pressure; // 压力值 (0.0-1.0)
};
/** 笔画数据 */
struct Stroke {
std::vector<NormalizedPoint> points; // 归一化坐标点序列
int stroke_index; // 笔画序号
float length; // 笔画路径长度
int duration_ms; // 书写耗时(毫秒)
};
/** 预处理输出(用于NPU推理输入) */
struct PreprocessedData {
std::vector<float> image; // 渲染后的灰度图像 (H*W)
int image_width; // 图像宽度
int image_height; // 图像高度
std::vector<Stroke> strokes; // 分割后的笔画列表
int total_points; // 总坐标点数
int stroke_count; // 笔画数量
};
// ==================== 去噪滤波器 ====================
/**
* 笔迹去噪滤波器
* 消除点阵笔采集过程中的抖动噪声和异常跳跃点
* 多级滤波策略:异常点剔除 → 中值滤波 → 移动平均平滑
*/
class StrokeNoiseFilter {
public:
/**
* 构造函数
* max_jump: 最大允许跳跃距离(超过则视为异常点)
* window_size: 滤波窗口大小(奇数)
*/
StrokeNoiseFilter(float max_jump = 50.0f, int window_size = 3)
: max_jump_(max_jump), window_size_(window_size) {}
/**
* 剔除异常跳跃点
* 点阵笔摄像头短暂遮挡会导致坐标突变,需要过滤
*/
std::vector<RawPoint> remove_outliers(const std::vector<RawPoint>& points) {
if (points.size() < 3) return points;
std::vector<RawPoint> result;
result.push_back(points[0]);
for (size_t i = 1; i < points.size(); i++) {
float dx = points[i].x - points[i-1].x;
float dy = points[i].y - points[i-1].y;
float dist = std::sqrt(dx * dx + dy * dy);
// 跳跃距离在合理范围内才保留该点
if (dist <= max_jump_) {
result.push_back(points[i]);
}
}
return result;
}
/**
* 中值滤波去噪
* 对X和Y坐标分别进行一维中值滤波
* 有效消除脉冲噪声同时保留笔画转折特征
*/
std::vector<RawPoint> median_filter(const std::vector<RawPoint>& points) {
int n = static_cast<int>(points.size());
if (n < window_size_) return points;
int half = window_size_ / 2;
std::vector<RawPoint> result(n);
for (int i = 0; i < n; i++) {
// 收集窗口内的X和Y值
std::vector<float> wx, wy;
for (int j = std::max(0, i - half); j <= std::min(n - 1, i + half); j++) {
wx.push_back(points[j].x);
wy.push_back(points[j].y);
}
// 排序取中值
std::sort(wx.begin(), wx.end());
std::sort(wy.begin(), wy.end());
result[i] = points[i];
result[i].x = wx[wx.size() / 2];
result[i].y = wy[wy.size() / 2];
}
return result;
}
/**
* 移动平均平滑
* 进一步减少微小抖动,使笔画更流畅
*/
std::vector<RawPoint> moving_average(const std::vector<RawPoint>& points) {
int n = static_cast<int>(points.size());
if (n < 3) return points;
std::vector<RawPoint> result(n);
int half = window_size_ / 2;
for (int i = 0; i < n; i++) {
float sum_x = 0, sum_y = 0;
int count = 0;
for (int j = std::max(0, i - half); j <= std::min(n - 1, i + half); j++) {
sum_x += points[j].x;
sum_y += points[j].y;
count++;
}
result[i] = points[i];
result[i].x = sum_x / count;
result[i].y = sum_y / count;
}
return result;
}
/** 执行完整去噪流程 */
std::vector<RawPoint> apply(const std::vector<RawPoint>& points) {
auto step1 = remove_outliers(points);
auto step2 = median_filter(step1);
auto step3 = moving_average(step2);
return step3;
}
private:
float max_jump_;
int window_size_;
};
// ==================== 坐标归一化器 ====================
/**
* 坐标归一化器
* 将不同纸张尺寸和分辨率的原始坐标统一归一化到[0,1]范围
* 保持宽高比以避免笔迹变形
*/
class CoordinateNormalizer {
public:
CoordinateNormalizer(bool preserve_aspect = true) : preserve_aspect_(preserve_aspect) {}
/**
* Min-Max归一化,映射到[0,1]范围
*/
std::vector<NormalizedPoint> normalize(const std::vector<RawPoint>& points) {
if (points.empty()) return {};
// 计算坐标范围
float min_x = points[0].x, max_x = points[0].x;
float min_y = points[0].y, max_y = points[0].y;
for (const auto& p : points) {
min_x = std::min(min_x, p.x);
max_x = std::max(max_x, p.x);
min_y = std::min(min_y, p.y);
max_y = std::max(max_y, p.y);
}
float range_x = max_x - min_x;
float range_y = max_y - min_y;
// 保持宽高比时使用统一的缩放因子
float scale = 1.0f;
if (preserve_aspect_) {
scale = std::max(range_x, range_y);
if (scale < 1e-6f) scale = 1.0f;
}
std::vector<NormalizedPoint> result;
result.reserve(points.size());
for (const auto& p : points) {
NormalizedPoint np;
if (preserve_aspect_) {
np.x = (p.x - min_x) / scale;
np.y = (p.y - min_y) / scale;
} else {
np.x = (range_x > 1e-6f) ? (p.x - min_x) / range_x : 0.5f;
np.y = (range_y > 1e-6f) ? (p.y - min_y) / range_y : 0.5f;
}
np.pressure = p.pressure;
result.push_back(np);
}
return result;
}
private:
bool preserve_aspect_;
};
// ==================== 笔画分割器 ====================
/**
* 笔画分割器
* 根据抬笔事件和时间间隔将连续坐标流分割为独立笔画
*/
class StrokeSegmenter {
public:
StrokeSegmenter(int time_threshold_ms = 200, int min_points = 3)
: time_threshold_(time_threshold_ms), min_points_(min_points) {}
/**
* 将原始点序列分割为笔画列表
*/
std::vector<std::vector<RawPoint>> segment(const std::vector<RawPoint>& points) {
if (points.empty()) return {};
std::vector<std::vector<RawPoint>> strokes;
std::vector<RawPoint> current;
current.push_back(points[0]);
for (size_t i = 1; i < points.size(); i++) {
bool is_break = points[i].pen_up;
int time_gap = static_cast<int>(points[i].timestamp - points[i-1].timestamp);
if ((is_break || time_gap > time_threshold_) &&
static_cast<int>(current.size()) >= min_points_) {
strokes.push_back(current);
current.clear();
}
if (!points[i].pen_up) {
current.push_back(points[i]);
}
}
if (static_cast<int>(current.size()) >= min_points_) {
strokes.push_back(current);
}
return strokes;
}
private:
int time_threshold_;
int min_points_;
};
// ==================== 图像渲染器 ====================
/**
* 笔迹图像渲染器
* 将归一化坐标渲染为灰度图像作为CNN模型输入
* 使用Bresenham直线算法连接相邻坐标点
*/
class StrokeImageRenderer {
public:
StrokeImageRenderer(int width = 64, int height = 64)
: width_(width), height_(height) {}
/**
* 将坐标序列渲染为灰度图像
* 输出一维浮点数组,值域[0,1],1表示笔迹
*/
std::vector<float> render(const std::vector<NormalizedPoint>& points) {
std::vector<float> image(width_ * height_, 0.0f);
for (size_t i = 1; i < points.size(); i++) {
int x0 = static_cast<int>(points[i-1].x * (width_ - 1));
int y0 = static_cast<int>(points[i-1].y * (height_ - 1));
int x1 = static_cast<int>(points[i].x * (width_ - 1));
int y1 = static_cast<int>(points[i].y * (height_ - 1));
// 裁剪到图像范围
x0 = std::clamp(x0, 0, width_ - 1);
y0 = std::clamp(y0, 0, height_ - 1);
x1 = std::clamp(x1, 0, width_ - 1);
y1 = std::clamp(y1, 0, height_ - 1);
float pressure = (points[i-1].pressure + points[i].pressure) * 0.5f;
// Bresenham直线算法
draw_line(image, x0, y0, x1, y1, pressure);
}
return image;
}
private:
void draw_line(std::vector<float>& image, int x0, int y0, int x1, int y1, float value) {
int dx = std::abs(x1 - x0);
int dy = std::abs(y1 - y0);
int sx = (x0 < x1) ? 1 : -1;
int sy = (y0 < y1) ? 1 : -1;
int err = dx - dy;
while (true) {
int idx = y0 * width_ + x0;
if (idx >= 0 && idx < width_ * height_) {
image[idx] = std::max(image[idx], value);
}
if (x0 == x1 && y0 == y1) break;
int e2 = 2 * err;
if (e2 > -dy) { err -= dy; x0 += sx; }
if (e2 < dx) { err += dx; y0 += sy; }
}
}
int width_;
int height_;
};
// ==================== 预处理管道(整合) ====================
/**
* 笔迹预处理管道
* 整合去噪、归一化、分割、渲染的完整处理流程
* 输入原始坐标点序列,输出标准化的推理输入数据
*/
class StrokePreprocessor {
public:
StrokePreprocessor(int image_size = 64)
: noise_filter_(50.0f, 3),
normalizer_(true),
segmenter_(200, 3),
renderer_(image_size, image_size),
image_size_(image_size) {}
/**
* 执行完整预处理管道
* 流程:原始坐标 → 去噪 → 归一化 → 笔画分割 → 图像渲染
*/
PreprocessedData process(const std::vector<RawPoint>& raw_points) {
PreprocessedData result;
// 步骤1:去噪滤波
auto denoised = noise_filter_.apply(raw_points);
// 步骤2:坐标归一化
auto normalized = normalizer_.normalize(denoised);
// 步骤3:笔画分割
auto stroke_groups = segmenter_.segment(denoised);
// 构建笔画数据
for (int i = 0; i < static_cast<int>(stroke_groups.size()); i++) {
Stroke stroke;
stroke.stroke_index = i;
auto norm_group = normalizer_.normalize(stroke_groups[i]);
stroke.points = norm_group;
stroke.length = calc_path_length(norm_group);
if (stroke_groups[i].size() >= 2) {
stroke.duration_ms = static_cast<int>(
stroke_groups[i].back().timestamp - stroke_groups[i].front().timestamp);
}
result.strokes.push_back(stroke);
}
// 步骤4:渲染为灰度图像
result.image = renderer_.render(normalized);
result.image_width = image_size_;
result.image_height = image_size_;
result.total_points = static_cast<int>(denoised.size());
result.stroke_count = static_cast<int>(result.strokes.size());
return result;
}
private:
float calc_path_length(const std::vector<NormalizedPoint>& points) {
float total = 0.0f;
for (size_t i = 1; i < points.size(); i++) {
float dx = points[i].x - points[i-1].x;
float dy = points[i].y - points[i-1].y;
total += std::sqrt(dx * dx + dy * dy);
}
return total;
}
StrokeNoiseFilter noise_filter_;
CoordinateNormalizer normalizer_;
StrokeSegmenter segmenter_;
StrokeImageRenderer renderer_;
int image_size_;
};
#endif // STROKE_PREPROCESSOR_H
```