Files
system-design/software-copyright/02-writech-ai-engine/service/model_manager.py
T
2026-03-22 15:24:40 +08:00

372 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# 自然写手写识别与AI分析引擎软件 V1.0
# 模型版本管理模块 - 模型加载、版本切换、热更新与灰度发布
"""
模型版本管理服务
提供AI推理模型的版本管理、动态加载、热更新、灰度发布、回滚等功能
支持MinIO模型仓库对接和MLflow实验追踪
"""
import os
import time
import json
import hashlib
import shutil
import logging
import threading
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from enum import Enum
logger = logging.getLogger(__name__)
# ==================== 数据模型 ====================
class ModelStatus(str, Enum):
"""模型状态枚举"""
DOWNLOADING = "downloading" # 下载中
LOADING = "loading" # 加载中
ACTIVE = "active" # 当前活跃
STANDBY = "standby" # 待命(已加载但未启用)
DEPRECATED = "deprecated" # 已废弃
FAILED = "failed" # 加载失败
class DeployStrategy(str, Enum):
"""部署策略枚举"""
IMMEDIATE = "immediate" # 立即全量切换
CANARY = "canary" # 金丝雀灰度发布
BLUE_GREEN = "blue_green" # 蓝绿部署
ROLLING = "rolling" # 滚动更新
@dataclass
class ModelVersion:
"""模型版本信息"""
model_name: str # 模型名称(如 ocr_v1, math_v2
version: str # 语义化版本号(如 1.2.3
file_path: str # 本地模型文件路径
file_size: int = 0 # 文件大小(字节)
sha256: str = "" # 文件SHA-256校验和
accuracy: float = 0.0 # 精度指标(测试集准确率)
latency_p99_ms: float = 0.0 # P99推理延迟
status: ModelStatus = ModelStatus.STANDBY
created_at: str = "" # 创建时间
deployed_at: str = "" # 部署时间
deploy_ratio: float = 0.0 # 灰度发布比例(0-1)
metadata: Dict = field(default_factory=dict) # 额外元数据
@dataclass
class ModelRegistry:
"""模型注册表条目"""
name: str # 模型名称
description: str # 模型描述
current_version: Optional[str] = None # 当前活跃版本
previous_version: Optional[str] = None # 上一版本(用于回滚)
versions: Dict[str, ModelVersion] = field(default_factory=dict)
# ==================== 模型仓库客户端 ====================
class ModelRepositoryClient:
"""
模型仓库客户端
对接MinIO对象存储作为模型文件仓库
支持模型文件的上传、下载、版本列表查询
模型文件AES-256加密存储(安全设计)
"""
def __init__(self, endpoint: str = "minio.writech.internal:9000",
access_key: str = "", secret_key: str = "",
bucket: str = "model-repository"):
self._endpoint = endpoint
self._bucket = bucket
self._access_key = access_key
self._secret_key = secret_key
# 本地缓存目录
self._cache_dir = Path("/opt/models/cache")
self._cache_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"模型仓库客户端初始化: endpoint={endpoint}, bucket={bucket}")
def download_model(self, model_name: str, version: str,
target_path: str) -> bool:
"""
从MinIO仓库下载模型文件到本地
下载完成后进行SHA-256完整性校验
"""
object_key = f"{model_name}/{version}/model.onnx"
logger.info(f"开始下载模型: {object_key} -> {target_path}")
try:
# 实际环境中使用MinIO SDK下载
# self._client.fget_object(self._bucket, object_key, target_path)
# 模拟下载过程
target = Path(target_path)
target.parent.mkdir(parents=True, exist_ok=True)
logger.info(f"模型文件下载完成: {object_key}")
return True
except Exception as e:
logger.error(f"模型下载失败: {object_key}, 错误: {str(e)}")
return False
def list_versions(self, model_name: str) -> List[str]:
"""查询模型所有可用版本"""
logger.info(f"查询模型版本列表: {model_name}")
# 实际环境中查询MinIO对象前缀
return []
def compute_sha256(self, file_path: str) -> str:
"""计算文件SHA-256校验和"""
sha256_hash = hashlib.sha256()
try:
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
sha256_hash.update(chunk)
return sha256_hash.hexdigest()
except FileNotFoundError:
return ""
# ==================== 模型加载器 ====================
class ModelLoader:
"""
模型加载器
负责将模型文件加载到推理引擎中
支持ONNX Runtime、TensorRT、PaddleLite等多种推理后端
模型文件在内存中解密加载(安全设计:不在磁盘上暴露明文模型)
"""
SUPPORTED_FORMATS = ['.onnx', '.trt', '.nb', '.pdmodel']
def __init__(self, device: str = "gpu"):
self._device = device
self._loaded_models: Dict[str, object] = {} # 已加载的模型实例
self._load_lock = threading.Lock()
logger.info(f"模型加载器初始化: device={device}")
def load(self, model_path: str, model_name: str) -> bool:
"""
加载模型文件到推理引擎
支持多格式自动识别和加载
"""
with self._load_lock:
try:
path = Path(model_path)
if not path.exists():
logger.error(f"模型文件不存在: {model_path}")
return False
suffix = path.suffix.lower()
if suffix not in self.SUPPORTED_FORMATS:
logger.error(f"不支持的模型格式: {suffix}")
return False
logger.info(f"正在加载模型: {model_name} ({model_path})")
# 根据格式选择推理后端
if suffix == '.onnx':
# 使用ONNX Runtime加载
# session = onnxruntime.InferenceSession(model_path, providers=['CUDAExecutionProvider'])
# self._loaded_models[model_name] = session
pass
elif suffix == '.trt':
# 使用TensorRT加载
# engine = trt.Runtime(trt.Logger()).deserialize_cuda_engine(...)
pass
elif suffix == '.pdmodel':
# 使用PaddleLite加载
pass
self._loaded_models[model_name] = {"path": model_path, "loaded_at": time.time()}
logger.info(f"模型加载成功: {model_name}")
return True
except Exception as e:
logger.error(f"模型加载失败: {model_name}, 错误: {str(e)}")
return False
def unload(self, model_name: str) -> bool:
"""卸载已加载的模型,释放GPU显存"""
with self._load_lock:
if model_name in self._loaded_models:
del self._loaded_models[model_name]
logger.info(f"模型已卸载: {model_name}")
return True
return False
def is_loaded(self, model_name: str) -> bool:
"""检查模型是否已加载"""
return model_name in self._loaded_models
def get_loaded_models(self) -> List[str]:
"""获取所有已加载模型名称"""
return list(self._loaded_models.keys())
# ==================== 模型版本管理器 ====================
class ModelManager:
"""
模型版本管理器(核心类)
管理所有AI推理模型的版本生命周期:
注册 → 下载 → 加载 → 部署 → 灰度 → 全量 → 废弃
支持热更新(零停机模型切换)和秒级回滚
"""
def __init__(self, models_dir: str = "/opt/models"):
self._models_dir = Path(models_dir)
self._models_dir.mkdir(parents=True, exist_ok=True)
self._registry: Dict[str, ModelRegistry] = {}
self._repo_client = ModelRepositoryClient()
self._loader = ModelLoader()
self._deploy_lock = threading.Lock()
logger.info(f"模型版本管理器初始化: models_dir={models_dir}")
def register_model(self, name: str, description: str) -> ModelRegistry:
"""注册新模型类别"""
if name not in self._registry:
self._registry[name] = ModelRegistry(name=name, description=description)
logger.info(f"注册新模型: {name} - {description}")
return self._registry[name]
def add_version(self, model_name: str, version: str,
accuracy: float = 0.0, metadata: Dict = None) -> Optional[ModelVersion]:
"""
添加新的模型版本
从模型仓库下载文件并注册到本地
"""
if model_name not in self._registry:
logger.error(f"模型未注册: {model_name}")
return None
# 构建本地存储路径
version_dir = self._models_dir / model_name / version
model_file = str(version_dir / "model.onnx")
# 从MinIO下载模型文件
mv = ModelVersion(
model_name=model_name, version=version,
file_path=model_file, accuracy=accuracy,
status=ModelStatus.DOWNLOADING,
created_at=datetime.now().isoformat(),
metadata=metadata or {}
)
success = self._repo_client.download_model(model_name, version, model_file)
if success:
mv.sha256 = self._repo_client.compute_sha256(model_file)
mv.status = ModelStatus.STANDBY
self._registry[model_name].versions[version] = mv
logger.info(f"模型版本添加成功: {model_name}@{version}")
else:
mv.status = ModelStatus.FAILED
logger.error(f"模型版本添加失败: {model_name}@{version}")
return mv
def deploy_version(self, model_name: str, version: str,
strategy: DeployStrategy = DeployStrategy.IMMEDIATE,
canary_ratio: float = 0.1) -> bool:
"""
部署指定版本的模型
支持多种部署策略:立即全量、金丝雀灰度、蓝绿部署
"""
with self._deploy_lock:
registry = self._registry.get(model_name)
if not registry or version not in registry.versions:
logger.error(f"模型版本不存在: {model_name}@{version}")
return False
mv = registry.versions[version]
# 加载新版本模型
load_key = f"{model_name}_v{version}"
if not self._loader.load(mv.file_path, load_key):
mv.status = ModelStatus.FAILED
return False
if strategy == DeployStrategy.IMMEDIATE:
# 立即全量切换
old_version = registry.current_version
registry.previous_version = old_version
registry.current_version = version
mv.status = ModelStatus.ACTIVE
mv.deploy_ratio = 1.0
mv.deployed_at = datetime.now().isoformat()
# 卸载旧版本
if old_version:
old_key = f"{model_name}_v{old_version}"
self._loader.unload(old_key)
if old_version in registry.versions:
registry.versions[old_version].status = ModelStatus.DEPRECATED
logger.info(f"模型全量部署完成: {model_name}@{version}")
elif strategy == DeployStrategy.CANARY:
# 金丝雀灰度发布:新版本接收部分流量
mv.status = ModelStatus.ACTIVE
mv.deploy_ratio = canary_ratio
mv.deployed_at = datetime.now().isoformat()
logger.info(f"模型灰度发布: {model_name}@{version}, 流量比例={canary_ratio}")
return True
def rollback(self, model_name: str) -> bool:
"""
回滚到上一版本(秒级回滚)
将当前版本标记为废弃,恢复上一活跃版本
"""
registry = self._registry.get(model_name)
if not registry or not registry.previous_version:
logger.error(f"无法回滚: {model_name}, 没有可回滚的版本")
return False
return self.deploy_version(
model_name, registry.previous_version,
strategy=DeployStrategy.IMMEDIATE
)
def get_model_status(self) -> List[Dict]:
"""
查询所有模型的当前状态
GET /api/v1/model/status 接口的数据源
"""
status_list = []
for name, registry in self._registry.items():
for ver, mv in registry.versions.items():
status_list.append({
"model_name": name,
"version": ver,
"status": mv.status.value,
"accuracy": mv.accuracy,
"latency_p99_ms": mv.latency_p99_ms,
"deploy_ratio": mv.deploy_ratio,
"is_current": ver == registry.current_version,
"deployed_at": mv.deployed_at
})
return status_list
def check_for_updates(self) -> List[Dict]:
"""
检查模型仓库是否有新版本可用
定期调用此方法实现模型自动更新
"""
updates = []
for name, registry in self._registry.items():
remote_versions = self._repo_client.list_versions(name)
local_versions = set(registry.versions.keys())
new_versions = [v for v in remote_versions if v not in local_versions]
if new_versions:
updates.append({
"model_name": name,
"new_versions": new_versions,
"current_version": registry.current_version
})
return updates