401 lines
16 KiB
Python
401 lines
16 KiB
Python
# 自然写手写识别与AI分析引擎软件 V1.0
|
||
# 笔顺评分接口模块 - 中文汉字笔顺识别与评分服务
|
||
|
||
"""
|
||
笔顺评分API接口
|
||
提供汉字笔顺正确性评估、书写质量评分、笔画拆分分析等功能
|
||
基于深度学习笔顺分析模型,支持GB2312常用汉字笔顺评分
|
||
"""
|
||
|
||
import time
|
||
import logging
|
||
import hashlib
|
||
import numpy as np
|
||
from typing import List, Dict, Optional, Tuple
|
||
from dataclasses import dataclass, field
|
||
from enum import Enum
|
||
from fastapi import APIRouter, HTTPException, Depends, Request
|
||
from pydantic import BaseModel, Field, validator
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# ==================== 数据模型定义 ====================
|
||
|
||
class StrokePointInput(BaseModel):
|
||
"""笔迹坐标点输入"""
|
||
x: float = Field(..., description="X坐标")
|
||
y: float = Field(..., description="Y坐标")
|
||
pressure: float = Field(0.5, ge=0.0, le=1.0, description="压力值")
|
||
timestamp: int = Field(..., description="时间戳(毫秒)")
|
||
|
||
|
||
class StrokeOrderRequest(BaseModel):
|
||
"""笔顺评分请求"""
|
||
character: str = Field(..., min_length=1, max_length=1, description="目标汉字")
|
||
strokes: List[List[StrokePointInput]] = Field(..., description="用户书写的笔画列表")
|
||
pen_id: Optional[str] = Field(None, description="点阵笔设备ID")
|
||
student_id: Optional[str] = Field(None, description="学生ID")
|
||
difficulty_level: int = Field(1, ge=1, le=3, description="评分难度等级1-3")
|
||
|
||
@validator('character')
|
||
def validate_chinese_char(cls, v):
|
||
"""校验是否为中文汉字"""
|
||
if not '\u4e00' <= v <= '\u9fff':
|
||
raise ValueError('仅支持中文汉字笔顺评分')
|
||
return v
|
||
|
||
|
||
class WritingQualityRequest(BaseModel):
|
||
"""书写质量评测请求"""
|
||
strokes: List[List[StrokePointInput]] = Field(..., description="笔迹数据")
|
||
reference_char: Optional[str] = Field(None, description="参考字符(可选)")
|
||
eval_dimensions: List[str] = Field(
|
||
default=["structure", "spacing", "normative", "aesthetics"],
|
||
description="评测维度"
|
||
)
|
||
|
||
|
||
class StrokeDirection(str, Enum):
|
||
"""笔画方向枚举"""
|
||
HORIZONTAL = "horizontal" # 横
|
||
VERTICAL = "vertical" # 竖
|
||
LEFT_FALLING = "left_falling" # 撇
|
||
RIGHT_FALLING = "right_falling" # 捺
|
||
DOT = "dot" # 点
|
||
TURNING = "turning" # 折
|
||
HOOK = "hook" # 钩
|
||
RISING = "rising" # 提
|
||
|
||
|
||
@dataclass
|
||
class StrokeFeature:
|
||
"""单个笔画特征数据"""
|
||
direction: StrokeDirection # 笔画方向
|
||
start_point: Tuple[float, float] # 起始坐标
|
||
end_point: Tuple[float, float] # 结束坐标
|
||
length: float # 笔画长度
|
||
avg_pressure: float # 平均压力
|
||
curvature: float # 弯曲度
|
||
speed: float # 书写速度
|
||
|
||
|
||
# ==================== 标准笔顺数据库 ====================
|
||
|
||
class StrokeOrderDatabase:
|
||
"""
|
||
标准笔顺数据库
|
||
存储GB2312常用汉字的标准笔顺信息,用于笔顺正确性比对
|
||
数据来源:国家语委《现代汉语通用字笔顺规范》
|
||
"""
|
||
|
||
def __init__(self):
|
||
# 标准笔顺字典:字符 -> 笔画方向序列
|
||
self._standard_orders: Dict[str, List[StrokeDirection]] = {}
|
||
# 笔画数字典:字符 -> 标准笔画数
|
||
self._stroke_counts: Dict[str, int] = {}
|
||
# 加载常用汉字笔顺数据
|
||
self._load_standard_data()
|
||
|
||
def _load_standard_data(self):
|
||
"""加载标准笔顺数据(示例部分常用字)"""
|
||
# 一年级常用汉字笔顺数据
|
||
standard_data = {
|
||
"一": ([StrokeDirection.HORIZONTAL], 1),
|
||
"二": ([StrokeDirection.HORIZONTAL, StrokeDirection.HORIZONTAL], 2),
|
||
"三": ([StrokeDirection.HORIZONTAL, StrokeDirection.HORIZONTAL, StrokeDirection.HORIZONTAL], 3),
|
||
"十": ([StrokeDirection.HORIZONTAL, StrokeDirection.VERTICAL], 2),
|
||
"大": ([StrokeDirection.HORIZONTAL, StrokeDirection.LEFT_FALLING, StrokeDirection.RIGHT_FALLING], 3),
|
||
"人": ([StrokeDirection.LEFT_FALLING, StrokeDirection.RIGHT_FALLING], 2),
|
||
"口": ([StrokeDirection.VERTICAL, StrokeDirection.TURNING, StrokeDirection.HORIZONTAL], 3),
|
||
"日": ([StrokeDirection.VERTICAL, StrokeDirection.TURNING, StrokeDirection.HORIZONTAL, StrokeDirection.HORIZONTAL], 4),
|
||
"月": ([StrokeDirection.LEFT_FALLING, StrokeDirection.TURNING, StrokeDirection.HORIZONTAL, StrokeDirection.HORIZONTAL], 4),
|
||
"水": ([StrokeDirection.VERTICAL, StrokeDirection.TURNING, StrokeDirection.LEFT_FALLING, StrokeDirection.RIGHT_FALLING], 4),
|
||
}
|
||
for char, (order, count) in standard_data.items():
|
||
self._standard_orders[char] = order
|
||
self._stroke_counts[char] = count
|
||
logger.info(f"标准笔顺数据库加载完成,共 {len(self._standard_orders)} 个汉字")
|
||
|
||
def get_standard_order(self, char: str) -> Optional[List[StrokeDirection]]:
|
||
"""获取汉字标准笔顺"""
|
||
return self._standard_orders.get(char)
|
||
|
||
def get_stroke_count(self, char: str) -> Optional[int]:
|
||
"""获取汉字标准笔画数"""
|
||
return self._stroke_counts.get(char)
|
||
|
||
|
||
# ==================== 笔顺分析引擎 ====================
|
||
|
||
class StrokeOrderAnalyzer:
|
||
"""
|
||
笔顺分析引擎
|
||
通过笔迹坐标数据分析每一笔的方向、顺序,并与标准笔顺进行比对评分
|
||
评分维度:笔顺正确性、笔画数、书写规范性
|
||
"""
|
||
|
||
def __init__(self):
|
||
self._database = StrokeOrderDatabase()
|
||
self._direction_model = None # 笔画方向分类模型(CNN)
|
||
logger.info("笔顺分析引擎初始化完成")
|
||
|
||
def _extract_stroke_feature(self, points: List[StrokePointInput]) -> StrokeFeature:
|
||
"""
|
||
提取单个笔画的特征向量
|
||
包括方向、长度、弯曲度、书写速度等
|
||
"""
|
||
if len(points) < 2:
|
||
return StrokeFeature(
|
||
direction=StrokeDirection.DOT,
|
||
start_point=(points[0].x, points[0].y),
|
||
end_point=(points[0].x, points[0].y),
|
||
length=0.0, avg_pressure=points[0].pressure,
|
||
curvature=0.0, speed=0.0
|
||
)
|
||
|
||
# 计算起止点
|
||
start = (points[0].x, points[0].y)
|
||
end = (points[-1].x, points[-1].y)
|
||
|
||
# 计算笔画总长度(累加相邻点欧氏距离)
|
||
total_length = 0.0
|
||
for i in range(1, len(points)):
|
||
dx = points[i].x - points[i-1].x
|
||
dy = points[i].y - points[i-1].y
|
||
total_length += np.sqrt(dx*dx + dy*dy)
|
||
|
||
# 计算平均压力值
|
||
avg_pressure = np.mean([p.pressure for p in points])
|
||
|
||
# 计算书写速度(总长度/时间差)
|
||
time_diff = max(points[-1].timestamp - points[0].timestamp, 1)
|
||
speed = total_length / time_diff * 1000 # 像素/秒
|
||
|
||
# 计算弯曲度(实际路径长度 / 起止点直线距离)
|
||
direct_dist = np.sqrt((end[0]-start[0])**2 + (end[1]-start[1])**2)
|
||
curvature = total_length / max(direct_dist, 1.0)
|
||
|
||
# 判定笔画方向
|
||
direction = self._classify_direction(start, end, curvature)
|
||
|
||
return StrokeFeature(
|
||
direction=direction, start_point=start, end_point=end,
|
||
length=total_length, avg_pressure=avg_pressure,
|
||
curvature=curvature, speed=speed
|
||
)
|
||
|
||
def _classify_direction(self, start: Tuple, end: Tuple, curvature: float) -> StrokeDirection:
|
||
"""
|
||
基于起止点坐标和弯曲度分类笔画方向
|
||
使用角度阈值和弯曲度综合判定
|
||
"""
|
||
dx = end[0] - start[0]
|
||
dy = end[1] - start[1]
|
||
distance = np.sqrt(dx*dx + dy*dy)
|
||
|
||
# 极短笔画判定为点
|
||
if distance < 5.0:
|
||
return StrokeDirection.DOT
|
||
|
||
# 计算角度(弧度转角度,0度为正右方,顺时针为正)
|
||
angle = np.degrees(np.arctan2(dy, dx))
|
||
|
||
# 弯曲度高的笔画判定为折或钩
|
||
if curvature > 1.8:
|
||
return StrokeDirection.TURNING if dy > 0 else StrokeDirection.HOOK
|
||
|
||
# 根据角度范围判定笔画方向
|
||
if -20 <= angle <= 20:
|
||
return StrokeDirection.HORIZONTAL # 横:接近水平向右
|
||
elif 70 <= angle <= 110:
|
||
return StrokeDirection.VERTICAL # 竖:接近垂直向下
|
||
elif 120 <= angle <= 170:
|
||
return StrokeDirection.LEFT_FALLING # 撇:左下方向
|
||
elif 20 < angle < 70:
|
||
return StrokeDirection.RIGHT_FALLING # 捺:右下方向
|
||
elif -70 <= angle < -20:
|
||
return StrokeDirection.RISING # 提:右上方向
|
||
else:
|
||
return StrokeDirection.LEFT_FALLING # 默认归为撇
|
||
|
||
def evaluate_stroke_order(self, char: str, strokes: List[List[StrokePointInput]],
|
||
difficulty: int = 1) -> Dict:
|
||
"""
|
||
评估笔顺正确性
|
||
将用户书写的每一笔与标准笔顺逐一比对,计算匹配分数
|
||
"""
|
||
start_time = time.time()
|
||
|
||
# 获取标准笔顺
|
||
standard_order = self._database.get_standard_order(char)
|
||
standard_count = self._database.get_stroke_count(char)
|
||
|
||
# 提取用户每一笔的特征
|
||
user_features = [self._extract_stroke_feature(s) for s in strokes]
|
||
user_directions = [f.direction for f in user_features]
|
||
|
||
# 笔画数评分(满分100)
|
||
count_score = 100.0
|
||
if standard_count:
|
||
count_diff = abs(len(strokes) - standard_count)
|
||
count_score = max(0, 100 - count_diff * 25)
|
||
|
||
# 笔顺正确性评分(逐笔比对方向)
|
||
order_score = 100.0
|
||
errors = []
|
||
if standard_order:
|
||
match_count = 0
|
||
compare_len = min(len(user_directions), len(standard_order))
|
||
for i in range(compare_len):
|
||
if user_directions[i] == standard_order[i]:
|
||
match_count += 1
|
||
else:
|
||
errors.append({
|
||
"stroke_index": i + 1,
|
||
"expected": standard_order[i].value,
|
||
"actual": user_directions[i].value,
|
||
"message": f"第{i+1}笔方向错误:应为{standard_order[i].value},实际为{user_directions[i].value}"
|
||
})
|
||
order_score = (match_count / max(len(standard_order), 1)) * 100
|
||
|
||
# 根据难度等级调整评分权重
|
||
weight_order = 0.5 + difficulty * 0.1 # 难度越高,笔顺正确性权重越大
|
||
weight_count = 1.0 - weight_order
|
||
|
||
total_score = order_score * weight_order + count_score * weight_count
|
||
elapsed = (time.time() - start_time) * 1000
|
||
|
||
return {
|
||
"character": char,
|
||
"total_score": round(total_score, 1),
|
||
"order_score": round(order_score, 1),
|
||
"count_score": round(count_score, 1),
|
||
"user_stroke_count": len(strokes),
|
||
"standard_stroke_count": standard_count,
|
||
"stroke_order": [d.value for d in user_directions],
|
||
"correct_order": [d.value for d in standard_order] if standard_order else [],
|
||
"errors": errors,
|
||
"inference_time_ms": round(elapsed, 2)
|
||
}
|
||
|
||
|
||
# ==================== 书写质量评测引擎 ====================
|
||
|
||
class WritingQualityEngine:
|
||
"""
|
||
书写质量评测引擎
|
||
从结构均衡性、笔画间距、规范性、美观度四个维度评估书写质量
|
||
"""
|
||
|
||
def evaluate(self, strokes: List[List[StrokePointInput]],
|
||
dimensions: List[str]) -> Dict:
|
||
"""执行书写质量评测"""
|
||
scores = {}
|
||
|
||
# 提取全部坐标点用于整体分析
|
||
all_points = []
|
||
for stroke in strokes:
|
||
all_points.extend([(p.x, p.y, p.pressure) for p in stroke])
|
||
|
||
if not all_points:
|
||
return {"total_score": 0, "dimensions": {}}
|
||
|
||
xs = [p[0] for p in all_points]
|
||
ys = [p[1] for p in all_points]
|
||
|
||
# 计算书写区域边界框
|
||
bbox_width = max(xs) - min(xs)
|
||
bbox_height = max(ys) - min(ys)
|
||
|
||
if "structure" in dimensions:
|
||
# 结构均衡性:分析重心位置与对称性
|
||
center_x = np.mean(xs)
|
||
center_y = np.mean(ys)
|
||
expected_center_x = min(xs) + bbox_width / 2
|
||
expected_center_y = min(ys) + bbox_height / 2
|
||
offset = np.sqrt((center_x - expected_center_x)**2 + (center_y - expected_center_y)**2)
|
||
max_offset = np.sqrt(bbox_width**2 + bbox_height**2) / 4
|
||
scores["structure"] = round(max(0, 100 - (offset / max(max_offset, 1)) * 60), 1)
|
||
|
||
if "spacing" in dimensions:
|
||
# 笔画间距均匀性:分析相邻笔画起始点间距的标准差
|
||
if len(strokes) > 1:
|
||
start_points = [(s[0].x, s[0].y) for s in strokes if s]
|
||
gaps = []
|
||
for i in range(1, len(start_points)):
|
||
gap = np.sqrt((start_points[i][0]-start_points[i-1][0])**2 +
|
||
(start_points[i][1]-start_points[i-1][1])**2)
|
||
gaps.append(gap)
|
||
gap_std = np.std(gaps) if gaps else 0
|
||
gap_mean = np.mean(gaps) if gaps else 1
|
||
cv = gap_std / max(gap_mean, 1) # 变异系数
|
||
scores["spacing"] = round(max(0, 100 - cv * 80), 1)
|
||
else:
|
||
scores["spacing"] = 80.0
|
||
|
||
if "normative" in dimensions:
|
||
# 规范性:分析笔画弯曲度和压力稳定性
|
||
pressures = [p[2] for p in all_points]
|
||
pressure_std = np.std(pressures) if pressures else 0
|
||
scores["normative"] = round(max(0, 100 - pressure_std * 200), 1)
|
||
|
||
if "aesthetics" in dimensions:
|
||
# 美观度:综合笔画流畅度和整体比例
|
||
aspect_ratio = bbox_width / max(bbox_height, 1)
|
||
ratio_score = max(0, 100 - abs(aspect_ratio - 1.0) * 50) # 接近正方形得分高
|
||
scores["aesthetics"] = round(ratio_score, 1)
|
||
|
||
total = np.mean(list(scores.values())) if scores else 0
|
||
return {"total_score": round(total, 1), "dimensions": scores}
|
||
|
||
|
||
# ==================== API路由定义 ====================
|
||
|
||
router = APIRouter(prefix="/api/v1", tags=["笔顺评分"])
|
||
_analyzer = StrokeOrderAnalyzer()
|
||
_quality_engine = WritingQualityEngine()
|
||
|
||
|
||
@router.post("/stroke-order/evaluate")
|
||
async def evaluate_stroke_order(request: StrokeOrderRequest):
|
||
"""
|
||
笔顺正确性评分接口
|
||
POST /api/v1/stroke-order/evaluate
|
||
输入汉字和用户书写笔画数据,返回笔顺正确性评分和错误详情
|
||
"""
|
||
try:
|
||
result = _analyzer.evaluate_stroke_order(
|
||
char=request.character,
|
||
strokes=request.strokes,
|
||
difficulty=request.difficulty_level
|
||
)
|
||
# 记录审计日志(安全设计:所有识别请求记录调用方、时间、模型版本)
|
||
logger.info(
|
||
f"笔顺评分完成: char={request.character}, "
|
||
f"score={result['total_score']}, pen={request.pen_id}, "
|
||
f"student={request.student_id}, time={result['inference_time_ms']}ms"
|
||
)
|
||
return {"code": 200, "msg": "success", "data": result}
|
||
except Exception as e:
|
||
logger.error(f"笔顺评分异常: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"笔顺评分服务异常: {str(e)}")
|
||
|
||
|
||
@router.post("/writing/quality")
|
||
async def evaluate_writing_quality(request: WritingQualityRequest):
|
||
"""
|
||
书写质量评测接口
|
||
POST /api/v1/writing/quality
|
||
从结构、间距、规范性、美观度四维度评测书写质量
|
||
"""
|
||
try:
|
||
result = _quality_engine.evaluate(
|
||
strokes=request.strokes,
|
||
dimensions=request.eval_dimensions
|
||
)
|
||
logger.info(f"书写质量评测完成: score={result['total_score']}")
|
||
return {"code": 200, "msg": "success", "data": result}
|
||
except Exception as e:
|
||
logger.error(f"书写质量评测异常: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"书写质量评测异常: {str(e)}")
|