444 lines
13 KiB
C++
444 lines
13 KiB
C++
/**
|
||
* 自然写教室智能算力盒边缘计算软件 V1.0
|
||
* 模型管理模块 - 模型加载、版本管理、量化压缩、云端同步
|
||
*
|
||
* 管理算力盒上部署的所有AI推理模型的生命周期
|
||
* 支持模型热更新、A/B切换、云端版本同步
|
||
* 模型文件AES-256加密存储,推理时内存解密加载
|
||
*/
|
||
|
||
#ifndef MODEL_MANAGER_H
|
||
#define MODEL_MANAGER_H
|
||
|
||
#include <string>
|
||
#include <vector>
|
||
#include <memory>
|
||
#include <mutex>
|
||
#include <atomic>
|
||
#include <unordered_map>
|
||
#include <chrono>
|
||
#include <functional>
|
||
|
||
// ==================== 模型元信息 ====================
|
||
|
||
/** 模型状态枚举 */
|
||
enum class ModelState {
|
||
NOT_FOUND = 0, // 未发现
|
||
DOWNLOADING = 1, // 下载中
|
||
DECRYPTING = 2, // 解密中
|
||
LOADING = 3, // 加载到设备中
|
||
READY = 4, // 就绪可用
|
||
ACTIVE = 5, // 当前使用中
|
||
DEPRECATED = 6, // 已弃用
|
||
ERROR = 7 // 错误状态
|
||
};
|
||
|
||
/** 模型量化精度 */
|
||
enum class QuantizationType {
|
||
FP32 = 0, // 全精度浮点
|
||
FP16 = 1, // 半精度浮点
|
||
INT8 = 2, // 8位整型量化
|
||
INT4 = 3 // 4位整型量化(极致压缩)
|
||
};
|
||
|
||
/** 模型元信息 */
|
||
struct ModelInfo {
|
||
std::string name; // 模型名称
|
||
std::string version; // 版本号(语义化版本)
|
||
std::string format; // 格式(onnx/trt/rknn)
|
||
std::string file_path; // 本地文件路径
|
||
size_t file_size_bytes; // 文件大小
|
||
std::string sha256; // 文件SHA-256校验和
|
||
QuantizationType quantization; // 量化类型
|
||
float accuracy; // 测试集准确率
|
||
float latency_ms; // 平均推理延迟
|
||
ModelState state; // 当前状态
|
||
std::string deployed_at; // 部署时间
|
||
std::string description; // 模型描述
|
||
};
|
||
|
||
// ==================== 模型加密管理 ====================
|
||
|
||
/**
|
||
* 模型文件加密/解密管理器
|
||
* 安全设计:模型文件AES-256加密存储,推理时内存解密加载
|
||
* 加密密钥通过安全芯片(TPM)或环境变量注入
|
||
*/
|
||
class ModelCryptoManager {
|
||
public:
|
||
ModelCryptoManager() : key_loaded_(false) {}
|
||
|
||
/**
|
||
* 加载加密密钥
|
||
* 优先从安全芯片读取,其次从环境变量
|
||
*/
|
||
bool load_encryption_key() {
|
||
// 尝试从TPM安全芯片读取密钥
|
||
// if (tpm_available()) { key_ = tpm_read_key("model_key"); }
|
||
|
||
// 后备方案:从环境变量读取
|
||
const char* env_key = std::getenv("WRITECH_MODEL_KEY");
|
||
if (env_key) {
|
||
key_ = std::string(env_key);
|
||
key_loaded_ = true;
|
||
return true;
|
||
}
|
||
return false;
|
||
}
|
||
|
||
/**
|
||
* 解密模型文件到内存
|
||
* 不在磁盘上生成明文文件,仅在内存中解密
|
||
*/
|
||
std::vector<uint8_t> decrypt_model(const std::string& encrypted_path) {
|
||
std::vector<uint8_t> decrypted_data;
|
||
if (!key_loaded_) return decrypted_data;
|
||
|
||
// 读取加密文件
|
||
// AES-256-CBC解密
|
||
// openssl EVP_DecryptInit_ex(ctx, EVP_aes_256_cbc(), NULL, key, iv);
|
||
// EVP_DecryptUpdate(ctx, output, &out_len, input, in_len);
|
||
// EVP_DecryptFinal_ex(ctx, output + out_len, &final_len);
|
||
|
||
return decrypted_data;
|
||
}
|
||
|
||
/**
|
||
* 加密模型文件
|
||
* 新下载的模型文件加密后存储到本地Flash
|
||
*/
|
||
bool encrypt_model(const std::vector<uint8_t>& data, const std::string& output_path) {
|
||
if (!key_loaded_) return false;
|
||
// AES-256-CBC加密并写入文件
|
||
return true;
|
||
}
|
||
|
||
/**
|
||
* 验证模型文件完整性
|
||
* 计算SHA-256校验和并与元数据中的值比对
|
||
*/
|
||
bool verify_integrity(const std::string& file_path, const std::string& expected_sha256) {
|
||
// 计算文件SHA-256
|
||
// SHA256_CTX sha256;
|
||
// SHA256_Init(&sha256);
|
||
// while (read chunk) SHA256_Update(&sha256, chunk, len);
|
||
// SHA256_Final(hash, &sha256);
|
||
return true;
|
||
}
|
||
|
||
private:
|
||
std::string key_;
|
||
bool key_loaded_;
|
||
};
|
||
|
||
// ==================== 模型版本管理器 ====================
|
||
|
||
/**
|
||
* 模型版本管理器
|
||
* 管理算力盒上所有AI模型的版本、加载、切换
|
||
* 支持A/B分区切换实现热更新
|
||
*/
|
||
class ModelVersionManager {
|
||
public:
|
||
ModelVersionManager(const std::string& models_dir)
|
||
: models_dir_(models_dir) {}
|
||
|
||
/**
|
||
* 注册模型
|
||
* 扫描模型目录,加载所有可用模型的元信息
|
||
*/
|
||
bool register_model(const ModelInfo& info) {
|
||
std::lock_guard<std::mutex> lock(mutex_);
|
||
std::string key = info.name + "@" + info.version;
|
||
models_[key] = info;
|
||
return true;
|
||
}
|
||
|
||
/**
|
||
* 激活指定版本的模型
|
||
* 将旧版本标记为deprecated,新版本标记为active
|
||
*/
|
||
bool activate_version(const std::string& name, const std::string& version) {
|
||
std::lock_guard<std::mutex> lock(mutex_);
|
||
|
||
// 将当前活跃版本设为deprecated
|
||
for (auto& pair : models_) {
|
||
if (pair.second.name == name && pair.second.state == ModelState::ACTIVE) {
|
||
pair.second.state = ModelState::DEPRECATED;
|
||
}
|
||
}
|
||
|
||
// 激活新版本
|
||
std::string key = name + "@" + version;
|
||
auto it = models_.find(key);
|
||
if (it != models_.end()) {
|
||
it->second.state = ModelState::ACTIVE;
|
||
return true;
|
||
}
|
||
return false;
|
||
}
|
||
|
||
/**
|
||
* 获取当前活跃版本的模型信息
|
||
*/
|
||
ModelInfo get_active_model(const std::string& name) {
|
||
std::lock_guard<std::mutex> lock(mutex_);
|
||
for (const auto& pair : models_) {
|
||
if (pair.second.name == name && pair.second.state == ModelState::ACTIVE) {
|
||
return pair.second;
|
||
}
|
||
}
|
||
return ModelInfo{};
|
||
}
|
||
|
||
/**
|
||
* 获取所有模型状态列表
|
||
*/
|
||
std::vector<ModelInfo> get_all_models() {
|
||
std::lock_guard<std::mutex> lock(mutex_);
|
||
std::vector<ModelInfo> result;
|
||
for (const auto& pair : models_) {
|
||
result.push_back(pair.second);
|
||
}
|
||
return result;
|
||
}
|
||
|
||
/**
|
||
* 清理已废弃的旧版本模型文件
|
||
* 保留最近2个版本,删除更早的版本释放存储空间
|
||
*/
|
||
void cleanup_old_versions(const std::string& name, int keep_count = 2) {
|
||
std::lock_guard<std::mutex> lock(mutex_);
|
||
std::vector<std::string> deprecated_keys;
|
||
|
||
for (const auto& pair : models_) {
|
||
if (pair.second.name == name && pair.second.state == ModelState::DEPRECATED) {
|
||
deprecated_keys.push_back(pair.first);
|
||
}
|
||
}
|
||
|
||
// 按版本排序,保留最新的keep_count个
|
||
if (static_cast<int>(deprecated_keys.size()) > keep_count) {
|
||
for (int i = 0; i < static_cast<int>(deprecated_keys.size()) - keep_count; i++) {
|
||
// 删除模型文件并从注册表移除
|
||
models_.erase(deprecated_keys[i]);
|
||
}
|
||
}
|
||
}
|
||
|
||
private:
|
||
std::string models_dir_;
|
||
std::unordered_map<std::string, ModelInfo> models_;
|
||
std::mutex mutex_;
|
||
};
|
||
|
||
// ==================== 云端模型同步器 ====================
|
||
|
||
/**
|
||
* 云端模型同步器
|
||
* 定期检查云端是否有新版本模型,自动下载并部署
|
||
* 通过HTTPS加密通道下载,下载后RSA签名校验
|
||
*/
|
||
class CloudModelSyncer {
|
||
public:
|
||
CloudModelSyncer(const std::string& server_url, const std::string& device_id)
|
||
: server_url_(server_url), device_id_(device_id) {}
|
||
|
||
/**
|
||
* 检查云端是否有模型更新
|
||
* GET /api/v1/model/check-update?device_id=xxx&models=ocr@1.0,math@1.0
|
||
*/
|
||
struct UpdateInfo {
|
||
std::string model_name;
|
||
std::string new_version;
|
||
std::string download_url;
|
||
size_t file_size;
|
||
std::string sha256;
|
||
};
|
||
|
||
std::vector<UpdateInfo> check_updates(const std::vector<ModelInfo>& current_models) {
|
||
std::vector<UpdateInfo> updates;
|
||
// 向云端API发送当前模型版本列表,获取可更新版本
|
||
// HTTPS请求:GET server_url_/api/v1/model/check-update
|
||
return updates;
|
||
}
|
||
|
||
/**
|
||
* 下载模型文件
|
||
* HTTPS下载,支持断点续传
|
||
* 下载完成后进行SHA-256校验和RSA签名验证
|
||
*/
|
||
bool download_model(const UpdateInfo& info, const std::string& save_path) {
|
||
// HTTPS下载
|
||
// 进度回调上报
|
||
// SHA-256校验
|
||
// RSA签名验证(OTA安全:升级包RSA签名+SHA-256校验,防篡改)
|
||
return true;
|
||
}
|
||
|
||
/**
|
||
* 上报模型部署状态
|
||
* POST /api/v1/model/deploy-status
|
||
*/
|
||
void report_deploy_status(const std::string& model_name, const std::string& version,
|
||
bool success, const std::string& error = "") {
|
||
// 向云端上报模型部署结果
|
||
}
|
||
|
||
private:
|
||
std::string server_url_;
|
||
std::string device_id_;
|
||
};
|
||
|
||
// ==================== OTA固件升级管理器 ====================
|
||
|
||
/**
|
||
* OTA固件升级管理器
|
||
* 管理算力盒固件的远程升级
|
||
* 采用A/B双分区方案,升级失败自动回滚
|
||
* 安全设计:升级包RSA签名+SHA-256校验,防篡改
|
||
*/
|
||
class OtaUpgradeManager {
|
||
public:
|
||
enum class OtaState {
|
||
IDLE, // 空闲
|
||
CHECKING, // 检查更新中
|
||
DOWNLOADING, // 下载中
|
||
VERIFYING, // 校验中
|
||
INSTALLING, // 安装中
|
||
REBOOTING, // 重启中
|
||
FAILED // 失败
|
||
};
|
||
|
||
OtaUpgradeManager(const std::string& ota_url, const std::string& device_id)
|
||
: ota_url_(ota_url), device_id_(device_id), state_(OtaState::IDLE),
|
||
current_partition_("A"), download_progress_(0) {}
|
||
|
||
/** 检查固件更新 */
|
||
bool check_update() {
|
||
state_ = OtaState::CHECKING;
|
||
// GET ota_url_/api/v1/ota/check?device_id=xxx&version=xxx
|
||
return false; // 返回是否有新版本
|
||
}
|
||
|
||
/** 下载固件升级包 */
|
||
bool download_firmware(const std::string& download_url) {
|
||
state_ = OtaState::DOWNLOADING;
|
||
// HTTPS分块下载到非活跃分区
|
||
// 支持断点续传
|
||
return true;
|
||
}
|
||
|
||
/** 验证固件包完整性和签名 */
|
||
bool verify_firmware(const std::string& firmware_path) {
|
||
state_ = OtaState::VERIFYING;
|
||
// SHA-256校验
|
||
// RSA-2048签名验证
|
||
return true;
|
||
}
|
||
|
||
/** 安装固件(写入非活跃分区) */
|
||
bool install_firmware() {
|
||
state_ = OtaState::INSTALLING;
|
||
// 写入B分区(如当前运行A分区)
|
||
// 设置下次启动从B分区引导
|
||
return true;
|
||
}
|
||
|
||
/** 回滚到上一版本 */
|
||
bool rollback() {
|
||
// 切换回上一个分区
|
||
std::string target = (current_partition_ == "A") ? "B" : "A";
|
||
// 设置引导分区为target
|
||
return true;
|
||
}
|
||
|
||
/** 获取当前OTA状态 */
|
||
OtaState get_state() const { return state_; }
|
||
int get_progress() const { return download_progress_; }
|
||
std::string get_current_partition() const { return current_partition_; }
|
||
|
||
private:
|
||
std::string ota_url_;
|
||
std::string device_id_;
|
||
OtaState state_;
|
||
std::string current_partition_;
|
||
int download_progress_;
|
||
};
|
||
|
||
// ==================== 系统监控模块 ====================
|
||
|
||
/**
|
||
* 系统运行状态监控
|
||
* 采集CPU、内存、GPU/NPU利用率、温度等硬件指标
|
||
* 为云端监控告警和集群调度提供数据支撑
|
||
*/
|
||
class SystemMonitor {
|
||
public:
|
||
struct SystemMetrics {
|
||
float cpu_usage_percent; // CPU使用率
|
||
float memory_usage_percent; // 内存使用率
|
||
long memory_total_mb; // 总内存
|
||
long memory_used_mb; // 已用内存
|
||
float gpu_usage_percent; // GPU/NPU利用率
|
||
float gpu_memory_usage_mb; // GPU显存使用
|
||
float gpu_temperature_c; // GPU温度
|
||
float disk_usage_percent; // 磁盘使用率
|
||
float network_rx_mbps; // 网络接收速率
|
||
float network_tx_mbps; // 网络发送速率
|
||
long uptime_seconds; // 系统运行时长
|
||
};
|
||
|
||
SystemMonitor() : running_(false) {}
|
||
|
||
/** 启动监控采集线程 */
|
||
void start(int interval_ms = 5000) {
|
||
running_ = true;
|
||
// 定时采集系统指标
|
||
}
|
||
|
||
/** 获取最新系统指标 */
|
||
SystemMetrics get_metrics() {
|
||
SystemMetrics metrics;
|
||
metrics.cpu_usage_percent = read_cpu_usage();
|
||
metrics.memory_usage_percent = read_memory_usage();
|
||
metrics.gpu_usage_percent = read_gpu_usage();
|
||
metrics.gpu_temperature_c = read_gpu_temperature();
|
||
metrics.disk_usage_percent = read_disk_usage();
|
||
return metrics;
|
||
}
|
||
|
||
void stop() { running_ = false; }
|
||
|
||
private:
|
||
float read_cpu_usage() {
|
||
// 读取 /proc/stat 计算CPU使用率
|
||
return 0.0f;
|
||
}
|
||
|
||
float read_memory_usage() {
|
||
// 读取 /proc/meminfo
|
||
return 0.0f;
|
||
}
|
||
|
||
float read_gpu_usage() {
|
||
// NVIDIA: nvidia-smi / NVML
|
||
// 瑞芯微: /sys/class/devfreq/xxx
|
||
return 0.0f;
|
||
}
|
||
|
||
float read_gpu_temperature() {
|
||
// 读取GPU温度传感器
|
||
return 0.0f;
|
||
}
|
||
|
||
float read_disk_usage() {
|
||
// statfs("/")
|
||
return 0.0f;
|
||
}
|
||
|
||
std::atomic<bool> running_;
|
||
};
|
||
|
||
#endif // MODEL_MANAGER_H
|