Files
2026-03-22 15:24:40 +08:00

460 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# 自然写手写识别与AI分析引擎软件 V1.0
# 笔顺分析算法模块 - 笔画拆分与顺序分析核心算法
"""
笔顺分析核心算法
提供笔画自动拆分、方向判定、笔画连接检测、
笔迹相似度计算等底层分析算法
"""
import math
import logging
import numpy as np
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass, field
from enum import IntEnum
logger = logging.getLogger(__name__)
# ==================== 常量定义 ====================
# 笔画方向角度范围(度数)
DIRECTION_ANGLES = {
"horizontal": (-15, 15), # 横
"vertical": (75, 105), # 竖
"left_falling": (120, 165), # 撇
"right_falling": (30, 75), # 捺
"dot": None, # 点(特殊判定)
"turning": None, # 折(特殊判定)
"hook": None, # 钩(特殊判定)
"rising": (-60, -15), # 提
}
# 笔画最小长度阈值(像素),低于此值视为噪声
MIN_STROKE_LENGTH = 3.0
# 笔画分段时的角度变化阈值(度数)
ANGLE_CHANGE_THRESHOLD = 45.0
# 采样点间距最小阈值
MIN_POINT_DISTANCE = 1.0
class StrokeType(IntEnum):
"""笔画类型枚举"""
UNKNOWN = 0
HORIZONTAL = 1 # 横
VERTICAL = 2 # 竖
LEFT_FALLING = 3 # 撇
RIGHT_FALLING = 4 # 捺
DOT = 5 # 点
TURNING = 6 # 折
HOOK = 7 # 钩
RISING = 8 # 提
@dataclass
class Point2D:
"""二维坐标点"""
x: float
y: float
pressure: float = 0.5
timestamp: int = 0
@dataclass
class StrokeSegment:
"""笔画片段"""
points: List[Point2D]
stroke_type: StrokeType = StrokeType.UNKNOWN
direction_angle: float = 0.0
length: float = 0.0
curvature: float = 0.0
avg_speed: float = 0.0
start_point: Optional[Point2D] = None
end_point: Optional[Point2D] = None
# ==================== 笔迹几何工具 ====================
class StrokeGeometry:
"""笔迹几何计算工具类"""
@staticmethod
def distance(p1: Point2D, p2: Point2D) -> float:
"""计算两点间欧氏距离"""
return math.sqrt((p2.x - p1.x) ** 2 + (p2.y - p1.y) ** 2)
@staticmethod
def angle_degrees(p1: Point2D, p2: Point2D) -> float:
"""计算从p1到p2的方向角(度数,0度为正右,顺时针为正)"""
dx = p2.x - p1.x
dy = p2.y - p1.y
return math.degrees(math.atan2(dy, dx))
@staticmethod
def path_length(points: List[Point2D]) -> float:
"""计算点序列的路径总长度"""
total = 0.0
for i in range(1, len(points)):
total += StrokeGeometry.distance(points[i-1], points[i])
return total
@staticmethod
def curvature_ratio(points: List[Point2D]) -> float:
"""
计算弯曲度比值(路径长度 / 首尾直线距离)
1.0表示完全直线,数值越大弯曲程度越高
"""
if len(points) < 2:
return 1.0
path_len = StrokeGeometry.path_length(points)
direct = StrokeGeometry.distance(points[0], points[-1])
return path_len / max(direct, 0.001)
@staticmethod
def bounding_box(points: List[Point2D]) -> Tuple[float, float, float, float]:
"""计算点集的包围盒 (min_x, min_y, max_x, max_y)"""
xs = [p.x for p in points]
ys = [p.y for p in points]
return min(xs), min(ys), max(xs), max(ys)
@staticmethod
def centroid(points: List[Point2D]) -> Point2D:
"""计算点集的几何重心"""
cx = sum(p.x for p in points) / len(points)
cy = sum(p.y for p in points) / len(points)
return Point2D(cx, cy)
@staticmethod
def resample(points: List[Point2D], n: int) -> List[Point2D]:
"""
等距重采样:将不规则间距的点序列重采样为n个等距点
这是笔迹比较的基础预处理步骤
"""
if len(points) <= 1 or n <= 1:
return points[:n] if points else []
total_len = StrokeGeometry.path_length(points)
interval = total_len / (n - 1)
resampled = [Point2D(points[0].x, points[0].y, points[0].pressure)]
accumulated = 0.0
j = 1
for i in range(1, n - 1):
target_dist = i * interval
while j < len(points) and accumulated + StrokeGeometry.distance(points[j-1], points[j]) < target_dist:
accumulated += StrokeGeometry.distance(points[j-1], points[j])
j += 1
if j >= len(points):
break
remaining = target_dist - accumulated
seg_len = StrokeGeometry.distance(points[j-1], points[j])
ratio = remaining / max(seg_len, 0.001)
# 线性插值计算新坐标
new_x = points[j-1].x + ratio * (points[j].x - points[j-1].x)
new_y = points[j-1].y + ratio * (points[j].y - points[j-1].y)
new_p = points[j-1].pressure + ratio * (points[j].pressure - points[j-1].pressure)
resampled.append(Point2D(new_x, new_y, new_p))
resampled.append(Point2D(points[-1].x, points[-1].y, points[-1].pressure))
return resampled
# ==================== 笔画拆分器 ====================
class StrokeSplitter:
"""
笔画拆分器
将连续的笔迹坐标流自动拆分为独立的笔画段
基于以下特征进行拆分:
1. 抬笔点(pressure=0或时间间隔大)
2. 方向突变点(角度变化超过阈值)
3. 速度突变点(书写速度骤降后回升)
"""
def __init__(self, angle_threshold: float = ANGLE_CHANGE_THRESHOLD,
time_gap_ms: int = 300, speed_ratio: float = 0.3):
self._angle_threshold = angle_threshold
self._time_gap_ms = time_gap_ms
self._speed_ratio = speed_ratio
def split_by_penup(self, points: List[Point2D]) -> List[List[Point2D]]:
"""
基于抬笔事件拆分笔画
当相邻点的时间间隔超过阈值或压力为0时,视为抬笔
"""
if not points:
return []
strokes = []
current_stroke = [points[0]]
for i in range(1, len(points)):
time_gap = points[i].timestamp - points[i-1].timestamp
is_penup = (points[i].pressure <= 0.01 or time_gap > self._time_gap_ms)
if is_penup and len(current_stroke) > 1:
strokes.append(current_stroke)
current_stroke = [points[i]]
else:
current_stroke.append(points[i])
if len(current_stroke) > 1:
strokes.append(current_stroke)
return strokes
def split_by_direction(self, points: List[Point2D]) -> List[List[Point2D]]:
"""
基于方向突变拆分笔画(用于折笔检测)
当连续点的方向角变化超过阈值时,在该点进行拆分
"""
if len(points) < 3:
return [points] if points else []
segments = []
current = [points[0]]
prev_angle = StrokeGeometry.angle_degrees(points[0], points[1])
for i in range(1, len(points)):
current.append(points[i])
if i + 1 < len(points):
curr_angle = StrokeGeometry.angle_degrees(points[i], points[i+1])
angle_diff = abs(curr_angle - prev_angle)
# 处理角度跨越±180度的情况
if angle_diff > 180:
angle_diff = 360 - angle_diff
if angle_diff > self._angle_threshold and len(current) > 2:
segments.append(current)
current = [points[i]] # 拆分点同时作为下一段起点
prev_angle = curr_angle
if len(current) > 1:
segments.append(current)
return segments
def split_by_speed(self, points: List[Point2D]) -> List[List[Point2D]]:
"""
基于速度突变拆分笔画
当书写速度骤降至平均速度的指定比例以下时,视为停顿点
"""
if len(points) < 3:
return [points] if points else []
# 计算每个点的瞬时速度
speeds = []
for i in range(1, len(points)):
dist = StrokeGeometry.distance(points[i-1], points[i])
dt = max(points[i].timestamp - points[i-1].timestamp, 1)
speeds.append(dist / dt * 1000) # 像素/秒
avg_speed = np.mean(speeds) if speeds else 0
threshold = avg_speed * self._speed_ratio
segments = []
current = [points[0]]
for i in range(len(speeds)):
current.append(points[i + 1])
if speeds[i] < threshold and len(current) > 3:
segments.append(current)
current = [points[i + 1]]
if len(current) > 1:
segments.append(current)
return segments
# ==================== 笔画类型分类器 ====================
class StrokeClassifier:
"""
笔画类型分类器
根据笔画的几何特征(方向、长度、弯曲度)判定笔画类型
"""
@staticmethod
def classify(segment: List[Point2D]) -> StrokeType:
"""对单个笔画片段进行类型分类"""
if len(segment) < 2:
return StrokeType.DOT
length = StrokeGeometry.path_length(segment)
curvature = StrokeGeometry.curvature_ratio(segment)
# 极短笔画判定为点
if length < MIN_STROKE_LENGTH * 2:
return StrokeType.DOT
# 高弯曲度判定为折或钩
if curvature > 2.0:
# 检查末端是否有向上的钩
if len(segment) >= 3:
end_angle = StrokeGeometry.angle_degrees(segment[-2], segment[-1])
if -90 < end_angle < -10:
return StrokeType.HOOK
return StrokeType.TURNING
# 根据整体方向角判定
angle = StrokeGeometry.angle_degrees(segment[0], segment[-1])
if -20 <= angle <= 20:
return StrokeType.HORIZONTAL
elif 70 <= angle <= 110:
return StrokeType.VERTICAL
elif 120 <= angle <= 170 or -170 <= angle <= -150:
return StrokeType.LEFT_FALLING
elif 25 <= angle <= 70:
return StrokeType.RIGHT_FALLING
elif -65 <= angle <= -20:
return StrokeType.RISING
else:
return StrokeType.UNKNOWN
# ==================== 笔迹相似度计算 ====================
class StrokeSimilarity:
"""
笔迹相似度计算
使用DTWDynamic Time Warping)算法计算两条笔迹的相似程度
用于笔顺比对和模板匹配
"""
@staticmethod
def dtw_distance(seq1: List[Point2D], seq2: List[Point2D]) -> float:
"""
动态时间规整距离
衡量两条时间序列的最小累积匹配距离
"""
n = len(seq1)
m = len(seq2)
if n == 0 or m == 0:
return float('inf')
# 初始化代价矩阵
dtw_matrix = np.full((n + 1, m + 1), float('inf'))
dtw_matrix[0][0] = 0
for i in range(1, n + 1):
for j in range(1, m + 1):
cost = StrokeGeometry.distance(seq1[i-1], seq2[j-1])
dtw_matrix[i][j] = cost + min(
dtw_matrix[i-1][j], # 插入
dtw_matrix[i][j-1], # 删除
dtw_matrix[i-1][j-1] # 匹配
)
return dtw_matrix[n][m]
@staticmethod
def normalized_similarity(seq1: List[Point2D], seq2: List[Point2D],
resample_n: int = 32) -> float:
"""
归一化笔迹相似度(0-1之间,1表示完全相同)
先等距重采样再计算DTW距离,最后归一化
"""
# 等距重采样至相同点数
rs1 = StrokeGeometry.resample(seq1, resample_n)
rs2 = StrokeGeometry.resample(seq2, resample_n)
if not rs1 or not rs2:
return 0.0
# 归一化坐标到[0,1]范围
all_pts = rs1 + rs2
bbox = StrokeGeometry.bounding_box(all_pts)
scale = max(bbox[2] - bbox[0], bbox[3] - bbox[1], 1.0)
norm1 = [Point2D((p.x - bbox[0]) / scale, (p.y - bbox[1]) / scale) for p in rs1]
norm2 = [Point2D((p.x - bbox[0]) / scale, (p.y - bbox[1]) / scale) for p in rs2]
dtw_dist = StrokeSimilarity.dtw_distance(norm1, norm2)
# 将DTW距离映射到相似度分数
similarity = max(0, 1.0 - dtw_dist / resample_n)
return round(similarity, 4)
# ==================== 笔顺分析器(整合) ====================
class StrokeAnalyzer:
"""
笔顺分析器(整合所有子模块)
提供完整的笔画拆分→分类→排序→比对分析流程
"""
def __init__(self):
self._splitter = StrokeSplitter()
self._classifier = StrokeClassifier()
self._similarity = StrokeSimilarity()
logger.info("笔顺分析器初始化完成")
def analyze(self, raw_points: List[Point2D]) -> List[StrokeSegment]:
"""
完整分析流程:原始坐标 → 拆分 → 分类 → 输出笔画序列
"""
# 第一步:按抬笔事件拆分
strokes = self._splitter.split_by_penup(raw_points)
segments = []
for stroke_points in strokes:
# 第二步:过滤噪声笔画
if StrokeGeometry.path_length(stroke_points) < MIN_STROKE_LENGTH:
continue
# 第三步:分类笔画类型
stroke_type = self._classifier.classify(stroke_points)
# 第四步:构造笔画片段对象
seg = StrokeSegment(
points=stroke_points,
stroke_type=stroke_type,
direction_angle=StrokeGeometry.angle_degrees(stroke_points[0], stroke_points[-1]),
length=StrokeGeometry.path_length(stroke_points),
curvature=StrokeGeometry.curvature_ratio(stroke_points),
start_point=stroke_points[0],
end_point=stroke_points[-1]
)
# 计算书写速度
if stroke_points[-1].timestamp > stroke_points[0].timestamp:
time_s = (stroke_points[-1].timestamp - stroke_points[0].timestamp) / 1000.0
seg.avg_speed = seg.length / max(time_s, 0.001)
segments.append(seg)
logger.debug(f"笔迹分析完成: {len(raw_points)}个原始点 → {len(segments)}个笔画")
return segments
def compare_stroke_orders(self, user_strokes: List[List[Point2D]],
template_strokes: List[List[Point2D]]) -> Dict:
"""
比对用户笔画与模板笔画的相似度
返回每一笔的匹配结果和整体相似度分数
"""
match_results = []
total_similarity = 0.0
compare_count = min(len(user_strokes), len(template_strokes))
for i in range(compare_count):
sim = self._similarity.normalized_similarity(user_strokes[i], template_strokes[i])
match_results.append({
"stroke_index": i + 1,
"similarity": sim,
"match": sim > 0.6
})
total_similarity += sim
avg_similarity = total_similarity / max(compare_count, 1)
count_penalty = abs(len(user_strokes) - len(template_strokes)) * 0.1
return {
"overall_similarity": round(max(0, avg_similarity - count_penalty), 4),
"stroke_matches": match_results,
"user_count": len(user_strokes),
"template_count": len(template_strokes)
}