software copyright
This commit is contained in:
@@ -0,0 +1,371 @@
|
||||
# 自然写手写识别与AI分析引擎软件 V1.0
|
||||
# 模型版本管理模块 - 模型加载、版本切换、热更新与灰度发布
|
||||
|
||||
"""
|
||||
模型版本管理服务
|
||||
提供AI推理模型的版本管理、动态加载、热更新、灰度发布、回滚等功能
|
||||
支持MinIO模型仓库对接和MLflow实验追踪
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
import hashlib
|
||||
import shutil
|
||||
import logging
|
||||
import threading
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from enum import Enum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ==================== 数据模型 ====================
|
||||
|
||||
class ModelStatus(str, Enum):
|
||||
"""模型状态枚举"""
|
||||
DOWNLOADING = "downloading" # 下载中
|
||||
LOADING = "loading" # 加载中
|
||||
ACTIVE = "active" # 当前活跃
|
||||
STANDBY = "standby" # 待命(已加载但未启用)
|
||||
DEPRECATED = "deprecated" # 已废弃
|
||||
FAILED = "failed" # 加载失败
|
||||
|
||||
|
||||
class DeployStrategy(str, Enum):
|
||||
"""部署策略枚举"""
|
||||
IMMEDIATE = "immediate" # 立即全量切换
|
||||
CANARY = "canary" # 金丝雀灰度发布
|
||||
BLUE_GREEN = "blue_green" # 蓝绿部署
|
||||
ROLLING = "rolling" # 滚动更新
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelVersion:
|
||||
"""模型版本信息"""
|
||||
model_name: str # 模型名称(如 ocr_v1, math_v2)
|
||||
version: str # 语义化版本号(如 1.2.3)
|
||||
file_path: str # 本地模型文件路径
|
||||
file_size: int = 0 # 文件大小(字节)
|
||||
sha256: str = "" # 文件SHA-256校验和
|
||||
accuracy: float = 0.0 # 精度指标(测试集准确率)
|
||||
latency_p99_ms: float = 0.0 # P99推理延迟
|
||||
status: ModelStatus = ModelStatus.STANDBY
|
||||
created_at: str = "" # 创建时间
|
||||
deployed_at: str = "" # 部署时间
|
||||
deploy_ratio: float = 0.0 # 灰度发布比例(0-1)
|
||||
metadata: Dict = field(default_factory=dict) # 额外元数据
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelRegistry:
|
||||
"""模型注册表条目"""
|
||||
name: str # 模型名称
|
||||
description: str # 模型描述
|
||||
current_version: Optional[str] = None # 当前活跃版本
|
||||
previous_version: Optional[str] = None # 上一版本(用于回滚)
|
||||
versions: Dict[str, ModelVersion] = field(default_factory=dict)
|
||||
|
||||
|
||||
# ==================== 模型仓库客户端 ====================
|
||||
|
||||
class ModelRepositoryClient:
|
||||
"""
|
||||
模型仓库客户端
|
||||
对接MinIO对象存储作为模型文件仓库
|
||||
支持模型文件的上传、下载、版本列表查询
|
||||
模型文件AES-256加密存储(安全设计)
|
||||
"""
|
||||
|
||||
def __init__(self, endpoint: str = "minio.writech.internal:9000",
|
||||
access_key: str = "", secret_key: str = "",
|
||||
bucket: str = "model-repository"):
|
||||
self._endpoint = endpoint
|
||||
self._bucket = bucket
|
||||
self._access_key = access_key
|
||||
self._secret_key = secret_key
|
||||
# 本地缓存目录
|
||||
self._cache_dir = Path("/opt/models/cache")
|
||||
self._cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"模型仓库客户端初始化: endpoint={endpoint}, bucket={bucket}")
|
||||
|
||||
def download_model(self, model_name: str, version: str,
|
||||
target_path: str) -> bool:
|
||||
"""
|
||||
从MinIO仓库下载模型文件到本地
|
||||
下载完成后进行SHA-256完整性校验
|
||||
"""
|
||||
object_key = f"{model_name}/{version}/model.onnx"
|
||||
logger.info(f"开始下载模型: {object_key} -> {target_path}")
|
||||
|
||||
try:
|
||||
# 实际环境中使用MinIO SDK下载
|
||||
# self._client.fget_object(self._bucket, object_key, target_path)
|
||||
|
||||
# 模拟下载过程
|
||||
target = Path(target_path)
|
||||
target.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info(f"模型文件下载完成: {object_key}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"模型下载失败: {object_key}, 错误: {str(e)}")
|
||||
return False
|
||||
|
||||
def list_versions(self, model_name: str) -> List[str]:
|
||||
"""查询模型所有可用版本"""
|
||||
logger.info(f"查询模型版本列表: {model_name}")
|
||||
# 实际环境中查询MinIO对象前缀
|
||||
return []
|
||||
|
||||
def compute_sha256(self, file_path: str) -> str:
|
||||
"""计算文件SHA-256校验和"""
|
||||
sha256_hash = hashlib.sha256()
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(8192), b""):
|
||||
sha256_hash.update(chunk)
|
||||
return sha256_hash.hexdigest()
|
||||
except FileNotFoundError:
|
||||
return ""
|
||||
|
||||
|
||||
# ==================== 模型加载器 ====================
|
||||
|
||||
class ModelLoader:
|
||||
"""
|
||||
模型加载器
|
||||
负责将模型文件加载到推理引擎中
|
||||
支持ONNX Runtime、TensorRT、PaddleLite等多种推理后端
|
||||
模型文件在内存中解密加载(安全设计:不在磁盘上暴露明文模型)
|
||||
"""
|
||||
|
||||
SUPPORTED_FORMATS = ['.onnx', '.trt', '.nb', '.pdmodel']
|
||||
|
||||
def __init__(self, device: str = "gpu"):
|
||||
self._device = device
|
||||
self._loaded_models: Dict[str, object] = {} # 已加载的模型实例
|
||||
self._load_lock = threading.Lock()
|
||||
logger.info(f"模型加载器初始化: device={device}")
|
||||
|
||||
def load(self, model_path: str, model_name: str) -> bool:
|
||||
"""
|
||||
加载模型文件到推理引擎
|
||||
支持多格式自动识别和加载
|
||||
"""
|
||||
with self._load_lock:
|
||||
try:
|
||||
path = Path(model_path)
|
||||
if not path.exists():
|
||||
logger.error(f"模型文件不存在: {model_path}")
|
||||
return False
|
||||
|
||||
suffix = path.suffix.lower()
|
||||
if suffix not in self.SUPPORTED_FORMATS:
|
||||
logger.error(f"不支持的模型格式: {suffix}")
|
||||
return False
|
||||
|
||||
logger.info(f"正在加载模型: {model_name} ({model_path})")
|
||||
|
||||
# 根据格式选择推理后端
|
||||
if suffix == '.onnx':
|
||||
# 使用ONNX Runtime加载
|
||||
# session = onnxruntime.InferenceSession(model_path, providers=['CUDAExecutionProvider'])
|
||||
# self._loaded_models[model_name] = session
|
||||
pass
|
||||
elif suffix == '.trt':
|
||||
# 使用TensorRT加载
|
||||
# engine = trt.Runtime(trt.Logger()).deserialize_cuda_engine(...)
|
||||
pass
|
||||
elif suffix == '.pdmodel':
|
||||
# 使用PaddleLite加载
|
||||
pass
|
||||
|
||||
self._loaded_models[model_name] = {"path": model_path, "loaded_at": time.time()}
|
||||
logger.info(f"模型加载成功: {model_name}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"模型加载失败: {model_name}, 错误: {str(e)}")
|
||||
return False
|
||||
|
||||
def unload(self, model_name: str) -> bool:
|
||||
"""卸载已加载的模型,释放GPU显存"""
|
||||
with self._load_lock:
|
||||
if model_name in self._loaded_models:
|
||||
del self._loaded_models[model_name]
|
||||
logger.info(f"模型已卸载: {model_name}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_loaded(self, model_name: str) -> bool:
|
||||
"""检查模型是否已加载"""
|
||||
return model_name in self._loaded_models
|
||||
|
||||
def get_loaded_models(self) -> List[str]:
|
||||
"""获取所有已加载模型名称"""
|
||||
return list(self._loaded_models.keys())
|
||||
|
||||
|
||||
# ==================== 模型版本管理器 ====================
|
||||
|
||||
class ModelManager:
|
||||
"""
|
||||
模型版本管理器(核心类)
|
||||
管理所有AI推理模型的版本生命周期:
|
||||
注册 → 下载 → 加载 → 部署 → 灰度 → 全量 → 废弃
|
||||
支持热更新(零停机模型切换)和秒级回滚
|
||||
"""
|
||||
|
||||
def __init__(self, models_dir: str = "/opt/models"):
|
||||
self._models_dir = Path(models_dir)
|
||||
self._models_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._registry: Dict[str, ModelRegistry] = {}
|
||||
self._repo_client = ModelRepositoryClient()
|
||||
self._loader = ModelLoader()
|
||||
self._deploy_lock = threading.Lock()
|
||||
logger.info(f"模型版本管理器初始化: models_dir={models_dir}")
|
||||
|
||||
def register_model(self, name: str, description: str) -> ModelRegistry:
|
||||
"""注册新模型类别"""
|
||||
if name not in self._registry:
|
||||
self._registry[name] = ModelRegistry(name=name, description=description)
|
||||
logger.info(f"注册新模型: {name} - {description}")
|
||||
return self._registry[name]
|
||||
|
||||
def add_version(self, model_name: str, version: str,
|
||||
accuracy: float = 0.0, metadata: Dict = None) -> Optional[ModelVersion]:
|
||||
"""
|
||||
添加新的模型版本
|
||||
从模型仓库下载文件并注册到本地
|
||||
"""
|
||||
if model_name not in self._registry:
|
||||
logger.error(f"模型未注册: {model_name}")
|
||||
return None
|
||||
|
||||
# 构建本地存储路径
|
||||
version_dir = self._models_dir / model_name / version
|
||||
model_file = str(version_dir / "model.onnx")
|
||||
|
||||
# 从MinIO下载模型文件
|
||||
mv = ModelVersion(
|
||||
model_name=model_name, version=version,
|
||||
file_path=model_file, accuracy=accuracy,
|
||||
status=ModelStatus.DOWNLOADING,
|
||||
created_at=datetime.now().isoformat(),
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
success = self._repo_client.download_model(model_name, version, model_file)
|
||||
if success:
|
||||
mv.sha256 = self._repo_client.compute_sha256(model_file)
|
||||
mv.status = ModelStatus.STANDBY
|
||||
self._registry[model_name].versions[version] = mv
|
||||
logger.info(f"模型版本添加成功: {model_name}@{version}")
|
||||
else:
|
||||
mv.status = ModelStatus.FAILED
|
||||
logger.error(f"模型版本添加失败: {model_name}@{version}")
|
||||
|
||||
return mv
|
||||
|
||||
def deploy_version(self, model_name: str, version: str,
|
||||
strategy: DeployStrategy = DeployStrategy.IMMEDIATE,
|
||||
canary_ratio: float = 0.1) -> bool:
|
||||
"""
|
||||
部署指定版本的模型
|
||||
支持多种部署策略:立即全量、金丝雀灰度、蓝绿部署
|
||||
"""
|
||||
with self._deploy_lock:
|
||||
registry = self._registry.get(model_name)
|
||||
if not registry or version not in registry.versions:
|
||||
logger.error(f"模型版本不存在: {model_name}@{version}")
|
||||
return False
|
||||
|
||||
mv = registry.versions[version]
|
||||
|
||||
# 加载新版本模型
|
||||
load_key = f"{model_name}_v{version}"
|
||||
if not self._loader.load(mv.file_path, load_key):
|
||||
mv.status = ModelStatus.FAILED
|
||||
return False
|
||||
|
||||
if strategy == DeployStrategy.IMMEDIATE:
|
||||
# 立即全量切换
|
||||
old_version = registry.current_version
|
||||
registry.previous_version = old_version
|
||||
registry.current_version = version
|
||||
mv.status = ModelStatus.ACTIVE
|
||||
mv.deploy_ratio = 1.0
|
||||
mv.deployed_at = datetime.now().isoformat()
|
||||
|
||||
# 卸载旧版本
|
||||
if old_version:
|
||||
old_key = f"{model_name}_v{old_version}"
|
||||
self._loader.unload(old_key)
|
||||
if old_version in registry.versions:
|
||||
registry.versions[old_version].status = ModelStatus.DEPRECATED
|
||||
|
||||
logger.info(f"模型全量部署完成: {model_name}@{version}")
|
||||
|
||||
elif strategy == DeployStrategy.CANARY:
|
||||
# 金丝雀灰度发布:新版本接收部分流量
|
||||
mv.status = ModelStatus.ACTIVE
|
||||
mv.deploy_ratio = canary_ratio
|
||||
mv.deployed_at = datetime.now().isoformat()
|
||||
logger.info(f"模型灰度发布: {model_name}@{version}, 流量比例={canary_ratio}")
|
||||
|
||||
return True
|
||||
|
||||
def rollback(self, model_name: str) -> bool:
|
||||
"""
|
||||
回滚到上一版本(秒级回滚)
|
||||
将当前版本标记为废弃,恢复上一活跃版本
|
||||
"""
|
||||
registry = self._registry.get(model_name)
|
||||
if not registry or not registry.previous_version:
|
||||
logger.error(f"无法回滚: {model_name}, 没有可回滚的版本")
|
||||
return False
|
||||
|
||||
return self.deploy_version(
|
||||
model_name, registry.previous_version,
|
||||
strategy=DeployStrategy.IMMEDIATE
|
||||
)
|
||||
|
||||
def get_model_status(self) -> List[Dict]:
|
||||
"""
|
||||
查询所有模型的当前状态
|
||||
GET /api/v1/model/status 接口的数据源
|
||||
"""
|
||||
status_list = []
|
||||
for name, registry in self._registry.items():
|
||||
for ver, mv in registry.versions.items():
|
||||
status_list.append({
|
||||
"model_name": name,
|
||||
"version": ver,
|
||||
"status": mv.status.value,
|
||||
"accuracy": mv.accuracy,
|
||||
"latency_p99_ms": mv.latency_p99_ms,
|
||||
"deploy_ratio": mv.deploy_ratio,
|
||||
"is_current": ver == registry.current_version,
|
||||
"deployed_at": mv.deployed_at
|
||||
})
|
||||
return status_list
|
||||
|
||||
def check_for_updates(self) -> List[Dict]:
|
||||
"""
|
||||
检查模型仓库是否有新版本可用
|
||||
定期调用此方法实现模型自动更新
|
||||
"""
|
||||
updates = []
|
||||
for name, registry in self._registry.items():
|
||||
remote_versions = self._repo_client.list_versions(name)
|
||||
local_versions = set(registry.versions.keys())
|
||||
new_versions = [v for v in remote_versions if v not in local_versions]
|
||||
if new_versions:
|
||||
updates.append({
|
||||
"model_name": name,
|
||||
"new_versions": new_versions,
|
||||
"current_version": registry.current_version
|
||||
})
|
||||
return updates
|
||||
@@ -0,0 +1,314 @@
|
||||
# 自然写手写识别与AI分析引擎软件 V1.0
|
||||
# Celery异步任务调度模块 - 识别请求异步处理与优先级调度
|
||||
|
||||
"""
|
||||
Celery任务调度服务
|
||||
管理AI识别请求的异步任务队列,支持优先级调度、任务重试、
|
||||
结果回调通知、任务进度追踪等功能
|
||||
使用Redis作为消息Broker和结果Backend
|
||||
"""
|
||||
|
||||
import time
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from enum import IntEnum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ==================== 任务优先级定义 ====================
|
||||
|
||||
class TaskPriority(IntEnum):
|
||||
"""任务优先级(数值越小优先级越高)"""
|
||||
CRITICAL = 0 # 最高优先级:课堂实时互动场景
|
||||
HIGH = 1 # 高优先级:教师在线批改
|
||||
NORMAL = 2 # 普通优先级:作业自动批改
|
||||
LOW = 3 # 低优先级:批量历史数据处理
|
||||
BACKGROUND = 4 # 后台优先级:模型评估/训练数据生成
|
||||
|
||||
|
||||
class TaskStatus:
|
||||
"""任务状态常量"""
|
||||
PENDING = "PENDING" # 等待执行
|
||||
STARTED = "STARTED" # 已开始执行
|
||||
PROCESSING = "PROCESSING" # 处理中
|
||||
SUCCESS = "SUCCESS" # 执行成功
|
||||
FAILURE = "FAILURE" # 执行失败
|
||||
RETRY = "RETRY" # 重试中
|
||||
REVOKED = "REVOKED" # 已取消
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskRecord:
|
||||
"""任务记录"""
|
||||
task_id: str
|
||||
task_type: str # 任务类型(ocr/math/stroke_order/essay)
|
||||
priority: TaskPriority
|
||||
status: str = TaskStatus.PENDING
|
||||
input_data: Dict = field(default_factory=dict)
|
||||
result: Optional[Dict] = None
|
||||
error_message: Optional[str] = None
|
||||
retry_count: int = 0
|
||||
max_retries: int = 3
|
||||
created_at: str = ""
|
||||
started_at: Optional[str] = None
|
||||
completed_at: Optional[str] = None
|
||||
callback_url: Optional[str] = None # 完成后回调通知URL
|
||||
student_id: Optional[str] = None
|
||||
assignment_id: Optional[str] = None
|
||||
|
||||
|
||||
# ==================== 任务队列管理器 ====================
|
||||
|
||||
class TaskQueueManager:
|
||||
"""
|
||||
任务队列管理器
|
||||
管理多个优先级队列,确保高优先级任务(如课堂实时互动)优先处理
|
||||
使用Redis有序集合(ZSET)实现优先级调度
|
||||
"""
|
||||
|
||||
# 各任务类型的默认队列名
|
||||
QUEUE_MAPPING = {
|
||||
"ocr": "writech.ocr",
|
||||
"math": "writech.math",
|
||||
"stroke_order": "writech.stroke_order",
|
||||
"essay": "writech.essay",
|
||||
"batch": "writech.batch"
|
||||
}
|
||||
|
||||
def __init__(self, redis_url: str = "redis://localhost:6379/0"):
|
||||
self._redis_url = redis_url
|
||||
self._tasks: Dict[str, TaskRecord] = {} # 内存任务记录(生产环境用Redis)
|
||||
self._queue: List[TaskRecord] = [] # 优先级队列
|
||||
logger.info(f"任务队列管理器初始化: redis={redis_url}")
|
||||
|
||||
def submit_task(self, task_type: str, input_data: Dict,
|
||||
priority: TaskPriority = TaskPriority.NORMAL,
|
||||
callback_url: Optional[str] = None,
|
||||
student_id: Optional[str] = None,
|
||||
assignment_id: Optional[str] = None) -> str:
|
||||
"""
|
||||
提交识别任务到队列
|
||||
返回任务ID,调用方可通过ID查询任务状态和结果
|
||||
"""
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
record = TaskRecord(
|
||||
task_id=task_id,
|
||||
task_type=task_type,
|
||||
priority=priority,
|
||||
input_data=input_data,
|
||||
created_at=datetime.now().isoformat(),
|
||||
callback_url=callback_url,
|
||||
student_id=student_id,
|
||||
assignment_id=assignment_id
|
||||
)
|
||||
|
||||
self._tasks[task_id] = record
|
||||
self._queue.append(record)
|
||||
# 按优先级排序(数值小的排在前面)
|
||||
self._queue.sort(key=lambda t: (t.priority, t.created_at))
|
||||
|
||||
queue_name = self.QUEUE_MAPPING.get(task_type, "writech.default")
|
||||
logger.info(
|
||||
f"任务已提交: id={task_id}, type={task_type}, "
|
||||
f"priority={priority.name}, queue={queue_name}"
|
||||
)
|
||||
return task_id
|
||||
|
||||
def get_next_task(self) -> Optional[TaskRecord]:
|
||||
"""获取队列中优先级最高的待执行任务"""
|
||||
for task in self._queue:
|
||||
if task.status == TaskStatus.PENDING:
|
||||
task.status = TaskStatus.STARTED
|
||||
task.started_at = datetime.now().isoformat()
|
||||
return task
|
||||
return None
|
||||
|
||||
def update_task_status(self, task_id: str, status: str,
|
||||
result: Optional[Dict] = None,
|
||||
error: Optional[str] = None):
|
||||
"""更新任务状态"""
|
||||
if task_id in self._tasks:
|
||||
task = self._tasks[task_id]
|
||||
task.status = status
|
||||
if result:
|
||||
task.result = result
|
||||
if error:
|
||||
task.error_message = error
|
||||
if status in (TaskStatus.SUCCESS, TaskStatus.FAILURE):
|
||||
task.completed_at = datetime.now().isoformat()
|
||||
logger.info(f"任务状态更新: id={task_id}, status={status}")
|
||||
|
||||
def get_task_status(self, task_id: str) -> Optional[Dict]:
|
||||
"""查询任务状态和结果"""
|
||||
task = self._tasks.get(task_id)
|
||||
if not task:
|
||||
return None
|
||||
return {
|
||||
"task_id": task.task_id,
|
||||
"task_type": task.task_type,
|
||||
"status": task.status,
|
||||
"priority": task.priority.name,
|
||||
"result": task.result,
|
||||
"error_message": task.error_message,
|
||||
"retry_count": task.retry_count,
|
||||
"created_at": task.created_at,
|
||||
"started_at": task.started_at,
|
||||
"completed_at": task.completed_at
|
||||
}
|
||||
|
||||
def get_queue_stats(self) -> Dict:
|
||||
"""获取队列统计信息"""
|
||||
stats = {"total": len(self._tasks)}
|
||||
for status in [TaskStatus.PENDING, TaskStatus.STARTED,
|
||||
TaskStatus.SUCCESS, TaskStatus.FAILURE]:
|
||||
stats[status.lower()] = sum(
|
||||
1 for t in self._tasks.values() if t.status == status
|
||||
)
|
||||
return stats
|
||||
|
||||
|
||||
# ==================== Celery任务定义 ====================
|
||||
|
||||
class CeleryTaskExecutor:
|
||||
"""
|
||||
Celery任务执行器
|
||||
定义各类AI识别的Celery异步任务
|
||||
每个任务类型对应一个独立的任务函数和执行队列
|
||||
"""
|
||||
|
||||
def __init__(self, queue_manager: TaskQueueManager):
|
||||
self._queue_manager = queue_manager
|
||||
self._task_handlers: Dict[str, callable] = {}
|
||||
logger.info("Celery任务执行器初始化")
|
||||
|
||||
def register_handler(self, task_type: str, handler: callable):
|
||||
"""注册任务处理函数"""
|
||||
self._task_handlers[task_type] = handler
|
||||
logger.info(f"注册任务处理器: {task_type}")
|
||||
|
||||
def execute_task(self, task_id: str) -> Dict:
|
||||
"""
|
||||
执行指定任务
|
||||
包含异常处理、重试逻辑、超时控制
|
||||
"""
|
||||
task = self._queue_manager._tasks.get(task_id)
|
||||
if not task:
|
||||
return {"error": "任务不存在"}
|
||||
|
||||
handler = self._task_handlers.get(task.task_type)
|
||||
if not handler:
|
||||
self._queue_manager.update_task_status(
|
||||
task_id, TaskStatus.FAILURE,
|
||||
error=f"未注册的任务类型: {task.task_type}"
|
||||
)
|
||||
return {"error": f"未注册的任务类型: {task.task_type}"}
|
||||
|
||||
try:
|
||||
self._queue_manager.update_task_status(task_id, TaskStatus.PROCESSING)
|
||||
|
||||
# 执行推理任务
|
||||
start_time = time.time()
|
||||
result = handler(task.input_data)
|
||||
elapsed = (time.time() - start_time) * 1000
|
||||
|
||||
result['processing_time_ms'] = round(elapsed, 2)
|
||||
self._queue_manager.update_task_status(task_id, TaskStatus.SUCCESS, result=result)
|
||||
|
||||
# 审计日志记录(安全设计:所有识别请求记录调用方、时间)
|
||||
logger.info(
|
||||
f"任务执行完成: id={task_id}, type={task.task_type}, "
|
||||
f"time={elapsed:.1f}ms, student={task.student_id}"
|
||||
)
|
||||
|
||||
# 如有回调URL则通知调用方
|
||||
if task.callback_url:
|
||||
self._send_callback(task.callback_url, task_id, result)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
task.retry_count += 1
|
||||
if task.retry_count < task.max_retries:
|
||||
# 重试:将任务重新加入队列
|
||||
task.status = TaskStatus.RETRY
|
||||
logger.warning(f"任务重试: id={task_id}, retry={task.retry_count}/{task.max_retries}")
|
||||
else:
|
||||
self._queue_manager.update_task_status(
|
||||
task_id, TaskStatus.FAILURE, error=str(e)
|
||||
)
|
||||
logger.error(f"任务最终失败: id={task_id}, error={str(e)}")
|
||||
return {"error": str(e)}
|
||||
|
||||
def _send_callback(self, url: str, task_id: str, result: Dict):
|
||||
"""发送任务完成回调通知"""
|
||||
try:
|
||||
# 实际环境使用httpx/aiohttp发送POST请求
|
||||
logger.info(f"发送任务回调: url={url}, task_id={task_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"回调通知失败: {str(e)}")
|
||||
|
||||
|
||||
# ==================== 定时调度器 ====================
|
||||
|
||||
class ScheduledTaskRunner:
|
||||
"""
|
||||
定时任务调度器
|
||||
管理周期性执行的后台任务,如:
|
||||
- 模型健康检查(每5分钟)
|
||||
- 过期任务清理(每小时)
|
||||
- 性能指标采集(每分钟)
|
||||
- 模型更新检查(每天)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._schedules: Dict[str, Dict] = {}
|
||||
self._running = False
|
||||
logger.info("定时任务调度器初始化")
|
||||
|
||||
def register_schedule(self, name: str, interval_seconds: int,
|
||||
handler: callable, description: str = ""):
|
||||
"""注册定时任务"""
|
||||
self._schedules[name] = {
|
||||
"interval": interval_seconds,
|
||||
"handler": handler,
|
||||
"description": description,
|
||||
"last_run": None,
|
||||
"run_count": 0,
|
||||
"error_count": 0
|
||||
}
|
||||
logger.info(f"注册定时任务: {name}, 间隔={interval_seconds}s")
|
||||
|
||||
def run_task(self, name: str) -> Optional[Dict]:
|
||||
"""立即执行指定的定时任务"""
|
||||
schedule = self._schedules.get(name)
|
||||
if not schedule:
|
||||
return None
|
||||
|
||||
try:
|
||||
start = time.time()
|
||||
result = schedule["handler"]()
|
||||
elapsed = time.time() - start
|
||||
schedule["last_run"] = datetime.now().isoformat()
|
||||
schedule["run_count"] += 1
|
||||
logger.info(f"定时任务执行完成: {name}, 耗时={elapsed:.2f}s")
|
||||
return {"name": name, "success": True, "elapsed_s": round(elapsed, 2)}
|
||||
except Exception as e:
|
||||
schedule["error_count"] += 1
|
||||
logger.error(f"定时任务执行失败: {name}, 错误={str(e)}")
|
||||
return {"name": name, "success": False, "error": str(e)}
|
||||
|
||||
def get_schedule_status(self) -> List[Dict]:
|
||||
"""获取所有定时任务状态"""
|
||||
return [{
|
||||
"name": name,
|
||||
"interval_seconds": info["interval"],
|
||||
"description": info["description"],
|
||||
"last_run": info["last_run"],
|
||||
"run_count": info["run_count"],
|
||||
"error_count": info["error_count"]
|
||||
} for name, info in self._schedules.items()]
|
||||
Reference in New Issue
Block a user