Files
system-design/software-copyright/02-writech-ai-engine/service/task_scheduler.py
T
2026-03-22 15:24:40 +08:00

315 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# 自然写手写识别与AI分析引擎软件 V1.0
# Celery异步任务调度模块 - 识别请求异步处理与优先级调度
"""
Celery任务调度服务
管理AI识别请求的异步任务队列,支持优先级调度、任务重试、
结果回调通知、任务进度追踪等功能
使用Redis作为消息Broker和结果Backend
"""
import time
import json
import logging
import uuid
from typing import Dict, List, Optional, Any
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import IntEnum
logger = logging.getLogger(__name__)
# ==================== 任务优先级定义 ====================
class TaskPriority(IntEnum):
"""任务优先级(数值越小优先级越高)"""
CRITICAL = 0 # 最高优先级:课堂实时互动场景
HIGH = 1 # 高优先级:教师在线批改
NORMAL = 2 # 普通优先级:作业自动批改
LOW = 3 # 低优先级:批量历史数据处理
BACKGROUND = 4 # 后台优先级:模型评估/训练数据生成
class TaskStatus:
"""任务状态常量"""
PENDING = "PENDING" # 等待执行
STARTED = "STARTED" # 已开始执行
PROCESSING = "PROCESSING" # 处理中
SUCCESS = "SUCCESS" # 执行成功
FAILURE = "FAILURE" # 执行失败
RETRY = "RETRY" # 重试中
REVOKED = "REVOKED" # 已取消
@dataclass
class TaskRecord:
"""任务记录"""
task_id: str
task_type: str # 任务类型(ocr/math/stroke_order/essay
priority: TaskPriority
status: str = TaskStatus.PENDING
input_data: Dict = field(default_factory=dict)
result: Optional[Dict] = None
error_message: Optional[str] = None
retry_count: int = 0
max_retries: int = 3
created_at: str = ""
started_at: Optional[str] = None
completed_at: Optional[str] = None
callback_url: Optional[str] = None # 完成后回调通知URL
student_id: Optional[str] = None
assignment_id: Optional[str] = None
# ==================== 任务队列管理器 ====================
class TaskQueueManager:
"""
任务队列管理器
管理多个优先级队列,确保高优先级任务(如课堂实时互动)优先处理
使用Redis有序集合(ZSET)实现优先级调度
"""
# 各任务类型的默认队列名
QUEUE_MAPPING = {
"ocr": "writech.ocr",
"math": "writech.math",
"stroke_order": "writech.stroke_order",
"essay": "writech.essay",
"batch": "writech.batch"
}
def __init__(self, redis_url: str = "redis://localhost:6379/0"):
self._redis_url = redis_url
self._tasks: Dict[str, TaskRecord] = {} # 内存任务记录(生产环境用Redis)
self._queue: List[TaskRecord] = [] # 优先级队列
logger.info(f"任务队列管理器初始化: redis={redis_url}")
def submit_task(self, task_type: str, input_data: Dict,
priority: TaskPriority = TaskPriority.NORMAL,
callback_url: Optional[str] = None,
student_id: Optional[str] = None,
assignment_id: Optional[str] = None) -> str:
"""
提交识别任务到队列
返回任务ID,调用方可通过ID查询任务状态和结果
"""
task_id = str(uuid.uuid4())
record = TaskRecord(
task_id=task_id,
task_type=task_type,
priority=priority,
input_data=input_data,
created_at=datetime.now().isoformat(),
callback_url=callback_url,
student_id=student_id,
assignment_id=assignment_id
)
self._tasks[task_id] = record
self._queue.append(record)
# 按优先级排序(数值小的排在前面)
self._queue.sort(key=lambda t: (t.priority, t.created_at))
queue_name = self.QUEUE_MAPPING.get(task_type, "writech.default")
logger.info(
f"任务已提交: id={task_id}, type={task_type}, "
f"priority={priority.name}, queue={queue_name}"
)
return task_id
def get_next_task(self) -> Optional[TaskRecord]:
"""获取队列中优先级最高的待执行任务"""
for task in self._queue:
if task.status == TaskStatus.PENDING:
task.status = TaskStatus.STARTED
task.started_at = datetime.now().isoformat()
return task
return None
def update_task_status(self, task_id: str, status: str,
result: Optional[Dict] = None,
error: Optional[str] = None):
"""更新任务状态"""
if task_id in self._tasks:
task = self._tasks[task_id]
task.status = status
if result:
task.result = result
if error:
task.error_message = error
if status in (TaskStatus.SUCCESS, TaskStatus.FAILURE):
task.completed_at = datetime.now().isoformat()
logger.info(f"任务状态更新: id={task_id}, status={status}")
def get_task_status(self, task_id: str) -> Optional[Dict]:
"""查询任务状态和结果"""
task = self._tasks.get(task_id)
if not task:
return None
return {
"task_id": task.task_id,
"task_type": task.task_type,
"status": task.status,
"priority": task.priority.name,
"result": task.result,
"error_message": task.error_message,
"retry_count": task.retry_count,
"created_at": task.created_at,
"started_at": task.started_at,
"completed_at": task.completed_at
}
def get_queue_stats(self) -> Dict:
"""获取队列统计信息"""
stats = {"total": len(self._tasks)}
for status in [TaskStatus.PENDING, TaskStatus.STARTED,
TaskStatus.SUCCESS, TaskStatus.FAILURE]:
stats[status.lower()] = sum(
1 for t in self._tasks.values() if t.status == status
)
return stats
# ==================== Celery任务定义 ====================
class CeleryTaskExecutor:
"""
Celery任务执行器
定义各类AI识别的Celery异步任务
每个任务类型对应一个独立的任务函数和执行队列
"""
def __init__(self, queue_manager: TaskQueueManager):
self._queue_manager = queue_manager
self._task_handlers: Dict[str, callable] = {}
logger.info("Celery任务执行器初始化")
def register_handler(self, task_type: str, handler: callable):
"""注册任务处理函数"""
self._task_handlers[task_type] = handler
logger.info(f"注册任务处理器: {task_type}")
def execute_task(self, task_id: str) -> Dict:
"""
执行指定任务
包含异常处理、重试逻辑、超时控制
"""
task = self._queue_manager._tasks.get(task_id)
if not task:
return {"error": "任务不存在"}
handler = self._task_handlers.get(task.task_type)
if not handler:
self._queue_manager.update_task_status(
task_id, TaskStatus.FAILURE,
error=f"未注册的任务类型: {task.task_type}"
)
return {"error": f"未注册的任务类型: {task.task_type}"}
try:
self._queue_manager.update_task_status(task_id, TaskStatus.PROCESSING)
# 执行推理任务
start_time = time.time()
result = handler(task.input_data)
elapsed = (time.time() - start_time) * 1000
result['processing_time_ms'] = round(elapsed, 2)
self._queue_manager.update_task_status(task_id, TaskStatus.SUCCESS, result=result)
# 审计日志记录(安全设计:所有识别请求记录调用方、时间)
logger.info(
f"任务执行完成: id={task_id}, type={task.task_type}, "
f"time={elapsed:.1f}ms, student={task.student_id}"
)
# 如有回调URL则通知调用方
if task.callback_url:
self._send_callback(task.callback_url, task_id, result)
return result
except Exception as e:
task.retry_count += 1
if task.retry_count < task.max_retries:
# 重试:将任务重新加入队列
task.status = TaskStatus.RETRY
logger.warning(f"任务重试: id={task_id}, retry={task.retry_count}/{task.max_retries}")
else:
self._queue_manager.update_task_status(
task_id, TaskStatus.FAILURE, error=str(e)
)
logger.error(f"任务最终失败: id={task_id}, error={str(e)}")
return {"error": str(e)}
def _send_callback(self, url: str, task_id: str, result: Dict):
"""发送任务完成回调通知"""
try:
# 实际环境使用httpx/aiohttp发送POST请求
logger.info(f"发送任务回调: url={url}, task_id={task_id}")
except Exception as e:
logger.error(f"回调通知失败: {str(e)}")
# ==================== 定时调度器 ====================
class ScheduledTaskRunner:
"""
定时任务调度器
管理周期性执行的后台任务,如:
- 模型健康检查(每5分钟)
- 过期任务清理(每小时)
- 性能指标采集(每分钟)
- 模型更新检查(每天)
"""
def __init__(self):
self._schedules: Dict[str, Dict] = {}
self._running = False
logger.info("定时任务调度器初始化")
def register_schedule(self, name: str, interval_seconds: int,
handler: callable, description: str = ""):
"""注册定时任务"""
self._schedules[name] = {
"interval": interval_seconds,
"handler": handler,
"description": description,
"last_run": None,
"run_count": 0,
"error_count": 0
}
logger.info(f"注册定时任务: {name}, 间隔={interval_seconds}s")
def run_task(self, name: str) -> Optional[Dict]:
"""立即执行指定的定时任务"""
schedule = self._schedules.get(name)
if not schedule:
return None
try:
start = time.time()
result = schedule["handler"]()
elapsed = time.time() - start
schedule["last_run"] = datetime.now().isoformat()
schedule["run_count"] += 1
logger.info(f"定时任务执行完成: {name}, 耗时={elapsed:.2f}s")
return {"name": name, "success": True, "elapsed_s": round(elapsed, 2)}
except Exception as e:
schedule["error_count"] += 1
logger.error(f"定时任务执行失败: {name}, 错误={str(e)}")
return {"name": name, "success": False, "error": str(e)}
def get_schedule_status(self) -> List[Dict]:
"""获取所有定时任务状态"""
return [{
"name": name,
"interval_seconds": info["interval"],
"description": info["description"],
"last_run": info["last_run"],
"run_count": info["run_count"],
"error_count": info["error_count"]
} for name, info in self._schedules.items()]