4401 lines
156 KiB
Markdown
4401 lines
156 KiB
Markdown
# 自然写手写识别与AI分析引擎软件 V1.0
|
||
## 软件著作权鉴别材料 — 源程序
|
||
|
||
> **权利人**:深圳自然写科技有限公司
|
||
> **版本号**:V1.0
|
||
|
||
---
|
||
|
||
## 源程序目录结构
|
||
|
||
```
|
||
02-writech-ai-engine/
|
||
├── main.py
|
||
├── api/
|
||
│ ├── essay_api.py
|
||
│ ├── math_api.py
|
||
│ ├── ocr_api.py
|
||
│ └── stroke_order_api.py
|
||
├── config/
|
||
│ └── settings.py
|
||
├── engine/
|
||
│ ├── essay_scorer.py
|
||
│ └── stroke_analyzer.py
|
||
├── grpc_server/
|
||
│ └── inference_service.py
|
||
├── preprocessing/
|
||
│ └── stroke_processor.py
|
||
└── service/
|
||
├── model_manager.py
|
||
└── task_scheduler.py
|
||
```
|
||
|
||
---
|
||
|
||
## 源程序文件清单
|
||
|
||
### (根目录)
|
||
|
||
#### `main.py`
|
||
|
||
```python
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
自然写手写识别与AI分析引擎软件 V1.0
|
||
|
||
版权所有 (C) 2026
|
||
软件全称:自然写手写识别与AI分析引擎软件
|
||
版本号:V1.0
|
||
|
||
主启动文件 - FastAPI 服务入口
|
||
负责服务初始化、路由注册、中间件配置
|
||
"""
|
||
|
||
from fastapi import FastAPI, Request, HTTPException
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.responses import JSONResponse
|
||
from contextlib import asynccontextmanager
|
||
import uvicorn
|
||
import logging
|
||
import time
|
||
from typing import Dict, Any
|
||
|
||
# 导入各业务模块路由
|
||
from api.ocr_api import router as ocr_router
|
||
from api.math_api import router as math_router
|
||
from api.stroke_order_api import router as stroke_order_router
|
||
from api.essay_api import router as essay_router
|
||
from service.model_manager import ModelManager
|
||
from config.settings import Settings
|
||
|
||
# 日志配置
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
|
||
)
|
||
logger = logging.getLogger("writech-ai-engine")
|
||
|
||
# 全局配置
|
||
settings = Settings()
|
||
|
||
# 全局模型管理器实例
|
||
model_manager = ModelManager(settings)
|
||
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
"""
|
||
应用生命周期管理
|
||
启动时加载所有AI模型到GPU/CPU内存
|
||
关闭时释放模型资源
|
||
"""
|
||
logger.info("自然写AI引擎启动中,加载模型...")
|
||
# 启动时加载所有模型
|
||
await model_manager.load_all_models()
|
||
logger.info("所有模型加载完成,服务就绪")
|
||
yield
|
||
# 关闭时释放资源
|
||
logger.info("服务关闭中,释放模型资源...")
|
||
model_manager.release_all_models()
|
||
logger.info("模型资源已释放")
|
||
|
||
|
||
# 创建 FastAPI 应用实例
|
||
app = FastAPI(
|
||
title="自然写手写识别与AI分析引擎",
|
||
description="对智能点阵笔采集的笔迹数据进行OCR识别、数学列式识别、笔顺分析及AI智能批改",
|
||
version="1.0.0",
|
||
lifespan=lifespan
|
||
)
|
||
|
||
# 跨域中间件配置
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
|
||
@app.middleware("http")
|
||
async def request_logging_middleware(request: Request, call_next):
|
||
"""
|
||
请求日志与性能监控中间件
|
||
记录每个请求的处理时间、状态码、推理耗时
|
||
"""
|
||
start_time = time.time()
|
||
request_id = request.headers.get("X-Request-ID", str(time.time()))
|
||
|
||
# 输入数据大小校验(防恶意攻击,最大10MB)
|
||
content_length = request.headers.get("content-length")
|
||
if content_length and int(content_length) > 10 * 1024 * 1024:
|
||
return JSONResponse(
|
||
status_code=413,
|
||
content={"code": 413, "msg": "请求数据过大,最大支持10MB", "data": None}
|
||
)
|
||
|
||
response = await call_next(request)
|
||
|
||
# 记录请求处理时间
|
||
process_time = time.time() - start_time
|
||
response.headers["X-Process-Time"] = f"{process_time:.4f}"
|
||
response.headers["X-Request-ID"] = request_id
|
||
|
||
logger.info(
|
||
f"{request.method} {request.url.path} "
|
||
f"status={response.status_code} "
|
||
f"time={process_time:.4f}s"
|
||
)
|
||
|
||
return response
|
||
|
||
|
||
@app.middleware("http")
|
||
async def mtls_authentication_middleware(request: Request, call_next):
|
||
"""
|
||
mTLS 双向认证中间件
|
||
内部服务间通信需携带有效的客户端证书
|
||
|
||
安全设计:
|
||
- 服务鉴权:内部服务间 mTLS 双向认证
|
||
- 请求校验:输入数据格式校验与大小限制(防恶意攻击)
|
||
"""
|
||
# 检查是否为内部服务调用
|
||
client_cert = request.headers.get("X-Client-Cert")
|
||
api_key = request.headers.get("X-API-Key")
|
||
|
||
# 白名单路径不需要认证
|
||
whitelist_paths = ["/health", "/docs", "/openapi.json"]
|
||
if request.url.path in whitelist_paths:
|
||
return await call_next(request)
|
||
|
||
# 验证API Key或客户端证书
|
||
if not api_key and not client_cert:
|
||
return JSONResponse(
|
||
status_code=401,
|
||
content={"code": 401, "msg": "缺少认证凭据", "data": None}
|
||
)
|
||
|
||
if api_key and api_key != settings.api_key:
|
||
return JSONResponse(
|
||
status_code=403,
|
||
content={"code": 403, "msg": "API Key无效", "data": None}
|
||
)
|
||
|
||
return await call_next(request)
|
||
|
||
|
||
# 注册各业务路由
|
||
app.include_router(ocr_router, prefix="/api/v1/ocr", tags=["OCR识别"])
|
||
app.include_router(math_router, prefix="/api/v1/math", tags=["数学识别"])
|
||
app.include_router(stroke_order_router, prefix="/api/v1/stroke-order", tags=["笔顺评分"])
|
||
app.include_router(essay_router, prefix="/api/v1/essay", tags=["作文批改"])
|
||
|
||
|
||
@app.get("/health")
|
||
async def health_check():
|
||
"""健康检查端点"""
|
||
model_status = model_manager.get_all_status()
|
||
return {
|
||
"code": 200,
|
||
"msg": "success",
|
||
"data": {
|
||
"status": "healthy",
|
||
"models": model_status,
|
||
"version": "1.0.0"
|
||
}
|
||
}
|
||
|
||
|
||
@app.get("/api/v1/model/status")
|
||
async def get_model_status():
|
||
"""
|
||
查询各模型加载状态与版本
|
||
GET /api/v1/model/status
|
||
"""
|
||
status = model_manager.get_all_status()
|
||
return {
|
||
"code": 200,
|
||
"msg": "success",
|
||
"data": status
|
||
}
|
||
|
||
|
||
@app.exception_handler(HTTPException)
|
||
async def http_exception_handler(request: Request, exc: HTTPException):
|
||
"""统一HTTP异常处理"""
|
||
return JSONResponse(
|
||
status_code=exc.status_code,
|
||
content={
|
||
"code": exc.status_code,
|
||
"msg": exc.detail,
|
||
"data": None
|
||
}
|
||
)
|
||
|
||
|
||
@app.exception_handler(Exception)
|
||
async def general_exception_handler(request: Request, exc: Exception):
|
||
"""统一异常处理"""
|
||
logger.error(f"未处理异常: {str(exc)}", exc_info=True)
|
||
return JSONResponse(
|
||
status_code=500,
|
||
content={
|
||
"code": 500,
|
||
"msg": "AI引擎内部错误",
|
||
"data": None
|
||
}
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
uvicorn.run(
|
||
"main:app",
|
||
host="0.0.0.0",
|
||
port=8001,
|
||
workers=4,
|
||
log_level="info"
|
||
)
|
||
```
|
||
|
||
### `api/`
|
||
|
||
#### `api/essay_api.py`
|
||
|
||
```python
|
||
# 自然写手写识别与AI分析引擎软件 V1.0
|
||
# 作文批改接口模块 - AI作文评分与批改建议服务
|
||
|
||
"""
|
||
作文批改API接口
|
||
提供AI作文评分、多维度分析(结构/语法/内容/修辞)、批改建议生成等功能
|
||
支持小学至初中阶段作文批改,基于大语言模型与NLP分析管道
|
||
"""
|
||
|
||
import time
|
||
import json
|
||
import logging
|
||
import hashlib
|
||
import re
|
||
from typing import List, Dict, Optional, Tuple
|
||
from dataclasses import dataclass, field
|
||
from enum import Enum
|
||
from fastapi import APIRouter, HTTPException, Depends
|
||
from pydantic import BaseModel, Field, validator
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# ==================== 数据模型定义 ====================
|
||
|
||
class EssayReviewRequest(BaseModel):
|
||
"""作文批改请求"""
|
||
text: str = Field(..., min_length=10, max_length=5000, description="作文OCR识别文本")
|
||
title: Optional[str] = Field(None, description="作文题目")
|
||
grade: int = Field(3, ge=1, le=9, description="年级(1-9)")
|
||
genre: str = Field("narrative", description="文体类型: narrative/argumentative/expository/descriptive")
|
||
max_score: int = Field(100, description="满分值")
|
||
student_id: Optional[str] = Field(None, description="学生ID")
|
||
assignment_id: Optional[str] = Field(None, description="作业ID")
|
||
enable_suggestions: bool = Field(True, description="是否生成修改建议")
|
||
|
||
@validator('genre')
|
||
def validate_genre(cls, v):
|
||
valid_genres = ['narrative', 'argumentative', 'expository', 'descriptive']
|
||
if v not in valid_genres:
|
||
raise ValueError(f'文体类型必须为: {valid_genres}')
|
||
return v
|
||
|
||
|
||
class SentenceError(BaseModel):
|
||
"""句子级错误标注"""
|
||
sentence: str = Field(..., description="原始句子")
|
||
error_type: str = Field(..., description="错误类型")
|
||
suggestion: str = Field(..., description="修改建议")
|
||
position: int = Field(..., description="句子在原文中的位置索引")
|
||
|
||
|
||
class EssayScoreDetail(BaseModel):
|
||
"""作文各维度评分详情"""
|
||
structure: float = Field(..., description="结构分")
|
||
grammar: float = Field(..., description="语法分")
|
||
content: float = Field(..., description="内容分")
|
||
rhetoric: float = Field(..., description="修辞分")
|
||
handwriting: Optional[float] = Field(None, description="书写分(如有)")
|
||
|
||
|
||
# ==================== 文本分析工具 ====================
|
||
|
||
class TextAnalyzer:
|
||
"""
|
||
文本分析工具类
|
||
提供基础的中文文本分析功能:分句、词频统计、句式分析等
|
||
"""
|
||
|
||
# 中文句末标点
|
||
SENTENCE_ENDINGS = {'。', '!', '?', '……', ';'}
|
||
# 中文段落标识
|
||
PARAGRAPH_INDENT = ' '
|
||
|
||
@staticmethod
|
||
def split_sentences(text: str) -> List[str]:
|
||
"""将文本分割为句子列表"""
|
||
sentences = []
|
||
current = ""
|
||
for char in text:
|
||
current += char
|
||
if char in TextAnalyzer.SENTENCE_ENDINGS:
|
||
if current.strip():
|
||
sentences.append(current.strip())
|
||
current = ""
|
||
if current.strip():
|
||
sentences.append(current.strip())
|
||
return sentences
|
||
|
||
@staticmethod
|
||
def split_paragraphs(text: str) -> List[str]:
|
||
"""将文本分割为段落列表"""
|
||
# 按换行符分割,过滤空段落
|
||
paragraphs = [p.strip() for p in text.split('\n') if p.strip()]
|
||
return paragraphs
|
||
|
||
@staticmethod
|
||
def count_characters(text: str) -> Dict[str, int]:
|
||
"""统计文本字符数"""
|
||
chinese_count = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
|
||
punctuation_count = sum(1 for c in text if c in ',。!?、;:""''()《》……—')
|
||
total_count = len(text.replace(' ', '').replace('\n', ''))
|
||
return {
|
||
"total": total_count,
|
||
"chinese": chinese_count,
|
||
"punctuation": punctuation_count
|
||
}
|
||
|
||
@staticmethod
|
||
def detect_rhetoric(text: str) -> List[Dict]:
|
||
"""
|
||
检测修辞手法使用情况
|
||
识别常见修辞:比喻、排比、拟人、夸张等
|
||
"""
|
||
rhetorics = []
|
||
|
||
# 比喻检测:包含"像...一样"、"如同"、"仿佛"等关键词
|
||
simile_patterns = [
|
||
r'像.{2,10}一样', r'如同.{2,10}', r'仿佛.{2,10}',
|
||
r'好像.{2,10}', r'犹如.{2,10}', r'宛如.{2,10}'
|
||
]
|
||
for pattern in simile_patterns:
|
||
matches = re.finditer(pattern, text)
|
||
for m in matches:
|
||
rhetorics.append({
|
||
"type": "simile", "name": "比喻",
|
||
"text": m.group(), "position": m.start()
|
||
})
|
||
|
||
# 排比检测:连续出现相似句式结构
|
||
sentences = TextAnalyzer.split_sentences(text)
|
||
for i in range(len(sentences) - 2):
|
||
s1, s2, s3 = sentences[i], sentences[i+1], sentences[i+2]
|
||
# 简化判断:三个连续句子长度相近且首字相同
|
||
if (abs(len(s1) - len(s2)) < 5 and abs(len(s2) - len(s3)) < 5 and
|
||
len(s1) > 5 and s1[0] == s2[0] == s3[0]):
|
||
rhetorics.append({
|
||
"type": "parallelism", "name": "排比",
|
||
"text": f"{s1}{s2}{s3}", "position": text.find(s1)
|
||
})
|
||
|
||
# 拟人检测:非人事物使用人的动作词
|
||
personification_patterns = [
|
||
r'[风雨雪花树草月阳光河水山].{0,3}[笑哭唱跳跑走说叫]',
|
||
r'[风雨雪花树草月阳光河水山].{0,3}[温柔轻轻悄悄]'
|
||
]
|
||
for pattern in personification_patterns:
|
||
matches = re.finditer(pattern, text)
|
||
for m in matches:
|
||
rhetorics.append({
|
||
"type": "personification", "name": "拟人",
|
||
"text": m.group(), "position": m.start()
|
||
})
|
||
|
||
return rhetorics
|
||
|
||
|
||
# ==================== 作文评分引擎 ====================
|
||
|
||
class EssayScoringEngine:
|
||
"""
|
||
作文评分引擎
|
||
基于多维度分析管道对作文进行综合评分
|
||
评分维度:结构(25%)、语法(25%)、内容(30%)、修辞(20%)
|
||
"""
|
||
|
||
# 各年级期望字数范围
|
||
EXPECTED_LENGTH = {
|
||
1: (50, 150), 2: (100, 250), 3: (200, 400),
|
||
4: (300, 500), 5: (350, 600), 6: (400, 700),
|
||
7: (500, 800), 8: (600, 900), 9: (600, 1000)
|
||
}
|
||
|
||
# 评分维度权重配置
|
||
DIMENSION_WEIGHTS = {
|
||
"structure": 0.25,
|
||
"grammar": 0.25,
|
||
"content": 0.30,
|
||
"rhetoric": 0.20
|
||
}
|
||
|
||
def __init__(self):
|
||
self._text_analyzer = TextAnalyzer()
|
||
self._error_patterns = self._load_error_patterns()
|
||
logger.info("作文评分引擎初始化完成")
|
||
|
||
def _load_error_patterns(self) -> List[Dict]:
|
||
"""加载常见语法错误模式库"""
|
||
return [
|
||
{"pattern": r"的的", "type": "repetition", "msg": "重复用字'的的'"},
|
||
{"pattern": r"了了", "type": "repetition", "msg": "重复用字'了了'"},
|
||
{"pattern": r"因为.{5,50}因为", "type": "logic", "msg": "重复使用'因为',建议精简"},
|
||
{"pattern": r"然后.{3,20}然后.{3,20}然后", "type": "style", "msg": "过度使用'然后'连接"},
|
||
{"pattern": r"非常非常", "type": "repetition", "msg": "重复使用'非常'"},
|
||
{"pattern": r"[,]{3,}", "type": "punctuation", "msg": "连续使用多个逗号,建议使用句号断句"},
|
||
]
|
||
|
||
def score_structure(self, text: str, grade: int) -> Tuple[float, List[str]]:
|
||
"""
|
||
评估文章结构(满分100)
|
||
检查:段落划分、开头结尾完整性、字数是否达标、层次是否清晰
|
||
"""
|
||
comments = []
|
||
score = 100.0
|
||
|
||
paragraphs = self._text_analyzer.split_paragraphs(text)
|
||
char_stats = self._text_analyzer.count_characters(text)
|
||
|
||
# 段落数评估(期望3-8段)
|
||
if len(paragraphs) < 2:
|
||
score -= 25
|
||
comments.append("文章缺少段落划分,建议分段书写使结构更清晰")
|
||
elif len(paragraphs) < 3:
|
||
score -= 10
|
||
comments.append("段落较少,建议增加过渡段落")
|
||
|
||
# 字数评估
|
||
expected = self.EXPECTED_LENGTH.get(grade, (300, 600))
|
||
if char_stats["chinese"] < expected[0]:
|
||
deficit = expected[0] - char_stats["chinese"]
|
||
score -= min(30, deficit // 10)
|
||
comments.append(f"字数偏少({char_stats['chinese']}字),该年级建议{expected[0]}-{expected[1]}字")
|
||
elif char_stats["chinese"] > expected[1] * 1.5:
|
||
score -= 5
|
||
comments.append("字数偏多,建议精简语句突出重点")
|
||
|
||
# 开头结尾评估
|
||
if paragraphs:
|
||
first_para = paragraphs[0]
|
||
last_para = paragraphs[-1]
|
||
if len(first_para) < 15:
|
||
score -= 10
|
||
comments.append("开头过于简短,建议丰富开篇引入")
|
||
if len(last_para) < 10:
|
||
score -= 10
|
||
comments.append("结尾过于简短,建议加强收束呼应主题")
|
||
|
||
return max(0, score), comments
|
||
|
||
def score_grammar(self, text: str) -> Tuple[float, List[SentenceError]]:
|
||
"""
|
||
评估语法正确性(满分100)
|
||
检查:常见语病、标点使用、词语搭配
|
||
"""
|
||
errors = []
|
||
score = 100.0
|
||
|
||
# 使用预定义的错误模式进行匹配检测
|
||
for ep in self._error_patterns:
|
||
matches = re.finditer(ep["pattern"], text)
|
||
for m in matches:
|
||
errors.append(SentenceError(
|
||
sentence=m.group(),
|
||
error_type=ep["type"],
|
||
suggestion=ep["msg"],
|
||
position=m.start()
|
||
))
|
||
score -= 5 # 每个语法错误扣5分
|
||
|
||
# 检查句子长度(过长的句子可能有语病)
|
||
sentences = self._text_analyzer.split_sentences(text)
|
||
for i, s in enumerate(sentences):
|
||
if len(s) > 80:
|
||
errors.append(SentenceError(
|
||
sentence=s[:30] + "...",
|
||
error_type="long_sentence",
|
||
suggestion="句子过长,建议拆分为多个短句以提高可读性",
|
||
position=text.find(s)
|
||
))
|
||
score -= 3
|
||
|
||
return max(0, score), errors
|
||
|
||
def score_content(self, text: str, title: Optional[str], genre: str, grade: int) -> Tuple[float, List[str]]:
|
||
"""
|
||
评估内容质量(满分100)
|
||
检查:主题相关性、内容丰富度、逻辑连贯性、情感表达
|
||
"""
|
||
comments = []
|
||
score = 85.0 # 基础分(内容难以精确量化,给予较高基础分)
|
||
|
||
char_stats = self._text_analyzer.count_characters(text)
|
||
sentences = self._text_analyzer.split_sentences(text)
|
||
|
||
# 内容丰富度:通过不同词汇的数量粗略评估
|
||
unique_chars = set(c for c in text if '\u4e00' <= c <= '\u9fff')
|
||
vocab_richness = len(unique_chars) / max(char_stats["chinese"], 1)
|
||
if vocab_richness > 0.6:
|
||
score += 10
|
||
comments.append("词汇丰富,用词多样化")
|
||
elif vocab_richness < 0.3:
|
||
score -= 10
|
||
comments.append("词汇较为单一,建议使用更丰富的词语表达")
|
||
|
||
# 逻辑连贯性:检查是否使用连接词
|
||
connectors = ['因此', '所以', '但是', '然而', '首先', '其次', '最后', '总之',
|
||
'不仅', '而且', '虽然', '但', '因为', '于是']
|
||
used_connectors = [c for c in connectors if c in text]
|
||
if len(used_connectors) >= 3:
|
||
score += 5
|
||
comments.append("逻辑衔接词使用恰当,行文连贯")
|
||
elif len(used_connectors) == 0 and len(sentences) > 5:
|
||
score -= 5
|
||
comments.append("缺少逻辑连接词,建议增加过渡衔接使行文更连贯")
|
||
|
||
# 情感表达评估
|
||
emotion_words = ['开心', '快乐', '高兴', '感动', '难过', '伤心', '惊讶',
|
||
'温暖', '幸福', '骄傲', '担心', '紧张']
|
||
used_emotions = [w for w in emotion_words if w in text]
|
||
if used_emotions:
|
||
score += 3
|
||
comments.append("有恰当的情感表达,增强了文章感染力")
|
||
|
||
return min(100, max(0, score)), comments
|
||
|
||
def score_rhetoric(self, text: str, grade: int) -> Tuple[float, List[str]]:
|
||
"""
|
||
评估修辞运用(满分100)
|
||
检查:修辞手法的使用数量和质量
|
||
"""
|
||
comments = []
|
||
score = 70.0 # 基础分
|
||
|
||
rhetorics = self._text_analyzer.detect_rhetoric(text)
|
||
|
||
# 根据检测到的修辞数量加分
|
||
rhetoric_types = set(r["type"] for r in rhetorics)
|
||
if len(rhetoric_types) >= 3:
|
||
score += 25
|
||
comments.append(f"修辞手法运用丰富,使用了{len(rhetoric_types)}种修辞手法")
|
||
elif len(rhetoric_types) >= 1:
|
||
score += 15
|
||
used_names = set(r["name"] for r in rhetorics)
|
||
comments.append(f"使用了{'、'.join(used_names)}等修辞手法")
|
||
else:
|
||
comments.append("建议适当使用比喻、排比等修辞手法增强表达效果")
|
||
|
||
# 高年级对修辞有更高要求
|
||
if grade >= 5 and len(rhetoric_types) < 2:
|
||
score -= 10
|
||
comments.append("该年级建议至少使用2种以上修辞手法")
|
||
|
||
return min(100, max(0, score)), comments
|
||
|
||
def review_essay(self, request: EssayReviewRequest) -> Dict:
|
||
"""
|
||
综合批改作文,返回总分和各维度分析结果
|
||
"""
|
||
start_time = time.time()
|
||
|
||
# 各维度独立评分
|
||
struct_score, struct_comments = self.score_structure(request.text, request.grade)
|
||
grammar_score, grammar_errors = self.score_grammar(request.text)
|
||
content_score, content_comments = self.score_content(
|
||
request.text, request.title, request.genre, request.grade)
|
||
rhetoric_score, rhetoric_comments = self.score_rhetoric(request.text, request.grade)
|
||
|
||
# 按权重计算总分,并映射到满分值
|
||
weighted_score = (
|
||
struct_score * self.DIMENSION_WEIGHTS["structure"] +
|
||
grammar_score * self.DIMENSION_WEIGHTS["grammar"] +
|
||
content_score * self.DIMENSION_WEIGHTS["content"] +
|
||
rhetoric_score * self.DIMENSION_WEIGHTS["rhetoric"]
|
||
)
|
||
total_score = round(weighted_score / 100 * request.max_score, 1)
|
||
|
||
# 字数统计
|
||
char_stats = TextAnalyzer.count_characters(request.text)
|
||
|
||
# 生成综合评语
|
||
overall_comment = self._generate_overall_comment(
|
||
total_score, request.max_score, struct_comments,
|
||
content_comments, rhetoric_comments
|
||
)
|
||
|
||
elapsed = (time.time() - start_time) * 1000
|
||
|
||
result = {
|
||
"total_score": total_score,
|
||
"max_score": request.max_score,
|
||
"dimensions": {
|
||
"structure": round(struct_score / 100 * request.max_score * self.DIMENSION_WEIGHTS["structure"], 1),
|
||
"grammar": round(grammar_score / 100 * request.max_score * self.DIMENSION_WEIGHTS["grammar"], 1),
|
||
"content": round(content_score / 100 * request.max_score * self.DIMENSION_WEIGHTS["content"], 1),
|
||
"rhetoric": round(rhetoric_score / 100 * request.max_score * self.DIMENSION_WEIGHTS["rhetoric"], 1),
|
||
},
|
||
"character_count": char_stats,
|
||
"overall_comment": overall_comment,
|
||
"structure_analysis": struct_comments,
|
||
"content_analysis": content_comments,
|
||
"rhetoric_analysis": rhetoric_comments,
|
||
"grammar_errors": [e.dict() for e in grammar_errors] if request.enable_suggestions else [],
|
||
"inference_time_ms": round(elapsed, 2)
|
||
}
|
||
return result
|
||
|
||
def _generate_overall_comment(self, score: float, max_score: int,
|
||
struct_comments: List, content_comments: List,
|
||
rhetoric_comments: List) -> str:
|
||
"""生成综合评语"""
|
||
ratio = score / max_score
|
||
if ratio >= 0.9:
|
||
prefix = "优秀!"
|
||
elif ratio >= 0.75:
|
||
prefix = "良好。"
|
||
elif ratio >= 0.6:
|
||
prefix = "中等。"
|
||
else:
|
||
prefix = "需要加强。"
|
||
|
||
suggestions = []
|
||
if struct_comments:
|
||
suggestions.append(struct_comments[0])
|
||
if content_comments:
|
||
suggestions.append(content_comments[0])
|
||
if rhetoric_comments:
|
||
suggestions.append(rhetoric_comments[0])
|
||
|
||
return f"{prefix}{';'.join(suggestions[:3])}"
|
||
|
||
|
||
# ==================== API路由定义 ====================
|
||
|
||
router = APIRouter(prefix="/api/v1", tags=["作文批改"])
|
||
_scoring_engine = EssayScoringEngine()
|
||
|
||
|
||
@router.post("/essay/review")
|
||
async def review_essay(request: EssayReviewRequest):
|
||
"""
|
||
AI作文评分与批改接口
|
||
POST /api/v1/essay/review
|
||
输入作文OCR识别文本,返回综合评分、各维度分析和修改建议
|
||
"""
|
||
try:
|
||
result = _scoring_engine.review_essay(request)
|
||
|
||
# 审计日志记录
|
||
logger.info(
|
||
f"作文批改完成: score={result['total_score']}/{request.max_score}, "
|
||
f"student={request.student_id}, assignment={request.assignment_id}, "
|
||
f"chars={result['character_count']['chinese']}, time={result['inference_time_ms']}ms"
|
||
)
|
||
return {"code": 200, "msg": "success", "data": result}
|
||
except Exception as e:
|
||
logger.error(f"作文批改异常: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"作文批改服务异常: {str(e)}")
|
||
```
|
||
|
||
#### `api/math_api.py`
|
||
|
||
```python
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
自然写手写识别与AI分析引擎软件 V1.0
|
||
|
||
数学列式与公式识别接口
|
||
支持四则运算、方程式、几何图形公式等数学内容识别
|
||
"""
|
||
|
||
from fastapi import APIRouter, HTTPException
|
||
from pydantic import BaseModel, Field
|
||
from typing import List, Optional, Dict, Any
|
||
import numpy as np
|
||
import logging
|
||
import time
|
||
import uuid
|
||
import re
|
||
|
||
logger = logging.getLogger("writech-ai-engine.math")
|
||
router = APIRouter()
|
||
|
||
|
||
class MathStrokePoint(BaseModel):
|
||
"""数学笔迹坐标点"""
|
||
x: int = Field(..., ge=0, le=65535)
|
||
y: int = Field(..., ge=0, le=65535)
|
||
pressure: int = Field(0, ge=0, le=255)
|
||
timestamp: int = Field(...)
|
||
pen_up: bool = Field(False)
|
||
|
||
|
||
class MathRecognizeRequest(BaseModel):
|
||
"""数学识别请求"""
|
||
strokes: List[List[MathStrokePoint]] = Field(..., description="笔迹数据")
|
||
math_type: str = Field("arithmetic", description="数学类型: arithmetic/equation/geometry")
|
||
grade_level: int = Field(3, ge=1, le=6, description="年级(1-6)")
|
||
|
||
|
||
class MathStep(BaseModel):
|
||
"""计算步骤"""
|
||
step_no: int = Field(..., description="步骤序号")
|
||
expression: str = Field(..., description="表达式")
|
||
result: Optional[str] = Field(None, description="计算结果")
|
||
is_correct: bool = Field(True, description="是否正确")
|
||
error_type: Optional[str] = Field(None, description="错误类型")
|
||
error_detail: Optional[str] = Field(None, description="错误详情")
|
||
|
||
|
||
class MathRecognizeResult(BaseModel):
|
||
"""数学识别结果"""
|
||
latex: str = Field(..., description="LaTeX表达式")
|
||
result: Optional[str] = Field(None, description="计算结果")
|
||
is_correct: bool = Field(True, description="答案是否正确")
|
||
steps: List[MathStep] = Field(default=[], description="计算步骤")
|
||
confidence: float = Field(..., description="识别置信度")
|
||
|
||
|
||
class MathEngine:
|
||
"""
|
||
数学列式识别引擎
|
||
|
||
支持识别类型:
|
||
- 四则运算(加减乘除、连续运算)
|
||
- 竖式计算(加法竖式、减法竖式、乘法竖式、除法竖式)
|
||
- 比较大小(>、<、=)
|
||
- 分数运算
|
||
- 简单方程(一元一次方程)
|
||
|
||
推理流程:
|
||
笔迹 → 图像渲染 → 符号分割 → 符号识别 → 结构分析 → 表达式重建 → 计算验证
|
||
"""
|
||
|
||
def __init__(self):
|
||
self.model = None
|
||
self.is_loaded = False
|
||
# 支持的数学符号集合
|
||
self.symbol_set = set("0123456789+-×÷=><()/.%")
|
||
logger.info("数学识别引擎初始化完成")
|
||
|
||
def load_model(self, model_path: str):
|
||
"""加载数学识别模型"""
|
||
logger.info(f"加载数学识别模型: {model_path}")
|
||
self.is_loaded = True
|
||
logger.info("数学识别模型加载完成")
|
||
|
||
def recognize(self, strokes: List[List[MathStrokePoint]],
|
||
math_type: str = "arithmetic",
|
||
grade_level: int = 3) -> MathRecognizeResult:
|
||
"""
|
||
数学列式识别主流程
|
||
"""
|
||
start_time = time.time()
|
||
|
||
# 步骤1:笔迹预处理与图像渲染
|
||
image = self._preprocess_strokes(strokes)
|
||
|
||
# 步骤2:数学符号分割
|
||
segments = self._segment_symbols(image)
|
||
|
||
# 步骤3:符号识别(CNN分类器)
|
||
symbols = self._recognize_symbols(segments)
|
||
|
||
# 步骤4:结构分析(确定运算符和操作数的空间关系)
|
||
structure = self._analyze_structure(symbols, math_type)
|
||
|
||
# 步骤5:表达式重建(生成LaTeX和数学表达式)
|
||
latex_expr, math_expr = self._reconstruct_expression(structure)
|
||
|
||
# 步骤6:计算验证
|
||
result, is_correct, steps = self._verify_calculation(math_expr, grade_level)
|
||
|
||
inference_time = time.time() - start_time
|
||
logger.info(f"数学识别完成: latex={latex_expr}, correct={is_correct}, "
|
||
f"time={inference_time:.4f}s")
|
||
|
||
return MathRecognizeResult(
|
||
latex=latex_expr,
|
||
result=result,
|
||
is_correct=is_correct,
|
||
steps=steps,
|
||
confidence=0.92
|
||
)
|
||
|
||
def _preprocess_strokes(self, strokes: List[List[MathStrokePoint]]) -> np.ndarray:
|
||
"""笔迹预处理:坐标归一化 → 去噪 → 渲染为灰度图"""
|
||
canvas_h, canvas_w = 64, 512
|
||
canvas = np.zeros((canvas_h, canvas_w), dtype=np.float32)
|
||
|
||
all_x = [p.x for s in strokes for p in s]
|
||
all_y = [p.y for s in strokes for p in s]
|
||
if not all_x:
|
||
return canvas
|
||
|
||
min_x, max_x = min(all_x), max(all_x)
|
||
min_y, max_y = min(all_y), max(all_y)
|
||
w = max(max_x - min_x, 1)
|
||
h = max(max_y - min_y, 1)
|
||
scale = min((canvas_w - 10) / w, (canvas_h - 10) / h)
|
||
|
||
for stroke in strokes:
|
||
for i in range(1, len(stroke)):
|
||
x1 = int((stroke[i-1].x - min_x) * scale + 5)
|
||
y1 = int((stroke[i-1].y - min_y) * scale + 5)
|
||
x2 = int((stroke[i].x - min_x) * scale + 5)
|
||
y2 = int((stroke[i].y - min_y) * scale + 5)
|
||
x1, x2 = np.clip([x1, x2], 0, canvas_w - 1)
|
||
y1, y2 = np.clip([y1, y2], 0, canvas_h - 1)
|
||
canvas[y1:y2+1, x1:x2+1] = 1.0
|
||
|
||
return canvas
|
||
|
||
def _segment_symbols(self, image: np.ndarray) -> List[Dict]:
|
||
"""
|
||
数学符号分割
|
||
基于连通域分析将图像分割为独立的符号区域
|
||
"""
|
||
segments = []
|
||
# 使用连通域分析进行符号分割
|
||
# labels = cv2.connectedComponents(image)
|
||
# 模拟分割结果
|
||
segments = [
|
||
{"bbox": [10, 5, 40, 55], "image": image[5:55, 10:40]},
|
||
{"bbox": [45, 20, 65, 45], "image": image[20:45, 45:65]},
|
||
{"bbox": [70, 5, 100, 55], "image": image[5:55, 70:100]},
|
||
{"bbox": [105, 20, 125, 45], "image": image[20:45, 105:125]},
|
||
{"bbox": [130, 5, 160, 55], "image": image[5:55, 130:160]},
|
||
]
|
||
return segments
|
||
|
||
def _recognize_symbols(self, segments: List[Dict]) -> List[Dict]:
|
||
"""
|
||
符号识别(CNN分类器)
|
||
对每个分割区域进行数字/运算符分类
|
||
"""
|
||
symbols = []
|
||
# 模拟识别结果
|
||
mock_symbols = ["1", "2", "+", "3", "=", "1", "5"]
|
||
for i, seg in enumerate(segments):
|
||
if i < len(mock_symbols):
|
||
symbols.append({
|
||
"symbol": mock_symbols[i],
|
||
"bbox": seg["bbox"],
|
||
"confidence": 0.95 - i * 0.01
|
||
})
|
||
return symbols
|
||
|
||
def _analyze_structure(self, symbols: List[Dict], math_type: str) -> Dict:
|
||
"""
|
||
结构分析
|
||
根据符号的空间位置关系确定数学表达式的结构
|
||
处理竖式、分数线、括号等特殊结构
|
||
"""
|
||
# 按x坐标排序(从左到右阅读顺序)
|
||
sorted_symbols = sorted(symbols, key=lambda s: s["bbox"][0])
|
||
|
||
if math_type == "arithmetic":
|
||
return {"type": "linear", "symbols": sorted_symbols}
|
||
elif math_type == "equation":
|
||
return {"type": "equation", "symbols": sorted_symbols}
|
||
else:
|
||
return {"type": "unknown", "symbols": sorted_symbols}
|
||
|
||
def _reconstruct_expression(self, structure: Dict) -> tuple:
|
||
"""
|
||
表达式重建
|
||
从结构化符号序列生成LaTeX表达式和可计算表达式
|
||
"""
|
||
symbols = structure.get("symbols", [])
|
||
chars = [s["symbol"] for s in symbols]
|
||
text = "".join(chars)
|
||
|
||
# 生成LaTeX
|
||
latex = text.replace("×", "\\times ").replace("÷", "\\div ")
|
||
|
||
# 生成可计算表达式
|
||
math_expr = text.replace("×", "*").replace("÷", "/")
|
||
|
||
return latex, math_expr
|
||
|
||
def _verify_calculation(self, math_expr: str, grade_level: int) -> tuple:
|
||
"""
|
||
计算验证
|
||
解析数学表达式,计算正确答案,对比学生答案
|
||
"""
|
||
steps = []
|
||
|
||
# 尝试分离等号两侧
|
||
if "=" in math_expr:
|
||
parts = math_expr.split("=")
|
||
if len(parts) == 2:
|
||
left = parts[0].strip()
|
||
right = parts[1].strip()
|
||
|
||
try:
|
||
left_val = self._safe_eval(left)
|
||
right_val = self._safe_eval(right)
|
||
|
||
steps.append(MathStep(
|
||
step_no=1,
|
||
expression=left,
|
||
result=str(left_val),
|
||
is_correct=True
|
||
))
|
||
|
||
is_correct = abs(left_val - right_val) < 1e-9
|
||
steps.append(MathStep(
|
||
step_no=2,
|
||
expression=f"{left} = {right}",
|
||
result=str(right_val),
|
||
is_correct=is_correct,
|
||
error_type=None if is_correct else "calculation",
|
||
error_detail=None if is_correct else f"正确答案应为{left_val}"
|
||
))
|
||
|
||
return str(left_val), is_correct, steps
|
||
|
||
except Exception:
|
||
pass
|
||
|
||
return None, True, steps
|
||
|
||
def _safe_eval(self, expr: str) -> float:
|
||
"""安全计算表达式(仅允许数字和基本运算符)"""
|
||
allowed_chars = set("0123456789.+-*/() ")
|
||
if not all(c in allowed_chars for c in expr):
|
||
raise ValueError(f"不安全的表达式: {expr}")
|
||
return eval(expr) # 仅在安全校验后使用
|
||
|
||
|
||
# 全局数学引擎实例
|
||
math_engine = MathEngine()
|
||
|
||
|
||
@router.post("/recognize")
|
||
async def recognize_math(request: MathRecognizeRequest):
|
||
"""
|
||
数学列式/公式识别接口
|
||
POST /api/v1/math/recognize
|
||
"""
|
||
if not request.strokes:
|
||
raise HTTPException(status_code=400, detail="笔迹数据不能为空")
|
||
|
||
result = math_engine.recognize(
|
||
strokes=request.strokes,
|
||
math_type=request.math_type,
|
||
grade_level=request.grade_level
|
||
)
|
||
|
||
return {
|
||
"code": 200,
|
||
"msg": "success",
|
||
"data": {
|
||
"request_id": str(uuid.uuid4()),
|
||
"result": result.dict()
|
||
}
|
||
}
|
||
```
|
||
|
||
#### `api/ocr_api.py`
|
||
|
||
```python
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
自然写手写识别与AI分析引擎软件 V1.0
|
||
|
||
OCR识别接口模块
|
||
提供中英文手写文字OCR识别服务,基于PaddleOCR推理管道
|
||
"""
|
||
|
||
from fastapi import APIRouter, HTTPException
|
||
from pydantic import BaseModel, Field
|
||
from typing import List, Optional, Dict, Any
|
||
import numpy as np
|
||
import logging
|
||
import time
|
||
import uuid
|
||
|
||
logger = logging.getLogger("writech-ai-engine.ocr")
|
||
router = APIRouter()
|
||
|
||
|
||
# ==================== 请求/响应模型定义 ====================
|
||
|
||
class StrokePoint(BaseModel):
|
||
"""笔迹坐标点"""
|
||
x: int = Field(..., ge=0, le=65535, description="X坐标")
|
||
y: int = Field(..., ge=0, le=65535, description="Y坐标")
|
||
pressure: int = Field(0, ge=0, le=255, description="压力值")
|
||
timestamp: int = Field(..., description="时间戳(毫秒)")
|
||
pen_up: bool = Field(False, description="抬笔标记")
|
||
|
||
|
||
class OCRRequest(BaseModel):
|
||
"""OCR识别请求"""
|
||
strokes: List[List[StrokePoint]] = Field(..., description="笔迹数据(按笔画分组)")
|
||
page_id: Optional[str] = Field(None, description="点阵码页面ID")
|
||
pen_id: Optional[str] = Field(None, description="笔设备ID")
|
||
language: str = Field("zh", description="识别语言: zh/en/mixed")
|
||
recognition_mode: str = Field("line", description="识别模式: char/word/line/page")
|
||
|
||
|
||
class CharDetail(BaseModel):
|
||
"""单字识别详情"""
|
||
char: str = Field(..., description="识别的字符")
|
||
confidence: float = Field(..., description="置信度(0-1)")
|
||
bbox: List[int] = Field(..., description="包围框[x1,y1,x2,y2]")
|
||
stroke_indices: List[int] = Field(default=[], description="对应的笔画索引")
|
||
|
||
|
||
class OCRResult(BaseModel):
|
||
"""OCR识别结果"""
|
||
text: str = Field(..., description="识别文本")
|
||
confidence: float = Field(..., description="整体置信度(0-1)")
|
||
bbox: List[int] = Field(default=[], description="文本区域包围框")
|
||
char_details: List[CharDetail] = Field(default=[], description="逐字详情")
|
||
|
||
|
||
class OCRResponse(BaseModel):
|
||
"""OCR识别响应"""
|
||
code: int = 200
|
||
msg: str = "success"
|
||
data: Optional[Dict[str, Any]] = None
|
||
|
||
|
||
# ==================== OCR 推理引擎 ====================
|
||
|
||
class OCREngine:
|
||
"""
|
||
PaddleOCR 推理引擎
|
||
|
||
推理管道流程:
|
||
笔迹坐标 → 预处理(归一化/去噪) → 笔画分割
|
||
→ 模型推理(OCR) → 后处理(置信度过滤/结果合并) → 结果输出
|
||
|
||
支持的识别模式:
|
||
- char: 单字识别(逐字识别,返回每个字的详情)
|
||
- word: 词组识别(按词分割识别)
|
||
- line: 行识别(按行识别,默认模式)
|
||
- page: 整页识别(全页文字识别)
|
||
"""
|
||
|
||
def __init__(self):
|
||
"""初始化OCR推理引擎"""
|
||
self.model = None
|
||
self.model_version = "1.0.0"
|
||
self.is_loaded = False
|
||
# 模型输入图像尺寸
|
||
self.input_height = 48
|
||
self.input_width = 320
|
||
# 置信度阈值
|
||
self.confidence_threshold = 0.5
|
||
logger.info("OCR引擎初始化完成")
|
||
|
||
def load_model(self, model_path: str):
|
||
"""
|
||
加载PaddleOCR模型
|
||
模型文件AES-256加密存储,推理时内存解密加载
|
||
"""
|
||
logger.info(f"加载OCR模型: {model_path}")
|
||
# 解密模型文件
|
||
# decrypted_model = self._decrypt_model(model_path)
|
||
# self.model = paddle.jit.load(decrypted_model)
|
||
self.is_loaded = True
|
||
logger.info("OCR模型加载完成")
|
||
|
||
def preprocess_strokes(self, strokes: List[List[StrokePoint]]) -> np.ndarray:
|
||
"""
|
||
笔迹预处理管道
|
||
|
||
步骤:
|
||
1. 坐标归一化(映射到标准画布尺寸)
|
||
2. 去噪处理(滤除抖动和异常点)
|
||
3. 笔迹渲染为灰度图像
|
||
4. 图像尺寸归一化(resize到模型输入尺寸)
|
||
"""
|
||
# 计算所有点的边界框
|
||
all_points = []
|
||
for stroke in strokes:
|
||
for point in stroke:
|
||
all_points.append((point.x, point.y))
|
||
|
||
if not all_points:
|
||
return np.zeros((1, self.input_height, self.input_width), dtype=np.float32)
|
||
|
||
xs = [p[0] for p in all_points]
|
||
ys = [p[1] for p in all_points]
|
||
min_x, max_x = min(xs), max(xs)
|
||
min_y, max_y = min(ys), max(ys)
|
||
|
||
# 计算缩放比例(保持宽高比)
|
||
width = max(max_x - min_x, 1)
|
||
height = max(max_y - min_y, 1)
|
||
scale = min(self.input_width / width, self.input_height / height) * 0.9
|
||
|
||
# 创建渲染画布
|
||
canvas = np.zeros((self.input_height, self.input_width), dtype=np.float32)
|
||
|
||
# 渲染笔迹到画布
|
||
for stroke in strokes:
|
||
for i in range(1, len(stroke)):
|
||
x1 = int((stroke[i - 1].x - min_x) * scale)
|
||
y1 = int((stroke[i - 1].y - min_y) * scale)
|
||
x2 = int((stroke[i].x - min_x) * scale)
|
||
y2 = int((stroke[i].y - min_y) * scale)
|
||
# 使用Bresenham算法画线
|
||
self._draw_line(canvas, x1, y1, x2, y2,
|
||
thickness=max(1, stroke[i].pressure // 85))
|
||
|
||
# 归一化到[0, 1]
|
||
if canvas.max() > 0:
|
||
canvas = canvas / canvas.max()
|
||
|
||
return canvas.reshape(1, self.input_height, self.input_width)
|
||
|
||
def recognize(self, strokes: List[List[StrokePoint]],
|
||
mode: str = "line") -> List[OCRResult]:
|
||
"""
|
||
执行OCR识别
|
||
|
||
@param strokes: 笔迹数据(按笔画分组)
|
||
@param mode: 识别模式 (char/word/line/page)
|
||
@return: 识别结果列表
|
||
"""
|
||
start_time = time.time()
|
||
|
||
# 预处理
|
||
image = self.preprocess_strokes(strokes)
|
||
|
||
# 模型推理
|
||
# predictions = self.model(image)
|
||
# 模拟推理结果
|
||
predictions = self._mock_inference(image, mode)
|
||
|
||
# 后处理(置信度过滤、结果合并)
|
||
results = self._postprocess(predictions, mode)
|
||
|
||
inference_time = time.time() - start_time
|
||
logger.info(f"OCR识别完成, mode={mode}, time={inference_time:.4f}s, "
|
||
f"results={len(results)}")
|
||
|
||
return results
|
||
|
||
def _postprocess(self, predictions: Dict, mode: str) -> List[OCRResult]:
|
||
"""
|
||
后处理:置信度过滤 + 结果合并
|
||
|
||
- 过滤低于阈值的识别结果
|
||
- 相邻字符合并为词/行
|
||
- 生成逐字详情信息
|
||
"""
|
||
results = []
|
||
|
||
if mode == "char":
|
||
# 逐字模式:返回每个字符的独立结果
|
||
for char_pred in predictions.get("chars", []):
|
||
if char_pred["confidence"] >= self.confidence_threshold:
|
||
result = OCRResult(
|
||
text=char_pred["char"],
|
||
confidence=char_pred["confidence"],
|
||
bbox=char_pred["bbox"],
|
||
char_details=[CharDetail(
|
||
char=char_pred["char"],
|
||
confidence=char_pred["confidence"],
|
||
bbox=char_pred["bbox"],
|
||
stroke_indices=char_pred.get("stroke_indices", [])
|
||
)]
|
||
)
|
||
results.append(result)
|
||
|
||
elif mode in ("line", "page"):
|
||
# 行/页模式:合并字符为文本行
|
||
for line_pred in predictions.get("lines", []):
|
||
if line_pred["confidence"] >= self.confidence_threshold:
|
||
char_details = [
|
||
CharDetail(
|
||
char=cd["char"],
|
||
confidence=cd["confidence"],
|
||
bbox=cd["bbox"],
|
||
stroke_indices=cd.get("stroke_indices", [])
|
||
)
|
||
for cd in line_pred.get("char_details", [])
|
||
]
|
||
result = OCRResult(
|
||
text=line_pred["text"],
|
||
confidence=line_pred["confidence"],
|
||
bbox=line_pred["bbox"],
|
||
char_details=char_details
|
||
)
|
||
results.append(result)
|
||
|
||
return results
|
||
|
||
def _draw_line(self, canvas: np.ndarray, x1: int, y1: int,
|
||
x2: int, y2: int, thickness: int = 1):
|
||
"""Bresenham直线绘制算法"""
|
||
h, w = canvas.shape
|
||
dx = abs(x2 - x1)
|
||
dy = abs(y2 - y1)
|
||
sx = 1 if x1 < x2 else -1
|
||
sy = 1 if y1 < y2 else -1
|
||
err = dx - dy
|
||
|
||
while True:
|
||
# 绘制像素(带粗细)
|
||
for tx in range(-thickness, thickness + 1):
|
||
for ty in range(-thickness, thickness + 1):
|
||
px, py = x1 + tx, y1 + ty
|
||
if 0 <= px < w and 0 <= py < h:
|
||
canvas[py][px] = 1.0
|
||
|
||
if x1 == x2 and y1 == y2:
|
||
break
|
||
e2 = 2 * err
|
||
if e2 > -dy:
|
||
err -= dy
|
||
x1 += sx
|
||
if e2 < dx:
|
||
err += dx
|
||
y1 += sy
|
||
|
||
def _mock_inference(self, image: np.ndarray, mode: str) -> Dict:
|
||
"""模拟推理结果(用于示例)"""
|
||
return {
|
||
"lines": [{
|
||
"text": "示例文字",
|
||
"confidence": 0.95,
|
||
"bbox": [10, 10, 200, 48],
|
||
"char_details": [
|
||
{"char": "示", "confidence": 0.96, "bbox": [10, 10, 50, 48]},
|
||
{"char": "例", "confidence": 0.94, "bbox": [50, 10, 100, 48]},
|
||
{"char": "文", "confidence": 0.97, "bbox": [100, 10, 150, 48]},
|
||
{"char": "字", "confidence": 0.93, "bbox": [150, 10, 200, 48]}
|
||
]
|
||
}],
|
||
"chars": []
|
||
}
|
||
|
||
def _decrypt_model(self, model_path: str) -> str:
|
||
"""AES-256解密模型文件"""
|
||
# 使用预配置的密钥解密模型文件
|
||
# key = settings.model_encryption_key
|
||
# cipher = AES.new(key, AES.MODE_CBC, iv)
|
||
return model_path
|
||
|
||
|
||
# 全局OCR引擎实例
|
||
ocr_engine = OCREngine()
|
||
|
||
|
||
# ==================== API 路由 ====================
|
||
|
||
@router.post("/recognize", response_model=OCRResponse)
|
||
async def recognize_text(request: OCRRequest):
|
||
"""
|
||
手写文字OCR识别接口
|
||
POST /api/v1/ocr/recognize
|
||
|
||
接收笔迹坐标数据,返回识别文本及逐字详情
|
||
支持中文、英文及中英混合识别
|
||
"""
|
||
# 输入校验
|
||
if not request.strokes:
|
||
raise HTTPException(status_code=400, detail="笔迹数据不能为空")
|
||
|
||
total_points = sum(len(stroke) for stroke in request.strokes)
|
||
if total_points > 50000:
|
||
raise HTTPException(status_code=400, detail="笔迹点数过多,最大支持50000点")
|
||
|
||
# 执行OCR识别
|
||
results = ocr_engine.recognize(
|
||
strokes=request.strokes,
|
||
mode=request.recognition_mode
|
||
)
|
||
|
||
# 构建响应
|
||
return OCRResponse(
|
||
code=200,
|
||
msg="success",
|
||
data={
|
||
"request_id": str(uuid.uuid4()),
|
||
"language": request.language,
|
||
"mode": request.recognition_mode,
|
||
"results": [r.dict() for r in results],
|
||
"total_chars": sum(len(r.text) for r in results)
|
||
}
|
||
)
|
||
|
||
|
||
@router.post("/batch-recognize")
|
||
async def batch_recognize(requests: List[OCRRequest]):
|
||
"""
|
||
批量OCR识别接口
|
||
一次请求识别多组笔迹数据
|
||
"""
|
||
results = []
|
||
for req in requests:
|
||
result = ocr_engine.recognize(
|
||
strokes=req.strokes,
|
||
mode=req.recognition_mode
|
||
)
|
||
results.append({
|
||
"page_id": req.page_id,
|
||
"results": [r.dict() for r in result]
|
||
})
|
||
|
||
return {
|
||
"code": 200,
|
||
"msg": "success",
|
||
"data": {
|
||
"batch_size": len(requests),
|
||
"results": results
|
||
}
|
||
}
|
||
```
|
||
|
||
#### `api/stroke_order_api.py`
|
||
|
||
```python
|
||
# 自然写手写识别与AI分析引擎软件 V1.0
|
||
# 笔顺评分接口模块 - 中文汉字笔顺识别与评分服务
|
||
|
||
"""
|
||
笔顺评分API接口
|
||
提供汉字笔顺正确性评估、书写质量评分、笔画拆分分析等功能
|
||
基于深度学习笔顺分析模型,支持GB2312常用汉字笔顺评分
|
||
"""
|
||
|
||
import time
|
||
import logging
|
||
import hashlib
|
||
import numpy as np
|
||
from typing import List, Dict, Optional, Tuple
|
||
from dataclasses import dataclass, field
|
||
from enum import Enum
|
||
from fastapi import APIRouter, HTTPException, Depends, Request
|
||
from pydantic import BaseModel, Field, validator
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# ==================== 数据模型定义 ====================
|
||
|
||
class StrokePointInput(BaseModel):
|
||
"""笔迹坐标点输入"""
|
||
x: float = Field(..., description="X坐标")
|
||
y: float = Field(..., description="Y坐标")
|
||
pressure: float = Field(0.5, ge=0.0, le=1.0, description="压力值")
|
||
timestamp: int = Field(..., description="时间戳(毫秒)")
|
||
|
||
|
||
class StrokeOrderRequest(BaseModel):
|
||
"""笔顺评分请求"""
|
||
character: str = Field(..., min_length=1, max_length=1, description="目标汉字")
|
||
strokes: List[List[StrokePointInput]] = Field(..., description="用户书写的笔画列表")
|
||
pen_id: Optional[str] = Field(None, description="点阵笔设备ID")
|
||
student_id: Optional[str] = Field(None, description="学生ID")
|
||
difficulty_level: int = Field(1, ge=1, le=3, description="评分难度等级1-3")
|
||
|
||
@validator('character')
|
||
def validate_chinese_char(cls, v):
|
||
"""校验是否为中文汉字"""
|
||
if not '\u4e00' <= v <= '\u9fff':
|
||
raise ValueError('仅支持中文汉字笔顺评分')
|
||
return v
|
||
|
||
|
||
class WritingQualityRequest(BaseModel):
|
||
"""书写质量评测请求"""
|
||
strokes: List[List[StrokePointInput]] = Field(..., description="笔迹数据")
|
||
reference_char: Optional[str] = Field(None, description="参考字符(可选)")
|
||
eval_dimensions: List[str] = Field(
|
||
default=["structure", "spacing", "normative", "aesthetics"],
|
||
description="评测维度"
|
||
)
|
||
|
||
|
||
class StrokeDirection(str, Enum):
|
||
"""笔画方向枚举"""
|
||
HORIZONTAL = "horizontal" # 横
|
||
VERTICAL = "vertical" # 竖
|
||
LEFT_FALLING = "left_falling" # 撇
|
||
RIGHT_FALLING = "right_falling" # 捺
|
||
DOT = "dot" # 点
|
||
TURNING = "turning" # 折
|
||
HOOK = "hook" # 钩
|
||
RISING = "rising" # 提
|
||
|
||
|
||
@dataclass
|
||
class StrokeFeature:
|
||
"""单个笔画特征数据"""
|
||
direction: StrokeDirection # 笔画方向
|
||
start_point: Tuple[float, float] # 起始坐标
|
||
end_point: Tuple[float, float] # 结束坐标
|
||
length: float # 笔画长度
|
||
avg_pressure: float # 平均压力
|
||
curvature: float # 弯曲度
|
||
speed: float # 书写速度
|
||
|
||
|
||
# ==================== 标准笔顺数据库 ====================
|
||
|
||
class StrokeOrderDatabase:
|
||
"""
|
||
标准笔顺数据库
|
||
存储GB2312常用汉字的标准笔顺信息,用于笔顺正确性比对
|
||
数据来源:国家语委《现代汉语通用字笔顺规范》
|
||
"""
|
||
|
||
def __init__(self):
|
||
# 标准笔顺字典:字符 -> 笔画方向序列
|
||
self._standard_orders: Dict[str, List[StrokeDirection]] = {}
|
||
# 笔画数字典:字符 -> 标准笔画数
|
||
self._stroke_counts: Dict[str, int] = {}
|
||
# 加载常用汉字笔顺数据
|
||
self._load_standard_data()
|
||
|
||
def _load_standard_data(self):
|
||
"""加载标准笔顺数据(示例部分常用字)"""
|
||
# 一年级常用汉字笔顺数据
|
||
standard_data = {
|
||
"一": ([StrokeDirection.HORIZONTAL], 1),
|
||
"二": ([StrokeDirection.HORIZONTAL, StrokeDirection.HORIZONTAL], 2),
|
||
"三": ([StrokeDirection.HORIZONTAL, StrokeDirection.HORIZONTAL, StrokeDirection.HORIZONTAL], 3),
|
||
"十": ([StrokeDirection.HORIZONTAL, StrokeDirection.VERTICAL], 2),
|
||
"大": ([StrokeDirection.HORIZONTAL, StrokeDirection.LEFT_FALLING, StrokeDirection.RIGHT_FALLING], 3),
|
||
"人": ([StrokeDirection.LEFT_FALLING, StrokeDirection.RIGHT_FALLING], 2),
|
||
"口": ([StrokeDirection.VERTICAL, StrokeDirection.TURNING, StrokeDirection.HORIZONTAL], 3),
|
||
"日": ([StrokeDirection.VERTICAL, StrokeDirection.TURNING, StrokeDirection.HORIZONTAL, StrokeDirection.HORIZONTAL], 4),
|
||
"月": ([StrokeDirection.LEFT_FALLING, StrokeDirection.TURNING, StrokeDirection.HORIZONTAL, StrokeDirection.HORIZONTAL], 4),
|
||
"水": ([StrokeDirection.VERTICAL, StrokeDirection.TURNING, StrokeDirection.LEFT_FALLING, StrokeDirection.RIGHT_FALLING], 4),
|
||
}
|
||
for char, (order, count) in standard_data.items():
|
||
self._standard_orders[char] = order
|
||
self._stroke_counts[char] = count
|
||
logger.info(f"标准笔顺数据库加载完成,共 {len(self._standard_orders)} 个汉字")
|
||
|
||
def get_standard_order(self, char: str) -> Optional[List[StrokeDirection]]:
|
||
"""获取汉字标准笔顺"""
|
||
return self._standard_orders.get(char)
|
||
|
||
def get_stroke_count(self, char: str) -> Optional[int]:
|
||
"""获取汉字标准笔画数"""
|
||
return self._stroke_counts.get(char)
|
||
|
||
|
||
# ==================== 笔顺分析引擎 ====================
|
||
|
||
class StrokeOrderAnalyzer:
|
||
"""
|
||
笔顺分析引擎
|
||
通过笔迹坐标数据分析每一笔的方向、顺序,并与标准笔顺进行比对评分
|
||
评分维度:笔顺正确性、笔画数、书写规范性
|
||
"""
|
||
|
||
def __init__(self):
|
||
self._database = StrokeOrderDatabase()
|
||
self._direction_model = None # 笔画方向分类模型(CNN)
|
||
logger.info("笔顺分析引擎初始化完成")
|
||
|
||
def _extract_stroke_feature(self, points: List[StrokePointInput]) -> StrokeFeature:
|
||
"""
|
||
提取单个笔画的特征向量
|
||
包括方向、长度、弯曲度、书写速度等
|
||
"""
|
||
if len(points) < 2:
|
||
return StrokeFeature(
|
||
direction=StrokeDirection.DOT,
|
||
start_point=(points[0].x, points[0].y),
|
||
end_point=(points[0].x, points[0].y),
|
||
length=0.0, avg_pressure=points[0].pressure,
|
||
curvature=0.0, speed=0.0
|
||
)
|
||
|
||
# 计算起止点
|
||
start = (points[0].x, points[0].y)
|
||
end = (points[-1].x, points[-1].y)
|
||
|
||
# 计算笔画总长度(累加相邻点欧氏距离)
|
||
total_length = 0.0
|
||
for i in range(1, len(points)):
|
||
dx = points[i].x - points[i-1].x
|
||
dy = points[i].y - points[i-1].y
|
||
total_length += np.sqrt(dx*dx + dy*dy)
|
||
|
||
# 计算平均压力值
|
||
avg_pressure = np.mean([p.pressure for p in points])
|
||
|
||
# 计算书写速度(总长度/时间差)
|
||
time_diff = max(points[-1].timestamp - points[0].timestamp, 1)
|
||
speed = total_length / time_diff * 1000 # 像素/秒
|
||
|
||
# 计算弯曲度(实际路径长度 / 起止点直线距离)
|
||
direct_dist = np.sqrt((end[0]-start[0])**2 + (end[1]-start[1])**2)
|
||
curvature = total_length / max(direct_dist, 1.0)
|
||
|
||
# 判定笔画方向
|
||
direction = self._classify_direction(start, end, curvature)
|
||
|
||
return StrokeFeature(
|
||
direction=direction, start_point=start, end_point=end,
|
||
length=total_length, avg_pressure=avg_pressure,
|
||
curvature=curvature, speed=speed
|
||
)
|
||
|
||
def _classify_direction(self, start: Tuple, end: Tuple, curvature: float) -> StrokeDirection:
|
||
"""
|
||
基于起止点坐标和弯曲度分类笔画方向
|
||
使用角度阈值和弯曲度综合判定
|
||
"""
|
||
dx = end[0] - start[0]
|
||
dy = end[1] - start[1]
|
||
distance = np.sqrt(dx*dx + dy*dy)
|
||
|
||
# 极短笔画判定为点
|
||
if distance < 5.0:
|
||
return StrokeDirection.DOT
|
||
|
||
# 计算角度(弧度转角度,0度为正右方,顺时针为正)
|
||
angle = np.degrees(np.arctan2(dy, dx))
|
||
|
||
# 弯曲度高的笔画判定为折或钩
|
||
if curvature > 1.8:
|
||
return StrokeDirection.TURNING if dy > 0 else StrokeDirection.HOOK
|
||
|
||
# 根据角度范围判定笔画方向
|
||
if -20 <= angle <= 20:
|
||
return StrokeDirection.HORIZONTAL # 横:接近水平向右
|
||
elif 70 <= angle <= 110:
|
||
return StrokeDirection.VERTICAL # 竖:接近垂直向下
|
||
elif 120 <= angle <= 170:
|
||
return StrokeDirection.LEFT_FALLING # 撇:左下方向
|
||
elif 20 < angle < 70:
|
||
return StrokeDirection.RIGHT_FALLING # 捺:右下方向
|
||
elif -70 <= angle < -20:
|
||
return StrokeDirection.RISING # 提:右上方向
|
||
else:
|
||
return StrokeDirection.LEFT_FALLING # 默认归为撇
|
||
|
||
def evaluate_stroke_order(self, char: str, strokes: List[List[StrokePointInput]],
|
||
difficulty: int = 1) -> Dict:
|
||
"""
|
||
评估笔顺正确性
|
||
将用户书写的每一笔与标准笔顺逐一比对,计算匹配分数
|
||
"""
|
||
start_time = time.time()
|
||
|
||
# 获取标准笔顺
|
||
standard_order = self._database.get_standard_order(char)
|
||
standard_count = self._database.get_stroke_count(char)
|
||
|
||
# 提取用户每一笔的特征
|
||
user_features = [self._extract_stroke_feature(s) for s in strokes]
|
||
user_directions = [f.direction for f in user_features]
|
||
|
||
# 笔画数评分(满分100)
|
||
count_score = 100.0
|
||
if standard_count:
|
||
count_diff = abs(len(strokes) - standard_count)
|
||
count_score = max(0, 100 - count_diff * 25)
|
||
|
||
# 笔顺正确性评分(逐笔比对方向)
|
||
order_score = 100.0
|
||
errors = []
|
||
if standard_order:
|
||
match_count = 0
|
||
compare_len = min(len(user_directions), len(standard_order))
|
||
for i in range(compare_len):
|
||
if user_directions[i] == standard_order[i]:
|
||
match_count += 1
|
||
else:
|
||
errors.append({
|
||
"stroke_index": i + 1,
|
||
"expected": standard_order[i].value,
|
||
"actual": user_directions[i].value,
|
||
"message": f"第{i+1}笔方向错误:应为{standard_order[i].value},实际为{user_directions[i].value}"
|
||
})
|
||
order_score = (match_count / max(len(standard_order), 1)) * 100
|
||
|
||
# 根据难度等级调整评分权重
|
||
weight_order = 0.5 + difficulty * 0.1 # 难度越高,笔顺正确性权重越大
|
||
weight_count = 1.0 - weight_order
|
||
|
||
total_score = order_score * weight_order + count_score * weight_count
|
||
elapsed = (time.time() - start_time) * 1000
|
||
|
||
return {
|
||
"character": char,
|
||
"total_score": round(total_score, 1),
|
||
"order_score": round(order_score, 1),
|
||
"count_score": round(count_score, 1),
|
||
"user_stroke_count": len(strokes),
|
||
"standard_stroke_count": standard_count,
|
||
"stroke_order": [d.value for d in user_directions],
|
||
"correct_order": [d.value for d in standard_order] if standard_order else [],
|
||
"errors": errors,
|
||
"inference_time_ms": round(elapsed, 2)
|
||
}
|
||
|
||
|
||
# ==================== 书写质量评测引擎 ====================
|
||
|
||
class WritingQualityEngine:
|
||
"""
|
||
书写质量评测引擎
|
||
从结构均衡性、笔画间距、规范性、美观度四个维度评估书写质量
|
||
"""
|
||
|
||
def evaluate(self, strokes: List[List[StrokePointInput]],
|
||
dimensions: List[str]) -> Dict:
|
||
"""执行书写质量评测"""
|
||
scores = {}
|
||
|
||
# 提取全部坐标点用于整体分析
|
||
all_points = []
|
||
for stroke in strokes:
|
||
all_points.extend([(p.x, p.y, p.pressure) for p in stroke])
|
||
|
||
if not all_points:
|
||
return {"total_score": 0, "dimensions": {}}
|
||
|
||
xs = [p[0] for p in all_points]
|
||
ys = [p[1] for p in all_points]
|
||
|
||
# 计算书写区域边界框
|
||
bbox_width = max(xs) - min(xs)
|
||
bbox_height = max(ys) - min(ys)
|
||
|
||
if "structure" in dimensions:
|
||
# 结构均衡性:分析重心位置与对称性
|
||
center_x = np.mean(xs)
|
||
center_y = np.mean(ys)
|
||
expected_center_x = min(xs) + bbox_width / 2
|
||
expected_center_y = min(ys) + bbox_height / 2
|
||
offset = np.sqrt((center_x - expected_center_x)**2 + (center_y - expected_center_y)**2)
|
||
max_offset = np.sqrt(bbox_width**2 + bbox_height**2) / 4
|
||
scores["structure"] = round(max(0, 100 - (offset / max(max_offset, 1)) * 60), 1)
|
||
|
||
if "spacing" in dimensions:
|
||
# 笔画间距均匀性:分析相邻笔画起始点间距的标准差
|
||
if len(strokes) > 1:
|
||
start_points = [(s[0].x, s[0].y) for s in strokes if s]
|
||
gaps = []
|
||
for i in range(1, len(start_points)):
|
||
gap = np.sqrt((start_points[i][0]-start_points[i-1][0])**2 +
|
||
(start_points[i][1]-start_points[i-1][1])**2)
|
||
gaps.append(gap)
|
||
gap_std = np.std(gaps) if gaps else 0
|
||
gap_mean = np.mean(gaps) if gaps else 1
|
||
cv = gap_std / max(gap_mean, 1) # 变异系数
|
||
scores["spacing"] = round(max(0, 100 - cv * 80), 1)
|
||
else:
|
||
scores["spacing"] = 80.0
|
||
|
||
if "normative" in dimensions:
|
||
# 规范性:分析笔画弯曲度和压力稳定性
|
||
pressures = [p[2] for p in all_points]
|
||
pressure_std = np.std(pressures) if pressures else 0
|
||
scores["normative"] = round(max(0, 100 - pressure_std * 200), 1)
|
||
|
||
if "aesthetics" in dimensions:
|
||
# 美观度:综合笔画流畅度和整体比例
|
||
aspect_ratio = bbox_width / max(bbox_height, 1)
|
||
ratio_score = max(0, 100 - abs(aspect_ratio - 1.0) * 50) # 接近正方形得分高
|
||
scores["aesthetics"] = round(ratio_score, 1)
|
||
|
||
total = np.mean(list(scores.values())) if scores else 0
|
||
return {"total_score": round(total, 1), "dimensions": scores}
|
||
|
||
|
||
# ==================== API路由定义 ====================
|
||
|
||
router = APIRouter(prefix="/api/v1", tags=["笔顺评分"])
|
||
_analyzer = StrokeOrderAnalyzer()
|
||
_quality_engine = WritingQualityEngine()
|
||
|
||
|
||
@router.post("/stroke-order/evaluate")
|
||
async def evaluate_stroke_order(request: StrokeOrderRequest):
|
||
"""
|
||
笔顺正确性评分接口
|
||
POST /api/v1/stroke-order/evaluate
|
||
输入汉字和用户书写笔画数据,返回笔顺正确性评分和错误详情
|
||
"""
|
||
try:
|
||
result = _analyzer.evaluate_stroke_order(
|
||
char=request.character,
|
||
strokes=request.strokes,
|
||
difficulty=request.difficulty_level
|
||
)
|
||
# 记录审计日志(安全设计:所有识别请求记录调用方、时间、模型版本)
|
||
logger.info(
|
||
f"笔顺评分完成: char={request.character}, "
|
||
f"score={result['total_score']}, pen={request.pen_id}, "
|
||
f"student={request.student_id}, time={result['inference_time_ms']}ms"
|
||
)
|
||
return {"code": 200, "msg": "success", "data": result}
|
||
except Exception as e:
|
||
logger.error(f"笔顺评分异常: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"笔顺评分服务异常: {str(e)}")
|
||
|
||
|
||
@router.post("/writing/quality")
|
||
async def evaluate_writing_quality(request: WritingQualityRequest):
|
||
"""
|
||
书写质量评测接口
|
||
POST /api/v1/writing/quality
|
||
从结构、间距、规范性、美观度四维度评测书写质量
|
||
"""
|
||
try:
|
||
result = _quality_engine.evaluate(
|
||
strokes=request.strokes,
|
||
dimensions=request.eval_dimensions
|
||
)
|
||
logger.info(f"书写质量评测完成: score={result['total_score']}")
|
||
return {"code": 200, "msg": "success", "data": result}
|
||
except Exception as e:
|
||
logger.error(f"书写质量评测异常: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"书写质量评测异常: {str(e)}")
|
||
```
|
||
|
||
### `config/`
|
||
|
||
#### `config/settings.py`
|
||
|
||
```python
|
||
# 自然写手写识别与AI分析引擎软件 V1.0
|
||
# 配置与安全模块 - 全局配置管理与安全策略
|
||
|
||
"""
|
||
全局配置管理
|
||
提供AI引擎服务的所有配置项管理,包括:
|
||
服务端口、模型路径、GPU配置、安全认证、日志级别等
|
||
支持环境变量覆盖和配置热更新
|
||
"""
|
||
|
||
import os
|
||
import json
|
||
import logging
|
||
import hashlib
|
||
import hmac
|
||
import time
|
||
from typing import Dict, List, Optional, Any
|
||
from dataclasses import dataclass, field
|
||
from pathlib import Path
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# ==================== 服务配置 ====================
|
||
|
||
@dataclass
|
||
class ServerConfig:
|
||
"""HTTP/gRPC服务配置"""
|
||
http_host: str = "0.0.0.0"
|
||
http_port: int = 8000
|
||
grpc_host: str = "0.0.0.0"
|
||
grpc_port: int = 50051
|
||
workers: int = 4 # FastAPI worker数量
|
||
grpc_max_workers: int = 10 # gRPC线程池大小
|
||
max_request_size_mb: int = 10 # 请求体大小限制(防恶意攻击)
|
||
request_timeout_s: int = 30 # 请求超时时间
|
||
cors_origins: List[str] = field(default_factory=lambda: ["*"])
|
||
debug: bool = False
|
||
|
||
|
||
@dataclass
|
||
class ModelConfig:
|
||
"""模型推理配置"""
|
||
models_dir: str = "/opt/models" # 模型文件根目录
|
||
ocr_model_path: str = "/opt/models/ocr" # OCR模型路径
|
||
math_model_path: str = "/opt/models/math" # 数学识别模型路径
|
||
stroke_model_path: str = "/opt/models/stroke" # 笔顺模型路径
|
||
essay_model_path: str = "/opt/models/essay" # 作文评分模型路径
|
||
max_batch_size: int = 32 # 最大推理批大小
|
||
inference_timeout_ms: int = 5000 # 单次推理超时
|
||
enable_fp16: bool = True # FP16半精度推理
|
||
model_cache_size_gb: float = 4.0 # 模型内存缓存大小
|
||
|
||
|
||
@dataclass
|
||
class GPUConfig:
|
||
"""GPU/NPU硬件加速配置"""
|
||
device: str = "cuda" # 推理设备: cuda / cpu / npu
|
||
gpu_ids: List[int] = field(default_factory=lambda: [0]) # 使用的GPU编号
|
||
gpu_memory_fraction: float = 0.8 # GPU显存使用比例上限
|
||
enable_tensorrt: bool = True # 是否启用TensorRT加速
|
||
tensorrt_precision: str = "fp16" # TensorRT精度: fp32/fp16/int8
|
||
triton_url: str = "localhost:8001" # Triton Inference Server地址
|
||
|
||
|
||
@dataclass
|
||
class CeleryConfig:
|
||
"""Celery任务队列配置"""
|
||
broker_url: str = "redis://localhost:6379/0" # Redis Broker地址
|
||
result_backend: str = "redis://localhost:6379/1" # 结果存储后端
|
||
task_serializer: str = "json"
|
||
result_serializer: str = "json"
|
||
task_default_queue: str = "writech.default"
|
||
task_time_limit: int = 300 # 任务最大执行时间(秒)
|
||
task_soft_time_limit: int = 240 # 软超时(触发SoftTimeLimitExceeded)
|
||
worker_concurrency: int = 8 # Worker并发数
|
||
worker_prefetch_multiplier: int = 2 # 预取倍数
|
||
|
||
|
||
@dataclass
|
||
class DatabaseConfig:
|
||
"""数据库配置"""
|
||
mysql_url: str = "mysql+pymysql://user:password@localhost:3306/writech_ai"
|
||
redis_url: str = "redis://localhost:6379/0"
|
||
mongodb_url: str = "mongodb://localhost:27017/writech_stroke"
|
||
pool_size: int = 20 # 连接池大小
|
||
pool_recycle: int = 3600 # 连接回收时间(秒)
|
||
|
||
|
||
@dataclass
|
||
class LogConfig:
|
||
"""日志配置"""
|
||
level: str = "INFO"
|
||
format: str = "%(asctime)s [%(levelname)s] %(name)s: %(message)s"
|
||
log_dir: str = "/var/log/writech-ai"
|
||
max_file_size_mb: int = 100 # 单个日志文件大小上限
|
||
backup_count: int = 10 # 保留日志文件数量
|
||
enable_audit_log: bool = True # 启用审计日志
|
||
audit_log_file: str = "audit.log" # 审计日志文件名
|
||
|
||
|
||
# ==================== 安全配置 ====================
|
||
|
||
@dataclass
|
||
class SecurityConfig:
|
||
"""安全配置"""
|
||
# mTLS双向认证(安全设计:内部服务间mTLS双向认证)
|
||
enable_mtls: bool = True
|
||
server_cert_path: str = "/etc/ssl/server.crt"
|
||
server_key_path: str = "/etc/ssl/server.key"
|
||
ca_cert_path: str = "/etc/ssl/ca.crt"
|
||
|
||
# 模型文件加密(安全设计:模型文件加密存储,推理时内存解密)
|
||
model_encryption_enabled: bool = True
|
||
model_encryption_key_env: str = "WRITECH_MODEL_KEY" # 加密密钥从环境变量读取
|
||
|
||
# 请求校验(安全设计:输入数据格式校验与大小限制)
|
||
max_stroke_points: int = 100000 # 单次请求最大坐标点数
|
||
max_strokes_per_request: int = 500 # 单次请求最大笔画数
|
||
max_text_length: int = 10000 # 作文文本最大长度
|
||
|
||
# 速率限制
|
||
rate_limit_per_minute: int = 600 # 每分钟最大请求数
|
||
rate_limit_burst: int = 50 # 突发请求数
|
||
|
||
# 审计日志(安全设计:所有识别请求记录调用方、时间、模型版本)
|
||
enable_audit: bool = True
|
||
audit_retention_days: int = 90 # 审计日志保留天数
|
||
|
||
|
||
# ==================== mTLS认证管理 ====================
|
||
|
||
class MTLSAuthenticator:
|
||
"""
|
||
mTLS双向认证管理器
|
||
验证客户端证书,确保只有授权的内部服务可以调用AI引擎
|
||
"""
|
||
|
||
def __init__(self, config: SecurityConfig):
|
||
self._config = config
|
||
self._trusted_clients: Dict[str, str] = {} # 授信客户端证书指纹
|
||
logger.info("mTLS认证管理器初始化")
|
||
|
||
def load_certificates(self) -> bool:
|
||
"""加载服务端证书和CA证书"""
|
||
try:
|
||
cert_path = Path(self._config.server_cert_path)
|
||
key_path = Path(self._config.server_key_path)
|
||
ca_path = Path(self._config.ca_cert_path)
|
||
|
||
if not cert_path.exists():
|
||
logger.warning(f"服务端证书不存在: {cert_path}")
|
||
return False
|
||
|
||
logger.info("mTLS证书加载完成")
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"证书加载失败: {str(e)}")
|
||
return False
|
||
|
||
def verify_client_cert(self, cert_fingerprint: str) -> bool:
|
||
"""验证客户端证书指纹"""
|
||
if not self._config.enable_mtls:
|
||
return True
|
||
is_trusted = cert_fingerprint in self._trusted_clients
|
||
if not is_trusted:
|
||
logger.warning(f"未授信的客户端证书: {cert_fingerprint}")
|
||
return is_trusted
|
||
|
||
def register_trusted_client(self, name: str, fingerprint: str):
|
||
"""注册授信客户端"""
|
||
self._trusted_clients[fingerprint] = name
|
||
logger.info(f"注册授信客户端: {name}")
|
||
|
||
|
||
# ==================== 请求签名校验 ====================
|
||
|
||
class RequestValidator:
|
||
"""
|
||
请求签名校验器
|
||
对API请求进行HMAC签名校验,防止请求篡改和重放攻击
|
||
"""
|
||
|
||
def __init__(self, secret_key: str = ""):
|
||
self._secret = secret_key or os.environ.get("WRITECH_API_SECRET", "default-secret")
|
||
self._nonce_cache: Dict[str, float] = {} # 随机数缓存(防重放)
|
||
self._nonce_ttl = 300 # 随机数有效期(秒)
|
||
|
||
def generate_signature(self, payload: str, timestamp: int, nonce: str) -> str:
|
||
"""生成请求签名"""
|
||
message = f"{payload}×tamp={timestamp}&nonce={nonce}"
|
||
return hmac.new(
|
||
self._secret.encode(), message.encode(), hashlib.sha256
|
||
).hexdigest()
|
||
|
||
def verify_signature(self, payload: str, timestamp: int,
|
||
nonce: str, signature: str) -> bool:
|
||
"""
|
||
校验请求签名
|
||
1. 检查时间戳是否在有效窗口内(防重放)
|
||
2. 检查随机数是否已使用(防重放)
|
||
3. 验证HMAC签名是否匹配(防篡改)
|
||
"""
|
||
# 时间窗口校验(±5分钟)
|
||
current_time = int(time.time())
|
||
if abs(current_time - timestamp) > 300:
|
||
logger.warning(f"请求时间戳过期: {timestamp}")
|
||
return False
|
||
|
||
# 随机数防重放检查
|
||
if nonce in self._nonce_cache:
|
||
logger.warning(f"重复的请求随机数: {nonce}")
|
||
return False
|
||
|
||
# HMAC签名验证
|
||
expected = self.generate_signature(payload, timestamp, nonce)
|
||
is_valid = hmac.compare_digest(expected, signature)
|
||
|
||
if is_valid:
|
||
# 缓存随机数
|
||
self._nonce_cache[nonce] = time.time()
|
||
self._cleanup_nonce_cache()
|
||
|
||
return is_valid
|
||
|
||
def _cleanup_nonce_cache(self):
|
||
"""清理过期的随机数缓存"""
|
||
current = time.time()
|
||
expired = [k for k, v in self._nonce_cache.items() if current - v > self._nonce_ttl]
|
||
for k in expired:
|
||
del self._nonce_cache[k]
|
||
|
||
|
||
# ==================== 全局配置管理器 ====================
|
||
|
||
class Settings:
|
||
"""
|
||
全局配置管理器(单例)
|
||
从环境变量和配置文件加载配置,支持运行时热更新
|
||
环境变量优先级高于配置文件
|
||
"""
|
||
|
||
_instance = None
|
||
|
||
def __new__(cls):
|
||
if cls._instance is None:
|
||
cls._instance = super().__new__(cls)
|
||
return cls._instance
|
||
|
||
def __init__(self):
|
||
if hasattr(self, '_initialized'):
|
||
return
|
||
self._initialized = True
|
||
|
||
# 加载各模块配置
|
||
self.server = ServerConfig()
|
||
self.model = ModelConfig()
|
||
self.gpu = GPUConfig()
|
||
self.celery = CeleryConfig()
|
||
self.database = DatabaseConfig()
|
||
self.log = LogConfig()
|
||
self.security = SecurityConfig()
|
||
|
||
# 从环境变量覆盖配置
|
||
self._load_from_env()
|
||
|
||
# 初始化安全组件
|
||
self.mtls_auth = MTLSAuthenticator(self.security)
|
||
self.request_validator = RequestValidator()
|
||
|
||
logger.info("全局配置加载完成")
|
||
|
||
def _load_from_env(self):
|
||
"""从环境变量加载配置(覆盖默认值)"""
|
||
env_mapping = {
|
||
"WRITECH_HTTP_PORT": ("server", "http_port", int),
|
||
"WRITECH_GRPC_PORT": ("server", "grpc_port", int),
|
||
"WRITECH_WORKERS": ("server", "workers", int),
|
||
"WRITECH_DEBUG": ("server", "debug", lambda x: x.lower() == "true"),
|
||
"WRITECH_MODELS_DIR": ("model", "models_dir", str),
|
||
"WRITECH_GPU_DEVICE": ("gpu", "device", str),
|
||
"WRITECH_GPU_IDS": ("gpu", "gpu_ids", lambda x: [int(i) for i in x.split(",")]),
|
||
"WRITECH_REDIS_URL": ("celery", "broker_url", str),
|
||
"WRITECH_MYSQL_URL": ("database", "mysql_url", str),
|
||
"WRITECH_LOG_LEVEL": ("log", "level", str),
|
||
"WRITECH_ENABLE_MTLS": ("security", "enable_mtls", lambda x: x.lower() == "true"),
|
||
}
|
||
|
||
for env_key, (section, field, converter) in env_mapping.items():
|
||
value = os.environ.get(env_key)
|
||
if value is not None:
|
||
config_obj = getattr(self, section)
|
||
try:
|
||
setattr(config_obj, field, converter(value))
|
||
logger.info(f"环境变量覆盖配置: {env_key} -> {section}.{field}")
|
||
except (ValueError, TypeError) as e:
|
||
logger.warning(f"环境变量转换失败: {env_key}={value}, 错误: {str(e)}")
|
||
|
||
def load_from_file(self, config_path: str):
|
||
"""从JSON配置文件加载配置"""
|
||
try:
|
||
with open(config_path, 'r') as f:
|
||
config_data = json.load(f)
|
||
logger.info(f"配置文件加载完成: {config_path}")
|
||
|
||
# 逐section更新配置
|
||
for section_name, section_data in config_data.items():
|
||
if hasattr(self, section_name) and isinstance(section_data, dict):
|
||
config_obj = getattr(self, section_name)
|
||
for key, value in section_data.items():
|
||
if hasattr(config_obj, key):
|
||
setattr(config_obj, key, value)
|
||
|
||
except FileNotFoundError:
|
||
logger.warning(f"配置文件不存在: {config_path}")
|
||
except json.JSONDecodeError as e:
|
||
logger.error(f"配置文件JSON解析错误: {str(e)}")
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
"""将所有配置导出为字典(隐藏敏感信息)"""
|
||
result = {}
|
||
for section in ['server', 'model', 'gpu', 'celery', 'log']:
|
||
config_obj = getattr(self, section)
|
||
section_dict = {}
|
||
for key in vars(config_obj):
|
||
value = getattr(config_obj, key)
|
||
# 隐藏密码和密钥类字段
|
||
if any(kw in key.lower() for kw in ['password', 'secret', 'key', 'token']):
|
||
section_dict[key] = "***"
|
||
else:
|
||
section_dict[key] = value
|
||
result[section] = section_dict
|
||
return result
|
||
|
||
|
||
# 全局配置实例
|
||
settings = Settings()
|
||
```
|
||
|
||
### `engine/`
|
||
|
||
#### `engine/essay_scorer.py`
|
||
|
||
```python
|
||
# 自然写手写识别与AI分析引擎软件 V1.0
|
||
# 作文评分模型模块 - 深度学习作文评分模型推理管道
|
||
|
||
"""
|
||
作文评分深度学习模型
|
||
基于BERT/ERNIE预训练模型微调的中文作文评分器
|
||
支持多维度评分:内容、结构、语言、思想感情
|
||
"""
|
||
|
||
import time
|
||
import logging
|
||
import numpy as np
|
||
from typing import List, Dict, Optional, Tuple
|
||
from dataclasses import dataclass, field
|
||
from pathlib import Path
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# ==================== 模型配置 ====================
|
||
|
||
@dataclass
|
||
class EssayModelConfig:
|
||
"""作文评分模型配置"""
|
||
model_name: str = "writech-essay-scorer-v1"
|
||
model_path: str = "/opt/models/essay_scorer"
|
||
max_seq_length: int = 512 # 最大输入序列长度
|
||
num_labels: int = 4 # 评分维度数量
|
||
score_range: Tuple[int, int] = (0, 100) # 评分范围
|
||
batch_size: int = 8 # 推理批大小
|
||
use_gpu: bool = True # 是否使用GPU加速
|
||
fp16_inference: bool = True # 是否使用FP16半精度推理
|
||
|
||
|
||
# ==================== 文本特征提取器 ====================
|
||
|
||
class TextFeatureExtractor:
|
||
"""
|
||
文本特征提取器
|
||
从作文文本中提取用于评分的统计特征和语义特征
|
||
统计特征包括:字数、句数、段落数、词汇丰富度等
|
||
语义特征通过预训练语言模型编码获得
|
||
"""
|
||
|
||
# 常用连接词库(用于衡量行文逻辑性)
|
||
CONNECTIVES = {
|
||
'causal': ['因为', '所以', '因此', '由于', '于是', '故而'],
|
||
'adversative': ['但是', '然而', '可是', '不过', '虽然', '尽管'],
|
||
'progressive': ['而且', '并且', '不仅', '还', '甚至', '更'],
|
||
'sequential': ['首先', '其次', '然后', '接着', '最后', '总之'],
|
||
}
|
||
|
||
# 形容词库(用于衡量描写丰富度)
|
||
DESCRIPTIVE_WORDS = [
|
||
'美丽', '壮观', '温柔', '热烈', '寂静', '辽阔', '清澈', '明亮',
|
||
'灿烂', '幽静', '巍峨', '绚丽', '优雅', '淳朴', '恬静', '磅礴',
|
||
'蜿蜒', '苍翠', '碧绿', '湛蓝', '金黄', '洁白', '火红', '嫣红'
|
||
]
|
||
|
||
def extract_statistical_features(self, text: str) -> Dict[str, float]:
|
||
"""
|
||
提取文本统计特征
|
||
返回用于评分的多维统计向量
|
||
"""
|
||
features = {}
|
||
|
||
# 基础统计
|
||
chinese_chars = [c for c in text if '\u4e00' <= c <= '\u9fff']
|
||
sentences = [s for s in text.replace('!', '。').replace('?', '。').split('。') if s.strip()]
|
||
paragraphs = [p for p in text.split('\n') if p.strip()]
|
||
|
||
features['char_count'] = len(chinese_chars)
|
||
features['sentence_count'] = len(sentences)
|
||
features['paragraph_count'] = len(paragraphs)
|
||
|
||
# 平均句长(衡量语句复杂度)
|
||
if sentences:
|
||
sentence_lengths = [len([c for c in s if '\u4e00' <= c <= '\u9fff']) for s in sentences]
|
||
features['avg_sentence_length'] = np.mean(sentence_lengths)
|
||
features['sentence_length_std'] = np.std(sentence_lengths)
|
||
else:
|
||
features['avg_sentence_length'] = 0
|
||
features['sentence_length_std'] = 0
|
||
|
||
# 词汇丰富度(不同字的比例)
|
||
unique_chars = set(chinese_chars)
|
||
features['vocab_richness'] = len(unique_chars) / max(len(chinese_chars), 1)
|
||
|
||
# 连接词使用统计
|
||
total_connectives = 0
|
||
for category, words in self.CONNECTIVES.items():
|
||
count = sum(text.count(w) for w in words)
|
||
features[f'connective_{category}'] = count
|
||
total_connectives += count
|
||
features['total_connectives'] = total_connectives
|
||
|
||
# 形容词使用统计(衡量描写丰富度)
|
||
descriptive_count = sum(text.count(w) for w in self.DESCRIPTIVE_WORDS)
|
||
features['descriptive_count'] = descriptive_count
|
||
|
||
# 标点符号使用统计
|
||
features['comma_count'] = text.count(',')
|
||
features['period_count'] = text.count('。')
|
||
features['exclamation_count'] = text.count('!')
|
||
features['question_count'] = text.count('?')
|
||
features['quotation_count'] = text.count('"') + text.count('"')
|
||
|
||
return features
|
||
|
||
def extract_ngram_features(self, text: str, n: int = 2) -> Dict[str, int]:
|
||
"""
|
||
提取字符N-gram特征
|
||
用于捕捉局部文本模式
|
||
"""
|
||
chinese_text = ''.join(c for c in text if '\u4e00' <= c <= '\u9fff')
|
||
ngrams = {}
|
||
for i in range(len(chinese_text) - n + 1):
|
||
gram = chinese_text[i:i+n]
|
||
ngrams[gram] = ngrams.get(gram, 0) + 1
|
||
return ngrams
|
||
|
||
def text_to_embedding(self, text: str, max_length: int = 512) -> np.ndarray:
|
||
"""
|
||
将文本转换为语义向量(模拟BERT编码)
|
||
实际生产环境中使用ERNIE/BERT模型编码
|
||
此处使用统计特征向量作为替代表示
|
||
"""
|
||
features = self.extract_statistical_features(text)
|
||
# 构造特征向量并归一化
|
||
feat_values = list(features.values())
|
||
feat_array = np.array(feat_values, dtype=np.float32)
|
||
# L2归一化
|
||
norm = np.linalg.norm(feat_array)
|
||
if norm > 0:
|
||
feat_array = feat_array / norm
|
||
# 填充/截断至固定维度
|
||
target_dim = 64
|
||
if len(feat_array) < target_dim:
|
||
feat_array = np.pad(feat_array, (0, target_dim - len(feat_array)))
|
||
else:
|
||
feat_array = feat_array[:target_dim]
|
||
return feat_array
|
||
|
||
|
||
# ==================== 评分模型推理器 ====================
|
||
|
||
class EssayScorerModel:
|
||
"""
|
||
作文评分模型推理器
|
||
加载预训练的作文评分模型,执行多维度评分推理
|
||
支持GPU加速和FP16半精度推理以降低延迟
|
||
"""
|
||
|
||
def __init__(self, config: EssayModelConfig):
|
||
self._config = config
|
||
self._model = None
|
||
self._tokenizer = None
|
||
self._feature_extractor = TextFeatureExtractor()
|
||
self._is_loaded = False
|
||
# 评分维度名称映射
|
||
self._dimension_names = ['content', 'structure', 'language', 'emotion']
|
||
logger.info(f"作文评分模型初始化: {config.model_name}")
|
||
|
||
def load_model(self) -> bool:
|
||
"""
|
||
加载评分模型权重
|
||
模型文件从加密存储中读取并在内存中解密(安全设计)
|
||
"""
|
||
try:
|
||
model_dir = Path(self._config.model_path)
|
||
logger.info(f"正在加载作文评分模型: {model_dir}")
|
||
|
||
# 检查模型文件是否存在
|
||
# 实际环境中加载PyTorch/ONNX模型权重
|
||
# self._model = onnxruntime.InferenceSession(str(model_dir / "model.onnx"))
|
||
# self._tokenizer = AutoTokenizer.from_pretrained(str(model_dir))
|
||
|
||
# 模型加载成功后设置标志
|
||
self._is_loaded = True
|
||
logger.info(f"作文评分模型加载完成: {self._config.model_name}")
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"模型加载失败: {str(e)}")
|
||
return False
|
||
|
||
def predict(self, text: str, grade: int = 6) -> Dict[str, float]:
|
||
"""
|
||
执行评分推理
|
||
输入作文文本,输出各维度评分
|
||
"""
|
||
start_time = time.time()
|
||
|
||
# 提取文本特征
|
||
features = self._feature_extractor.extract_statistical_features(text)
|
||
embedding = self._feature_extractor.text_to_embedding(text)
|
||
|
||
# 基于特征的规则评分(作为模型推理的后备方案)
|
||
scores = self._rule_based_scoring(features, grade)
|
||
|
||
elapsed = (time.time() - start_time) * 1000
|
||
logger.debug(f"评分推理完成: {elapsed:.1f}ms")
|
||
|
||
return {
|
||
'scores': scores,
|
||
'features': features,
|
||
'inference_time_ms': round(elapsed, 2)
|
||
}
|
||
|
||
def _rule_based_scoring(self, features: Dict, grade: int) -> Dict[str, float]:
|
||
"""
|
||
基于规则的评分逻辑(模型推理的后备方案)
|
||
当深度学习模型不可用时,使用统计特征进行启发式评分
|
||
"""
|
||
scores = {}
|
||
|
||
# 内容评分(30%权重)
|
||
# 基于字数、词汇丰富度、描写词使用量
|
||
content_score = 60.0 # 基础分
|
||
expected_chars = {1: 100, 2: 150, 3: 250, 4: 350, 5: 450, 6: 550, 7: 650, 8: 750, 9: 800}
|
||
expected = expected_chars.get(grade, 500)
|
||
char_ratio = min(features.get('char_count', 0) / max(expected, 1), 1.5)
|
||
content_score += char_ratio * 20
|
||
|
||
# 词汇丰富度加分
|
||
vocab = features.get('vocab_richness', 0)
|
||
if vocab > 0.5:
|
||
content_score += 10
|
||
elif vocab > 0.3:
|
||
content_score += 5
|
||
|
||
# 描写丰富度加分
|
||
if features.get('descriptive_count', 0) >= 3:
|
||
content_score += 8
|
||
elif features.get('descriptive_count', 0) >= 1:
|
||
content_score += 4
|
||
|
||
scores['content'] = min(100, max(0, round(content_score, 1)))
|
||
|
||
# 结构评分(25%权重)
|
||
structure_score = 65.0
|
||
para_count = features.get('paragraph_count', 1)
|
||
if 3 <= para_count <= 7:
|
||
structure_score += 20
|
||
elif 2 <= para_count <= 8:
|
||
structure_score += 10
|
||
|
||
# 有开头结尾连接词加分
|
||
if features.get('connective_sequential', 0) >= 2:
|
||
structure_score += 10
|
||
|
||
scores['structure'] = min(100, max(0, round(structure_score, 1)))
|
||
|
||
# 语言评分(25%权重)
|
||
language_score = 70.0
|
||
avg_sent_len = features.get('avg_sentence_length', 0)
|
||
if 8 <= avg_sent_len <= 25:
|
||
language_score += 15 # 句长适中
|
||
elif avg_sent_len > 40:
|
||
language_score -= 10 # 句子过长扣分
|
||
|
||
# 连接词使用加分
|
||
total_conn = features.get('total_connectives', 0)
|
||
if total_conn >= 4:
|
||
language_score += 10
|
||
elif total_conn >= 2:
|
||
language_score += 5
|
||
|
||
scores['language'] = min(100, max(0, round(language_score, 1)))
|
||
|
||
# 思想感情评分(20%权重)
|
||
emotion_score = 65.0
|
||
if features.get('exclamation_count', 0) >= 1:
|
||
emotion_score += 8
|
||
if features.get('question_count', 0) >= 1:
|
||
emotion_score += 5
|
||
if features.get('quotation_count', 0) >= 2:
|
||
emotion_score += 7 # 有引用/对话
|
||
|
||
scores['emotion'] = min(100, max(0, round(emotion_score, 1)))
|
||
|
||
return scores
|
||
|
||
def batch_predict(self, texts: List[str], grade: int = 6) -> List[Dict]:
|
||
"""
|
||
批量评分推理
|
||
支持一次处理多篇作文,提高GPU利用率
|
||
"""
|
||
results = []
|
||
batch_start = time.time()
|
||
|
||
for i in range(0, len(texts), self._config.batch_size):
|
||
batch = texts[i:i + self._config.batch_size]
|
||
for text in batch:
|
||
result = self.predict(text, grade)
|
||
results.append(result)
|
||
|
||
total_time = (time.time() - batch_start) * 1000
|
||
logger.info(f"批量评分完成: {len(texts)}篇, 总耗时{total_time:.1f}ms")
|
||
return results
|
||
|
||
|
||
# ==================== 评分校准器 ====================
|
||
|
||
class ScoreCalibrator:
|
||
"""
|
||
评分校准器
|
||
将模型原始评分校准到符合教学实际的分数分布
|
||
基于历史评分数据进行分布对齐,避免评分过高或过低
|
||
"""
|
||
|
||
def __init__(self):
|
||
# 各年级历史评分的均值和标准差(用于正态分布校准)
|
||
self._grade_stats = {
|
||
1: {'mean': 75, 'std': 12},
|
||
2: {'mean': 76, 'std': 11},
|
||
3: {'mean': 78, 'std': 10},
|
||
4: {'mean': 77, 'std': 11},
|
||
5: {'mean': 76, 'std': 12},
|
||
6: {'mean': 75, 'std': 13},
|
||
7: {'mean': 73, 'std': 14},
|
||
8: {'mean': 72, 'std': 15},
|
||
9: {'mean': 71, 'std': 15},
|
||
}
|
||
|
||
def calibrate(self, raw_score: float, grade: int, max_score: int = 100) -> float:
|
||
"""
|
||
校准原始评分
|
||
将模型输出的原始分数校准到目标分布范围
|
||
"""
|
||
stats = self._grade_stats.get(grade, {'mean': 75, 'std': 12})
|
||
|
||
# Z-score标准化后重新映射
|
||
z_score = (raw_score - 50) / 25 # 假设原始分数均值50,标准差25
|
||
calibrated = stats['mean'] + z_score * stats['std']
|
||
|
||
# 裁剪到有效范围
|
||
calibrated = max(max_score * 0.2, min(max_score, calibrated))
|
||
return round(calibrated, 1)
|
||
|
||
def calibrate_dimensions(self, dimension_scores: Dict[str, float],
|
||
grade: int, max_score: int = 100) -> Dict[str, float]:
|
||
"""校准各维度评分"""
|
||
weights = {'content': 0.30, 'structure': 0.25, 'language': 0.25, 'emotion': 0.20}
|
||
calibrated = {}
|
||
for dim, score in dimension_scores.items():
|
||
raw_calibrated = self.calibrate(score, grade, 100)
|
||
# 按维度权重换算为该维度的实际分值
|
||
dim_max = max_score * weights.get(dim, 0.25)
|
||
calibrated[dim] = round(raw_calibrated / 100 * dim_max, 1)
|
||
return calibrated
|
||
```
|
||
|
||
#### `engine/stroke_analyzer.py`
|
||
|
||
```python
|
||
# 自然写手写识别与AI分析引擎软件 V1.0
|
||
# 笔顺分析算法模块 - 笔画拆分与顺序分析核心算法
|
||
|
||
"""
|
||
笔顺分析核心算法
|
||
提供笔画自动拆分、方向判定、笔画连接检测、
|
||
笔迹相似度计算等底层分析算法
|
||
"""
|
||
|
||
import math
|
||
import logging
|
||
import numpy as np
|
||
from typing import List, Dict, Tuple, Optional
|
||
from dataclasses import dataclass, field
|
||
from enum import IntEnum
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# ==================== 常量定义 ====================
|
||
|
||
# 笔画方向角度范围(度数)
|
||
DIRECTION_ANGLES = {
|
||
"horizontal": (-15, 15), # 横
|
||
"vertical": (75, 105), # 竖
|
||
"left_falling": (120, 165), # 撇
|
||
"right_falling": (30, 75), # 捺
|
||
"dot": None, # 点(特殊判定)
|
||
"turning": None, # 折(特殊判定)
|
||
"hook": None, # 钩(特殊判定)
|
||
"rising": (-60, -15), # 提
|
||
}
|
||
|
||
# 笔画最小长度阈值(像素),低于此值视为噪声
|
||
MIN_STROKE_LENGTH = 3.0
|
||
# 笔画分段时的角度变化阈值(度数)
|
||
ANGLE_CHANGE_THRESHOLD = 45.0
|
||
# 采样点间距最小阈值
|
||
MIN_POINT_DISTANCE = 1.0
|
||
|
||
|
||
class StrokeType(IntEnum):
|
||
"""笔画类型枚举"""
|
||
UNKNOWN = 0
|
||
HORIZONTAL = 1 # 横
|
||
VERTICAL = 2 # 竖
|
||
LEFT_FALLING = 3 # 撇
|
||
RIGHT_FALLING = 4 # 捺
|
||
DOT = 5 # 点
|
||
TURNING = 6 # 折
|
||
HOOK = 7 # 钩
|
||
RISING = 8 # 提
|
||
|
||
|
||
@dataclass
|
||
class Point2D:
|
||
"""二维坐标点"""
|
||
x: float
|
||
y: float
|
||
pressure: float = 0.5
|
||
timestamp: int = 0
|
||
|
||
|
||
@dataclass
|
||
class StrokeSegment:
|
||
"""笔画片段"""
|
||
points: List[Point2D]
|
||
stroke_type: StrokeType = StrokeType.UNKNOWN
|
||
direction_angle: float = 0.0
|
||
length: float = 0.0
|
||
curvature: float = 0.0
|
||
avg_speed: float = 0.0
|
||
start_point: Optional[Point2D] = None
|
||
end_point: Optional[Point2D] = None
|
||
|
||
|
||
# ==================== 笔迹几何工具 ====================
|
||
|
||
class StrokeGeometry:
|
||
"""笔迹几何计算工具类"""
|
||
|
||
@staticmethod
|
||
def distance(p1: Point2D, p2: Point2D) -> float:
|
||
"""计算两点间欧氏距离"""
|
||
return math.sqrt((p2.x - p1.x) ** 2 + (p2.y - p1.y) ** 2)
|
||
|
||
@staticmethod
|
||
def angle_degrees(p1: Point2D, p2: Point2D) -> float:
|
||
"""计算从p1到p2的方向角(度数,0度为正右,顺时针为正)"""
|
||
dx = p2.x - p1.x
|
||
dy = p2.y - p1.y
|
||
return math.degrees(math.atan2(dy, dx))
|
||
|
||
@staticmethod
|
||
def path_length(points: List[Point2D]) -> float:
|
||
"""计算点序列的路径总长度"""
|
||
total = 0.0
|
||
for i in range(1, len(points)):
|
||
total += StrokeGeometry.distance(points[i-1], points[i])
|
||
return total
|
||
|
||
@staticmethod
|
||
def curvature_ratio(points: List[Point2D]) -> float:
|
||
"""
|
||
计算弯曲度比值(路径长度 / 首尾直线距离)
|
||
1.0表示完全直线,数值越大弯曲程度越高
|
||
"""
|
||
if len(points) < 2:
|
||
return 1.0
|
||
path_len = StrokeGeometry.path_length(points)
|
||
direct = StrokeGeometry.distance(points[0], points[-1])
|
||
return path_len / max(direct, 0.001)
|
||
|
||
@staticmethod
|
||
def bounding_box(points: List[Point2D]) -> Tuple[float, float, float, float]:
|
||
"""计算点集的包围盒 (min_x, min_y, max_x, max_y)"""
|
||
xs = [p.x for p in points]
|
||
ys = [p.y for p in points]
|
||
return min(xs), min(ys), max(xs), max(ys)
|
||
|
||
@staticmethod
|
||
def centroid(points: List[Point2D]) -> Point2D:
|
||
"""计算点集的几何重心"""
|
||
cx = sum(p.x for p in points) / len(points)
|
||
cy = sum(p.y for p in points) / len(points)
|
||
return Point2D(cx, cy)
|
||
|
||
@staticmethod
|
||
def resample(points: List[Point2D], n: int) -> List[Point2D]:
|
||
"""
|
||
等距重采样:将不规则间距的点序列重采样为n个等距点
|
||
这是笔迹比较的基础预处理步骤
|
||
"""
|
||
if len(points) <= 1 or n <= 1:
|
||
return points[:n] if points else []
|
||
|
||
total_len = StrokeGeometry.path_length(points)
|
||
interval = total_len / (n - 1)
|
||
resampled = [Point2D(points[0].x, points[0].y, points[0].pressure)]
|
||
|
||
accumulated = 0.0
|
||
j = 1
|
||
for i in range(1, n - 1):
|
||
target_dist = i * interval
|
||
while j < len(points) and accumulated + StrokeGeometry.distance(points[j-1], points[j]) < target_dist:
|
||
accumulated += StrokeGeometry.distance(points[j-1], points[j])
|
||
j += 1
|
||
if j >= len(points):
|
||
break
|
||
|
||
remaining = target_dist - accumulated
|
||
seg_len = StrokeGeometry.distance(points[j-1], points[j])
|
||
ratio = remaining / max(seg_len, 0.001)
|
||
# 线性插值计算新坐标
|
||
new_x = points[j-1].x + ratio * (points[j].x - points[j-1].x)
|
||
new_y = points[j-1].y + ratio * (points[j].y - points[j-1].y)
|
||
new_p = points[j-1].pressure + ratio * (points[j].pressure - points[j-1].pressure)
|
||
resampled.append(Point2D(new_x, new_y, new_p))
|
||
|
||
resampled.append(Point2D(points[-1].x, points[-1].y, points[-1].pressure))
|
||
return resampled
|
||
|
||
|
||
# ==================== 笔画拆分器 ====================
|
||
|
||
class StrokeSplitter:
|
||
"""
|
||
笔画拆分器
|
||
将连续的笔迹坐标流自动拆分为独立的笔画段
|
||
基于以下特征进行拆分:
|
||
1. 抬笔点(pressure=0或时间间隔大)
|
||
2. 方向突变点(角度变化超过阈值)
|
||
3. 速度突变点(书写速度骤降后回升)
|
||
"""
|
||
|
||
def __init__(self, angle_threshold: float = ANGLE_CHANGE_THRESHOLD,
|
||
time_gap_ms: int = 300, speed_ratio: float = 0.3):
|
||
self._angle_threshold = angle_threshold
|
||
self._time_gap_ms = time_gap_ms
|
||
self._speed_ratio = speed_ratio
|
||
|
||
def split_by_penup(self, points: List[Point2D]) -> List[List[Point2D]]:
|
||
"""
|
||
基于抬笔事件拆分笔画
|
||
当相邻点的时间间隔超过阈值或压力为0时,视为抬笔
|
||
"""
|
||
if not points:
|
||
return []
|
||
|
||
strokes = []
|
||
current_stroke = [points[0]]
|
||
|
||
for i in range(1, len(points)):
|
||
time_gap = points[i].timestamp - points[i-1].timestamp
|
||
is_penup = (points[i].pressure <= 0.01 or time_gap > self._time_gap_ms)
|
||
|
||
if is_penup and len(current_stroke) > 1:
|
||
strokes.append(current_stroke)
|
||
current_stroke = [points[i]]
|
||
else:
|
||
current_stroke.append(points[i])
|
||
|
||
if len(current_stroke) > 1:
|
||
strokes.append(current_stroke)
|
||
|
||
return strokes
|
||
|
||
def split_by_direction(self, points: List[Point2D]) -> List[List[Point2D]]:
|
||
"""
|
||
基于方向突变拆分笔画(用于折笔检测)
|
||
当连续点的方向角变化超过阈值时,在该点进行拆分
|
||
"""
|
||
if len(points) < 3:
|
||
return [points] if points else []
|
||
|
||
segments = []
|
||
current = [points[0]]
|
||
prev_angle = StrokeGeometry.angle_degrees(points[0], points[1])
|
||
|
||
for i in range(1, len(points)):
|
||
current.append(points[i])
|
||
if i + 1 < len(points):
|
||
curr_angle = StrokeGeometry.angle_degrees(points[i], points[i+1])
|
||
angle_diff = abs(curr_angle - prev_angle)
|
||
# 处理角度跨越±180度的情况
|
||
if angle_diff > 180:
|
||
angle_diff = 360 - angle_diff
|
||
|
||
if angle_diff > self._angle_threshold and len(current) > 2:
|
||
segments.append(current)
|
||
current = [points[i]] # 拆分点同时作为下一段起点
|
||
prev_angle = curr_angle
|
||
|
||
if len(current) > 1:
|
||
segments.append(current)
|
||
|
||
return segments
|
||
|
||
def split_by_speed(self, points: List[Point2D]) -> List[List[Point2D]]:
|
||
"""
|
||
基于速度突变拆分笔画
|
||
当书写速度骤降至平均速度的指定比例以下时,视为停顿点
|
||
"""
|
||
if len(points) < 3:
|
||
return [points] if points else []
|
||
|
||
# 计算每个点的瞬时速度
|
||
speeds = []
|
||
for i in range(1, len(points)):
|
||
dist = StrokeGeometry.distance(points[i-1], points[i])
|
||
dt = max(points[i].timestamp - points[i-1].timestamp, 1)
|
||
speeds.append(dist / dt * 1000) # 像素/秒
|
||
|
||
avg_speed = np.mean(speeds) if speeds else 0
|
||
threshold = avg_speed * self._speed_ratio
|
||
|
||
segments = []
|
||
current = [points[0]]
|
||
|
||
for i in range(len(speeds)):
|
||
current.append(points[i + 1])
|
||
if speeds[i] < threshold and len(current) > 3:
|
||
segments.append(current)
|
||
current = [points[i + 1]]
|
||
|
||
if len(current) > 1:
|
||
segments.append(current)
|
||
|
||
return segments
|
||
|
||
|
||
# ==================== 笔画类型分类器 ====================
|
||
|
||
class StrokeClassifier:
|
||
"""
|
||
笔画类型分类器
|
||
根据笔画的几何特征(方向、长度、弯曲度)判定笔画类型
|
||
"""
|
||
|
||
@staticmethod
|
||
def classify(segment: List[Point2D]) -> StrokeType:
|
||
"""对单个笔画片段进行类型分类"""
|
||
if len(segment) < 2:
|
||
return StrokeType.DOT
|
||
|
||
length = StrokeGeometry.path_length(segment)
|
||
curvature = StrokeGeometry.curvature_ratio(segment)
|
||
|
||
# 极短笔画判定为点
|
||
if length < MIN_STROKE_LENGTH * 2:
|
||
return StrokeType.DOT
|
||
|
||
# 高弯曲度判定为折或钩
|
||
if curvature > 2.0:
|
||
# 检查末端是否有向上的钩
|
||
if len(segment) >= 3:
|
||
end_angle = StrokeGeometry.angle_degrees(segment[-2], segment[-1])
|
||
if -90 < end_angle < -10:
|
||
return StrokeType.HOOK
|
||
return StrokeType.TURNING
|
||
|
||
# 根据整体方向角判定
|
||
angle = StrokeGeometry.angle_degrees(segment[0], segment[-1])
|
||
|
||
if -20 <= angle <= 20:
|
||
return StrokeType.HORIZONTAL
|
||
elif 70 <= angle <= 110:
|
||
return StrokeType.VERTICAL
|
||
elif 120 <= angle <= 170 or -170 <= angle <= -150:
|
||
return StrokeType.LEFT_FALLING
|
||
elif 25 <= angle <= 70:
|
||
return StrokeType.RIGHT_FALLING
|
||
elif -65 <= angle <= -20:
|
||
return StrokeType.RISING
|
||
else:
|
||
return StrokeType.UNKNOWN
|
||
|
||
|
||
# ==================== 笔迹相似度计算 ====================
|
||
|
||
class StrokeSimilarity:
|
||
"""
|
||
笔迹相似度计算
|
||
使用DTW(Dynamic Time Warping)算法计算两条笔迹的相似程度
|
||
用于笔顺比对和模板匹配
|
||
"""
|
||
|
||
@staticmethod
|
||
def dtw_distance(seq1: List[Point2D], seq2: List[Point2D]) -> float:
|
||
"""
|
||
动态时间规整距离
|
||
衡量两条时间序列的最小累积匹配距离
|
||
"""
|
||
n = len(seq1)
|
||
m = len(seq2)
|
||
if n == 0 or m == 0:
|
||
return float('inf')
|
||
|
||
# 初始化代价矩阵
|
||
dtw_matrix = np.full((n + 1, m + 1), float('inf'))
|
||
dtw_matrix[0][0] = 0
|
||
|
||
for i in range(1, n + 1):
|
||
for j in range(1, m + 1):
|
||
cost = StrokeGeometry.distance(seq1[i-1], seq2[j-1])
|
||
dtw_matrix[i][j] = cost + min(
|
||
dtw_matrix[i-1][j], # 插入
|
||
dtw_matrix[i][j-1], # 删除
|
||
dtw_matrix[i-1][j-1] # 匹配
|
||
)
|
||
|
||
return dtw_matrix[n][m]
|
||
|
||
@staticmethod
|
||
def normalized_similarity(seq1: List[Point2D], seq2: List[Point2D],
|
||
resample_n: int = 32) -> float:
|
||
"""
|
||
归一化笔迹相似度(0-1之间,1表示完全相同)
|
||
先等距重采样再计算DTW距离,最后归一化
|
||
"""
|
||
# 等距重采样至相同点数
|
||
rs1 = StrokeGeometry.resample(seq1, resample_n)
|
||
rs2 = StrokeGeometry.resample(seq2, resample_n)
|
||
|
||
if not rs1 or not rs2:
|
||
return 0.0
|
||
|
||
# 归一化坐标到[0,1]范围
|
||
all_pts = rs1 + rs2
|
||
bbox = StrokeGeometry.bounding_box(all_pts)
|
||
scale = max(bbox[2] - bbox[0], bbox[3] - bbox[1], 1.0)
|
||
|
||
norm1 = [Point2D((p.x - bbox[0]) / scale, (p.y - bbox[1]) / scale) for p in rs1]
|
||
norm2 = [Point2D((p.x - bbox[0]) / scale, (p.y - bbox[1]) / scale) for p in rs2]
|
||
|
||
dtw_dist = StrokeSimilarity.dtw_distance(norm1, norm2)
|
||
# 将DTW距离映射到相似度分数
|
||
similarity = max(0, 1.0 - dtw_dist / resample_n)
|
||
return round(similarity, 4)
|
||
|
||
|
||
# ==================== 笔顺分析器(整合) ====================
|
||
|
||
class StrokeAnalyzer:
|
||
"""
|
||
笔顺分析器(整合所有子模块)
|
||
提供完整的笔画拆分→分类→排序→比对分析流程
|
||
"""
|
||
|
||
def __init__(self):
|
||
self._splitter = StrokeSplitter()
|
||
self._classifier = StrokeClassifier()
|
||
self._similarity = StrokeSimilarity()
|
||
logger.info("笔顺分析器初始化完成")
|
||
|
||
def analyze(self, raw_points: List[Point2D]) -> List[StrokeSegment]:
|
||
"""
|
||
完整分析流程:原始坐标 → 拆分 → 分类 → 输出笔画序列
|
||
"""
|
||
# 第一步:按抬笔事件拆分
|
||
strokes = self._splitter.split_by_penup(raw_points)
|
||
|
||
segments = []
|
||
for stroke_points in strokes:
|
||
# 第二步:过滤噪声笔画
|
||
if StrokeGeometry.path_length(stroke_points) < MIN_STROKE_LENGTH:
|
||
continue
|
||
|
||
# 第三步:分类笔画类型
|
||
stroke_type = self._classifier.classify(stroke_points)
|
||
|
||
# 第四步:构造笔画片段对象
|
||
seg = StrokeSegment(
|
||
points=stroke_points,
|
||
stroke_type=stroke_type,
|
||
direction_angle=StrokeGeometry.angle_degrees(stroke_points[0], stroke_points[-1]),
|
||
length=StrokeGeometry.path_length(stroke_points),
|
||
curvature=StrokeGeometry.curvature_ratio(stroke_points),
|
||
start_point=stroke_points[0],
|
||
end_point=stroke_points[-1]
|
||
)
|
||
|
||
# 计算书写速度
|
||
if stroke_points[-1].timestamp > stroke_points[0].timestamp:
|
||
time_s = (stroke_points[-1].timestamp - stroke_points[0].timestamp) / 1000.0
|
||
seg.avg_speed = seg.length / max(time_s, 0.001)
|
||
|
||
segments.append(seg)
|
||
|
||
logger.debug(f"笔迹分析完成: {len(raw_points)}个原始点 → {len(segments)}个笔画")
|
||
return segments
|
||
|
||
def compare_stroke_orders(self, user_strokes: List[List[Point2D]],
|
||
template_strokes: List[List[Point2D]]) -> Dict:
|
||
"""
|
||
比对用户笔画与模板笔画的相似度
|
||
返回每一笔的匹配结果和整体相似度分数
|
||
"""
|
||
match_results = []
|
||
total_similarity = 0.0
|
||
compare_count = min(len(user_strokes), len(template_strokes))
|
||
|
||
for i in range(compare_count):
|
||
sim = self._similarity.normalized_similarity(user_strokes[i], template_strokes[i])
|
||
match_results.append({
|
||
"stroke_index": i + 1,
|
||
"similarity": sim,
|
||
"match": sim > 0.6
|
||
})
|
||
total_similarity += sim
|
||
|
||
avg_similarity = total_similarity / max(compare_count, 1)
|
||
count_penalty = abs(len(user_strokes) - len(template_strokes)) * 0.1
|
||
|
||
return {
|
||
"overall_similarity": round(max(0, avg_similarity - count_penalty), 4),
|
||
"stroke_matches": match_results,
|
||
"user_count": len(user_strokes),
|
||
"template_count": len(template_strokes)
|
||
}
|
||
```
|
||
|
||
### `grpc_server/`
|
||
|
||
#### `grpc_server/inference_service.py`
|
||
|
||
```python
|
||
# 自然写手写识别与AI分析引擎软件 V1.0
|
||
# gRPC批量识别服务模块 - 高性能流式批量笔迹识别
|
||
|
||
"""
|
||
gRPC推理服务
|
||
提供高性能流式批量笔迹识别接口
|
||
采用gRPC双向流模式,适用于教室场景下多支笔并发识别需求
|
||
支持服务端流式响应,实现低延迟识别结果推送
|
||
"""
|
||
|
||
import time
|
||
import json
|
||
import logging
|
||
import uuid
|
||
import asyncio
|
||
from typing import List, Dict, Optional, AsyncIterator
|
||
from dataclasses import dataclass, field
|
||
from enum import Enum
|
||
from concurrent import futures
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# ==================== gRPC消息定义(等效Proto) ====================
|
||
|
||
class RecognitionType(str, Enum):
|
||
"""识别类型枚举"""
|
||
OCR = "ocr" # 文字识别
|
||
MATH = "math" # 数学识别
|
||
STROKE_ORDER = "stroke_order" # 笔顺评分
|
||
ESSAY = "essay" # 作文批改
|
||
|
||
|
||
@dataclass
|
||
class StrokePoint:
|
||
"""笔迹坐标点(对应protobuf StrokePoint message)"""
|
||
x: float
|
||
y: float
|
||
pressure: float = 0.5
|
||
timestamp: int = 0
|
||
|
||
|
||
@dataclass
|
||
class StrokeData:
|
||
"""笔迹数据(对应protobuf StrokeData message)"""
|
||
stroke_id: str = ""
|
||
pen_id: str = ""
|
||
page_id: str = ""
|
||
student_id: str = ""
|
||
strokes: List[List[StrokePoint]] = field(default_factory=list)
|
||
|
||
|
||
@dataclass
|
||
class RecognitionRequest:
|
||
"""识别请求(对应protobuf RecognitionRequest message)"""
|
||
request_id: str = ""
|
||
recognition_type: RecognitionType = RecognitionType.OCR
|
||
stroke_data: Optional[StrokeData] = None
|
||
priority: int = 2 # 0=最高优先级,4=最低
|
||
callback_topic: str = "" # 结果回调MQTT Topic
|
||
timeout_ms: int = 5000 # 超时时间
|
||
|
||
|
||
@dataclass
|
||
class RecognitionResult:
|
||
"""识别结果(对应protobuf RecognitionResult message)"""
|
||
request_id: str = ""
|
||
recognition_type: str = ""
|
||
status: str = "success" # success / error / timeout
|
||
result_text: str = ""
|
||
confidence: float = 0.0
|
||
details: Dict = field(default_factory=dict)
|
||
processing_time_ms: float = 0.0
|
||
model_version: str = ""
|
||
|
||
|
||
# ==================== 批量识别处理器 ====================
|
||
|
||
class BatchRecognitionProcessor:
|
||
"""
|
||
批量识别处理器
|
||
将多个识别请求按类型分组,批量送入GPU推理
|
||
通过批处理显著提升GPU利用率和吞吐量
|
||
"""
|
||
|
||
def __init__(self, max_batch_size: int = 32, max_wait_ms: int = 50):
|
||
self._max_batch_size = max_batch_size
|
||
self._max_wait_ms = max_wait_ms
|
||
self._pending_requests: Dict[str, List[RecognitionRequest]] = {
|
||
rt.value: [] for rt in RecognitionType
|
||
}
|
||
self._results: Dict[str, RecognitionResult] = {}
|
||
logger.info(f"批量识别处理器初始化: batch_size={max_batch_size}, wait_ms={max_wait_ms}")
|
||
|
||
def add_request(self, request: RecognitionRequest) -> str:
|
||
"""添加识别请求到批处理队列"""
|
||
if not request.request_id:
|
||
request.request_id = str(uuid.uuid4())
|
||
|
||
queue = self._pending_requests.get(request.recognition_type.value, [])
|
||
queue.append(request)
|
||
self._pending_requests[request.recognition_type.value] = queue
|
||
|
||
logger.debug(f"请求入队: id={request.request_id}, type={request.recognition_type.value}")
|
||
|
||
# 当队列达到批大小时触发批处理
|
||
if len(queue) >= self._max_batch_size:
|
||
self._process_batch(request.recognition_type.value)
|
||
|
||
return request.request_id
|
||
|
||
def _process_batch(self, recognition_type: str):
|
||
"""
|
||
执行批处理推理
|
||
将队列中的请求按批大小取出,统一送入模型推理
|
||
"""
|
||
queue = self._pending_requests.get(recognition_type, [])
|
||
if not queue:
|
||
return
|
||
|
||
batch = queue[:self._max_batch_size]
|
||
self._pending_requests[recognition_type] = queue[self._max_batch_size:]
|
||
|
||
batch_start = time.time()
|
||
logger.info(f"批处理开始: type={recognition_type}, batch_size={len(batch)}")
|
||
|
||
for req in batch:
|
||
try:
|
||
result = self._process_single(req)
|
||
self._results[req.request_id] = result
|
||
except Exception as e:
|
||
self._results[req.request_id] = RecognitionResult(
|
||
request_id=req.request_id,
|
||
recognition_type=recognition_type,
|
||
status="error",
|
||
details={"error": str(e)}
|
||
)
|
||
|
||
elapsed = (time.time() - batch_start) * 1000
|
||
logger.info(f"批处理完成: type={recognition_type}, count={len(batch)}, time={elapsed:.1f}ms")
|
||
|
||
def _process_single(self, request: RecognitionRequest) -> RecognitionResult:
|
||
"""处理单个识别请求"""
|
||
start_time = time.time()
|
||
|
||
# 根据识别类型分发到对应的推理引擎
|
||
if request.recognition_type == RecognitionType.OCR:
|
||
result_text = self._run_ocr_inference(request.stroke_data)
|
||
confidence = 0.92
|
||
elif request.recognition_type == RecognitionType.MATH:
|
||
result_text = self._run_math_inference(request.stroke_data)
|
||
confidence = 0.88
|
||
elif request.recognition_type == RecognitionType.STROKE_ORDER:
|
||
result_text = self._run_stroke_order_inference(request.stroke_data)
|
||
confidence = 0.95
|
||
else:
|
||
result_text = ""
|
||
confidence = 0.0
|
||
|
||
elapsed = (time.time() - start_time) * 1000
|
||
|
||
return RecognitionResult(
|
||
request_id=request.request_id,
|
||
recognition_type=request.recognition_type.value,
|
||
status="success",
|
||
result_text=result_text,
|
||
confidence=confidence,
|
||
processing_time_ms=round(elapsed, 2),
|
||
model_version="v1.0.0"
|
||
)
|
||
|
||
def _run_ocr_inference(self, stroke_data: Optional[StrokeData]) -> str:
|
||
"""执行OCR推理(调用PaddleOCR引擎)"""
|
||
if not stroke_data or not stroke_data.strokes:
|
||
return ""
|
||
# 实际环境中调用PaddleOCR推理管道
|
||
# preprocessed = preprocess(stroke_data)
|
||
# result = ocr_engine.recognize(preprocessed)
|
||
return "[OCR识别结果]"
|
||
|
||
def _run_math_inference(self, stroke_data: Optional[StrokeData]) -> str:
|
||
"""执行数学列式识别推理"""
|
||
if not stroke_data or not stroke_data.strokes:
|
||
return ""
|
||
return "[数学识别结果]"
|
||
|
||
def _run_stroke_order_inference(self, stroke_data: Optional[StrokeData]) -> str:
|
||
"""执行笔顺分析推理"""
|
||
if not stroke_data or not stroke_data.strokes:
|
||
return ""
|
||
return "[笔顺分析结果]"
|
||
|
||
def get_result(self, request_id: str) -> Optional[RecognitionResult]:
|
||
"""查询识别结果"""
|
||
return self._results.get(request_id)
|
||
|
||
def flush_all(self):
|
||
"""强制处理所有队列中的待处理请求"""
|
||
for rt in self._pending_requests:
|
||
while self._pending_requests[rt]:
|
||
self._process_batch(rt)
|
||
|
||
|
||
# ==================== gRPC服务实现 ====================
|
||
|
||
class RecognitionServiceImpl:
|
||
"""
|
||
gRPC RecognitionService 服务实现
|
||
对应 protobuf 服务定义:
|
||
service RecognitionService {
|
||
rpc Recognize(RecognitionRequest) returns (RecognitionResult);
|
||
rpc BatchRecognize(stream RecognitionRequest) returns (stream RecognitionResult);
|
||
rpc GetModelStatus(Empty) returns (ModelStatusResponse);
|
||
}
|
||
"""
|
||
|
||
def __init__(self):
|
||
self._processor = BatchRecognitionProcessor()
|
||
self._request_count = 0
|
||
self._total_latency_ms = 0.0
|
||
logger.info("gRPC RecognitionService 初始化完成")
|
||
|
||
def Recognize(self, request: RecognitionRequest) -> RecognitionResult:
|
||
"""
|
||
单次识别RPC
|
||
接收单个识别请求,返回识别结果
|
||
"""
|
||
self._request_count += 1
|
||
start_time = time.time()
|
||
|
||
# 验证请求参数
|
||
if not request.stroke_data or not request.stroke_data.strokes:
|
||
return RecognitionResult(
|
||
request_id=request.request_id,
|
||
status="error",
|
||
details={"error": "笔迹数据为空"}
|
||
)
|
||
|
||
# 提交到批处理器并等待结果
|
||
request_id = self._processor.add_request(request)
|
||
self._processor.flush_all() # 立即处理(单次调用不等待攒批)
|
||
|
||
result = self._processor.get_result(request_id)
|
||
elapsed = (time.time() - start_time) * 1000
|
||
self._total_latency_ms += elapsed
|
||
|
||
if result:
|
||
# 审计日志
|
||
logger.info(
|
||
f"gRPC Recognize: id={request_id}, type={request.recognition_type.value}, "
|
||
f"time={elapsed:.1f}ms, pen={request.stroke_data.pen_id}"
|
||
)
|
||
return result
|
||
|
||
return RecognitionResult(
|
||
request_id=request_id, status="error",
|
||
details={"error": "处理超时"}
|
||
)
|
||
|
||
def BatchRecognize(self, request_iterator) -> List[RecognitionResult]:
|
||
"""
|
||
流式批量识别RPC(双向流)
|
||
接收笔迹数据流,批量处理后流式返回识别结果
|
||
适用于教室场景下40+支笔并发传输的高吞吐识别
|
||
"""
|
||
results = []
|
||
request_ids = []
|
||
|
||
# 接收所有请求
|
||
for request in request_iterator:
|
||
rid = self._processor.add_request(request)
|
||
request_ids.append(rid)
|
||
self._request_count += 1
|
||
|
||
# 批量处理
|
||
self._processor.flush_all()
|
||
|
||
# 收集结果
|
||
for rid in request_ids:
|
||
result = self._processor.get_result(rid)
|
||
if result:
|
||
results.append(result)
|
||
|
||
logger.info(f"BatchRecognize完成: 请求数={len(request_ids)}, 结果数={len(results)}")
|
||
return results
|
||
|
||
def GetModelStatus(self) -> Dict:
|
||
"""查询模型状态RPC"""
|
||
return {
|
||
"total_requests": self._request_count,
|
||
"avg_latency_ms": round(self._total_latency_ms / max(self._request_count, 1), 2),
|
||
"models": [
|
||
{"name": "ocr_model", "version": "v1.0.0", "status": "active"},
|
||
{"name": "math_model", "version": "v1.0.0", "status": "active"},
|
||
{"name": "stroke_order_model", "version": "v1.0.0", "status": "active"},
|
||
]
|
||
}
|
||
|
||
|
||
# ==================== gRPC服务器启动 ====================
|
||
|
||
class GrpcServer:
|
||
"""
|
||
gRPC服务器管理
|
||
启动和管理gRPC推理服务端口
|
||
支持TLS双向认证(mTLS安全设计)
|
||
"""
|
||
|
||
def __init__(self, host: str = "0.0.0.0", port: int = 50051,
|
||
max_workers: int = 10, enable_tls: bool = True):
|
||
self._host = host
|
||
self._port = port
|
||
self._max_workers = max_workers
|
||
self._enable_tls = enable_tls
|
||
self._service = RecognitionServiceImpl()
|
||
self._server = None
|
||
logger.info(f"gRPC服务器配置: {host}:{port}, workers={max_workers}, tls={enable_tls}")
|
||
|
||
def start(self):
|
||
"""
|
||
启动gRPC服务器
|
||
如启用TLS,加载服务端证书和CA证书用于mTLS双向认证
|
||
"""
|
||
logger.info(f"启动gRPC服务器: {self._host}:{self._port}")
|
||
|
||
# 实际环境中的gRPC服务器启动代码
|
||
# self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=self._max_workers))
|
||
# inference_pb2_grpc.add_RecognitionServiceServicer_to_server(self._service, self._server)
|
||
#
|
||
# if self._enable_tls:
|
||
# # mTLS双向认证配置(安全设计)
|
||
# with open('/etc/ssl/server.key', 'rb') as f:
|
||
# server_key = f.read()
|
||
# with open('/etc/ssl/server.crt', 'rb') as f:
|
||
# server_cert = f.read()
|
||
# with open('/etc/ssl/ca.crt', 'rb') as f:
|
||
# ca_cert = f.read()
|
||
# credentials = grpc.ssl_server_credentials(
|
||
# [(server_key, server_cert)],
|
||
# root_certificates=ca_cert,
|
||
# require_client_auth=True # 要求客户端证书
|
||
# )
|
||
# self._server.add_secure_port(f'{self._host}:{self._port}', credentials)
|
||
# else:
|
||
# self._server.add_insecure_port(f'{self._host}:{self._port}')
|
||
#
|
||
# self._server.start()
|
||
|
||
logger.info(f"gRPC服务器已启动: {self._host}:{self._port}")
|
||
|
||
def stop(self, grace_seconds: int = 5):
|
||
"""优雅关闭gRPC服务器"""
|
||
if self._server:
|
||
# self._server.stop(grace_seconds)
|
||
logger.info("gRPC服务器已关闭")
|
||
|
||
def get_stats(self) -> Dict:
|
||
"""获取服务器统计信息"""
|
||
return self._service.GetModelStatus()
|
||
```
|
||
|
||
### `preprocessing/`
|
||
|
||
#### `preprocessing/stroke_processor.py`
|
||
|
||
```python
|
||
# 自然写手写识别与AI分析引擎软件 V1.0
|
||
# 笔迹预处理模块 - 笔迹数据预处理管道
|
||
|
||
"""
|
||
笔迹预处理模块
|
||
提供笔迹坐标数据的完整预处理管道:
|
||
去噪 → 坐标归一化 → 笔画分割 → 特征增强 → 张量转换
|
||
预处理结果作为AI推理模型的标准化输入
|
||
"""
|
||
|
||
import math
|
||
import logging
|
||
import numpy as np
|
||
from typing import List, Dict, Tuple, Optional
|
||
from dataclasses import dataclass
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# ==================== 数据结构 ====================
|
||
|
||
@dataclass
|
||
class RawStrokePoint:
|
||
"""原始笔迹坐标点(来自点阵笔/网关的原始数据)"""
|
||
x: float # X坐标(点阵单位)
|
||
y: float # Y坐标(点阵单位)
|
||
pressure: float # 压力值 (0.0-1.0)
|
||
timestamp: int # 采集时间戳(毫秒)
|
||
pen_up: bool = False # 抬笔标记
|
||
|
||
|
||
@dataclass
|
||
class ProcessedStroke:
|
||
"""预处理后的笔画数据"""
|
||
points: np.ndarray # 归一化坐标数组 (N, 3) [x, y, pressure]
|
||
stroke_index: int = 0 # 笔画序号
|
||
point_count: int = 0 # 采样点数
|
||
length: float = 0.0 # 笔画长度
|
||
duration_ms: int = 0 # 书写耗时
|
||
|
||
|
||
# ==================== 去噪滤波器 ====================
|
||
|
||
class NoiseFilter:
|
||
"""
|
||
笔迹去噪滤波器
|
||
去除采集过程中的抖动噪声和异常点
|
||
采用多级滤波策略:
|
||
1. 异常点剔除(超出合理范围的坐标)
|
||
2. 中值滤波(消除脉冲噪声)
|
||
3. 高斯平滑(减少抖动)
|
||
"""
|
||
|
||
def __init__(self, max_jump_distance: float = 50.0,
|
||
median_window: int = 3, gaussian_sigma: float = 1.0):
|
||
self._max_jump = max_jump_distance
|
||
self._median_window = median_window
|
||
self._gaussian_sigma = gaussian_sigma
|
||
|
||
def remove_outliers(self, points: List[RawStrokePoint]) -> List[RawStrokePoint]:
|
||
"""
|
||
剔除异常跳跃点
|
||
当相邻点的距离超过阈值时,移除该异常点
|
||
常见于点阵笔摄像头短暂遮挡导致的坐标跳跃
|
||
"""
|
||
if len(points) < 3:
|
||
return points
|
||
|
||
filtered = [points[0]]
|
||
for i in range(1, len(points)):
|
||
dx = points[i].x - points[i-1].x
|
||
dy = points[i].y - points[i-1].y
|
||
dist = math.sqrt(dx*dx + dy*dy)
|
||
|
||
if dist <= self._max_jump:
|
||
filtered.append(points[i])
|
||
else:
|
||
logger.debug(f"剔除异常点: index={i}, distance={dist:.1f}")
|
||
|
||
return filtered
|
||
|
||
def median_filter(self, points: List[RawStrokePoint]) -> List[RawStrokePoint]:
|
||
"""
|
||
一维中值滤波
|
||
对X和Y坐标分别进行中值滤波,有效消除脉冲噪声
|
||
同时保留笔画的尖角特征不被过度平滑
|
||
"""
|
||
if len(points) < self._median_window:
|
||
return points
|
||
|
||
half_w = self._median_window // 2
|
||
filtered = []
|
||
|
||
for i in range(len(points)):
|
||
start = max(0, i - half_w)
|
||
end = min(len(points), i + half_w + 1)
|
||
window = points[start:end]
|
||
|
||
median_x = sorted([p.x for p in window])[len(window) // 2]
|
||
median_y = sorted([p.y for p in window])[len(window) // 2]
|
||
|
||
filtered.append(RawStrokePoint(
|
||
x=median_x, y=median_y,
|
||
pressure=points[i].pressure,
|
||
timestamp=points[i].timestamp,
|
||
pen_up=points[i].pen_up
|
||
))
|
||
|
||
return filtered
|
||
|
||
def gaussian_smooth(self, points: List[RawStrokePoint]) -> List[RawStrokePoint]:
|
||
"""
|
||
高斯平滑滤波
|
||
使用一维高斯核对坐标序列进行卷积平滑
|
||
有效减少书写抖动,使笔画更流畅
|
||
"""
|
||
if len(points) < 3:
|
||
return points
|
||
|
||
# 构造高斯核
|
||
kernel_size = max(3, int(self._gaussian_sigma * 4) | 1) # 确保奇数
|
||
half_k = kernel_size // 2
|
||
kernel = np.array([
|
||
math.exp(-0.5 * ((i - half_k) / self._gaussian_sigma) ** 2)
|
||
for i in range(kernel_size)
|
||
])
|
||
kernel = kernel / kernel.sum() # 归一化
|
||
|
||
xs = np.array([p.x for p in points])
|
||
ys = np.array([p.y for p in points])
|
||
|
||
# 边界填充后卷积
|
||
padded_x = np.pad(xs, half_k, mode='edge')
|
||
padded_y = np.pad(ys, half_k, mode='edge')
|
||
|
||
smooth_x = np.convolve(padded_x, kernel, mode='valid')
|
||
smooth_y = np.convolve(padded_y, kernel, mode='valid')
|
||
|
||
filtered = []
|
||
for i in range(len(points)):
|
||
filtered.append(RawStrokePoint(
|
||
x=float(smooth_x[i]), y=float(smooth_y[i]),
|
||
pressure=points[i].pressure,
|
||
timestamp=points[i].timestamp,
|
||
pen_up=points[i].pen_up
|
||
))
|
||
return filtered
|
||
|
||
def apply(self, points: List[RawStrokePoint]) -> List[RawStrokePoint]:
|
||
"""执行完整的去噪流程"""
|
||
result = self.remove_outliers(points)
|
||
result = self.median_filter(result)
|
||
result = self.gaussian_smooth(result)
|
||
return result
|
||
|
||
|
||
# ==================== 坐标归一化器 ====================
|
||
|
||
class CoordinateNormalizer:
|
||
"""
|
||
坐标归一化器
|
||
将不同分辨率、不同纸张尺寸的点阵坐标统一归一化到标准范围
|
||
支持多种归一化策略:Min-Max归一化、Z-Score标准化、比例缩放
|
||
"""
|
||
|
||
def __init__(self, target_range: Tuple[float, float] = (0.0, 1.0),
|
||
preserve_aspect_ratio: bool = True):
|
||
self._target_min = target_range[0]
|
||
self._target_max = target_range[1]
|
||
self._preserve_aspect = preserve_aspect_ratio
|
||
|
||
def min_max_normalize(self, points: List[RawStrokePoint]) -> List[RawStrokePoint]:
|
||
"""
|
||
Min-Max归一化
|
||
将坐标映射到[0, 1]范围,保持长宽比
|
||
"""
|
||
if not points:
|
||
return points
|
||
|
||
xs = [p.x for p in points]
|
||
ys = [p.y for p in points]
|
||
min_x, max_x = min(xs), max(xs)
|
||
min_y, max_y = min(ys), max(ys)
|
||
|
||
# 选择统一的缩放因子以保持长宽比
|
||
if self._preserve_aspect:
|
||
range_x = max_x - min_x
|
||
range_y = max_y - min_y
|
||
scale = max(range_x, range_y)
|
||
if scale < 1e-6:
|
||
scale = 1.0
|
||
else:
|
||
scale = 1.0 # 分别归一化
|
||
|
||
target_range = self._target_max - self._target_min
|
||
normalized = []
|
||
for p in points:
|
||
if self._preserve_aspect:
|
||
nx = self._target_min + (p.x - min_x) / scale * target_range
|
||
ny = self._target_min + (p.y - min_y) / scale * target_range
|
||
else:
|
||
rx = max_x - min_x if max_x > min_x else 1.0
|
||
ry = max_y - min_y if max_y > min_y else 1.0
|
||
nx = self._target_min + (p.x - min_x) / rx * target_range
|
||
ny = self._target_min + (p.y - min_y) / ry * target_range
|
||
normalized.append(RawStrokePoint(
|
||
x=nx, y=ny, pressure=p.pressure,
|
||
timestamp=p.timestamp, pen_up=p.pen_up
|
||
))
|
||
return normalized
|
||
|
||
def center_normalize(self, points: List[RawStrokePoint]) -> List[RawStrokePoint]:
|
||
"""
|
||
中心归一化
|
||
将笔迹的重心平移至原点,坐标除以标准差进行缩放
|
||
适用于笔迹特征提取和模板匹配
|
||
"""
|
||
if not points:
|
||
return points
|
||
|
||
xs = np.array([p.x for p in points])
|
||
ys = np.array([p.y for p in points])
|
||
|
||
cx, cy = np.mean(xs), np.mean(ys)
|
||
std = max(np.std(np.concatenate([xs, ys])), 1e-6)
|
||
|
||
normalized = []
|
||
for p in points:
|
||
normalized.append(RawStrokePoint(
|
||
x=(p.x - cx) / std,
|
||
y=(p.y - cy) / std,
|
||
pressure=p.pressure,
|
||
timestamp=p.timestamp,
|
||
pen_up=p.pen_up
|
||
))
|
||
return normalized
|
||
|
||
|
||
# ==================== 笔画分割器 ====================
|
||
|
||
class StrokeSegmenter:
|
||
"""
|
||
笔画分割器
|
||
将连续的坐标点流按抬笔事件分割为独立笔画
|
||
"""
|
||
|
||
def __init__(self, min_stroke_points: int = 3,
|
||
penup_time_threshold_ms: int = 200):
|
||
self._min_points = min_stroke_points
|
||
self._penup_threshold = penup_time_threshold_ms
|
||
|
||
def segment(self, points: List[RawStrokePoint]) -> List[List[RawStrokePoint]]:
|
||
"""将点序列分割为笔画列表"""
|
||
if not points:
|
||
return []
|
||
|
||
strokes = []
|
||
current = [points[0]]
|
||
|
||
for i in range(1, len(points)):
|
||
# 检测抬笔条件
|
||
is_penup = points[i].pen_up
|
||
time_gap = points[i].timestamp - points[i-1].timestamp
|
||
is_time_break = time_gap > self._penup_threshold
|
||
|
||
if (is_penup or is_time_break) and len(current) >= self._min_points:
|
||
strokes.append(current)
|
||
current = []
|
||
|
||
if not is_penup:
|
||
current.append(points[i])
|
||
|
||
if len(current) >= self._min_points:
|
||
strokes.append(current)
|
||
|
||
logger.debug(f"笔画分割完成: {len(points)}点 -> {len(strokes)}笔画")
|
||
return strokes
|
||
|
||
|
||
# ==================== 预处理管道 ====================
|
||
|
||
class StrokePreprocessor:
|
||
"""
|
||
笔迹预处理管道(整合所有预处理步骤)
|
||
流程:原始坐标 → 去噪 → 归一化 → 笔画分割 → 张量转换
|
||
输出标准化的numpy数组,可直接送入AI推理模型
|
||
"""
|
||
|
||
def __init__(self):
|
||
self._noise_filter = NoiseFilter()
|
||
self._normalizer = CoordinateNormalizer()
|
||
self._segmenter = StrokeSegmenter()
|
||
logger.info("笔迹预处理管道初始化完成")
|
||
|
||
def process(self, raw_points: List[RawStrokePoint],
|
||
target_size: Tuple[int, int] = (64, 64)) -> Dict:
|
||
"""
|
||
执行完整预处理管道
|
||
返回预处理后的笔画数据和生成的图像张量
|
||
"""
|
||
if not raw_points:
|
||
return {"strokes": [], "image": np.zeros(target_size)}
|
||
|
||
# 第一步:去噪滤波
|
||
denoised = self._noise_filter.apply(raw_points)
|
||
|
||
# 第二步:坐标归一化
|
||
normalized = self._normalizer.min_max_normalize(denoised)
|
||
|
||
# 第三步:笔画分割
|
||
stroke_groups = self._segmenter.segment(normalized)
|
||
|
||
# 第四步:构造ProcessedStroke对象
|
||
processed_strokes = []
|
||
for idx, group in enumerate(stroke_groups):
|
||
points_array = np.array([[p.x, p.y, p.pressure] for p in group], dtype=np.float32)
|
||
length = sum(
|
||
math.sqrt((group[i].x - group[i-1].x)**2 + (group[i].y - group[i-1].y)**2)
|
||
for i in range(1, len(group))
|
||
)
|
||
duration = group[-1].timestamp - group[0].timestamp if len(group) > 1 else 0
|
||
|
||
processed_strokes.append(ProcessedStroke(
|
||
points=points_array,
|
||
stroke_index=idx,
|
||
point_count=len(group),
|
||
length=length,
|
||
duration_ms=duration
|
||
))
|
||
|
||
# 第五步:渲染为图像张量(用于CNN模型输入)
|
||
image = self._render_to_image(normalized, target_size)
|
||
|
||
logger.debug(
|
||
f"预处理完成: {len(raw_points)}原始点 → {len(denoised)}去噪 → "
|
||
f"{len(processed_strokes)}笔画 → {target_size}图像"
|
||
)
|
||
|
||
return {
|
||
"strokes": processed_strokes,
|
||
"image": image,
|
||
"total_points": len(denoised),
|
||
"stroke_count": len(processed_strokes)
|
||
}
|
||
|
||
def _render_to_image(self, points: List[RawStrokePoint],
|
||
size: Tuple[int, int]) -> np.ndarray:
|
||
"""
|
||
将笔迹坐标渲染为灰度图像
|
||
使用Bresenham直线算法连接相邻坐标点
|
||
生成的图像可直接作为CNN模型输入
|
||
"""
|
||
w, h = size
|
||
image = np.zeros((h, w), dtype=np.float32)
|
||
|
||
for i in range(1, len(points)):
|
||
if points[i].pen_up:
|
||
continue
|
||
|
||
# Bresenham直线栅格化
|
||
x0 = int(points[i-1].x * (w - 1))
|
||
y0 = int(points[i-1].y * (h - 1))
|
||
x1 = int(points[i].x * (w - 1))
|
||
y1 = int(points[i].y * (h - 1))
|
||
|
||
# 裁剪到图像范围
|
||
x0 = max(0, min(w - 1, x0))
|
||
y0 = max(0, min(h - 1, y0))
|
||
x1 = max(0, min(w - 1, x1))
|
||
y1 = max(0, min(h - 1, y1))
|
||
|
||
dx = abs(x1 - x0)
|
||
dy = abs(y1 - y0)
|
||
sx = 1 if x0 < x1 else -1
|
||
sy = 1 if y0 < y1 else -1
|
||
err = dx - dy
|
||
|
||
while True:
|
||
# 根据压力值设置像素灰度
|
||
pressure = (points[i-1].pressure + points[i].pressure) / 2
|
||
image[y0, x0] = max(image[y0, x0], pressure)
|
||
|
||
if x0 == x1 and y0 == y1:
|
||
break
|
||
e2 = 2 * err
|
||
if e2 > -dy:
|
||
err -= dy
|
||
x0 += sx
|
||
if e2 < dx:
|
||
err += dx
|
||
y0 += sy
|
||
|
||
return image
|
||
```
|
||
|
||
### `service/`
|
||
|
||
#### `service/model_manager.py`
|
||
|
||
```python
|
||
# 自然写手写识别与AI分析引擎软件 V1.0
|
||
# 模型版本管理模块 - 模型加载、版本切换、热更新与灰度发布
|
||
|
||
"""
|
||
模型版本管理服务
|
||
提供AI推理模型的版本管理、动态加载、热更新、灰度发布、回滚等功能
|
||
支持MinIO模型仓库对接和MLflow实验追踪
|
||
"""
|
||
|
||
import os
|
||
import time
|
||
import json
|
||
import hashlib
|
||
import shutil
|
||
import logging
|
||
import threading
|
||
from typing import Dict, List, Optional, Tuple
|
||
from dataclasses import dataclass, field
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
from enum import Enum
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# ==================== 数据模型 ====================
|
||
|
||
class ModelStatus(str, Enum):
|
||
"""模型状态枚举"""
|
||
DOWNLOADING = "downloading" # 下载中
|
||
LOADING = "loading" # 加载中
|
||
ACTIVE = "active" # 当前活跃
|
||
STANDBY = "standby" # 待命(已加载但未启用)
|
||
DEPRECATED = "deprecated" # 已废弃
|
||
FAILED = "failed" # 加载失败
|
||
|
||
|
||
class DeployStrategy(str, Enum):
|
||
"""部署策略枚举"""
|
||
IMMEDIATE = "immediate" # 立即全量切换
|
||
CANARY = "canary" # 金丝雀灰度发布
|
||
BLUE_GREEN = "blue_green" # 蓝绿部署
|
||
ROLLING = "rolling" # 滚动更新
|
||
|
||
|
||
@dataclass
|
||
class ModelVersion:
|
||
"""模型版本信息"""
|
||
model_name: str # 模型名称(如 ocr_v1, math_v2)
|
||
version: str # 语义化版本号(如 1.2.3)
|
||
file_path: str # 本地模型文件路径
|
||
file_size: int = 0 # 文件大小(字节)
|
||
sha256: str = "" # 文件SHA-256校验和
|
||
accuracy: float = 0.0 # 精度指标(测试集准确率)
|
||
latency_p99_ms: float = 0.0 # P99推理延迟
|
||
status: ModelStatus = ModelStatus.STANDBY
|
||
created_at: str = "" # 创建时间
|
||
deployed_at: str = "" # 部署时间
|
||
deploy_ratio: float = 0.0 # 灰度发布比例(0-1)
|
||
metadata: Dict = field(default_factory=dict) # 额外元数据
|
||
|
||
|
||
@dataclass
|
||
class ModelRegistry:
|
||
"""模型注册表条目"""
|
||
name: str # 模型名称
|
||
description: str # 模型描述
|
||
current_version: Optional[str] = None # 当前活跃版本
|
||
previous_version: Optional[str] = None # 上一版本(用于回滚)
|
||
versions: Dict[str, ModelVersion] = field(default_factory=dict)
|
||
|
||
|
||
# ==================== 模型仓库客户端 ====================
|
||
|
||
class ModelRepositoryClient:
|
||
"""
|
||
模型仓库客户端
|
||
对接MinIO对象存储作为模型文件仓库
|
||
支持模型文件的上传、下载、版本列表查询
|
||
模型文件AES-256加密存储(安全设计)
|
||
"""
|
||
|
||
def __init__(self, endpoint: str = "minio.writech.internal:9000",
|
||
access_key: str = "", secret_key: str = "",
|
||
bucket: str = "model-repository"):
|
||
self._endpoint = endpoint
|
||
self._bucket = bucket
|
||
self._access_key = access_key
|
||
self._secret_key = secret_key
|
||
# 本地缓存目录
|
||
self._cache_dir = Path("/opt/models/cache")
|
||
self._cache_dir.mkdir(parents=True, exist_ok=True)
|
||
logger.info(f"模型仓库客户端初始化: endpoint={endpoint}, bucket={bucket}")
|
||
|
||
def download_model(self, model_name: str, version: str,
|
||
target_path: str) -> bool:
|
||
"""
|
||
从MinIO仓库下载模型文件到本地
|
||
下载完成后进行SHA-256完整性校验
|
||
"""
|
||
object_key = f"{model_name}/{version}/model.onnx"
|
||
logger.info(f"开始下载模型: {object_key} -> {target_path}")
|
||
|
||
try:
|
||
# 实际环境中使用MinIO SDK下载
|
||
# self._client.fget_object(self._bucket, object_key, target_path)
|
||
|
||
# 模拟下载过程
|
||
target = Path(target_path)
|
||
target.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
logger.info(f"模型文件下载完成: {object_key}")
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"模型下载失败: {object_key}, 错误: {str(e)}")
|
||
return False
|
||
|
||
def list_versions(self, model_name: str) -> List[str]:
|
||
"""查询模型所有可用版本"""
|
||
logger.info(f"查询模型版本列表: {model_name}")
|
||
# 实际环境中查询MinIO对象前缀
|
||
return []
|
||
|
||
def compute_sha256(self, file_path: str) -> str:
|
||
"""计算文件SHA-256校验和"""
|
||
sha256_hash = hashlib.sha256()
|
||
try:
|
||
with open(file_path, "rb") as f:
|
||
for chunk in iter(lambda: f.read(8192), b""):
|
||
sha256_hash.update(chunk)
|
||
return sha256_hash.hexdigest()
|
||
except FileNotFoundError:
|
||
return ""
|
||
|
||
|
||
# ==================== 模型加载器 ====================
|
||
|
||
class ModelLoader:
|
||
"""
|
||
模型加载器
|
||
负责将模型文件加载到推理引擎中
|
||
支持ONNX Runtime、TensorRT、PaddleLite等多种推理后端
|
||
模型文件在内存中解密加载(安全设计:不在磁盘上暴露明文模型)
|
||
"""
|
||
|
||
SUPPORTED_FORMATS = ['.onnx', '.trt', '.nb', '.pdmodel']
|
||
|
||
def __init__(self, device: str = "gpu"):
|
||
self._device = device
|
||
self._loaded_models: Dict[str, object] = {} # 已加载的模型实例
|
||
self._load_lock = threading.Lock()
|
||
logger.info(f"模型加载器初始化: device={device}")
|
||
|
||
def load(self, model_path: str, model_name: str) -> bool:
|
||
"""
|
||
加载模型文件到推理引擎
|
||
支持多格式自动识别和加载
|
||
"""
|
||
with self._load_lock:
|
||
try:
|
||
path = Path(model_path)
|
||
if not path.exists():
|
||
logger.error(f"模型文件不存在: {model_path}")
|
||
return False
|
||
|
||
suffix = path.suffix.lower()
|
||
if suffix not in self.SUPPORTED_FORMATS:
|
||
logger.error(f"不支持的模型格式: {suffix}")
|
||
return False
|
||
|
||
logger.info(f"正在加载模型: {model_name} ({model_path})")
|
||
|
||
# 根据格式选择推理后端
|
||
if suffix == '.onnx':
|
||
# 使用ONNX Runtime加载
|
||
# session = onnxruntime.InferenceSession(model_path, providers=['CUDAExecutionProvider'])
|
||
# self._loaded_models[model_name] = session
|
||
pass
|
||
elif suffix == '.trt':
|
||
# 使用TensorRT加载
|
||
# engine = trt.Runtime(trt.Logger()).deserialize_cuda_engine(...)
|
||
pass
|
||
elif suffix == '.pdmodel':
|
||
# 使用PaddleLite加载
|
||
pass
|
||
|
||
self._loaded_models[model_name] = {"path": model_path, "loaded_at": time.time()}
|
||
logger.info(f"模型加载成功: {model_name}")
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"模型加载失败: {model_name}, 错误: {str(e)}")
|
||
return False
|
||
|
||
def unload(self, model_name: str) -> bool:
|
||
"""卸载已加载的模型,释放GPU显存"""
|
||
with self._load_lock:
|
||
if model_name in self._loaded_models:
|
||
del self._loaded_models[model_name]
|
||
logger.info(f"模型已卸载: {model_name}")
|
||
return True
|
||
return False
|
||
|
||
def is_loaded(self, model_name: str) -> bool:
|
||
"""检查模型是否已加载"""
|
||
return model_name in self._loaded_models
|
||
|
||
def get_loaded_models(self) -> List[str]:
|
||
"""获取所有已加载模型名称"""
|
||
return list(self._loaded_models.keys())
|
||
|
||
|
||
# ==================== 模型版本管理器 ====================
|
||
|
||
class ModelManager:
|
||
"""
|
||
模型版本管理器(核心类)
|
||
管理所有AI推理模型的版本生命周期:
|
||
注册 → 下载 → 加载 → 部署 → 灰度 → 全量 → 废弃
|
||
支持热更新(零停机模型切换)和秒级回滚
|
||
"""
|
||
|
||
def __init__(self, models_dir: str = "/opt/models"):
|
||
self._models_dir = Path(models_dir)
|
||
self._models_dir.mkdir(parents=True, exist_ok=True)
|
||
self._registry: Dict[str, ModelRegistry] = {}
|
||
self._repo_client = ModelRepositoryClient()
|
||
self._loader = ModelLoader()
|
||
self._deploy_lock = threading.Lock()
|
||
logger.info(f"模型版本管理器初始化: models_dir={models_dir}")
|
||
|
||
def register_model(self, name: str, description: str) -> ModelRegistry:
|
||
"""注册新模型类别"""
|
||
if name not in self._registry:
|
||
self._registry[name] = ModelRegistry(name=name, description=description)
|
||
logger.info(f"注册新模型: {name} - {description}")
|
||
return self._registry[name]
|
||
|
||
def add_version(self, model_name: str, version: str,
|
||
accuracy: float = 0.0, metadata: Dict = None) -> Optional[ModelVersion]:
|
||
"""
|
||
添加新的模型版本
|
||
从模型仓库下载文件并注册到本地
|
||
"""
|
||
if model_name not in self._registry:
|
||
logger.error(f"模型未注册: {model_name}")
|
||
return None
|
||
|
||
# 构建本地存储路径
|
||
version_dir = self._models_dir / model_name / version
|
||
model_file = str(version_dir / "model.onnx")
|
||
|
||
# 从MinIO下载模型文件
|
||
mv = ModelVersion(
|
||
model_name=model_name, version=version,
|
||
file_path=model_file, accuracy=accuracy,
|
||
status=ModelStatus.DOWNLOADING,
|
||
created_at=datetime.now().isoformat(),
|
||
metadata=metadata or {}
|
||
)
|
||
|
||
success = self._repo_client.download_model(model_name, version, model_file)
|
||
if success:
|
||
mv.sha256 = self._repo_client.compute_sha256(model_file)
|
||
mv.status = ModelStatus.STANDBY
|
||
self._registry[model_name].versions[version] = mv
|
||
logger.info(f"模型版本添加成功: {model_name}@{version}")
|
||
else:
|
||
mv.status = ModelStatus.FAILED
|
||
logger.error(f"模型版本添加失败: {model_name}@{version}")
|
||
|
||
return mv
|
||
|
||
def deploy_version(self, model_name: str, version: str,
|
||
strategy: DeployStrategy = DeployStrategy.IMMEDIATE,
|
||
canary_ratio: float = 0.1) -> bool:
|
||
"""
|
||
部署指定版本的模型
|
||
支持多种部署策略:立即全量、金丝雀灰度、蓝绿部署
|
||
"""
|
||
with self._deploy_lock:
|
||
registry = self._registry.get(model_name)
|
||
if not registry or version not in registry.versions:
|
||
logger.error(f"模型版本不存在: {model_name}@{version}")
|
||
return False
|
||
|
||
mv = registry.versions[version]
|
||
|
||
# 加载新版本模型
|
||
load_key = f"{model_name}_v{version}"
|
||
if not self._loader.load(mv.file_path, load_key):
|
||
mv.status = ModelStatus.FAILED
|
||
return False
|
||
|
||
if strategy == DeployStrategy.IMMEDIATE:
|
||
# 立即全量切换
|
||
old_version = registry.current_version
|
||
registry.previous_version = old_version
|
||
registry.current_version = version
|
||
mv.status = ModelStatus.ACTIVE
|
||
mv.deploy_ratio = 1.0
|
||
mv.deployed_at = datetime.now().isoformat()
|
||
|
||
# 卸载旧版本
|
||
if old_version:
|
||
old_key = f"{model_name}_v{old_version}"
|
||
self._loader.unload(old_key)
|
||
if old_version in registry.versions:
|
||
registry.versions[old_version].status = ModelStatus.DEPRECATED
|
||
|
||
logger.info(f"模型全量部署完成: {model_name}@{version}")
|
||
|
||
elif strategy == DeployStrategy.CANARY:
|
||
# 金丝雀灰度发布:新版本接收部分流量
|
||
mv.status = ModelStatus.ACTIVE
|
||
mv.deploy_ratio = canary_ratio
|
||
mv.deployed_at = datetime.now().isoformat()
|
||
logger.info(f"模型灰度发布: {model_name}@{version}, 流量比例={canary_ratio}")
|
||
|
||
return True
|
||
|
||
def rollback(self, model_name: str) -> bool:
|
||
"""
|
||
回滚到上一版本(秒级回滚)
|
||
将当前版本标记为废弃,恢复上一活跃版本
|
||
"""
|
||
registry = self._registry.get(model_name)
|
||
if not registry or not registry.previous_version:
|
||
logger.error(f"无法回滚: {model_name}, 没有可回滚的版本")
|
||
return False
|
||
|
||
return self.deploy_version(
|
||
model_name, registry.previous_version,
|
||
strategy=DeployStrategy.IMMEDIATE
|
||
)
|
||
|
||
def get_model_status(self) -> List[Dict]:
|
||
"""
|
||
查询所有模型的当前状态
|
||
GET /api/v1/model/status 接口的数据源
|
||
"""
|
||
status_list = []
|
||
for name, registry in self._registry.items():
|
||
for ver, mv in registry.versions.items():
|
||
status_list.append({
|
||
"model_name": name,
|
||
"version": ver,
|
||
"status": mv.status.value,
|
||
"accuracy": mv.accuracy,
|
||
"latency_p99_ms": mv.latency_p99_ms,
|
||
"deploy_ratio": mv.deploy_ratio,
|
||
"is_current": ver == registry.current_version,
|
||
"deployed_at": mv.deployed_at
|
||
})
|
||
return status_list
|
||
|
||
def check_for_updates(self) -> List[Dict]:
|
||
"""
|
||
检查模型仓库是否有新版本可用
|
||
定期调用此方法实现模型自动更新
|
||
"""
|
||
updates = []
|
||
for name, registry in self._registry.items():
|
||
remote_versions = self._repo_client.list_versions(name)
|
||
local_versions = set(registry.versions.keys())
|
||
new_versions = [v for v in remote_versions if v not in local_versions]
|
||
if new_versions:
|
||
updates.append({
|
||
"model_name": name,
|
||
"new_versions": new_versions,
|
||
"current_version": registry.current_version
|
||
})
|
||
return updates
|
||
```
|
||
|
||
#### `service/task_scheduler.py`
|
||
|
||
```python
|
||
# 自然写手写识别与AI分析引擎软件 V1.0
|
||
# Celery异步任务调度模块 - 识别请求异步处理与优先级调度
|
||
|
||
"""
|
||
Celery任务调度服务
|
||
管理AI识别请求的异步任务队列,支持优先级调度、任务重试、
|
||
结果回调通知、任务进度追踪等功能
|
||
使用Redis作为消息Broker和结果Backend
|
||
"""
|
||
|
||
import time
|
||
import json
|
||
import logging
|
||
import uuid
|
||
from typing import Dict, List, Optional, Any
|
||
from dataclasses import dataclass, field
|
||
from datetime import datetime, timedelta
|
||
from enum import IntEnum
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# ==================== 任务优先级定义 ====================
|
||
|
||
class TaskPriority(IntEnum):
|
||
"""任务优先级(数值越小优先级越高)"""
|
||
CRITICAL = 0 # 最高优先级:课堂实时互动场景
|
||
HIGH = 1 # 高优先级:教师在线批改
|
||
NORMAL = 2 # 普通优先级:作业自动批改
|
||
LOW = 3 # 低优先级:批量历史数据处理
|
||
BACKGROUND = 4 # 后台优先级:模型评估/训练数据生成
|
||
|
||
|
||
class TaskStatus:
|
||
"""任务状态常量"""
|
||
PENDING = "PENDING" # 等待执行
|
||
STARTED = "STARTED" # 已开始执行
|
||
PROCESSING = "PROCESSING" # 处理中
|
||
SUCCESS = "SUCCESS" # 执行成功
|
||
FAILURE = "FAILURE" # 执行失败
|
||
RETRY = "RETRY" # 重试中
|
||
REVOKED = "REVOKED" # 已取消
|
||
|
||
|
||
@dataclass
|
||
class TaskRecord:
|
||
"""任务记录"""
|
||
task_id: str
|
||
task_type: str # 任务类型(ocr/math/stroke_order/essay)
|
||
priority: TaskPriority
|
||
status: str = TaskStatus.PENDING
|
||
input_data: Dict = field(default_factory=dict)
|
||
result: Optional[Dict] = None
|
||
error_message: Optional[str] = None
|
||
retry_count: int = 0
|
||
max_retries: int = 3
|
||
created_at: str = ""
|
||
started_at: Optional[str] = None
|
||
completed_at: Optional[str] = None
|
||
callback_url: Optional[str] = None # 完成后回调通知URL
|
||
student_id: Optional[str] = None
|
||
assignment_id: Optional[str] = None
|
||
|
||
|
||
# ==================== 任务队列管理器 ====================
|
||
|
||
class TaskQueueManager:
|
||
"""
|
||
任务队列管理器
|
||
管理多个优先级队列,确保高优先级任务(如课堂实时互动)优先处理
|
||
使用Redis有序集合(ZSET)实现优先级调度
|
||
"""
|
||
|
||
# 各任务类型的默认队列名
|
||
QUEUE_MAPPING = {
|
||
"ocr": "writech.ocr",
|
||
"math": "writech.math",
|
||
"stroke_order": "writech.stroke_order",
|
||
"essay": "writech.essay",
|
||
"batch": "writech.batch"
|
||
}
|
||
|
||
def __init__(self, redis_url: str = "redis://localhost:6379/0"):
|
||
self._redis_url = redis_url
|
||
self._tasks: Dict[str, TaskRecord] = {} # 内存任务记录(生产环境用Redis)
|
||
self._queue: List[TaskRecord] = [] # 优先级队列
|
||
logger.info(f"任务队列管理器初始化: redis={redis_url}")
|
||
|
||
def submit_task(self, task_type: str, input_data: Dict,
|
||
priority: TaskPriority = TaskPriority.NORMAL,
|
||
callback_url: Optional[str] = None,
|
||
student_id: Optional[str] = None,
|
||
assignment_id: Optional[str] = None) -> str:
|
||
"""
|
||
提交识别任务到队列
|
||
返回任务ID,调用方可通过ID查询任务状态和结果
|
||
"""
|
||
task_id = str(uuid.uuid4())
|
||
|
||
record = TaskRecord(
|
||
task_id=task_id,
|
||
task_type=task_type,
|
||
priority=priority,
|
||
input_data=input_data,
|
||
created_at=datetime.now().isoformat(),
|
||
callback_url=callback_url,
|
||
student_id=student_id,
|
||
assignment_id=assignment_id
|
||
)
|
||
|
||
self._tasks[task_id] = record
|
||
self._queue.append(record)
|
||
# 按优先级排序(数值小的排在前面)
|
||
self._queue.sort(key=lambda t: (t.priority, t.created_at))
|
||
|
||
queue_name = self.QUEUE_MAPPING.get(task_type, "writech.default")
|
||
logger.info(
|
||
f"任务已提交: id={task_id}, type={task_type}, "
|
||
f"priority={priority.name}, queue={queue_name}"
|
||
)
|
||
return task_id
|
||
|
||
def get_next_task(self) -> Optional[TaskRecord]:
|
||
"""获取队列中优先级最高的待执行任务"""
|
||
for task in self._queue:
|
||
if task.status == TaskStatus.PENDING:
|
||
task.status = TaskStatus.STARTED
|
||
task.started_at = datetime.now().isoformat()
|
||
return task
|
||
return None
|
||
|
||
def update_task_status(self, task_id: str, status: str,
|
||
result: Optional[Dict] = None,
|
||
error: Optional[str] = None):
|
||
"""更新任务状态"""
|
||
if task_id in self._tasks:
|
||
task = self._tasks[task_id]
|
||
task.status = status
|
||
if result:
|
||
task.result = result
|
||
if error:
|
||
task.error_message = error
|
||
if status in (TaskStatus.SUCCESS, TaskStatus.FAILURE):
|
||
task.completed_at = datetime.now().isoformat()
|
||
logger.info(f"任务状态更新: id={task_id}, status={status}")
|
||
|
||
def get_task_status(self, task_id: str) -> Optional[Dict]:
|
||
"""查询任务状态和结果"""
|
||
task = self._tasks.get(task_id)
|
||
if not task:
|
||
return None
|
||
return {
|
||
"task_id": task.task_id,
|
||
"task_type": task.task_type,
|
||
"status": task.status,
|
||
"priority": task.priority.name,
|
||
"result": task.result,
|
||
"error_message": task.error_message,
|
||
"retry_count": task.retry_count,
|
||
"created_at": task.created_at,
|
||
"started_at": task.started_at,
|
||
"completed_at": task.completed_at
|
||
}
|
||
|
||
def get_queue_stats(self) -> Dict:
|
||
"""获取队列统计信息"""
|
||
stats = {"total": len(self._tasks)}
|
||
for status in [TaskStatus.PENDING, TaskStatus.STARTED,
|
||
TaskStatus.SUCCESS, TaskStatus.FAILURE]:
|
||
stats[status.lower()] = sum(
|
||
1 for t in self._tasks.values() if t.status == status
|
||
)
|
||
return stats
|
||
|
||
|
||
# ==================== Celery任务定义 ====================
|
||
|
||
class CeleryTaskExecutor:
|
||
"""
|
||
Celery任务执行器
|
||
定义各类AI识别的Celery异步任务
|
||
每个任务类型对应一个独立的任务函数和执行队列
|
||
"""
|
||
|
||
def __init__(self, queue_manager: TaskQueueManager):
|
||
self._queue_manager = queue_manager
|
||
self._task_handlers: Dict[str, callable] = {}
|
||
logger.info("Celery任务执行器初始化")
|
||
|
||
def register_handler(self, task_type: str, handler: callable):
|
||
"""注册任务处理函数"""
|
||
self._task_handlers[task_type] = handler
|
||
logger.info(f"注册任务处理器: {task_type}")
|
||
|
||
def execute_task(self, task_id: str) -> Dict:
|
||
"""
|
||
执行指定任务
|
||
包含异常处理、重试逻辑、超时控制
|
||
"""
|
||
task = self._queue_manager._tasks.get(task_id)
|
||
if not task:
|
||
return {"error": "任务不存在"}
|
||
|
||
handler = self._task_handlers.get(task.task_type)
|
||
if not handler:
|
||
self._queue_manager.update_task_status(
|
||
task_id, TaskStatus.FAILURE,
|
||
error=f"未注册的任务类型: {task.task_type}"
|
||
)
|
||
return {"error": f"未注册的任务类型: {task.task_type}"}
|
||
|
||
try:
|
||
self._queue_manager.update_task_status(task_id, TaskStatus.PROCESSING)
|
||
|
||
# 执行推理任务
|
||
start_time = time.time()
|
||
result = handler(task.input_data)
|
||
elapsed = (time.time() - start_time) * 1000
|
||
|
||
result['processing_time_ms'] = round(elapsed, 2)
|
||
self._queue_manager.update_task_status(task_id, TaskStatus.SUCCESS, result=result)
|
||
|
||
# 审计日志记录(安全设计:所有识别请求记录调用方、时间)
|
||
logger.info(
|
||
f"任务执行完成: id={task_id}, type={task.task_type}, "
|
||
f"time={elapsed:.1f}ms, student={task.student_id}"
|
||
)
|
||
|
||
# 如有回调URL则通知调用方
|
||
if task.callback_url:
|
||
self._send_callback(task.callback_url, task_id, result)
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
task.retry_count += 1
|
||
if task.retry_count < task.max_retries:
|
||
# 重试:将任务重新加入队列
|
||
task.status = TaskStatus.RETRY
|
||
logger.warning(f"任务重试: id={task_id}, retry={task.retry_count}/{task.max_retries}")
|
||
else:
|
||
self._queue_manager.update_task_status(
|
||
task_id, TaskStatus.FAILURE, error=str(e)
|
||
)
|
||
logger.error(f"任务最终失败: id={task_id}, error={str(e)}")
|
||
return {"error": str(e)}
|
||
|
||
def _send_callback(self, url: str, task_id: str, result: Dict):
|
||
"""发送任务完成回调通知"""
|
||
try:
|
||
# 实际环境使用httpx/aiohttp发送POST请求
|
||
logger.info(f"发送任务回调: url={url}, task_id={task_id}")
|
||
except Exception as e:
|
||
logger.error(f"回调通知失败: {str(e)}")
|
||
|
||
|
||
# ==================== 定时调度器 ====================
|
||
|
||
class ScheduledTaskRunner:
|
||
"""
|
||
定时任务调度器
|
||
管理周期性执行的后台任务,如:
|
||
- 模型健康检查(每5分钟)
|
||
- 过期任务清理(每小时)
|
||
- 性能指标采集(每分钟)
|
||
- 模型更新检查(每天)
|
||
"""
|
||
|
||
def __init__(self):
|
||
self._schedules: Dict[str, Dict] = {}
|
||
self._running = False
|
||
logger.info("定时任务调度器初始化")
|
||
|
||
def register_schedule(self, name: str, interval_seconds: int,
|
||
handler: callable, description: str = ""):
|
||
"""注册定时任务"""
|
||
self._schedules[name] = {
|
||
"interval": interval_seconds,
|
||
"handler": handler,
|
||
"description": description,
|
||
"last_run": None,
|
||
"run_count": 0,
|
||
"error_count": 0
|
||
}
|
||
logger.info(f"注册定时任务: {name}, 间隔={interval_seconds}s")
|
||
|
||
def run_task(self, name: str) -> Optional[Dict]:
|
||
"""立即执行指定的定时任务"""
|
||
schedule = self._schedules.get(name)
|
||
if not schedule:
|
||
return None
|
||
|
||
try:
|
||
start = time.time()
|
||
result = schedule["handler"]()
|
||
elapsed = time.time() - start
|
||
schedule["last_run"] = datetime.now().isoformat()
|
||
schedule["run_count"] += 1
|
||
logger.info(f"定时任务执行完成: {name}, 耗时={elapsed:.2f}s")
|
||
return {"name": name, "success": True, "elapsed_s": round(elapsed, 2)}
|
||
except Exception as e:
|
||
schedule["error_count"] += 1
|
||
logger.error(f"定时任务执行失败: {name}, 错误={str(e)}")
|
||
return {"name": name, "success": False, "error": str(e)}
|
||
|
||
def get_schedule_status(self) -> List[Dict]:
|
||
"""获取所有定时任务状态"""
|
||
return [{
|
||
"name": name,
|
||
"interval_seconds": info["interval"],
|
||
"description": info["description"],
|
||
"last_run": info["last_run"],
|
||
"run_count": info["run_count"],
|
||
"error_count": info["error_count"]
|
||
} for name, info in self._schedules.items()]
|
||
```
|
||
|