# 自然写手写识别与AI分析引擎软件 V1.0 # 模型版本管理模块 - 模型加载、版本切换、热更新与灰度发布 """ 模型版本管理服务 提供AI推理模型的版本管理、动态加载、热更新、灰度发布、回滚等功能 支持MinIO模型仓库对接和MLflow实验追踪 """ import os import time import json import hashlib import shutil import logging import threading from typing import Dict, List, Optional, Tuple from dataclasses import dataclass, field from datetime import datetime from pathlib import Path from enum import Enum logger = logging.getLogger(__name__) # ==================== 数据模型 ==================== class ModelStatus(str, Enum): """模型状态枚举""" DOWNLOADING = "downloading" # 下载中 LOADING = "loading" # 加载中 ACTIVE = "active" # 当前活跃 STANDBY = "standby" # 待命(已加载但未启用) DEPRECATED = "deprecated" # 已废弃 FAILED = "failed" # 加载失败 class DeployStrategy(str, Enum): """部署策略枚举""" IMMEDIATE = "immediate" # 立即全量切换 CANARY = "canary" # 金丝雀灰度发布 BLUE_GREEN = "blue_green" # 蓝绿部署 ROLLING = "rolling" # 滚动更新 @dataclass class ModelVersion: """模型版本信息""" model_name: str # 模型名称(如 ocr_v1, math_v2) version: str # 语义化版本号(如 1.2.3) file_path: str # 本地模型文件路径 file_size: int = 0 # 文件大小(字节) sha256: str = "" # 文件SHA-256校验和 accuracy: float = 0.0 # 精度指标(测试集准确率) latency_p99_ms: float = 0.0 # P99推理延迟 status: ModelStatus = ModelStatus.STANDBY created_at: str = "" # 创建时间 deployed_at: str = "" # 部署时间 deploy_ratio: float = 0.0 # 灰度发布比例(0-1) metadata: Dict = field(default_factory=dict) # 额外元数据 @dataclass class ModelRegistry: """模型注册表条目""" name: str # 模型名称 description: str # 模型描述 current_version: Optional[str] = None # 当前活跃版本 previous_version: Optional[str] = None # 上一版本(用于回滚) versions: Dict[str, ModelVersion] = field(default_factory=dict) # ==================== 模型仓库客户端 ==================== class ModelRepositoryClient: """ 模型仓库客户端 对接MinIO对象存储作为模型文件仓库 支持模型文件的上传、下载、版本列表查询 模型文件AES-256加密存储(安全设计) """ def __init__(self, endpoint: str = "minio.writech.internal:9000", access_key: str = "", secret_key: str = "", bucket: str = "model-repository"): self._endpoint = endpoint self._bucket = bucket self._access_key = access_key self._secret_key = secret_key # 本地缓存目录 self._cache_dir = Path("/opt/models/cache") self._cache_dir.mkdir(parents=True, exist_ok=True) logger.info(f"模型仓库客户端初始化: endpoint={endpoint}, bucket={bucket}") def download_model(self, model_name: str, version: str, target_path: str) -> bool: """ 从MinIO仓库下载模型文件到本地 下载完成后进行SHA-256完整性校验 """ object_key = f"{model_name}/{version}/model.onnx" logger.info(f"开始下载模型: {object_key} -> {target_path}") try: # 实际环境中使用MinIO SDK下载 # self._client.fget_object(self._bucket, object_key, target_path) # 模拟下载过程 target = Path(target_path) target.parent.mkdir(parents=True, exist_ok=True) logger.info(f"模型文件下载完成: {object_key}") return True except Exception as e: logger.error(f"模型下载失败: {object_key}, 错误: {str(e)}") return False def list_versions(self, model_name: str) -> List[str]: """查询模型所有可用版本""" logger.info(f"查询模型版本列表: {model_name}") # 实际环境中查询MinIO对象前缀 return [] def compute_sha256(self, file_path: str) -> str: """计算文件SHA-256校验和""" sha256_hash = hashlib.sha256() try: with open(file_path, "rb") as f: for chunk in iter(lambda: f.read(8192), b""): sha256_hash.update(chunk) return sha256_hash.hexdigest() except FileNotFoundError: return "" # ==================== 模型加载器 ==================== class ModelLoader: """ 模型加载器 负责将模型文件加载到推理引擎中 支持ONNX Runtime、TensorRT、PaddleLite等多种推理后端 模型文件在内存中解密加载(安全设计:不在磁盘上暴露明文模型) """ SUPPORTED_FORMATS = ['.onnx', '.trt', '.nb', '.pdmodel'] def __init__(self, device: str = "gpu"): self._device = device self._loaded_models: Dict[str, object] = {} # 已加载的模型实例 self._load_lock = threading.Lock() logger.info(f"模型加载器初始化: device={device}") def load(self, model_path: str, model_name: str) -> bool: """ 加载模型文件到推理引擎 支持多格式自动识别和加载 """ with self._load_lock: try: path = Path(model_path) if not path.exists(): logger.error(f"模型文件不存在: {model_path}") return False suffix = path.suffix.lower() if suffix not in self.SUPPORTED_FORMATS: logger.error(f"不支持的模型格式: {suffix}") return False logger.info(f"正在加载模型: {model_name} ({model_path})") # 根据格式选择推理后端 if suffix == '.onnx': # 使用ONNX Runtime加载 # session = onnxruntime.InferenceSession(model_path, providers=['CUDAExecutionProvider']) # self._loaded_models[model_name] = session pass elif suffix == '.trt': # 使用TensorRT加载 # engine = trt.Runtime(trt.Logger()).deserialize_cuda_engine(...) pass elif suffix == '.pdmodel': # 使用PaddleLite加载 pass self._loaded_models[model_name] = {"path": model_path, "loaded_at": time.time()} logger.info(f"模型加载成功: {model_name}") return True except Exception as e: logger.error(f"模型加载失败: {model_name}, 错误: {str(e)}") return False def unload(self, model_name: str) -> bool: """卸载已加载的模型,释放GPU显存""" with self._load_lock: if model_name in self._loaded_models: del self._loaded_models[model_name] logger.info(f"模型已卸载: {model_name}") return True return False def is_loaded(self, model_name: str) -> bool: """检查模型是否已加载""" return model_name in self._loaded_models def get_loaded_models(self) -> List[str]: """获取所有已加载模型名称""" return list(self._loaded_models.keys()) # ==================== 模型版本管理器 ==================== class ModelManager: """ 模型版本管理器(核心类) 管理所有AI推理模型的版本生命周期: 注册 → 下载 → 加载 → 部署 → 灰度 → 全量 → 废弃 支持热更新(零停机模型切换)和秒级回滚 """ def __init__(self, models_dir: str = "/opt/models"): self._models_dir = Path(models_dir) self._models_dir.mkdir(parents=True, exist_ok=True) self._registry: Dict[str, ModelRegistry] = {} self._repo_client = ModelRepositoryClient() self._loader = ModelLoader() self._deploy_lock = threading.Lock() logger.info(f"模型版本管理器初始化: models_dir={models_dir}") def register_model(self, name: str, description: str) -> ModelRegistry: """注册新模型类别""" if name not in self._registry: self._registry[name] = ModelRegistry(name=name, description=description) logger.info(f"注册新模型: {name} - {description}") return self._registry[name] def add_version(self, model_name: str, version: str, accuracy: float = 0.0, metadata: Dict = None) -> Optional[ModelVersion]: """ 添加新的模型版本 从模型仓库下载文件并注册到本地 """ if model_name not in self._registry: logger.error(f"模型未注册: {model_name}") return None # 构建本地存储路径 version_dir = self._models_dir / model_name / version model_file = str(version_dir / "model.onnx") # 从MinIO下载模型文件 mv = ModelVersion( model_name=model_name, version=version, file_path=model_file, accuracy=accuracy, status=ModelStatus.DOWNLOADING, created_at=datetime.now().isoformat(), metadata=metadata or {} ) success = self._repo_client.download_model(model_name, version, model_file) if success: mv.sha256 = self._repo_client.compute_sha256(model_file) mv.status = ModelStatus.STANDBY self._registry[model_name].versions[version] = mv logger.info(f"模型版本添加成功: {model_name}@{version}") else: mv.status = ModelStatus.FAILED logger.error(f"模型版本添加失败: {model_name}@{version}") return mv def deploy_version(self, model_name: str, version: str, strategy: DeployStrategy = DeployStrategy.IMMEDIATE, canary_ratio: float = 0.1) -> bool: """ 部署指定版本的模型 支持多种部署策略:立即全量、金丝雀灰度、蓝绿部署 """ with self._deploy_lock: registry = self._registry.get(model_name) if not registry or version not in registry.versions: logger.error(f"模型版本不存在: {model_name}@{version}") return False mv = registry.versions[version] # 加载新版本模型 load_key = f"{model_name}_v{version}" if not self._loader.load(mv.file_path, load_key): mv.status = ModelStatus.FAILED return False if strategy == DeployStrategy.IMMEDIATE: # 立即全量切换 old_version = registry.current_version registry.previous_version = old_version registry.current_version = version mv.status = ModelStatus.ACTIVE mv.deploy_ratio = 1.0 mv.deployed_at = datetime.now().isoformat() # 卸载旧版本 if old_version: old_key = f"{model_name}_v{old_version}" self._loader.unload(old_key) if old_version in registry.versions: registry.versions[old_version].status = ModelStatus.DEPRECATED logger.info(f"模型全量部署完成: {model_name}@{version}") elif strategy == DeployStrategy.CANARY: # 金丝雀灰度发布:新版本接收部分流量 mv.status = ModelStatus.ACTIVE mv.deploy_ratio = canary_ratio mv.deployed_at = datetime.now().isoformat() logger.info(f"模型灰度发布: {model_name}@{version}, 流量比例={canary_ratio}") return True def rollback(self, model_name: str) -> bool: """ 回滚到上一版本(秒级回滚) 将当前版本标记为废弃,恢复上一活跃版本 """ registry = self._registry.get(model_name) if not registry or not registry.previous_version: logger.error(f"无法回滚: {model_name}, 没有可回滚的版本") return False return self.deploy_version( model_name, registry.previous_version, strategy=DeployStrategy.IMMEDIATE ) def get_model_status(self) -> List[Dict]: """ 查询所有模型的当前状态 GET /api/v1/model/status 接口的数据源 """ status_list = [] for name, registry in self._registry.items(): for ver, mv in registry.versions.items(): status_list.append({ "model_name": name, "version": ver, "status": mv.status.value, "accuracy": mv.accuracy, "latency_p99_ms": mv.latency_p99_ms, "deploy_ratio": mv.deploy_ratio, "is_current": ver == registry.current_version, "deployed_at": mv.deployed_at }) return status_list def check_for_updates(self) -> List[Dict]: """ 检查模型仓库是否有新版本可用 定期调用此方法实现模型自动更新 """ updates = [] for name, registry in self._registry.items(): remote_versions = self._repo_client.list_versions(name) local_versions = set(registry.versions.keys()) new_versions = [v for v in remote_versions if v not in local_versions] if new_versions: updates.append({ "model_name": name, "new_versions": new_versions, "current_version": registry.current_version }) return updates