software copyright
This commit is contained in:
@@ -0,0 +1,446 @@
|
||||
# 自然写手写识别与AI分析引擎软件 V1.0
|
||||
# 作文批改接口模块 - AI作文评分与批改建议服务
|
||||
|
||||
"""
|
||||
作文批改API接口
|
||||
提供AI作文评分、多维度分析(结构/语法/内容/修辞)、批改建议生成等功能
|
||||
支持小学至初中阶段作文批改,基于大语言模型与NLP分析管道
|
||||
"""
|
||||
|
||||
import time
|
||||
import json
|
||||
import logging
|
||||
import hashlib
|
||||
import re
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ==================== 数据模型定义 ====================
|
||||
|
||||
class EssayReviewRequest(BaseModel):
|
||||
"""作文批改请求"""
|
||||
text: str = Field(..., min_length=10, max_length=5000, description="作文OCR识别文本")
|
||||
title: Optional[str] = Field(None, description="作文题目")
|
||||
grade: int = Field(3, ge=1, le=9, description="年级(1-9)")
|
||||
genre: str = Field("narrative", description="文体类型: narrative/argumentative/expository/descriptive")
|
||||
max_score: int = Field(100, description="满分值")
|
||||
student_id: Optional[str] = Field(None, description="学生ID")
|
||||
assignment_id: Optional[str] = Field(None, description="作业ID")
|
||||
enable_suggestions: bool = Field(True, description="是否生成修改建议")
|
||||
|
||||
@validator('genre')
|
||||
def validate_genre(cls, v):
|
||||
valid_genres = ['narrative', 'argumentative', 'expository', 'descriptive']
|
||||
if v not in valid_genres:
|
||||
raise ValueError(f'文体类型必须为: {valid_genres}')
|
||||
return v
|
||||
|
||||
|
||||
class SentenceError(BaseModel):
|
||||
"""句子级错误标注"""
|
||||
sentence: str = Field(..., description="原始句子")
|
||||
error_type: str = Field(..., description="错误类型")
|
||||
suggestion: str = Field(..., description="修改建议")
|
||||
position: int = Field(..., description="句子在原文中的位置索引")
|
||||
|
||||
|
||||
class EssayScoreDetail(BaseModel):
|
||||
"""作文各维度评分详情"""
|
||||
structure: float = Field(..., description="结构分")
|
||||
grammar: float = Field(..., description="语法分")
|
||||
content: float = Field(..., description="内容分")
|
||||
rhetoric: float = Field(..., description="修辞分")
|
||||
handwriting: Optional[float] = Field(None, description="书写分(如有)")
|
||||
|
||||
|
||||
# ==================== 文本分析工具 ====================
|
||||
|
||||
class TextAnalyzer:
|
||||
"""
|
||||
文本分析工具类
|
||||
提供基础的中文文本分析功能:分句、词频统计、句式分析等
|
||||
"""
|
||||
|
||||
# 中文句末标点
|
||||
SENTENCE_ENDINGS = {'。', '!', '?', '……', ';'}
|
||||
# 中文段落标识
|
||||
PARAGRAPH_INDENT = ' '
|
||||
|
||||
@staticmethod
|
||||
def split_sentences(text: str) -> List[str]:
|
||||
"""将文本分割为句子列表"""
|
||||
sentences = []
|
||||
current = ""
|
||||
for char in text:
|
||||
current += char
|
||||
if char in TextAnalyzer.SENTENCE_ENDINGS:
|
||||
if current.strip():
|
||||
sentences.append(current.strip())
|
||||
current = ""
|
||||
if current.strip():
|
||||
sentences.append(current.strip())
|
||||
return sentences
|
||||
|
||||
@staticmethod
|
||||
def split_paragraphs(text: str) -> List[str]:
|
||||
"""将文本分割为段落列表"""
|
||||
# 按换行符分割,过滤空段落
|
||||
paragraphs = [p.strip() for p in text.split('\n') if p.strip()]
|
||||
return paragraphs
|
||||
|
||||
@staticmethod
|
||||
def count_characters(text: str) -> Dict[str, int]:
|
||||
"""统计文本字符数"""
|
||||
chinese_count = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
|
||||
punctuation_count = sum(1 for c in text if c in ',。!?、;:""''()《》……—')
|
||||
total_count = len(text.replace(' ', '').replace('\n', ''))
|
||||
return {
|
||||
"total": total_count,
|
||||
"chinese": chinese_count,
|
||||
"punctuation": punctuation_count
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def detect_rhetoric(text: str) -> List[Dict]:
|
||||
"""
|
||||
检测修辞手法使用情况
|
||||
识别常见修辞:比喻、排比、拟人、夸张等
|
||||
"""
|
||||
rhetorics = []
|
||||
|
||||
# 比喻检测:包含"像...一样"、"如同"、"仿佛"等关键词
|
||||
simile_patterns = [
|
||||
r'像.{2,10}一样', r'如同.{2,10}', r'仿佛.{2,10}',
|
||||
r'好像.{2,10}', r'犹如.{2,10}', r'宛如.{2,10}'
|
||||
]
|
||||
for pattern in simile_patterns:
|
||||
matches = re.finditer(pattern, text)
|
||||
for m in matches:
|
||||
rhetorics.append({
|
||||
"type": "simile", "name": "比喻",
|
||||
"text": m.group(), "position": m.start()
|
||||
})
|
||||
|
||||
# 排比检测:连续出现相似句式结构
|
||||
sentences = TextAnalyzer.split_sentences(text)
|
||||
for i in range(len(sentences) - 2):
|
||||
s1, s2, s3 = sentences[i], sentences[i+1], sentences[i+2]
|
||||
# 简化判断:三个连续句子长度相近且首字相同
|
||||
if (abs(len(s1) - len(s2)) < 5 and abs(len(s2) - len(s3)) < 5 and
|
||||
len(s1) > 5 and s1[0] == s2[0] == s3[0]):
|
||||
rhetorics.append({
|
||||
"type": "parallelism", "name": "排比",
|
||||
"text": f"{s1}{s2}{s3}", "position": text.find(s1)
|
||||
})
|
||||
|
||||
# 拟人检测:非人事物使用人的动作词
|
||||
personification_patterns = [
|
||||
r'[风雨雪花树草月阳光河水山].{0,3}[笑哭唱跳跑走说叫]',
|
||||
r'[风雨雪花树草月阳光河水山].{0,3}[温柔轻轻悄悄]'
|
||||
]
|
||||
for pattern in personification_patterns:
|
||||
matches = re.finditer(pattern, text)
|
||||
for m in matches:
|
||||
rhetorics.append({
|
||||
"type": "personification", "name": "拟人",
|
||||
"text": m.group(), "position": m.start()
|
||||
})
|
||||
|
||||
return rhetorics
|
||||
|
||||
|
||||
# ==================== 作文评分引擎 ====================
|
||||
|
||||
class EssayScoringEngine:
|
||||
"""
|
||||
作文评分引擎
|
||||
基于多维度分析管道对作文进行综合评分
|
||||
评分维度:结构(25%)、语法(25%)、内容(30%)、修辞(20%)
|
||||
"""
|
||||
|
||||
# 各年级期望字数范围
|
||||
EXPECTED_LENGTH = {
|
||||
1: (50, 150), 2: (100, 250), 3: (200, 400),
|
||||
4: (300, 500), 5: (350, 600), 6: (400, 700),
|
||||
7: (500, 800), 8: (600, 900), 9: (600, 1000)
|
||||
}
|
||||
|
||||
# 评分维度权重配置
|
||||
DIMENSION_WEIGHTS = {
|
||||
"structure": 0.25,
|
||||
"grammar": 0.25,
|
||||
"content": 0.30,
|
||||
"rhetoric": 0.20
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self._text_analyzer = TextAnalyzer()
|
||||
self._error_patterns = self._load_error_patterns()
|
||||
logger.info("作文评分引擎初始化完成")
|
||||
|
||||
def _load_error_patterns(self) -> List[Dict]:
|
||||
"""加载常见语法错误模式库"""
|
||||
return [
|
||||
{"pattern": r"的的", "type": "repetition", "msg": "重复用字'的的'"},
|
||||
{"pattern": r"了了", "type": "repetition", "msg": "重复用字'了了'"},
|
||||
{"pattern": r"因为.{5,50}因为", "type": "logic", "msg": "重复使用'因为',建议精简"},
|
||||
{"pattern": r"然后.{3,20}然后.{3,20}然后", "type": "style", "msg": "过度使用'然后'连接"},
|
||||
{"pattern": r"非常非常", "type": "repetition", "msg": "重复使用'非常'"},
|
||||
{"pattern": r"[,]{3,}", "type": "punctuation", "msg": "连续使用多个逗号,建议使用句号断句"},
|
||||
]
|
||||
|
||||
def score_structure(self, text: str, grade: int) -> Tuple[float, List[str]]:
|
||||
"""
|
||||
评估文章结构(满分100)
|
||||
检查:段落划分、开头结尾完整性、字数是否达标、层次是否清晰
|
||||
"""
|
||||
comments = []
|
||||
score = 100.0
|
||||
|
||||
paragraphs = self._text_analyzer.split_paragraphs(text)
|
||||
char_stats = self._text_analyzer.count_characters(text)
|
||||
|
||||
# 段落数评估(期望3-8段)
|
||||
if len(paragraphs) < 2:
|
||||
score -= 25
|
||||
comments.append("文章缺少段落划分,建议分段书写使结构更清晰")
|
||||
elif len(paragraphs) < 3:
|
||||
score -= 10
|
||||
comments.append("段落较少,建议增加过渡段落")
|
||||
|
||||
# 字数评估
|
||||
expected = self.EXPECTED_LENGTH.get(grade, (300, 600))
|
||||
if char_stats["chinese"] < expected[0]:
|
||||
deficit = expected[0] - char_stats["chinese"]
|
||||
score -= min(30, deficit // 10)
|
||||
comments.append(f"字数偏少({char_stats['chinese']}字),该年级建议{expected[0]}-{expected[1]}字")
|
||||
elif char_stats["chinese"] > expected[1] * 1.5:
|
||||
score -= 5
|
||||
comments.append("字数偏多,建议精简语句突出重点")
|
||||
|
||||
# 开头结尾评估
|
||||
if paragraphs:
|
||||
first_para = paragraphs[0]
|
||||
last_para = paragraphs[-1]
|
||||
if len(first_para) < 15:
|
||||
score -= 10
|
||||
comments.append("开头过于简短,建议丰富开篇引入")
|
||||
if len(last_para) < 10:
|
||||
score -= 10
|
||||
comments.append("结尾过于简短,建议加强收束呼应主题")
|
||||
|
||||
return max(0, score), comments
|
||||
|
||||
def score_grammar(self, text: str) -> Tuple[float, List[SentenceError]]:
|
||||
"""
|
||||
评估语法正确性(满分100)
|
||||
检查:常见语病、标点使用、词语搭配
|
||||
"""
|
||||
errors = []
|
||||
score = 100.0
|
||||
|
||||
# 使用预定义的错误模式进行匹配检测
|
||||
for ep in self._error_patterns:
|
||||
matches = re.finditer(ep["pattern"], text)
|
||||
for m in matches:
|
||||
errors.append(SentenceError(
|
||||
sentence=m.group(),
|
||||
error_type=ep["type"],
|
||||
suggestion=ep["msg"],
|
||||
position=m.start()
|
||||
))
|
||||
score -= 5 # 每个语法错误扣5分
|
||||
|
||||
# 检查句子长度(过长的句子可能有语病)
|
||||
sentences = self._text_analyzer.split_sentences(text)
|
||||
for i, s in enumerate(sentences):
|
||||
if len(s) > 80:
|
||||
errors.append(SentenceError(
|
||||
sentence=s[:30] + "...",
|
||||
error_type="long_sentence",
|
||||
suggestion="句子过长,建议拆分为多个短句以提高可读性",
|
||||
position=text.find(s)
|
||||
))
|
||||
score -= 3
|
||||
|
||||
return max(0, score), errors
|
||||
|
||||
def score_content(self, text: str, title: Optional[str], genre: str, grade: int) -> Tuple[float, List[str]]:
|
||||
"""
|
||||
评估内容质量(满分100)
|
||||
检查:主题相关性、内容丰富度、逻辑连贯性、情感表达
|
||||
"""
|
||||
comments = []
|
||||
score = 85.0 # 基础分(内容难以精确量化,给予较高基础分)
|
||||
|
||||
char_stats = self._text_analyzer.count_characters(text)
|
||||
sentences = self._text_analyzer.split_sentences(text)
|
||||
|
||||
# 内容丰富度:通过不同词汇的数量粗略评估
|
||||
unique_chars = set(c for c in text if '\u4e00' <= c <= '\u9fff')
|
||||
vocab_richness = len(unique_chars) / max(char_stats["chinese"], 1)
|
||||
if vocab_richness > 0.6:
|
||||
score += 10
|
||||
comments.append("词汇丰富,用词多样化")
|
||||
elif vocab_richness < 0.3:
|
||||
score -= 10
|
||||
comments.append("词汇较为单一,建议使用更丰富的词语表达")
|
||||
|
||||
# 逻辑连贯性:检查是否使用连接词
|
||||
connectors = ['因此', '所以', '但是', '然而', '首先', '其次', '最后', '总之',
|
||||
'不仅', '而且', '虽然', '但', '因为', '于是']
|
||||
used_connectors = [c for c in connectors if c in text]
|
||||
if len(used_connectors) >= 3:
|
||||
score += 5
|
||||
comments.append("逻辑衔接词使用恰当,行文连贯")
|
||||
elif len(used_connectors) == 0 and len(sentences) > 5:
|
||||
score -= 5
|
||||
comments.append("缺少逻辑连接词,建议增加过渡衔接使行文更连贯")
|
||||
|
||||
# 情感表达评估
|
||||
emotion_words = ['开心', '快乐', '高兴', '感动', '难过', '伤心', '惊讶',
|
||||
'温暖', '幸福', '骄傲', '担心', '紧张']
|
||||
used_emotions = [w for w in emotion_words if w in text]
|
||||
if used_emotions:
|
||||
score += 3
|
||||
comments.append("有恰当的情感表达,增强了文章感染力")
|
||||
|
||||
return min(100, max(0, score)), comments
|
||||
|
||||
def score_rhetoric(self, text: str, grade: int) -> Tuple[float, List[str]]:
|
||||
"""
|
||||
评估修辞运用(满分100)
|
||||
检查:修辞手法的使用数量和质量
|
||||
"""
|
||||
comments = []
|
||||
score = 70.0 # 基础分
|
||||
|
||||
rhetorics = self._text_analyzer.detect_rhetoric(text)
|
||||
|
||||
# 根据检测到的修辞数量加分
|
||||
rhetoric_types = set(r["type"] for r in rhetorics)
|
||||
if len(rhetoric_types) >= 3:
|
||||
score += 25
|
||||
comments.append(f"修辞手法运用丰富,使用了{len(rhetoric_types)}种修辞手法")
|
||||
elif len(rhetoric_types) >= 1:
|
||||
score += 15
|
||||
used_names = set(r["name"] for r in rhetorics)
|
||||
comments.append(f"使用了{'、'.join(used_names)}等修辞手法")
|
||||
else:
|
||||
comments.append("建议适当使用比喻、排比等修辞手法增强表达效果")
|
||||
|
||||
# 高年级对修辞有更高要求
|
||||
if grade >= 5 and len(rhetoric_types) < 2:
|
||||
score -= 10
|
||||
comments.append("该年级建议至少使用2种以上修辞手法")
|
||||
|
||||
return min(100, max(0, score)), comments
|
||||
|
||||
def review_essay(self, request: EssayReviewRequest) -> Dict:
|
||||
"""
|
||||
综合批改作文,返回总分和各维度分析结果
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# 各维度独立评分
|
||||
struct_score, struct_comments = self.score_structure(request.text, request.grade)
|
||||
grammar_score, grammar_errors = self.score_grammar(request.text)
|
||||
content_score, content_comments = self.score_content(
|
||||
request.text, request.title, request.genre, request.grade)
|
||||
rhetoric_score, rhetoric_comments = self.score_rhetoric(request.text, request.grade)
|
||||
|
||||
# 按权重计算总分,并映射到满分值
|
||||
weighted_score = (
|
||||
struct_score * self.DIMENSION_WEIGHTS["structure"] +
|
||||
grammar_score * self.DIMENSION_WEIGHTS["grammar"] +
|
||||
content_score * self.DIMENSION_WEIGHTS["content"] +
|
||||
rhetoric_score * self.DIMENSION_WEIGHTS["rhetoric"]
|
||||
)
|
||||
total_score = round(weighted_score / 100 * request.max_score, 1)
|
||||
|
||||
# 字数统计
|
||||
char_stats = TextAnalyzer.count_characters(request.text)
|
||||
|
||||
# 生成综合评语
|
||||
overall_comment = self._generate_overall_comment(
|
||||
total_score, request.max_score, struct_comments,
|
||||
content_comments, rhetoric_comments
|
||||
)
|
||||
|
||||
elapsed = (time.time() - start_time) * 1000
|
||||
|
||||
result = {
|
||||
"total_score": total_score,
|
||||
"max_score": request.max_score,
|
||||
"dimensions": {
|
||||
"structure": round(struct_score / 100 * request.max_score * self.DIMENSION_WEIGHTS["structure"], 1),
|
||||
"grammar": round(grammar_score / 100 * request.max_score * self.DIMENSION_WEIGHTS["grammar"], 1),
|
||||
"content": round(content_score / 100 * request.max_score * self.DIMENSION_WEIGHTS["content"], 1),
|
||||
"rhetoric": round(rhetoric_score / 100 * request.max_score * self.DIMENSION_WEIGHTS["rhetoric"], 1),
|
||||
},
|
||||
"character_count": char_stats,
|
||||
"overall_comment": overall_comment,
|
||||
"structure_analysis": struct_comments,
|
||||
"content_analysis": content_comments,
|
||||
"rhetoric_analysis": rhetoric_comments,
|
||||
"grammar_errors": [e.dict() for e in grammar_errors] if request.enable_suggestions else [],
|
||||
"inference_time_ms": round(elapsed, 2)
|
||||
}
|
||||
return result
|
||||
|
||||
def _generate_overall_comment(self, score: float, max_score: int,
|
||||
struct_comments: List, content_comments: List,
|
||||
rhetoric_comments: List) -> str:
|
||||
"""生成综合评语"""
|
||||
ratio = score / max_score
|
||||
if ratio >= 0.9:
|
||||
prefix = "优秀!"
|
||||
elif ratio >= 0.75:
|
||||
prefix = "良好。"
|
||||
elif ratio >= 0.6:
|
||||
prefix = "中等。"
|
||||
else:
|
||||
prefix = "需要加强。"
|
||||
|
||||
suggestions = []
|
||||
if struct_comments:
|
||||
suggestions.append(struct_comments[0])
|
||||
if content_comments:
|
||||
suggestions.append(content_comments[0])
|
||||
if rhetoric_comments:
|
||||
suggestions.append(rhetoric_comments[0])
|
||||
|
||||
return f"{prefix}{';'.join(suggestions[:3])}"
|
||||
|
||||
|
||||
# ==================== API路由定义 ====================
|
||||
|
||||
router = APIRouter(prefix="/api/v1", tags=["作文批改"])
|
||||
_scoring_engine = EssayScoringEngine()
|
||||
|
||||
|
||||
@router.post("/essay/review")
|
||||
async def review_essay(request: EssayReviewRequest):
|
||||
"""
|
||||
AI作文评分与批改接口
|
||||
POST /api/v1/essay/review
|
||||
输入作文OCR识别文本,返回综合评分、各维度分析和修改建议
|
||||
"""
|
||||
try:
|
||||
result = _scoring_engine.review_essay(request)
|
||||
|
||||
# 审计日志记录
|
||||
logger.info(
|
||||
f"作文批改完成: score={result['total_score']}/{request.max_score}, "
|
||||
f"student={request.student_id}, assignment={request.assignment_id}, "
|
||||
f"chars={result['character_count']['chinese']}, time={result['inference_time_ms']}ms"
|
||||
)
|
||||
return {"code": 200, "msg": "success", "data": result}
|
||||
except Exception as e:
|
||||
logger.error(f"作文批改异常: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"作文批改服务异常: {str(e)}")
|
||||
@@ -0,0 +1,295 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
自然写手写识别与AI分析引擎软件 V1.0
|
||||
|
||||
数学列式与公式识别接口
|
||||
支持四则运算、方程式、几何图形公式等数学内容识别
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional, Dict, Any
|
||||
import numpy as np
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
import re
|
||||
|
||||
logger = logging.getLogger("writech-ai-engine.math")
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class MathStrokePoint(BaseModel):
|
||||
"""数学笔迹坐标点"""
|
||||
x: int = Field(..., ge=0, le=65535)
|
||||
y: int = Field(..., ge=0, le=65535)
|
||||
pressure: int = Field(0, ge=0, le=255)
|
||||
timestamp: int = Field(...)
|
||||
pen_up: bool = Field(False)
|
||||
|
||||
|
||||
class MathRecognizeRequest(BaseModel):
|
||||
"""数学识别请求"""
|
||||
strokes: List[List[MathStrokePoint]] = Field(..., description="笔迹数据")
|
||||
math_type: str = Field("arithmetic", description="数学类型: arithmetic/equation/geometry")
|
||||
grade_level: int = Field(3, ge=1, le=6, description="年级(1-6)")
|
||||
|
||||
|
||||
class MathStep(BaseModel):
|
||||
"""计算步骤"""
|
||||
step_no: int = Field(..., description="步骤序号")
|
||||
expression: str = Field(..., description="表达式")
|
||||
result: Optional[str] = Field(None, description="计算结果")
|
||||
is_correct: bool = Field(True, description="是否正确")
|
||||
error_type: Optional[str] = Field(None, description="错误类型")
|
||||
error_detail: Optional[str] = Field(None, description="错误详情")
|
||||
|
||||
|
||||
class MathRecognizeResult(BaseModel):
|
||||
"""数学识别结果"""
|
||||
latex: str = Field(..., description="LaTeX表达式")
|
||||
result: Optional[str] = Field(None, description="计算结果")
|
||||
is_correct: bool = Field(True, description="答案是否正确")
|
||||
steps: List[MathStep] = Field(default=[], description="计算步骤")
|
||||
confidence: float = Field(..., description="识别置信度")
|
||||
|
||||
|
||||
class MathEngine:
|
||||
"""
|
||||
数学列式识别引擎
|
||||
|
||||
支持识别类型:
|
||||
- 四则运算(加减乘除、连续运算)
|
||||
- 竖式计算(加法竖式、减法竖式、乘法竖式、除法竖式)
|
||||
- 比较大小(>、<、=)
|
||||
- 分数运算
|
||||
- 简单方程(一元一次方程)
|
||||
|
||||
推理流程:
|
||||
笔迹 → 图像渲染 → 符号分割 → 符号识别 → 结构分析 → 表达式重建 → 计算验证
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
self.is_loaded = False
|
||||
# 支持的数学符号集合
|
||||
self.symbol_set = set("0123456789+-×÷=><()/.%")
|
||||
logger.info("数学识别引擎初始化完成")
|
||||
|
||||
def load_model(self, model_path: str):
|
||||
"""加载数学识别模型"""
|
||||
logger.info(f"加载数学识别模型: {model_path}")
|
||||
self.is_loaded = True
|
||||
logger.info("数学识别模型加载完成")
|
||||
|
||||
def recognize(self, strokes: List[List[MathStrokePoint]],
|
||||
math_type: str = "arithmetic",
|
||||
grade_level: int = 3) -> MathRecognizeResult:
|
||||
"""
|
||||
数学列式识别主流程
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# 步骤1:笔迹预处理与图像渲染
|
||||
image = self._preprocess_strokes(strokes)
|
||||
|
||||
# 步骤2:数学符号分割
|
||||
segments = self._segment_symbols(image)
|
||||
|
||||
# 步骤3:符号识别(CNN分类器)
|
||||
symbols = self._recognize_symbols(segments)
|
||||
|
||||
# 步骤4:结构分析(确定运算符和操作数的空间关系)
|
||||
structure = self._analyze_structure(symbols, math_type)
|
||||
|
||||
# 步骤5:表达式重建(生成LaTeX和数学表达式)
|
||||
latex_expr, math_expr = self._reconstruct_expression(structure)
|
||||
|
||||
# 步骤6:计算验证
|
||||
result, is_correct, steps = self._verify_calculation(math_expr, grade_level)
|
||||
|
||||
inference_time = time.time() - start_time
|
||||
logger.info(f"数学识别完成: latex={latex_expr}, correct={is_correct}, "
|
||||
f"time={inference_time:.4f}s")
|
||||
|
||||
return MathRecognizeResult(
|
||||
latex=latex_expr,
|
||||
result=result,
|
||||
is_correct=is_correct,
|
||||
steps=steps,
|
||||
confidence=0.92
|
||||
)
|
||||
|
||||
def _preprocess_strokes(self, strokes: List[List[MathStrokePoint]]) -> np.ndarray:
|
||||
"""笔迹预处理:坐标归一化 → 去噪 → 渲染为灰度图"""
|
||||
canvas_h, canvas_w = 64, 512
|
||||
canvas = np.zeros((canvas_h, canvas_w), dtype=np.float32)
|
||||
|
||||
all_x = [p.x for s in strokes for p in s]
|
||||
all_y = [p.y for s in strokes for p in s]
|
||||
if not all_x:
|
||||
return canvas
|
||||
|
||||
min_x, max_x = min(all_x), max(all_x)
|
||||
min_y, max_y = min(all_y), max(all_y)
|
||||
w = max(max_x - min_x, 1)
|
||||
h = max(max_y - min_y, 1)
|
||||
scale = min((canvas_w - 10) / w, (canvas_h - 10) / h)
|
||||
|
||||
for stroke in strokes:
|
||||
for i in range(1, len(stroke)):
|
||||
x1 = int((stroke[i-1].x - min_x) * scale + 5)
|
||||
y1 = int((stroke[i-1].y - min_y) * scale + 5)
|
||||
x2 = int((stroke[i].x - min_x) * scale + 5)
|
||||
y2 = int((stroke[i].y - min_y) * scale + 5)
|
||||
x1, x2 = np.clip([x1, x2], 0, canvas_w - 1)
|
||||
y1, y2 = np.clip([y1, y2], 0, canvas_h - 1)
|
||||
canvas[y1:y2+1, x1:x2+1] = 1.0
|
||||
|
||||
return canvas
|
||||
|
||||
def _segment_symbols(self, image: np.ndarray) -> List[Dict]:
|
||||
"""
|
||||
数学符号分割
|
||||
基于连通域分析将图像分割为独立的符号区域
|
||||
"""
|
||||
segments = []
|
||||
# 使用连通域分析进行符号分割
|
||||
# labels = cv2.connectedComponents(image)
|
||||
# 模拟分割结果
|
||||
segments = [
|
||||
{"bbox": [10, 5, 40, 55], "image": image[5:55, 10:40]},
|
||||
{"bbox": [45, 20, 65, 45], "image": image[20:45, 45:65]},
|
||||
{"bbox": [70, 5, 100, 55], "image": image[5:55, 70:100]},
|
||||
{"bbox": [105, 20, 125, 45], "image": image[20:45, 105:125]},
|
||||
{"bbox": [130, 5, 160, 55], "image": image[5:55, 130:160]},
|
||||
]
|
||||
return segments
|
||||
|
||||
def _recognize_symbols(self, segments: List[Dict]) -> List[Dict]:
|
||||
"""
|
||||
符号识别(CNN分类器)
|
||||
对每个分割区域进行数字/运算符分类
|
||||
"""
|
||||
symbols = []
|
||||
# 模拟识别结果
|
||||
mock_symbols = ["1", "2", "+", "3", "=", "1", "5"]
|
||||
for i, seg in enumerate(segments):
|
||||
if i < len(mock_symbols):
|
||||
symbols.append({
|
||||
"symbol": mock_symbols[i],
|
||||
"bbox": seg["bbox"],
|
||||
"confidence": 0.95 - i * 0.01
|
||||
})
|
||||
return symbols
|
||||
|
||||
def _analyze_structure(self, symbols: List[Dict], math_type: str) -> Dict:
|
||||
"""
|
||||
结构分析
|
||||
根据符号的空间位置关系确定数学表达式的结构
|
||||
处理竖式、分数线、括号等特殊结构
|
||||
"""
|
||||
# 按x坐标排序(从左到右阅读顺序)
|
||||
sorted_symbols = sorted(symbols, key=lambda s: s["bbox"][0])
|
||||
|
||||
if math_type == "arithmetic":
|
||||
return {"type": "linear", "symbols": sorted_symbols}
|
||||
elif math_type == "equation":
|
||||
return {"type": "equation", "symbols": sorted_symbols}
|
||||
else:
|
||||
return {"type": "unknown", "symbols": sorted_symbols}
|
||||
|
||||
def _reconstruct_expression(self, structure: Dict) -> tuple:
|
||||
"""
|
||||
表达式重建
|
||||
从结构化符号序列生成LaTeX表达式和可计算表达式
|
||||
"""
|
||||
symbols = structure.get("symbols", [])
|
||||
chars = [s["symbol"] for s in symbols]
|
||||
text = "".join(chars)
|
||||
|
||||
# 生成LaTeX
|
||||
latex = text.replace("×", "\\times ").replace("÷", "\\div ")
|
||||
|
||||
# 生成可计算表达式
|
||||
math_expr = text.replace("×", "*").replace("÷", "/")
|
||||
|
||||
return latex, math_expr
|
||||
|
||||
def _verify_calculation(self, math_expr: str, grade_level: int) -> tuple:
|
||||
"""
|
||||
计算验证
|
||||
解析数学表达式,计算正确答案,对比学生答案
|
||||
"""
|
||||
steps = []
|
||||
|
||||
# 尝试分离等号两侧
|
||||
if "=" in math_expr:
|
||||
parts = math_expr.split("=")
|
||||
if len(parts) == 2:
|
||||
left = parts[0].strip()
|
||||
right = parts[1].strip()
|
||||
|
||||
try:
|
||||
left_val = self._safe_eval(left)
|
||||
right_val = self._safe_eval(right)
|
||||
|
||||
steps.append(MathStep(
|
||||
step_no=1,
|
||||
expression=left,
|
||||
result=str(left_val),
|
||||
is_correct=True
|
||||
))
|
||||
|
||||
is_correct = abs(left_val - right_val) < 1e-9
|
||||
steps.append(MathStep(
|
||||
step_no=2,
|
||||
expression=f"{left} = {right}",
|
||||
result=str(right_val),
|
||||
is_correct=is_correct,
|
||||
error_type=None if is_correct else "calculation",
|
||||
error_detail=None if is_correct else f"正确答案应为{left_val}"
|
||||
))
|
||||
|
||||
return str(left_val), is_correct, steps
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None, True, steps
|
||||
|
||||
def _safe_eval(self, expr: str) -> float:
|
||||
"""安全计算表达式(仅允许数字和基本运算符)"""
|
||||
allowed_chars = set("0123456789.+-*/() ")
|
||||
if not all(c in allowed_chars for c in expr):
|
||||
raise ValueError(f"不安全的表达式: {expr}")
|
||||
return eval(expr) # 仅在安全校验后使用
|
||||
|
||||
|
||||
# 全局数学引擎实例
|
||||
math_engine = MathEngine()
|
||||
|
||||
|
||||
@router.post("/recognize")
|
||||
async def recognize_math(request: MathRecognizeRequest):
|
||||
"""
|
||||
数学列式/公式识别接口
|
||||
POST /api/v1/math/recognize
|
||||
"""
|
||||
if not request.strokes:
|
||||
raise HTTPException(status_code=400, detail="笔迹数据不能为空")
|
||||
|
||||
result = math_engine.recognize(
|
||||
strokes=request.strokes,
|
||||
math_type=request.math_type,
|
||||
grade_level=request.grade_level
|
||||
)
|
||||
|
||||
return {
|
||||
"code": 200,
|
||||
"msg": "success",
|
||||
"data": {
|
||||
"request_id": str(uuid.uuid4()),
|
||||
"result": result.dict()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,352 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
自然写手写识别与AI分析引擎软件 V1.0
|
||||
|
||||
OCR识别接口模块
|
||||
提供中英文手写文字OCR识别服务,基于PaddleOCR推理管道
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional, Dict, Any
|
||||
import numpy as np
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
|
||||
logger = logging.getLogger("writech-ai-engine.ocr")
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ==================== 请求/响应模型定义 ====================
|
||||
|
||||
class StrokePoint(BaseModel):
|
||||
"""笔迹坐标点"""
|
||||
x: int = Field(..., ge=0, le=65535, description="X坐标")
|
||||
y: int = Field(..., ge=0, le=65535, description="Y坐标")
|
||||
pressure: int = Field(0, ge=0, le=255, description="压力值")
|
||||
timestamp: int = Field(..., description="时间戳(毫秒)")
|
||||
pen_up: bool = Field(False, description="抬笔标记")
|
||||
|
||||
|
||||
class OCRRequest(BaseModel):
|
||||
"""OCR识别请求"""
|
||||
strokes: List[List[StrokePoint]] = Field(..., description="笔迹数据(按笔画分组)")
|
||||
page_id: Optional[str] = Field(None, description="点阵码页面ID")
|
||||
pen_id: Optional[str] = Field(None, description="笔设备ID")
|
||||
language: str = Field("zh", description="识别语言: zh/en/mixed")
|
||||
recognition_mode: str = Field("line", description="识别模式: char/word/line/page")
|
||||
|
||||
|
||||
class CharDetail(BaseModel):
|
||||
"""单字识别详情"""
|
||||
char: str = Field(..., description="识别的字符")
|
||||
confidence: float = Field(..., description="置信度(0-1)")
|
||||
bbox: List[int] = Field(..., description="包围框[x1,y1,x2,y2]")
|
||||
stroke_indices: List[int] = Field(default=[], description="对应的笔画索引")
|
||||
|
||||
|
||||
class OCRResult(BaseModel):
|
||||
"""OCR识别结果"""
|
||||
text: str = Field(..., description="识别文本")
|
||||
confidence: float = Field(..., description="整体置信度(0-1)")
|
||||
bbox: List[int] = Field(default=[], description="文本区域包围框")
|
||||
char_details: List[CharDetail] = Field(default=[], description="逐字详情")
|
||||
|
||||
|
||||
class OCRResponse(BaseModel):
|
||||
"""OCR识别响应"""
|
||||
code: int = 200
|
||||
msg: str = "success"
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
# ==================== OCR 推理引擎 ====================
|
||||
|
||||
class OCREngine:
|
||||
"""
|
||||
PaddleOCR 推理引擎
|
||||
|
||||
推理管道流程:
|
||||
笔迹坐标 → 预处理(归一化/去噪) → 笔画分割
|
||||
→ 模型推理(OCR) → 后处理(置信度过滤/结果合并) → 结果输出
|
||||
|
||||
支持的识别模式:
|
||||
- char: 单字识别(逐字识别,返回每个字的详情)
|
||||
- word: 词组识别(按词分割识别)
|
||||
- line: 行识别(按行识别,默认模式)
|
||||
- page: 整页识别(全页文字识别)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化OCR推理引擎"""
|
||||
self.model = None
|
||||
self.model_version = "1.0.0"
|
||||
self.is_loaded = False
|
||||
# 模型输入图像尺寸
|
||||
self.input_height = 48
|
||||
self.input_width = 320
|
||||
# 置信度阈值
|
||||
self.confidence_threshold = 0.5
|
||||
logger.info("OCR引擎初始化完成")
|
||||
|
||||
def load_model(self, model_path: str):
|
||||
"""
|
||||
加载PaddleOCR模型
|
||||
模型文件AES-256加密存储,推理时内存解密加载
|
||||
"""
|
||||
logger.info(f"加载OCR模型: {model_path}")
|
||||
# 解密模型文件
|
||||
# decrypted_model = self._decrypt_model(model_path)
|
||||
# self.model = paddle.jit.load(decrypted_model)
|
||||
self.is_loaded = True
|
||||
logger.info("OCR模型加载完成")
|
||||
|
||||
def preprocess_strokes(self, strokes: List[List[StrokePoint]]) -> np.ndarray:
|
||||
"""
|
||||
笔迹预处理管道
|
||||
|
||||
步骤:
|
||||
1. 坐标归一化(映射到标准画布尺寸)
|
||||
2. 去噪处理(滤除抖动和异常点)
|
||||
3. 笔迹渲染为灰度图像
|
||||
4. 图像尺寸归一化(resize到模型输入尺寸)
|
||||
"""
|
||||
# 计算所有点的边界框
|
||||
all_points = []
|
||||
for stroke in strokes:
|
||||
for point in stroke:
|
||||
all_points.append((point.x, point.y))
|
||||
|
||||
if not all_points:
|
||||
return np.zeros((1, self.input_height, self.input_width), dtype=np.float32)
|
||||
|
||||
xs = [p[0] for p in all_points]
|
||||
ys = [p[1] for p in all_points]
|
||||
min_x, max_x = min(xs), max(xs)
|
||||
min_y, max_y = min(ys), max(ys)
|
||||
|
||||
# 计算缩放比例(保持宽高比)
|
||||
width = max(max_x - min_x, 1)
|
||||
height = max(max_y - min_y, 1)
|
||||
scale = min(self.input_width / width, self.input_height / height) * 0.9
|
||||
|
||||
# 创建渲染画布
|
||||
canvas = np.zeros((self.input_height, self.input_width), dtype=np.float32)
|
||||
|
||||
# 渲染笔迹到画布
|
||||
for stroke in strokes:
|
||||
for i in range(1, len(stroke)):
|
||||
x1 = int((stroke[i - 1].x - min_x) * scale)
|
||||
y1 = int((stroke[i - 1].y - min_y) * scale)
|
||||
x2 = int((stroke[i].x - min_x) * scale)
|
||||
y2 = int((stroke[i].y - min_y) * scale)
|
||||
# 使用Bresenham算法画线
|
||||
self._draw_line(canvas, x1, y1, x2, y2,
|
||||
thickness=max(1, stroke[i].pressure // 85))
|
||||
|
||||
# 归一化到[0, 1]
|
||||
if canvas.max() > 0:
|
||||
canvas = canvas / canvas.max()
|
||||
|
||||
return canvas.reshape(1, self.input_height, self.input_width)
|
||||
|
||||
def recognize(self, strokes: List[List[StrokePoint]],
|
||||
mode: str = "line") -> List[OCRResult]:
|
||||
"""
|
||||
执行OCR识别
|
||||
|
||||
@param strokes: 笔迹数据(按笔画分组)
|
||||
@param mode: 识别模式 (char/word/line/page)
|
||||
@return: 识别结果列表
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# 预处理
|
||||
image = self.preprocess_strokes(strokes)
|
||||
|
||||
# 模型推理
|
||||
# predictions = self.model(image)
|
||||
# 模拟推理结果
|
||||
predictions = self._mock_inference(image, mode)
|
||||
|
||||
# 后处理(置信度过滤、结果合并)
|
||||
results = self._postprocess(predictions, mode)
|
||||
|
||||
inference_time = time.time() - start_time
|
||||
logger.info(f"OCR识别完成, mode={mode}, time={inference_time:.4f}s, "
|
||||
f"results={len(results)}")
|
||||
|
||||
return results
|
||||
|
||||
def _postprocess(self, predictions: Dict, mode: str) -> List[OCRResult]:
|
||||
"""
|
||||
后处理:置信度过滤 + 结果合并
|
||||
|
||||
- 过滤低于阈值的识别结果
|
||||
- 相邻字符合并为词/行
|
||||
- 生成逐字详情信息
|
||||
"""
|
||||
results = []
|
||||
|
||||
if mode == "char":
|
||||
# 逐字模式:返回每个字符的独立结果
|
||||
for char_pred in predictions.get("chars", []):
|
||||
if char_pred["confidence"] >= self.confidence_threshold:
|
||||
result = OCRResult(
|
||||
text=char_pred["char"],
|
||||
confidence=char_pred["confidence"],
|
||||
bbox=char_pred["bbox"],
|
||||
char_details=[CharDetail(
|
||||
char=char_pred["char"],
|
||||
confidence=char_pred["confidence"],
|
||||
bbox=char_pred["bbox"],
|
||||
stroke_indices=char_pred.get("stroke_indices", [])
|
||||
)]
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
elif mode in ("line", "page"):
|
||||
# 行/页模式:合并字符为文本行
|
||||
for line_pred in predictions.get("lines", []):
|
||||
if line_pred["confidence"] >= self.confidence_threshold:
|
||||
char_details = [
|
||||
CharDetail(
|
||||
char=cd["char"],
|
||||
confidence=cd["confidence"],
|
||||
bbox=cd["bbox"],
|
||||
stroke_indices=cd.get("stroke_indices", [])
|
||||
)
|
||||
for cd in line_pred.get("char_details", [])
|
||||
]
|
||||
result = OCRResult(
|
||||
text=line_pred["text"],
|
||||
confidence=line_pred["confidence"],
|
||||
bbox=line_pred["bbox"],
|
||||
char_details=char_details
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
def _draw_line(self, canvas: np.ndarray, x1: int, y1: int,
|
||||
x2: int, y2: int, thickness: int = 1):
|
||||
"""Bresenham直线绘制算法"""
|
||||
h, w = canvas.shape
|
||||
dx = abs(x2 - x1)
|
||||
dy = abs(y2 - y1)
|
||||
sx = 1 if x1 < x2 else -1
|
||||
sy = 1 if y1 < y2 else -1
|
||||
err = dx - dy
|
||||
|
||||
while True:
|
||||
# 绘制像素(带粗细)
|
||||
for tx in range(-thickness, thickness + 1):
|
||||
for ty in range(-thickness, thickness + 1):
|
||||
px, py = x1 + tx, y1 + ty
|
||||
if 0 <= px < w and 0 <= py < h:
|
||||
canvas[py][px] = 1.0
|
||||
|
||||
if x1 == x2 and y1 == y2:
|
||||
break
|
||||
e2 = 2 * err
|
||||
if e2 > -dy:
|
||||
err -= dy
|
||||
x1 += sx
|
||||
if e2 < dx:
|
||||
err += dx
|
||||
y1 += sy
|
||||
|
||||
def _mock_inference(self, image: np.ndarray, mode: str) -> Dict:
|
||||
"""模拟推理结果(用于示例)"""
|
||||
return {
|
||||
"lines": [{
|
||||
"text": "示例文字",
|
||||
"confidence": 0.95,
|
||||
"bbox": [10, 10, 200, 48],
|
||||
"char_details": [
|
||||
{"char": "示", "confidence": 0.96, "bbox": [10, 10, 50, 48]},
|
||||
{"char": "例", "confidence": 0.94, "bbox": [50, 10, 100, 48]},
|
||||
{"char": "文", "confidence": 0.97, "bbox": [100, 10, 150, 48]},
|
||||
{"char": "字", "confidence": 0.93, "bbox": [150, 10, 200, 48]}
|
||||
]
|
||||
}],
|
||||
"chars": []
|
||||
}
|
||||
|
||||
def _decrypt_model(self, model_path: str) -> str:
|
||||
"""AES-256解密模型文件"""
|
||||
# 使用预配置的密钥解密模型文件
|
||||
# key = settings.model_encryption_key
|
||||
# cipher = AES.new(key, AES.MODE_CBC, iv)
|
||||
return model_path
|
||||
|
||||
|
||||
# 全局OCR引擎实例
|
||||
ocr_engine = OCREngine()
|
||||
|
||||
|
||||
# ==================== API 路由 ====================
|
||||
|
||||
@router.post("/recognize", response_model=OCRResponse)
|
||||
async def recognize_text(request: OCRRequest):
|
||||
"""
|
||||
手写文字OCR识别接口
|
||||
POST /api/v1/ocr/recognize
|
||||
|
||||
接收笔迹坐标数据,返回识别文本及逐字详情
|
||||
支持中文、英文及中英混合识别
|
||||
"""
|
||||
# 输入校验
|
||||
if not request.strokes:
|
||||
raise HTTPException(status_code=400, detail="笔迹数据不能为空")
|
||||
|
||||
total_points = sum(len(stroke) for stroke in request.strokes)
|
||||
if total_points > 50000:
|
||||
raise HTTPException(status_code=400, detail="笔迹点数过多,最大支持50000点")
|
||||
|
||||
# 执行OCR识别
|
||||
results = ocr_engine.recognize(
|
||||
strokes=request.strokes,
|
||||
mode=request.recognition_mode
|
||||
)
|
||||
|
||||
# 构建响应
|
||||
return OCRResponse(
|
||||
code=200,
|
||||
msg="success",
|
||||
data={
|
||||
"request_id": str(uuid.uuid4()),
|
||||
"language": request.language,
|
||||
"mode": request.recognition_mode,
|
||||
"results": [r.dict() for r in results],
|
||||
"total_chars": sum(len(r.text) for r in results)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.post("/batch-recognize")
|
||||
async def batch_recognize(requests: List[OCRRequest]):
|
||||
"""
|
||||
批量OCR识别接口
|
||||
一次请求识别多组笔迹数据
|
||||
"""
|
||||
results = []
|
||||
for req in requests:
|
||||
result = ocr_engine.recognize(
|
||||
strokes=req.strokes,
|
||||
mode=req.recognition_mode
|
||||
)
|
||||
results.append({
|
||||
"page_id": req.page_id,
|
||||
"results": [r.dict() for r in result]
|
||||
})
|
||||
|
||||
return {
|
||||
"code": 200,
|
||||
"msg": "success",
|
||||
"data": {
|
||||
"batch_size": len(requests),
|
||||
"results": results
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,400 @@
|
||||
# 自然写手写识别与AI分析引擎软件 V1.0
|
||||
# 笔顺评分接口模块 - 中文汉字笔顺识别与评分服务
|
||||
|
||||
"""
|
||||
笔顺评分API接口
|
||||
提供汉字笔顺正确性评估、书写质量评分、笔画拆分分析等功能
|
||||
基于深度学习笔顺分析模型,支持GB2312常用汉字笔顺评分
|
||||
"""
|
||||
|
||||
import time
|
||||
import logging
|
||||
import hashlib
|
||||
import numpy as np
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from fastapi import APIRouter, HTTPException, Depends, Request
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ==================== 数据模型定义 ====================
|
||||
|
||||
class StrokePointInput(BaseModel):
|
||||
"""笔迹坐标点输入"""
|
||||
x: float = Field(..., description="X坐标")
|
||||
y: float = Field(..., description="Y坐标")
|
||||
pressure: float = Field(0.5, ge=0.0, le=1.0, description="压力值")
|
||||
timestamp: int = Field(..., description="时间戳(毫秒)")
|
||||
|
||||
|
||||
class StrokeOrderRequest(BaseModel):
|
||||
"""笔顺评分请求"""
|
||||
character: str = Field(..., min_length=1, max_length=1, description="目标汉字")
|
||||
strokes: List[List[StrokePointInput]] = Field(..., description="用户书写的笔画列表")
|
||||
pen_id: Optional[str] = Field(None, description="点阵笔设备ID")
|
||||
student_id: Optional[str] = Field(None, description="学生ID")
|
||||
difficulty_level: int = Field(1, ge=1, le=3, description="评分难度等级1-3")
|
||||
|
||||
@validator('character')
|
||||
def validate_chinese_char(cls, v):
|
||||
"""校验是否为中文汉字"""
|
||||
if not '\u4e00' <= v <= '\u9fff':
|
||||
raise ValueError('仅支持中文汉字笔顺评分')
|
||||
return v
|
||||
|
||||
|
||||
class WritingQualityRequest(BaseModel):
|
||||
"""书写质量评测请求"""
|
||||
strokes: List[List[StrokePointInput]] = Field(..., description="笔迹数据")
|
||||
reference_char: Optional[str] = Field(None, description="参考字符(可选)")
|
||||
eval_dimensions: List[str] = Field(
|
||||
default=["structure", "spacing", "normative", "aesthetics"],
|
||||
description="评测维度"
|
||||
)
|
||||
|
||||
|
||||
class StrokeDirection(str, Enum):
|
||||
"""笔画方向枚举"""
|
||||
HORIZONTAL = "horizontal" # 横
|
||||
VERTICAL = "vertical" # 竖
|
||||
LEFT_FALLING = "left_falling" # 撇
|
||||
RIGHT_FALLING = "right_falling" # 捺
|
||||
DOT = "dot" # 点
|
||||
TURNING = "turning" # 折
|
||||
HOOK = "hook" # 钩
|
||||
RISING = "rising" # 提
|
||||
|
||||
|
||||
@dataclass
|
||||
class StrokeFeature:
|
||||
"""单个笔画特征数据"""
|
||||
direction: StrokeDirection # 笔画方向
|
||||
start_point: Tuple[float, float] # 起始坐标
|
||||
end_point: Tuple[float, float] # 结束坐标
|
||||
length: float # 笔画长度
|
||||
avg_pressure: float # 平均压力
|
||||
curvature: float # 弯曲度
|
||||
speed: float # 书写速度
|
||||
|
||||
|
||||
# ==================== 标准笔顺数据库 ====================
|
||||
|
||||
class StrokeOrderDatabase:
|
||||
"""
|
||||
标准笔顺数据库
|
||||
存储GB2312常用汉字的标准笔顺信息,用于笔顺正确性比对
|
||||
数据来源:国家语委《现代汉语通用字笔顺规范》
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# 标准笔顺字典:字符 -> 笔画方向序列
|
||||
self._standard_orders: Dict[str, List[StrokeDirection]] = {}
|
||||
# 笔画数字典:字符 -> 标准笔画数
|
||||
self._stroke_counts: Dict[str, int] = {}
|
||||
# 加载常用汉字笔顺数据
|
||||
self._load_standard_data()
|
||||
|
||||
def _load_standard_data(self):
|
||||
"""加载标准笔顺数据(示例部分常用字)"""
|
||||
# 一年级常用汉字笔顺数据
|
||||
standard_data = {
|
||||
"一": ([StrokeDirection.HORIZONTAL], 1),
|
||||
"二": ([StrokeDirection.HORIZONTAL, StrokeDirection.HORIZONTAL], 2),
|
||||
"三": ([StrokeDirection.HORIZONTAL, StrokeDirection.HORIZONTAL, StrokeDirection.HORIZONTAL], 3),
|
||||
"十": ([StrokeDirection.HORIZONTAL, StrokeDirection.VERTICAL], 2),
|
||||
"大": ([StrokeDirection.HORIZONTAL, StrokeDirection.LEFT_FALLING, StrokeDirection.RIGHT_FALLING], 3),
|
||||
"人": ([StrokeDirection.LEFT_FALLING, StrokeDirection.RIGHT_FALLING], 2),
|
||||
"口": ([StrokeDirection.VERTICAL, StrokeDirection.TURNING, StrokeDirection.HORIZONTAL], 3),
|
||||
"日": ([StrokeDirection.VERTICAL, StrokeDirection.TURNING, StrokeDirection.HORIZONTAL, StrokeDirection.HORIZONTAL], 4),
|
||||
"月": ([StrokeDirection.LEFT_FALLING, StrokeDirection.TURNING, StrokeDirection.HORIZONTAL, StrokeDirection.HORIZONTAL], 4),
|
||||
"水": ([StrokeDirection.VERTICAL, StrokeDirection.TURNING, StrokeDirection.LEFT_FALLING, StrokeDirection.RIGHT_FALLING], 4),
|
||||
}
|
||||
for char, (order, count) in standard_data.items():
|
||||
self._standard_orders[char] = order
|
||||
self._stroke_counts[char] = count
|
||||
logger.info(f"标准笔顺数据库加载完成,共 {len(self._standard_orders)} 个汉字")
|
||||
|
||||
def get_standard_order(self, char: str) -> Optional[List[StrokeDirection]]:
|
||||
"""获取汉字标准笔顺"""
|
||||
return self._standard_orders.get(char)
|
||||
|
||||
def get_stroke_count(self, char: str) -> Optional[int]:
|
||||
"""获取汉字标准笔画数"""
|
||||
return self._stroke_counts.get(char)
|
||||
|
||||
|
||||
# ==================== 笔顺分析引擎 ====================
|
||||
|
||||
class StrokeOrderAnalyzer:
|
||||
"""
|
||||
笔顺分析引擎
|
||||
通过笔迹坐标数据分析每一笔的方向、顺序,并与标准笔顺进行比对评分
|
||||
评分维度:笔顺正确性、笔画数、书写规范性
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._database = StrokeOrderDatabase()
|
||||
self._direction_model = None # 笔画方向分类模型(CNN)
|
||||
logger.info("笔顺分析引擎初始化完成")
|
||||
|
||||
def _extract_stroke_feature(self, points: List[StrokePointInput]) -> StrokeFeature:
|
||||
"""
|
||||
提取单个笔画的特征向量
|
||||
包括方向、长度、弯曲度、书写速度等
|
||||
"""
|
||||
if len(points) < 2:
|
||||
return StrokeFeature(
|
||||
direction=StrokeDirection.DOT,
|
||||
start_point=(points[0].x, points[0].y),
|
||||
end_point=(points[0].x, points[0].y),
|
||||
length=0.0, avg_pressure=points[0].pressure,
|
||||
curvature=0.0, speed=0.0
|
||||
)
|
||||
|
||||
# 计算起止点
|
||||
start = (points[0].x, points[0].y)
|
||||
end = (points[-1].x, points[-1].y)
|
||||
|
||||
# 计算笔画总长度(累加相邻点欧氏距离)
|
||||
total_length = 0.0
|
||||
for i in range(1, len(points)):
|
||||
dx = points[i].x - points[i-1].x
|
||||
dy = points[i].y - points[i-1].y
|
||||
total_length += np.sqrt(dx*dx + dy*dy)
|
||||
|
||||
# 计算平均压力值
|
||||
avg_pressure = np.mean([p.pressure for p in points])
|
||||
|
||||
# 计算书写速度(总长度/时间差)
|
||||
time_diff = max(points[-1].timestamp - points[0].timestamp, 1)
|
||||
speed = total_length / time_diff * 1000 # 像素/秒
|
||||
|
||||
# 计算弯曲度(实际路径长度 / 起止点直线距离)
|
||||
direct_dist = np.sqrt((end[0]-start[0])**2 + (end[1]-start[1])**2)
|
||||
curvature = total_length / max(direct_dist, 1.0)
|
||||
|
||||
# 判定笔画方向
|
||||
direction = self._classify_direction(start, end, curvature)
|
||||
|
||||
return StrokeFeature(
|
||||
direction=direction, start_point=start, end_point=end,
|
||||
length=total_length, avg_pressure=avg_pressure,
|
||||
curvature=curvature, speed=speed
|
||||
)
|
||||
|
||||
def _classify_direction(self, start: Tuple, end: Tuple, curvature: float) -> StrokeDirection:
|
||||
"""
|
||||
基于起止点坐标和弯曲度分类笔画方向
|
||||
使用角度阈值和弯曲度综合判定
|
||||
"""
|
||||
dx = end[0] - start[0]
|
||||
dy = end[1] - start[1]
|
||||
distance = np.sqrt(dx*dx + dy*dy)
|
||||
|
||||
# 极短笔画判定为点
|
||||
if distance < 5.0:
|
||||
return StrokeDirection.DOT
|
||||
|
||||
# 计算角度(弧度转角度,0度为正右方,顺时针为正)
|
||||
angle = np.degrees(np.arctan2(dy, dx))
|
||||
|
||||
# 弯曲度高的笔画判定为折或钩
|
||||
if curvature > 1.8:
|
||||
return StrokeDirection.TURNING if dy > 0 else StrokeDirection.HOOK
|
||||
|
||||
# 根据角度范围判定笔画方向
|
||||
if -20 <= angle <= 20:
|
||||
return StrokeDirection.HORIZONTAL # 横:接近水平向右
|
||||
elif 70 <= angle <= 110:
|
||||
return StrokeDirection.VERTICAL # 竖:接近垂直向下
|
||||
elif 120 <= angle <= 170:
|
||||
return StrokeDirection.LEFT_FALLING # 撇:左下方向
|
||||
elif 20 < angle < 70:
|
||||
return StrokeDirection.RIGHT_FALLING # 捺:右下方向
|
||||
elif -70 <= angle < -20:
|
||||
return StrokeDirection.RISING # 提:右上方向
|
||||
else:
|
||||
return StrokeDirection.LEFT_FALLING # 默认归为撇
|
||||
|
||||
def evaluate_stroke_order(self, char: str, strokes: List[List[StrokePointInput]],
|
||||
difficulty: int = 1) -> Dict:
|
||||
"""
|
||||
评估笔顺正确性
|
||||
将用户书写的每一笔与标准笔顺逐一比对,计算匹配分数
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# 获取标准笔顺
|
||||
standard_order = self._database.get_standard_order(char)
|
||||
standard_count = self._database.get_stroke_count(char)
|
||||
|
||||
# 提取用户每一笔的特征
|
||||
user_features = [self._extract_stroke_feature(s) for s in strokes]
|
||||
user_directions = [f.direction for f in user_features]
|
||||
|
||||
# 笔画数评分(满分100)
|
||||
count_score = 100.0
|
||||
if standard_count:
|
||||
count_diff = abs(len(strokes) - standard_count)
|
||||
count_score = max(0, 100 - count_diff * 25)
|
||||
|
||||
# 笔顺正确性评分(逐笔比对方向)
|
||||
order_score = 100.0
|
||||
errors = []
|
||||
if standard_order:
|
||||
match_count = 0
|
||||
compare_len = min(len(user_directions), len(standard_order))
|
||||
for i in range(compare_len):
|
||||
if user_directions[i] == standard_order[i]:
|
||||
match_count += 1
|
||||
else:
|
||||
errors.append({
|
||||
"stroke_index": i + 1,
|
||||
"expected": standard_order[i].value,
|
||||
"actual": user_directions[i].value,
|
||||
"message": f"第{i+1}笔方向错误:应为{standard_order[i].value},实际为{user_directions[i].value}"
|
||||
})
|
||||
order_score = (match_count / max(len(standard_order), 1)) * 100
|
||||
|
||||
# 根据难度等级调整评分权重
|
||||
weight_order = 0.5 + difficulty * 0.1 # 难度越高,笔顺正确性权重越大
|
||||
weight_count = 1.0 - weight_order
|
||||
|
||||
total_score = order_score * weight_order + count_score * weight_count
|
||||
elapsed = (time.time() - start_time) * 1000
|
||||
|
||||
return {
|
||||
"character": char,
|
||||
"total_score": round(total_score, 1),
|
||||
"order_score": round(order_score, 1),
|
||||
"count_score": round(count_score, 1),
|
||||
"user_stroke_count": len(strokes),
|
||||
"standard_stroke_count": standard_count,
|
||||
"stroke_order": [d.value for d in user_directions],
|
||||
"correct_order": [d.value for d in standard_order] if standard_order else [],
|
||||
"errors": errors,
|
||||
"inference_time_ms": round(elapsed, 2)
|
||||
}
|
||||
|
||||
|
||||
# ==================== 书写质量评测引擎 ====================
|
||||
|
||||
class WritingQualityEngine:
|
||||
"""
|
||||
书写质量评测引擎
|
||||
从结构均衡性、笔画间距、规范性、美观度四个维度评估书写质量
|
||||
"""
|
||||
|
||||
def evaluate(self, strokes: List[List[StrokePointInput]],
|
||||
dimensions: List[str]) -> Dict:
|
||||
"""执行书写质量评测"""
|
||||
scores = {}
|
||||
|
||||
# 提取全部坐标点用于整体分析
|
||||
all_points = []
|
||||
for stroke in strokes:
|
||||
all_points.extend([(p.x, p.y, p.pressure) for p in stroke])
|
||||
|
||||
if not all_points:
|
||||
return {"total_score": 0, "dimensions": {}}
|
||||
|
||||
xs = [p[0] for p in all_points]
|
||||
ys = [p[1] for p in all_points]
|
||||
|
||||
# 计算书写区域边界框
|
||||
bbox_width = max(xs) - min(xs)
|
||||
bbox_height = max(ys) - min(ys)
|
||||
|
||||
if "structure" in dimensions:
|
||||
# 结构均衡性:分析重心位置与对称性
|
||||
center_x = np.mean(xs)
|
||||
center_y = np.mean(ys)
|
||||
expected_center_x = min(xs) + bbox_width / 2
|
||||
expected_center_y = min(ys) + bbox_height / 2
|
||||
offset = np.sqrt((center_x - expected_center_x)**2 + (center_y - expected_center_y)**2)
|
||||
max_offset = np.sqrt(bbox_width**2 + bbox_height**2) / 4
|
||||
scores["structure"] = round(max(0, 100 - (offset / max(max_offset, 1)) * 60), 1)
|
||||
|
||||
if "spacing" in dimensions:
|
||||
# 笔画间距均匀性:分析相邻笔画起始点间距的标准差
|
||||
if len(strokes) > 1:
|
||||
start_points = [(s[0].x, s[0].y) for s in strokes if s]
|
||||
gaps = []
|
||||
for i in range(1, len(start_points)):
|
||||
gap = np.sqrt((start_points[i][0]-start_points[i-1][0])**2 +
|
||||
(start_points[i][1]-start_points[i-1][1])**2)
|
||||
gaps.append(gap)
|
||||
gap_std = np.std(gaps) if gaps else 0
|
||||
gap_mean = np.mean(gaps) if gaps else 1
|
||||
cv = gap_std / max(gap_mean, 1) # 变异系数
|
||||
scores["spacing"] = round(max(0, 100 - cv * 80), 1)
|
||||
else:
|
||||
scores["spacing"] = 80.0
|
||||
|
||||
if "normative" in dimensions:
|
||||
# 规范性:分析笔画弯曲度和压力稳定性
|
||||
pressures = [p[2] for p in all_points]
|
||||
pressure_std = np.std(pressures) if pressures else 0
|
||||
scores["normative"] = round(max(0, 100 - pressure_std * 200), 1)
|
||||
|
||||
if "aesthetics" in dimensions:
|
||||
# 美观度:综合笔画流畅度和整体比例
|
||||
aspect_ratio = bbox_width / max(bbox_height, 1)
|
||||
ratio_score = max(0, 100 - abs(aspect_ratio - 1.0) * 50) # 接近正方形得分高
|
||||
scores["aesthetics"] = round(ratio_score, 1)
|
||||
|
||||
total = np.mean(list(scores.values())) if scores else 0
|
||||
return {"total_score": round(total, 1), "dimensions": scores}
|
||||
|
||||
|
||||
# ==================== API路由定义 ====================
|
||||
|
||||
router = APIRouter(prefix="/api/v1", tags=["笔顺评分"])
|
||||
_analyzer = StrokeOrderAnalyzer()
|
||||
_quality_engine = WritingQualityEngine()
|
||||
|
||||
|
||||
@router.post("/stroke-order/evaluate")
|
||||
async def evaluate_stroke_order(request: StrokeOrderRequest):
|
||||
"""
|
||||
笔顺正确性评分接口
|
||||
POST /api/v1/stroke-order/evaluate
|
||||
输入汉字和用户书写笔画数据,返回笔顺正确性评分和错误详情
|
||||
"""
|
||||
try:
|
||||
result = _analyzer.evaluate_stroke_order(
|
||||
char=request.character,
|
||||
strokes=request.strokes,
|
||||
difficulty=request.difficulty_level
|
||||
)
|
||||
# 记录审计日志(安全设计:所有识别请求记录调用方、时间、模型版本)
|
||||
logger.info(
|
||||
f"笔顺评分完成: char={request.character}, "
|
||||
f"score={result['total_score']}, pen={request.pen_id}, "
|
||||
f"student={request.student_id}, time={result['inference_time_ms']}ms"
|
||||
)
|
||||
return {"code": 200, "msg": "success", "data": result}
|
||||
except Exception as e:
|
||||
logger.error(f"笔顺评分异常: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"笔顺评分服务异常: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/writing/quality")
|
||||
async def evaluate_writing_quality(request: WritingQualityRequest):
|
||||
"""
|
||||
书写质量评测接口
|
||||
POST /api/v1/writing/quality
|
||||
从结构、间距、规范性、美观度四维度评测书写质量
|
||||
"""
|
||||
try:
|
||||
result = _quality_engine.evaluate(
|
||||
strokes=request.strokes,
|
||||
dimensions=request.eval_dimensions
|
||||
)
|
||||
logger.info(f"书写质量评测完成: score={result['total_score']}")
|
||||
return {"code": 200, "msg": "success", "data": result}
|
||||
except Exception as e:
|
||||
logger.error(f"书写质量评测异常: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"书写质量评测异常: {str(e)}")
|
||||
Reference in New Issue
Block a user