372 lines
14 KiB
Python
372 lines
14 KiB
Python
# 自然写手写识别与AI分析引擎软件 V1.0
|
||
# 模型版本管理模块 - 模型加载、版本切换、热更新与灰度发布
|
||
|
||
"""
|
||
模型版本管理服务
|
||
提供AI推理模型的版本管理、动态加载、热更新、灰度发布、回滚等功能
|
||
支持MinIO模型仓库对接和MLflow实验追踪
|
||
"""
|
||
|
||
import os
|
||
import time
|
||
import json
|
||
import hashlib
|
||
import shutil
|
||
import logging
|
||
import threading
|
||
from typing import Dict, List, Optional, Tuple
|
||
from dataclasses import dataclass, field
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
from enum import Enum
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# ==================== 数据模型 ====================
|
||
|
||
class ModelStatus(str, Enum):
|
||
"""模型状态枚举"""
|
||
DOWNLOADING = "downloading" # 下载中
|
||
LOADING = "loading" # 加载中
|
||
ACTIVE = "active" # 当前活跃
|
||
STANDBY = "standby" # 待命(已加载但未启用)
|
||
DEPRECATED = "deprecated" # 已废弃
|
||
FAILED = "failed" # 加载失败
|
||
|
||
|
||
class DeployStrategy(str, Enum):
|
||
"""部署策略枚举"""
|
||
IMMEDIATE = "immediate" # 立即全量切换
|
||
CANARY = "canary" # 金丝雀灰度发布
|
||
BLUE_GREEN = "blue_green" # 蓝绿部署
|
||
ROLLING = "rolling" # 滚动更新
|
||
|
||
|
||
@dataclass
|
||
class ModelVersion:
|
||
"""模型版本信息"""
|
||
model_name: str # 模型名称(如 ocr_v1, math_v2)
|
||
version: str # 语义化版本号(如 1.2.3)
|
||
file_path: str # 本地模型文件路径
|
||
file_size: int = 0 # 文件大小(字节)
|
||
sha256: str = "" # 文件SHA-256校验和
|
||
accuracy: float = 0.0 # 精度指标(测试集准确率)
|
||
latency_p99_ms: float = 0.0 # P99推理延迟
|
||
status: ModelStatus = ModelStatus.STANDBY
|
||
created_at: str = "" # 创建时间
|
||
deployed_at: str = "" # 部署时间
|
||
deploy_ratio: float = 0.0 # 灰度发布比例(0-1)
|
||
metadata: Dict = field(default_factory=dict) # 额外元数据
|
||
|
||
|
||
@dataclass
|
||
class ModelRegistry:
|
||
"""模型注册表条目"""
|
||
name: str # 模型名称
|
||
description: str # 模型描述
|
||
current_version: Optional[str] = None # 当前活跃版本
|
||
previous_version: Optional[str] = None # 上一版本(用于回滚)
|
||
versions: Dict[str, ModelVersion] = field(default_factory=dict)
|
||
|
||
|
||
# ==================== 模型仓库客户端 ====================
|
||
|
||
class ModelRepositoryClient:
|
||
"""
|
||
模型仓库客户端
|
||
对接MinIO对象存储作为模型文件仓库
|
||
支持模型文件的上传、下载、版本列表查询
|
||
模型文件AES-256加密存储(安全设计)
|
||
"""
|
||
|
||
def __init__(self, endpoint: str = "minio.writech.internal:9000",
|
||
access_key: str = "", secret_key: str = "",
|
||
bucket: str = "model-repository"):
|
||
self._endpoint = endpoint
|
||
self._bucket = bucket
|
||
self._access_key = access_key
|
||
self._secret_key = secret_key
|
||
# 本地缓存目录
|
||
self._cache_dir = Path("/opt/models/cache")
|
||
self._cache_dir.mkdir(parents=True, exist_ok=True)
|
||
logger.info(f"模型仓库客户端初始化: endpoint={endpoint}, bucket={bucket}")
|
||
|
||
def download_model(self, model_name: str, version: str,
|
||
target_path: str) -> bool:
|
||
"""
|
||
从MinIO仓库下载模型文件到本地
|
||
下载完成后进行SHA-256完整性校验
|
||
"""
|
||
object_key = f"{model_name}/{version}/model.onnx"
|
||
logger.info(f"开始下载模型: {object_key} -> {target_path}")
|
||
|
||
try:
|
||
# 实际环境中使用MinIO SDK下载
|
||
# self._client.fget_object(self._bucket, object_key, target_path)
|
||
|
||
# 模拟下载过程
|
||
target = Path(target_path)
|
||
target.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
logger.info(f"模型文件下载完成: {object_key}")
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"模型下载失败: {object_key}, 错误: {str(e)}")
|
||
return False
|
||
|
||
def list_versions(self, model_name: str) -> List[str]:
|
||
"""查询模型所有可用版本"""
|
||
logger.info(f"查询模型版本列表: {model_name}")
|
||
# 实际环境中查询MinIO对象前缀
|
||
return []
|
||
|
||
def compute_sha256(self, file_path: str) -> str:
|
||
"""计算文件SHA-256校验和"""
|
||
sha256_hash = hashlib.sha256()
|
||
try:
|
||
with open(file_path, "rb") as f:
|
||
for chunk in iter(lambda: f.read(8192), b""):
|
||
sha256_hash.update(chunk)
|
||
return sha256_hash.hexdigest()
|
||
except FileNotFoundError:
|
||
return ""
|
||
|
||
|
||
# ==================== 模型加载器 ====================
|
||
|
||
class ModelLoader:
|
||
"""
|
||
模型加载器
|
||
负责将模型文件加载到推理引擎中
|
||
支持ONNX Runtime、TensorRT、PaddleLite等多种推理后端
|
||
模型文件在内存中解密加载(安全设计:不在磁盘上暴露明文模型)
|
||
"""
|
||
|
||
SUPPORTED_FORMATS = ['.onnx', '.trt', '.nb', '.pdmodel']
|
||
|
||
def __init__(self, device: str = "gpu"):
|
||
self._device = device
|
||
self._loaded_models: Dict[str, object] = {} # 已加载的模型实例
|
||
self._load_lock = threading.Lock()
|
||
logger.info(f"模型加载器初始化: device={device}")
|
||
|
||
def load(self, model_path: str, model_name: str) -> bool:
|
||
"""
|
||
加载模型文件到推理引擎
|
||
支持多格式自动识别和加载
|
||
"""
|
||
with self._load_lock:
|
||
try:
|
||
path = Path(model_path)
|
||
if not path.exists():
|
||
logger.error(f"模型文件不存在: {model_path}")
|
||
return False
|
||
|
||
suffix = path.suffix.lower()
|
||
if suffix not in self.SUPPORTED_FORMATS:
|
||
logger.error(f"不支持的模型格式: {suffix}")
|
||
return False
|
||
|
||
logger.info(f"正在加载模型: {model_name} ({model_path})")
|
||
|
||
# 根据格式选择推理后端
|
||
if suffix == '.onnx':
|
||
# 使用ONNX Runtime加载
|
||
# session = onnxruntime.InferenceSession(model_path, providers=['CUDAExecutionProvider'])
|
||
# self._loaded_models[model_name] = session
|
||
pass
|
||
elif suffix == '.trt':
|
||
# 使用TensorRT加载
|
||
# engine = trt.Runtime(trt.Logger()).deserialize_cuda_engine(...)
|
||
pass
|
||
elif suffix == '.pdmodel':
|
||
# 使用PaddleLite加载
|
||
pass
|
||
|
||
self._loaded_models[model_name] = {"path": model_path, "loaded_at": time.time()}
|
||
logger.info(f"模型加载成功: {model_name}")
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"模型加载失败: {model_name}, 错误: {str(e)}")
|
||
return False
|
||
|
||
def unload(self, model_name: str) -> bool:
|
||
"""卸载已加载的模型,释放GPU显存"""
|
||
with self._load_lock:
|
||
if model_name in self._loaded_models:
|
||
del self._loaded_models[model_name]
|
||
logger.info(f"模型已卸载: {model_name}")
|
||
return True
|
||
return False
|
||
|
||
def is_loaded(self, model_name: str) -> bool:
|
||
"""检查模型是否已加载"""
|
||
return model_name in self._loaded_models
|
||
|
||
def get_loaded_models(self) -> List[str]:
|
||
"""获取所有已加载模型名称"""
|
||
return list(self._loaded_models.keys())
|
||
|
||
|
||
# ==================== 模型版本管理器 ====================
|
||
|
||
class ModelManager:
|
||
"""
|
||
模型版本管理器(核心类)
|
||
管理所有AI推理模型的版本生命周期:
|
||
注册 → 下载 → 加载 → 部署 → 灰度 → 全量 → 废弃
|
||
支持热更新(零停机模型切换)和秒级回滚
|
||
"""
|
||
|
||
def __init__(self, models_dir: str = "/opt/models"):
|
||
self._models_dir = Path(models_dir)
|
||
self._models_dir.mkdir(parents=True, exist_ok=True)
|
||
self._registry: Dict[str, ModelRegistry] = {}
|
||
self._repo_client = ModelRepositoryClient()
|
||
self._loader = ModelLoader()
|
||
self._deploy_lock = threading.Lock()
|
||
logger.info(f"模型版本管理器初始化: models_dir={models_dir}")
|
||
|
||
def register_model(self, name: str, description: str) -> ModelRegistry:
|
||
"""注册新模型类别"""
|
||
if name not in self._registry:
|
||
self._registry[name] = ModelRegistry(name=name, description=description)
|
||
logger.info(f"注册新模型: {name} - {description}")
|
||
return self._registry[name]
|
||
|
||
def add_version(self, model_name: str, version: str,
|
||
accuracy: float = 0.0, metadata: Dict = None) -> Optional[ModelVersion]:
|
||
"""
|
||
添加新的模型版本
|
||
从模型仓库下载文件并注册到本地
|
||
"""
|
||
if model_name not in self._registry:
|
||
logger.error(f"模型未注册: {model_name}")
|
||
return None
|
||
|
||
# 构建本地存储路径
|
||
version_dir = self._models_dir / model_name / version
|
||
model_file = str(version_dir / "model.onnx")
|
||
|
||
# 从MinIO下载模型文件
|
||
mv = ModelVersion(
|
||
model_name=model_name, version=version,
|
||
file_path=model_file, accuracy=accuracy,
|
||
status=ModelStatus.DOWNLOADING,
|
||
created_at=datetime.now().isoformat(),
|
||
metadata=metadata or {}
|
||
)
|
||
|
||
success = self._repo_client.download_model(model_name, version, model_file)
|
||
if success:
|
||
mv.sha256 = self._repo_client.compute_sha256(model_file)
|
||
mv.status = ModelStatus.STANDBY
|
||
self._registry[model_name].versions[version] = mv
|
||
logger.info(f"模型版本添加成功: {model_name}@{version}")
|
||
else:
|
||
mv.status = ModelStatus.FAILED
|
||
logger.error(f"模型版本添加失败: {model_name}@{version}")
|
||
|
||
return mv
|
||
|
||
def deploy_version(self, model_name: str, version: str,
|
||
strategy: DeployStrategy = DeployStrategy.IMMEDIATE,
|
||
canary_ratio: float = 0.1) -> bool:
|
||
"""
|
||
部署指定版本的模型
|
||
支持多种部署策略:立即全量、金丝雀灰度、蓝绿部署
|
||
"""
|
||
with self._deploy_lock:
|
||
registry = self._registry.get(model_name)
|
||
if not registry or version not in registry.versions:
|
||
logger.error(f"模型版本不存在: {model_name}@{version}")
|
||
return False
|
||
|
||
mv = registry.versions[version]
|
||
|
||
# 加载新版本模型
|
||
load_key = f"{model_name}_v{version}"
|
||
if not self._loader.load(mv.file_path, load_key):
|
||
mv.status = ModelStatus.FAILED
|
||
return False
|
||
|
||
if strategy == DeployStrategy.IMMEDIATE:
|
||
# 立即全量切换
|
||
old_version = registry.current_version
|
||
registry.previous_version = old_version
|
||
registry.current_version = version
|
||
mv.status = ModelStatus.ACTIVE
|
||
mv.deploy_ratio = 1.0
|
||
mv.deployed_at = datetime.now().isoformat()
|
||
|
||
# 卸载旧版本
|
||
if old_version:
|
||
old_key = f"{model_name}_v{old_version}"
|
||
self._loader.unload(old_key)
|
||
if old_version in registry.versions:
|
||
registry.versions[old_version].status = ModelStatus.DEPRECATED
|
||
|
||
logger.info(f"模型全量部署完成: {model_name}@{version}")
|
||
|
||
elif strategy == DeployStrategy.CANARY:
|
||
# 金丝雀灰度发布:新版本接收部分流量
|
||
mv.status = ModelStatus.ACTIVE
|
||
mv.deploy_ratio = canary_ratio
|
||
mv.deployed_at = datetime.now().isoformat()
|
||
logger.info(f"模型灰度发布: {model_name}@{version}, 流量比例={canary_ratio}")
|
||
|
||
return True
|
||
|
||
def rollback(self, model_name: str) -> bool:
|
||
"""
|
||
回滚到上一版本(秒级回滚)
|
||
将当前版本标记为废弃,恢复上一活跃版本
|
||
"""
|
||
registry = self._registry.get(model_name)
|
||
if not registry or not registry.previous_version:
|
||
logger.error(f"无法回滚: {model_name}, 没有可回滚的版本")
|
||
return False
|
||
|
||
return self.deploy_version(
|
||
model_name, registry.previous_version,
|
||
strategy=DeployStrategy.IMMEDIATE
|
||
)
|
||
|
||
def get_model_status(self) -> List[Dict]:
|
||
"""
|
||
查询所有模型的当前状态
|
||
GET /api/v1/model/status 接口的数据源
|
||
"""
|
||
status_list = []
|
||
for name, registry in self._registry.items():
|
||
for ver, mv in registry.versions.items():
|
||
status_list.append({
|
||
"model_name": name,
|
||
"version": ver,
|
||
"status": mv.status.value,
|
||
"accuracy": mv.accuracy,
|
||
"latency_p99_ms": mv.latency_p99_ms,
|
||
"deploy_ratio": mv.deploy_ratio,
|
||
"is_current": ver == registry.current_version,
|
||
"deployed_at": mv.deployed_at
|
||
})
|
||
return status_list
|
||
|
||
def check_for_updates(self) -> List[Dict]:
|
||
"""
|
||
检查模型仓库是否有新版本可用
|
||
定期调用此方法实现模型自动更新
|
||
"""
|
||
updates = []
|
||
for name, registry in self._registry.items():
|
||
remote_versions = self._repo_client.list_versions(name)
|
||
local_versions = set(registry.versions.keys())
|
||
new_versions = [v for v in remote_versions if v not in local_versions]
|
||
if new_versions:
|
||
updates.append({
|
||
"model_name": name,
|
||
"new_versions": new_versions,
|
||
"current_version": registry.current_version
|
||
})
|
||
return updates
|