315 lines
11 KiB
Python
315 lines
11 KiB
Python
# 自然写手写识别与AI分析引擎软件 V1.0
|
||
# Celery异步任务调度模块 - 识别请求异步处理与优先级调度
|
||
|
||
"""
|
||
Celery任务调度服务
|
||
管理AI识别请求的异步任务队列,支持优先级调度、任务重试、
|
||
结果回调通知、任务进度追踪等功能
|
||
使用Redis作为消息Broker和结果Backend
|
||
"""
|
||
|
||
import time
|
||
import json
|
||
import logging
|
||
import uuid
|
||
from typing import Dict, List, Optional, Any
|
||
from dataclasses import dataclass, field
|
||
from datetime import datetime, timedelta
|
||
from enum import IntEnum
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# ==================== 任务优先级定义 ====================
|
||
|
||
class TaskPriority(IntEnum):
|
||
"""任务优先级(数值越小优先级越高)"""
|
||
CRITICAL = 0 # 最高优先级:课堂实时互动场景
|
||
HIGH = 1 # 高优先级:教师在线批改
|
||
NORMAL = 2 # 普通优先级:作业自动批改
|
||
LOW = 3 # 低优先级:批量历史数据处理
|
||
BACKGROUND = 4 # 后台优先级:模型评估/训练数据生成
|
||
|
||
|
||
class TaskStatus:
|
||
"""任务状态常量"""
|
||
PENDING = "PENDING" # 等待执行
|
||
STARTED = "STARTED" # 已开始执行
|
||
PROCESSING = "PROCESSING" # 处理中
|
||
SUCCESS = "SUCCESS" # 执行成功
|
||
FAILURE = "FAILURE" # 执行失败
|
||
RETRY = "RETRY" # 重试中
|
||
REVOKED = "REVOKED" # 已取消
|
||
|
||
|
||
@dataclass
|
||
class TaskRecord:
|
||
"""任务记录"""
|
||
task_id: str
|
||
task_type: str # 任务类型(ocr/math/stroke_order/essay)
|
||
priority: TaskPriority
|
||
status: str = TaskStatus.PENDING
|
||
input_data: Dict = field(default_factory=dict)
|
||
result: Optional[Dict] = None
|
||
error_message: Optional[str] = None
|
||
retry_count: int = 0
|
||
max_retries: int = 3
|
||
created_at: str = ""
|
||
started_at: Optional[str] = None
|
||
completed_at: Optional[str] = None
|
||
callback_url: Optional[str] = None # 完成后回调通知URL
|
||
student_id: Optional[str] = None
|
||
assignment_id: Optional[str] = None
|
||
|
||
|
||
# ==================== 任务队列管理器 ====================
|
||
|
||
class TaskQueueManager:
|
||
"""
|
||
任务队列管理器
|
||
管理多个优先级队列,确保高优先级任务(如课堂实时互动)优先处理
|
||
使用Redis有序集合(ZSET)实现优先级调度
|
||
"""
|
||
|
||
# 各任务类型的默认队列名
|
||
QUEUE_MAPPING = {
|
||
"ocr": "writech.ocr",
|
||
"math": "writech.math",
|
||
"stroke_order": "writech.stroke_order",
|
||
"essay": "writech.essay",
|
||
"batch": "writech.batch"
|
||
}
|
||
|
||
def __init__(self, redis_url: str = "redis://localhost:6379/0"):
|
||
self._redis_url = redis_url
|
||
self._tasks: Dict[str, TaskRecord] = {} # 内存任务记录(生产环境用Redis)
|
||
self._queue: List[TaskRecord] = [] # 优先级队列
|
||
logger.info(f"任务队列管理器初始化: redis={redis_url}")
|
||
|
||
def submit_task(self, task_type: str, input_data: Dict,
|
||
priority: TaskPriority = TaskPriority.NORMAL,
|
||
callback_url: Optional[str] = None,
|
||
student_id: Optional[str] = None,
|
||
assignment_id: Optional[str] = None) -> str:
|
||
"""
|
||
提交识别任务到队列
|
||
返回任务ID,调用方可通过ID查询任务状态和结果
|
||
"""
|
||
task_id = str(uuid.uuid4())
|
||
|
||
record = TaskRecord(
|
||
task_id=task_id,
|
||
task_type=task_type,
|
||
priority=priority,
|
||
input_data=input_data,
|
||
created_at=datetime.now().isoformat(),
|
||
callback_url=callback_url,
|
||
student_id=student_id,
|
||
assignment_id=assignment_id
|
||
)
|
||
|
||
self._tasks[task_id] = record
|
||
self._queue.append(record)
|
||
# 按优先级排序(数值小的排在前面)
|
||
self._queue.sort(key=lambda t: (t.priority, t.created_at))
|
||
|
||
queue_name = self.QUEUE_MAPPING.get(task_type, "writech.default")
|
||
logger.info(
|
||
f"任务已提交: id={task_id}, type={task_type}, "
|
||
f"priority={priority.name}, queue={queue_name}"
|
||
)
|
||
return task_id
|
||
|
||
def get_next_task(self) -> Optional[TaskRecord]:
|
||
"""获取队列中优先级最高的待执行任务"""
|
||
for task in self._queue:
|
||
if task.status == TaskStatus.PENDING:
|
||
task.status = TaskStatus.STARTED
|
||
task.started_at = datetime.now().isoformat()
|
||
return task
|
||
return None
|
||
|
||
def update_task_status(self, task_id: str, status: str,
|
||
result: Optional[Dict] = None,
|
||
error: Optional[str] = None):
|
||
"""更新任务状态"""
|
||
if task_id in self._tasks:
|
||
task = self._tasks[task_id]
|
||
task.status = status
|
||
if result:
|
||
task.result = result
|
||
if error:
|
||
task.error_message = error
|
||
if status in (TaskStatus.SUCCESS, TaskStatus.FAILURE):
|
||
task.completed_at = datetime.now().isoformat()
|
||
logger.info(f"任务状态更新: id={task_id}, status={status}")
|
||
|
||
def get_task_status(self, task_id: str) -> Optional[Dict]:
|
||
"""查询任务状态和结果"""
|
||
task = self._tasks.get(task_id)
|
||
if not task:
|
||
return None
|
||
return {
|
||
"task_id": task.task_id,
|
||
"task_type": task.task_type,
|
||
"status": task.status,
|
||
"priority": task.priority.name,
|
||
"result": task.result,
|
||
"error_message": task.error_message,
|
||
"retry_count": task.retry_count,
|
||
"created_at": task.created_at,
|
||
"started_at": task.started_at,
|
||
"completed_at": task.completed_at
|
||
}
|
||
|
||
def get_queue_stats(self) -> Dict:
|
||
"""获取队列统计信息"""
|
||
stats = {"total": len(self._tasks)}
|
||
for status in [TaskStatus.PENDING, TaskStatus.STARTED,
|
||
TaskStatus.SUCCESS, TaskStatus.FAILURE]:
|
||
stats[status.lower()] = sum(
|
||
1 for t in self._tasks.values() if t.status == status
|
||
)
|
||
return stats
|
||
|
||
|
||
# ==================== Celery任务定义 ====================
|
||
|
||
class CeleryTaskExecutor:
|
||
"""
|
||
Celery任务执行器
|
||
定义各类AI识别的Celery异步任务
|
||
每个任务类型对应一个独立的任务函数和执行队列
|
||
"""
|
||
|
||
def __init__(self, queue_manager: TaskQueueManager):
|
||
self._queue_manager = queue_manager
|
||
self._task_handlers: Dict[str, callable] = {}
|
||
logger.info("Celery任务执行器初始化")
|
||
|
||
def register_handler(self, task_type: str, handler: callable):
|
||
"""注册任务处理函数"""
|
||
self._task_handlers[task_type] = handler
|
||
logger.info(f"注册任务处理器: {task_type}")
|
||
|
||
def execute_task(self, task_id: str) -> Dict:
|
||
"""
|
||
执行指定任务
|
||
包含异常处理、重试逻辑、超时控制
|
||
"""
|
||
task = self._queue_manager._tasks.get(task_id)
|
||
if not task:
|
||
return {"error": "任务不存在"}
|
||
|
||
handler = self._task_handlers.get(task.task_type)
|
||
if not handler:
|
||
self._queue_manager.update_task_status(
|
||
task_id, TaskStatus.FAILURE,
|
||
error=f"未注册的任务类型: {task.task_type}"
|
||
)
|
||
return {"error": f"未注册的任务类型: {task.task_type}"}
|
||
|
||
try:
|
||
self._queue_manager.update_task_status(task_id, TaskStatus.PROCESSING)
|
||
|
||
# 执行推理任务
|
||
start_time = time.time()
|
||
result = handler(task.input_data)
|
||
elapsed = (time.time() - start_time) * 1000
|
||
|
||
result['processing_time_ms'] = round(elapsed, 2)
|
||
self._queue_manager.update_task_status(task_id, TaskStatus.SUCCESS, result=result)
|
||
|
||
# 审计日志记录(安全设计:所有识别请求记录调用方、时间)
|
||
logger.info(
|
||
f"任务执行完成: id={task_id}, type={task.task_type}, "
|
||
f"time={elapsed:.1f}ms, student={task.student_id}"
|
||
)
|
||
|
||
# 如有回调URL则通知调用方
|
||
if task.callback_url:
|
||
self._send_callback(task.callback_url, task_id, result)
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
task.retry_count += 1
|
||
if task.retry_count < task.max_retries:
|
||
# 重试:将任务重新加入队列
|
||
task.status = TaskStatus.RETRY
|
||
logger.warning(f"任务重试: id={task_id}, retry={task.retry_count}/{task.max_retries}")
|
||
else:
|
||
self._queue_manager.update_task_status(
|
||
task_id, TaskStatus.FAILURE, error=str(e)
|
||
)
|
||
logger.error(f"任务最终失败: id={task_id}, error={str(e)}")
|
||
return {"error": str(e)}
|
||
|
||
def _send_callback(self, url: str, task_id: str, result: Dict):
|
||
"""发送任务完成回调通知"""
|
||
try:
|
||
# 实际环境使用httpx/aiohttp发送POST请求
|
||
logger.info(f"发送任务回调: url={url}, task_id={task_id}")
|
||
except Exception as e:
|
||
logger.error(f"回调通知失败: {str(e)}")
|
||
|
||
|
||
# ==================== 定时调度器 ====================
|
||
|
||
class ScheduledTaskRunner:
|
||
"""
|
||
定时任务调度器
|
||
管理周期性执行的后台任务,如:
|
||
- 模型健康检查(每5分钟)
|
||
- 过期任务清理(每小时)
|
||
- 性能指标采集(每分钟)
|
||
- 模型更新检查(每天)
|
||
"""
|
||
|
||
def __init__(self):
|
||
self._schedules: Dict[str, Dict] = {}
|
||
self._running = False
|
||
logger.info("定时任务调度器初始化")
|
||
|
||
def register_schedule(self, name: str, interval_seconds: int,
|
||
handler: callable, description: str = ""):
|
||
"""注册定时任务"""
|
||
self._schedules[name] = {
|
||
"interval": interval_seconds,
|
||
"handler": handler,
|
||
"description": description,
|
||
"last_run": None,
|
||
"run_count": 0,
|
||
"error_count": 0
|
||
}
|
||
logger.info(f"注册定时任务: {name}, 间隔={interval_seconds}s")
|
||
|
||
def run_task(self, name: str) -> Optional[Dict]:
|
||
"""立即执行指定的定时任务"""
|
||
schedule = self._schedules.get(name)
|
||
if not schedule:
|
||
return None
|
||
|
||
try:
|
||
start = time.time()
|
||
result = schedule["handler"]()
|
||
elapsed = time.time() - start
|
||
schedule["last_run"] = datetime.now().isoformat()
|
||
schedule["run_count"] += 1
|
||
logger.info(f"定时任务执行完成: {name}, 耗时={elapsed:.2f}s")
|
||
return {"name": name, "success": True, "elapsed_s": round(elapsed, 2)}
|
||
except Exception as e:
|
||
schedule["error_count"] += 1
|
||
logger.error(f"定时任务执行失败: {name}, 错误={str(e)}")
|
||
return {"name": name, "success": False, "error": str(e)}
|
||
|
||
def get_schedule_status(self) -> List[Dict]:
|
||
"""获取所有定时任务状态"""
|
||
return [{
|
||
"name": name,
|
||
"interval_seconds": info["interval"],
|
||
"description": info["description"],
|
||
"last_run": info["last_run"],
|
||
"run_count": info["run_count"],
|
||
"error_count": info["error_count"]
|
||
} for name, info in self._schedules.items()]
|