software copyright
This commit is contained in:
@@ -0,0 +1,295 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
自然写手写识别与AI分析引擎软件 V1.0
|
||||
|
||||
数学列式与公式识别接口
|
||||
支持四则运算、方程式、几何图形公式等数学内容识别
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional, Dict, Any
|
||||
import numpy as np
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
import re
|
||||
|
||||
logger = logging.getLogger("writech-ai-engine.math")
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class MathStrokePoint(BaseModel):
|
||||
"""数学笔迹坐标点"""
|
||||
x: int = Field(..., ge=0, le=65535)
|
||||
y: int = Field(..., ge=0, le=65535)
|
||||
pressure: int = Field(0, ge=0, le=255)
|
||||
timestamp: int = Field(...)
|
||||
pen_up: bool = Field(False)
|
||||
|
||||
|
||||
class MathRecognizeRequest(BaseModel):
|
||||
"""数学识别请求"""
|
||||
strokes: List[List[MathStrokePoint]] = Field(..., description="笔迹数据")
|
||||
math_type: str = Field("arithmetic", description="数学类型: arithmetic/equation/geometry")
|
||||
grade_level: int = Field(3, ge=1, le=6, description="年级(1-6)")
|
||||
|
||||
|
||||
class MathStep(BaseModel):
|
||||
"""计算步骤"""
|
||||
step_no: int = Field(..., description="步骤序号")
|
||||
expression: str = Field(..., description="表达式")
|
||||
result: Optional[str] = Field(None, description="计算结果")
|
||||
is_correct: bool = Field(True, description="是否正确")
|
||||
error_type: Optional[str] = Field(None, description="错误类型")
|
||||
error_detail: Optional[str] = Field(None, description="错误详情")
|
||||
|
||||
|
||||
class MathRecognizeResult(BaseModel):
|
||||
"""数学识别结果"""
|
||||
latex: str = Field(..., description="LaTeX表达式")
|
||||
result: Optional[str] = Field(None, description="计算结果")
|
||||
is_correct: bool = Field(True, description="答案是否正确")
|
||||
steps: List[MathStep] = Field(default=[], description="计算步骤")
|
||||
confidence: float = Field(..., description="识别置信度")
|
||||
|
||||
|
||||
class MathEngine:
|
||||
"""
|
||||
数学列式识别引擎
|
||||
|
||||
支持识别类型:
|
||||
- 四则运算(加减乘除、连续运算)
|
||||
- 竖式计算(加法竖式、减法竖式、乘法竖式、除法竖式)
|
||||
- 比较大小(>、<、=)
|
||||
- 分数运算
|
||||
- 简单方程(一元一次方程)
|
||||
|
||||
推理流程:
|
||||
笔迹 → 图像渲染 → 符号分割 → 符号识别 → 结构分析 → 表达式重建 → 计算验证
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
self.is_loaded = False
|
||||
# 支持的数学符号集合
|
||||
self.symbol_set = set("0123456789+-×÷=><()/.%")
|
||||
logger.info("数学识别引擎初始化完成")
|
||||
|
||||
def load_model(self, model_path: str):
|
||||
"""加载数学识别模型"""
|
||||
logger.info(f"加载数学识别模型: {model_path}")
|
||||
self.is_loaded = True
|
||||
logger.info("数学识别模型加载完成")
|
||||
|
||||
def recognize(self, strokes: List[List[MathStrokePoint]],
|
||||
math_type: str = "arithmetic",
|
||||
grade_level: int = 3) -> MathRecognizeResult:
|
||||
"""
|
||||
数学列式识别主流程
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# 步骤1:笔迹预处理与图像渲染
|
||||
image = self._preprocess_strokes(strokes)
|
||||
|
||||
# 步骤2:数学符号分割
|
||||
segments = self._segment_symbols(image)
|
||||
|
||||
# 步骤3:符号识别(CNN分类器)
|
||||
symbols = self._recognize_symbols(segments)
|
||||
|
||||
# 步骤4:结构分析(确定运算符和操作数的空间关系)
|
||||
structure = self._analyze_structure(symbols, math_type)
|
||||
|
||||
# 步骤5:表达式重建(生成LaTeX和数学表达式)
|
||||
latex_expr, math_expr = self._reconstruct_expression(structure)
|
||||
|
||||
# 步骤6:计算验证
|
||||
result, is_correct, steps = self._verify_calculation(math_expr, grade_level)
|
||||
|
||||
inference_time = time.time() - start_time
|
||||
logger.info(f"数学识别完成: latex={latex_expr}, correct={is_correct}, "
|
||||
f"time={inference_time:.4f}s")
|
||||
|
||||
return MathRecognizeResult(
|
||||
latex=latex_expr,
|
||||
result=result,
|
||||
is_correct=is_correct,
|
||||
steps=steps,
|
||||
confidence=0.92
|
||||
)
|
||||
|
||||
def _preprocess_strokes(self, strokes: List[List[MathStrokePoint]]) -> np.ndarray:
|
||||
"""笔迹预处理:坐标归一化 → 去噪 → 渲染为灰度图"""
|
||||
canvas_h, canvas_w = 64, 512
|
||||
canvas = np.zeros((canvas_h, canvas_w), dtype=np.float32)
|
||||
|
||||
all_x = [p.x for s in strokes for p in s]
|
||||
all_y = [p.y for s in strokes for p in s]
|
||||
if not all_x:
|
||||
return canvas
|
||||
|
||||
min_x, max_x = min(all_x), max(all_x)
|
||||
min_y, max_y = min(all_y), max(all_y)
|
||||
w = max(max_x - min_x, 1)
|
||||
h = max(max_y - min_y, 1)
|
||||
scale = min((canvas_w - 10) / w, (canvas_h - 10) / h)
|
||||
|
||||
for stroke in strokes:
|
||||
for i in range(1, len(stroke)):
|
||||
x1 = int((stroke[i-1].x - min_x) * scale + 5)
|
||||
y1 = int((stroke[i-1].y - min_y) * scale + 5)
|
||||
x2 = int((stroke[i].x - min_x) * scale + 5)
|
||||
y2 = int((stroke[i].y - min_y) * scale + 5)
|
||||
x1, x2 = np.clip([x1, x2], 0, canvas_w - 1)
|
||||
y1, y2 = np.clip([y1, y2], 0, canvas_h - 1)
|
||||
canvas[y1:y2+1, x1:x2+1] = 1.0
|
||||
|
||||
return canvas
|
||||
|
||||
def _segment_symbols(self, image: np.ndarray) -> List[Dict]:
|
||||
"""
|
||||
数学符号分割
|
||||
基于连通域分析将图像分割为独立的符号区域
|
||||
"""
|
||||
segments = []
|
||||
# 使用连通域分析进行符号分割
|
||||
# labels = cv2.connectedComponents(image)
|
||||
# 模拟分割结果
|
||||
segments = [
|
||||
{"bbox": [10, 5, 40, 55], "image": image[5:55, 10:40]},
|
||||
{"bbox": [45, 20, 65, 45], "image": image[20:45, 45:65]},
|
||||
{"bbox": [70, 5, 100, 55], "image": image[5:55, 70:100]},
|
||||
{"bbox": [105, 20, 125, 45], "image": image[20:45, 105:125]},
|
||||
{"bbox": [130, 5, 160, 55], "image": image[5:55, 130:160]},
|
||||
]
|
||||
return segments
|
||||
|
||||
def _recognize_symbols(self, segments: List[Dict]) -> List[Dict]:
|
||||
"""
|
||||
符号识别(CNN分类器)
|
||||
对每个分割区域进行数字/运算符分类
|
||||
"""
|
||||
symbols = []
|
||||
# 模拟识别结果
|
||||
mock_symbols = ["1", "2", "+", "3", "=", "1", "5"]
|
||||
for i, seg in enumerate(segments):
|
||||
if i < len(mock_symbols):
|
||||
symbols.append({
|
||||
"symbol": mock_symbols[i],
|
||||
"bbox": seg["bbox"],
|
||||
"confidence": 0.95 - i * 0.01
|
||||
})
|
||||
return symbols
|
||||
|
||||
def _analyze_structure(self, symbols: List[Dict], math_type: str) -> Dict:
|
||||
"""
|
||||
结构分析
|
||||
根据符号的空间位置关系确定数学表达式的结构
|
||||
处理竖式、分数线、括号等特殊结构
|
||||
"""
|
||||
# 按x坐标排序(从左到右阅读顺序)
|
||||
sorted_symbols = sorted(symbols, key=lambda s: s["bbox"][0])
|
||||
|
||||
if math_type == "arithmetic":
|
||||
return {"type": "linear", "symbols": sorted_symbols}
|
||||
elif math_type == "equation":
|
||||
return {"type": "equation", "symbols": sorted_symbols}
|
||||
else:
|
||||
return {"type": "unknown", "symbols": sorted_symbols}
|
||||
|
||||
def _reconstruct_expression(self, structure: Dict) -> tuple:
|
||||
"""
|
||||
表达式重建
|
||||
从结构化符号序列生成LaTeX表达式和可计算表达式
|
||||
"""
|
||||
symbols = structure.get("symbols", [])
|
||||
chars = [s["symbol"] for s in symbols]
|
||||
text = "".join(chars)
|
||||
|
||||
# 生成LaTeX
|
||||
latex = text.replace("×", "\\times ").replace("÷", "\\div ")
|
||||
|
||||
# 生成可计算表达式
|
||||
math_expr = text.replace("×", "*").replace("÷", "/")
|
||||
|
||||
return latex, math_expr
|
||||
|
||||
def _verify_calculation(self, math_expr: str, grade_level: int) -> tuple:
|
||||
"""
|
||||
计算验证
|
||||
解析数学表达式,计算正确答案,对比学生答案
|
||||
"""
|
||||
steps = []
|
||||
|
||||
# 尝试分离等号两侧
|
||||
if "=" in math_expr:
|
||||
parts = math_expr.split("=")
|
||||
if len(parts) == 2:
|
||||
left = parts[0].strip()
|
||||
right = parts[1].strip()
|
||||
|
||||
try:
|
||||
left_val = self._safe_eval(left)
|
||||
right_val = self._safe_eval(right)
|
||||
|
||||
steps.append(MathStep(
|
||||
step_no=1,
|
||||
expression=left,
|
||||
result=str(left_val),
|
||||
is_correct=True
|
||||
))
|
||||
|
||||
is_correct = abs(left_val - right_val) < 1e-9
|
||||
steps.append(MathStep(
|
||||
step_no=2,
|
||||
expression=f"{left} = {right}",
|
||||
result=str(right_val),
|
||||
is_correct=is_correct,
|
||||
error_type=None if is_correct else "calculation",
|
||||
error_detail=None if is_correct else f"正确答案应为{left_val}"
|
||||
))
|
||||
|
||||
return str(left_val), is_correct, steps
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None, True, steps
|
||||
|
||||
def _safe_eval(self, expr: str) -> float:
|
||||
"""安全计算表达式(仅允许数字和基本运算符)"""
|
||||
allowed_chars = set("0123456789.+-*/() ")
|
||||
if not all(c in allowed_chars for c in expr):
|
||||
raise ValueError(f"不安全的表达式: {expr}")
|
||||
return eval(expr) # 仅在安全校验后使用
|
||||
|
||||
|
||||
# 全局数学引擎实例
|
||||
math_engine = MathEngine()
|
||||
|
||||
|
||||
@router.post("/recognize")
|
||||
async def recognize_math(request: MathRecognizeRequest):
|
||||
"""
|
||||
数学列式/公式识别接口
|
||||
POST /api/v1/math/recognize
|
||||
"""
|
||||
if not request.strokes:
|
||||
raise HTTPException(status_code=400, detail="笔迹数据不能为空")
|
||||
|
||||
result = math_engine.recognize(
|
||||
strokes=request.strokes,
|
||||
math_type=request.math_type,
|
||||
grade_level=request.grade_level
|
||||
)
|
||||
|
||||
return {
|
||||
"code": 200,
|
||||
"msg": "success",
|
||||
"data": {
|
||||
"request_id": str(uuid.uuid4()),
|
||||
"result": result.dict()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user