software copyright
This commit is contained in:
@@ -0,0 +1,443 @@
|
||||
/**
|
||||
* 自然写教室智能算力盒边缘计算软件 V1.0
|
||||
* 模型管理模块 - 模型加载、版本管理、量化压缩、云端同步
|
||||
*
|
||||
* 管理算力盒上部署的所有AI推理模型的生命周期
|
||||
* 支持模型热更新、A/B切换、云端版本同步
|
||||
* 模型文件AES-256加密存储,推理时内存解密加载
|
||||
*/
|
||||
|
||||
#ifndef MODEL_MANAGER_H
|
||||
#define MODEL_MANAGER_H
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <atomic>
|
||||
#include <unordered_map>
|
||||
#include <chrono>
|
||||
#include <functional>
|
||||
|
||||
// ==================== 模型元信息 ====================
|
||||
|
||||
/** 模型状态枚举 */
|
||||
enum class ModelState {
|
||||
NOT_FOUND = 0, // 未发现
|
||||
DOWNLOADING = 1, // 下载中
|
||||
DECRYPTING = 2, // 解密中
|
||||
LOADING = 3, // 加载到设备中
|
||||
READY = 4, // 就绪可用
|
||||
ACTIVE = 5, // 当前使用中
|
||||
DEPRECATED = 6, // 已弃用
|
||||
ERROR = 7 // 错误状态
|
||||
};
|
||||
|
||||
/** 模型量化精度 */
|
||||
enum class QuantizationType {
|
||||
FP32 = 0, // 全精度浮点
|
||||
FP16 = 1, // 半精度浮点
|
||||
INT8 = 2, // 8位整型量化
|
||||
INT4 = 3 // 4位整型量化(极致压缩)
|
||||
};
|
||||
|
||||
/** 模型元信息 */
|
||||
struct ModelInfo {
|
||||
std::string name; // 模型名称
|
||||
std::string version; // 版本号(语义化版本)
|
||||
std::string format; // 格式(onnx/trt/rknn)
|
||||
std::string file_path; // 本地文件路径
|
||||
size_t file_size_bytes; // 文件大小
|
||||
std::string sha256; // 文件SHA-256校验和
|
||||
QuantizationType quantization; // 量化类型
|
||||
float accuracy; // 测试集准确率
|
||||
float latency_ms; // 平均推理延迟
|
||||
ModelState state; // 当前状态
|
||||
std::string deployed_at; // 部署时间
|
||||
std::string description; // 模型描述
|
||||
};
|
||||
|
||||
// ==================== 模型加密管理 ====================
|
||||
|
||||
/**
|
||||
* 模型文件加密/解密管理器
|
||||
* 安全设计:模型文件AES-256加密存储,推理时内存解密加载
|
||||
* 加密密钥通过安全芯片(TPM)或环境变量注入
|
||||
*/
|
||||
class ModelCryptoManager {
|
||||
public:
|
||||
ModelCryptoManager() : key_loaded_(false) {}
|
||||
|
||||
/**
|
||||
* 加载加密密钥
|
||||
* 优先从安全芯片读取,其次从环境变量
|
||||
*/
|
||||
bool load_encryption_key() {
|
||||
// 尝试从TPM安全芯片读取密钥
|
||||
// if (tpm_available()) { key_ = tpm_read_key("model_key"); }
|
||||
|
||||
// 后备方案:从环境变量读取
|
||||
const char* env_key = std::getenv("WRITECH_MODEL_KEY");
|
||||
if (env_key) {
|
||||
key_ = std::string(env_key);
|
||||
key_loaded_ = true;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* 解密模型文件到内存
|
||||
* 不在磁盘上生成明文文件,仅在内存中解密
|
||||
*/
|
||||
std::vector<uint8_t> decrypt_model(const std::string& encrypted_path) {
|
||||
std::vector<uint8_t> decrypted_data;
|
||||
if (!key_loaded_) return decrypted_data;
|
||||
|
||||
// 读取加密文件
|
||||
// AES-256-CBC解密
|
||||
// openssl EVP_DecryptInit_ex(ctx, EVP_aes_256_cbc(), NULL, key, iv);
|
||||
// EVP_DecryptUpdate(ctx, output, &out_len, input, in_len);
|
||||
// EVP_DecryptFinal_ex(ctx, output + out_len, &final_len);
|
||||
|
||||
return decrypted_data;
|
||||
}
|
||||
|
||||
/**
|
||||
* 加密模型文件
|
||||
* 新下载的模型文件加密后存储到本地Flash
|
||||
*/
|
||||
bool encrypt_model(const std::vector<uint8_t>& data, const std::string& output_path) {
|
||||
if (!key_loaded_) return false;
|
||||
// AES-256-CBC加密并写入文件
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* 验证模型文件完整性
|
||||
* 计算SHA-256校验和并与元数据中的值比对
|
||||
*/
|
||||
bool verify_integrity(const std::string& file_path, const std::string& expected_sha256) {
|
||||
// 计算文件SHA-256
|
||||
// SHA256_CTX sha256;
|
||||
// SHA256_Init(&sha256);
|
||||
// while (read chunk) SHA256_Update(&sha256, chunk, len);
|
||||
// SHA256_Final(hash, &sha256);
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
std::string key_;
|
||||
bool key_loaded_;
|
||||
};
|
||||
|
||||
// ==================== 模型版本管理器 ====================
|
||||
|
||||
/**
|
||||
* 模型版本管理器
|
||||
* 管理算力盒上所有AI模型的版本、加载、切换
|
||||
* 支持A/B分区切换实现热更新
|
||||
*/
|
||||
class ModelVersionManager {
|
||||
public:
|
||||
ModelVersionManager(const std::string& models_dir)
|
||||
: models_dir_(models_dir) {}
|
||||
|
||||
/**
|
||||
* 注册模型
|
||||
* 扫描模型目录,加载所有可用模型的元信息
|
||||
*/
|
||||
bool register_model(const ModelInfo& info) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
std::string key = info.name + "@" + info.version;
|
||||
models_[key] = info;
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* 激活指定版本的模型
|
||||
* 将旧版本标记为deprecated,新版本标记为active
|
||||
*/
|
||||
bool activate_version(const std::string& name, const std::string& version) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
|
||||
// 将当前活跃版本设为deprecated
|
||||
for (auto& pair : models_) {
|
||||
if (pair.second.name == name && pair.second.state == ModelState::ACTIVE) {
|
||||
pair.second.state = ModelState::DEPRECATED;
|
||||
}
|
||||
}
|
||||
|
||||
// 激活新版本
|
||||
std::string key = name + "@" + version;
|
||||
auto it = models_.find(key);
|
||||
if (it != models_.end()) {
|
||||
it->second.state = ModelState::ACTIVE;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取当前活跃版本的模型信息
|
||||
*/
|
||||
ModelInfo get_active_model(const std::string& name) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
for (const auto& pair : models_) {
|
||||
if (pair.second.name == name && pair.second.state == ModelState::ACTIVE) {
|
||||
return pair.second;
|
||||
}
|
||||
}
|
||||
return ModelInfo{};
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有模型状态列表
|
||||
*/
|
||||
std::vector<ModelInfo> get_all_models() {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
std::vector<ModelInfo> result;
|
||||
for (const auto& pair : models_) {
|
||||
result.push_back(pair.second);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* 清理已废弃的旧版本模型文件
|
||||
* 保留最近2个版本,删除更早的版本释放存储空间
|
||||
*/
|
||||
void cleanup_old_versions(const std::string& name, int keep_count = 2) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
std::vector<std::string> deprecated_keys;
|
||||
|
||||
for (const auto& pair : models_) {
|
||||
if (pair.second.name == name && pair.second.state == ModelState::DEPRECATED) {
|
||||
deprecated_keys.push_back(pair.first);
|
||||
}
|
||||
}
|
||||
|
||||
// 按版本排序,保留最新的keep_count个
|
||||
if (static_cast<int>(deprecated_keys.size()) > keep_count) {
|
||||
for (int i = 0; i < static_cast<int>(deprecated_keys.size()) - keep_count; i++) {
|
||||
// 删除模型文件并从注册表移除
|
||||
models_.erase(deprecated_keys[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::string models_dir_;
|
||||
std::unordered_map<std::string, ModelInfo> models_;
|
||||
std::mutex mutex_;
|
||||
};
|
||||
|
||||
// ==================== 云端模型同步器 ====================
|
||||
|
||||
/**
|
||||
* 云端模型同步器
|
||||
* 定期检查云端是否有新版本模型,自动下载并部署
|
||||
* 通过HTTPS加密通道下载,下载后RSA签名校验
|
||||
*/
|
||||
class CloudModelSyncer {
|
||||
public:
|
||||
CloudModelSyncer(const std::string& server_url, const std::string& device_id)
|
||||
: server_url_(server_url), device_id_(device_id) {}
|
||||
|
||||
/**
|
||||
* 检查云端是否有模型更新
|
||||
* GET /api/v1/model/check-update?device_id=xxx&models=ocr@1.0,math@1.0
|
||||
*/
|
||||
struct UpdateInfo {
|
||||
std::string model_name;
|
||||
std::string new_version;
|
||||
std::string download_url;
|
||||
size_t file_size;
|
||||
std::string sha256;
|
||||
};
|
||||
|
||||
std::vector<UpdateInfo> check_updates(const std::vector<ModelInfo>& current_models) {
|
||||
std::vector<UpdateInfo> updates;
|
||||
// 向云端API发送当前模型版本列表,获取可更新版本
|
||||
// HTTPS请求:GET server_url_/api/v1/model/check-update
|
||||
return updates;
|
||||
}
|
||||
|
||||
/**
|
||||
* 下载模型文件
|
||||
* HTTPS下载,支持断点续传
|
||||
* 下载完成后进行SHA-256校验和RSA签名验证
|
||||
*/
|
||||
bool download_model(const UpdateInfo& info, const std::string& save_path) {
|
||||
// HTTPS下载
|
||||
// 进度回调上报
|
||||
// SHA-256校验
|
||||
// RSA签名验证(OTA安全:升级包RSA签名+SHA-256校验,防篡改)
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* 上报模型部署状态
|
||||
* POST /api/v1/model/deploy-status
|
||||
*/
|
||||
void report_deploy_status(const std::string& model_name, const std::string& version,
|
||||
bool success, const std::string& error = "") {
|
||||
// 向云端上报模型部署结果
|
||||
}
|
||||
|
||||
private:
|
||||
std::string server_url_;
|
||||
std::string device_id_;
|
||||
};
|
||||
|
||||
// ==================== OTA固件升级管理器 ====================
|
||||
|
||||
/**
|
||||
* OTA固件升级管理器
|
||||
* 管理算力盒固件的远程升级
|
||||
* 采用A/B双分区方案,升级失败自动回滚
|
||||
* 安全设计:升级包RSA签名+SHA-256校验,防篡改
|
||||
*/
|
||||
class OtaUpgradeManager {
|
||||
public:
|
||||
enum class OtaState {
|
||||
IDLE, // 空闲
|
||||
CHECKING, // 检查更新中
|
||||
DOWNLOADING, // 下载中
|
||||
VERIFYING, // 校验中
|
||||
INSTALLING, // 安装中
|
||||
REBOOTING, // 重启中
|
||||
FAILED // 失败
|
||||
};
|
||||
|
||||
OtaUpgradeManager(const std::string& ota_url, const std::string& device_id)
|
||||
: ota_url_(ota_url), device_id_(device_id), state_(OtaState::IDLE),
|
||||
current_partition_("A"), download_progress_(0) {}
|
||||
|
||||
/** 检查固件更新 */
|
||||
bool check_update() {
|
||||
state_ = OtaState::CHECKING;
|
||||
// GET ota_url_/api/v1/ota/check?device_id=xxx&version=xxx
|
||||
return false; // 返回是否有新版本
|
||||
}
|
||||
|
||||
/** 下载固件升级包 */
|
||||
bool download_firmware(const std::string& download_url) {
|
||||
state_ = OtaState::DOWNLOADING;
|
||||
// HTTPS分块下载到非活跃分区
|
||||
// 支持断点续传
|
||||
return true;
|
||||
}
|
||||
|
||||
/** 验证固件包完整性和签名 */
|
||||
bool verify_firmware(const std::string& firmware_path) {
|
||||
state_ = OtaState::VERIFYING;
|
||||
// SHA-256校验
|
||||
// RSA-2048签名验证
|
||||
return true;
|
||||
}
|
||||
|
||||
/** 安装固件(写入非活跃分区) */
|
||||
bool install_firmware() {
|
||||
state_ = OtaState::INSTALLING;
|
||||
// 写入B分区(如当前运行A分区)
|
||||
// 设置下次启动从B分区引导
|
||||
return true;
|
||||
}
|
||||
|
||||
/** 回滚到上一版本 */
|
||||
bool rollback() {
|
||||
// 切换回上一个分区
|
||||
std::string target = (current_partition_ == "A") ? "B" : "A";
|
||||
// 设置引导分区为target
|
||||
return true;
|
||||
}
|
||||
|
||||
/** 获取当前OTA状态 */
|
||||
OtaState get_state() const { return state_; }
|
||||
int get_progress() const { return download_progress_; }
|
||||
std::string get_current_partition() const { return current_partition_; }
|
||||
|
||||
private:
|
||||
std::string ota_url_;
|
||||
std::string device_id_;
|
||||
OtaState state_;
|
||||
std::string current_partition_;
|
||||
int download_progress_;
|
||||
};
|
||||
|
||||
// ==================== 系统监控模块 ====================
|
||||
|
||||
/**
|
||||
* 系统运行状态监控
|
||||
* 采集CPU、内存、GPU/NPU利用率、温度等硬件指标
|
||||
* 为云端监控告警和集群调度提供数据支撑
|
||||
*/
|
||||
class SystemMonitor {
|
||||
public:
|
||||
struct SystemMetrics {
|
||||
float cpu_usage_percent; // CPU使用率
|
||||
float memory_usage_percent; // 内存使用率
|
||||
long memory_total_mb; // 总内存
|
||||
long memory_used_mb; // 已用内存
|
||||
float gpu_usage_percent; // GPU/NPU利用率
|
||||
float gpu_memory_usage_mb; // GPU显存使用
|
||||
float gpu_temperature_c; // GPU温度
|
||||
float disk_usage_percent; // 磁盘使用率
|
||||
float network_rx_mbps; // 网络接收速率
|
||||
float network_tx_mbps; // 网络发送速率
|
||||
long uptime_seconds; // 系统运行时长
|
||||
};
|
||||
|
||||
SystemMonitor() : running_(false) {}
|
||||
|
||||
/** 启动监控采集线程 */
|
||||
void start(int interval_ms = 5000) {
|
||||
running_ = true;
|
||||
// 定时采集系统指标
|
||||
}
|
||||
|
||||
/** 获取最新系统指标 */
|
||||
SystemMetrics get_metrics() {
|
||||
SystemMetrics metrics;
|
||||
metrics.cpu_usage_percent = read_cpu_usage();
|
||||
metrics.memory_usage_percent = read_memory_usage();
|
||||
metrics.gpu_usage_percent = read_gpu_usage();
|
||||
metrics.gpu_temperature_c = read_gpu_temperature();
|
||||
metrics.disk_usage_percent = read_disk_usage();
|
||||
return metrics;
|
||||
}
|
||||
|
||||
void stop() { running_ = false; }
|
||||
|
||||
private:
|
||||
float read_cpu_usage() {
|
||||
// 读取 /proc/stat 计算CPU使用率
|
||||
return 0.0f;
|
||||
}
|
||||
|
||||
float read_memory_usage() {
|
||||
// 读取 /proc/meminfo
|
||||
return 0.0f;
|
||||
}
|
||||
|
||||
float read_gpu_usage() {
|
||||
// NVIDIA: nvidia-smi / NVML
|
||||
// 瑞芯微: /sys/class/devfreq/xxx
|
||||
return 0.0f;
|
||||
}
|
||||
|
||||
float read_gpu_temperature() {
|
||||
// 读取GPU温度传感器
|
||||
return 0.0f;
|
||||
}
|
||||
|
||||
float read_disk_usage() {
|
||||
// statfs("/")
|
||||
return 0.0f;
|
||||
}
|
||||
|
||||
std::atomic<bool> running_;
|
||||
};
|
||||
|
||||
#endif // MODEL_MANAGER_H
|
||||
Reference in New Issue
Block a user