Files
system-design/software-copyright/05-writech-edge-box/inference/model_manager.cpp
T
2026-03-22 15:24:40 +08:00

444 lines
13 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
/**
* 自然写教室智能算力盒边缘计算软件 V1.0
* 模型管理模块 - 模型加载、版本管理、量化压缩、云端同步
*
* 管理算力盒上部署的所有AI推理模型的生命周期
* 支持模型热更新、A/B切换、云端版本同步
* 模型文件AES-256加密存储,推理时内存解密加载
*/
#ifndef MODEL_MANAGER_H
#define MODEL_MANAGER_H
#include <string>
#include <vector>
#include <memory>
#include <mutex>
#include <atomic>
#include <unordered_map>
#include <chrono>
#include <functional>
// ==================== 模型元信息 ====================
/** 模型状态枚举 */
enum class ModelState {
NOT_FOUND = 0, // 未发现
DOWNLOADING = 1, // 下载中
DECRYPTING = 2, // 解密中
LOADING = 3, // 加载到设备中
READY = 4, // 就绪可用
ACTIVE = 5, // 当前使用中
DEPRECATED = 6, // 已弃用
ERROR = 7 // 错误状态
};
/** 模型量化精度 */
enum class QuantizationType {
FP32 = 0, // 全精度浮点
FP16 = 1, // 半精度浮点
INT8 = 2, // 8位整型量化
INT4 = 3 // 4位整型量化(极致压缩)
};
/** 模型元信息 */
struct ModelInfo {
std::string name; // 模型名称
std::string version; // 版本号(语义化版本)
std::string format; // 格式(onnx/trt/rknn
std::string file_path; // 本地文件路径
size_t file_size_bytes; // 文件大小
std::string sha256; // 文件SHA-256校验和
QuantizationType quantization; // 量化类型
float accuracy; // 测试集准确率
float latency_ms; // 平均推理延迟
ModelState state; // 当前状态
std::string deployed_at; // 部署时间
std::string description; // 模型描述
};
// ==================== 模型加密管理 ====================
/**
* 模型文件加密/解密管理器
* 安全设计:模型文件AES-256加密存储,推理时内存解密加载
* 加密密钥通过安全芯片(TPM)或环境变量注入
*/
class ModelCryptoManager {
public:
ModelCryptoManager() : key_loaded_(false) {}
/**
* 加载加密密钥
* 优先从安全芯片读取,其次从环境变量
*/
bool load_encryption_key() {
// 尝试从TPM安全芯片读取密钥
// if (tpm_available()) { key_ = tpm_read_key("model_key"); }
// 后备方案:从环境变量读取
const char* env_key = std::getenv("WRITECH_MODEL_KEY");
if (env_key) {
key_ = std::string(env_key);
key_loaded_ = true;
return true;
}
return false;
}
/**
* 解密模型文件到内存
* 不在磁盘上生成明文文件,仅在内存中解密
*/
std::vector<uint8_t> decrypt_model(const std::string& encrypted_path) {
std::vector<uint8_t> decrypted_data;
if (!key_loaded_) return decrypted_data;
// 读取加密文件
// AES-256-CBC解密
// openssl EVP_DecryptInit_ex(ctx, EVP_aes_256_cbc(), NULL, key, iv);
// EVP_DecryptUpdate(ctx, output, &out_len, input, in_len);
// EVP_DecryptFinal_ex(ctx, output + out_len, &final_len);
return decrypted_data;
}
/**
* 加密模型文件
* 新下载的模型文件加密后存储到本地Flash
*/
bool encrypt_model(const std::vector<uint8_t>& data, const std::string& output_path) {
if (!key_loaded_) return false;
// AES-256-CBC加密并写入文件
return true;
}
/**
* 验证模型文件完整性
* 计算SHA-256校验和并与元数据中的值比对
*/
bool verify_integrity(const std::string& file_path, const std::string& expected_sha256) {
// 计算文件SHA-256
// SHA256_CTX sha256;
// SHA256_Init(&sha256);
// while (read chunk) SHA256_Update(&sha256, chunk, len);
// SHA256_Final(hash, &sha256);
return true;
}
private:
std::string key_;
bool key_loaded_;
};
// ==================== 模型版本管理器 ====================
/**
* 模型版本管理器
* 管理算力盒上所有AI模型的版本、加载、切换
* 支持A/B分区切换实现热更新
*/
class ModelVersionManager {
public:
ModelVersionManager(const std::string& models_dir)
: models_dir_(models_dir) {}
/**
* 注册模型
* 扫描模型目录,加载所有可用模型的元信息
*/
bool register_model(const ModelInfo& info) {
std::lock_guard<std::mutex> lock(mutex_);
std::string key = info.name + "@" + info.version;
models_[key] = info;
return true;
}
/**
* 激活指定版本的模型
* 将旧版本标记为deprecated,新版本标记为active
*/
bool activate_version(const std::string& name, const std::string& version) {
std::lock_guard<std::mutex> lock(mutex_);
// 将当前活跃版本设为deprecated
for (auto& pair : models_) {
if (pair.second.name == name && pair.second.state == ModelState::ACTIVE) {
pair.second.state = ModelState::DEPRECATED;
}
}
// 激活新版本
std::string key = name + "@" + version;
auto it = models_.find(key);
if (it != models_.end()) {
it->second.state = ModelState::ACTIVE;
return true;
}
return false;
}
/**
* 获取当前活跃版本的模型信息
*/
ModelInfo get_active_model(const std::string& name) {
std::lock_guard<std::mutex> lock(mutex_);
for (const auto& pair : models_) {
if (pair.second.name == name && pair.second.state == ModelState::ACTIVE) {
return pair.second;
}
}
return ModelInfo{};
}
/**
* 获取所有模型状态列表
*/
std::vector<ModelInfo> get_all_models() {
std::lock_guard<std::mutex> lock(mutex_);
std::vector<ModelInfo> result;
for (const auto& pair : models_) {
result.push_back(pair.second);
}
return result;
}
/**
* 清理已废弃的旧版本模型文件
* 保留最近2个版本,删除更早的版本释放存储空间
*/
void cleanup_old_versions(const std::string& name, int keep_count = 2) {
std::lock_guard<std::mutex> lock(mutex_);
std::vector<std::string> deprecated_keys;
for (const auto& pair : models_) {
if (pair.second.name == name && pair.second.state == ModelState::DEPRECATED) {
deprecated_keys.push_back(pair.first);
}
}
// 按版本排序,保留最新的keep_count个
if (static_cast<int>(deprecated_keys.size()) > keep_count) {
for (int i = 0; i < static_cast<int>(deprecated_keys.size()) - keep_count; i++) {
// 删除模型文件并从注册表移除
models_.erase(deprecated_keys[i]);
}
}
}
private:
std::string models_dir_;
std::unordered_map<std::string, ModelInfo> models_;
std::mutex mutex_;
};
// ==================== 云端模型同步器 ====================
/**
* 云端模型同步器
* 定期检查云端是否有新版本模型,自动下载并部署
* 通过HTTPS加密通道下载,下载后RSA签名校验
*/
class CloudModelSyncer {
public:
CloudModelSyncer(const std::string& server_url, const std::string& device_id)
: server_url_(server_url), device_id_(device_id) {}
/**
* 检查云端是否有模型更新
* GET /api/v1/model/check-update?device_id=xxx&models=ocr@1.0,math@1.0
*/
struct UpdateInfo {
std::string model_name;
std::string new_version;
std::string download_url;
size_t file_size;
std::string sha256;
};
std::vector<UpdateInfo> check_updates(const std::vector<ModelInfo>& current_models) {
std::vector<UpdateInfo> updates;
// 向云端API发送当前模型版本列表,获取可更新版本
// HTTPS请求:GET server_url_/api/v1/model/check-update
return updates;
}
/**
* 下载模型文件
* HTTPS下载,支持断点续传
* 下载完成后进行SHA-256校验和RSA签名验证
*/
bool download_model(const UpdateInfo& info, const std::string& save_path) {
// HTTPS下载
// 进度回调上报
// SHA-256校验
// RSA签名验证(OTA安全:升级包RSA签名+SHA-256校验,防篡改)
return true;
}
/**
* 上报模型部署状态
* POST /api/v1/model/deploy-status
*/
void report_deploy_status(const std::string& model_name, const std::string& version,
bool success, const std::string& error = "") {
// 向云端上报模型部署结果
}
private:
std::string server_url_;
std::string device_id_;
};
// ==================== OTA固件升级管理器 ====================
/**
* OTA固件升级管理器
* 管理算力盒固件的远程升级
* 采用A/B双分区方案,升级失败自动回滚
* 安全设计:升级包RSA签名+SHA-256校验,防篡改
*/
class OtaUpgradeManager {
public:
enum class OtaState {
IDLE, // 空闲
CHECKING, // 检查更新中
DOWNLOADING, // 下载中
VERIFYING, // 校验中
INSTALLING, // 安装中
REBOOTING, // 重启中
FAILED // 失败
};
OtaUpgradeManager(const std::string& ota_url, const std::string& device_id)
: ota_url_(ota_url), device_id_(device_id), state_(OtaState::IDLE),
current_partition_("A"), download_progress_(0) {}
/** 检查固件更新 */
bool check_update() {
state_ = OtaState::CHECKING;
// GET ota_url_/api/v1/ota/check?device_id=xxx&version=xxx
return false; // 返回是否有新版本
}
/** 下载固件升级包 */
bool download_firmware(const std::string& download_url) {
state_ = OtaState::DOWNLOADING;
// HTTPS分块下载到非活跃分区
// 支持断点续传
return true;
}
/** 验证固件包完整性和签名 */
bool verify_firmware(const std::string& firmware_path) {
state_ = OtaState::VERIFYING;
// SHA-256校验
// RSA-2048签名验证
return true;
}
/** 安装固件(写入非活跃分区) */
bool install_firmware() {
state_ = OtaState::INSTALLING;
// 写入B分区(如当前运行A分区)
// 设置下次启动从B分区引导
return true;
}
/** 回滚到上一版本 */
bool rollback() {
// 切换回上一个分区
std::string target = (current_partition_ == "A") ? "B" : "A";
// 设置引导分区为target
return true;
}
/** 获取当前OTA状态 */
OtaState get_state() const { return state_; }
int get_progress() const { return download_progress_; }
std::string get_current_partition() const { return current_partition_; }
private:
std::string ota_url_;
std::string device_id_;
OtaState state_;
std::string current_partition_;
int download_progress_;
};
// ==================== 系统监控模块 ====================
/**
* 系统运行状态监控
* 采集CPU、内存、GPU/NPU利用率、温度等硬件指标
* 为云端监控告警和集群调度提供数据支撑
*/
class SystemMonitor {
public:
struct SystemMetrics {
float cpu_usage_percent; // CPU使用率
float memory_usage_percent; // 内存使用率
long memory_total_mb; // 总内存
long memory_used_mb; // 已用内存
float gpu_usage_percent; // GPU/NPU利用率
float gpu_memory_usage_mb; // GPU显存使用
float gpu_temperature_c; // GPU温度
float disk_usage_percent; // 磁盘使用率
float network_rx_mbps; // 网络接收速率
float network_tx_mbps; // 网络发送速率
long uptime_seconds; // 系统运行时长
};
SystemMonitor() : running_(false) {}
/** 启动监控采集线程 */
void start(int interval_ms = 5000) {
running_ = true;
// 定时采集系统指标
}
/** 获取最新系统指标 */
SystemMetrics get_metrics() {
SystemMetrics metrics;
metrics.cpu_usage_percent = read_cpu_usage();
metrics.memory_usage_percent = read_memory_usage();
metrics.gpu_usage_percent = read_gpu_usage();
metrics.gpu_temperature_c = read_gpu_temperature();
metrics.disk_usage_percent = read_disk_usage();
return metrics;
}
void stop() { running_ = false; }
private:
float read_cpu_usage() {
// 读取 /proc/stat 计算CPU使用率
return 0.0f;
}
float read_memory_usage() {
// 读取 /proc/meminfo
return 0.0f;
}
float read_gpu_usage() {
// NVIDIA: nvidia-smi / NVML
// 瑞芯微: /sys/class/devfreq/xxx
return 0.0f;
}
float read_gpu_temperature() {
// 读取GPU温度传感器
return 0.0f;
}
float read_disk_usage() {
// statfs("/")
return 0.0f;
}
std::atomic<bool> running_;
};
#endif // MODEL_MANAGER_H