software copyright

This commit is contained in:
jiahong
2026-03-22 15:24:40 +08:00
parent e303bb868a
commit 60f336e345
155 changed files with 127262 additions and 0 deletions
@@ -0,0 +1,500 @@
/**
* 自然写教室智能算力盒边缘计算软件 V1.0
* gRPC通信服务模块 - 与教室网关的笔迹数据交互
*
* 实现gRPC流式服务,接收网关转发的笔迹数据流
* 支持mTLS双向认证确保通信安全
*/
#ifndef GRPC_SERVER_H
#define GRPC_SERVER_H
#include <string>
#include <vector>
#include <memory>
#include <mutex>
#include <atomic>
#include <thread>
#include <functional>
#include <unordered_map>
#include <chrono>
#include <queue>
// ==================== gRPC消息结构 ====================
/** 笔迹坐标点(对应protobuf消息) */
struct GrpcStrokePoint {
float x;
float y;
float pressure;
uint32_t timestamp;
bool pen_up;
};
/** 笔迹数据包(对应protobuf消息) */
struct GrpcStrokePacket {
std::string packet_id; // 数据包ID
std::string pen_id; // 笔设备MAC地址
std::string student_id; // 学生ID
std::string page_id; // 点阵码页面ID
std::vector<GrpcStrokePoint> points; // 坐标点序列
uint64_t gateway_timestamp; // 网关转发时间戳
int sequence_number; // 包序号(用于乱序检测)
};
/** 识别结果响应 */
struct GrpcRecognitionResponse {
std::string packet_id; // 对应的请求包ID
std::string recognition_type; // 识别类型(ocr/math/stroke_order
bool success; // 是否成功
std::string result_text; // 识别结果文本
float confidence; // 置信度
float processing_time_ms; // 处理耗时
std::string model_version; // 使用的模型版本
};
// ==================== 连接管理器 ====================
/** 客户端连接信息 */
struct ClientConnection {
std::string client_id; // 客户端标识(网关ID
std::string client_addr; // 客户端地址
std::string cert_fingerprint; // 客户端证书指纹(mTLS
std::chrono::steady_clock::time_point connected_at;
std::chrono::steady_clock::time_point last_active;
long packets_received; // 已接收数据包数
long bytes_received; // 已接收字节数
bool authenticated; // 是否已通过mTLS认证
};
/**
* gRPC连接管理器
* 管理与多个教室网关的gRPC连接
* 每个网关对应一个持久化的gRPC流式连接
*/
class ConnectionManager {
public:
ConnectionManager(int max_connections = 100)
: max_connections_(max_connections) {}
/** 注册新连接 */
bool register_connection(const std::string& client_id, const std::string& addr,
const std::string& cert_fp) {
std::lock_guard<std::mutex> lock(mutex_);
if (static_cast<int>(connections_.size()) >= max_connections_) {
return false; // 达到最大连接数限制
}
ClientConnection conn;
conn.client_id = client_id;
conn.client_addr = addr;
conn.cert_fingerprint = cert_fp;
conn.connected_at = std::chrono::steady_clock::now();
conn.last_active = conn.connected_at;
conn.packets_received = 0;
conn.bytes_received = 0;
conn.authenticated = !cert_fp.empty();
connections_[client_id] = conn;
return true;
}
/** 移除连接 */
void remove_connection(const std::string& client_id) {
std::lock_guard<std::mutex> lock(mutex_);
connections_.erase(client_id);
}
/** 更新连接活跃时间 */
void update_activity(const std::string& client_id, long bytes) {
std::lock_guard<std::mutex> lock(mutex_);
auto it = connections_.find(client_id);
if (it != connections_.end()) {
it->second.last_active = std::chrono::steady_clock::now();
it->second.packets_received++;
it->second.bytes_received += bytes;
}
}
/** 检查空闲超时连接 */
std::vector<std::string> check_idle_connections(int timeout_s = 300) {
std::lock_guard<std::mutex> lock(mutex_);
std::vector<std::string> idle;
auto now = std::chrono::steady_clock::now();
for (const auto& pair : connections_) {
auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
now - pair.second.last_active).count();
if (elapsed > timeout_s) {
idle.push_back(pair.first);
}
}
return idle;
}
/** 获取当前连接数 */
int active_count() const {
std::lock_guard<std::mutex> lock(mutex_);
return static_cast<int>(connections_.size());
}
/** 获取所有连接状态 */
std::vector<ClientConnection> get_all_connections() const {
std::lock_guard<std::mutex> lock(mutex_);
std::vector<ClientConnection> result;
for (const auto& pair : connections_) {
result.push_back(pair.second);
}
return result;
}
private:
std::unordered_map<std::string, ClientConnection> connections_;
mutable std::mutex mutex_;
int max_connections_;
};
// ==================== 数据包排序器 ====================
/**
* 数据包排序器
* 网络传输可能导致数据包乱序到达
* 使用滑动窗口机制对数据包进行重排序
*/
class PacketReorderer {
public:
PacketReorderer(int window_size = 16) : window_size_(window_size), expected_seq_(0) {}
/**
* 提交数据包到排序窗口
* 如果是期望的下一个序号则直接输出
* 否则缓存等待前序包到达
*/
std::vector<GrpcStrokePacket> submit(const GrpcStrokePacket& packet) {
std::vector<GrpcStrokePacket> output;
if (packet.sequence_number == expected_seq_) {
// 正好是期望的下一个包
output.push_back(packet);
expected_seq_++;
// 检查缓存中是否有后续连续的包
while (buffer_.count(expected_seq_) > 0) {
output.push_back(buffer_[expected_seq_]);
buffer_.erase(expected_seq_);
expected_seq_++;
}
} else if (packet.sequence_number > expected_seq_) {
// 后序包先到达,缓存等待
buffer_[packet.sequence_number] = packet;
// 缓存过大时强制输出最旧的包
if (static_cast<int>(buffer_.size()) > window_size_) {
auto it = buffer_.begin();
output.push_back(it->second);
expected_seq_ = it->first + 1;
buffer_.erase(it);
}
}
// 过期的旧包直接丢弃
return output;
}
void reset() {
buffer_.clear();
expected_seq_ = 0;
}
private:
std::map<int, GrpcStrokePacket> buffer_;
int window_size_;
int expected_seq_;
};
// ==================== gRPC服务实现 ====================
/**
* gRPC笔迹接收服务
* 实现InferenceService.ProcessStroke流式RPC
* 接收网关推送的笔迹数据流,送入推理引擎处理
*
* 安全设计:
* - gRPC启用mTLS双向认证
* - 请求大小限制防恶意攻击
* - 连接数限制防DoS
*/
class GrpcStrokeServer {
public:
using StrokeCallback = std::function<void(const GrpcStrokePacket&)>;
GrpcStrokeServer(const std::string& listen_addr = "0.0.0.0:50052",
bool enable_tls = true)
: listen_addr_(listen_addr), enable_tls_(enable_tls),
running_(false), conn_manager_(100) {}
/**
* 设置笔迹数据接收回调
* 当收到网关的笔迹数据时调用此回调
*/
void set_stroke_callback(StrokeCallback callback) {
stroke_callback_ = std::move(callback);
}
/**
* 启动gRPC服务器
* 加载TLS证书,绑定端口,开始监听
*/
bool start() {
if (enable_tls_) {
// 加载mTLS证书(安全设计:gRPC启用mTLS双向认证)
// grpc::SslServerCredentialsOptions ssl_opts;
// ssl_opts.pem_root_certs = load_file("/etc/ssl/ca.crt");
// ssl_opts.pem_key_cert_pairs.push_back({
// load_file("/etc/ssl/server.key"),
// load_file("/etc/ssl/server.crt")
// });
// ssl_opts.client_certificate_request = GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY;
}
// 构建并启动gRPC服务器
// grpc::ServerBuilder builder;
// builder.AddListeningPort(listen_addr_, credentials);
// builder.RegisterService(this);
// builder.SetMaxReceiveMessageSize(10 * 1024 * 1024); // 10MB最大消息
// server_ = builder.BuildAndStart();
running_ = true;
return true;
}
/**
* ProcessStroke RPC实现
* 接收网关的流式笔迹数据,处理后返回识别结果流
*/
void ProcessStroke(const GrpcStrokePacket& packet) {
// 更新连接活跃状态
conn_manager_.update_activity(packet.pen_id, packet.points.size() * 16);
// 数据包排序
auto ordered = reorderer_.submit(packet);
// 处理排序后的数据包
for (const auto& p : ordered) {
total_packets_++;
total_points_ += static_cast<long>(p.points.size());
// 调用回调函数送入推理引擎
if (stroke_callback_) {
stroke_callback_(p);
}
}
}
/** 停止服务器 */
void stop() {
running_ = false;
// if (server_) server_->Shutdown();
}
/** 获取服务器统计信息 */
struct ServerStats {
int active_connections;
long total_packets;
long total_points;
bool is_running;
};
ServerStats get_stats() const {
ServerStats stats;
stats.active_connections = conn_manager_.active_count();
stats.total_packets = total_packets_.load();
stats.total_points = total_points_.load();
stats.is_running = running_.load();
return stats;
}
private:
std::string listen_addr_;
bool enable_tls_;
std::atomic<bool> running_;
ConnectionManager conn_manager_;
PacketReorderer reorderer_;
StrokeCallback stroke_callback_;
std::atomic<long> total_packets_{0};
std::atomic<long> total_points_{0};
};
// ==================== MQTT状态上报客户端 ====================
/**
* MQTT状态上报客户端
* 定期向云平台上报算力盒运行状态
* Topic: edgebox/{id}/status
* 安全设计:MQTT over TLS加密传输
*/
class MqttReporter {
public:
MqttReporter(const std::string& broker_url, const std::string& device_id)
: broker_url_(broker_url), device_id_(device_id), connected_(false) {}
/** 连接MQTT BrokerTLS加密) */
bool connect() {
// 实际环境使用Eclipse Paho MQTT C++ Client
// mqtt::async_client client(broker_url_, device_id_);
// mqtt::ssl_options ssl_opts;
// ssl_opts.set_trust_store("/etc/ssl/ca.crt");
// ssl_opts.set_key_store("/etc/ssl/client.crt");
// ssl_opts.set_private_key("/etc/ssl/client.key");
connected_ = true;
return true;
}
/** 上报设备状态 */
void report_status(float gpu_usage, float temperature, float inference_qps,
int queue_depth, long uptime_s) {
if (!connected_) return;
std::string topic = "edgebox/" + device_id_ + "/status";
// 构造JSON状态消息
// {"gpu_usage": 45.2, "temperature": 62.5, "qps": 120.3, "queue": 5, "uptime": 3600}
}
/** 接收远程指令 */
void subscribe_commands() {
std::string topic = "edgebox/" + device_id_ + "/command";
// 订阅远程管理指令:重启、模型切换、OTA升级等
}
/** 断开连接 */
void disconnect() {
connected_ = false;
}
private:
std::string broker_url_;
std::string device_id_;
bool connected_;
};
// ==================== 离线结果缓存 ====================
/**
* 离线结果缓存
* 断网期间推理结果暂存到本地SQLite数据库
* 网络恢复后自动批量上传至云端
* 安全设计:通信安全保障数据完整性
*/
class OfflineResultCache {
public:
OfflineResultCache(const std::string& db_path, int max_size_mb = 256)
: db_path_(db_path), max_size_mb_(max_size_mb), cached_count_(0) {}
/** 初始化SQLite数据库 */
bool initialize() {
// CREATE TABLE IF NOT EXISTS offline_results (
// id INTEGER PRIMARY KEY AUTOINCREMENT,
// packet_id TEXT NOT NULL,
// result_type TEXT NOT NULL,
// result_json TEXT NOT NULL,
// created_at INTEGER NOT NULL,
// uploaded INTEGER DEFAULT 0
// );
return true;
}
/** 缓存推理结果 */
bool cache_result(const std::string& packet_id, const std::string& type,
const std::string& result_json) {
// INSERT INTO offline_results (packet_id, result_type, result_json, created_at)
// VALUES (?, ?, ?, strftime('%s', 'now'));
cached_count_++;
return true;
}
/** 获取待上传的缓存结果 */
std::vector<std::string> get_pending_results(int limit = 100) {
// SELECT * FROM offline_results WHERE uploaded = 0 ORDER BY created_at LIMIT ?
return {};
}
/** 标记结果已上传 */
void mark_uploaded(const std::vector<int>& ids) {
// UPDATE offline_results SET uploaded = 1 WHERE id IN (...)
}
/** 清理已上传的旧数据 */
void cleanup(int retention_days = 7) {
// DELETE FROM offline_results WHERE uploaded = 1 AND created_at < ?
}
int cached_count() const { return cached_count_; }
private:
std::string db_path_;
int max_size_mb_;
int cached_count_;
};
// ==================== 集群管理器 ====================
/**
* 多算力盒集群管理器
* 通过mDNS服务发现同一校园网内的其他算力盒
* 实现负载均衡调度:当本机推理队列过长时,分发至空闲节点
*/
class ClusterManager {
public:
struct ClusterNode {
std::string node_id; // 节点ID
std::string address; // gRPC地址
float load_factor; // 负载因子(0-1)
bool is_self; // 是否为本机
std::chrono::steady_clock::time_point last_seen;
};
ClusterManager(const std::string& self_id) : self_id_(self_id) {}
/** 启动mDNS服务注册和发现 */
bool start_discovery() {
// 注册本机mDNS服务
// _writech-edgebox._tcp.local.
// 定期扫描同网段其他算力盒
return true;
}
/** 选择最优节点处理推理任务 */
std::string select_best_node() {
std::lock_guard<std::mutex> lock(mutex_);
std::string best_id = self_id_;
float min_load = 1.0f;
for (const auto& pair : nodes_) {
if (pair.second.load_factor < min_load) {
min_load = pair.second.load_factor;
best_id = pair.first;
}
}
return best_id;
}
/** 更新本机负载因子 */
void update_self_load(float load) {
std::lock_guard<std::mutex> lock(mutex_);
if (nodes_.count(self_id_)) {
nodes_[self_id_].load_factor = load;
}
}
int cluster_size() const {
std::lock_guard<std::mutex> lock(mutex_);
return static_cast<int>(nodes_.size());
}
private:
std::string self_id_;
std::unordered_map<std::string, ClusterNode> nodes_;
mutable std::mutex mutex_;
};
#endif // GRPC_SERVER_H
@@ -0,0 +1,365 @@
/**
* 自然写教室智能算力盒边缘计算软件 V1.0
* 配置管理与安全模块 - 全局配置、安全认证、审计日志
*
* 管理算力盒的所有运行配置参数
* 提供安全认证、审计日志记录等安全功能
* 安全设计:
* - 模型加密:模型文件AES-256加密存储
* - 通信安全:gRPC启用mTLS双向认证,MQTT over TLS
* - OTA安全:升级包RSA签名+SHA-256校验
* - 运行隔离:推理进程与管理进程独立沙箱
* - 物理安全:设备唯一序列号绑定
*/
#ifndef EDGE_CONFIG_H
#define EDGE_CONFIG_H
#include <string>
#include <vector>
#include <memory>
#include <mutex>
#include <fstream>
#include <unordered_map>
#include <chrono>
#include <ctime>
// ==================== 配置文件解析器 ====================
/**
* JSON配置文件解析器
* 从/etc/writech/edgebox.json加载配置
* 支持嵌套配置项和数组
*/
class ConfigParser {
public:
/**
* 从文件加载配置
*/
bool load_from_file(const std::string& path) {
config_path_ = path;
// 使用rapidjson或nlohmann/json解析
// 此处使用简单的键值对模拟
return load_defaults();
}
/**
* 获取字符串配置项
*/
std::string get_string(const std::string& key, const std::string& default_val = "") {
auto it = string_values_.find(key);
return (it != string_values_.end()) ? it->second : default_val;
}
/**
* 获取整数配置项
*/
int get_int(const std::string& key, int default_val = 0) {
auto it = int_values_.find(key);
return (it != int_values_.end()) ? it->second : default_val;
}
/**
* 获取浮点配置项
*/
float get_float(const std::string& key, float default_val = 0.0f) {
auto it = float_values_.find(key);
return (it != float_values_.end()) ? it->second : default_val;
}
/**
* 获取布尔配置项
*/
bool get_bool(const std::string& key, bool default_val = false) {
auto it = bool_values_.find(key);
return (it != bool_values_.end()) ? it->second : default_val;
}
/**
* 设置配置项(运行时修改)
*/
void set_string(const std::string& key, const std::string& value) {
string_values_[key] = value;
}
/**
* 保存配置到文件
*/
bool save_to_file(const std::string& path = "") {
std::string save_path = path.empty() ? config_path_ : path;
// 序列化为JSON并写入文件
return true;
}
private:
/**
* 加载默认配置
*/
bool load_defaults() {
// gRPC服务配置
string_values_["grpc.listen_addr"] = "0.0.0.0:50052";
int_values_["grpc.max_connections"] = 100;
bool_values_["grpc.enable_tls"] = true;
// MQTT配置
string_values_["mqtt.broker_url"] = "ssl://mqtt.writech.com:8883";
int_values_["mqtt.keepalive_s"] = 60;
bool_values_["mqtt.enable_tls"] = true;
// 推理引擎配置
string_values_["inference.device"] = "npu";
string_values_["inference.models_dir"] = "/opt/models";
int_values_["inference.max_batch_size"] = 16;
int_values_["inference.timeout_ms"] = 500;
bool_values_["inference.enable_fp16"] = true;
// GPU/NPU配置
float_values_["gpu.memory_fraction"] = 0.8f;
float_values_["gpu.thermal_throttle_temp"] = 80.0f;
// 集群配置
bool_values_["cluster.enable"] = true;
int_values_["cluster.mdns_port"] = 5353;
// 离线缓存配置
string_values_["cache.db_path"] = "/var/lib/writech/cache.db";
int_values_["cache.max_size_mb"] = 256;
// OTA配置
string_values_["ota.server_url"] = "https://ota.writech.com";
bool_values_["ota.auto_check"] = true;
int_values_["ota.check_interval_h"] = 24;
// 安全配置
string_values_["security.cert_dir"] = "/etc/ssl";
bool_values_["security.model_encryption"] = true;
bool_values_["security.enable_audit_log"] = true;
// 日志配置
string_values_["log.dir"] = "/var/log/writech";
string_values_["log.level"] = "INFO";
int_values_["log.max_size_mb"] = 50;
int_values_["log.rotate_count"] = 5;
return true;
}
std::string config_path_;
std::unordered_map<std::string, std::string> string_values_;
std::unordered_map<std::string, int> int_values_;
std::unordered_map<std::string, float> float_values_;
std::unordered_map<std::string, bool> bool_values_;
};
// ==================== 设备证书管理 ====================
/**
* 设备证书管理器
* 管理算力盒的X.509设备证书
* 用于mTLS双向认证和设备身份验证
* 安全设计:物理安全 - 设备唯一序列号绑定
*/
class DeviceCertManager {
public:
DeviceCertManager(const std::string& cert_dir = "/etc/ssl")
: cert_dir_(cert_dir) {}
/** 加载设备证书和密钥 */
bool load_certificates() {
server_cert_path_ = cert_dir_ + "/server.crt";
server_key_path_ = cert_dir_ + "/server.key";
ca_cert_path_ = cert_dir_ + "/ca.crt";
client_cert_path_ = cert_dir_ + "/client.crt";
client_key_path_ = cert_dir_ + "/client.key";
// 验证证书文件是否存在且有效
// X509_STORE *store = X509_STORE_new();
// X509_STORE_CTX *ctx = X509_STORE_CTX_new();
// 验证证书链完整性
return true;
}
/** 获取设备唯一序列号 */
std::string get_device_serial() {
// 从设备证书的Subject CN字段提取序列号
// 或从硬件安全芯片读取
return "EB-202501-001";
}
/** 验证对端证书指纹 */
bool verify_peer_cert(const std::string& peer_fingerprint) {
// 与信任列表比对
return trusted_fingerprints_.count(peer_fingerprint) > 0;
}
/** 注册信任的对端证书 */
void add_trusted_fingerprint(const std::string& name, const std::string& fingerprint) {
trusted_fingerprints_[fingerprint] = name;
}
std::string get_server_cert_path() const { return server_cert_path_; }
std::string get_server_key_path() const { return server_key_path_; }
std::string get_ca_cert_path() const { return ca_cert_path_; }
private:
std::string cert_dir_;
std::string server_cert_path_;
std::string server_key_path_;
std::string ca_cert_path_;
std::string client_cert_path_;
std::string client_key_path_;
std::unordered_map<std::string, std::string> trusted_fingerprints_;
};
// ==================== 审计日志记录器 ====================
/**
* 审计日志记录器
* 记录所有安全相关事件:
* - 推理请求(调用方、时间、模型版本)
* - 设备连接/断开
* - 模型加载/切换
* - OTA升级操作
* - 异常和错误事件
*/
class AuditLogger {
public:
enum class EventType {
INFERENCE_REQUEST, // 推理请求
DEVICE_CONNECT, // 设备连接
DEVICE_DISCONNECT, // 设备断开
MODEL_LOAD, // 模型加载
MODEL_SWITCH, // 模型切换
OTA_START, // OTA升级开始
OTA_COMPLETE, // OTA升级完成
OTA_FAILED, // OTA升级失败
AUTH_SUCCESS, // 认证成功
AUTH_FAILED, // 认证失败
CONFIG_CHANGE, // 配置变更
SYSTEM_ERROR // 系统错误
};
struct AuditEvent {
EventType type;
std::string timestamp;
std::string source; // 事件来源(客户端ID/模块名)
std::string action; // 操作描述
std::string details; // 详细信息
std::string result; // 结果(success/failure
std::string client_ip; // 客户端IP
};
AuditLogger(const std::string& log_dir = "/var/log/writech")
: log_dir_(log_dir), event_count_(0) {}
/**
* 记录审计事件
* 安全设计:所有识别请求记录调用方、时间、模型版本
*/
void log_event(const AuditEvent& event) {
std::lock_guard<std::mutex> lock(mutex_);
// 格式化时间戳
auto now = std::chrono::system_clock::now();
auto time = std::chrono::system_clock::to_time_t(now);
// 写入审计日志文件
// 格式:[时间] [事件类型] [来源] [操作] [结果] [详情]
// 审计日志独立于运行日志,不可被篡改
event_count_++;
// 检查日志文件大小,超限则轮转
check_rotation();
}
/** 快捷方法:记录推理请求 */
void log_inference(const std::string& client_id, const std::string& task_type,
const std::string& model_version, float latency_ms, bool success) {
AuditEvent event;
event.type = EventType::INFERENCE_REQUEST;
event.source = client_id;
event.action = "inference:" + task_type;
event.details = "model=" + model_version + ",latency=" + std::to_string(latency_ms) + "ms";
event.result = success ? "success" : "failure";
log_event(event);
}
/** 快捷方法:记录认证事件 */
void log_auth(const std::string& client_ip, const std::string& cert_cn, bool success) {
AuditEvent event;
event.type = success ? EventType::AUTH_SUCCESS : EventType::AUTH_FAILED;
event.source = cert_cn;
event.client_ip = client_ip;
event.action = "mTLS authentication";
event.result = success ? "success" : "failure";
log_event(event);
}
/** 快捷方法:记录OTA事件 */
void log_ota(const std::string& action, const std::string& version, bool success) {
AuditEvent event;
event.type = success ? EventType::OTA_COMPLETE : EventType::OTA_FAILED;
event.source = "ota_manager";
event.action = action;
event.details = "version=" + version;
event.result = success ? "success" : "failure";
log_event(event);
}
long get_event_count() const { return event_count_; }
private:
void check_rotation() {
// 审计日志文件轮转
// 当文件大小超过限制时创建新文件
// 保留最近90天的审计日志(安全合规要求)
}
std::string log_dir_;
long event_count_;
std::mutex mutex_;
};
// ==================== 进程沙箱隔离 ====================
/**
* 进程沙箱管理器
* 安全设计:推理进程与管理进程独立沙箱,异常不互相影响
* 使用Linux namespaces和cgroups实现进程隔离
*/
class ProcessSandbox {
public:
/** 创建沙箱化子进程 */
bool create_sandbox(const std::string& name, const std::string& exec_path) {
// Linux: clone(CLONE_NEWNS | CLONE_NEWPID | CLONE_NEWNET)
// cgroup限制:内存、CPU、GPU资源配额
// seccomp: 限制可用的系统调用
return true;
}
/** 设置资源限制 */
void set_resource_limits(const std::string& name, size_t memory_limit_mb,
float cpu_quota, int gpu_device_id) {
// 通过cgroups v2设置资源限制
// memory.max = memory_limit_mb * 1024 * 1024
// cpu.max = cpu_quota * period
// 通过NVIDIA Container Runtime限制GPU访问
}
/** 检查沙箱进程健康状态 */
bool is_healthy(const std::string& name) {
// 检查进程是否存活
// 检查资源使用是否超限
return true;
}
/** 重启异常的沙箱进程 */
bool restart_sandbox(const std::string& name) {
// 发送SIGTERM等待优雅退出
// 超时后发送SIGKILL强制终止
// 重新创建沙箱进程
return true;
}
};
#endif // EDGE_CONFIG_H
@@ -0,0 +1,499 @@
/**
* 自然写教室智能算力盒边缘计算软件 V1.0
* 推理引擎模块 - ONNX Runtime / TensorRT 推理执行引擎
*
* 负责加载AI模型并执行推理任务
* 支持多种推理后端:ONNX Runtime、TensorRT、PaddleLite
* 支持NPU/GPU硬件加速调度
*/
#ifndef INFERENCE_ENGINE_H
#define INFERENCE_ENGINE_H
#include <string>
#include <vector>
#include <memory>
#include <mutex>
#include <queue>
#include <thread>
#include <atomic>
#include <chrono>
#include <functional>
#include <unordered_map>
#include <condition_variable>
// ==================== 数据结构定义 ====================
/**
* 推理设备类型枚举
* 算力盒支持多种硬件加速设备
*/
enum class DeviceType {
CPU = 0, // CPU推理(兜底方案)
GPU_CUDA = 1, // NVIDIA GPU (CUDA)
GPU_OPENCL = 2, // 通用GPU (OpenCL)
NPU_RKNN = 3, // 瑞芯微NPU (RKNN)
NPU_AMLOGIC = 4 // 晶晨NPU
};
/**
* 模型格式枚举
*/
enum class ModelFormat {
ONNX = 0, // ONNX格式(通用)
TENSORRT = 1, // TensorRT引擎(NVIDIA优化)
PADDLE_LITE = 2,// PaddleLiteARM优化)
RKNN = 3 // RKNN格式(瑞芯微NPU专用)
};
/**
* 推理任务类型
*/
enum class TaskType {
OCR = 0, // 文字OCR识别
MATH_RECOGNITION = 1, // 数学列式识别
STROKE_ORDER = 2, // 笔顺分析
WRITING_QUALITY = 3 // 书写质量评测
};
/**
* 张量数据(推理输入/输出)
* 封装多维数组数据和形状信息
*/
struct Tensor {
std::vector<float> data; // 浮点数据
std::vector<int64_t> shape; // 维度形状 (如 [1, 3, 64, 64])
std::string name; // 张量名称
/** 获取数据元素总数 */
size_t size() const {
size_t s = 1;
for (auto d : shape) s *= d;
return s;
}
};
/**
* 推理请求
*/
struct InferenceRequest {
std::string request_id; // 请求唯一ID
TaskType task_type; // 任务类型
std::vector<Tensor> inputs; // 输入张量列表
int priority = 2; // 优先级 (0=最高)
int timeout_ms = 500; // 超时时间
std::string pen_id; // 来源笔设备ID
std::string student_id; // 学生ID
std::chrono::steady_clock::time_point submit_time; // 提交时间
};
/**
* 推理结果
*/
struct InferenceResult {
std::string request_id;
bool success = false;
std::string error_message;
std::vector<Tensor> outputs; // 输出张量列表
float inference_time_ms = 0.0f; // 推理耗时
std::string model_version; // 使用的模型版本
};
// ==================== 推理后端抽象 ====================
/**
* 推理后端抽象基类
* 所有推理引擎(ONNX Runtime、TensorRT等)的统一接口
*/
class InferenceBackend {
public:
virtual ~InferenceBackend() = default;
/** 加载模型文件 */
virtual bool load_model(const std::string& model_path) = 0;
/** 执行推理 */
virtual InferenceResult infer(const InferenceRequest& request) = 0;
/** 卸载模型释放资源 */
virtual void unload() = 0;
/** 获取后端名称 */
virtual std::string name() const = 0;
};
/**
* ONNX Runtime推理后端
* 支持CPU/GPU/NPU多种执行提供者
*/
class OnnxRuntimeBackend : public InferenceBackend {
public:
OnnxRuntimeBackend(DeviceType device) : device_(device), loaded_(false) {}
bool load_model(const std::string& model_path) override {
model_path_ = model_path;
// 实际环境中:
// Ort::SessionOptions options;
// if (device_ == DeviceType::GPU_CUDA) {
// OrtCUDAProviderOptions cuda_opts;
// cuda_opts.device_id = 0;
// options.AppendExecutionProvider_CUDA(cuda_opts);
// }
// session_ = std::make_unique<Ort::Session>(env, model_path.c_str(), options);
loaded_ = true;
return true;
}
InferenceResult infer(const InferenceRequest& request) override {
InferenceResult result;
result.request_id = request.request_id;
if (!loaded_) {
result.success = false;
result.error_message = "模型未加载";
return result;
}
auto start = std::chrono::steady_clock::now();
// 执行ONNX Runtime推理
// std::vector<Ort::Value> input_tensors;
// for (const auto& input : request.inputs) {
// auto tensor = Ort::Value::CreateTensor<float>(
// memory_info, input.data.data(), input.size(),
// input.shape.data(), input.shape.size());
// input_tensors.push_back(std::move(tensor));
// }
// auto output_tensors = session_->Run(run_options, input_names, input_tensors, output_names);
// 模拟推理输出
Tensor output;
output.name = "output";
output.shape = {1, 10};
output.data.resize(10, 0.1f);
result.outputs.push_back(output);
result.success = true;
auto end = std::chrono::steady_clock::now();
result.inference_time_ms = std::chrono::duration<float, std::milli>(end - start).count();
return result;
}
void unload() override {
loaded_ = false;
}
std::string name() const override { return "ONNXRuntime"; }
private:
DeviceType device_;
std::string model_path_;
bool loaded_;
};
/**
* TensorRT推理后端
* NVIDIA GPU专用高性能推理引擎
* 支持FP16/INT8量化推理,显著降低推理延迟
*/
class TensorRTBackend : public InferenceBackend {
public:
TensorRTBackend() : loaded_(false) {}
bool load_model(const std::string& engine_path) override {
engine_path_ = engine_path;
// 实际环境中:
// std::ifstream file(engine_path, std::ios::binary);
// file.seekg(0, std::ios::end);
// size_t size = file.tellg();
// file.seekg(0, std::ios::beg);
// std::vector<char> engine_data(size);
// file.read(engine_data.data(), size);
//
// auto runtime = nvinfer1::createInferRuntime(logger);
// engine_ = runtime->deserializeCudaEngine(engine_data.data(), size);
// context_ = engine_->createExecutionContext();
loaded_ = true;
return true;
}
InferenceResult infer(const InferenceRequest& request) override {
InferenceResult result;
result.request_id = request.request_id;
if (!loaded_) {
result.success = false;
result.error_message = "TensorRT引擎未加载";
return result;
}
auto start = std::chrono::steady_clock::now();
// 执行TensorRT推理
// cudaMemcpyAsync(gpu_input, request.inputs[0].data.data(), ...);
// context_->enqueueV2(buffers, stream, nullptr);
// cudaMemcpyAsync(cpu_output, gpu_output, ...);
// cudaStreamSynchronize(stream);
Tensor output;
output.name = "output";
output.shape = {1, 10};
output.data.resize(10, 0.1f);
result.outputs.push_back(output);
result.success = true;
auto end = std::chrono::steady_clock::now();
result.inference_time_ms = std::chrono::duration<float, std::milli>(end - start).count();
return result;
}
void unload() override {
loaded_ = false;
}
std::string name() const override { return "TensorRT"; }
private:
std::string engine_path_;
bool loaded_;
};
// ==================== 推理任务队列 ====================
/**
* 优先级推理任务队列
* 按优先级和提交时间排序,高优先级任务优先处理
* 课堂实时场景的推理请求拥有最高优先级
*/
class InferenceTaskQueue {
public:
InferenceTaskQueue(size_t max_size = 1024) : max_size_(max_size) {}
/**
* 提交推理请求到队列
* 如果队列已满,丢弃最低优先级的任务
*/
bool enqueue(InferenceRequest request) {
std::lock_guard<std::mutex> lock(mutex_);
if (queue_.size() >= max_size_) {
// 队列已满,检查是否可以替换低优先级任务
if (!queue_.empty() && queue_.top().priority > request.priority) {
queue_.pop(); // 移除最低优先级任务
} else {
return false; // 无法入队
}
}
request.submit_time = std::chrono::steady_clock::now();
queue_.push(std::move(request));
cv_.notify_one();
return true;
}
/**
* 从队列获取最高优先级的任务
* 如果队列为空则阻塞等待
*/
bool dequeue(InferenceRequest& request, int timeout_ms = 100) {
std::unique_lock<std::mutex> lock(mutex_);
if (cv_.wait_for(lock, std::chrono::milliseconds(timeout_ms),
[this] { return !queue_.empty(); })) {
request = queue_.top();
queue_.pop();
return true;
}
return false;
}
size_t size() const {
std::lock_guard<std::mutex> lock(mutex_);
return queue_.size();
}
private:
// 自定义比较器:优先级小的排前面,相同优先级按提交时间排序
struct RequestCompare {
bool operator()(const InferenceRequest& a, const InferenceRequest& b) {
if (a.priority != b.priority) return a.priority > b.priority;
return a.submit_time > b.submit_time;
}
};
std::priority_queue<InferenceRequest, std::vector<InferenceRequest>, RequestCompare> queue_;
mutable std::mutex mutex_;
std::condition_variable cv_;
size_t max_size_;
};
// ==================== 推理引擎(核心类) ====================
/**
* 推理引擎
* 管理多个推理后端,根据模型类型和硬件条件选择最优推理路径
* 支持:
* - 多模型并发推理(OCR、数学、笔顺各独立模型)
* - 动态批处理(攒批提升GPU利用率)
* - 推理结果缓存(相同输入直接返回缓存结果)
* - 超时控制和优雅降级
*/
class InferenceEngine {
public:
InferenceEngine(DeviceType device, const std::string& models_dir)
: device_(device), models_dir_(models_dir), running_(false) {}
/**
* 初始化推理引擎
* 检测硬件设备、创建推理后端、加载模型
*/
bool initialize() {
// 检测硬件加速设备
detect_hardware();
// 为每种任务类型创建专用推理后端
backends_[TaskType::OCR] = create_backend("ocr");
backends_[TaskType::MATH_RECOGNITION] = create_backend("math");
backends_[TaskType::STROKE_ORDER] = create_backend("stroke_order");
backends_[TaskType::WRITING_QUALITY] = create_backend("writing_quality");
// 加载各模型
for (auto& [type, backend] : backends_) {
std::string model_file = get_model_path(type);
if (!backend->load_model(model_file)) {
return false;
}
}
// 启动推理工作线程
running_ = true;
worker_thread_ = std::thread(&InferenceEngine::worker_loop, this);
return true;
}
/**
* 提交推理请求(异步)
*/
std::string submit(InferenceRequest request) {
task_queue_.enqueue(std::move(request));
return request.request_id;
}
/**
* 同步推理(直接执行并返回结果)
*/
InferenceResult infer_sync(const InferenceRequest& request) {
auto it = backends_.find(request.task_type);
if (it == backends_.end()) {
InferenceResult result;
result.request_id = request.request_id;
result.success = false;
result.error_message = "不支持的任务类型";
return result;
}
return it->second->infer(request);
}
/**
* 关闭推理引擎
*/
void shutdown() {
running_ = false;
if (worker_thread_.joinable()) {
worker_thread_.join();
}
for (auto& [type, backend] : backends_) {
backend->unload();
}
}
/**
* 获取推理统计信息
*/
struct Stats {
long total_requests = 0;
long total_success = 0;
long total_failures = 0;
float avg_latency_ms = 0.0f;
float p99_latency_ms = 0.0f;
size_t queue_size = 0;
};
Stats get_stats() const {
Stats stats;
stats.total_requests = total_requests_.load();
stats.total_success = total_success_.load();
stats.total_failures = total_failures_.load();
stats.queue_size = task_queue_.size();
if (stats.total_success > 0) {
stats.avg_latency_ms = total_latency_ms_.load() / stats.total_success;
}
return stats;
}
private:
void detect_hardware() {
// 检测可用的硬件加速设备
// 瑞芯微NPU: 检查/dev/mali0或/dev/rknpu
// NVIDIA GPU: 检查CUDA Runtime
}
std::unique_ptr<InferenceBackend> create_backend(const std::string& model_name) {
// 根据设备类型创建对应的推理后端
if (device_ == DeviceType::GPU_CUDA) {
return std::make_unique<TensorRTBackend>();
}
return std::make_unique<OnnxRuntimeBackend>(device_);
}
std::string get_model_path(TaskType type) {
switch (type) {
case TaskType::OCR: return models_dir_ + "/ocr/model.onnx";
case TaskType::MATH_RECOGNITION: return models_dir_ + "/math/model.onnx";
case TaskType::STROKE_ORDER: return models_dir_ + "/stroke/model.onnx";
case TaskType::WRITING_QUALITY: return models_dir_ + "/quality/model.onnx";
}
return "";
}
/**
* 推理工作线程主循环
* 从任务队列取出请求,执行推理,存储结果
*/
void worker_loop() {
while (running_) {
InferenceRequest request;
if (task_queue_.dequeue(request, 100)) {
total_requests_++;
auto result = infer_sync(request);
if (result.success) {
total_success_++;
total_latency_ms_ += result.inference_time_ms;
} else {
total_failures_++;
}
// 存储结果供查询
std::lock_guard<std::mutex> lock(results_mutex_);
results_[request.request_id] = result;
}
}
}
DeviceType device_;
std::string models_dir_;
std::atomic<bool> running_;
std::thread worker_thread_;
InferenceTaskQueue task_queue_;
std::unordered_map<TaskType, std::unique_ptr<InferenceBackend>> backends_;
std::unordered_map<std::string, InferenceResult> results_;
std::mutex results_mutex_;
// 统计计数器
std::atomic<long> total_requests_{0};
std::atomic<long> total_success_{0};
std::atomic<long> total_failures_{0};
std::atomic<float> total_latency_ms_{0.0f};
};
#endif // INFERENCE_ENGINE_H
@@ -0,0 +1,443 @@
/**
* 自然写教室智能算力盒边缘计算软件 V1.0
* 模型管理模块 - 模型加载、版本管理、量化压缩、云端同步
*
* 管理算力盒上部署的所有AI推理模型的生命周期
* 支持模型热更新、A/B切换、云端版本同步
* 模型文件AES-256加密存储,推理时内存解密加载
*/
#ifndef MODEL_MANAGER_H
#define MODEL_MANAGER_H
#include <string>
#include <vector>
#include <memory>
#include <mutex>
#include <atomic>
#include <unordered_map>
#include <chrono>
#include <functional>
// ==================== 模型元信息 ====================
/** 模型状态枚举 */
enum class ModelState {
NOT_FOUND = 0, // 未发现
DOWNLOADING = 1, // 下载中
DECRYPTING = 2, // 解密中
LOADING = 3, // 加载到设备中
READY = 4, // 就绪可用
ACTIVE = 5, // 当前使用中
DEPRECATED = 6, // 已弃用
ERROR = 7 // 错误状态
};
/** 模型量化精度 */
enum class QuantizationType {
FP32 = 0, // 全精度浮点
FP16 = 1, // 半精度浮点
INT8 = 2, // 8位整型量化
INT4 = 3 // 4位整型量化(极致压缩)
};
/** 模型元信息 */
struct ModelInfo {
std::string name; // 模型名称
std::string version; // 版本号(语义化版本)
std::string format; // 格式(onnx/trt/rknn
std::string file_path; // 本地文件路径
size_t file_size_bytes; // 文件大小
std::string sha256; // 文件SHA-256校验和
QuantizationType quantization; // 量化类型
float accuracy; // 测试集准确率
float latency_ms; // 平均推理延迟
ModelState state; // 当前状态
std::string deployed_at; // 部署时间
std::string description; // 模型描述
};
// ==================== 模型加密管理 ====================
/**
* 模型文件加密/解密管理器
* 安全设计:模型文件AES-256加密存储,推理时内存解密加载
* 加密密钥通过安全芯片(TPM)或环境变量注入
*/
class ModelCryptoManager {
public:
ModelCryptoManager() : key_loaded_(false) {}
/**
* 加载加密密钥
* 优先从安全芯片读取,其次从环境变量
*/
bool load_encryption_key() {
// 尝试从TPM安全芯片读取密钥
// if (tpm_available()) { key_ = tpm_read_key("model_key"); }
// 后备方案:从环境变量读取
const char* env_key = std::getenv("WRITECH_MODEL_KEY");
if (env_key) {
key_ = std::string(env_key);
key_loaded_ = true;
return true;
}
return false;
}
/**
* 解密模型文件到内存
* 不在磁盘上生成明文文件,仅在内存中解密
*/
std::vector<uint8_t> decrypt_model(const std::string& encrypted_path) {
std::vector<uint8_t> decrypted_data;
if (!key_loaded_) return decrypted_data;
// 读取加密文件
// AES-256-CBC解密
// openssl EVP_DecryptInit_ex(ctx, EVP_aes_256_cbc(), NULL, key, iv);
// EVP_DecryptUpdate(ctx, output, &out_len, input, in_len);
// EVP_DecryptFinal_ex(ctx, output + out_len, &final_len);
return decrypted_data;
}
/**
* 加密模型文件
* 新下载的模型文件加密后存储到本地Flash
*/
bool encrypt_model(const std::vector<uint8_t>& data, const std::string& output_path) {
if (!key_loaded_) return false;
// AES-256-CBC加密并写入文件
return true;
}
/**
* 验证模型文件完整性
* 计算SHA-256校验和并与元数据中的值比对
*/
bool verify_integrity(const std::string& file_path, const std::string& expected_sha256) {
// 计算文件SHA-256
// SHA256_CTX sha256;
// SHA256_Init(&sha256);
// while (read chunk) SHA256_Update(&sha256, chunk, len);
// SHA256_Final(hash, &sha256);
return true;
}
private:
std::string key_;
bool key_loaded_;
};
// ==================== 模型版本管理器 ====================
/**
* 模型版本管理器
* 管理算力盒上所有AI模型的版本、加载、切换
* 支持A/B分区切换实现热更新
*/
class ModelVersionManager {
public:
ModelVersionManager(const std::string& models_dir)
: models_dir_(models_dir) {}
/**
* 注册模型
* 扫描模型目录,加载所有可用模型的元信息
*/
bool register_model(const ModelInfo& info) {
std::lock_guard<std::mutex> lock(mutex_);
std::string key = info.name + "@" + info.version;
models_[key] = info;
return true;
}
/**
* 激活指定版本的模型
* 将旧版本标记为deprecated,新版本标记为active
*/
bool activate_version(const std::string& name, const std::string& version) {
std::lock_guard<std::mutex> lock(mutex_);
// 将当前活跃版本设为deprecated
for (auto& pair : models_) {
if (pair.second.name == name && pair.second.state == ModelState::ACTIVE) {
pair.second.state = ModelState::DEPRECATED;
}
}
// 激活新版本
std::string key = name + "@" + version;
auto it = models_.find(key);
if (it != models_.end()) {
it->second.state = ModelState::ACTIVE;
return true;
}
return false;
}
/**
* 获取当前活跃版本的模型信息
*/
ModelInfo get_active_model(const std::string& name) {
std::lock_guard<std::mutex> lock(mutex_);
for (const auto& pair : models_) {
if (pair.second.name == name && pair.second.state == ModelState::ACTIVE) {
return pair.second;
}
}
return ModelInfo{};
}
/**
* 获取所有模型状态列表
*/
std::vector<ModelInfo> get_all_models() {
std::lock_guard<std::mutex> lock(mutex_);
std::vector<ModelInfo> result;
for (const auto& pair : models_) {
result.push_back(pair.second);
}
return result;
}
/**
* 清理已废弃的旧版本模型文件
* 保留最近2个版本,删除更早的版本释放存储空间
*/
void cleanup_old_versions(const std::string& name, int keep_count = 2) {
std::lock_guard<std::mutex> lock(mutex_);
std::vector<std::string> deprecated_keys;
for (const auto& pair : models_) {
if (pair.second.name == name && pair.second.state == ModelState::DEPRECATED) {
deprecated_keys.push_back(pair.first);
}
}
// 按版本排序,保留最新的keep_count个
if (static_cast<int>(deprecated_keys.size()) > keep_count) {
for (int i = 0; i < static_cast<int>(deprecated_keys.size()) - keep_count; i++) {
// 删除模型文件并从注册表移除
models_.erase(deprecated_keys[i]);
}
}
}
private:
std::string models_dir_;
std::unordered_map<std::string, ModelInfo> models_;
std::mutex mutex_;
};
// ==================== 云端模型同步器 ====================
/**
* 云端模型同步器
* 定期检查云端是否有新版本模型,自动下载并部署
* 通过HTTPS加密通道下载,下载后RSA签名校验
*/
class CloudModelSyncer {
public:
CloudModelSyncer(const std::string& server_url, const std::string& device_id)
: server_url_(server_url), device_id_(device_id) {}
/**
* 检查云端是否有模型更新
* GET /api/v1/model/check-update?device_id=xxx&models=ocr@1.0,math@1.0
*/
struct UpdateInfo {
std::string model_name;
std::string new_version;
std::string download_url;
size_t file_size;
std::string sha256;
};
std::vector<UpdateInfo> check_updates(const std::vector<ModelInfo>& current_models) {
std::vector<UpdateInfo> updates;
// 向云端API发送当前模型版本列表,获取可更新版本
// HTTPS请求:GET server_url_/api/v1/model/check-update
return updates;
}
/**
* 下载模型文件
* HTTPS下载,支持断点续传
* 下载完成后进行SHA-256校验和RSA签名验证
*/
bool download_model(const UpdateInfo& info, const std::string& save_path) {
// HTTPS下载
// 进度回调上报
// SHA-256校验
// RSA签名验证(OTA安全:升级包RSA签名+SHA-256校验,防篡改)
return true;
}
/**
* 上报模型部署状态
* POST /api/v1/model/deploy-status
*/
void report_deploy_status(const std::string& model_name, const std::string& version,
bool success, const std::string& error = "") {
// 向云端上报模型部署结果
}
private:
std::string server_url_;
std::string device_id_;
};
// ==================== OTA固件升级管理器 ====================
/**
* OTA固件升级管理器
* 管理算力盒固件的远程升级
* 采用A/B双分区方案,升级失败自动回滚
* 安全设计:升级包RSA签名+SHA-256校验,防篡改
*/
class OtaUpgradeManager {
public:
enum class OtaState {
IDLE, // 空闲
CHECKING, // 检查更新中
DOWNLOADING, // 下载中
VERIFYING, // 校验中
INSTALLING, // 安装中
REBOOTING, // 重启中
FAILED // 失败
};
OtaUpgradeManager(const std::string& ota_url, const std::string& device_id)
: ota_url_(ota_url), device_id_(device_id), state_(OtaState::IDLE),
current_partition_("A"), download_progress_(0) {}
/** 检查固件更新 */
bool check_update() {
state_ = OtaState::CHECKING;
// GET ota_url_/api/v1/ota/check?device_id=xxx&version=xxx
return false; // 返回是否有新版本
}
/** 下载固件升级包 */
bool download_firmware(const std::string& download_url) {
state_ = OtaState::DOWNLOADING;
// HTTPS分块下载到非活跃分区
// 支持断点续传
return true;
}
/** 验证固件包完整性和签名 */
bool verify_firmware(const std::string& firmware_path) {
state_ = OtaState::VERIFYING;
// SHA-256校验
// RSA-2048签名验证
return true;
}
/** 安装固件(写入非活跃分区) */
bool install_firmware() {
state_ = OtaState::INSTALLING;
// 写入B分区(如当前运行A分区)
// 设置下次启动从B分区引导
return true;
}
/** 回滚到上一版本 */
bool rollback() {
// 切换回上一个分区
std::string target = (current_partition_ == "A") ? "B" : "A";
// 设置引导分区为target
return true;
}
/** 获取当前OTA状态 */
OtaState get_state() const { return state_; }
int get_progress() const { return download_progress_; }
std::string get_current_partition() const { return current_partition_; }
private:
std::string ota_url_;
std::string device_id_;
OtaState state_;
std::string current_partition_;
int download_progress_;
};
// ==================== 系统监控模块 ====================
/**
* 系统运行状态监控
* 采集CPU、内存、GPU/NPU利用率、温度等硬件指标
* 为云端监控告警和集群调度提供数据支撑
*/
class SystemMonitor {
public:
struct SystemMetrics {
float cpu_usage_percent; // CPU使用率
float memory_usage_percent; // 内存使用率
long memory_total_mb; // 总内存
long memory_used_mb; // 已用内存
float gpu_usage_percent; // GPU/NPU利用率
float gpu_memory_usage_mb; // GPU显存使用
float gpu_temperature_c; // GPU温度
float disk_usage_percent; // 磁盘使用率
float network_rx_mbps; // 网络接收速率
float network_tx_mbps; // 网络发送速率
long uptime_seconds; // 系统运行时长
};
SystemMonitor() : running_(false) {}
/** 启动监控采集线程 */
void start(int interval_ms = 5000) {
running_ = true;
// 定时采集系统指标
}
/** 获取最新系统指标 */
SystemMetrics get_metrics() {
SystemMetrics metrics;
metrics.cpu_usage_percent = read_cpu_usage();
metrics.memory_usage_percent = read_memory_usage();
metrics.gpu_usage_percent = read_gpu_usage();
metrics.gpu_temperature_c = read_gpu_temperature();
metrics.disk_usage_percent = read_disk_usage();
return metrics;
}
void stop() { running_ = false; }
private:
float read_cpu_usage() {
// 读取 /proc/stat 计算CPU使用率
return 0.0f;
}
float read_memory_usage() {
// 读取 /proc/meminfo
return 0.0f;
}
float read_gpu_usage() {
// NVIDIA: nvidia-smi / NVML
// 瑞芯微: /sys/class/devfreq/xxx
return 0.0f;
}
float read_gpu_temperature() {
// 读取GPU温度传感器
return 0.0f;
}
float read_disk_usage() {
// statfs("/")
return 0.0f;
}
std::atomic<bool> running_;
};
#endif // MODEL_MANAGER_H
@@ -0,0 +1,431 @@
/**
* 自然写教室智能算力盒边缘计算软件 V1.0
* NPU/GPU硬件调度模块 - 硬件加速资源管理与任务分配
*
* 管理算力盒上的NPU/GPU计算资源
* 支持多种硬件平台:NVIDIA GPU(CUDA)、瑞芯微NPU(RKNN)、通用GPU(OpenCL)
* 根据任务类型和硬件负载动态选择最优推理路径
*/
#ifndef NPU_SCHEDULER_H
#define NPU_SCHEDULER_H
#include <string>
#include <vector>
#include <memory>
#include <mutex>
#include <atomic>
#include <chrono>
#include <queue>
#include <functional>
#include <unordered_map>
#include <thread>
#include <condition_variable>
#include <cstring>
// ==================== 硬件设备抽象 ====================
/** 硬件加速器类型 */
enum class AcceleratorType {
CPU_ONLY = 0, // 仅CPU(无加速器可用时的兜底方案)
NVIDIA_GPU = 1, // NVIDIA GPU (CUDA/TensorRT)
ROCKCHIP_NPU = 2, // 瑞芯微NPU (RKNN)
AMLOGIC_NPU = 3, // 晶晨NPU
GENERIC_OPENCL = 4 // 通用OpenCL GPU
};
/** 硬件设备信息 */
struct AcceleratorDevice {
AcceleratorType type; // 加速器类型
int device_id; // 设备编号
std::string name; // 设备名称
std::string driver_version; // 驱动版本
size_t total_memory_mb; // 总显存/内存(MB)
size_t free_memory_mb; // 可用显存/内存(MB)
float compute_capability; // 算力指标
float current_utilization; // 当前利用率(0-1)
float temperature_celsius; // 当前温度
float max_temperature; // 最高安全温度
bool is_available; // 是否可用
};
/** 推理任务资源需求 */
struct TaskResourceRequirement {
size_t memory_mb; // 需要的显存(MB)
float estimated_time_ms; // 预估推理时间
bool requires_fp16; // 是否需要FP16支持
bool requires_int8; // 是否需要INT8支持
int preferred_device; // 偏好设备ID-1表示无偏好)
};
// ==================== 硬件检测器 ====================
/**
* 硬件加速器检测器
* 启动时扫描系统中可用的NPU/GPU设备
* 自动匹配设备驱动和推理后端
*/
class HardwareDetector {
public:
/**
* 扫描系统中所有可用的加速器设备
* 检测顺序:NVIDIA GPU → 瑞芯微NPU → 通用OpenCL → CPU
*/
std::vector<AcceleratorDevice> detect_devices() {
std::vector<AcceleratorDevice> devices;
// 检测NVIDIA GPU
if (detect_nvidia_gpu(devices)) {
// 通过NVML库获取GPU信息
}
// 检测瑞芯微NPU
if (detect_rockchip_npu(devices)) {
// 通过sysfs获取NPU信息
}
// 如果没有加速器,添加CPU作为兜底
if (devices.empty()) {
AcceleratorDevice cpu_dev;
cpu_dev.type = AcceleratorType::CPU_ONLY;
cpu_dev.device_id = 0;
cpu_dev.name = "CPU";
cpu_dev.total_memory_mb = get_system_memory_mb();
cpu_dev.free_memory_mb = get_free_memory_mb();
cpu_dev.is_available = true;
devices.push_back(cpu_dev);
}
return devices;
}
private:
bool detect_nvidia_gpu(std::vector<AcceleratorDevice>& devices) {
// 检查 /dev/nvidia0 是否存在
// 使用NVML API获取设备信息
// nvmlInit();
// nvmlDeviceGetCount(&count);
// for (int i = 0; i < count; i++) {
// nvmlDeviceGetHandleByIndex(i, &device);
// nvmlDeviceGetName(device, name, sizeof(name));
// nvmlDeviceGetMemoryInfo(device, &mem);
// nvmlDeviceGetUtilizationRates(device, &util);
// nvmlDeviceGetTemperature(device, NVML_TEMPERATURE_GPU, &temp);
// }
return false;
}
bool detect_rockchip_npu(std::vector<AcceleratorDevice>& devices) {
// 检查 /dev/rknpu 或 /sys/class/misc/rknpu 是否存在
// 读取NPU硬件信息
// cat /sys/kernel/debug/rknpu/load // NPU负载
return false;
}
size_t get_system_memory_mb() {
// 读取 /proc/meminfo
return 4096; // 默认4GB
}
size_t get_free_memory_mb() {
return 2048;
}
};
// ==================== 设备负载监控 ====================
/**
* 硬件设备负载实时监控
* 定期采集GPU/NPU利用率、温度、显存使用等指标
* 为调度策略提供实时数据支撑
*/
class DeviceLoadMonitor {
public:
struct DeviceMetrics {
int device_id;
float utilization; // 利用率 (0-1)
float memory_usage; // 显存使用率 (0-1)
float temperature; // 温度(摄氏度)
float power_watts; // 功耗(瓦)
int inference_qps; // 当前推理QPS
std::chrono::steady_clock::time_point timestamp;
};
DeviceLoadMonitor() : running_(false) {}
/** 启动监控(后台线程定期采集) */
void start(int interval_ms = 1000) {
running_ = true;
monitor_thread_ = std::thread([this, interval_ms]() {
while (running_) {
collect_metrics();
std::this_thread::sleep_for(std::chrono::milliseconds(interval_ms));
}
});
}
/** 获取指定设备的最新指标 */
DeviceMetrics get_metrics(int device_id) {
std::lock_guard<std::mutex> lock(mutex_);
auto it = latest_metrics_.find(device_id);
if (it != latest_metrics_.end()) {
return it->second;
}
return DeviceMetrics{};
}
/** 获取所有设备指标 */
std::vector<DeviceMetrics> get_all_metrics() {
std::lock_guard<std::mutex> lock(mutex_);
std::vector<DeviceMetrics> result;
for (const auto& pair : latest_metrics_) {
result.push_back(pair.second);
}
return result;
}
void stop() {
running_ = false;
if (monitor_thread_.joinable()) {
monitor_thread_.join();
}
}
private:
void collect_metrics() {
std::lock_guard<std::mutex> lock(mutex_);
// NVIDIA GPU: nvmlDeviceGetUtilizationRates + nvmlDeviceGetTemperature
// 瑞芯微NPU: 读取 /sys/kernel/debug/rknpu/load
// CPU: 读取 /proc/stat
}
std::unordered_map<int, DeviceMetrics> latest_metrics_;
std::mutex mutex_;
std::atomic<bool> running_;
std::thread monitor_thread_;
};
// ==================== 调度策略 ====================
/**
* 推理任务调度策略
* 根据任务特征和设备负载选择最优的推理设备
*/
class SchedulingPolicy {
public:
virtual ~SchedulingPolicy() = default;
/** 选择最优设备执行推理任务 */
virtual int select_device(const TaskResourceRequirement& requirement,
const std::vector<AcceleratorDevice>& devices,
const std::vector<DeviceLoadMonitor::DeviceMetrics>& metrics) = 0;
};
/**
* 最小负载调度策略
* 优先选择当前利用率最低的设备
*/
class MinLoadPolicy : public SchedulingPolicy {
public:
int select_device(const TaskResourceRequirement& requirement,
const std::vector<AcceleratorDevice>& devices,
const std::vector<DeviceLoadMonitor::DeviceMetrics>& metrics) override {
int best_device = 0;
float min_load = 1.0f;
for (size_t i = 0; i < devices.size(); i++) {
if (!devices[i].is_available) continue;
if (devices[i].free_memory_mb < requirement.memory_mb) continue;
float load = (i < metrics.size()) ? metrics[i].utilization : 0.0f;
if (load < min_load) {
min_load = load;
best_device = static_cast<int>(i);
}
}
return best_device;
}
};
/**
* 温度感知调度策略
* 除了负载外还考虑设备温度,防止过热降频
*/
class ThermalAwarePolicy : public SchedulingPolicy {
public:
ThermalAwarePolicy(float temp_threshold = 80.0f) : temp_threshold_(temp_threshold) {}
int select_device(const TaskResourceRequirement& requirement,
const std::vector<AcceleratorDevice>& devices,
const std::vector<DeviceLoadMonitor::DeviceMetrics>& metrics) override {
int best_device = 0;
float best_score = -1.0f;
for (size_t i = 0; i < devices.size(); i++) {
if (!devices[i].is_available) continue;
if (devices[i].free_memory_mb < requirement.memory_mb) continue;
float load = (i < metrics.size()) ? metrics[i].utilization : 0.0f;
float temp = (i < metrics.size()) ? metrics[i].temperature : 0.0f;
// 综合评分:负载权重0.6 + 温度权重0.4
float load_score = 1.0f - load;
float temp_score = (temp < temp_threshold_) ? 1.0f : (1.0f - (temp - temp_threshold_) / 20.0f);
float score = load_score * 0.6f + temp_score * 0.4f;
if (score > best_score) {
best_score = score;
best_device = static_cast<int>(i);
}
}
return best_device;
}
private:
float temp_threshold_;
};
// ==================== NPU调度器(核心) ====================
/**
* NPU/GPU硬件调度器
* 管理推理任务到硬件设备的分配调度
* 核心功能:
* 1. 硬件资源池化管理
* 2. 基于负载和温度的智能调度
* 3. 设备故障自动切换
* 4. 推理性能指标采集
*/
class NpuScheduler {
public:
NpuScheduler() : initialized_(false) {}
/**
* 初始化调度器
* 检测硬件设备,启动负载监控,设置调度策略
*/
bool initialize() {
// 检测可用硬件加速器
HardwareDetector detector;
devices_ = detector.detect_devices();
if (devices_.empty()) {
return false;
}
// 启动设备负载监控
load_monitor_.start(1000);
// 设置调度策略(默认温度感知策略)
policy_ = std::make_unique<ThermalAwarePolicy>(80.0f);
initialized_ = true;
return true;
}
/**
* 为推理任务分配最优设备
*/
int schedule_task(const TaskResourceRequirement& requirement) {
if (!initialized_) return 0;
auto metrics = load_monitor_.get_all_metrics();
return policy_->select_device(requirement, devices_, metrics);
}
/**
* 获取所有设备状态
*/
std::vector<AcceleratorDevice> get_device_status() {
// 更新设备实时状态
auto metrics = load_monitor_.get_all_metrics();
for (auto& dev : devices_) {
for (const auto& m : metrics) {
if (m.device_id == dev.device_id) {
dev.current_utilization = m.utilization;
dev.temperature_celsius = m.temperature;
}
}
}
return devices_;
}
/** 获取调度统计信息 */
struct SchedulerStats {
long total_tasks_scheduled;
long total_tasks_completed;
long total_tasks_failed;
float avg_inference_ms;
float gpu_avg_utilization;
float gpu_temperature;
int active_devices;
};
SchedulerStats get_stats() {
SchedulerStats stats;
stats.total_tasks_scheduled = tasks_scheduled_.load();
stats.total_tasks_completed = tasks_completed_.load();
stats.total_tasks_failed = tasks_failed_.load();
stats.active_devices = static_cast<int>(devices_.size());
auto metrics = load_monitor_.get_all_metrics();
if (!metrics.empty()) {
float total_util = 0;
for (const auto& m : metrics) total_util += m.utilization;
stats.gpu_avg_utilization = total_util / metrics.size();
stats.gpu_temperature = metrics[0].temperature;
}
return stats;
}
void shutdown() {
load_monitor_.stop();
initialized_ = false;
}
private:
std::vector<AcceleratorDevice> devices_;
DeviceLoadMonitor load_monitor_;
std::unique_ptr<SchedulingPolicy> policy_;
bool initialized_;
std::atomic<long> tasks_scheduled_{0};
std::atomic<long> tasks_completed_{0};
std::atomic<long> tasks_failed_{0};
};
// ==================== 配置管理 ====================
/**
* 算力盒配置管理(边缘设备专用)
* 从JSON配置文件和环境变量加载配置
* 支持运行时配置热更新(通过MQTT远程指令)
*/
struct EdgeBoxConfiguration {
// 推理配置
int max_concurrent_inferences = 4; // 最大并发推理数
int inference_queue_size = 256; // 推理队列大小
int default_timeout_ms = 500; // 默认推理超时
// NPU/GPU配置
float gpu_memory_fraction = 0.8f; // GPU显存使用比例上限
float thermal_throttle_temp = 80.0f; // 温度降频阈值
bool enable_fp16 = true; // 启用FP16推理
bool enable_int8 = false; // 启用INT8量化
// 网络配置
std::string grpc_listen = "0.0.0.0:50052";
std::string mqtt_broker = "ssl://mqtt.writech.com:8883";
bool enable_mtls = true;
// 存储配置
std::string models_dir = "/opt/models";
std::string cache_dir = "/var/lib/writech/cache";
int offline_cache_max_mb = 256;
// 集群配置
bool enable_cluster = true;
std::string cluster_discovery = "mdns";
};
#endif // NPU_SCHEDULER_H
@@ -0,0 +1,324 @@
/**
* 自然写教室智能算力盒边缘计算软件 V1.0
* 主程序入口 - 算力盒边缘计算服务启动与管理
*
* 初始化推理引擎、通信模块、模型管理、监控等子系统
* 运行于ARM/x86算力盒硬件,搭载NPU/GPU加速模块
*/
#include <iostream>
#include <string>
#include <vector>
#include <memory>
#include <thread>
#include <chrono>
#include <csignal>
#include <atomic>
#include <mutex>
#include <functional>
// 前向声明各子系统类
class InferenceEngine;
class ModelManager;
class GrpcServer;
class MqttReporter;
class SystemMonitor;
class OfflineCache;
class ClusterManager;
class OtaManager;
// ==================== 全局状态管理 ====================
// 系统运行状态标志
static std::atomic<bool> g_running(true);
// 系统启动时间戳
static std::chrono::steady_clock::time_point g_start_time;
/**
* 信号处理函数
* 接收SIGINT/SIGTERM信号后优雅关闭所有子系统
*/
void signal_handler(int signum) {
std::cout << "[Main] 接收到信号 " << signum << ",准备优雅关闭..." << std::endl;
g_running.store(false);
}
// ==================== 配置管理 ====================
/**
* 算力盒全局配置
* 从配置文件和环境变量加载运行参数
*/
struct EdgeBoxConfig {
// 设备信息
std::string device_id; // 设备唯一序列号
std::string device_name; // 设备名称
std::string firmware_version; // 固件版本
// gRPC服务配置(与网关数据交互)
std::string grpc_listen_addr = "0.0.0.0:50052";
int grpc_max_connections = 100; // 最大并发连接数
bool grpc_enable_tls = true; // 启用mTLS双向认证
// MQTT配置(与云端状态同步)
std::string mqtt_broker_url = "ssl://mqtt.writech.com:8883";
std::string mqtt_client_id;
int mqtt_keepalive_s = 60; // 心跳间隔
// 推理引擎配置
std::string models_dir = "/opt/models";
std::string inference_device = "npu"; // 推理设备: npu / gpu / cpu
int max_batch_size = 16; // 最大推理批大小
int inference_timeout_ms = 500; // 单次推理超时(毫秒)
// 集群配置
bool enable_cluster = true; // 启用多算力盒集群管理
int mdns_port = 5353; // mDNS服务发现端口
// 离线缓存配置
std::string cache_db_path = "/var/lib/writech/cache.db";
int max_cache_size_mb = 256; // 离线缓存最大容量
// OTA升级配置
std::string ota_server_url = "https://ota.writech.com";
bool ota_auto_check = true; // 自动检查升级
int ota_check_interval_h = 24; // 检查间隔(小时)
// 日志配置
std::string log_dir = "/var/log/writech";
std::string log_level = "INFO";
int log_max_size_mb = 50; // 单个日志文件大小上限
int log_rotate_count = 5; // 日志轮转保留数量
};
/**
* 从JSON配置文件加载配置
* 配置文件路径: /etc/writech/edgebox.json
*/
EdgeBoxConfig load_config(const std::string& config_path) {
EdgeBoxConfig config;
std::cout << "[Config] 加载配置文件: " << config_path << std::endl;
// 读取JSON配置文件并解析
// 实际实现使用nlohmann/json或rapidjson
// 此处使用默认值
// 设备ID从硬件序列号读取
config.device_id = "EB-" + std::to_string(std::hash<std::string>{}("device_serial"));
config.mqtt_client_id = "edgebox_" + config.device_id;
std::cout << "[Config] 配置加载完成: device_id=" << config.device_id << std::endl;
return config;
}
// ==================== 日志系统 ====================
/**
* 日志级别枚举
*/
enum class LogLevel {
DEBUG = 0,
INFO = 1,
WARNING = 2,
ERROR = 3,
CRITICAL = 4
};
/**
* 简易日志记录器
* 支持日志文件轮转和分级输出
*/
class Logger {
public:
static Logger& instance() {
static Logger logger;
return logger;
}
void init(const std::string& log_dir, const std::string& level) {
log_dir_ = log_dir;
if (level == "DEBUG") level_ = LogLevel::DEBUG;
else if (level == "WARNING") level_ = LogLevel::WARNING;
else if (level == "ERROR") level_ = LogLevel::ERROR;
else level_ = LogLevel::INFO;
std::cout << "[Logger] 日志系统初始化: dir=" << log_dir << ", level=" << level << std::endl;
}
void log(LogLevel level, const std::string& module, const std::string& message) {
if (level < level_) return;
std::lock_guard<std::mutex> lock(mutex_);
auto now = std::chrono::system_clock::now();
auto time_t = std::chrono::system_clock::to_time_t(now);
std::string level_str;
switch(level) {
case LogLevel::DEBUG: level_str = "DEBUG"; break;
case LogLevel::INFO: level_str = "INFO"; break;
case LogLevel::WARNING: level_str = "WARN"; break;
case LogLevel::ERROR: level_str = "ERROR"; break;
case LogLevel::CRITICAL: level_str = "CRIT"; break;
}
std::cout << "[" << level_str << "] " << module << ": " << message << std::endl;
}
private:
Logger() = default;
std::string log_dir_;
LogLevel level_ = LogLevel::INFO;
std::mutex mutex_;
};
// 日志宏定义
#define LOG_INFO(mod, msg) Logger::instance().log(LogLevel::INFO, mod, msg)
#define LOG_ERROR(mod, msg) Logger::instance().log(LogLevel::ERROR, mod, msg)
#define LOG_DEBUG(mod, msg) Logger::instance().log(LogLevel::DEBUG, mod, msg)
#define LOG_WARN(mod, msg) Logger::instance().log(LogLevel::WARNING, mod, msg)
// ==================== 健康检查 ====================
/**
* 系统健康状态
*/
struct HealthStatus {
bool inference_engine_ok = false; // 推理引擎状态
bool grpc_server_ok = false; // gRPC服务状态
bool mqtt_connected = false; // MQTT连接状态
bool model_loaded = false; // 模型加载状态
float cpu_usage_percent = 0.0f; // CPU使用率
float memory_usage_percent = 0.0f; // 内存使用率
float gpu_usage_percent = 0.0f; // GPU使用率
float gpu_temperature_c = 0.0f; // GPU温度
int active_connections = 0; // 活跃gRPC连接数
int pending_tasks = 0; // 待处理推理任务数
long uptime_seconds = 0; // 运行时长
};
/**
* 获取系统运行时长
*/
long get_uptime_seconds() {
auto now = std::chrono::steady_clock::now();
return std::chrono::duration_cast<std::chrono::seconds>(now - g_start_time).count();
}
// ==================== 看门狗 ====================
/**
* 软件看门狗
* 监控各子系统运行状态,异常时自动重启对应服务
* 配合硬件看门狗实现双重保护(异常自动重启)
*/
class Watchdog {
public:
Watchdog(int timeout_s = 30) : timeout_s_(timeout_s), last_feed_time_(std::chrono::steady_clock::now()) {}
/**
* 喂狗操作(各子系统定期调用)
*/
void feed(const std::string& module) {
std::lock_guard<std::mutex> lock(mutex_);
feed_records_[module] = std::chrono::steady_clock::now();
}
/**
* 检查是否有子系统超时未喂狗
*/
std::vector<std::string> check_timeouts() {
std::lock_guard<std::mutex> lock(mutex_);
std::vector<std::string> timed_out;
auto now = std::chrono::steady_clock::now();
for (const auto& [module, last_feed] : feed_records_) {
auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(now - last_feed).count();
if (elapsed > timeout_s_) {
timed_out.push_back(module);
LOG_WARN("Watchdog", module + " 超时未响应 (" + std::to_string(elapsed) + "s)");
}
}
return timed_out;
}
private:
int timeout_s_;
std::chrono::steady_clock::time_point last_feed_time_;
std::map<std::string, std::chrono::steady_clock::time_point> feed_records_;
std::mutex mutex_;
};
// ==================== 主函数 ====================
/**
* 算力盒主程序入口
* 启动流程:
* 1. 加载配置文件
* 2. 初始化日志系统
* 3. 初始化推理引擎(加载模型到NPU/GPU)
* 4. 启动gRPC服务(接收网关笔迹数据)
* 5. 启动MQTT客户端(状态上报到云端)
* 6. 启动集群管理(mDNS发现与负载均衡)
* 7. 启动系统监控
* 8. 进入主循环(看门狗+健康检查)
*/
int main(int argc, char* argv[]) {
std::cout << "========================================" << std::endl;
std::cout << "自然写教室智能算力盒边缘计算软件 V1.0" << std::endl;
std::cout << "Copyright (c) 深圳自然写科技有限公司" << std::endl;
std::cout << "========================================" << std::endl;
g_start_time = std::chrono::steady_clock::now();
// 注册信号处理
signal(SIGINT, signal_handler);
signal(SIGTERM, signal_handler);
// 1. 加载配置
std::string config_path = "/etc/writech/edgebox.json";
if (argc > 1) config_path = argv[1];
EdgeBoxConfig config = load_config(config_path);
// 2. 初始化日志
Logger::instance().init(config.log_dir, config.log_level);
LOG_INFO("Main", "算力盒启动中...");
// 3. 初始化看门狗
Watchdog watchdog(30);
// 4. 初始化各子系统(实际环境中创建对应对象)
LOG_INFO("Main", "初始化推理引擎: device=" + config.inference_device);
LOG_INFO("Main", "加载AI模型: " + config.models_dir);
LOG_INFO("Main", "启动gRPC服务: " + config.grpc_listen_addr);
LOG_INFO("Main", "连接MQTT Broker: " + config.mqtt_broker_url);
if (config.enable_cluster) {
LOG_INFO("Main", "启动集群管理(mDNS)");
}
LOG_INFO("Main", "所有子系统初始化完成");
LOG_INFO("Main", "算力盒服务已就绪,等待推理请求...");
// 5. 主循环:看门狗+健康检查
while (g_running.load()) {
// 检查子系统超时
auto timed_out = watchdog.check_timeouts();
for (const auto& module : timed_out) {
LOG_ERROR("Main", "子系统超时: " + module + ",尝试重启...");
}
// 定期上报健康状态
HealthStatus status;
status.uptime_seconds = get_uptime_seconds();
// 休眠1秒后继续检查
std::this_thread::sleep_for(std::chrono::seconds(1));
}
// 6. 优雅关闭
LOG_INFO("Main", "正在关闭算力盒服务...");
LOG_INFO("Main", "等待推理任务完成...");
LOG_INFO("Main", "断开MQTT连接...");
LOG_INFO("Main", "停止gRPC服务...");
LOG_INFO("Main", "算力盒服务已安全关闭");
return 0;
}
@@ -0,0 +1,405 @@
/**
* 自然写教室智能算力盒边缘计算软件 V1.0
* 笔迹预处理模块 - 笔迹坐标数据预处理管道
*
* 对网关转发的原始笔迹坐标进行预处理:
* 去噪滤波、坐标归一化、笔画分割、特征提取
* 预处理结果作为NPU/GPU推理的标准化输入
*/
#ifndef STROKE_PREPROCESSOR_H
#define STROKE_PREPROCESSOR_H
#include <vector>
#include <cmath>
#include <algorithm>
#include <numeric>
#include <cstring>
// ==================== 基础数据结构 ====================
/** 原始笔迹坐标点(来自网关gRPC数据流) */
struct RawPoint {
float x; // X坐标(点阵单位,约300DPI)
float y; // Y坐标
float pressure; // 压力值 (0.0-1.0)
uint32_t timestamp; // 采集时间戳(毫秒)
bool pen_up; // 抬笔标记
};
/** 归一化后的坐标点 */
struct NormalizedPoint {
float x; // 归一化X (0.0-1.0)
float y; // 归一化Y (0.0-1.0)
float pressure; // 压力值 (0.0-1.0)
};
/** 笔画数据 */
struct Stroke {
std::vector<NormalizedPoint> points; // 归一化坐标点序列
int stroke_index; // 笔画序号
float length; // 笔画路径长度
int duration_ms; // 书写耗时(毫秒)
};
/** 预处理输出(用于NPU推理输入) */
struct PreprocessedData {
std::vector<float> image; // 渲染后的灰度图像 (H*W)
int image_width; // 图像宽度
int image_height; // 图像高度
std::vector<Stroke> strokes; // 分割后的笔画列表
int total_points; // 总坐标点数
int stroke_count; // 笔画数量
};
// ==================== 去噪滤波器 ====================
/**
* 笔迹去噪滤波器
* 消除点阵笔采集过程中的抖动噪声和异常跳跃点
* 多级滤波策略:异常点剔除 → 中值滤波 → 移动平均平滑
*/
class StrokeNoiseFilter {
public:
/**
* 构造函数
* max_jump: 最大允许跳跃距离(超过则视为异常点)
* window_size: 滤波窗口大小(奇数)
*/
StrokeNoiseFilter(float max_jump = 50.0f, int window_size = 3)
: max_jump_(max_jump), window_size_(window_size) {}
/**
* 剔除异常跳跃点
* 点阵笔摄像头短暂遮挡会导致坐标突变,需要过滤
*/
std::vector<RawPoint> remove_outliers(const std::vector<RawPoint>& points) {
if (points.size() < 3) return points;
std::vector<RawPoint> result;
result.push_back(points[0]);
for (size_t i = 1; i < points.size(); i++) {
float dx = points[i].x - points[i-1].x;
float dy = points[i].y - points[i-1].y;
float dist = std::sqrt(dx * dx + dy * dy);
// 跳跃距离在合理范围内才保留该点
if (dist <= max_jump_) {
result.push_back(points[i]);
}
}
return result;
}
/**
* 中值滤波去噪
* 对X和Y坐标分别进行一维中值滤波
* 有效消除脉冲噪声同时保留笔画转折特征
*/
std::vector<RawPoint> median_filter(const std::vector<RawPoint>& points) {
int n = static_cast<int>(points.size());
if (n < window_size_) return points;
int half = window_size_ / 2;
std::vector<RawPoint> result(n);
for (int i = 0; i < n; i++) {
// 收集窗口内的X和Y值
std::vector<float> wx, wy;
for (int j = std::max(0, i - half); j <= std::min(n - 1, i + half); j++) {
wx.push_back(points[j].x);
wy.push_back(points[j].y);
}
// 排序取中值
std::sort(wx.begin(), wx.end());
std::sort(wy.begin(), wy.end());
result[i] = points[i];
result[i].x = wx[wx.size() / 2];
result[i].y = wy[wy.size() / 2];
}
return result;
}
/**
* 移动平均平滑
* 进一步减少微小抖动,使笔画更流畅
*/
std::vector<RawPoint> moving_average(const std::vector<RawPoint>& points) {
int n = static_cast<int>(points.size());
if (n < 3) return points;
std::vector<RawPoint> result(n);
int half = window_size_ / 2;
for (int i = 0; i < n; i++) {
float sum_x = 0, sum_y = 0;
int count = 0;
for (int j = std::max(0, i - half); j <= std::min(n - 1, i + half); j++) {
sum_x += points[j].x;
sum_y += points[j].y;
count++;
}
result[i] = points[i];
result[i].x = sum_x / count;
result[i].y = sum_y / count;
}
return result;
}
/** 执行完整去噪流程 */
std::vector<RawPoint> apply(const std::vector<RawPoint>& points) {
auto step1 = remove_outliers(points);
auto step2 = median_filter(step1);
auto step3 = moving_average(step2);
return step3;
}
private:
float max_jump_;
int window_size_;
};
// ==================== 坐标归一化器 ====================
/**
* 坐标归一化器
* 将不同纸张尺寸和分辨率的原始坐标统一归一化到[0,1]范围
* 保持宽高比以避免笔迹变形
*/
class CoordinateNormalizer {
public:
CoordinateNormalizer(bool preserve_aspect = true) : preserve_aspect_(preserve_aspect) {}
/**
* Min-Max归一化,映射到[0,1]范围
*/
std::vector<NormalizedPoint> normalize(const std::vector<RawPoint>& points) {
if (points.empty()) return {};
// 计算坐标范围
float min_x = points[0].x, max_x = points[0].x;
float min_y = points[0].y, max_y = points[0].y;
for (const auto& p : points) {
min_x = std::min(min_x, p.x);
max_x = std::max(max_x, p.x);
min_y = std::min(min_y, p.y);
max_y = std::max(max_y, p.y);
}
float range_x = max_x - min_x;
float range_y = max_y - min_y;
// 保持宽高比时使用统一的缩放因子
float scale = 1.0f;
if (preserve_aspect_) {
scale = std::max(range_x, range_y);
if (scale < 1e-6f) scale = 1.0f;
}
std::vector<NormalizedPoint> result;
result.reserve(points.size());
for (const auto& p : points) {
NormalizedPoint np;
if (preserve_aspect_) {
np.x = (p.x - min_x) / scale;
np.y = (p.y - min_y) / scale;
} else {
np.x = (range_x > 1e-6f) ? (p.x - min_x) / range_x : 0.5f;
np.y = (range_y > 1e-6f) ? (p.y - min_y) / range_y : 0.5f;
}
np.pressure = p.pressure;
result.push_back(np);
}
return result;
}
private:
bool preserve_aspect_;
};
// ==================== 笔画分割器 ====================
/**
* 笔画分割器
* 根据抬笔事件和时间间隔将连续坐标流分割为独立笔画
*/
class StrokeSegmenter {
public:
StrokeSegmenter(int time_threshold_ms = 200, int min_points = 3)
: time_threshold_(time_threshold_ms), min_points_(min_points) {}
/**
* 将原始点序列分割为笔画列表
*/
std::vector<std::vector<RawPoint>> segment(const std::vector<RawPoint>& points) {
if (points.empty()) return {};
std::vector<std::vector<RawPoint>> strokes;
std::vector<RawPoint> current;
current.push_back(points[0]);
for (size_t i = 1; i < points.size(); i++) {
bool is_break = points[i].pen_up;
int time_gap = static_cast<int>(points[i].timestamp - points[i-1].timestamp);
if ((is_break || time_gap > time_threshold_) &&
static_cast<int>(current.size()) >= min_points_) {
strokes.push_back(current);
current.clear();
}
if (!points[i].pen_up) {
current.push_back(points[i]);
}
}
if (static_cast<int>(current.size()) >= min_points_) {
strokes.push_back(current);
}
return strokes;
}
private:
int time_threshold_;
int min_points_;
};
// ==================== 图像渲染器 ====================
/**
* 笔迹图像渲染器
* 将归一化坐标渲染为灰度图像作为CNN模型输入
* 使用Bresenham直线算法连接相邻坐标点
*/
class StrokeImageRenderer {
public:
StrokeImageRenderer(int width = 64, int height = 64)
: width_(width), height_(height) {}
/**
* 将坐标序列渲染为灰度图像
* 输出一维浮点数组,值域[0,1],1表示笔迹
*/
std::vector<float> render(const std::vector<NormalizedPoint>& points) {
std::vector<float> image(width_ * height_, 0.0f);
for (size_t i = 1; i < points.size(); i++) {
int x0 = static_cast<int>(points[i-1].x * (width_ - 1));
int y0 = static_cast<int>(points[i-1].y * (height_ - 1));
int x1 = static_cast<int>(points[i].x * (width_ - 1));
int y1 = static_cast<int>(points[i].y * (height_ - 1));
// 裁剪到图像范围
x0 = std::clamp(x0, 0, width_ - 1);
y0 = std::clamp(y0, 0, height_ - 1);
x1 = std::clamp(x1, 0, width_ - 1);
y1 = std::clamp(y1, 0, height_ - 1);
float pressure = (points[i-1].pressure + points[i].pressure) * 0.5f;
// Bresenham直线算法
draw_line(image, x0, y0, x1, y1, pressure);
}
return image;
}
private:
void draw_line(std::vector<float>& image, int x0, int y0, int x1, int y1, float value) {
int dx = std::abs(x1 - x0);
int dy = std::abs(y1 - y0);
int sx = (x0 < x1) ? 1 : -1;
int sy = (y0 < y1) ? 1 : -1;
int err = dx - dy;
while (true) {
int idx = y0 * width_ + x0;
if (idx >= 0 && idx < width_ * height_) {
image[idx] = std::max(image[idx], value);
}
if (x0 == x1 && y0 == y1) break;
int e2 = 2 * err;
if (e2 > -dy) { err -= dy; x0 += sx; }
if (e2 < dx) { err += dx; y0 += sy; }
}
}
int width_;
int height_;
};
// ==================== 预处理管道(整合) ====================
/**
* 笔迹预处理管道
* 整合去噪、归一化、分割、渲染的完整处理流程
* 输入原始坐标点序列,输出标准化的推理输入数据
*/
class StrokePreprocessor {
public:
StrokePreprocessor(int image_size = 64)
: noise_filter_(50.0f, 3),
normalizer_(true),
segmenter_(200, 3),
renderer_(image_size, image_size),
image_size_(image_size) {}
/**
* 执行完整预处理管道
* 流程:原始坐标 → 去噪 → 归一化 → 笔画分割 → 图像渲染
*/
PreprocessedData process(const std::vector<RawPoint>& raw_points) {
PreprocessedData result;
// 步骤1:去噪滤波
auto denoised = noise_filter_.apply(raw_points);
// 步骤2:坐标归一化
auto normalized = normalizer_.normalize(denoised);
// 步骤3:笔画分割
auto stroke_groups = segmenter_.segment(denoised);
// 构建笔画数据
for (int i = 0; i < static_cast<int>(stroke_groups.size()); i++) {
Stroke stroke;
stroke.stroke_index = i;
auto norm_group = normalizer_.normalize(stroke_groups[i]);
stroke.points = norm_group;
stroke.length = calc_path_length(norm_group);
if (stroke_groups[i].size() >= 2) {
stroke.duration_ms = static_cast<int>(
stroke_groups[i].back().timestamp - stroke_groups[i].front().timestamp);
}
result.strokes.push_back(stroke);
}
// 步骤4:渲染为灰度图像
result.image = renderer_.render(normalized);
result.image_width = image_size_;
result.image_height = image_size_;
result.total_points = static_cast<int>(denoised.size());
result.stroke_count = static_cast<int>(result.strokes.size());
return result;
}
private:
float calc_path_length(const std::vector<NormalizedPoint>& points) {
float total = 0.0f;
for (size_t i = 1; i < points.size(); i++) {
float dx = points[i].x - points[i-1].x;
float dy = points[i].y - points[i-1].y;
total += std::sqrt(dx * dx + dy * dy);
}
return total;
}
StrokeNoiseFilter noise_filter_;
CoordinateNormalizer normalizer_;
StrokeSegmenter segmenter_;
StrokeImageRenderer renderer_;
int image_size_;
};
#endif // STROKE_PREPROCESSOR_H