501 lines
16 KiB
C++
501 lines
16 KiB
C++
/**
|
||
* 自然写教室智能算力盒边缘计算软件 V1.0
|
||
* gRPC通信服务模块 - 与教室网关的笔迹数据交互
|
||
*
|
||
* 实现gRPC流式服务,接收网关转发的笔迹数据流
|
||
* 支持mTLS双向认证确保通信安全
|
||
*/
|
||
|
||
#ifndef GRPC_SERVER_H
|
||
#define GRPC_SERVER_H
|
||
|
||
#include <string>
|
||
#include <vector>
|
||
#include <memory>
|
||
#include <mutex>
|
||
#include <atomic>
|
||
#include <thread>
|
||
#include <functional>
|
||
#include <unordered_map>
|
||
#include <chrono>
|
||
#include <queue>
|
||
|
||
// ==================== gRPC消息结构 ====================
|
||
|
||
/** 笔迹坐标点(对应protobuf消息) */
|
||
struct GrpcStrokePoint {
|
||
float x;
|
||
float y;
|
||
float pressure;
|
||
uint32_t timestamp;
|
||
bool pen_up;
|
||
};
|
||
|
||
/** 笔迹数据包(对应protobuf消息) */
|
||
struct GrpcStrokePacket {
|
||
std::string packet_id; // 数据包ID
|
||
std::string pen_id; // 笔设备MAC地址
|
||
std::string student_id; // 学生ID
|
||
std::string page_id; // 点阵码页面ID
|
||
std::vector<GrpcStrokePoint> points; // 坐标点序列
|
||
uint64_t gateway_timestamp; // 网关转发时间戳
|
||
int sequence_number; // 包序号(用于乱序检测)
|
||
};
|
||
|
||
/** 识别结果响应 */
|
||
struct GrpcRecognitionResponse {
|
||
std::string packet_id; // 对应的请求包ID
|
||
std::string recognition_type; // 识别类型(ocr/math/stroke_order)
|
||
bool success; // 是否成功
|
||
std::string result_text; // 识别结果文本
|
||
float confidence; // 置信度
|
||
float processing_time_ms; // 处理耗时
|
||
std::string model_version; // 使用的模型版本
|
||
};
|
||
|
||
// ==================== 连接管理器 ====================
|
||
|
||
/** 客户端连接信息 */
|
||
struct ClientConnection {
|
||
std::string client_id; // 客户端标识(网关ID)
|
||
std::string client_addr; // 客户端地址
|
||
std::string cert_fingerprint; // 客户端证书指纹(mTLS)
|
||
std::chrono::steady_clock::time_point connected_at;
|
||
std::chrono::steady_clock::time_point last_active;
|
||
long packets_received; // 已接收数据包数
|
||
long bytes_received; // 已接收字节数
|
||
bool authenticated; // 是否已通过mTLS认证
|
||
};
|
||
|
||
/**
|
||
* gRPC连接管理器
|
||
* 管理与多个教室网关的gRPC连接
|
||
* 每个网关对应一个持久化的gRPC流式连接
|
||
*/
|
||
class ConnectionManager {
|
||
public:
|
||
ConnectionManager(int max_connections = 100)
|
||
: max_connections_(max_connections) {}
|
||
|
||
/** 注册新连接 */
|
||
bool register_connection(const std::string& client_id, const std::string& addr,
|
||
const std::string& cert_fp) {
|
||
std::lock_guard<std::mutex> lock(mutex_);
|
||
if (static_cast<int>(connections_.size()) >= max_connections_) {
|
||
return false; // 达到最大连接数限制
|
||
}
|
||
|
||
ClientConnection conn;
|
||
conn.client_id = client_id;
|
||
conn.client_addr = addr;
|
||
conn.cert_fingerprint = cert_fp;
|
||
conn.connected_at = std::chrono::steady_clock::now();
|
||
conn.last_active = conn.connected_at;
|
||
conn.packets_received = 0;
|
||
conn.bytes_received = 0;
|
||
conn.authenticated = !cert_fp.empty();
|
||
|
||
connections_[client_id] = conn;
|
||
return true;
|
||
}
|
||
|
||
/** 移除连接 */
|
||
void remove_connection(const std::string& client_id) {
|
||
std::lock_guard<std::mutex> lock(mutex_);
|
||
connections_.erase(client_id);
|
||
}
|
||
|
||
/** 更新连接活跃时间 */
|
||
void update_activity(const std::string& client_id, long bytes) {
|
||
std::lock_guard<std::mutex> lock(mutex_);
|
||
auto it = connections_.find(client_id);
|
||
if (it != connections_.end()) {
|
||
it->second.last_active = std::chrono::steady_clock::now();
|
||
it->second.packets_received++;
|
||
it->second.bytes_received += bytes;
|
||
}
|
||
}
|
||
|
||
/** 检查空闲超时连接 */
|
||
std::vector<std::string> check_idle_connections(int timeout_s = 300) {
|
||
std::lock_guard<std::mutex> lock(mutex_);
|
||
std::vector<std::string> idle;
|
||
auto now = std::chrono::steady_clock::now();
|
||
|
||
for (const auto& pair : connections_) {
|
||
auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
|
||
now - pair.second.last_active).count();
|
||
if (elapsed > timeout_s) {
|
||
idle.push_back(pair.first);
|
||
}
|
||
}
|
||
return idle;
|
||
}
|
||
|
||
/** 获取当前连接数 */
|
||
int active_count() const {
|
||
std::lock_guard<std::mutex> lock(mutex_);
|
||
return static_cast<int>(connections_.size());
|
||
}
|
||
|
||
/** 获取所有连接状态 */
|
||
std::vector<ClientConnection> get_all_connections() const {
|
||
std::lock_guard<std::mutex> lock(mutex_);
|
||
std::vector<ClientConnection> result;
|
||
for (const auto& pair : connections_) {
|
||
result.push_back(pair.second);
|
||
}
|
||
return result;
|
||
}
|
||
|
||
private:
|
||
std::unordered_map<std::string, ClientConnection> connections_;
|
||
mutable std::mutex mutex_;
|
||
int max_connections_;
|
||
};
|
||
|
||
// ==================== 数据包排序器 ====================
|
||
|
||
/**
|
||
* 数据包排序器
|
||
* 网络传输可能导致数据包乱序到达
|
||
* 使用滑动窗口机制对数据包进行重排序
|
||
*/
|
||
class PacketReorderer {
|
||
public:
|
||
PacketReorderer(int window_size = 16) : window_size_(window_size), expected_seq_(0) {}
|
||
|
||
/**
|
||
* 提交数据包到排序窗口
|
||
* 如果是期望的下一个序号则直接输出
|
||
* 否则缓存等待前序包到达
|
||
*/
|
||
std::vector<GrpcStrokePacket> submit(const GrpcStrokePacket& packet) {
|
||
std::vector<GrpcStrokePacket> output;
|
||
|
||
if (packet.sequence_number == expected_seq_) {
|
||
// 正好是期望的下一个包
|
||
output.push_back(packet);
|
||
expected_seq_++;
|
||
|
||
// 检查缓存中是否有后续连续的包
|
||
while (buffer_.count(expected_seq_) > 0) {
|
||
output.push_back(buffer_[expected_seq_]);
|
||
buffer_.erase(expected_seq_);
|
||
expected_seq_++;
|
||
}
|
||
} else if (packet.sequence_number > expected_seq_) {
|
||
// 后序包先到达,缓存等待
|
||
buffer_[packet.sequence_number] = packet;
|
||
|
||
// 缓存过大时强制输出最旧的包
|
||
if (static_cast<int>(buffer_.size()) > window_size_) {
|
||
auto it = buffer_.begin();
|
||
output.push_back(it->second);
|
||
expected_seq_ = it->first + 1;
|
||
buffer_.erase(it);
|
||
}
|
||
}
|
||
// 过期的旧包直接丢弃
|
||
|
||
return output;
|
||
}
|
||
|
||
void reset() {
|
||
buffer_.clear();
|
||
expected_seq_ = 0;
|
||
}
|
||
|
||
private:
|
||
std::map<int, GrpcStrokePacket> buffer_;
|
||
int window_size_;
|
||
int expected_seq_;
|
||
};
|
||
|
||
// ==================== gRPC服务实现 ====================
|
||
|
||
/**
|
||
* gRPC笔迹接收服务
|
||
* 实现InferenceService.ProcessStroke流式RPC
|
||
* 接收网关推送的笔迹数据流,送入推理引擎处理
|
||
*
|
||
* 安全设计:
|
||
* - gRPC启用mTLS双向认证
|
||
* - 请求大小限制防恶意攻击
|
||
* - 连接数限制防DoS
|
||
*/
|
||
class GrpcStrokeServer {
|
||
public:
|
||
using StrokeCallback = std::function<void(const GrpcStrokePacket&)>;
|
||
|
||
GrpcStrokeServer(const std::string& listen_addr = "0.0.0.0:50052",
|
||
bool enable_tls = true)
|
||
: listen_addr_(listen_addr), enable_tls_(enable_tls),
|
||
running_(false), conn_manager_(100) {}
|
||
|
||
/**
|
||
* 设置笔迹数据接收回调
|
||
* 当收到网关的笔迹数据时调用此回调
|
||
*/
|
||
void set_stroke_callback(StrokeCallback callback) {
|
||
stroke_callback_ = std::move(callback);
|
||
}
|
||
|
||
/**
|
||
* 启动gRPC服务器
|
||
* 加载TLS证书,绑定端口,开始监听
|
||
*/
|
||
bool start() {
|
||
if (enable_tls_) {
|
||
// 加载mTLS证书(安全设计:gRPC启用mTLS双向认证)
|
||
// grpc::SslServerCredentialsOptions ssl_opts;
|
||
// ssl_opts.pem_root_certs = load_file("/etc/ssl/ca.crt");
|
||
// ssl_opts.pem_key_cert_pairs.push_back({
|
||
// load_file("/etc/ssl/server.key"),
|
||
// load_file("/etc/ssl/server.crt")
|
||
// });
|
||
// ssl_opts.client_certificate_request = GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY;
|
||
}
|
||
|
||
// 构建并启动gRPC服务器
|
||
// grpc::ServerBuilder builder;
|
||
// builder.AddListeningPort(listen_addr_, credentials);
|
||
// builder.RegisterService(this);
|
||
// builder.SetMaxReceiveMessageSize(10 * 1024 * 1024); // 10MB最大消息
|
||
// server_ = builder.BuildAndStart();
|
||
|
||
running_ = true;
|
||
return true;
|
||
}
|
||
|
||
/**
|
||
* ProcessStroke RPC实现
|
||
* 接收网关的流式笔迹数据,处理后返回识别结果流
|
||
*/
|
||
void ProcessStroke(const GrpcStrokePacket& packet) {
|
||
// 更新连接活跃状态
|
||
conn_manager_.update_activity(packet.pen_id, packet.points.size() * 16);
|
||
|
||
// 数据包排序
|
||
auto ordered = reorderer_.submit(packet);
|
||
|
||
// 处理排序后的数据包
|
||
for (const auto& p : ordered) {
|
||
total_packets_++;
|
||
total_points_ += static_cast<long>(p.points.size());
|
||
|
||
// 调用回调函数送入推理引擎
|
||
if (stroke_callback_) {
|
||
stroke_callback_(p);
|
||
}
|
||
}
|
||
}
|
||
|
||
/** 停止服务器 */
|
||
void stop() {
|
||
running_ = false;
|
||
// if (server_) server_->Shutdown();
|
||
}
|
||
|
||
/** 获取服务器统计信息 */
|
||
struct ServerStats {
|
||
int active_connections;
|
||
long total_packets;
|
||
long total_points;
|
||
bool is_running;
|
||
};
|
||
|
||
ServerStats get_stats() const {
|
||
ServerStats stats;
|
||
stats.active_connections = conn_manager_.active_count();
|
||
stats.total_packets = total_packets_.load();
|
||
stats.total_points = total_points_.load();
|
||
stats.is_running = running_.load();
|
||
return stats;
|
||
}
|
||
|
||
private:
|
||
std::string listen_addr_;
|
||
bool enable_tls_;
|
||
std::atomic<bool> running_;
|
||
ConnectionManager conn_manager_;
|
||
PacketReorderer reorderer_;
|
||
StrokeCallback stroke_callback_;
|
||
std::atomic<long> total_packets_{0};
|
||
std::atomic<long> total_points_{0};
|
||
};
|
||
|
||
// ==================== MQTT状态上报客户端 ====================
|
||
|
||
/**
|
||
* MQTT状态上报客户端
|
||
* 定期向云平台上报算力盒运行状态
|
||
* Topic: edgebox/{id}/status
|
||
* 安全设计:MQTT over TLS加密传输
|
||
*/
|
||
class MqttReporter {
|
||
public:
|
||
MqttReporter(const std::string& broker_url, const std::string& device_id)
|
||
: broker_url_(broker_url), device_id_(device_id), connected_(false) {}
|
||
|
||
/** 连接MQTT Broker(TLS加密) */
|
||
bool connect() {
|
||
// 实际环境使用Eclipse Paho MQTT C++ Client
|
||
// mqtt::async_client client(broker_url_, device_id_);
|
||
// mqtt::ssl_options ssl_opts;
|
||
// ssl_opts.set_trust_store("/etc/ssl/ca.crt");
|
||
// ssl_opts.set_key_store("/etc/ssl/client.crt");
|
||
// ssl_opts.set_private_key("/etc/ssl/client.key");
|
||
connected_ = true;
|
||
return true;
|
||
}
|
||
|
||
/** 上报设备状态 */
|
||
void report_status(float gpu_usage, float temperature, float inference_qps,
|
||
int queue_depth, long uptime_s) {
|
||
if (!connected_) return;
|
||
|
||
std::string topic = "edgebox/" + device_id_ + "/status";
|
||
// 构造JSON状态消息
|
||
// {"gpu_usage": 45.2, "temperature": 62.5, "qps": 120.3, "queue": 5, "uptime": 3600}
|
||
}
|
||
|
||
/** 接收远程指令 */
|
||
void subscribe_commands() {
|
||
std::string topic = "edgebox/" + device_id_ + "/command";
|
||
// 订阅远程管理指令:重启、模型切换、OTA升级等
|
||
}
|
||
|
||
/** 断开连接 */
|
||
void disconnect() {
|
||
connected_ = false;
|
||
}
|
||
|
||
private:
|
||
std::string broker_url_;
|
||
std::string device_id_;
|
||
bool connected_;
|
||
};
|
||
|
||
// ==================== 离线结果缓存 ====================
|
||
|
||
/**
|
||
* 离线结果缓存
|
||
* 断网期间推理结果暂存到本地SQLite数据库
|
||
* 网络恢复后自动批量上传至云端
|
||
* 安全设计:通信安全保障数据完整性
|
||
*/
|
||
class OfflineResultCache {
|
||
public:
|
||
OfflineResultCache(const std::string& db_path, int max_size_mb = 256)
|
||
: db_path_(db_path), max_size_mb_(max_size_mb), cached_count_(0) {}
|
||
|
||
/** 初始化SQLite数据库 */
|
||
bool initialize() {
|
||
// CREATE TABLE IF NOT EXISTS offline_results (
|
||
// id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
// packet_id TEXT NOT NULL,
|
||
// result_type TEXT NOT NULL,
|
||
// result_json TEXT NOT NULL,
|
||
// created_at INTEGER NOT NULL,
|
||
// uploaded INTEGER DEFAULT 0
|
||
// );
|
||
return true;
|
||
}
|
||
|
||
/** 缓存推理结果 */
|
||
bool cache_result(const std::string& packet_id, const std::string& type,
|
||
const std::string& result_json) {
|
||
// INSERT INTO offline_results (packet_id, result_type, result_json, created_at)
|
||
// VALUES (?, ?, ?, strftime('%s', 'now'));
|
||
cached_count_++;
|
||
return true;
|
||
}
|
||
|
||
/** 获取待上传的缓存结果 */
|
||
std::vector<std::string> get_pending_results(int limit = 100) {
|
||
// SELECT * FROM offline_results WHERE uploaded = 0 ORDER BY created_at LIMIT ?
|
||
return {};
|
||
}
|
||
|
||
/** 标记结果已上传 */
|
||
void mark_uploaded(const std::vector<int>& ids) {
|
||
// UPDATE offline_results SET uploaded = 1 WHERE id IN (...)
|
||
}
|
||
|
||
/** 清理已上传的旧数据 */
|
||
void cleanup(int retention_days = 7) {
|
||
// DELETE FROM offline_results WHERE uploaded = 1 AND created_at < ?
|
||
}
|
||
|
||
int cached_count() const { return cached_count_; }
|
||
|
||
private:
|
||
std::string db_path_;
|
||
int max_size_mb_;
|
||
int cached_count_;
|
||
};
|
||
|
||
// ==================== 集群管理器 ====================
|
||
|
||
/**
|
||
* 多算力盒集群管理器
|
||
* 通过mDNS服务发现同一校园网内的其他算力盒
|
||
* 实现负载均衡调度:当本机推理队列过长时,分发至空闲节点
|
||
*/
|
||
class ClusterManager {
|
||
public:
|
||
struct ClusterNode {
|
||
std::string node_id; // 节点ID
|
||
std::string address; // gRPC地址
|
||
float load_factor; // 负载因子(0-1)
|
||
bool is_self; // 是否为本机
|
||
std::chrono::steady_clock::time_point last_seen;
|
||
};
|
||
|
||
ClusterManager(const std::string& self_id) : self_id_(self_id) {}
|
||
|
||
/** 启动mDNS服务注册和发现 */
|
||
bool start_discovery() {
|
||
// 注册本机mDNS服务
|
||
// _writech-edgebox._tcp.local.
|
||
// 定期扫描同网段其他算力盒
|
||
return true;
|
||
}
|
||
|
||
/** 选择最优节点处理推理任务 */
|
||
std::string select_best_node() {
|
||
std::lock_guard<std::mutex> lock(mutex_);
|
||
std::string best_id = self_id_;
|
||
float min_load = 1.0f;
|
||
|
||
for (const auto& pair : nodes_) {
|
||
if (pair.second.load_factor < min_load) {
|
||
min_load = pair.second.load_factor;
|
||
best_id = pair.first;
|
||
}
|
||
}
|
||
return best_id;
|
||
}
|
||
|
||
/** 更新本机负载因子 */
|
||
void update_self_load(float load) {
|
||
std::lock_guard<std::mutex> lock(mutex_);
|
||
if (nodes_.count(self_id_)) {
|
||
nodes_[self_id_].load_factor = load;
|
||
}
|
||
}
|
||
|
||
int cluster_size() const {
|
||
std::lock_guard<std::mutex> lock(mutex_);
|
||
return static_cast<int>(nodes_.size());
|
||
}
|
||
|
||
private:
|
||
std::string self_id_;
|
||
std::unordered_map<std::string, ClusterNode> nodes_;
|
||
mutable std::mutex mutex_;
|
||
};
|
||
|
||
#endif // GRPC_SERVER_H
|