software copyright

This commit is contained in:
jiahong
2026-03-22 15:24:40 +08:00
parent e303bb868a
commit 60f336e345
155 changed files with 127262 additions and 0 deletions
@@ -0,0 +1,371 @@
# 自然写手写识别与AI分析引擎软件 V1.0
# 模型版本管理模块 - 模型加载、版本切换、热更新与灰度发布
"""
模型版本管理服务
提供AI推理模型的版本管理、动态加载、热更新、灰度发布、回滚等功能
支持MinIO模型仓库对接和MLflow实验追踪
"""
import os
import time
import json
import hashlib
import shutil
import logging
import threading
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from enum import Enum
logger = logging.getLogger(__name__)
# ==================== 数据模型 ====================
class ModelStatus(str, Enum):
"""模型状态枚举"""
DOWNLOADING = "downloading" # 下载中
LOADING = "loading" # 加载中
ACTIVE = "active" # 当前活跃
STANDBY = "standby" # 待命(已加载但未启用)
DEPRECATED = "deprecated" # 已废弃
FAILED = "failed" # 加载失败
class DeployStrategy(str, Enum):
"""部署策略枚举"""
IMMEDIATE = "immediate" # 立即全量切换
CANARY = "canary" # 金丝雀灰度发布
BLUE_GREEN = "blue_green" # 蓝绿部署
ROLLING = "rolling" # 滚动更新
@dataclass
class ModelVersion:
"""模型版本信息"""
model_name: str # 模型名称(如 ocr_v1, math_v2
version: str # 语义化版本号(如 1.2.3
file_path: str # 本地模型文件路径
file_size: int = 0 # 文件大小(字节)
sha256: str = "" # 文件SHA-256校验和
accuracy: float = 0.0 # 精度指标(测试集准确率)
latency_p99_ms: float = 0.0 # P99推理延迟
status: ModelStatus = ModelStatus.STANDBY
created_at: str = "" # 创建时间
deployed_at: str = "" # 部署时间
deploy_ratio: float = 0.0 # 灰度发布比例(0-1)
metadata: Dict = field(default_factory=dict) # 额外元数据
@dataclass
class ModelRegistry:
"""模型注册表条目"""
name: str # 模型名称
description: str # 模型描述
current_version: Optional[str] = None # 当前活跃版本
previous_version: Optional[str] = None # 上一版本(用于回滚)
versions: Dict[str, ModelVersion] = field(default_factory=dict)
# ==================== 模型仓库客户端 ====================
class ModelRepositoryClient:
"""
模型仓库客户端
对接MinIO对象存储作为模型文件仓库
支持模型文件的上传、下载、版本列表查询
模型文件AES-256加密存储(安全设计)
"""
def __init__(self, endpoint: str = "minio.writech.internal:9000",
access_key: str = "", secret_key: str = "",
bucket: str = "model-repository"):
self._endpoint = endpoint
self._bucket = bucket
self._access_key = access_key
self._secret_key = secret_key
# 本地缓存目录
self._cache_dir = Path("/opt/models/cache")
self._cache_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"模型仓库客户端初始化: endpoint={endpoint}, bucket={bucket}")
def download_model(self, model_name: str, version: str,
target_path: str) -> bool:
"""
从MinIO仓库下载模型文件到本地
下载完成后进行SHA-256完整性校验
"""
object_key = f"{model_name}/{version}/model.onnx"
logger.info(f"开始下载模型: {object_key} -> {target_path}")
try:
# 实际环境中使用MinIO SDK下载
# self._client.fget_object(self._bucket, object_key, target_path)
# 模拟下载过程
target = Path(target_path)
target.parent.mkdir(parents=True, exist_ok=True)
logger.info(f"模型文件下载完成: {object_key}")
return True
except Exception as e:
logger.error(f"模型下载失败: {object_key}, 错误: {str(e)}")
return False
def list_versions(self, model_name: str) -> List[str]:
"""查询模型所有可用版本"""
logger.info(f"查询模型版本列表: {model_name}")
# 实际环境中查询MinIO对象前缀
return []
def compute_sha256(self, file_path: str) -> str:
"""计算文件SHA-256校验和"""
sha256_hash = hashlib.sha256()
try:
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
sha256_hash.update(chunk)
return sha256_hash.hexdigest()
except FileNotFoundError:
return ""
# ==================== 模型加载器 ====================
class ModelLoader:
"""
模型加载器
负责将模型文件加载到推理引擎中
支持ONNX Runtime、TensorRT、PaddleLite等多种推理后端
模型文件在内存中解密加载(安全设计:不在磁盘上暴露明文模型)
"""
SUPPORTED_FORMATS = ['.onnx', '.trt', '.nb', '.pdmodel']
def __init__(self, device: str = "gpu"):
self._device = device
self._loaded_models: Dict[str, object] = {} # 已加载的模型实例
self._load_lock = threading.Lock()
logger.info(f"模型加载器初始化: device={device}")
def load(self, model_path: str, model_name: str) -> bool:
"""
加载模型文件到推理引擎
支持多格式自动识别和加载
"""
with self._load_lock:
try:
path = Path(model_path)
if not path.exists():
logger.error(f"模型文件不存在: {model_path}")
return False
suffix = path.suffix.lower()
if suffix not in self.SUPPORTED_FORMATS:
logger.error(f"不支持的模型格式: {suffix}")
return False
logger.info(f"正在加载模型: {model_name} ({model_path})")
# 根据格式选择推理后端
if suffix == '.onnx':
# 使用ONNX Runtime加载
# session = onnxruntime.InferenceSession(model_path, providers=['CUDAExecutionProvider'])
# self._loaded_models[model_name] = session
pass
elif suffix == '.trt':
# 使用TensorRT加载
# engine = trt.Runtime(trt.Logger()).deserialize_cuda_engine(...)
pass
elif suffix == '.pdmodel':
# 使用PaddleLite加载
pass
self._loaded_models[model_name] = {"path": model_path, "loaded_at": time.time()}
logger.info(f"模型加载成功: {model_name}")
return True
except Exception as e:
logger.error(f"模型加载失败: {model_name}, 错误: {str(e)}")
return False
def unload(self, model_name: str) -> bool:
"""卸载已加载的模型,释放GPU显存"""
with self._load_lock:
if model_name in self._loaded_models:
del self._loaded_models[model_name]
logger.info(f"模型已卸载: {model_name}")
return True
return False
def is_loaded(self, model_name: str) -> bool:
"""检查模型是否已加载"""
return model_name in self._loaded_models
def get_loaded_models(self) -> List[str]:
"""获取所有已加载模型名称"""
return list(self._loaded_models.keys())
# ==================== 模型版本管理器 ====================
class ModelManager:
"""
模型版本管理器(核心类)
管理所有AI推理模型的版本生命周期:
注册 → 下载 → 加载 → 部署 → 灰度 → 全量 → 废弃
支持热更新(零停机模型切换)和秒级回滚
"""
def __init__(self, models_dir: str = "/opt/models"):
self._models_dir = Path(models_dir)
self._models_dir.mkdir(parents=True, exist_ok=True)
self._registry: Dict[str, ModelRegistry] = {}
self._repo_client = ModelRepositoryClient()
self._loader = ModelLoader()
self._deploy_lock = threading.Lock()
logger.info(f"模型版本管理器初始化: models_dir={models_dir}")
def register_model(self, name: str, description: str) -> ModelRegistry:
"""注册新模型类别"""
if name not in self._registry:
self._registry[name] = ModelRegistry(name=name, description=description)
logger.info(f"注册新模型: {name} - {description}")
return self._registry[name]
def add_version(self, model_name: str, version: str,
accuracy: float = 0.0, metadata: Dict = None) -> Optional[ModelVersion]:
"""
添加新的模型版本
从模型仓库下载文件并注册到本地
"""
if model_name not in self._registry:
logger.error(f"模型未注册: {model_name}")
return None
# 构建本地存储路径
version_dir = self._models_dir / model_name / version
model_file = str(version_dir / "model.onnx")
# 从MinIO下载模型文件
mv = ModelVersion(
model_name=model_name, version=version,
file_path=model_file, accuracy=accuracy,
status=ModelStatus.DOWNLOADING,
created_at=datetime.now().isoformat(),
metadata=metadata or {}
)
success = self._repo_client.download_model(model_name, version, model_file)
if success:
mv.sha256 = self._repo_client.compute_sha256(model_file)
mv.status = ModelStatus.STANDBY
self._registry[model_name].versions[version] = mv
logger.info(f"模型版本添加成功: {model_name}@{version}")
else:
mv.status = ModelStatus.FAILED
logger.error(f"模型版本添加失败: {model_name}@{version}")
return mv
def deploy_version(self, model_name: str, version: str,
strategy: DeployStrategy = DeployStrategy.IMMEDIATE,
canary_ratio: float = 0.1) -> bool:
"""
部署指定版本的模型
支持多种部署策略:立即全量、金丝雀灰度、蓝绿部署
"""
with self._deploy_lock:
registry = self._registry.get(model_name)
if not registry or version not in registry.versions:
logger.error(f"模型版本不存在: {model_name}@{version}")
return False
mv = registry.versions[version]
# 加载新版本模型
load_key = f"{model_name}_v{version}"
if not self._loader.load(mv.file_path, load_key):
mv.status = ModelStatus.FAILED
return False
if strategy == DeployStrategy.IMMEDIATE:
# 立即全量切换
old_version = registry.current_version
registry.previous_version = old_version
registry.current_version = version
mv.status = ModelStatus.ACTIVE
mv.deploy_ratio = 1.0
mv.deployed_at = datetime.now().isoformat()
# 卸载旧版本
if old_version:
old_key = f"{model_name}_v{old_version}"
self._loader.unload(old_key)
if old_version in registry.versions:
registry.versions[old_version].status = ModelStatus.DEPRECATED
logger.info(f"模型全量部署完成: {model_name}@{version}")
elif strategy == DeployStrategy.CANARY:
# 金丝雀灰度发布:新版本接收部分流量
mv.status = ModelStatus.ACTIVE
mv.deploy_ratio = canary_ratio
mv.deployed_at = datetime.now().isoformat()
logger.info(f"模型灰度发布: {model_name}@{version}, 流量比例={canary_ratio}")
return True
def rollback(self, model_name: str) -> bool:
"""
回滚到上一版本(秒级回滚)
将当前版本标记为废弃,恢复上一活跃版本
"""
registry = self._registry.get(model_name)
if not registry or not registry.previous_version:
logger.error(f"无法回滚: {model_name}, 没有可回滚的版本")
return False
return self.deploy_version(
model_name, registry.previous_version,
strategy=DeployStrategy.IMMEDIATE
)
def get_model_status(self) -> List[Dict]:
"""
查询所有模型的当前状态
GET /api/v1/model/status 接口的数据源
"""
status_list = []
for name, registry in self._registry.items():
for ver, mv in registry.versions.items():
status_list.append({
"model_name": name,
"version": ver,
"status": mv.status.value,
"accuracy": mv.accuracy,
"latency_p99_ms": mv.latency_p99_ms,
"deploy_ratio": mv.deploy_ratio,
"is_current": ver == registry.current_version,
"deployed_at": mv.deployed_at
})
return status_list
def check_for_updates(self) -> List[Dict]:
"""
检查模型仓库是否有新版本可用
定期调用此方法实现模型自动更新
"""
updates = []
for name, registry in self._registry.items():
remote_versions = self._repo_client.list_versions(name)
local_versions = set(registry.versions.keys())
new_versions = [v for v in remote_versions if v not in local_versions]
if new_versions:
updates.append({
"model_name": name,
"new_versions": new_versions,
"current_version": registry.current_version
})
return updates
@@ -0,0 +1,314 @@
# 自然写手写识别与AI分析引擎软件 V1.0
# Celery异步任务调度模块 - 识别请求异步处理与优先级调度
"""
Celery任务调度服务
管理AI识别请求的异步任务队列,支持优先级调度、任务重试、
结果回调通知、任务进度追踪等功能
使用Redis作为消息Broker和结果Backend
"""
import time
import json
import logging
import uuid
from typing import Dict, List, Optional, Any
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import IntEnum
logger = logging.getLogger(__name__)
# ==================== 任务优先级定义 ====================
class TaskPriority(IntEnum):
"""任务优先级(数值越小优先级越高)"""
CRITICAL = 0 # 最高优先级:课堂实时互动场景
HIGH = 1 # 高优先级:教师在线批改
NORMAL = 2 # 普通优先级:作业自动批改
LOW = 3 # 低优先级:批量历史数据处理
BACKGROUND = 4 # 后台优先级:模型评估/训练数据生成
class TaskStatus:
"""任务状态常量"""
PENDING = "PENDING" # 等待执行
STARTED = "STARTED" # 已开始执行
PROCESSING = "PROCESSING" # 处理中
SUCCESS = "SUCCESS" # 执行成功
FAILURE = "FAILURE" # 执行失败
RETRY = "RETRY" # 重试中
REVOKED = "REVOKED" # 已取消
@dataclass
class TaskRecord:
"""任务记录"""
task_id: str
task_type: str # 任务类型(ocr/math/stroke_order/essay
priority: TaskPriority
status: str = TaskStatus.PENDING
input_data: Dict = field(default_factory=dict)
result: Optional[Dict] = None
error_message: Optional[str] = None
retry_count: int = 0
max_retries: int = 3
created_at: str = ""
started_at: Optional[str] = None
completed_at: Optional[str] = None
callback_url: Optional[str] = None # 完成后回调通知URL
student_id: Optional[str] = None
assignment_id: Optional[str] = None
# ==================== 任务队列管理器 ====================
class TaskQueueManager:
"""
任务队列管理器
管理多个优先级队列,确保高优先级任务(如课堂实时互动)优先处理
使用Redis有序集合(ZSET)实现优先级调度
"""
# 各任务类型的默认队列名
QUEUE_MAPPING = {
"ocr": "writech.ocr",
"math": "writech.math",
"stroke_order": "writech.stroke_order",
"essay": "writech.essay",
"batch": "writech.batch"
}
def __init__(self, redis_url: str = "redis://localhost:6379/0"):
self._redis_url = redis_url
self._tasks: Dict[str, TaskRecord] = {} # 内存任务记录(生产环境用Redis)
self._queue: List[TaskRecord] = [] # 优先级队列
logger.info(f"任务队列管理器初始化: redis={redis_url}")
def submit_task(self, task_type: str, input_data: Dict,
priority: TaskPriority = TaskPriority.NORMAL,
callback_url: Optional[str] = None,
student_id: Optional[str] = None,
assignment_id: Optional[str] = None) -> str:
"""
提交识别任务到队列
返回任务ID,调用方可通过ID查询任务状态和结果
"""
task_id = str(uuid.uuid4())
record = TaskRecord(
task_id=task_id,
task_type=task_type,
priority=priority,
input_data=input_data,
created_at=datetime.now().isoformat(),
callback_url=callback_url,
student_id=student_id,
assignment_id=assignment_id
)
self._tasks[task_id] = record
self._queue.append(record)
# 按优先级排序(数值小的排在前面)
self._queue.sort(key=lambda t: (t.priority, t.created_at))
queue_name = self.QUEUE_MAPPING.get(task_type, "writech.default")
logger.info(
f"任务已提交: id={task_id}, type={task_type}, "
f"priority={priority.name}, queue={queue_name}"
)
return task_id
def get_next_task(self) -> Optional[TaskRecord]:
"""获取队列中优先级最高的待执行任务"""
for task in self._queue:
if task.status == TaskStatus.PENDING:
task.status = TaskStatus.STARTED
task.started_at = datetime.now().isoformat()
return task
return None
def update_task_status(self, task_id: str, status: str,
result: Optional[Dict] = None,
error: Optional[str] = None):
"""更新任务状态"""
if task_id in self._tasks:
task = self._tasks[task_id]
task.status = status
if result:
task.result = result
if error:
task.error_message = error
if status in (TaskStatus.SUCCESS, TaskStatus.FAILURE):
task.completed_at = datetime.now().isoformat()
logger.info(f"任务状态更新: id={task_id}, status={status}")
def get_task_status(self, task_id: str) -> Optional[Dict]:
"""查询任务状态和结果"""
task = self._tasks.get(task_id)
if not task:
return None
return {
"task_id": task.task_id,
"task_type": task.task_type,
"status": task.status,
"priority": task.priority.name,
"result": task.result,
"error_message": task.error_message,
"retry_count": task.retry_count,
"created_at": task.created_at,
"started_at": task.started_at,
"completed_at": task.completed_at
}
def get_queue_stats(self) -> Dict:
"""获取队列统计信息"""
stats = {"total": len(self._tasks)}
for status in [TaskStatus.PENDING, TaskStatus.STARTED,
TaskStatus.SUCCESS, TaskStatus.FAILURE]:
stats[status.lower()] = sum(
1 for t in self._tasks.values() if t.status == status
)
return stats
# ==================== Celery任务定义 ====================
class CeleryTaskExecutor:
"""
Celery任务执行器
定义各类AI识别的Celery异步任务
每个任务类型对应一个独立的任务函数和执行队列
"""
def __init__(self, queue_manager: TaskQueueManager):
self._queue_manager = queue_manager
self._task_handlers: Dict[str, callable] = {}
logger.info("Celery任务执行器初始化")
def register_handler(self, task_type: str, handler: callable):
"""注册任务处理函数"""
self._task_handlers[task_type] = handler
logger.info(f"注册任务处理器: {task_type}")
def execute_task(self, task_id: str) -> Dict:
"""
执行指定任务
包含异常处理、重试逻辑、超时控制
"""
task = self._queue_manager._tasks.get(task_id)
if not task:
return {"error": "任务不存在"}
handler = self._task_handlers.get(task.task_type)
if not handler:
self._queue_manager.update_task_status(
task_id, TaskStatus.FAILURE,
error=f"未注册的任务类型: {task.task_type}"
)
return {"error": f"未注册的任务类型: {task.task_type}"}
try:
self._queue_manager.update_task_status(task_id, TaskStatus.PROCESSING)
# 执行推理任务
start_time = time.time()
result = handler(task.input_data)
elapsed = (time.time() - start_time) * 1000
result['processing_time_ms'] = round(elapsed, 2)
self._queue_manager.update_task_status(task_id, TaskStatus.SUCCESS, result=result)
# 审计日志记录(安全设计:所有识别请求记录调用方、时间)
logger.info(
f"任务执行完成: id={task_id}, type={task.task_type}, "
f"time={elapsed:.1f}ms, student={task.student_id}"
)
# 如有回调URL则通知调用方
if task.callback_url:
self._send_callback(task.callback_url, task_id, result)
return result
except Exception as e:
task.retry_count += 1
if task.retry_count < task.max_retries:
# 重试:将任务重新加入队列
task.status = TaskStatus.RETRY
logger.warning(f"任务重试: id={task_id}, retry={task.retry_count}/{task.max_retries}")
else:
self._queue_manager.update_task_status(
task_id, TaskStatus.FAILURE, error=str(e)
)
logger.error(f"任务最终失败: id={task_id}, error={str(e)}")
return {"error": str(e)}
def _send_callback(self, url: str, task_id: str, result: Dict):
"""发送任务完成回调通知"""
try:
# 实际环境使用httpx/aiohttp发送POST请求
logger.info(f"发送任务回调: url={url}, task_id={task_id}")
except Exception as e:
logger.error(f"回调通知失败: {str(e)}")
# ==================== 定时调度器 ====================
class ScheduledTaskRunner:
"""
定时任务调度器
管理周期性执行的后台任务,如:
- 模型健康检查(每5分钟)
- 过期任务清理(每小时)
- 性能指标采集(每分钟)
- 模型更新检查(每天)
"""
def __init__(self):
self._schedules: Dict[str, Dict] = {}
self._running = False
logger.info("定时任务调度器初始化")
def register_schedule(self, name: str, interval_seconds: int,
handler: callable, description: str = ""):
"""注册定时任务"""
self._schedules[name] = {
"interval": interval_seconds,
"handler": handler,
"description": description,
"last_run": None,
"run_count": 0,
"error_count": 0
}
logger.info(f"注册定时任务: {name}, 间隔={interval_seconds}s")
def run_task(self, name: str) -> Optional[Dict]:
"""立即执行指定的定时任务"""
schedule = self._schedules.get(name)
if not schedule:
return None
try:
start = time.time()
result = schedule["handler"]()
elapsed = time.time() - start
schedule["last_run"] = datetime.now().isoformat()
schedule["run_count"] += 1
logger.info(f"定时任务执行完成: {name}, 耗时={elapsed:.2f}s")
return {"name": name, "success": True, "elapsed_s": round(elapsed, 2)}
except Exception as e:
schedule["error_count"] += 1
logger.error(f"定时任务执行失败: {name}, 错误={str(e)}")
return {"name": name, "success": False, "error": str(e)}
def get_schedule_status(self) -> List[Dict]:
"""获取所有定时任务状态"""
return [{
"name": name,
"interval_seconds": info["interval"],
"description": info["description"],
"last_run": info["last_run"],
"run_count": info["run_count"],
"error_count": info["error_count"]
} for name, info in self._schedules.items()]