software copyright
This commit is contained in:
@@ -0,0 +1,314 @@
|
||||
# 自然写手写识别与AI分析引擎软件 V1.0
|
||||
# Celery异步任务调度模块 - 识别请求异步处理与优先级调度
|
||||
|
||||
"""
|
||||
Celery任务调度服务
|
||||
管理AI识别请求的异步任务队列,支持优先级调度、任务重试、
|
||||
结果回调通知、任务进度追踪等功能
|
||||
使用Redis作为消息Broker和结果Backend
|
||||
"""
|
||||
|
||||
import time
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from enum import IntEnum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ==================== 任务优先级定义 ====================
|
||||
|
||||
class TaskPriority(IntEnum):
|
||||
"""任务优先级(数值越小优先级越高)"""
|
||||
CRITICAL = 0 # 最高优先级:课堂实时互动场景
|
||||
HIGH = 1 # 高优先级:教师在线批改
|
||||
NORMAL = 2 # 普通优先级:作业自动批改
|
||||
LOW = 3 # 低优先级:批量历史数据处理
|
||||
BACKGROUND = 4 # 后台优先级:模型评估/训练数据生成
|
||||
|
||||
|
||||
class TaskStatus:
|
||||
"""任务状态常量"""
|
||||
PENDING = "PENDING" # 等待执行
|
||||
STARTED = "STARTED" # 已开始执行
|
||||
PROCESSING = "PROCESSING" # 处理中
|
||||
SUCCESS = "SUCCESS" # 执行成功
|
||||
FAILURE = "FAILURE" # 执行失败
|
||||
RETRY = "RETRY" # 重试中
|
||||
REVOKED = "REVOKED" # 已取消
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskRecord:
|
||||
"""任务记录"""
|
||||
task_id: str
|
||||
task_type: str # 任务类型(ocr/math/stroke_order/essay)
|
||||
priority: TaskPriority
|
||||
status: str = TaskStatus.PENDING
|
||||
input_data: Dict = field(default_factory=dict)
|
||||
result: Optional[Dict] = None
|
||||
error_message: Optional[str] = None
|
||||
retry_count: int = 0
|
||||
max_retries: int = 3
|
||||
created_at: str = ""
|
||||
started_at: Optional[str] = None
|
||||
completed_at: Optional[str] = None
|
||||
callback_url: Optional[str] = None # 完成后回调通知URL
|
||||
student_id: Optional[str] = None
|
||||
assignment_id: Optional[str] = None
|
||||
|
||||
|
||||
# ==================== 任务队列管理器 ====================
|
||||
|
||||
class TaskQueueManager:
|
||||
"""
|
||||
任务队列管理器
|
||||
管理多个优先级队列,确保高优先级任务(如课堂实时互动)优先处理
|
||||
使用Redis有序集合(ZSET)实现优先级调度
|
||||
"""
|
||||
|
||||
# 各任务类型的默认队列名
|
||||
QUEUE_MAPPING = {
|
||||
"ocr": "writech.ocr",
|
||||
"math": "writech.math",
|
||||
"stroke_order": "writech.stroke_order",
|
||||
"essay": "writech.essay",
|
||||
"batch": "writech.batch"
|
||||
}
|
||||
|
||||
def __init__(self, redis_url: str = "redis://localhost:6379/0"):
|
||||
self._redis_url = redis_url
|
||||
self._tasks: Dict[str, TaskRecord] = {} # 内存任务记录(生产环境用Redis)
|
||||
self._queue: List[TaskRecord] = [] # 优先级队列
|
||||
logger.info(f"任务队列管理器初始化: redis={redis_url}")
|
||||
|
||||
def submit_task(self, task_type: str, input_data: Dict,
|
||||
priority: TaskPriority = TaskPriority.NORMAL,
|
||||
callback_url: Optional[str] = None,
|
||||
student_id: Optional[str] = None,
|
||||
assignment_id: Optional[str] = None) -> str:
|
||||
"""
|
||||
提交识别任务到队列
|
||||
返回任务ID,调用方可通过ID查询任务状态和结果
|
||||
"""
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
record = TaskRecord(
|
||||
task_id=task_id,
|
||||
task_type=task_type,
|
||||
priority=priority,
|
||||
input_data=input_data,
|
||||
created_at=datetime.now().isoformat(),
|
||||
callback_url=callback_url,
|
||||
student_id=student_id,
|
||||
assignment_id=assignment_id
|
||||
)
|
||||
|
||||
self._tasks[task_id] = record
|
||||
self._queue.append(record)
|
||||
# 按优先级排序(数值小的排在前面)
|
||||
self._queue.sort(key=lambda t: (t.priority, t.created_at))
|
||||
|
||||
queue_name = self.QUEUE_MAPPING.get(task_type, "writech.default")
|
||||
logger.info(
|
||||
f"任务已提交: id={task_id}, type={task_type}, "
|
||||
f"priority={priority.name}, queue={queue_name}"
|
||||
)
|
||||
return task_id
|
||||
|
||||
def get_next_task(self) -> Optional[TaskRecord]:
|
||||
"""获取队列中优先级最高的待执行任务"""
|
||||
for task in self._queue:
|
||||
if task.status == TaskStatus.PENDING:
|
||||
task.status = TaskStatus.STARTED
|
||||
task.started_at = datetime.now().isoformat()
|
||||
return task
|
||||
return None
|
||||
|
||||
def update_task_status(self, task_id: str, status: str,
|
||||
result: Optional[Dict] = None,
|
||||
error: Optional[str] = None):
|
||||
"""更新任务状态"""
|
||||
if task_id in self._tasks:
|
||||
task = self._tasks[task_id]
|
||||
task.status = status
|
||||
if result:
|
||||
task.result = result
|
||||
if error:
|
||||
task.error_message = error
|
||||
if status in (TaskStatus.SUCCESS, TaskStatus.FAILURE):
|
||||
task.completed_at = datetime.now().isoformat()
|
||||
logger.info(f"任务状态更新: id={task_id}, status={status}")
|
||||
|
||||
def get_task_status(self, task_id: str) -> Optional[Dict]:
|
||||
"""查询任务状态和结果"""
|
||||
task = self._tasks.get(task_id)
|
||||
if not task:
|
||||
return None
|
||||
return {
|
||||
"task_id": task.task_id,
|
||||
"task_type": task.task_type,
|
||||
"status": task.status,
|
||||
"priority": task.priority.name,
|
||||
"result": task.result,
|
||||
"error_message": task.error_message,
|
||||
"retry_count": task.retry_count,
|
||||
"created_at": task.created_at,
|
||||
"started_at": task.started_at,
|
||||
"completed_at": task.completed_at
|
||||
}
|
||||
|
||||
def get_queue_stats(self) -> Dict:
|
||||
"""获取队列统计信息"""
|
||||
stats = {"total": len(self._tasks)}
|
||||
for status in [TaskStatus.PENDING, TaskStatus.STARTED,
|
||||
TaskStatus.SUCCESS, TaskStatus.FAILURE]:
|
||||
stats[status.lower()] = sum(
|
||||
1 for t in self._tasks.values() if t.status == status
|
||||
)
|
||||
return stats
|
||||
|
||||
|
||||
# ==================== Celery任务定义 ====================
|
||||
|
||||
class CeleryTaskExecutor:
|
||||
"""
|
||||
Celery任务执行器
|
||||
定义各类AI识别的Celery异步任务
|
||||
每个任务类型对应一个独立的任务函数和执行队列
|
||||
"""
|
||||
|
||||
def __init__(self, queue_manager: TaskQueueManager):
|
||||
self._queue_manager = queue_manager
|
||||
self._task_handlers: Dict[str, callable] = {}
|
||||
logger.info("Celery任务执行器初始化")
|
||||
|
||||
def register_handler(self, task_type: str, handler: callable):
|
||||
"""注册任务处理函数"""
|
||||
self._task_handlers[task_type] = handler
|
||||
logger.info(f"注册任务处理器: {task_type}")
|
||||
|
||||
def execute_task(self, task_id: str) -> Dict:
|
||||
"""
|
||||
执行指定任务
|
||||
包含异常处理、重试逻辑、超时控制
|
||||
"""
|
||||
task = self._queue_manager._tasks.get(task_id)
|
||||
if not task:
|
||||
return {"error": "任务不存在"}
|
||||
|
||||
handler = self._task_handlers.get(task.task_type)
|
||||
if not handler:
|
||||
self._queue_manager.update_task_status(
|
||||
task_id, TaskStatus.FAILURE,
|
||||
error=f"未注册的任务类型: {task.task_type}"
|
||||
)
|
||||
return {"error": f"未注册的任务类型: {task.task_type}"}
|
||||
|
||||
try:
|
||||
self._queue_manager.update_task_status(task_id, TaskStatus.PROCESSING)
|
||||
|
||||
# 执行推理任务
|
||||
start_time = time.time()
|
||||
result = handler(task.input_data)
|
||||
elapsed = (time.time() - start_time) * 1000
|
||||
|
||||
result['processing_time_ms'] = round(elapsed, 2)
|
||||
self._queue_manager.update_task_status(task_id, TaskStatus.SUCCESS, result=result)
|
||||
|
||||
# 审计日志记录(安全设计:所有识别请求记录调用方、时间)
|
||||
logger.info(
|
||||
f"任务执行完成: id={task_id}, type={task.task_type}, "
|
||||
f"time={elapsed:.1f}ms, student={task.student_id}"
|
||||
)
|
||||
|
||||
# 如有回调URL则通知调用方
|
||||
if task.callback_url:
|
||||
self._send_callback(task.callback_url, task_id, result)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
task.retry_count += 1
|
||||
if task.retry_count < task.max_retries:
|
||||
# 重试:将任务重新加入队列
|
||||
task.status = TaskStatus.RETRY
|
||||
logger.warning(f"任务重试: id={task_id}, retry={task.retry_count}/{task.max_retries}")
|
||||
else:
|
||||
self._queue_manager.update_task_status(
|
||||
task_id, TaskStatus.FAILURE, error=str(e)
|
||||
)
|
||||
logger.error(f"任务最终失败: id={task_id}, error={str(e)}")
|
||||
return {"error": str(e)}
|
||||
|
||||
def _send_callback(self, url: str, task_id: str, result: Dict):
|
||||
"""发送任务完成回调通知"""
|
||||
try:
|
||||
# 实际环境使用httpx/aiohttp发送POST请求
|
||||
logger.info(f"发送任务回调: url={url}, task_id={task_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"回调通知失败: {str(e)}")
|
||||
|
||||
|
||||
# ==================== 定时调度器 ====================
|
||||
|
||||
class ScheduledTaskRunner:
|
||||
"""
|
||||
定时任务调度器
|
||||
管理周期性执行的后台任务,如:
|
||||
- 模型健康检查(每5分钟)
|
||||
- 过期任务清理(每小时)
|
||||
- 性能指标采集(每分钟)
|
||||
- 模型更新检查(每天)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._schedules: Dict[str, Dict] = {}
|
||||
self._running = False
|
||||
logger.info("定时任务调度器初始化")
|
||||
|
||||
def register_schedule(self, name: str, interval_seconds: int,
|
||||
handler: callable, description: str = ""):
|
||||
"""注册定时任务"""
|
||||
self._schedules[name] = {
|
||||
"interval": interval_seconds,
|
||||
"handler": handler,
|
||||
"description": description,
|
||||
"last_run": None,
|
||||
"run_count": 0,
|
||||
"error_count": 0
|
||||
}
|
||||
logger.info(f"注册定时任务: {name}, 间隔={interval_seconds}s")
|
||||
|
||||
def run_task(self, name: str) -> Optional[Dict]:
|
||||
"""立即执行指定的定时任务"""
|
||||
schedule = self._schedules.get(name)
|
||||
if not schedule:
|
||||
return None
|
||||
|
||||
try:
|
||||
start = time.time()
|
||||
result = schedule["handler"]()
|
||||
elapsed = time.time() - start
|
||||
schedule["last_run"] = datetime.now().isoformat()
|
||||
schedule["run_count"] += 1
|
||||
logger.info(f"定时任务执行完成: {name}, 耗时={elapsed:.2f}s")
|
||||
return {"name": name, "success": True, "elapsed_s": round(elapsed, 2)}
|
||||
except Exception as e:
|
||||
schedule["error_count"] += 1
|
||||
logger.error(f"定时任务执行失败: {name}, 错误={str(e)}")
|
||||
return {"name": name, "success": False, "error": str(e)}
|
||||
|
||||
def get_schedule_status(self) -> List[Dict]:
|
||||
"""获取所有定时任务状态"""
|
||||
return [{
|
||||
"name": name,
|
||||
"interval_seconds": info["interval"],
|
||||
"description": info["description"],
|
||||
"last_run": info["last_run"],
|
||||
"run_count": info["run_count"],
|
||||
"error_count": info["error_count"]
|
||||
} for name, info in self._schedules.items()]
|
||||
Reference in New Issue
Block a user