software copyright
This commit is contained in:
@@ -0,0 +1,500 @@
|
||||
/**
|
||||
* 自然写教室智能算力盒边缘计算软件 V1.0
|
||||
* gRPC通信服务模块 - 与教室网关的笔迹数据交互
|
||||
*
|
||||
* 实现gRPC流式服务,接收网关转发的笔迹数据流
|
||||
* 支持mTLS双向认证确保通信安全
|
||||
*/
|
||||
|
||||
#ifndef GRPC_SERVER_H
|
||||
#define GRPC_SERVER_H
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <atomic>
|
||||
#include <thread>
|
||||
#include <functional>
|
||||
#include <unordered_map>
|
||||
#include <chrono>
|
||||
#include <queue>
|
||||
|
||||
// ==================== gRPC消息结构 ====================
|
||||
|
||||
/** 笔迹坐标点(对应protobuf消息) */
|
||||
struct GrpcStrokePoint {
|
||||
float x;
|
||||
float y;
|
||||
float pressure;
|
||||
uint32_t timestamp;
|
||||
bool pen_up;
|
||||
};
|
||||
|
||||
/** 笔迹数据包(对应protobuf消息) */
|
||||
struct GrpcStrokePacket {
|
||||
std::string packet_id; // 数据包ID
|
||||
std::string pen_id; // 笔设备MAC地址
|
||||
std::string student_id; // 学生ID
|
||||
std::string page_id; // 点阵码页面ID
|
||||
std::vector<GrpcStrokePoint> points; // 坐标点序列
|
||||
uint64_t gateway_timestamp; // 网关转发时间戳
|
||||
int sequence_number; // 包序号(用于乱序检测)
|
||||
};
|
||||
|
||||
/** 识别结果响应 */
|
||||
struct GrpcRecognitionResponse {
|
||||
std::string packet_id; // 对应的请求包ID
|
||||
std::string recognition_type; // 识别类型(ocr/math/stroke_order)
|
||||
bool success; // 是否成功
|
||||
std::string result_text; // 识别结果文本
|
||||
float confidence; // 置信度
|
||||
float processing_time_ms; // 处理耗时
|
||||
std::string model_version; // 使用的模型版本
|
||||
};
|
||||
|
||||
// ==================== 连接管理器 ====================
|
||||
|
||||
/** 客户端连接信息 */
|
||||
struct ClientConnection {
|
||||
std::string client_id; // 客户端标识(网关ID)
|
||||
std::string client_addr; // 客户端地址
|
||||
std::string cert_fingerprint; // 客户端证书指纹(mTLS)
|
||||
std::chrono::steady_clock::time_point connected_at;
|
||||
std::chrono::steady_clock::time_point last_active;
|
||||
long packets_received; // 已接收数据包数
|
||||
long bytes_received; // 已接收字节数
|
||||
bool authenticated; // 是否已通过mTLS认证
|
||||
};
|
||||
|
||||
/**
|
||||
* gRPC连接管理器
|
||||
* 管理与多个教室网关的gRPC连接
|
||||
* 每个网关对应一个持久化的gRPC流式连接
|
||||
*/
|
||||
class ConnectionManager {
|
||||
public:
|
||||
ConnectionManager(int max_connections = 100)
|
||||
: max_connections_(max_connections) {}
|
||||
|
||||
/** 注册新连接 */
|
||||
bool register_connection(const std::string& client_id, const std::string& addr,
|
||||
const std::string& cert_fp) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if (static_cast<int>(connections_.size()) >= max_connections_) {
|
||||
return false; // 达到最大连接数限制
|
||||
}
|
||||
|
||||
ClientConnection conn;
|
||||
conn.client_id = client_id;
|
||||
conn.client_addr = addr;
|
||||
conn.cert_fingerprint = cert_fp;
|
||||
conn.connected_at = std::chrono::steady_clock::now();
|
||||
conn.last_active = conn.connected_at;
|
||||
conn.packets_received = 0;
|
||||
conn.bytes_received = 0;
|
||||
conn.authenticated = !cert_fp.empty();
|
||||
|
||||
connections_[client_id] = conn;
|
||||
return true;
|
||||
}
|
||||
|
||||
/** 移除连接 */
|
||||
void remove_connection(const std::string& client_id) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
connections_.erase(client_id);
|
||||
}
|
||||
|
||||
/** 更新连接活跃时间 */
|
||||
void update_activity(const std::string& client_id, long bytes) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
auto it = connections_.find(client_id);
|
||||
if (it != connections_.end()) {
|
||||
it->second.last_active = std::chrono::steady_clock::now();
|
||||
it->second.packets_received++;
|
||||
it->second.bytes_received += bytes;
|
||||
}
|
||||
}
|
||||
|
||||
/** 检查空闲超时连接 */
|
||||
std::vector<std::string> check_idle_connections(int timeout_s = 300) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
std::vector<std::string> idle;
|
||||
auto now = std::chrono::steady_clock::now();
|
||||
|
||||
for (const auto& pair : connections_) {
|
||||
auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
|
||||
now - pair.second.last_active).count();
|
||||
if (elapsed > timeout_s) {
|
||||
idle.push_back(pair.first);
|
||||
}
|
||||
}
|
||||
return idle;
|
||||
}
|
||||
|
||||
/** 获取当前连接数 */
|
||||
int active_count() const {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return static_cast<int>(connections_.size());
|
||||
}
|
||||
|
||||
/** 获取所有连接状态 */
|
||||
std::vector<ClientConnection> get_all_connections() const {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
std::vector<ClientConnection> result;
|
||||
for (const auto& pair : connections_) {
|
||||
result.push_back(pair.second);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, ClientConnection> connections_;
|
||||
mutable std::mutex mutex_;
|
||||
int max_connections_;
|
||||
};
|
||||
|
||||
// ==================== 数据包排序器 ====================
|
||||
|
||||
/**
|
||||
* 数据包排序器
|
||||
* 网络传输可能导致数据包乱序到达
|
||||
* 使用滑动窗口机制对数据包进行重排序
|
||||
*/
|
||||
class PacketReorderer {
|
||||
public:
|
||||
PacketReorderer(int window_size = 16) : window_size_(window_size), expected_seq_(0) {}
|
||||
|
||||
/**
|
||||
* 提交数据包到排序窗口
|
||||
* 如果是期望的下一个序号则直接输出
|
||||
* 否则缓存等待前序包到达
|
||||
*/
|
||||
std::vector<GrpcStrokePacket> submit(const GrpcStrokePacket& packet) {
|
||||
std::vector<GrpcStrokePacket> output;
|
||||
|
||||
if (packet.sequence_number == expected_seq_) {
|
||||
// 正好是期望的下一个包
|
||||
output.push_back(packet);
|
||||
expected_seq_++;
|
||||
|
||||
// 检查缓存中是否有后续连续的包
|
||||
while (buffer_.count(expected_seq_) > 0) {
|
||||
output.push_back(buffer_[expected_seq_]);
|
||||
buffer_.erase(expected_seq_);
|
||||
expected_seq_++;
|
||||
}
|
||||
} else if (packet.sequence_number > expected_seq_) {
|
||||
// 后序包先到达,缓存等待
|
||||
buffer_[packet.sequence_number] = packet;
|
||||
|
||||
// 缓存过大时强制输出最旧的包
|
||||
if (static_cast<int>(buffer_.size()) > window_size_) {
|
||||
auto it = buffer_.begin();
|
||||
output.push_back(it->second);
|
||||
expected_seq_ = it->first + 1;
|
||||
buffer_.erase(it);
|
||||
}
|
||||
}
|
||||
// 过期的旧包直接丢弃
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
void reset() {
|
||||
buffer_.clear();
|
||||
expected_seq_ = 0;
|
||||
}
|
||||
|
||||
private:
|
||||
std::map<int, GrpcStrokePacket> buffer_;
|
||||
int window_size_;
|
||||
int expected_seq_;
|
||||
};
|
||||
|
||||
// ==================== gRPC服务实现 ====================
|
||||
|
||||
/**
|
||||
* gRPC笔迹接收服务
|
||||
* 实现InferenceService.ProcessStroke流式RPC
|
||||
* 接收网关推送的笔迹数据流,送入推理引擎处理
|
||||
*
|
||||
* 安全设计:
|
||||
* - gRPC启用mTLS双向认证
|
||||
* - 请求大小限制防恶意攻击
|
||||
* - 连接数限制防DoS
|
||||
*/
|
||||
class GrpcStrokeServer {
|
||||
public:
|
||||
using StrokeCallback = std::function<void(const GrpcStrokePacket&)>;
|
||||
|
||||
GrpcStrokeServer(const std::string& listen_addr = "0.0.0.0:50052",
|
||||
bool enable_tls = true)
|
||||
: listen_addr_(listen_addr), enable_tls_(enable_tls),
|
||||
running_(false), conn_manager_(100) {}
|
||||
|
||||
/**
|
||||
* 设置笔迹数据接收回调
|
||||
* 当收到网关的笔迹数据时调用此回调
|
||||
*/
|
||||
void set_stroke_callback(StrokeCallback callback) {
|
||||
stroke_callback_ = std::move(callback);
|
||||
}
|
||||
|
||||
/**
|
||||
* 启动gRPC服务器
|
||||
* 加载TLS证书,绑定端口,开始监听
|
||||
*/
|
||||
bool start() {
|
||||
if (enable_tls_) {
|
||||
// 加载mTLS证书(安全设计:gRPC启用mTLS双向认证)
|
||||
// grpc::SslServerCredentialsOptions ssl_opts;
|
||||
// ssl_opts.pem_root_certs = load_file("/etc/ssl/ca.crt");
|
||||
// ssl_opts.pem_key_cert_pairs.push_back({
|
||||
// load_file("/etc/ssl/server.key"),
|
||||
// load_file("/etc/ssl/server.crt")
|
||||
// });
|
||||
// ssl_opts.client_certificate_request = GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY;
|
||||
}
|
||||
|
||||
// 构建并启动gRPC服务器
|
||||
// grpc::ServerBuilder builder;
|
||||
// builder.AddListeningPort(listen_addr_, credentials);
|
||||
// builder.RegisterService(this);
|
||||
// builder.SetMaxReceiveMessageSize(10 * 1024 * 1024); // 10MB最大消息
|
||||
// server_ = builder.BuildAndStart();
|
||||
|
||||
running_ = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* ProcessStroke RPC实现
|
||||
* 接收网关的流式笔迹数据,处理后返回识别结果流
|
||||
*/
|
||||
void ProcessStroke(const GrpcStrokePacket& packet) {
|
||||
// 更新连接活跃状态
|
||||
conn_manager_.update_activity(packet.pen_id, packet.points.size() * 16);
|
||||
|
||||
// 数据包排序
|
||||
auto ordered = reorderer_.submit(packet);
|
||||
|
||||
// 处理排序后的数据包
|
||||
for (const auto& p : ordered) {
|
||||
total_packets_++;
|
||||
total_points_ += static_cast<long>(p.points.size());
|
||||
|
||||
// 调用回调函数送入推理引擎
|
||||
if (stroke_callback_) {
|
||||
stroke_callback_(p);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** 停止服务器 */
|
||||
void stop() {
|
||||
running_ = false;
|
||||
// if (server_) server_->Shutdown();
|
||||
}
|
||||
|
||||
/** 获取服务器统计信息 */
|
||||
struct ServerStats {
|
||||
int active_connections;
|
||||
long total_packets;
|
||||
long total_points;
|
||||
bool is_running;
|
||||
};
|
||||
|
||||
ServerStats get_stats() const {
|
||||
ServerStats stats;
|
||||
stats.active_connections = conn_manager_.active_count();
|
||||
stats.total_packets = total_packets_.load();
|
||||
stats.total_points = total_points_.load();
|
||||
stats.is_running = running_.load();
|
||||
return stats;
|
||||
}
|
||||
|
||||
private:
|
||||
std::string listen_addr_;
|
||||
bool enable_tls_;
|
||||
std::atomic<bool> running_;
|
||||
ConnectionManager conn_manager_;
|
||||
PacketReorderer reorderer_;
|
||||
StrokeCallback stroke_callback_;
|
||||
std::atomic<long> total_packets_{0};
|
||||
std::atomic<long> total_points_{0};
|
||||
};
|
||||
|
||||
// ==================== MQTT状态上报客户端 ====================
|
||||
|
||||
/**
|
||||
* MQTT状态上报客户端
|
||||
* 定期向云平台上报算力盒运行状态
|
||||
* Topic: edgebox/{id}/status
|
||||
* 安全设计:MQTT over TLS加密传输
|
||||
*/
|
||||
class MqttReporter {
|
||||
public:
|
||||
MqttReporter(const std::string& broker_url, const std::string& device_id)
|
||||
: broker_url_(broker_url), device_id_(device_id), connected_(false) {}
|
||||
|
||||
/** 连接MQTT Broker(TLS加密) */
|
||||
bool connect() {
|
||||
// 实际环境使用Eclipse Paho MQTT C++ Client
|
||||
// mqtt::async_client client(broker_url_, device_id_);
|
||||
// mqtt::ssl_options ssl_opts;
|
||||
// ssl_opts.set_trust_store("/etc/ssl/ca.crt");
|
||||
// ssl_opts.set_key_store("/etc/ssl/client.crt");
|
||||
// ssl_opts.set_private_key("/etc/ssl/client.key");
|
||||
connected_ = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
/** 上报设备状态 */
|
||||
void report_status(float gpu_usage, float temperature, float inference_qps,
|
||||
int queue_depth, long uptime_s) {
|
||||
if (!connected_) return;
|
||||
|
||||
std::string topic = "edgebox/" + device_id_ + "/status";
|
||||
// 构造JSON状态消息
|
||||
// {"gpu_usage": 45.2, "temperature": 62.5, "qps": 120.3, "queue": 5, "uptime": 3600}
|
||||
}
|
||||
|
||||
/** 接收远程指令 */
|
||||
void subscribe_commands() {
|
||||
std::string topic = "edgebox/" + device_id_ + "/command";
|
||||
// 订阅远程管理指令:重启、模型切换、OTA升级等
|
||||
}
|
||||
|
||||
/** 断开连接 */
|
||||
void disconnect() {
|
||||
connected_ = false;
|
||||
}
|
||||
|
||||
private:
|
||||
std::string broker_url_;
|
||||
std::string device_id_;
|
||||
bool connected_;
|
||||
};
|
||||
|
||||
// ==================== 离线结果缓存 ====================
|
||||
|
||||
/**
|
||||
* 离线结果缓存
|
||||
* 断网期间推理结果暂存到本地SQLite数据库
|
||||
* 网络恢复后自动批量上传至云端
|
||||
* 安全设计:通信安全保障数据完整性
|
||||
*/
|
||||
class OfflineResultCache {
|
||||
public:
|
||||
OfflineResultCache(const std::string& db_path, int max_size_mb = 256)
|
||||
: db_path_(db_path), max_size_mb_(max_size_mb), cached_count_(0) {}
|
||||
|
||||
/** 初始化SQLite数据库 */
|
||||
bool initialize() {
|
||||
// CREATE TABLE IF NOT EXISTS offline_results (
|
||||
// id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
// packet_id TEXT NOT NULL,
|
||||
// result_type TEXT NOT NULL,
|
||||
// result_json TEXT NOT NULL,
|
||||
// created_at INTEGER NOT NULL,
|
||||
// uploaded INTEGER DEFAULT 0
|
||||
// );
|
||||
return true;
|
||||
}
|
||||
|
||||
/** 缓存推理结果 */
|
||||
bool cache_result(const std::string& packet_id, const std::string& type,
|
||||
const std::string& result_json) {
|
||||
// INSERT INTO offline_results (packet_id, result_type, result_json, created_at)
|
||||
// VALUES (?, ?, ?, strftime('%s', 'now'));
|
||||
cached_count_++;
|
||||
return true;
|
||||
}
|
||||
|
||||
/** 获取待上传的缓存结果 */
|
||||
std::vector<std::string> get_pending_results(int limit = 100) {
|
||||
// SELECT * FROM offline_results WHERE uploaded = 0 ORDER BY created_at LIMIT ?
|
||||
return {};
|
||||
}
|
||||
|
||||
/** 标记结果已上传 */
|
||||
void mark_uploaded(const std::vector<int>& ids) {
|
||||
// UPDATE offline_results SET uploaded = 1 WHERE id IN (...)
|
||||
}
|
||||
|
||||
/** 清理已上传的旧数据 */
|
||||
void cleanup(int retention_days = 7) {
|
||||
// DELETE FROM offline_results WHERE uploaded = 1 AND created_at < ?
|
||||
}
|
||||
|
||||
int cached_count() const { return cached_count_; }
|
||||
|
||||
private:
|
||||
std::string db_path_;
|
||||
int max_size_mb_;
|
||||
int cached_count_;
|
||||
};
|
||||
|
||||
// ==================== 集群管理器 ====================
|
||||
|
||||
/**
|
||||
* 多算力盒集群管理器
|
||||
* 通过mDNS服务发现同一校园网内的其他算力盒
|
||||
* 实现负载均衡调度:当本机推理队列过长时,分发至空闲节点
|
||||
*/
|
||||
class ClusterManager {
|
||||
public:
|
||||
struct ClusterNode {
|
||||
std::string node_id; // 节点ID
|
||||
std::string address; // gRPC地址
|
||||
float load_factor; // 负载因子(0-1)
|
||||
bool is_self; // 是否为本机
|
||||
std::chrono::steady_clock::time_point last_seen;
|
||||
};
|
||||
|
||||
ClusterManager(const std::string& self_id) : self_id_(self_id) {}
|
||||
|
||||
/** 启动mDNS服务注册和发现 */
|
||||
bool start_discovery() {
|
||||
// 注册本机mDNS服务
|
||||
// _writech-edgebox._tcp.local.
|
||||
// 定期扫描同网段其他算力盒
|
||||
return true;
|
||||
}
|
||||
|
||||
/** 选择最优节点处理推理任务 */
|
||||
std::string select_best_node() {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
std::string best_id = self_id_;
|
||||
float min_load = 1.0f;
|
||||
|
||||
for (const auto& pair : nodes_) {
|
||||
if (pair.second.load_factor < min_load) {
|
||||
min_load = pair.second.load_factor;
|
||||
best_id = pair.first;
|
||||
}
|
||||
}
|
||||
return best_id;
|
||||
}
|
||||
|
||||
/** 更新本机负载因子 */
|
||||
void update_self_load(float load) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if (nodes_.count(self_id_)) {
|
||||
nodes_[self_id_].load_factor = load;
|
||||
}
|
||||
}
|
||||
|
||||
int cluster_size() const {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return static_cast<int>(nodes_.size());
|
||||
}
|
||||
|
||||
private:
|
||||
std::string self_id_;
|
||||
std::unordered_map<std::string, ClusterNode> nodes_;
|
||||
mutable std::mutex mutex_;
|
||||
};
|
||||
|
||||
#endif // GRPC_SERVER_H
|
||||
Reference in New Issue
Block a user