# 自然写手写识别与AI分析引擎软件 V1.0 # Celery异步任务调度模块 - 识别请求异步处理与优先级调度 """ Celery任务调度服务 管理AI识别请求的异步任务队列,支持优先级调度、任务重试、 结果回调通知、任务进度追踪等功能 使用Redis作为消息Broker和结果Backend """ import time import json import logging import uuid from typing import Dict, List, Optional, Any from dataclasses import dataclass, field from datetime import datetime, timedelta from enum import IntEnum logger = logging.getLogger(__name__) # ==================== 任务优先级定义 ==================== class TaskPriority(IntEnum): """任务优先级(数值越小优先级越高)""" CRITICAL = 0 # 最高优先级:课堂实时互动场景 HIGH = 1 # 高优先级:教师在线批改 NORMAL = 2 # 普通优先级:作业自动批改 LOW = 3 # 低优先级:批量历史数据处理 BACKGROUND = 4 # 后台优先级:模型评估/训练数据生成 class TaskStatus: """任务状态常量""" PENDING = "PENDING" # 等待执行 STARTED = "STARTED" # 已开始执行 PROCESSING = "PROCESSING" # 处理中 SUCCESS = "SUCCESS" # 执行成功 FAILURE = "FAILURE" # 执行失败 RETRY = "RETRY" # 重试中 REVOKED = "REVOKED" # 已取消 @dataclass class TaskRecord: """任务记录""" task_id: str task_type: str # 任务类型(ocr/math/stroke_order/essay) priority: TaskPriority status: str = TaskStatus.PENDING input_data: Dict = field(default_factory=dict) result: Optional[Dict] = None error_message: Optional[str] = None retry_count: int = 0 max_retries: int = 3 created_at: str = "" started_at: Optional[str] = None completed_at: Optional[str] = None callback_url: Optional[str] = None # 完成后回调通知URL student_id: Optional[str] = None assignment_id: Optional[str] = None # ==================== 任务队列管理器 ==================== class TaskQueueManager: """ 任务队列管理器 管理多个优先级队列,确保高优先级任务(如课堂实时互动)优先处理 使用Redis有序集合(ZSET)实现优先级调度 """ # 各任务类型的默认队列名 QUEUE_MAPPING = { "ocr": "writech.ocr", "math": "writech.math", "stroke_order": "writech.stroke_order", "essay": "writech.essay", "batch": "writech.batch" } def __init__(self, redis_url: str = "redis://localhost:6379/0"): self._redis_url = redis_url self._tasks: Dict[str, TaskRecord] = {} # 内存任务记录(生产环境用Redis) self._queue: List[TaskRecord] = [] # 优先级队列 logger.info(f"任务队列管理器初始化: redis={redis_url}") def submit_task(self, task_type: str, input_data: Dict, priority: TaskPriority = TaskPriority.NORMAL, callback_url: Optional[str] = None, student_id: Optional[str] = None, assignment_id: Optional[str] = None) -> str: """ 提交识别任务到队列 返回任务ID,调用方可通过ID查询任务状态和结果 """ task_id = str(uuid.uuid4()) record = TaskRecord( task_id=task_id, task_type=task_type, priority=priority, input_data=input_data, created_at=datetime.now().isoformat(), callback_url=callback_url, student_id=student_id, assignment_id=assignment_id ) self._tasks[task_id] = record self._queue.append(record) # 按优先级排序(数值小的排在前面) self._queue.sort(key=lambda t: (t.priority, t.created_at)) queue_name = self.QUEUE_MAPPING.get(task_type, "writech.default") logger.info( f"任务已提交: id={task_id}, type={task_type}, " f"priority={priority.name}, queue={queue_name}" ) return task_id def get_next_task(self) -> Optional[TaskRecord]: """获取队列中优先级最高的待执行任务""" for task in self._queue: if task.status == TaskStatus.PENDING: task.status = TaskStatus.STARTED task.started_at = datetime.now().isoformat() return task return None def update_task_status(self, task_id: str, status: str, result: Optional[Dict] = None, error: Optional[str] = None): """更新任务状态""" if task_id in self._tasks: task = self._tasks[task_id] task.status = status if result: task.result = result if error: task.error_message = error if status in (TaskStatus.SUCCESS, TaskStatus.FAILURE): task.completed_at = datetime.now().isoformat() logger.info(f"任务状态更新: id={task_id}, status={status}") def get_task_status(self, task_id: str) -> Optional[Dict]: """查询任务状态和结果""" task = self._tasks.get(task_id) if not task: return None return { "task_id": task.task_id, "task_type": task.task_type, "status": task.status, "priority": task.priority.name, "result": task.result, "error_message": task.error_message, "retry_count": task.retry_count, "created_at": task.created_at, "started_at": task.started_at, "completed_at": task.completed_at } def get_queue_stats(self) -> Dict: """获取队列统计信息""" stats = {"total": len(self._tasks)} for status in [TaskStatus.PENDING, TaskStatus.STARTED, TaskStatus.SUCCESS, TaskStatus.FAILURE]: stats[status.lower()] = sum( 1 for t in self._tasks.values() if t.status == status ) return stats # ==================== Celery任务定义 ==================== class CeleryTaskExecutor: """ Celery任务执行器 定义各类AI识别的Celery异步任务 每个任务类型对应一个独立的任务函数和执行队列 """ def __init__(self, queue_manager: TaskQueueManager): self._queue_manager = queue_manager self._task_handlers: Dict[str, callable] = {} logger.info("Celery任务执行器初始化") def register_handler(self, task_type: str, handler: callable): """注册任务处理函数""" self._task_handlers[task_type] = handler logger.info(f"注册任务处理器: {task_type}") def execute_task(self, task_id: str) -> Dict: """ 执行指定任务 包含异常处理、重试逻辑、超时控制 """ task = self._queue_manager._tasks.get(task_id) if not task: return {"error": "任务不存在"} handler = self._task_handlers.get(task.task_type) if not handler: self._queue_manager.update_task_status( task_id, TaskStatus.FAILURE, error=f"未注册的任务类型: {task.task_type}" ) return {"error": f"未注册的任务类型: {task.task_type}"} try: self._queue_manager.update_task_status(task_id, TaskStatus.PROCESSING) # 执行推理任务 start_time = time.time() result = handler(task.input_data) elapsed = (time.time() - start_time) * 1000 result['processing_time_ms'] = round(elapsed, 2) self._queue_manager.update_task_status(task_id, TaskStatus.SUCCESS, result=result) # 审计日志记录(安全设计:所有识别请求记录调用方、时间) logger.info( f"任务执行完成: id={task_id}, type={task.task_type}, " f"time={elapsed:.1f}ms, student={task.student_id}" ) # 如有回调URL则通知调用方 if task.callback_url: self._send_callback(task.callback_url, task_id, result) return result except Exception as e: task.retry_count += 1 if task.retry_count < task.max_retries: # 重试:将任务重新加入队列 task.status = TaskStatus.RETRY logger.warning(f"任务重试: id={task_id}, retry={task.retry_count}/{task.max_retries}") else: self._queue_manager.update_task_status( task_id, TaskStatus.FAILURE, error=str(e) ) logger.error(f"任务最终失败: id={task_id}, error={str(e)}") return {"error": str(e)} def _send_callback(self, url: str, task_id: str, result: Dict): """发送任务完成回调通知""" try: # 实际环境使用httpx/aiohttp发送POST请求 logger.info(f"发送任务回调: url={url}, task_id={task_id}") except Exception as e: logger.error(f"回调通知失败: {str(e)}") # ==================== 定时调度器 ==================== class ScheduledTaskRunner: """ 定时任务调度器 管理周期性执行的后台任务,如: - 模型健康检查(每5分钟) - 过期任务清理(每小时) - 性能指标采集(每分钟) - 模型更新检查(每天) """ def __init__(self): self._schedules: Dict[str, Dict] = {} self._running = False logger.info("定时任务调度器初始化") def register_schedule(self, name: str, interval_seconds: int, handler: callable, description: str = ""): """注册定时任务""" self._schedules[name] = { "interval": interval_seconds, "handler": handler, "description": description, "last_run": None, "run_count": 0, "error_count": 0 } logger.info(f"注册定时任务: {name}, 间隔={interval_seconds}s") def run_task(self, name: str) -> Optional[Dict]: """立即执行指定的定时任务""" schedule = self._schedules.get(name) if not schedule: return None try: start = time.time() result = schedule["handler"]() elapsed = time.time() - start schedule["last_run"] = datetime.now().isoformat() schedule["run_count"] += 1 logger.info(f"定时任务执行完成: {name}, 耗时={elapsed:.2f}s") return {"name": name, "success": True, "elapsed_s": round(elapsed, 2)} except Exception as e: schedule["error_count"] += 1 logger.error(f"定时任务执行失败: {name}, 错误={str(e)}") return {"name": name, "success": False, "error": str(e)} def get_schedule_status(self) -> List[Dict]: """获取所有定时任务状态""" return [{ "name": name, "interval_seconds": info["interval"], "description": info["description"], "last_run": info["last_run"], "run_count": info["run_count"], "error_count": info["error_count"] } for name, info in self._schedules.items()]