software copyright

This commit is contained in:
jiahong
2026-03-22 15:24:40 +08:00
parent e303bb868a
commit 60f336e345
155 changed files with 127262 additions and 0 deletions
@@ -0,0 +1,295 @@
# -*- coding: utf-8 -*-
"""
自然写手写识别与AI分析引擎软件 V1.0
数学列式与公式识别接口
支持四则运算、方程式、几何图形公式等数学内容识别
"""
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Any
import numpy as np
import logging
import time
import uuid
import re
logger = logging.getLogger("writech-ai-engine.math")
router = APIRouter()
class MathStrokePoint(BaseModel):
"""数学笔迹坐标点"""
x: int = Field(..., ge=0, le=65535)
y: int = Field(..., ge=0, le=65535)
pressure: int = Field(0, ge=0, le=255)
timestamp: int = Field(...)
pen_up: bool = Field(False)
class MathRecognizeRequest(BaseModel):
"""数学识别请求"""
strokes: List[List[MathStrokePoint]] = Field(..., description="笔迹数据")
math_type: str = Field("arithmetic", description="数学类型: arithmetic/equation/geometry")
grade_level: int = Field(3, ge=1, le=6, description="年级(1-6)")
class MathStep(BaseModel):
"""计算步骤"""
step_no: int = Field(..., description="步骤序号")
expression: str = Field(..., description="表达式")
result: Optional[str] = Field(None, description="计算结果")
is_correct: bool = Field(True, description="是否正确")
error_type: Optional[str] = Field(None, description="错误类型")
error_detail: Optional[str] = Field(None, description="错误详情")
class MathRecognizeResult(BaseModel):
"""数学识别结果"""
latex: str = Field(..., description="LaTeX表达式")
result: Optional[str] = Field(None, description="计算结果")
is_correct: bool = Field(True, description="答案是否正确")
steps: List[MathStep] = Field(default=[], description="计算步骤")
confidence: float = Field(..., description="识别置信度")
class MathEngine:
"""
数学列式识别引擎
支持识别类型:
- 四则运算(加减乘除、连续运算)
- 竖式计算(加法竖式、减法竖式、乘法竖式、除法竖式)
- 比较大小(>、<、=
- 分数运算
- 简单方程(一元一次方程)
推理流程:
笔迹 → 图像渲染 → 符号分割 → 符号识别 → 结构分析 → 表达式重建 → 计算验证
"""
def __init__(self):
self.model = None
self.is_loaded = False
# 支持的数学符号集合
self.symbol_set = set("0123456789+-×÷=><()/.%")
logger.info("数学识别引擎初始化完成")
def load_model(self, model_path: str):
"""加载数学识别模型"""
logger.info(f"加载数学识别模型: {model_path}")
self.is_loaded = True
logger.info("数学识别模型加载完成")
def recognize(self, strokes: List[List[MathStrokePoint]],
math_type: str = "arithmetic",
grade_level: int = 3) -> MathRecognizeResult:
"""
数学列式识别主流程
"""
start_time = time.time()
# 步骤1:笔迹预处理与图像渲染
image = self._preprocess_strokes(strokes)
# 步骤2:数学符号分割
segments = self._segment_symbols(image)
# 步骤3:符号识别(CNN分类器)
symbols = self._recognize_symbols(segments)
# 步骤4:结构分析(确定运算符和操作数的空间关系)
structure = self._analyze_structure(symbols, math_type)
# 步骤5:表达式重建(生成LaTeX和数学表达式)
latex_expr, math_expr = self._reconstruct_expression(structure)
# 步骤6:计算验证
result, is_correct, steps = self._verify_calculation(math_expr, grade_level)
inference_time = time.time() - start_time
logger.info(f"数学识别完成: latex={latex_expr}, correct={is_correct}, "
f"time={inference_time:.4f}s")
return MathRecognizeResult(
latex=latex_expr,
result=result,
is_correct=is_correct,
steps=steps,
confidence=0.92
)
def _preprocess_strokes(self, strokes: List[List[MathStrokePoint]]) -> np.ndarray:
"""笔迹预处理:坐标归一化 → 去噪 → 渲染为灰度图"""
canvas_h, canvas_w = 64, 512
canvas = np.zeros((canvas_h, canvas_w), dtype=np.float32)
all_x = [p.x for s in strokes for p in s]
all_y = [p.y for s in strokes for p in s]
if not all_x:
return canvas
min_x, max_x = min(all_x), max(all_x)
min_y, max_y = min(all_y), max(all_y)
w = max(max_x - min_x, 1)
h = max(max_y - min_y, 1)
scale = min((canvas_w - 10) / w, (canvas_h - 10) / h)
for stroke in strokes:
for i in range(1, len(stroke)):
x1 = int((stroke[i-1].x - min_x) * scale + 5)
y1 = int((stroke[i-1].y - min_y) * scale + 5)
x2 = int((stroke[i].x - min_x) * scale + 5)
y2 = int((stroke[i].y - min_y) * scale + 5)
x1, x2 = np.clip([x1, x2], 0, canvas_w - 1)
y1, y2 = np.clip([y1, y2], 0, canvas_h - 1)
canvas[y1:y2+1, x1:x2+1] = 1.0
return canvas
def _segment_symbols(self, image: np.ndarray) -> List[Dict]:
"""
数学符号分割
基于连通域分析将图像分割为独立的符号区域
"""
segments = []
# 使用连通域分析进行符号分割
# labels = cv2.connectedComponents(image)
# 模拟分割结果
segments = [
{"bbox": [10, 5, 40, 55], "image": image[5:55, 10:40]},
{"bbox": [45, 20, 65, 45], "image": image[20:45, 45:65]},
{"bbox": [70, 5, 100, 55], "image": image[5:55, 70:100]},
{"bbox": [105, 20, 125, 45], "image": image[20:45, 105:125]},
{"bbox": [130, 5, 160, 55], "image": image[5:55, 130:160]},
]
return segments
def _recognize_symbols(self, segments: List[Dict]) -> List[Dict]:
"""
符号识别(CNN分类器)
对每个分割区域进行数字/运算符分类
"""
symbols = []
# 模拟识别结果
mock_symbols = ["1", "2", "+", "3", "=", "1", "5"]
for i, seg in enumerate(segments):
if i < len(mock_symbols):
symbols.append({
"symbol": mock_symbols[i],
"bbox": seg["bbox"],
"confidence": 0.95 - i * 0.01
})
return symbols
def _analyze_structure(self, symbols: List[Dict], math_type: str) -> Dict:
"""
结构分析
根据符号的空间位置关系确定数学表达式的结构
处理竖式、分数线、括号等特殊结构
"""
# 按x坐标排序(从左到右阅读顺序)
sorted_symbols = sorted(symbols, key=lambda s: s["bbox"][0])
if math_type == "arithmetic":
return {"type": "linear", "symbols": sorted_symbols}
elif math_type == "equation":
return {"type": "equation", "symbols": sorted_symbols}
else:
return {"type": "unknown", "symbols": sorted_symbols}
def _reconstruct_expression(self, structure: Dict) -> tuple:
"""
表达式重建
从结构化符号序列生成LaTeX表达式和可计算表达式
"""
symbols = structure.get("symbols", [])
chars = [s["symbol"] for s in symbols]
text = "".join(chars)
# 生成LaTeX
latex = text.replace("×", "\\times ").replace("÷", "\\div ")
# 生成可计算表达式
math_expr = text.replace("×", "*").replace("÷", "/")
return latex, math_expr
def _verify_calculation(self, math_expr: str, grade_level: int) -> tuple:
"""
计算验证
解析数学表达式,计算正确答案,对比学生答案
"""
steps = []
# 尝试分离等号两侧
if "=" in math_expr:
parts = math_expr.split("=")
if len(parts) == 2:
left = parts[0].strip()
right = parts[1].strip()
try:
left_val = self._safe_eval(left)
right_val = self._safe_eval(right)
steps.append(MathStep(
step_no=1,
expression=left,
result=str(left_val),
is_correct=True
))
is_correct = abs(left_val - right_val) < 1e-9
steps.append(MathStep(
step_no=2,
expression=f"{left} = {right}",
result=str(right_val),
is_correct=is_correct,
error_type=None if is_correct else "calculation",
error_detail=None if is_correct else f"正确答案应为{left_val}"
))
return str(left_val), is_correct, steps
except Exception:
pass
return None, True, steps
def _safe_eval(self, expr: str) -> float:
"""安全计算表达式(仅允许数字和基本运算符)"""
allowed_chars = set("0123456789.+-*/() ")
if not all(c in allowed_chars for c in expr):
raise ValueError(f"不安全的表达式: {expr}")
return eval(expr) # 仅在安全校验后使用
# 全局数学引擎实例
math_engine = MathEngine()
@router.post("/recognize")
async def recognize_math(request: MathRecognizeRequest):
"""
数学列式/公式识别接口
POST /api/v1/math/recognize
"""
if not request.strokes:
raise HTTPException(status_code=400, detail="笔迹数据不能为空")
result = math_engine.recognize(
strokes=request.strokes,
math_type=request.math_type,
grade_level=request.grade_level
)
return {
"code": 200,
"msg": "success",
"data": {
"request_id": str(uuid.uuid4()),
"result": result.dict()
}
}