Files
system-design/software-copyright/05-writech-edge-box/communication/grpc_server.cpp
T
2026-03-22 15:24:40 +08:00

501 lines
16 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
/**
* 自然写教室智能算力盒边缘计算软件 V1.0
* gRPC通信服务模块 - 与教室网关的笔迹数据交互
*
* 实现gRPC流式服务,接收网关转发的笔迹数据流
* 支持mTLS双向认证确保通信安全
*/
#ifndef GRPC_SERVER_H
#define GRPC_SERVER_H
#include <string>
#include <vector>
#include <memory>
#include <mutex>
#include <atomic>
#include <thread>
#include <functional>
#include <unordered_map>
#include <chrono>
#include <queue>
// ==================== gRPC消息结构 ====================
/** 笔迹坐标点(对应protobuf消息) */
struct GrpcStrokePoint {
float x;
float y;
float pressure;
uint32_t timestamp;
bool pen_up;
};
/** 笔迹数据包(对应protobuf消息) */
struct GrpcStrokePacket {
std::string packet_id; // 数据包ID
std::string pen_id; // 笔设备MAC地址
std::string student_id; // 学生ID
std::string page_id; // 点阵码页面ID
std::vector<GrpcStrokePoint> points; // 坐标点序列
uint64_t gateway_timestamp; // 网关转发时间戳
int sequence_number; // 包序号(用于乱序检测)
};
/** 识别结果响应 */
struct GrpcRecognitionResponse {
std::string packet_id; // 对应的请求包ID
std::string recognition_type; // 识别类型(ocr/math/stroke_order
bool success; // 是否成功
std::string result_text; // 识别结果文本
float confidence; // 置信度
float processing_time_ms; // 处理耗时
std::string model_version; // 使用的模型版本
};
// ==================== 连接管理器 ====================
/** 客户端连接信息 */
struct ClientConnection {
std::string client_id; // 客户端标识(网关ID
std::string client_addr; // 客户端地址
std::string cert_fingerprint; // 客户端证书指纹(mTLS
std::chrono::steady_clock::time_point connected_at;
std::chrono::steady_clock::time_point last_active;
long packets_received; // 已接收数据包数
long bytes_received; // 已接收字节数
bool authenticated; // 是否已通过mTLS认证
};
/**
* gRPC连接管理器
* 管理与多个教室网关的gRPC连接
* 每个网关对应一个持久化的gRPC流式连接
*/
class ConnectionManager {
public:
ConnectionManager(int max_connections = 100)
: max_connections_(max_connections) {}
/** 注册新连接 */
bool register_connection(const std::string& client_id, const std::string& addr,
const std::string& cert_fp) {
std::lock_guard<std::mutex> lock(mutex_);
if (static_cast<int>(connections_.size()) >= max_connections_) {
return false; // 达到最大连接数限制
}
ClientConnection conn;
conn.client_id = client_id;
conn.client_addr = addr;
conn.cert_fingerprint = cert_fp;
conn.connected_at = std::chrono::steady_clock::now();
conn.last_active = conn.connected_at;
conn.packets_received = 0;
conn.bytes_received = 0;
conn.authenticated = !cert_fp.empty();
connections_[client_id] = conn;
return true;
}
/** 移除连接 */
void remove_connection(const std::string& client_id) {
std::lock_guard<std::mutex> lock(mutex_);
connections_.erase(client_id);
}
/** 更新连接活跃时间 */
void update_activity(const std::string& client_id, long bytes) {
std::lock_guard<std::mutex> lock(mutex_);
auto it = connections_.find(client_id);
if (it != connections_.end()) {
it->second.last_active = std::chrono::steady_clock::now();
it->second.packets_received++;
it->second.bytes_received += bytes;
}
}
/** 检查空闲超时连接 */
std::vector<std::string> check_idle_connections(int timeout_s = 300) {
std::lock_guard<std::mutex> lock(mutex_);
std::vector<std::string> idle;
auto now = std::chrono::steady_clock::now();
for (const auto& pair : connections_) {
auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
now - pair.second.last_active).count();
if (elapsed > timeout_s) {
idle.push_back(pair.first);
}
}
return idle;
}
/** 获取当前连接数 */
int active_count() const {
std::lock_guard<std::mutex> lock(mutex_);
return static_cast<int>(connections_.size());
}
/** 获取所有连接状态 */
std::vector<ClientConnection> get_all_connections() const {
std::lock_guard<std::mutex> lock(mutex_);
std::vector<ClientConnection> result;
for (const auto& pair : connections_) {
result.push_back(pair.second);
}
return result;
}
private:
std::unordered_map<std::string, ClientConnection> connections_;
mutable std::mutex mutex_;
int max_connections_;
};
// ==================== 数据包排序器 ====================
/**
* 数据包排序器
* 网络传输可能导致数据包乱序到达
* 使用滑动窗口机制对数据包进行重排序
*/
class PacketReorderer {
public:
PacketReorderer(int window_size = 16) : window_size_(window_size), expected_seq_(0) {}
/**
* 提交数据包到排序窗口
* 如果是期望的下一个序号则直接输出
* 否则缓存等待前序包到达
*/
std::vector<GrpcStrokePacket> submit(const GrpcStrokePacket& packet) {
std::vector<GrpcStrokePacket> output;
if (packet.sequence_number == expected_seq_) {
// 正好是期望的下一个包
output.push_back(packet);
expected_seq_++;
// 检查缓存中是否有后续连续的包
while (buffer_.count(expected_seq_) > 0) {
output.push_back(buffer_[expected_seq_]);
buffer_.erase(expected_seq_);
expected_seq_++;
}
} else if (packet.sequence_number > expected_seq_) {
// 后序包先到达,缓存等待
buffer_[packet.sequence_number] = packet;
// 缓存过大时强制输出最旧的包
if (static_cast<int>(buffer_.size()) > window_size_) {
auto it = buffer_.begin();
output.push_back(it->second);
expected_seq_ = it->first + 1;
buffer_.erase(it);
}
}
// 过期的旧包直接丢弃
return output;
}
void reset() {
buffer_.clear();
expected_seq_ = 0;
}
private:
std::map<int, GrpcStrokePacket> buffer_;
int window_size_;
int expected_seq_;
};
// ==================== gRPC服务实现 ====================
/**
* gRPC笔迹接收服务
* 实现InferenceService.ProcessStroke流式RPC
* 接收网关推送的笔迹数据流,送入推理引擎处理
*
* 安全设计:
* - gRPC启用mTLS双向认证
* - 请求大小限制防恶意攻击
* - 连接数限制防DoS
*/
class GrpcStrokeServer {
public:
using StrokeCallback = std::function<void(const GrpcStrokePacket&)>;
GrpcStrokeServer(const std::string& listen_addr = "0.0.0.0:50052",
bool enable_tls = true)
: listen_addr_(listen_addr), enable_tls_(enable_tls),
running_(false), conn_manager_(100) {}
/**
* 设置笔迹数据接收回调
* 当收到网关的笔迹数据时调用此回调
*/
void set_stroke_callback(StrokeCallback callback) {
stroke_callback_ = std::move(callback);
}
/**
* 启动gRPC服务器
* 加载TLS证书,绑定端口,开始监听
*/
bool start() {
if (enable_tls_) {
// 加载mTLS证书(安全设计:gRPC启用mTLS双向认证)
// grpc::SslServerCredentialsOptions ssl_opts;
// ssl_opts.pem_root_certs = load_file("/etc/ssl/ca.crt");
// ssl_opts.pem_key_cert_pairs.push_back({
// load_file("/etc/ssl/server.key"),
// load_file("/etc/ssl/server.crt")
// });
// ssl_opts.client_certificate_request = GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY;
}
// 构建并启动gRPC服务器
// grpc::ServerBuilder builder;
// builder.AddListeningPort(listen_addr_, credentials);
// builder.RegisterService(this);
// builder.SetMaxReceiveMessageSize(10 * 1024 * 1024); // 10MB最大消息
// server_ = builder.BuildAndStart();
running_ = true;
return true;
}
/**
* ProcessStroke RPC实现
* 接收网关的流式笔迹数据,处理后返回识别结果流
*/
void ProcessStroke(const GrpcStrokePacket& packet) {
// 更新连接活跃状态
conn_manager_.update_activity(packet.pen_id, packet.points.size() * 16);
// 数据包排序
auto ordered = reorderer_.submit(packet);
// 处理排序后的数据包
for (const auto& p : ordered) {
total_packets_++;
total_points_ += static_cast<long>(p.points.size());
// 调用回调函数送入推理引擎
if (stroke_callback_) {
stroke_callback_(p);
}
}
}
/** 停止服务器 */
void stop() {
running_ = false;
// if (server_) server_->Shutdown();
}
/** 获取服务器统计信息 */
struct ServerStats {
int active_connections;
long total_packets;
long total_points;
bool is_running;
};
ServerStats get_stats() const {
ServerStats stats;
stats.active_connections = conn_manager_.active_count();
stats.total_packets = total_packets_.load();
stats.total_points = total_points_.load();
stats.is_running = running_.load();
return stats;
}
private:
std::string listen_addr_;
bool enable_tls_;
std::atomic<bool> running_;
ConnectionManager conn_manager_;
PacketReorderer reorderer_;
StrokeCallback stroke_callback_;
std::atomic<long> total_packets_{0};
std::atomic<long> total_points_{0};
};
// ==================== MQTT状态上报客户端 ====================
/**
* MQTT状态上报客户端
* 定期向云平台上报算力盒运行状态
* Topic: edgebox/{id}/status
* 安全设计:MQTT over TLS加密传输
*/
class MqttReporter {
public:
MqttReporter(const std::string& broker_url, const std::string& device_id)
: broker_url_(broker_url), device_id_(device_id), connected_(false) {}
/** 连接MQTT BrokerTLS加密) */
bool connect() {
// 实际环境使用Eclipse Paho MQTT C++ Client
// mqtt::async_client client(broker_url_, device_id_);
// mqtt::ssl_options ssl_opts;
// ssl_opts.set_trust_store("/etc/ssl/ca.crt");
// ssl_opts.set_key_store("/etc/ssl/client.crt");
// ssl_opts.set_private_key("/etc/ssl/client.key");
connected_ = true;
return true;
}
/** 上报设备状态 */
void report_status(float gpu_usage, float temperature, float inference_qps,
int queue_depth, long uptime_s) {
if (!connected_) return;
std::string topic = "edgebox/" + device_id_ + "/status";
// 构造JSON状态消息
// {"gpu_usage": 45.2, "temperature": 62.5, "qps": 120.3, "queue": 5, "uptime": 3600}
}
/** 接收远程指令 */
void subscribe_commands() {
std::string topic = "edgebox/" + device_id_ + "/command";
// 订阅远程管理指令:重启、模型切换、OTA升级等
}
/** 断开连接 */
void disconnect() {
connected_ = false;
}
private:
std::string broker_url_;
std::string device_id_;
bool connected_;
};
// ==================== 离线结果缓存 ====================
/**
* 离线结果缓存
* 断网期间推理结果暂存到本地SQLite数据库
* 网络恢复后自动批量上传至云端
* 安全设计:通信安全保障数据完整性
*/
class OfflineResultCache {
public:
OfflineResultCache(const std::string& db_path, int max_size_mb = 256)
: db_path_(db_path), max_size_mb_(max_size_mb), cached_count_(0) {}
/** 初始化SQLite数据库 */
bool initialize() {
// CREATE TABLE IF NOT EXISTS offline_results (
// id INTEGER PRIMARY KEY AUTOINCREMENT,
// packet_id TEXT NOT NULL,
// result_type TEXT NOT NULL,
// result_json TEXT NOT NULL,
// created_at INTEGER NOT NULL,
// uploaded INTEGER DEFAULT 0
// );
return true;
}
/** 缓存推理结果 */
bool cache_result(const std::string& packet_id, const std::string& type,
const std::string& result_json) {
// INSERT INTO offline_results (packet_id, result_type, result_json, created_at)
// VALUES (?, ?, ?, strftime('%s', 'now'));
cached_count_++;
return true;
}
/** 获取待上传的缓存结果 */
std::vector<std::string> get_pending_results(int limit = 100) {
// SELECT * FROM offline_results WHERE uploaded = 0 ORDER BY created_at LIMIT ?
return {};
}
/** 标记结果已上传 */
void mark_uploaded(const std::vector<int>& ids) {
// UPDATE offline_results SET uploaded = 1 WHERE id IN (...)
}
/** 清理已上传的旧数据 */
void cleanup(int retention_days = 7) {
// DELETE FROM offline_results WHERE uploaded = 1 AND created_at < ?
}
int cached_count() const { return cached_count_; }
private:
std::string db_path_;
int max_size_mb_;
int cached_count_;
};
// ==================== 集群管理器 ====================
/**
* 多算力盒集群管理器
* 通过mDNS服务发现同一校园网内的其他算力盒
* 实现负载均衡调度:当本机推理队列过长时,分发至空闲节点
*/
class ClusterManager {
public:
struct ClusterNode {
std::string node_id; // 节点ID
std::string address; // gRPC地址
float load_factor; // 负载因子(0-1)
bool is_self; // 是否为本机
std::chrono::steady_clock::time_point last_seen;
};
ClusterManager(const std::string& self_id) : self_id_(self_id) {}
/** 启动mDNS服务注册和发现 */
bool start_discovery() {
// 注册本机mDNS服务
// _writech-edgebox._tcp.local.
// 定期扫描同网段其他算力盒
return true;
}
/** 选择最优节点处理推理任务 */
std::string select_best_node() {
std::lock_guard<std::mutex> lock(mutex_);
std::string best_id = self_id_;
float min_load = 1.0f;
for (const auto& pair : nodes_) {
if (pair.second.load_factor < min_load) {
min_load = pair.second.load_factor;
best_id = pair.first;
}
}
return best_id;
}
/** 更新本机负载因子 */
void update_self_load(float load) {
std::lock_guard<std::mutex> lock(mutex_);
if (nodes_.count(self_id_)) {
nodes_[self_id_].load_factor = load;
}
}
int cluster_size() const {
std::lock_guard<std::mutex> lock(mutex_);
return static_cast<int>(nodes_.size());
}
private:
std::string self_id_;
std::unordered_map<std::string, ClusterNode> nodes_;
mutable std::mutex mutex_;
};
#endif // GRPC_SERVER_H