353 lines
12 KiB
Python
353 lines
12 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
自然写手写识别与AI分析引擎软件 V1.0
|
|
|
|
OCR识别接口模块
|
|
提供中英文手写文字OCR识别服务,基于PaddleOCR推理管道
|
|
"""
|
|
|
|
from fastapi import APIRouter, HTTPException
|
|
from pydantic import BaseModel, Field
|
|
from typing import List, Optional, Dict, Any
|
|
import numpy as np
|
|
import logging
|
|
import time
|
|
import uuid
|
|
|
|
logger = logging.getLogger("writech-ai-engine.ocr")
|
|
router = APIRouter()
|
|
|
|
|
|
# ==================== 请求/响应模型定义 ====================
|
|
|
|
class StrokePoint(BaseModel):
|
|
"""笔迹坐标点"""
|
|
x: int = Field(..., ge=0, le=65535, description="X坐标")
|
|
y: int = Field(..., ge=0, le=65535, description="Y坐标")
|
|
pressure: int = Field(0, ge=0, le=255, description="压力值")
|
|
timestamp: int = Field(..., description="时间戳(毫秒)")
|
|
pen_up: bool = Field(False, description="抬笔标记")
|
|
|
|
|
|
class OCRRequest(BaseModel):
|
|
"""OCR识别请求"""
|
|
strokes: List[List[StrokePoint]] = Field(..., description="笔迹数据(按笔画分组)")
|
|
page_id: Optional[str] = Field(None, description="点阵码页面ID")
|
|
pen_id: Optional[str] = Field(None, description="笔设备ID")
|
|
language: str = Field("zh", description="识别语言: zh/en/mixed")
|
|
recognition_mode: str = Field("line", description="识别模式: char/word/line/page")
|
|
|
|
|
|
class CharDetail(BaseModel):
|
|
"""单字识别详情"""
|
|
char: str = Field(..., description="识别的字符")
|
|
confidence: float = Field(..., description="置信度(0-1)")
|
|
bbox: List[int] = Field(..., description="包围框[x1,y1,x2,y2]")
|
|
stroke_indices: List[int] = Field(default=[], description="对应的笔画索引")
|
|
|
|
|
|
class OCRResult(BaseModel):
|
|
"""OCR识别结果"""
|
|
text: str = Field(..., description="识别文本")
|
|
confidence: float = Field(..., description="整体置信度(0-1)")
|
|
bbox: List[int] = Field(default=[], description="文本区域包围框")
|
|
char_details: List[CharDetail] = Field(default=[], description="逐字详情")
|
|
|
|
|
|
class OCRResponse(BaseModel):
|
|
"""OCR识别响应"""
|
|
code: int = 200
|
|
msg: str = "success"
|
|
data: Optional[Dict[str, Any]] = None
|
|
|
|
|
|
# ==================== OCR 推理引擎 ====================
|
|
|
|
class OCREngine:
|
|
"""
|
|
PaddleOCR 推理引擎
|
|
|
|
推理管道流程:
|
|
笔迹坐标 → 预处理(归一化/去噪) → 笔画分割
|
|
→ 模型推理(OCR) → 后处理(置信度过滤/结果合并) → 结果输出
|
|
|
|
支持的识别模式:
|
|
- char: 单字识别(逐字识别,返回每个字的详情)
|
|
- word: 词组识别(按词分割识别)
|
|
- line: 行识别(按行识别,默认模式)
|
|
- page: 整页识别(全页文字识别)
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""初始化OCR推理引擎"""
|
|
self.model = None
|
|
self.model_version = "1.0.0"
|
|
self.is_loaded = False
|
|
# 模型输入图像尺寸
|
|
self.input_height = 48
|
|
self.input_width = 320
|
|
# 置信度阈值
|
|
self.confidence_threshold = 0.5
|
|
logger.info("OCR引擎初始化完成")
|
|
|
|
def load_model(self, model_path: str):
|
|
"""
|
|
加载PaddleOCR模型
|
|
模型文件AES-256加密存储,推理时内存解密加载
|
|
"""
|
|
logger.info(f"加载OCR模型: {model_path}")
|
|
# 解密模型文件
|
|
# decrypted_model = self._decrypt_model(model_path)
|
|
# self.model = paddle.jit.load(decrypted_model)
|
|
self.is_loaded = True
|
|
logger.info("OCR模型加载完成")
|
|
|
|
def preprocess_strokes(self, strokes: List[List[StrokePoint]]) -> np.ndarray:
|
|
"""
|
|
笔迹预处理管道
|
|
|
|
步骤:
|
|
1. 坐标归一化(映射到标准画布尺寸)
|
|
2. 去噪处理(滤除抖动和异常点)
|
|
3. 笔迹渲染为灰度图像
|
|
4. 图像尺寸归一化(resize到模型输入尺寸)
|
|
"""
|
|
# 计算所有点的边界框
|
|
all_points = []
|
|
for stroke in strokes:
|
|
for point in stroke:
|
|
all_points.append((point.x, point.y))
|
|
|
|
if not all_points:
|
|
return np.zeros((1, self.input_height, self.input_width), dtype=np.float32)
|
|
|
|
xs = [p[0] for p in all_points]
|
|
ys = [p[1] for p in all_points]
|
|
min_x, max_x = min(xs), max(xs)
|
|
min_y, max_y = min(ys), max(ys)
|
|
|
|
# 计算缩放比例(保持宽高比)
|
|
width = max(max_x - min_x, 1)
|
|
height = max(max_y - min_y, 1)
|
|
scale = min(self.input_width / width, self.input_height / height) * 0.9
|
|
|
|
# 创建渲染画布
|
|
canvas = np.zeros((self.input_height, self.input_width), dtype=np.float32)
|
|
|
|
# 渲染笔迹到画布
|
|
for stroke in strokes:
|
|
for i in range(1, len(stroke)):
|
|
x1 = int((stroke[i - 1].x - min_x) * scale)
|
|
y1 = int((stroke[i - 1].y - min_y) * scale)
|
|
x2 = int((stroke[i].x - min_x) * scale)
|
|
y2 = int((stroke[i].y - min_y) * scale)
|
|
# 使用Bresenham算法画线
|
|
self._draw_line(canvas, x1, y1, x2, y2,
|
|
thickness=max(1, stroke[i].pressure // 85))
|
|
|
|
# 归一化到[0, 1]
|
|
if canvas.max() > 0:
|
|
canvas = canvas / canvas.max()
|
|
|
|
return canvas.reshape(1, self.input_height, self.input_width)
|
|
|
|
def recognize(self, strokes: List[List[StrokePoint]],
|
|
mode: str = "line") -> List[OCRResult]:
|
|
"""
|
|
执行OCR识别
|
|
|
|
@param strokes: 笔迹数据(按笔画分组)
|
|
@param mode: 识别模式 (char/word/line/page)
|
|
@return: 识别结果列表
|
|
"""
|
|
start_time = time.time()
|
|
|
|
# 预处理
|
|
image = self.preprocess_strokes(strokes)
|
|
|
|
# 模型推理
|
|
# predictions = self.model(image)
|
|
# 模拟推理结果
|
|
predictions = self._mock_inference(image, mode)
|
|
|
|
# 后处理(置信度过滤、结果合并)
|
|
results = self._postprocess(predictions, mode)
|
|
|
|
inference_time = time.time() - start_time
|
|
logger.info(f"OCR识别完成, mode={mode}, time={inference_time:.4f}s, "
|
|
f"results={len(results)}")
|
|
|
|
return results
|
|
|
|
def _postprocess(self, predictions: Dict, mode: str) -> List[OCRResult]:
|
|
"""
|
|
后处理:置信度过滤 + 结果合并
|
|
|
|
- 过滤低于阈值的识别结果
|
|
- 相邻字符合并为词/行
|
|
- 生成逐字详情信息
|
|
"""
|
|
results = []
|
|
|
|
if mode == "char":
|
|
# 逐字模式:返回每个字符的独立结果
|
|
for char_pred in predictions.get("chars", []):
|
|
if char_pred["confidence"] >= self.confidence_threshold:
|
|
result = OCRResult(
|
|
text=char_pred["char"],
|
|
confidence=char_pred["confidence"],
|
|
bbox=char_pred["bbox"],
|
|
char_details=[CharDetail(
|
|
char=char_pred["char"],
|
|
confidence=char_pred["confidence"],
|
|
bbox=char_pred["bbox"],
|
|
stroke_indices=char_pred.get("stroke_indices", [])
|
|
)]
|
|
)
|
|
results.append(result)
|
|
|
|
elif mode in ("line", "page"):
|
|
# 行/页模式:合并字符为文本行
|
|
for line_pred in predictions.get("lines", []):
|
|
if line_pred["confidence"] >= self.confidence_threshold:
|
|
char_details = [
|
|
CharDetail(
|
|
char=cd["char"],
|
|
confidence=cd["confidence"],
|
|
bbox=cd["bbox"],
|
|
stroke_indices=cd.get("stroke_indices", [])
|
|
)
|
|
for cd in line_pred.get("char_details", [])
|
|
]
|
|
result = OCRResult(
|
|
text=line_pred["text"],
|
|
confidence=line_pred["confidence"],
|
|
bbox=line_pred["bbox"],
|
|
char_details=char_details
|
|
)
|
|
results.append(result)
|
|
|
|
return results
|
|
|
|
def _draw_line(self, canvas: np.ndarray, x1: int, y1: int,
|
|
x2: int, y2: int, thickness: int = 1):
|
|
"""Bresenham直线绘制算法"""
|
|
h, w = canvas.shape
|
|
dx = abs(x2 - x1)
|
|
dy = abs(y2 - y1)
|
|
sx = 1 if x1 < x2 else -1
|
|
sy = 1 if y1 < y2 else -1
|
|
err = dx - dy
|
|
|
|
while True:
|
|
# 绘制像素(带粗细)
|
|
for tx in range(-thickness, thickness + 1):
|
|
for ty in range(-thickness, thickness + 1):
|
|
px, py = x1 + tx, y1 + ty
|
|
if 0 <= px < w and 0 <= py < h:
|
|
canvas[py][px] = 1.0
|
|
|
|
if x1 == x2 and y1 == y2:
|
|
break
|
|
e2 = 2 * err
|
|
if e2 > -dy:
|
|
err -= dy
|
|
x1 += sx
|
|
if e2 < dx:
|
|
err += dx
|
|
y1 += sy
|
|
|
|
def _mock_inference(self, image: np.ndarray, mode: str) -> Dict:
|
|
"""模拟推理结果(用于示例)"""
|
|
return {
|
|
"lines": [{
|
|
"text": "示例文字",
|
|
"confidence": 0.95,
|
|
"bbox": [10, 10, 200, 48],
|
|
"char_details": [
|
|
{"char": "示", "confidence": 0.96, "bbox": [10, 10, 50, 48]},
|
|
{"char": "例", "confidence": 0.94, "bbox": [50, 10, 100, 48]},
|
|
{"char": "文", "confidence": 0.97, "bbox": [100, 10, 150, 48]},
|
|
{"char": "字", "confidence": 0.93, "bbox": [150, 10, 200, 48]}
|
|
]
|
|
}],
|
|
"chars": []
|
|
}
|
|
|
|
def _decrypt_model(self, model_path: str) -> str:
|
|
"""AES-256解密模型文件"""
|
|
# 使用预配置的密钥解密模型文件
|
|
# key = settings.model_encryption_key
|
|
# cipher = AES.new(key, AES.MODE_CBC, iv)
|
|
return model_path
|
|
|
|
|
|
# 全局OCR引擎实例
|
|
ocr_engine = OCREngine()
|
|
|
|
|
|
# ==================== API 路由 ====================
|
|
|
|
@router.post("/recognize", response_model=OCRResponse)
|
|
async def recognize_text(request: OCRRequest):
|
|
"""
|
|
手写文字OCR识别接口
|
|
POST /api/v1/ocr/recognize
|
|
|
|
接收笔迹坐标数据,返回识别文本及逐字详情
|
|
支持中文、英文及中英混合识别
|
|
"""
|
|
# 输入校验
|
|
if not request.strokes:
|
|
raise HTTPException(status_code=400, detail="笔迹数据不能为空")
|
|
|
|
total_points = sum(len(stroke) for stroke in request.strokes)
|
|
if total_points > 50000:
|
|
raise HTTPException(status_code=400, detail="笔迹点数过多,最大支持50000点")
|
|
|
|
# 执行OCR识别
|
|
results = ocr_engine.recognize(
|
|
strokes=request.strokes,
|
|
mode=request.recognition_mode
|
|
)
|
|
|
|
# 构建响应
|
|
return OCRResponse(
|
|
code=200,
|
|
msg="success",
|
|
data={
|
|
"request_id": str(uuid.uuid4()),
|
|
"language": request.language,
|
|
"mode": request.recognition_mode,
|
|
"results": [r.dict() for r in results],
|
|
"total_chars": sum(len(r.text) for r in results)
|
|
}
|
|
)
|
|
|
|
|
|
@router.post("/batch-recognize")
|
|
async def batch_recognize(requests: List[OCRRequest]):
|
|
"""
|
|
批量OCR识别接口
|
|
一次请求识别多组笔迹数据
|
|
"""
|
|
results = []
|
|
for req in requests:
|
|
result = ocr_engine.recognize(
|
|
strokes=req.strokes,
|
|
mode=req.recognition_mode
|
|
)
|
|
results.append({
|
|
"page_id": req.page_id,
|
|
"results": [r.dict() for r in result]
|
|
})
|
|
|
|
return {
|
|
"code": 200,
|
|
"msg": "success",
|
|
"data": {
|
|
"batch_size": len(requests),
|
|
"results": results
|
|
}
|
|
}
|