2493 lines
93 KiB
Markdown
2493 lines
93 KiB
Markdown
# 自然写手写识别与AI分析引擎软件 V1.0
|
||
## 软件著作权鉴别材料(技术设计说明书)
|
||
|
||
| 项目 | 内容 |
|
||
|------|------|
|
||
| 软件全称 | 自然写手写识别与AI分析引擎软件 |
|
||
| 软件简称 | 自然写AI引擎 |
|
||
| 版本号 | V1.0 |
|
||
| 权利人 | 深圳自然写科技有限公司 |
|
||
| 开发语言 | Python / C++ |
|
||
| 运行环境 | Linux服务器(GPU加速) |
|
||
| 文档类型 | 技术设计说明书 |
|
||
| 编制日期 | 2026年2月 |
|
||
|
||
---
|
||
|
||
## 目录
|
||
|
||
- 第一章 软件整体概述
|
||
- 1.1 软件简介与功能综述
|
||
- 1.2 软件用途与适用场景
|
||
- 1.3 运行环境与系统要求
|
||
- 1.4 开发语言与技术框架
|
||
- 1.5 版本说明
|
||
- 第二章 系统架构与设计思路
|
||
- 2.1 总体架构设计
|
||
- 2.2 推理管道设计
|
||
- 2.3 各层次模块说明
|
||
- 2.4 数据结构设计
|
||
- 2.5 接口设计
|
||
- 2.6 安全设计
|
||
- 2.7 部署架构
|
||
- 第三章 核心模块功能详细说明
|
||
- 3.1 中英文手写文字OCR识别模块
|
||
- 3.2 数学列式与公式识别模块
|
||
- 3.3 中文汉字笔顺识别与评分模块
|
||
- 3.4 书写质量评测模块
|
||
- 3.5 AI作文评分与批改模块
|
||
- 3.6 自动批改引擎模块
|
||
- 3.7 识别置信度评估模块
|
||
- 3.8 模型版本管理与热更新模块
|
||
- 3.9 推理任务调度模块
|
||
- 第四章 操作流程与使用步骤
|
||
- 4.1 服务部署与启动
|
||
- 4.2 模型加载与初始化
|
||
- 4.3 识别任务提交流程
|
||
- 4.4 模型更新与灰度发布流程
|
||
- 4.5 性能调优与故障排除
|
||
- 第五章 与源代码的对应关系
|
||
- 5.1 模块与源代码文件对应表
|
||
- 5.2 核心函数与方法说明
|
||
- 5.3 命名规范
|
||
- 附录
|
||
- 附录A 术语表
|
||
- 附录B 版本历史
|
||
|
||
---
|
||
|
||
# 第一章 软件整体概述
|
||
|
||
## 1.1 软件简介与功能综述
|
||
|
||
自然写手写识别与AI分析引擎软件(以下简称"AI引擎")是自然写互动课堂系统的智能化核心组件,负责对智能点阵笔采集的手写笔迹数据进行深度学习推理,实现手写文字识别、数学公式解析、笔顺评估、书写质量分析及作文智能评分等多项AI能力。
|
||
|
||
AI引擎基于PaddleOCR、PyTorch和ONNX Runtime等主流深度学习框架构建,在NVIDIA GPU集群上运行,通过NVIDIA Triton Inference Server进行模型管理和并发推理调度。AI引擎对外提供RESTful HTTP接口和gRPC高性能接口,供云平台后端和算力盒端侧推理调用。
|
||
|
||
**主要功能模块概述:**
|
||
|
||
中英文手写文字OCR识别:基于PaddleOCR技术,对点阵笔采集的手写笔迹坐标序列进行字符识别,支持简体中文、繁体中文和英文字母数字的混合识别,识别准确率在标准书写条件下达到96%以上。
|
||
|
||
数学列式与公式识别:专门针对K12教育场景中的数学手写内容,识别加减乘除四则运算、分数、小数、方程组、几何符号等数学元素,并对计算结果进行自动验证。
|
||
|
||
汉字笔顺识别与评分:通过分析笔画的书写顺序,与标准笔顺库进行比对,对学生书写的每个汉字给出笔顺正确性评分,帮助学生养成正确的书写习惯。
|
||
|
||
书写质量评测:从字体结构、笔画间距、整体规范性等多个维度对书写质量进行综合评分,提供具体的改进建议。
|
||
|
||
AI作文评分与批改:基于NLP模型对手写作文内容进行评分,从结构完整性、语言表达、内容丰富性、书写规范性四个维度给出综合评分,并标注典型错误位置。
|
||
|
||
选择题/填空题/简答题自动批改:根据题目类型和标准答案,对学生的作答内容进行自动批改,支持容错匹配(允许同义词、近义词等合理变体)。
|
||
|
||
## 1.2 软件用途与适用场景
|
||
|
||
**主要适用场景:**
|
||
|
||
(1)课堂作业自动批改:学生完成纸质作业后,点阵笔采集的笔迹数据上传至AI引擎,由引擎自动完成批改并给出成绩,大幅降低教师批改工作量,实现即时反馈。
|
||
|
||
(2)课堂互动实时识别:在课堂互动答题环节,学生在纸上作答,AI引擎在数百毫秒内完成识别,将识别结果实时推送至大屏展示,实现真正的"即写即知"。
|
||
|
||
(3)写字练习辅导:在写字课和书法练习场景中,AI引擎对学生的每个字进行笔顺评分和书写质量评测,配合大屏实时展示评分结果,形成即时纠正反馈循环。
|
||
|
||
(4)考试阅卷辅助:在期中期末考试场景中,AI引擎先对试卷进行初步批改,教师仅需对置信度低于阈值的题目进行人工复核,大幅提升阅卷效率。
|
||
|
||
(5)第三方教育平台赋能:通过SDK和API,将AI引擎能力输出至其他教育软件平台,使其获得手写识别和智能批改能力。
|
||
|
||
## 1.3 运行环境与系统要求
|
||
|
||
**服务端运行环境:**
|
||
|
||
| 组件 | 要求 |
|
||
|------|------|
|
||
| 操作系统 | Ubuntu 20.04 LTS / CentOS 7.6+ |
|
||
| Python版本 | Python 3.9+ |
|
||
| CUDA版本 | CUDA 11.8+ / CUDA 12.0+ |
|
||
| cuDNN版本 | cuDNN 8.6+ |
|
||
| GPU型号 | NVIDIA T4 / A10 / A100(推荐) |
|
||
| 内存要求 | 最低32GB,推荐64GB+(加载多个大模型) |
|
||
| 存储要求 | 200GB+ SSD(存储模型文件和临时推理数据) |
|
||
|
||
**主要依赖软件版本:**
|
||
|
||
| 依赖包 | 版本 | 用途 |
|
||
|-------|------|------|
|
||
| PaddlePaddle-GPU | 2.5.x | OCR基础框架 |
|
||
| PaddleOCR | 2.7.x | 手写文字识别 |
|
||
| PyTorch | 2.1.x | 数学识别和作文评分模型 |
|
||
| ONNX Runtime | 1.16.x | 跨框架模型推理 |
|
||
| FastAPI | 0.104.x | HTTP REST接口框架 |
|
||
| gRPC | 1.59.x | 高性能流式接口 |
|
||
| Celery | 5.3.x | 异步任务队列 |
|
||
| Redis | 7.0.x | 消息代理和结果缓存 |
|
||
| NVIDIA Triton | 23.10 | 模型服务化部署 |
|
||
| MLflow | 2.8.x | 模型版本管理 |
|
||
|
||
## 1.4 开发语言与技术框架
|
||
|
||
**Python技术栈(主要业务逻辑):**
|
||
|
||
- FastAPI:构建高性能异步HTTP接口,支持OpenAPI自动文档生成
|
||
- gRPC + protobuf:实现高吞吐量的流式识别接口
|
||
- Celery + Redis:构建分布式异步任务队列,处理批量识别请求
|
||
- PaddleOCR:手写OCR识别核心框架,基于PP-OCR v4模型
|
||
- PyTorch:数学公式识别和作文评分模型的推理框架
|
||
- ONNX Runtime:将训练好的模型转换为跨平台格式进行高效推理
|
||
|
||
**C++ 扩展模块(性能关键路径):**
|
||
|
||
- 笔迹坐标预处理(去噪、归一化)使用C++扩展实现,通过Python ctypes调用
|
||
- 图像渲染(将坐标序列渲染为图像供模型输入)使用OpenCV C++ API实现
|
||
|
||
**代码规范:**
|
||
|
||
- Python代码遵循PEP 8规范,使用Black格式化,flake8进行静态检查
|
||
- 函数注解使用Python类型标注(Type Hints),保证代码可读性
|
||
- 所有公开函数和类均编写docstring文档
|
||
|
||
## 1.5 版本说明
|
||
|
||
| 版本号 | 发布日期 | 说明 |
|
||
|-------|---------|------|
|
||
| V1.0 | 2026年2月 | 初始版本,包含OCR识别、数学识别、笔顺评分、书写质量评测、作文评分全功能 |
|
||
|
||
---
|
||
|
||
# 第二章 系统架构与设计思路
|
||
|
||
## 2.1 总体架构设计
|
||
|
||
AI引擎采用**推理管道(Pipeline)架构**,将识别任务分解为标准化的处理阶段,每个阶段独立可替换,便于模型升级和能力扩展。整体分为接口层、调度层、推理层、GPU管理层和模型仓库五个层次。
|
||
|
||
总体架构示意:
|
||
|
||
```
|
||
外部调用方(云平台 / 算力盒)
|
||
↓
|
||
┌──────────────────────────────────────────────────┐
|
||
│ 接口层 │
|
||
│ FastAPI REST(同步短任务) gRPC Server(流式批量)│
|
||
└──────────────────────────────────────────────────┘
|
||
↓
|
||
┌──────────────────────────────────────────────────┐
|
||
│ 调度层 │
|
||
│ Celery Worker × N + Redis Broker │
|
||
│ (按优先级调度:实时课堂 > 作业批改 > 批量) │
|
||
└──────────────────────────────────────────────────┘
|
||
↓
|
||
┌──────────────────────────────────────────────────┐
|
||
│ 推理层 │
|
||
│ OCR引擎 数学识别 笔顺分析 书写质量 作文评分 │
|
||
│ (PaddleOCR)(PyTorch)(自研模型)(自研NLP) │
|
||
└──────────────────────────────────────────────────┘
|
||
↓
|
||
┌──────────────────────────────────────────────────┐
|
||
│ GPU管理层 │
|
||
│ NVIDIA Triton Inference Server │
|
||
│ (多模型并发推理,GPU显存池化管理) │
|
||
└──────────────────────────────────────────────────┘
|
||
↓
|
||
┌──────────────────────────────────────────────────┐
|
||
│ 模型仓库 │
|
||
│ MinIO(模型文件存储)+ MLflow(版本管理) │
|
||
└──────────────────────────────────────────────────┘
|
||
```
|
||
|
||
## 2.2 推理管道设计
|
||
|
||
手写笔迹识别的完整推理管道包含以下标准化阶段:
|
||
|
||
**阶段1:输入预处理**
|
||
|
||
将原始笔迹坐标序列转换为模型可接受的输入格式。对于图像输入的OCR模型,需要将坐标序列渲染为灰度图像;对于序列输入的模型,需要进行坐标归一化和序列填充。
|
||
|
||
预处理步骤:
|
||
```
|
||
原始坐标序列 [{x, y, pressure, timestamp}, ...]
|
||
↓
|
||
坐标去噪(去除异常跳变点,使用Savitzky-Golay滤波器)
|
||
↓
|
||
坐标归一化(缩放至[0,1]范围,消除书写面积差异)
|
||
↓
|
||
笔画分割(根据笔离纸事件标志分割为独立笔画)
|
||
↓
|
||
图像渲染(将坐标序列绘制为固定分辨率的灰度图像,用于OCR模型输入)
|
||
↓
|
||
格式化为模型输入张量
|
||
```
|
||
|
||
**阶段2:模型推理**
|
||
|
||
根据识别任务类型,将处理后的输入数据分发至对应的推理模型:
|
||
|
||
| 识别任务 | 模型类型 | 输入格式 | 输出格式 |
|
||
|---------|---------|---------|---------|
|
||
| 手写文字OCR | PP-OCR v4(检测+识别) | 灰度图像 224×224 | 文字字符串 + 置信度 |
|
||
| 数学公式识别 | Transformer-based | 笔画序列坐标 | LaTeX格式公式 |
|
||
| 笔顺评分 | LSTM序列分类 | 笔画顺序序列 | 笔顺正确性分数 |
|
||
| 书写质量评测 | CNN分类 | 单字图像 64×64 | 多维度质量分数 |
|
||
| 作文评分 | BERT+多任务学习 | 识别后的文字序列 | 各维度评分 |
|
||
|
||
**阶段3:后处理**
|
||
|
||
对模型原始输出进行格式化和质量过滤:
|
||
- 置信度过滤:低于阈值(默认0.7)的识别结果标记为"需人工确认"
|
||
- 结果合并:将多个字符的识别结果按位置关系合并为完整的词句
|
||
- 格式转换:将模型输出转换为标准化的JSON响应格式
|
||
- 错误恢复:推理异常时返回降级结果(如置信度为0的空结果)而非抛出异常
|
||
|
||
## 2.3 各层次模块说明
|
||
|
||
**接口层:**
|
||
|
||
接口层提供两种接入方式,适应不同的调用场景:
|
||
|
||
FastAPI REST接口用于云平台后端的同步调用,适合单次识别请求。接口为异步设计(async/await),在等待GPU推理时不阻塞服务器线程,理论上支持数千个并发连接。
|
||
|
||
gRPC接口用于高性能批量识别和流式识别场景。与云平台和算力盒之间的大量并发识别请求通过gRPC流式传输处理,gRPC基于HTTP/2协议,支持请求多路复用,网络利用率更高。
|
||
|
||
**调度层:**
|
||
|
||
Celery分布式任务队列负责管理识别请求的优先级和资源分配:
|
||
|
||
- 高优先级队列(realtime_queue):处理课堂互动场景的实时识别请求,延迟目标 < 500ms
|
||
- 中优先级队列(assignment_queue):处理作业批改请求,延迟目标 < 5s
|
||
- 低优先级队列(batch_queue):处理批量历史数据重新识别请求,无严格延迟要求
|
||
|
||
Celery Worker数量根据GPU资源动态调整,每个Worker绑定一个GPU推理进程。
|
||
|
||
**推理层:**
|
||
|
||
推理层为每种识别任务实现独立的引擎类,继承自统一的BaseEngine抽象基类:
|
||
|
||
```python
|
||
class BaseEngine:
|
||
def preprocess(self, stroke_data: StrokeData) -> Tensor
|
||
def infer(self, tensor: Tensor) -> RawResult
|
||
def postprocess(self, raw_result: RawResult) -> RecognitionResult
|
||
def recognize(self, stroke_data: StrokeData) -> RecognitionResult
|
||
```
|
||
|
||
各具体引擎类(OCREngine, MathEngine, StrokeOrderEngine等)分别实现上述接口,保持一致的调用协议。
|
||
|
||
## 2.4 数据结构设计
|
||
|
||
**输入数据结构(StrokeData):**
|
||
|
||
```python
|
||
@dataclass
|
||
class StrokePoint:
|
||
x: int # X坐标(点阵码坐标系,单位:0.01mm)
|
||
y: int # Y坐标
|
||
pressure: int # 笔压(0-255)
|
||
timestamp: int # 时间戳(毫秒)
|
||
pen_up: bool # 是否抬笔标志
|
||
|
||
@dataclass
|
||
class Stroke:
|
||
points: List[StrokePoint] # 单条笔画的坐标点列表
|
||
stroke_index: int # 笔画序号(从0开始)
|
||
|
||
@dataclass
|
||
class StrokeData:
|
||
strokes: List[Stroke] # 所有笔画列表
|
||
pen_id: str # 笔设备MAC地址
|
||
page_id: int # 对应点阵纸张页面ID
|
||
student_id: int # 学生ID
|
||
assignment_id: int # 作业ID
|
||
region_type: str # 书写区域类型(hanzi/math/text/essay)
|
||
```
|
||
|
||
**识别结果数据结构(RecognitionResult):**
|
||
|
||
```python
|
||
@dataclass
|
||
class BoundingBox:
|
||
x1: int; y1: int; x2: int; y2: int # 矩形边界坐标
|
||
|
||
@dataclass
|
||
class OCRResult:
|
||
text: str # 识别的文字内容
|
||
confidence: float # 识别置信度(0.0-1.0)
|
||
bbox: BoundingBox # 文字区域边界框
|
||
char_details: List[CharDetail] # 逐字详情(含每字的置信度)
|
||
|
||
@dataclass
|
||
class MathResult:
|
||
latex: str # LaTeX格式数学公式
|
||
display_formula: str # 可读展示格式
|
||
numeric_result: Optional[str] # 计算数值结果(若可计算)
|
||
is_correct: Optional[bool] # 是否答对(需标准答案对比)
|
||
steps: List[str] # 解题步骤列表
|
||
|
||
@dataclass
|
||
class StrokeOrderResult:
|
||
char: str # 被评估的汉字
|
||
written_order: List[int] # 学生实际书写的笔画顺序
|
||
correct_order: List[int] # 标准笔顺
|
||
score: int # 笔顺得分(0-100)
|
||
errors: List[StrokeOrderError] # 错误笔顺列表
|
||
|
||
@dataclass
|
||
class WritingQualityResult:
|
||
overall_score: int # 总体书写质量分(0-100)
|
||
structure_score: int # 字形结构分
|
||
proportion_score: int # 笔画比例分
|
||
regularity_score: int # 规范性分
|
||
suggestions: List[str] # 改进建议列表
|
||
|
||
@dataclass
|
||
class EssayResult:
|
||
total_score: int # 作文总分(满分100分)
|
||
structure_score: int # 结构完整性分
|
||
language_score: int # 语言表达分
|
||
content_score: int # 内容丰富性分
|
||
handwriting_score: int # 书写规范性分
|
||
error_marks: List[ErrorMark] # 错误标注位置列表
|
||
overall_comment: str # 总体评语
|
||
```
|
||
|
||
## 2.5 接口设计
|
||
|
||
**REST接口(FastAPI):**
|
||
|
||
| 接口名称 | HTTP方法 | 路径 | 请求体 | 响应体 | 说明 |
|
||
|---------|---------|-----|-------|-------|------|
|
||
| 文字OCR识别 | POST | /api/v1/ocr/recognize | StrokeData | OCRResult | 单次文字识别 |
|
||
| 数学公式识别 | POST | /api/v1/math/recognize | StrokeData | MathResult | 数学列式识别 |
|
||
| 笔顺评分 | POST | /api/v1/stroke-order/evaluate | StrokeData + char | StrokeOrderResult | 汉字笔顺评估 |
|
||
| 书写质量评测 | POST | /api/v1/writing/quality | StrokeData | WritingQualityResult | 书写质量分析 |
|
||
| 作文批改 | POST | /api/v1/essay/review | StrokeData | EssayResult | AI作文评分 |
|
||
| 批量识别(异步) | POST | /api/v1/recognize/batch | List[StrokeData] | task_id | 异步批量识别,返回任务ID |
|
||
| 查询任务结果 | GET | /api/v1/task/{task_id} | - | List[RecognitionResult] | 查询批量识别结果 |
|
||
| 模型状态 | GET | /api/v1/model/status | - | List[ModelStatus] | 所有已加载模型状态 |
|
||
|
||
**gRPC接口(高性能流式):**
|
||
|
||
```protobuf
|
||
service RecognitionService {
|
||
// 单次识别(Unary RPC)
|
||
rpc Recognize(RecognizeRequest) returns (RecognizeResponse);
|
||
|
||
// 流式批量识别(Client Streaming RPC)
|
||
rpc BatchRecognize(stream RecognizeRequest) returns (BatchRecognizeResponse);
|
||
|
||
// 实时课堂识别(Bidirectional Streaming RPC)
|
||
rpc StreamRecognize(stream RecognizeRequest) returns (stream RecognizeResponse);
|
||
}
|
||
|
||
message RecognizeRequest {
|
||
repeated StrokeProto strokes = 1;
|
||
string region_type = 2; // "ocr" / "math" / "stroke_order" / "essay"
|
||
int64 student_id = 3;
|
||
int64 assignment_id = 4;
|
||
string request_id = 5; // 请求幂等ID
|
||
}
|
||
|
||
message StrokeProto {
|
||
repeated PointProto points = 1;
|
||
int32 stroke_index = 2;
|
||
}
|
||
|
||
message PointProto {
|
||
int32 x = 1;
|
||
int32 y = 2;
|
||
int32 pressure = 3;
|
||
int64 timestamp = 4;
|
||
bool pen_up = 5;
|
||
}
|
||
```
|
||
|
||
## 2.6 安全设计
|
||
|
||
**服务间认证:**
|
||
|
||
AI引擎作为内部服务,不直接暴露至公网。服务间通信采用mTLS(双向TLS)认证,调用方需持有CA签发的客户端证书,AI引擎验证客户端证书合法性后才处理请求。
|
||
|
||
**输入安全校验:**
|
||
|
||
- 数据大小限制:单次识别请求的笔画数据不超过1MB,防止超大请求占用GPU资源
|
||
- 数据格式校验:使用Pydantic对输入数据进行严格类型校验,拒绝格式不合规的请求
|
||
- 超时控制:单次推理任务超时30秒自动终止,防止GPU资源被长期占用
|
||
|
||
**模型文件保护:**
|
||
|
||
- 模型文件以加密方式存储于MinIO,下载时使用签名URL,有效期1小时
|
||
- 模型加载到Triton Server后,在GPU显存中的模型参数不允许外部读取
|
||
- 所有模型版本信息记录至MLflow,支持版本追溯和合规审计
|
||
|
||
**数据隐私:**
|
||
|
||
- 学生笔迹数据在AI引擎服务中仅用于推理,不持久化存储
|
||
- 识别完成后临时数据(预处理图像、中间张量)立即从内存清除
|
||
- 所有识别请求的调用日志仅记录任务ID和耗时,不记录原始笔迹内容
|
||
|
||
## 2.7 部署架构
|
||
|
||
**GPU集群部署方案:**
|
||
|
||
```
|
||
识别请求入口(内部网络)
|
||
↓
|
||
NVIDIA Triton Inference Server集群
|
||
┌────────────────────────────┐
|
||
│ GPU服务器节点1 │
|
||
│ ┌─────────────────────┐ │
|
||
│ │ Triton Server进程 │ │
|
||
│ │ 模型1: PP-OCR v4 │ │
|
||
│ │ 模型2: MathNet │ │
|
||
│ │ 模型3: StrokeOrder │ │
|
||
│ └─────────────────────┘ │
|
||
│ NVIDIA T4 GPU × 2 │
|
||
└────────────────────────────┘
|
||
┌────────────────────────────┐
|
||
│ GPU服务器节点N(同上结构) │
|
||
└────────────────────────────┘
|
||
↓
|
||
Celery调度层(CPU服务器)
|
||
┌─────────────────────────────────────────┐
|
||
│ Celery Worker × 8(每个Worker对应一个GPU)│
|
||
│ Redis Broker(任务队列) │
|
||
└─────────────────────────────────────────┘
|
||
↓
|
||
FastAPI/gRPC接口层(CPU服务器,多副本)
|
||
```
|
||
|
||
**模型热更新流程:**
|
||
|
||
新模型上线采用金丝雀发布策略:
|
||
1. 新模型文件上传至MinIO并在MLflow注册新版本
|
||
2. Triton Server加载新模型版本,保留旧版本(双版本并行运行)
|
||
3. 通过路由配置将5%的请求路由至新版本模型(金丝雀流量)
|
||
4. 观察新版本模型的识别准确率和推理延迟指标(观察期24小时)
|
||
5. 指标正常则逐步将流量从5%提升至50%、100%,完成版本切换
|
||
6. 旧版本模型保留7天后从Triton中卸载,释放GPU显存
|
||
|
||
---
|
||
|
||
# 第三章 核心模块功能详细说明
|
||
|
||
## 3.1 中英文手写文字OCR识别模块
|
||
|
||
**模块文件:** `engine/ocr_engine.py`
|
||
|
||
**功能概述:**
|
||
|
||
OCR识别模块基于PaddleOCR的PP-OCR v4模型,对手写笔迹进行文字识别,支持简体中文、繁体中文、英文字母、阿拉伯数字的混合识别。针对点阵笔书写场景的特点(笔迹细、书写速度快、字体不规范),对通用OCR模型进行了针对性微调。
|
||
|
||
**处理流程:**
|
||
|
||
```
|
||
步骤1:接收StrokeData(笔画坐标序列)
|
||
步骤2:调用预处理模块(preprocessing/stroke_preprocessor.py)
|
||
- 去除噪声点(压力值异常的点)
|
||
- 坐标归一化(缩放至标准坐标系)
|
||
- 笔画平滑(贝塞尔曲线拟合)
|
||
步骤3:调用图像渲染器,将平滑后的笔画绘制为灰度图像
|
||
- 分辨率:640×480像素(A4纸比例)
|
||
- 线宽:根据压力值动态调整(2-5像素)
|
||
步骤4:将灰度图像发送至Triton Server(OCR检测模型)
|
||
- 检测模型:PP-OCRv4-det(文字区域检测)
|
||
- 输出:文字边界框列表(BBox坐标)
|
||
步骤5:裁剪各文字区域,发送至OCR识别模型
|
||
- 识别模型:PP-OCRv4-rec(字符序列识别)
|
||
- 输出:字符串序列 + CTC解码置信度
|
||
步骤6:后处理:合并相邻文字区域,生成完整识别结果
|
||
步骤7:返回OCRResult(含识别文本和置信度)
|
||
```
|
||
|
||
**微调策略:**
|
||
|
||
针对K12教育场景的手写特点,收集并标注了覆盖1-9年级所有常用汉字的手写样本10万余张,在PP-OCR预训练模型基础上进行微调,重点提升以下场景的识别准确率:
|
||
- 低年级学生字体不规范(笔画变形、偏旁错位)
|
||
- 书写速度快导致的笔画粘连
|
||
- 铅笔书写(笔迹较轻,对比度低)
|
||
|
||
**性能指标:**
|
||
|
||
| 指标 | 目标值 | 实测值 |
|
||
|------|-------|-------|
|
||
| 单字识别准确率(标准书写) | ≥ 96% | 97.3% |
|
||
| 单字识别准确率(低年级书写) | ≥ 90% | 91.8% |
|
||
| 单次识别延迟(单字) | ≤ 200ms | 约120ms(T4 GPU) |
|
||
| 单次识别延迟(整页约50字) | ≤ 1s | 约600ms(T4 GPU) |
|
||
|
||
## 3.2 数学列式与公式识别模块
|
||
|
||
**模块文件:** `engine/math_engine.py`
|
||
|
||
**功能概述:**
|
||
|
||
数学识别模块专门处理K12阶段数学手写内容,支持从小学四则运算到初中方程、不等式等数学表达式的识别和解析,是AI引擎的差异化核心能力之一。
|
||
|
||
**支持识别的数学元素:**
|
||
|
||
| 类别 | 示例 | 说明 |
|
||
|------|------|------|
|
||
| 四则运算 | 123 + 456 = 579 | 加减乘除运算式和结果验证 |
|
||
| 分数 | 3/4 + 1/2 = 5/4 | 分子分母识别,通分计算 |
|
||
| 小数 | 3.14 × 2 = 6.28 | 小数点精确识别 |
|
||
| 方程 | 2x + 3 = 7 | 含未知数方程,求解验证 |
|
||
| 不等式 | x > 5 | 不等号识别 |
|
||
| 几何符号 | ∠ABC = 90° | 角度、平行、垂直等几何符号 |
|
||
| 数学函数 | sin30° = 0.5 | 三角函数(初中及以上) |
|
||
|
||
**识别与验证流程:**
|
||
|
||
```
|
||
步骤1:笔迹预处理(同OCR预处理流程)
|
||
步骤2:数学符号检测(定位各运算符和数字的位置和类型)
|
||
步骤3:结构解析(根据位置关系建立数学表达式的树形结构)
|
||
- 识别分数线(水平长横线)分隔分子分母
|
||
- 识别上下标(指数、角标)
|
||
- 识别括号嵌套层次
|
||
步骤4:转换为LaTeX格式表达式(如 \frac{3}{4} + \frac{1}{2})
|
||
步骤5:调用符号计算引擎(SymPy库)进行数值验证
|
||
- 对于含等号的算式:计算左右两边并比对
|
||
- 对于方程:求解并返回解
|
||
步骤6:生成MathResult(LaTeX格式 + 计算结果 + 正确与否)
|
||
```
|
||
|
||
## 3.3 中文汉字笔顺识别与评分模块
|
||
|
||
**模块文件:** `engine/stroke_order_engine.py`
|
||
|
||
**功能概述:**
|
||
|
||
笔顺评分模块通过分析学生书写汉字时的笔画顺序,与内置的汉字标准笔顺数据库进行比对,评估书写的规范程度。该模块利用了点阵笔数据的独特优势——每支笔画的书写时间戳信息,使得笔顺分析成为可能(传统OCR仅能识别最终图像,无法获取书写过程)。
|
||
|
||
**标准笔顺数据库:**
|
||
|
||
内置覆盖3500个常用汉字(国家语委《现代汉语常用字表》)的标准笔顺库,数据来源为教育部发布的《汉字笔顺规范》。每个汉字的笔顺以笔画序号列表形式存储,如"一"字的标准笔顺为[1](1画横),"人"字的标准笔顺为[1, 2](先撇后捺)。
|
||
|
||
**笔顺评分算法:**
|
||
|
||
```
|
||
步骤1:按时间戳对笔画序列排序,得到学生实际书写顺序
|
||
步骤2:通过笔画方向特征(起笔方向、运笔方向、收笔方式)识别每条笔画的类型
|
||
(横/竖/撇/捺/点/折等8种基本笔画类型)
|
||
步骤3:将识别的笔画类型序列与目标汉字的标准笔顺库匹配
|
||
步骤4:使用编辑距离算法(Levenshtein Distance)计算学生笔顺与标准笔顺的差异度
|
||
步骤5:计算笔顺得分:
|
||
- 完全正确:100分
|
||
- 1处错误:80分
|
||
- 2处错误:60分
|
||
- 3处及以上错误:40分或以下
|
||
步骤6:标注具体的错误位置和正确顺序,生成评语
|
||
```
|
||
|
||
**输出示例:**
|
||
|
||
```json
|
||
{
|
||
"char": "永",
|
||
"written_order": [1, 2, 4, 3, 5, 6, 7, 8],
|
||
"correct_order": [1, 2, 3, 4, 5, 6, 7, 8],
|
||
"score": 85,
|
||
"errors": [
|
||
{
|
||
"position": 3,
|
||
"written": "竖弯钩",
|
||
"expected": "横折折撇",
|
||
"suggestion": "第3笔应先写横折折撇(㇘),再写竖弯钩"
|
||
}
|
||
]
|
||
}
|
||
```
|
||
|
||
## 3.4 书写质量评测模块
|
||
|
||
**模块文件:** `engine/writing_quality_engine.py`
|
||
|
||
**功能概述:**
|
||
|
||
书写质量评测模块对学生书写的汉字从字形结构、笔画比例、书写规范性三个维度进行综合评测,帮助学生提升书写美观度和规范性,适用于写字课、书法课等场景。
|
||
|
||
**评测维度与算法:**
|
||
|
||
(1)字形结构评分(占总分40%)
|
||
|
||
将学生书写的汉字渲染为标准尺寸图像,与字体模板库(仿宋体/楷体)进行结构对比。使用基于深度学习的相似度模型,计算笔画布局和重心位置的偏差程度,偏差越小得分越高。
|
||
|
||
(2)笔画比例评分(占总分30%)
|
||
|
||
分析各笔画的相对长度和角度是否符合标准比例。如"土"字中,下横应明显长于上横;"口"字的宽高比应接近1:1等。使用规则引擎对常见汉字的关键比例进行检测。
|
||
|
||
(3)书写规范性评分(占总分30%)
|
||
|
||
评估书写是否符合国家规定的书写规范:
|
||
- 笔画是否有起笔和收笔动作(而非随意涂划)
|
||
- 相邻笔画间距是否均匀
|
||
- 整字的倾斜角度是否在合理范围(±15°以内)
|
||
|
||
## 3.5 AI作文评分与批改模块
|
||
|
||
**模块文件:** `engine/essay_engine.py`
|
||
|
||
**功能概述:**
|
||
|
||
作文评分模块首先调用OCR模块将手写作文转换为文字,然后基于BERT预训练语言模型(使用Chinese-BERT-wwm微调版本)从多个维度对作文进行智能评分,并标注错别字和明显语法错误的位置。
|
||
|
||
**评分维度:**
|
||
|
||
| 维度 | 权重 | 评测内容 |
|
||
|------|------|---------|
|
||
| 结构完整性 | 25% | 是否有开头/主体/结尾,段落划分是否合理 |
|
||
| 语言表达 | 30% | 用词是否准确,句式是否多样,是否存在语病 |
|
||
| 内容丰富性 | 30% | 内容是否切题,是否有具体事例,立意是否新颖 |
|
||
| 书写规范性 | 15% | 错别字数量,标点符号使用是否正确 |
|
||
|
||
**错别字检测:**
|
||
|
||
使用基于字音相似和字形相似的错别字检测模型,结合N-gram语言模型判断词语在上下文中的合理性,综合识别常见错别字类型:
|
||
- 音近字:(渴/喝,带/戴)
|
||
- 形近字:(己/已,土/士)
|
||
- 字义混淆:(的/地/得,其他/其它)
|
||
|
||
## 3.6 自动批改引擎模块
|
||
|
||
**模块文件:** `service/grading_service.py`
|
||
|
||
**功能概述:**
|
||
|
||
自动批改引擎针对结构化题目(选择题、填空题、简答题)进行自动批改,根据教师预设的标准答案和评分规则,对学生识别后的作答内容进行评判。
|
||
|
||
**批改规则类型:**
|
||
|
||
| 规则类型 | 说明 | 示例 |
|
||
|---------|------|------|
|
||
| 精确匹配 | 作答与标准答案完全一致(字符级) | 填写"北京",标准答案"北京" |
|
||
| 容错匹配 | 允许同义词和变体(由教师配置容错词库) | "首都"视为"北京"的等价答案 |
|
||
| 数值范围匹配 | 数值结果在允许误差范围内 | 计算结果允许±0.01的误差 |
|
||
| 关键词匹配 | 简答题包含所有关键词即得分 | 简答题含"光合作用/叶绿体/葡萄糖"三个关键词 |
|
||
| 部分给分 | 简答题按关键词命中数量比例给分 | 3个关键词各占1/3分数 |
|
||
|
||
## 3.7 识别置信度评估模块
|
||
|
||
**模块文件:** `service/confidence_service.py`
|
||
|
||
**功能概述:**
|
||
|
||
置信度评估模块对每个识别结果的可靠性进行量化评分,引导后续处理流程决定是否需要人工干预,是AI引擎质量控制的关键环节。
|
||
|
||
**置信度计算方法:**
|
||
|
||
最终置信度由多个维度的分数加权计算得出:
|
||
|
||
```python
|
||
final_confidence = (
|
||
model_confidence * 0.5 + # 模型本身的softmax置信度
|
||
stroke_density_score * 0.2 + # 笔画密度质量分(过疏或过密均降低)
|
||
writing_consistency * 0.3 # 书写一致性分(与学生历史书写风格对比)
|
||
)
|
||
```
|
||
|
||
置信度分级与处理策略:
|
||
|
||
| 置信度范围 | 级别 | 处理策略 |
|
||
|-----------|------|---------|
|
||
| 0.90 - 1.00 | 高置信 | 自动接受,无需人工审核 |
|
||
| 0.70 - 0.89 | 中置信 | 自动接受,但在批改结果中标记颜色提示教师关注 |
|
||
| 0.50 - 0.69 | 低置信 | 标记为"需人工确认",教师手动复核 |
|
||
| 0.00 - 0.49 | 不可信 | 标记为"识别失败",不计入自动批改成绩 |
|
||
|
||
## 3.8 模型版本管理与热更新模块
|
||
|
||
**模块文件:** `service/model_manager.py`
|
||
|
||
**功能概述:**
|
||
|
||
模型版本管理模块负责管理AI引擎中所有推理模型的版本生命周期,支持模型的注册、加载、切换、回滚和归档操作,确保模型更新过程不影响线上服务的稳定性。
|
||
|
||
**版本管理功能:**
|
||
|
||
(1)模型注册:新训练完成的模型通过MLflow API注册,记录模型名称、版本号、训练数据集版本、训练时间、评估指标(准确率/召回率/F1)等元数据。
|
||
|
||
(2)版本状态管理:每个模型版本有以下状态:
|
||
- Staging(待验证):模型已注册,正在测试评估中
|
||
- Production(生产中):当前线上使用版本
|
||
- Archived(已归档):已退出使用的历史版本
|
||
|
||
(3)模型加载:Triton Server在启动时读取模型配置文件,从MinIO下载对应版本的模型文件,加载到GPU显存。支持运行时动态加载新版本而不重启服务。
|
||
|
||
(4)灰度发布:通过配置路由权重,将一定比例的推理请求路由至新版本模型,实现金丝雀发布。
|
||
|
||
(5)快速回滚:若新版本模型上线后发现准确率下降或错误率异常升高,可在30秒内将流量100%切回旧版本模型,最大限度减少对教学的影响。
|
||
|
||
## 3.9 推理任务调度模块
|
||
|
||
**模块文件:** `service/task_scheduler.py`
|
||
|
||
**功能概述:**
|
||
|
||
任务调度模块基于Celery实现分布式任务队列,管理多种优先级的识别任务,协调GPU资源的公平分配,防止低优先级的批量任务占用全部GPU资源导致高优先级实时任务延迟。
|
||
|
||
**调度策略:**
|
||
|
||
```python
|
||
# Celery队列配置
|
||
CELERY_TASK_ROUTES = {
|
||
'tasks.realtime_recognize': {'queue': 'realtime'}, # 实时课堂识别
|
||
'tasks.assignment_recognize': {'queue': 'assignment'}, # 作业批改识别
|
||
'tasks.batch_recognize': {'queue': 'batch'}, # 批量历史识别
|
||
}
|
||
|
||
# 各队列的Worker预留数量
|
||
QUEUE_WORKER_RESERVATION = {
|
||
'realtime': 4, # 预留4个Worker专用于实时任务,不被其他任务占用
|
||
'assignment': 2, # 2个Worker用于作业批改
|
||
'batch': 2, # 2个Worker用于批量任务(空闲时使用所有剩余Worker)
|
||
}
|
||
```
|
||
|
||
---
|
||
|
||
# 第四章 操作流程与使用步骤
|
||
|
||
## 4.1 服务部署与启动
|
||
|
||
**基于Docker Compose的开发环境部署:**
|
||
|
||
```
|
||
步骤1:确认NVIDIA驱动和CUDA已正确安装
|
||
nvidia-smi(应显示GPU信息)
|
||
步骤2:确认nvidia-container-toolkit已安装(使Docker容器可访问GPU)
|
||
步骤3:从代码仓库拉取AI引擎代码
|
||
git clone https://git.writech.com/ai-engine.git
|
||
步骤4:进入项目目录,复制环境配置文件
|
||
cp .env.example .env
|
||
步骤5:编辑.env文件,配置以下关键参数:
|
||
REDIS_URL=redis://redis:6379/0
|
||
MODEL_STORE_PATH=/models
|
||
MINIO_ENDPOINT=minio:9000
|
||
MINIO_ACCESS_KEY=your_access_key
|
||
MINIO_SECRET_KEY=your_secret_key
|
||
步骤6:启动服务(包含Redis、MinIO、Triton Server、Celery Worker、FastAPI)
|
||
docker compose up -d
|
||
步骤7:检查服务启动状态
|
||
docker compose ps(各服务应为running或healthy)
|
||
步骤8:验证API可用性
|
||
curl http://localhost:8000/api/v1/model/status
|
||
```
|
||
|
||
**模型文件准备:**
|
||
|
||
```
|
||
步骤1:从MinIO或共享存储下载预训练模型文件
|
||
python scripts/download_models.py --version v1.0
|
||
步骤2:验证模型文件完整性(SHA256校验)
|
||
python scripts/verify_models.py
|
||
步骤3:将模型文件放置至正确目录结构:
|
||
/models/
|
||
├── ocr_det/ # OCR检测模型
|
||
│ └── 1/
|
||
│ └── model.onnx
|
||
├── ocr_rec/ # OCR识别模型
|
||
│ └── 1/
|
||
│ └── model.onnx
|
||
├── math_rec/ # 数学识别模型
|
||
│ └── 1/
|
||
│ └── model.pt
|
||
└── stroke_order/ # 笔顺评分模型
|
||
└── 1/
|
||
└── model.onnx
|
||
步骤4:Triton Server将自动扫描/models目录并加载所有模型
|
||
```
|
||
|
||
## 4.2 模型加载与初始化
|
||
|
||
**Triton Server模型配置文件示例(OCR识别模型):**
|
||
|
||
```
|
||
文件路径:/models/ocr_rec/config.pbtxt
|
||
|
||
name: "ocr_rec"
|
||
platform: "onnxruntime_onnx"
|
||
max_batch_size: 32
|
||
input [
|
||
{
|
||
name: "images"
|
||
data_type: TYPE_FP32
|
||
dims: [3, 48, -1] # 通道×高度×宽度(宽度可变)
|
||
}
|
||
]
|
||
output [
|
||
{
|
||
name: "output"
|
||
data_type: TYPE_FP32
|
||
dims: [-1, 97] # 序列长度×字符表大小
|
||
}
|
||
]
|
||
instance_group [
|
||
{
|
||
count: 2
|
||
kind: KIND_GPU
|
||
gpus: [0]
|
||
}
|
||
]
|
||
```
|
||
|
||
## 4.3 识别任务提交流程
|
||
|
||
**通过REST API提交单次OCR识别任务:**
|
||
|
||
```
|
||
HTTP请求示例:
|
||
POST http://ai-engine:8000/api/v1/ocr/recognize
|
||
Content-Type: application/json
|
||
Authorization: Bearer <service_token>
|
||
|
||
{
|
||
"strokes": [
|
||
{
|
||
"stroke_index": 0,
|
||
"points": [
|
||
{"x": 1200, "y": 800, "pressure": 150, "timestamp": 1700000000100, "pen_up": false},
|
||
{"x": 1250, "y": 800, "pressure": 145, "timestamp": 1700000000110, "pen_up": false},
|
||
{"x": 1300, "y": 800, "pressure": 140, "timestamp": 1700000000120, "pen_up": true}
|
||
]
|
||
}
|
||
],
|
||
"region_type": "ocr",
|
||
"student_id": 12345,
|
||
"assignment_id": 67890
|
||
}
|
||
|
||
期望响应(200 OK):
|
||
{
|
||
"code": 200,
|
||
"data": {
|
||
"text": "一",
|
||
"confidence": 0.99,
|
||
"bbox": {"x1": 1180, "y1": 780, "x2": 1320, "y2": 820},
|
||
"char_details": [
|
||
{"char": "一", "confidence": 0.99, "bbox": {...}}
|
||
]
|
||
},
|
||
"latency_ms": 125
|
||
}
|
||
```
|
||
|
||
**通过gRPC提交批量识别任务(Python示例):**
|
||
|
||
```python
|
||
import grpc
|
||
from proto import recognition_pb2, recognition_pb2_grpc
|
||
|
||
channel = grpc.secure_channel('ai-engine:50051', credentials)
|
||
stub = recognition_pb2_grpc.RecognitionServiceStub(channel)
|
||
|
||
# 构建请求列表
|
||
requests = []
|
||
for stroke_data in stroke_data_list:
|
||
request = recognition_pb2.RecognizeRequest(
|
||
strokes=[...],
|
||
region_type="ocr",
|
||
student_id=student_id,
|
||
assignment_id=assignment_id
|
||
)
|
||
requests.append(request)
|
||
|
||
# 发送流式批量识别请求(返回结果流)
|
||
def request_generator():
|
||
for req in requests:
|
||
yield req
|
||
|
||
responses = stub.BatchRecognize(request_generator())
|
||
results = list(responses.results)
|
||
```
|
||
|
||
## 4.4 模型更新与灰度发布流程
|
||
|
||
**发布新版OCR模型的操作步骤:**
|
||
|
||
```
|
||
步骤1:模型训练完成后,在开发环境测试评估新模型
|
||
python eval/evaluate_ocr.py --model-path models/ocr_rec_v1.1.onnx
|
||
(应输出:准确率 97.5%,高于当前生产版本的97.3%)
|
||
步骤2:将新模型文件上传至MinIO
|
||
python scripts/upload_model.py --model ocr_rec_v1.1.onnx --version 1.1
|
||
步骤3:在MLflow注册新模型版本
|
||
python scripts/register_model.py --name ocr_rec --version 1.1 --stage Staging
|
||
步骤4:在Triton Server加载新版本模型(不停机)
|
||
python scripts/triton_ops.py --action load --model ocr_rec --version 1.1
|
||
步骤5:配置5%金丝雀流量至新版本
|
||
python scripts/routing_config.py --model ocr_rec --new-version 1.1 --traffic 5
|
||
步骤6:观察监控面板(Grafana中AI引擎Dashboard)
|
||
- 新版本准确率指标(应高于旧版本)
|
||
- 新版本推理延迟(应与旧版本持平或更低)
|
||
- 新版本错误率(应接近0)
|
||
步骤7:确认指标正常后,逐步扩大流量比例(5% → 50% → 100%)
|
||
步骤8:新版本稳定后,更新MLflow中的模型状态为Production
|
||
步骤9:旧版本模型状态改为Archived,7天后从Triton中卸载
|
||
```
|
||
|
||
## 4.5 性能调优与故障排除
|
||
|
||
**常见性能问题排查:**
|
||
|
||
| 问题现象 | 可能原因 | 排查方法 | 处理方案 |
|
||
|---------|---------|---------|---------|
|
||
| 识别延迟突然升高 | GPU利用率过高或Celery队列积压 | 查看Grafana中GPU利用率和队列深度 | 增加Worker数量或限流 |
|
||
| 识别准确率下降 | 模型加载异常或输入数据格式变化 | 查看模型版本,对比历史准确率 | 重新加载模型或回滚版本 |
|
||
| gRPC连接超时 | 网络问题或服务重启 | 检查服务状态和网络连通性 | 重启gRPC服务,检查网络配置 |
|
||
| 内存OOM崩溃 | 模型太大或并发请求过多占用内存 | 查看服务器内存使用率 | 减少并发Worker数量,增加内存 |
|
||
| 特定字符识别率低 | 训练数据不足或模型偏差 | 统计错误字符频率分布 | 补充相应字符的训练样本后重训 |
|
||
|
||
**GPU资源监控指标:**
|
||
|
||
```
|
||
通过NVIDIA Management Library(NVML)监控以下关键指标:
|
||
- GPU利用率(%):正常范围60-80%,超过95%需要扩容
|
||
- GPU显存使用(MB):加载所有模型后约占8-12GB(T4显卡16GB)
|
||
- GPU温度(°C):正常范围60-75°C,超过85°C触发降频告警
|
||
- GPU功耗(W):T4正常功耗80-120W
|
||
```
|
||
|
||
---
|
||
|
||
# 第五章 与源代码的对应关系
|
||
|
||
## 5.1 模块名称与源代码文件对应表
|
||
|
||
| 功能模块 | 目录/文件路径 | 主要类/函数 | 说明 |
|
||
|---------|-------------|-----------|------|
|
||
| 应用程序入口 | `main.py` | `app`(FastAPI实例), `main()` | 服务启动,路由注册,中间件配置 |
|
||
| REST接口层 | `api/` | `ocr_router.py`, `math_router.py`, `essay_router.py`, `model_router.py` | 各识别类型的REST接口路由 |
|
||
| gRPC服务层 | `grpc_server/` | `recognition_server.py`(RecognitionService实现) | gRPC流式识别服务 |
|
||
| 笔迹预处理 | `preprocessing/` | `stroke_preprocessor.py`(StrokePreprocessor类) | 去噪、归一化、笔画分割、图像渲染 |
|
||
| OCR识别引擎 | `engine/` | `ocr_engine.py`(OCREngine类) | 基于PaddleOCR的文字识别 |
|
||
| 数学识别引擎 | `engine/` | `math_engine.py`(MathEngine类) | 数学公式识别和验证 |
|
||
| 笔顺评分引擎 | `engine/` | (BaseEngine子类,stroke_order_engine.py) | 汉字笔顺评分 |
|
||
| 任务调度服务 | `service/` | `task_scheduler.py`(Celery任务定义) | Celery任务队列,优先级调度 |
|
||
| 批改业务服务 | `service/` | `grading_service.py`(GradingService类) | 自动批改规则引擎 |
|
||
| 服务配置 | `config/` | `settings.py`(Settings类,Pydantic BaseSettings) | 环境变量配置加载 |
|
||
|
||
## 5.2 核心函数与方法说明
|
||
|
||
**main.py 核心函数:**
|
||
|
||
```python
|
||
# FastAPI应用实例
|
||
app = FastAPI(title="Writech AI Recognition Engine", version="1.0.0")
|
||
|
||
# 启动时加载所有模型
|
||
@app.on_event("startup")
|
||
async def startup_event():
|
||
await ModelManager.load_all_models()
|
||
await initialize_celery_workers()
|
||
|
||
# 注册路由
|
||
app.include_router(ocr_router, prefix="/api/v1/ocr")
|
||
app.include_router(math_router, prefix="/api/v1/math")
|
||
app.include_router(essay_router, prefix="/api/v1/essay")
|
||
app.include_router(model_router, prefix="/api/v1/model")
|
||
```
|
||
|
||
**OCREngine 核心方法:**
|
||
|
||
| 方法名 | 签名 | 功能说明 |
|
||
|-------|-----|---------|
|
||
| `preprocess` | `preprocess(stroke_data: StrokeData) -> np.ndarray` | 将笔迹坐标转换为灰度图像数组 |
|
||
| `infer` | `infer(image: np.ndarray) -> RawOCRResult` | 调用Triton进行OCR推理 |
|
||
| `postprocess` | `postprocess(raw: RawOCRResult) -> OCRResult` | 合并字符,过滤低置信度结果 |
|
||
| `recognize` | `recognize(stroke_data: StrokeData) -> OCRResult` | 完整OCR识别流程入口 |
|
||
|
||
**preprocessing/stroke_preprocessor.py 核心方法:**
|
||
|
||
| 方法名 | 功能说明 |
|
||
|-------|---------|
|
||
| `remove_noise_points(strokes)` | 去除异常跳变点(使用统计方法检测离群点) |
|
||
| `normalize_coordinates(strokes)` | 坐标归一化至[0,1]范围 |
|
||
| `split_strokes_by_pen_up(strokes)` | 根据pen_up标志将连续坐标序列分割为独立笔画 |
|
||
| `render_to_image(strokes, size)` | 将笔画列表渲染为指定尺寸的灰度图像(PIL/OpenCV) |
|
||
| `apply_bezier_smoothing(stroke)` | 对单条笔画应用贝塞尔曲线平滑,去除抖动 |
|
||
|
||
## 5.3 命名规范
|
||
|
||
**Python包命名规范:**
|
||
|
||
```
|
||
writech_ai_engine/
|
||
├── api/ # REST接口路由(功能名+_router.py)
|
||
├── config/ # 配置类(settings.py,单文件)
|
||
├── engine/ # 识别引擎类(功能名+_engine.py)
|
||
├── grpc_server/ # gRPC服务实现
|
||
├── preprocessing/ # 数据预处理
|
||
├── service/ # 业务服务层(功能名+_service.py)
|
||
└── model/ # 数据模型(Pydantic Schema类)
|
||
```
|
||
|
||
**类命名规范:**
|
||
|
||
| 类型 | 命名规则 | 示例 |
|
||
|------|---------|------|
|
||
| 识别引擎类 | XxxEngine | OCREngine, MathEngine, StrokeOrderEngine |
|
||
| 数据模型类 | XxxResult / XxxData | OCRResult, StrokeData, MathResult |
|
||
| 服务类 | XxxService | GradingService, ModelManagerService |
|
||
| 配置类 | XxxSettings / XxxConfig | Settings, TritonConfig |
|
||
| Celery任务 | 函数命名:xxx_task | ocr_recognize_task, essay_review_task |
|
||
|
||
---
|
||
|
||
# 附录
|
||
|
||
## 附录A 术语表
|
||
|
||
| 术语 | 说明 |
|
||
|------|------|
|
||
| PaddleOCR | 百度开源的OCR工具库,基于飞桨深度学习框架,提供文字检测、识别等完整OCR能力 |
|
||
| Triton Inference Server | NVIDIA提供的高性能模型服务框架,支持多模型并发推理和GPU资源调度 |
|
||
| ONNX | 开放神经网络交换格式(Open Neural Network Exchange),统一不同框架模型的表示格式 |
|
||
| MLflow | 开源的机器学习生命周期管理平台,提供实验追踪、模型注册、版本管理功能 |
|
||
| Celery | Python分布式任务队列框架,支持异步任务和定时任务 |
|
||
| CTC | 连接时序分类(Connectionist Temporal Classification),OCR解码算法之一 |
|
||
| BERT | 双向编码器表示(Bidirectional Encoder Representations from Transformers),NLP预训练模型 |
|
||
| mTLS | 双向TLS认证,通信双方均需出示证书,用于服务间高安全性认证 |
|
||
| 置信度 | 模型对识别结果的确信程度,取值0-1,值越大表示结果越可靠 |
|
||
| 金丝雀发布 | 灰度发布策略,新版本先接受少量流量,验证稳定后再全量切换 |
|
||
|
||
## 附录B 版本历史
|
||
|
||
| 版本号 | 发布日期 | 变更说明 |
|
||
|-------|---------|---------|
|
||
| V1.0 | 2026年2月 | 初始版本发布,支持OCR/数学/笔顺/书写质量/作文评分全套功能 |
|
||
|
||
---
|
||
|
||
**编制单位**:深圳自然写科技有限公司
|
||
**文档版本**:V1.0
|
||
**编制日期**:2026年2月
|
||
**版权声明**:本文档版权归深圳自然写科技有限公司所有,未经授权不得复制或传播
|
||
|
||
---
|
||
|
||
## 附录C AI 模型详细技术说明
|
||
|
||
### C.1 PaddleOCR 手写识别模型架构
|
||
|
||
#### C.1.1 模型总体架构
|
||
|
||
自然写 AI 引擎的手写识别模块基于 PaddleOCR 框架,采用 DB+CRNN 两阶段识别流水线:
|
||
|
||
```
|
||
输入:笔迹图像(灰度图,224×224)
|
||
│
|
||
▼
|
||
Stage 1: 文本检测(DB - Differentiable Binarization)
|
||
│ 检测文字区域边界框
|
||
│ 骨干网络:ResNet-50 + FPN 特征金字塔
|
||
▼
|
||
文字区域裁剪(透视变换矫正倾斜)
|
||
│
|
||
▼
|
||
Stage 2: 文字识别(CRNN - CNN+RNN)
|
||
│ CNN(VGG-like)提取特征列
|
||
│ 双向 LSTM 序列建模
|
||
│ CTC 解码(Connectionist Temporal Classification)
|
||
▼
|
||
输出:识别文字字符串 + 置信度
|
||
```
|
||
|
||
#### C.1.2 训练数据集
|
||
|
||
| 数据集 | 数量 | 来源 |
|
||
|-------|------|------|
|
||
| 学生楷书练字数据 | 500万字 | 自然写平台采集(脱敏处理) |
|
||
| 学生硬笔书法数据 | 200万字 | 自然写平台采集 |
|
||
| CASIA 手写中文数据集 | 300万字 | 中科院开放数据集 |
|
||
| 合成数据(字体变形)| 1000万字 | 程序合成 |
|
||
| 数字与字母 | 50万样本 | 多来源混合 |
|
||
|
||
**数据增强策略:**
|
||
- 随机旋转(±15°)
|
||
- 随机缩放(0.8x ~ 1.2x)
|
||
- 高斯噪声(模拟点阵摄像头图像噪声)
|
||
- 透视变换(模拟书写角度偏差)
|
||
- 随机笔画粗细变化
|
||
- 随机对比度调整(模拟不同笔压)
|
||
|
||
#### C.1.3 CRNN 模型实现
|
||
|
||
```python
|
||
# crnn_model.py
|
||
import paddle
|
||
import paddle.nn as nn
|
||
|
||
class CRNN(nn.Layer):
|
||
"""
|
||
手写文字序列识别模型
|
||
CNN 提取特征 + BiLSTM 序列建模 + CTC 解码
|
||
"""
|
||
def __init__(self, num_classes: int, hidden_size: int = 256):
|
||
super(CRNN, self).__init__()
|
||
|
||
# CNN 特征提取(输出特征尺寸:N × 512 × 1 × W)
|
||
self.cnn = nn.Sequential(
|
||
# Block 1: 64 filters
|
||
nn.Conv2D(1, 64, kernel_size=3, padding=1),
|
||
nn.BatchNorm2D(64),
|
||
nn.ReLU(),
|
||
nn.MaxPool2D(kernel_size=(2, 2), stride=(2, 2)), # 32x32 -> 16x16
|
||
|
||
# Block 2: 128 filters
|
||
nn.Conv2D(64, 128, kernel_size=3, padding=1),
|
||
nn.BatchNorm2D(128),
|
||
nn.ReLU(),
|
||
nn.MaxPool2D(kernel_size=(2, 2), stride=(2, 2)), # 16x16 -> 8x8
|
||
|
||
# Block 3: 256 filters(不在高度方向池化)
|
||
nn.Conv2D(128, 256, kernel_size=3, padding=1),
|
||
nn.BatchNorm2D(256),
|
||
nn.ReLU(),
|
||
nn.Conv2D(256, 256, kernel_size=3, padding=1),
|
||
nn.BatchNorm2D(256),
|
||
nn.ReLU(),
|
||
nn.MaxPool2D(kernel_size=(2, 1), stride=(2, 1)), # 8x8 -> 4xW
|
||
|
||
# Block 4: 512 filters
|
||
nn.Conv2D(256, 512, kernel_size=3, padding=1),
|
||
nn.BatchNorm2D(512),
|
||
nn.ReLU(),
|
||
nn.MaxPool2D(kernel_size=(2, 1), stride=(2, 1)), # 4xW -> 2xW
|
||
|
||
# Block 5: 512 filters,压缩高度为1
|
||
nn.Conv2D(512, 512, kernel_size=2, padding=0), # 2xW -> 1xW
|
||
nn.BatchNorm2D(512),
|
||
nn.ReLU(),
|
||
)
|
||
|
||
# BiLSTM 序列建模
|
||
self.rnn = nn.Sequential(
|
||
BidirectionalLSTM(512, hidden_size, hidden_size),
|
||
BidirectionalLSTM(hidden_size, hidden_size, num_classes),
|
||
)
|
||
|
||
def forward(self, x):
|
||
# x: (N, 1, H, W)
|
||
conv = self.cnn(x) # (N, 512, 1, W)
|
||
N, C, H, W = conv.shape
|
||
# 重塑为序列:(W, N, C)
|
||
conv = conv.squeeze(2).transpose([2, 0, 1])
|
||
output = self.rnn(conv) # (W, N, num_classes)
|
||
return output
|
||
|
||
|
||
class BidirectionalLSTM(nn.Layer):
|
||
def __init__(self, input_size, hidden_size, output_size):
|
||
super(BidirectionalLSTM, self).__init__()
|
||
self.lstm = nn.LSTM(input_size, hidden_size,
|
||
direction='bidirect', time_major=True)
|
||
self.linear = nn.Linear(hidden_size * 2, output_size)
|
||
|
||
def forward(self, x):
|
||
output, _ = self.lstm(x)
|
||
output = self.linear(output)
|
||
return output
|
||
```
|
||
|
||
#### C.1.4 CTC 解码算法
|
||
|
||
```python
|
||
# ctc_decoder.py
|
||
import numpy as np
|
||
|
||
class CTCDecoder:
|
||
"""
|
||
CTC(Connectionist Temporal Classification)解码器
|
||
支持贪心解码和束搜索解码
|
||
"""
|
||
|
||
def __init__(self, vocabulary: list[str], blank_token_id: int = 0):
|
||
self.vocabulary = vocabulary # 字典(汉字列表)
|
||
self.blank_id = blank_token_id
|
||
|
||
def greedy_decode(self, log_probs: np.ndarray) -> tuple[str, float]:
|
||
"""
|
||
贪心解码(取每帧最大概率字符)
|
||
|
||
log_probs: (T, V) - T帧,V个字符的对数概率
|
||
返回: (识别文字, 平均置信度)
|
||
"""
|
||
# 每帧取最大概率的字符索引
|
||
indices = np.argmax(log_probs, axis=1) # (T,)
|
||
probs = np.exp(np.max(log_probs, axis=1)) # 对应概率
|
||
|
||
# 合并重复字符并去除 blank
|
||
decoded_chars = []
|
||
decoded_probs = []
|
||
prev = -1
|
||
for i, idx in enumerate(indices):
|
||
if idx != prev and idx != self.blank_id:
|
||
decoded_chars.append(self.vocabulary[idx])
|
||
decoded_probs.append(probs[i])
|
||
prev = idx
|
||
|
||
text = ''.join(decoded_chars)
|
||
confidence = float(np.mean(decoded_probs)) if decoded_probs else 0.0
|
||
return text, confidence
|
||
|
||
def beam_search_decode(self, log_probs: np.ndarray,
|
||
beam_width: int = 10) -> list[tuple[str, float]]:
|
||
"""
|
||
束搜索解码(返回 top-k 候选字符串)
|
||
|
||
beam_width: 束宽度(返回候选数量)
|
||
返回: [(候选文字, 分数)] 列表(按分数降序)
|
||
"""
|
||
T, V = log_probs.shape
|
||
|
||
# 初始化:空字符串,得分为0
|
||
beams = [("", 0.0, -1)] # (text, score, last_token)
|
||
|
||
for t in range(T):
|
||
new_beams = {}
|
||
for text, score, last_token in beams:
|
||
# 扩展每个 beam
|
||
for v in range(V):
|
||
log_p = log_probs[t, v]
|
||
if v == self.blank_id:
|
||
# blank:保持文字不变,得分累加
|
||
new_text = text
|
||
elif v == last_token:
|
||
# 重复字符:只在中间有 blank 时才追加
|
||
new_text = text
|
||
else:
|
||
# 新字符
|
||
new_text = text + self.vocabulary[v]
|
||
|
||
key = (new_text, v)
|
||
if key not in new_beams:
|
||
new_beams[key] = float('-inf')
|
||
# log-sum-exp 合并相同结果
|
||
new_beams[key] = np.logaddexp(new_beams[key], score + log_p)
|
||
|
||
# 取 top beam_width
|
||
sorted_beams = sorted(new_beams.items(), key=lambda x: -x[1])[:beam_width]
|
||
beams = [(text, score, last_token) for (text, last_token), score in sorted_beams]
|
||
|
||
return [(text, np.exp(score)) for text, score, _ in beams]
|
||
```
|
||
|
||
---
|
||
|
||
### C.2 数学公式识别模型
|
||
|
||
#### C.2.1 模型架构概述
|
||
|
||
数学公式识别采用 Im2Latex 序列到序列模型(CNN + Attention LSTM),将手写数学表达式图像直接转换为 LaTeX 字符串:
|
||
|
||
```python
|
||
# math_ocr_model.py
|
||
class Im2LatexModel(nn.Layer):
|
||
"""
|
||
数学公式图像→LaTeX 序列模型
|
||
基于 Attention 机制的编解码架构
|
||
"""
|
||
|
||
def __init__(self, vocab_size: int, embed_size: int = 80,
|
||
decode_hidden: int = 512, encode_hidden: int = 256):
|
||
super(Im2LatexModel, self).__init__()
|
||
|
||
# 编码器:CNN(提取视觉特征)
|
||
self.encoder = nn.Sequential(
|
||
nn.Conv2D(1, 64, kernel_size=3, padding=1), nn.ReLU(),
|
||
nn.MaxPool2D(2, 2),
|
||
nn.Conv2D(64, 128, kernel_size=3, padding=1), nn.ReLU(),
|
||
nn.MaxPool2D(2, 2),
|
||
nn.Conv2D(128, 256, kernel_size=3, padding=1), nn.ReLU(),
|
||
nn.Conv2D(256, 256, kernel_size=3, padding=1), nn.ReLU(),
|
||
nn.MaxPool2D((2, 1)), # 保留水平分辨率
|
||
nn.Conv2D(256, encode_hidden, kernel_size=3, padding=1), nn.ReLU(),
|
||
)
|
||
|
||
# 解码器:LSTM + Attention
|
||
self.decoder = AttentionDecoder(
|
||
encode_hidden, decode_hidden, embed_size, vocab_size
|
||
)
|
||
|
||
def forward(self, images, targets=None):
|
||
# 编码图像特征
|
||
features = self.encoder(images) # (N, C, H', W')
|
||
# 解码序列(训练时使用 teacher forcing)
|
||
logits, attention_weights = self.decoder(features, targets)
|
||
return logits, attention_weights
|
||
```
|
||
|
||
#### C.2.2 数学算式语义校验
|
||
|
||
识别完数学表达式后,系统进行语义校验(判断等式/不等式是否成立):
|
||
|
||
```python
|
||
# math_validator.py
|
||
import re
|
||
from decimal import Decimal, InvalidOperation
|
||
|
||
class MathExpressionValidator:
|
||
"""
|
||
对识别到的数学表达式进行语义校验
|
||
支持:四则运算、分数、简单方程验证
|
||
"""
|
||
|
||
def validate(self, expression: str) -> dict:
|
||
"""
|
||
校验数学表达式
|
||
|
||
expression: 识别结果(如 "3 + 5 = 8" 或 "2 × 4 = 8")
|
||
返回: {'is_correct': bool, 'expected': str, 'explanation': str}
|
||
"""
|
||
# 标准化符号(×→*, ÷→/)
|
||
normalized = self._normalize_symbols(expression)
|
||
|
||
# 识别等号或不等号
|
||
if '=' in normalized:
|
||
return self._validate_equation(normalized)
|
||
elif '>' in normalized or '<' in normalized:
|
||
return self._validate_inequality(normalized)
|
||
else:
|
||
return {'is_correct': None, 'explanation': '无法判断(表达式不包含等式)'}
|
||
|
||
def _validate_equation(self, expr: str) -> dict:
|
||
parts = expr.split('=')
|
||
if len(parts) != 2:
|
||
return {'is_correct': False, 'explanation': '格式错误'}
|
||
|
||
left_expr, right_expr = parts[0].strip(), parts[1].strip()
|
||
|
||
try:
|
||
left_val = self._safe_eval(left_expr)
|
||
right_val = self._safe_eval(right_expr)
|
||
is_correct = abs(left_val - right_val) < 1e-9
|
||
|
||
return {
|
||
'is_correct': is_correct,
|
||
'left_value': str(left_val),
|
||
'right_value': str(right_val),
|
||
'expected': f"{left_expr} = {left_val}",
|
||
'explanation': '正确' if is_correct else f'等号左边={left_val},右边={right_val},不相等'
|
||
}
|
||
except Exception as e:
|
||
return {'is_correct': False, 'explanation': f'计算出错:{str(e)}'}
|
||
|
||
def _safe_eval(self, expr: str) -> Decimal:
|
||
"""安全的数学表达式求值(防注入攻击)"""
|
||
# 只允许数字、运算符、括号、小数点
|
||
if not re.match(r'^[\d\s\+\-\*\/\(\)\.]+$', expr):
|
||
raise ValueError(f"不安全的表达式:{expr}")
|
||
# 使用 Decimal 保证精度
|
||
return Decimal(str(eval(expr)))
|
||
|
||
def _normalize_symbols(self, expr: str) -> str:
|
||
return (expr.replace('×', '*').replace('÷', '/')
|
||
.replace('=', '=').replace('−', '-'))
|
||
```
|
||
|
||
---
|
||
|
||
### C.3 笔顺评分模型
|
||
|
||
#### C.3.1 笔顺检测算法
|
||
|
||
```python
|
||
# stroke_order_evaluator.py
|
||
import numpy as np
|
||
from dataclasses import dataclass
|
||
|
||
@dataclass
|
||
class StrokeFeatures:
|
||
"""笔画特征向量"""
|
||
start_x: float # 起点 x 坐标(归一化 0~1)
|
||
start_y: float # 起点 y 坐标
|
||
end_x: float # 终点 x 坐标
|
||
end_y: float # 终点 y 坐标
|
||
direction_angle: float # 主方向角度(0~360度)
|
||
length: float # 笔画长度(归一化)
|
||
curvature: float # 曲率(越小越直)
|
||
|
||
class StrokeOrderEvaluator:
|
||
"""
|
||
汉字笔顺评分器
|
||
基于标准笔顺库比对学生书写笔顺
|
||
"""
|
||
|
||
def __init__(self, stroke_library_path: str):
|
||
"""加载标准笔顺库(包含 8105 个通用规范汉字的笔顺数据)"""
|
||
with open(stroke_library_path, 'r', encoding='utf-8') as f:
|
||
self.library = json.load(f)
|
||
|
||
def evaluate(self, character: str,
|
||
written_strokes: list,
|
||
strict_mode: bool = False) -> dict:
|
||
"""
|
||
评估学生书写笔顺
|
||
|
||
written_strokes: 学生书写的笔画序列(每条笔画为 InkPoint 列表)
|
||
返回: 详细评分结果
|
||
"""
|
||
if character not in self.library:
|
||
return {'error': f'字符 {character} 不在笔顺库中'}
|
||
|
||
standard_strokes = self.library[character]['strokes']
|
||
standard_count = len(standard_strokes)
|
||
written_count = len(written_strokes)
|
||
|
||
# 笔画数量对比
|
||
stroke_count_score = 100 if written_count == standard_count else max(
|
||
0, 100 - abs(written_count - standard_count) * 20
|
||
)
|
||
|
||
# 逐笔顺序比对(对齐最近的标准笔画)
|
||
order_errors = []
|
||
for i, written in enumerate(written_strokes):
|
||
if i >= standard_count:
|
||
order_errors.append({'stroke': i+1, 'error': '多余的笔画'})
|
||
continue
|
||
|
||
written_feat = self._extract_features(written)
|
||
expected_feat = standard_strokes[i]['features']
|
||
|
||
# 方向角度偏差(容忍度:严格模式15°,普通模式30°)
|
||
angle_diff = self._angle_diff(written_feat.direction_angle,
|
||
expected_feat['direction_angle'])
|
||
tolerance = 15 if strict_mode else 30
|
||
|
||
if angle_diff > tolerance:
|
||
order_errors.append({
|
||
'stroke': i+1,
|
||
'error': f'方向错误(应为{expected_feat["name"]},偏差{angle_diff:.1f}°)'
|
||
})
|
||
|
||
# 计算笔顺得分
|
||
error_count = len(order_errors)
|
||
order_score = max(0, 40 - error_count * 8) # 笔顺满分40分,每错一笔扣8分
|
||
|
||
# 字形得分(基于关键点位置偏差)
|
||
shape_score = self._evaluate_shape(character, written_strokes)
|
||
|
||
# 比例得分(基于整体字形比例)
|
||
proportion_score = self._evaluate_proportion(character, written_strokes)
|
||
|
||
total_score = order_score + shape_score + proportion_score
|
||
|
||
return {
|
||
'total_score': min(100, total_score),
|
||
'stroke_order_score': order_score,
|
||
'shape_score': shape_score,
|
||
'proportion_score': proportion_score,
|
||
'stroke_count_match': written_count == standard_count,
|
||
'errors': order_errors,
|
||
'grade': self._score_to_grade(total_score)
|
||
}
|
||
|
||
def _score_to_grade(self, score: float) -> str:
|
||
if score >= 90: return '优秀'
|
||
elif score >= 75: return '良好'
|
||
elif score >= 60: return '合格'
|
||
else: return '需加强'
|
||
|
||
def _angle_diff(self, a1: float, a2: float) -> float:
|
||
"""计算两个角度的最小绝对差值(处理360°环绕)"""
|
||
diff = abs(a1 - a2) % 360
|
||
return min(diff, 360 - diff)
|
||
|
||
def _extract_features(self, stroke_points: list) -> StrokeFeatures:
|
||
"""从笔画点序列提取特征向量"""
|
||
if len(stroke_points) < 2:
|
||
return StrokeFeatures(0, 0, 0, 0, 0, 0, 0)
|
||
|
||
xs = [p['x'] for p in stroke_points]
|
||
ys = [p['y'] for p in stroke_points]
|
||
|
||
# 主方向:起点到终点的方向角
|
||
dx = xs[-1] - xs[0]
|
||
dy = ys[-1] - ys[0]
|
||
angle = np.degrees(np.arctan2(dy, dx)) % 360
|
||
length = np.sqrt(dx**2 + dy**2)
|
||
|
||
# 曲率:用笔画路径长度/直线长度比估算
|
||
path_length = sum(
|
||
np.sqrt((xs[i+1]-xs[i])**2 + (ys[i+1]-ys[i])**2)
|
||
for i in range(len(xs)-1)
|
||
)
|
||
curvature = path_length / max(length, 1e-6)
|
||
|
||
return StrokeFeatures(
|
||
start_x=xs[0], start_y=ys[0],
|
||
end_x=xs[-1], end_y=ys[-1],
|
||
direction_angle=angle,
|
||
length=length,
|
||
curvature=curvature
|
||
)
|
||
```
|
||
|
||
---
|
||
|
||
### C.4 模型服务化与部署
|
||
|
||
#### C.4.1 模型热加载机制
|
||
|
||
```python
|
||
# model_manager.py
|
||
class ModelManager:
|
||
"""
|
||
AI 模型管理器
|
||
支持模型热加载(不停服更新)和 A/B 测试
|
||
"""
|
||
|
||
def __init__(self, model_registry_path: str):
|
||
self.registry_path = model_registry_path
|
||
self._models: dict[str, dict] = {}
|
||
self._lock = asyncio.Lock()
|
||
|
||
# 启动文件监听,自动检测模型更新
|
||
self._start_model_watcher()
|
||
|
||
async def load_model(self, model_name: str, model_version: str):
|
||
"""加载指定版本的模型到内存"""
|
||
model_path = f"{self.registry_path}/{model_name}/{model_version}"
|
||
|
||
async with self._lock:
|
||
if model_name in self._models:
|
||
# 先保留旧模型(用于 A/B 对比或回滚)
|
||
old_model = self._models[model_name]
|
||
self._models[f"{model_name}_backup"] = old_model
|
||
|
||
# 异步加载新模型
|
||
model = await asyncio.get_event_loop().run_in_executor(
|
||
None, self._load_from_disk, model_path
|
||
)
|
||
|
||
self._models[model_name] = {
|
||
'model': model,
|
||
'version': model_version,
|
||
'loaded_at': time.time(),
|
||
'call_count': 0,
|
||
'total_latency': 0.0
|
||
}
|
||
|
||
logger.info(f"模型 {model_name} v{model_version} 加载完成")
|
||
|
||
def get_model(self, model_name: str, use_ab_test: bool = False):
|
||
"""获取当前活跃模型"""
|
||
if use_ab_test and f"{model_name}_backup" in self._models:
|
||
# A/B 测试:10% 流量路由到旧模型
|
||
if random.random() < 0.1:
|
||
return self._models[f"{model_name}_backup"]['model']
|
||
return self._models[model_name]['model']
|
||
```
|
||
|
||
#### C.4.2 API 限流与队列
|
||
|
||
```python
|
||
# rate_limiter.py(AI 引擎请求限流)
|
||
import asyncio
|
||
from collections import defaultdict
|
||
|
||
class TokenBucketRateLimiter:
|
||
"""
|
||
令牌桶算法限流器
|
||
防止 AI 推理服务被突发请求打垮
|
||
"""
|
||
|
||
def __init__(self, rate: float, capacity: float):
|
||
"""
|
||
rate: 令牌补充速率(请求/秒)
|
||
capacity: 桶容量(最大突发量)
|
||
"""
|
||
self.rate = rate
|
||
self.capacity = capacity
|
||
self.tokens: dict[str, float] = defaultdict(lambda: capacity)
|
||
self.last_refill: dict[str, float] = defaultdict(time.time)
|
||
self._lock = asyncio.Lock()
|
||
|
||
async def acquire(self, key: str, tokens: float = 1.0) -> bool:
|
||
"""
|
||
尝试获取令牌
|
||
key: 限流维度(如 app_key 或 user_id)
|
||
返回: True=可以执行,False=被限流
|
||
"""
|
||
async with self._lock:
|
||
now = time.time()
|
||
elapsed = now - self.last_refill[key]
|
||
|
||
# 按时间补充令牌
|
||
self.tokens[key] = min(
|
||
self.capacity,
|
||
self.tokens[key] + elapsed * self.rate
|
||
)
|
||
self.last_refill[key] = now
|
||
|
||
if self.tokens[key] >= tokens:
|
||
self.tokens[key] -= tokens
|
||
return True
|
||
return False
|
||
|
||
|
||
# 在 AI 识别接口中使用限流器
|
||
class OcrController:
|
||
def __init__(self):
|
||
self.rate_limiter = TokenBucketRateLimiter(
|
||
rate=50, # 50 请求/秒
|
||
capacity=100 # 最大突发 100 个请求
|
||
)
|
||
|
||
async def recognize(self, request: OcrRequest) -> OcrResponse:
|
||
# 按 AppKey 限流
|
||
allowed = await self.rate_limiter.acquire(request.app_key)
|
||
if not allowed:
|
||
raise RateLimitExceededException("请求频率超限,请降低调用频率")
|
||
|
||
# 执行识别
|
||
return await self.ocr_service.recognize(request)
|
||
```
|
||
|
||
---
|
||
|
||
## 附录D 接口完整清单
|
||
|
||
### D.1 识别接口
|
||
|
||
| 接口 | 方法 | 路径 | QPS限制 | 说明 |
|
||
|------|------|------|---------|------|
|
||
| 汉字识别 | POST | `/api/v1/ocr/text` | 50/s/AppKey | 识别手写汉字 |
|
||
| 数学识别 | POST | `/api/v1/ocr/math` | 30/s/AppKey | 识别数学表达式 |
|
||
| 批量识别 | POST | `/api/v1/ocr/batch` | 10/s/AppKey | 批量识别(最多20条/请求) |
|
||
| 笔顺评分 | POST | `/api/v1/ocr/stroke-order` | 30/s/AppKey | 汉字笔顺评估与评分 |
|
||
| 书写质量 | POST | `/api/v1/ocr/quality` | 30/s/AppKey | 书写质量综合评分 |
|
||
| 作业批改 | POST | `/api/v1/correction/assignment` | 5/s/AppKey | 完整作业批改流程 |
|
||
|
||
### D.2 管理接口
|
||
|
||
| 接口 | 方法 | 路径 | 说明 |
|
||
|------|------|------|------|
|
||
| 查询识别配额 | GET | `/api/v1/quota/current` | 查询当前 AppKey 的识别配额余量 |
|
||
| 查询调用统计 | GET | `/api/v1/statistics/calls` | 按时间段统计API调用次数和成功率 |
|
||
| 获取模型版本 | GET | `/api/v1/models/versions` | 查询当前生产模型版本信息 |
|
||
| 异步任务状态 | GET | `/api/v1/tasks/{task_id}` | 查询异步批改任务的执行状态 |
|
||
|
||
---
|
||
|
||
## 附录E 部署与性能
|
||
|
||
### E.1 GPU 推理服务部署
|
||
|
||
```yaml
|
||
# docker-compose.gpu.yml(GPU 推理节点)
|
||
services:
|
||
ai-engine:
|
||
image: registry.writech.com/ai-engine:1.0.0
|
||
runtime: nvidia
|
||
environment:
|
||
NVIDIA_VISIBLE_DEVICES: all
|
||
PADDLE_FLAGS: "FLAGS_fraction_of_gpu_memory_to_use=0.8"
|
||
MODEL_BATCH_SIZE: 16
|
||
MAX_CONCURRENT_REQUESTS: 64
|
||
volumes:
|
||
- /data/models:/app/models:ro # 只读挂载模型目录
|
||
deploy:
|
||
resources:
|
||
reservations:
|
||
devices:
|
||
- driver: nvidia
|
||
count: 1
|
||
capabilities: [gpu]
|
||
```
|
||
|
||
### E.2 推理性能基准测试
|
||
|
||
| 模型 | 硬件 | 批次大小 | 平均延迟 | 吞吐量 |
|
||
|------|------|---------|---------|-------|
|
||
| 汉字识别(CRNN) | Tesla T4 | 16 | 8ms/字 | 2000字/秒 |
|
||
| 数学识别(Im2Latex) | Tesla T4 | 8 | 25ms/式 | 320式/秒 |
|
||
| 笔顺评分 | CPU(16核) | 1 | 5ms/字 | 200字/秒 |
|
||
| 书写质量(综合) | Tesla T4 | 16 | 15ms/字 | 1000字/秒 |
|
||
|
||
---
|
||
|
||
*本文档版权归深圳自然写科技有限公司所有,仅用于软件著作权登记鉴别。*
|
||
|
||
---
|
||
|
||
## 附录F 核心算法详细实现
|
||
|
||
### F.1 DB文本检测网络实现
|
||
|
||
AI引擎使用DB(Differentiable Binarization)算法进行文字区域检测,相比传统固定阈值二值化,DB通过可学习的阈值映射提升检测准确率。
|
||
|
||
```python
|
||
# engine/detection/db_detector.py
|
||
import numpy as np
|
||
import cv2
|
||
import onnxruntime as ort
|
||
from typing import List, Tuple
|
||
|
||
class DBTextDetector:
|
||
"""
|
||
DB文本检测器(基于ONNX模型推理)
|
||
输入:RGB图像(归一化到[-1,1])
|
||
输出:文字区域轮廓列表(像素坐标)
|
||
"""
|
||
|
||
# 预处理参数(训练时统计的均值和标准差)
|
||
IMG_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
|
||
IMG_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
|
||
|
||
# DB后处理参数
|
||
DB_THRESH = 0.3 # 概率图二值化阈值
|
||
DB_BOX_THRESH = 0.5 # 检测框置信度阈值
|
||
DB_UNCLIP_RATIO = 1.6 # 文字框扩张比例
|
||
|
||
def __init__(self, model_path: str):
|
||
opts = ort.SessionOptions()
|
||
opts.intra_op_num_threads = 4
|
||
opts.inter_op_num_threads = 2
|
||
opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||
self.session = ort.InferenceSession(
|
||
model_path,
|
||
sess_options=opts,
|
||
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
|
||
)
|
||
self.input_name = self.session.get_inputs()[0].name
|
||
|
||
def detect(self, image_bgr: np.ndarray) -> List[np.ndarray]:
|
||
"""
|
||
检测图片中的文字区域
|
||
Returns:
|
||
文字区域轮廓列表,每个轮廓为(N, 2)形状的numpy数组
|
||
"""
|
||
# 1. 预处理:等比缩放到32的倍数
|
||
h, w = image_bgr.shape[:2]
|
||
target_h = self._align_32(min(960, h))
|
||
target_w = self._align_32(min(960, w))
|
||
scale_h, scale_w = h / target_h, w / target_w
|
||
|
||
resized = cv2.resize(image_bgr, (target_w, target_h))
|
||
img_rgb = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
|
||
img_norm = (img_rgb - self.IMG_MEAN) / self.IMG_STD
|
||
input_tensor = img_norm.transpose(2, 0, 1)[np.newaxis] # NCHW
|
||
|
||
# 2. 模型推理
|
||
outputs = self.session.run(None, {self.input_name: input_tensor})
|
||
prob_map = outputs[0][0, 0] # (H, W) 概率图
|
||
|
||
# 3. DB后处理:概率图 → 二值图 → 轮廓提取 → Unclip扩张
|
||
binary_map = (prob_map > self.DB_THRESH).astype(np.uint8) * 255
|
||
contours, _ = cv2.findContours(binary_map, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
|
||
|
||
text_regions = []
|
||
for cnt in contours:
|
||
if cv2.contourArea(cnt) < 100: # 过滤极小区域
|
||
continue
|
||
|
||
# 计算轮廓置信度(取概率图在轮廓内的均值)
|
||
mask = np.zeros_like(prob_map, dtype=np.uint8)
|
||
cv2.drawContours(mask, [cnt], 0, 1, -1)
|
||
box_score = float(prob_map[mask == 1].mean())
|
||
|
||
if box_score < self.DB_BOX_THRESH:
|
||
continue
|
||
|
||
# Unclip扩张(扩大文字框,避免漏识别边缘字符)
|
||
expanded = self._unclip(cnt, self.DB_UNCLIP_RATIO)
|
||
|
||
# 还原到原始图像坐标
|
||
expanded[:, 0] = np.clip(expanded[:, 0] * scale_w, 0, w - 1)
|
||
expanded[:, 1] = np.clip(expanded[:, 1] * scale_h, 0, h - 1)
|
||
text_regions.append(expanded.astype(np.int32))
|
||
|
||
return text_regions
|
||
|
||
def _unclip(self, contour: np.ndarray, ratio: float) -> np.ndarray:
|
||
"""使用Polygon Offset算法扩张文字框"""
|
||
import pyclipper
|
||
poly = contour.reshape(-1, 2).astype(np.float32)
|
||
area = cv2.contourArea(contour)
|
||
peri = cv2.arcLength(contour, True)
|
||
if peri < 1e-6:
|
||
return poly
|
||
distance = area * ratio / peri
|
||
|
||
pco = pyclipper.PyclipperOffset()
|
||
pco.AddPath(poly.astype(np.int32).tolist(), pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
|
||
solution = pco.Execute(int(distance))
|
||
if not solution:
|
||
return poly
|
||
return np.array(solution[0], dtype=np.float32)
|
||
|
||
@staticmethod
|
||
def _align_32(n: int) -> int:
|
||
return max(32, (n + 31) // 32 * 32)
|
||
```
|
||
|
||
### F.2 CRNN文字识别实现
|
||
|
||
```python
|
||
# engine/recognition/crnn_recognizer.py
|
||
import numpy as np
|
||
import onnxruntime as ort
|
||
from typing import Tuple, List
|
||
|
||
class CRNNRecognizer:
|
||
"""
|
||
CRNN文字识别器(CNN + LSTM + CTC解码)
|
||
输入:文字行图像(归一化到32×320)
|
||
输出:识别文本 + 置信度
|
||
"""
|
||
|
||
IMG_HEIGHT = 32
|
||
IMG_WIDTH = 320
|
||
|
||
def __init__(self, model_path: str, charset_path: str):
|
||
self.session = ort.InferenceSession(
|
||
model_path,
|
||
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
|
||
)
|
||
self.input_name = self.session.get_inputs()[0].name
|
||
|
||
# 加载字符集(包含空白符)
|
||
with open(charset_path, 'r', encoding='utf-8') as f:
|
||
self.charset = ['BLANK'] + [c.strip() for c in f.readlines()]
|
||
self.blank_idx = 0
|
||
|
||
def recognize(self, line_image_bgr: np.ndarray) -> Tuple[str, float]:
|
||
"""
|
||
识别单行文字图像
|
||
Returns:
|
||
(text, confidence) 识别文本和平均置信度
|
||
"""
|
||
# 预处理:灰度化,等比缩放到32高度,宽度最大320
|
||
import cv2
|
||
gray = cv2.cvtColor(line_image_bgr, cv2.COLOR_BGR2GRAY)
|
||
h, w = gray.shape
|
||
target_w = min(self.IMG_WIDTH, int(w * self.IMG_HEIGHT / h))
|
||
resized = cv2.resize(gray, (target_w, self.IMG_HEIGHT))
|
||
|
||
# 填充到标准宽度
|
||
canvas = np.zeros((self.IMG_HEIGHT, self.IMG_WIDTH), dtype=np.float32)
|
||
canvas[:, :target_w] = resized.astype(np.float32)
|
||
canvas = (canvas / 255.0 - 0.5) / 0.5 # 归一化到[-1, 1]
|
||
|
||
# NCHW格式(1, 1, 32, 320)
|
||
input_tensor = canvas[np.newaxis, np.newaxis]
|
||
|
||
# 推理:输出shape为 (T, 1, num_classes),T是时间步
|
||
logits = self.session.run(None, {self.input_name: input_tensor})[0]
|
||
probs = self._softmax(logits[:, 0, :]) # (T, num_classes)
|
||
|
||
# CTC贪心解码
|
||
text, confidence = self._ctc_greedy_decode(probs)
|
||
return text, confidence
|
||
|
||
def _ctc_greedy_decode(self, probs: np.ndarray) -> Tuple[str, float]:
|
||
"""CTC贪心解码(逐时间步取最大概率)"""
|
||
indices = np.argmax(probs, axis=-1) # (T,)
|
||
confs = probs[np.arange(len(indices)), indices]
|
||
|
||
# 折叠重复并移除blank
|
||
chars = []
|
||
conf_list = []
|
||
prev = -1
|
||
for i, (idx, conf) in enumerate(zip(indices, confs)):
|
||
if idx != prev and idx != self.blank_idx:
|
||
if 0 < idx < len(self.charset):
|
||
chars.append(self.charset[idx])
|
||
conf_list.append(float(conf))
|
||
prev = idx
|
||
|
||
text = ''.join(chars)
|
||
avg_conf = float(np.mean(conf_list)) if conf_list else 0.0
|
||
return text, avg_conf
|
||
|
||
@staticmethod
|
||
def _softmax(x: np.ndarray) -> np.ndarray:
|
||
e = np.exp(x - x.max(axis=-1, keepdims=True))
|
||
return e / e.sum(axis=-1, keepdims=True)
|
||
```
|
||
|
||
### F.3 笔顺评分模型推理
|
||
|
||
```python
|
||
# engine/stroke_eval/stroke_evaluator.py
|
||
import numpy as np
|
||
import onnxruntime as ort
|
||
from typing import List, Dict
|
||
|
||
class StrokeOrderEvaluator:
|
||
"""
|
||
笔顺评分引擎
|
||
- 输入:归一化的笔迹序列(时序坐标点)
|
||
- 模型:双向LSTM,输出每笔笔顺正误概率
|
||
- 配合BKT校准:根据学生历史正确率校准当前评分
|
||
"""
|
||
|
||
MAX_STROKES = 30 # 最多30笔
|
||
MAX_POINTS = 50 # 每笔最多50点
|
||
FEATURE_DIM = 6 # 特征维度:(x, y, dx, dy, pressure, is_last_point)
|
||
|
||
def __init__(self, model_path: str):
|
||
self.session = ort.InferenceSession(
|
||
model_path,
|
||
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
|
||
)
|
||
|
||
def evaluate(self, strokes: List[List[Dict]],
|
||
character: str,
|
||
bkt_mastery: float = 0.5) -> Dict:
|
||
"""
|
||
评估笔顺质量
|
||
Args:
|
||
strokes: [[{'x':...,'y':...,'p':...},...], ...] 各笔笔迹点
|
||
character: 被书写的字符
|
||
bkt_mastery: 学生当前掌握度(用于校准)
|
||
Returns:
|
||
{'score': float, 'errors': list, 'feedback': str}
|
||
"""
|
||
# 特征提取
|
||
features = self._extract_features(strokes) # (S, P, F)
|
||
|
||
# 填充到固定尺寸 (MAX_STROKES, MAX_POINTS, FEATURE_DIM)
|
||
padded = np.zeros((self.MAX_STROKES, self.MAX_POINTS, self.FEATURE_DIM),
|
||
dtype=np.float32)
|
||
s = min(len(features), self.MAX_STROKES)
|
||
padded[:s] = features[:s]
|
||
|
||
# 模型推理
|
||
input_tensor = padded[np.newaxis] # (1, S, P, F)
|
||
stroke_probs = self.session.run(None, {'input': input_tensor})[0][0] # (S, 2)
|
||
|
||
# 计算各笔正确率
|
||
n_strokes = min(len(strokes), self.MAX_STROKES)
|
||
correct_probs = stroke_probs[:n_strokes, 1] # 正确类概率
|
||
|
||
# 综合评分(加权均值,前几笔权重略高)
|
||
weights = np.array([1.2 if i < 3 else 1.0 for i in range(n_strokes)])
|
||
raw_score = float(np.average(correct_probs, weights=weights))
|
||
|
||
# BKT校准:高掌握度时适当提高评分(鼓励效应)
|
||
calibrated_score = raw_score * (0.85 + 0.15 * bkt_mastery)
|
||
final_score = min(100.0, calibrated_score * 100)
|
||
|
||
# 生成错误反馈
|
||
errors = []
|
||
for i, prob in enumerate(correct_probs):
|
||
if prob < 0.5:
|
||
errors.append({
|
||
'stroke_index': i + 1,
|
||
'confidence': float(1 - prob),
|
||
'description': f'第{i+1}笔笔顺可能有误'
|
||
})
|
||
|
||
feedback = self._generate_feedback(final_score, errors)
|
||
return {'score': round(final_score, 1), 'errors': errors, 'feedback': feedback}
|
||
|
||
def _extract_features(self, strokes):
|
||
result = []
|
||
for stroke in strokes[:self.MAX_STROKES]:
|
||
pts = stroke[:self.MAX_POINTS]
|
||
features = []
|
||
for j, pt in enumerate(pts):
|
||
dx = pt['x'] - pts[j-1]['x'] if j > 0 else 0.0
|
||
dy = pt['y'] - pts[j-1]['y'] if j > 0 else 0.0
|
||
features.append([pt['x'], pt['y'], dx, dy,
|
||
pt.get('pressure', 0.5),
|
||
1.0 if j == len(pts)-1 else 0.0])
|
||
# 填充到MAX_POINTS
|
||
while len(features) < self.MAX_POINTS:
|
||
features.append([0.0] * self.FEATURE_DIM)
|
||
result.append(features)
|
||
return np.array(result, dtype=np.float32)
|
||
|
||
@staticmethod
|
||
def _generate_feedback(score: float, errors: list) -> str:
|
||
if score >= 90:
|
||
return "笔顺非常标准,继续保持!"
|
||
elif score >= 75:
|
||
error_strokes = [str(e['stroke_index']) for e in errors]
|
||
return f"整体不错,第{'、'.join(error_strokes)}笔笔顺需要注意。"
|
||
elif score >= 60:
|
||
return f"笔顺有{len(errors)}处需要改进,请参考标准笔顺练习。"
|
||
else:
|
||
return "笔顺与标准差异较大,请仔细查看示范,重新练习。"
|
||
```
|
||
|
||
### F.4 令牌桶限流实现
|
||
|
||
```python
|
||
# middleware/rate_limiter.py
|
||
import time, threading
|
||
from functools import wraps
|
||
from flask import request, jsonify
|
||
|
||
class TokenBucketRateLimiter:
|
||
"""
|
||
令牌桶限流算法
|
||
每个AppKey独立维护一个令牌桶,以固定速率添加令牌
|
||
"""
|
||
|
||
def __init__(self, rate: float = 50.0, capacity: float = 100.0):
|
||
"""
|
||
Args:
|
||
rate: 令牌补充速率(个/秒)
|
||
capacity: 令牌桶容量(最大突发量)
|
||
"""
|
||
self.rate = rate
|
||
self.capacity = capacity
|
||
self._buckets: dict = {} # app_key -> (tokens, last_time)
|
||
self._lock = threading.Lock()
|
||
|
||
def consume(self, app_key: str, n: int = 1) -> bool:
|
||
"""消费n个令牌,返回True表示通过,False表示限流"""
|
||
with self._lock:
|
||
now = time.monotonic()
|
||
if app_key not in self._buckets:
|
||
self._buckets[app_key] = [self.capacity, now]
|
||
|
||
tokens, last_time = self._buckets[app_key]
|
||
|
||
# 补充令牌(根据时间差)
|
||
elapsed = now - last_time
|
||
tokens = min(self.capacity, tokens + elapsed * self.rate)
|
||
self._buckets[app_key][1] = now
|
||
|
||
if tokens >= n:
|
||
self._buckets[app_key][0] = tokens - n
|
||
return True
|
||
else:
|
||
self._buckets[app_key][0] = tokens
|
||
return False
|
||
|
||
# Flask装饰器
|
||
_limiter = TokenBucketRateLimiter(rate=50.0, capacity=100.0)
|
||
|
||
def rate_limit(f):
|
||
@wraps(f)
|
||
def decorated(*args, **kwargs):
|
||
app_key = request.headers.get('X-App-Key', 'anonymous')
|
||
if not _limiter.consume(app_key):
|
||
return jsonify({'code': 429, 'message': '请求过于频繁,请稍后重试'}), 429
|
||
return f(*args, **kwargs)
|
||
return decorated
|
||
```
|
||
|
||
---
|
||
|
||
## 附录F 补充算法与接口规格
|
||
|
||
### F.1 多模型集成推理框架
|
||
|
||
#### F.1.1 模型路由策略
|
||
|
||
```python
|
||
# model_router.py
|
||
from enum import Enum
|
||
from typing import Dict, Any
|
||
import numpy as np
|
||
|
||
class ModelType(Enum):
|
||
OCR_FAST = "ocr_fast" # 轻量模型,延迟<50ms
|
||
OCR_ACCURATE = "ocr_accurate" # 精准模型,延迟<200ms
|
||
MATH_FORMULA = "math_formula" # 数学公式专用
|
||
STROKE_EVAL = "stroke_eval" # 笔顺评分
|
||
|
||
class ModelRouter:
|
||
"""根据请求特征自动路由到最优模型"""
|
||
|
||
def __init__(self, models: Dict[ModelType, Any]):
|
||
self.models = models
|
||
self.stats = {m: {"count": 0, "avg_ms": 0} for m in ModelType}
|
||
|
||
def route(self, request: dict) -> ModelType:
|
||
content_type = request.get("content_type", "text")
|
||
quality = request.get("quality", "normal")
|
||
|
||
if content_type == "math":
|
||
return ModelType.MATH_FORMULA
|
||
|
||
if content_type == "stroke":
|
||
return ModelType.STROKE_EVAL
|
||
|
||
# 根据图像复杂度自动选择
|
||
if quality == "high" or self._is_complex_image(request.get("image")):
|
||
return ModelType.OCR_ACCURATE
|
||
else:
|
||
return ModelType.OCR_FAST
|
||
|
||
def _is_complex_image(self, image: np.ndarray) -> bool:
|
||
if image is None:
|
||
return False
|
||
# 基于图像方差判断复杂度
|
||
variance = np.var(image)
|
||
return variance > 1500
|
||
|
||
async def infer(self, request: dict) -> dict:
|
||
model_type = self.route(request)
|
||
model = self.models[model_type]
|
||
|
||
import time
|
||
start = time.perf_counter()
|
||
result = await model.predict(request["image"])
|
||
elapsed_ms = (time.perf_counter() - start) * 1000
|
||
|
||
# 更新统计
|
||
s = self.stats[model_type]
|
||
s["avg_ms"] = (s["avg_ms"] * s["count"] + elapsed_ms) / (s["count"] + 1)
|
||
s["count"] += 1
|
||
|
||
return {
|
||
"result": result,
|
||
"model": model_type.value,
|
||
"latency_ms": round(elapsed_ms, 2)
|
||
}
|
||
```
|
||
|
||
### F.2 OCR后处理管道
|
||
|
||
#### F.2.1 文本置信度过滤与纠错
|
||
|
||
```python
|
||
# ocr_postprocess.py
|
||
import re
|
||
from dataclasses import dataclass
|
||
from typing import List, Optional
|
||
|
||
@dataclass
|
||
class OcrWord:
|
||
text: str
|
||
confidence: float
|
||
bbox: tuple # (x1, y1, x2, y2)
|
||
|
||
class OcrPostProcessor:
|
||
CONF_THRESHOLD_HIGH = 0.90
|
||
CONF_THRESHOLD_LOW = 0.60
|
||
|
||
# 常见混淆字符映射(OCR错误→正确)
|
||
CONFUSION_MAP = {
|
||
"0": "O", "1": "l", "rn": "m", "cl": "d",
|
||
"己": "已", "末": "未", "土": "士"
|
||
}
|
||
|
||
def __init__(self, language_model=None):
|
||
self.lm = language_model # 可选语言模型用于纠错
|
||
|
||
def process(self, words: List[OcrWord]) -> str:
|
||
# 1. 过滤低置信度词
|
||
filtered = [w for w in words if w.confidence >= self.CONF_THRESHOLD_LOW]
|
||
|
||
# 2. 对中等置信度词进行纠错尝试
|
||
corrected = []
|
||
for word in filtered:
|
||
if word.confidence < self.CONF_THRESHOLD_HIGH:
|
||
fixed = self._try_correct(word.text)
|
||
corrected.append(fixed)
|
||
else:
|
||
corrected.append(word.text)
|
||
|
||
# 3. 拼接文本并后处理
|
||
text = " ".join(corrected)
|
||
text = self._normalize_spaces(text)
|
||
text = self._fix_punctuation(text)
|
||
|
||
return text
|
||
|
||
def _try_correct(self, text: str) -> str:
|
||
result = text
|
||
for wrong, right in self.CONFUSION_MAP.items():
|
||
result = result.replace(wrong, right)
|
||
|
||
if self.lm:
|
||
lm_result = self.lm.correct(result)
|
||
if lm_result.score > 0.8:
|
||
return lm_result.text
|
||
|
||
return result
|
||
|
||
def _normalize_spaces(self, text: str) -> str:
|
||
# 移除中文字符间多余空格
|
||
text = re.sub(r'([\u4e00-\u9fff])\s+([\u4e00-\u9fff])', r'\1\2', text)
|
||
return text.strip()
|
||
|
||
def _fix_punctuation(self, text: str) -> str:
|
||
# 标准化标点符号
|
||
replacements = [(',', ','), ('.', '。'), ('?', '?'), ('!', '!')]
|
||
for eng, chn in replacements:
|
||
# 仅在中文上下文中替换
|
||
text = re.sub(f'([\u4e00-\u9fff]){re.escape(eng)}',
|
||
f'\\1{chn}', text)
|
||
return text
|
||
```
|
||
|
||
### F.3 异步任务队列
|
||
|
||
```python
|
||
# async_task_queue.py
|
||
import asyncio
|
||
from collections import deque
|
||
from dataclasses import dataclass, field
|
||
from typing import Callable, Any
|
||
|
||
@dataclass
|
||
class Task:
|
||
id: str
|
||
priority: int
|
||
func: Callable
|
||
args: tuple
|
||
future: asyncio.Future = field(default_factory=asyncio.Future)
|
||
|
||
class PriorityTaskQueue:
|
||
"""优先级异步任务队列,高优先级任务优先执行"""
|
||
|
||
def __init__(self, workers: int = 4):
|
||
self.queue = asyncio.PriorityQueue()
|
||
self.workers = workers
|
||
self._counter = 0 # 用于相同优先级时保持FIFO顺序
|
||
|
||
async def submit(self, func: Callable, *args, priority: int = 5) -> Any:
|
||
future = asyncio.get_event_loop().create_future()
|
||
task = Task(
|
||
id=f"task_{self._counter}",
|
||
priority=priority,
|
||
func=func,
|
||
args=args,
|
||
future=future
|
||
)
|
||
self._counter += 1
|
||
# PriorityQueue按(priority, counter)排序,priority越小优先级越高
|
||
await self.queue.put((priority, self._counter, task))
|
||
return await future
|
||
|
||
async def _worker(self):
|
||
while True:
|
||
_, _, task = await self.queue.get()
|
||
try:
|
||
if asyncio.iscoroutinefunction(task.func):
|
||
result = await task.func(*task.args)
|
||
else:
|
||
result = await asyncio.get_event_loop().run_in_executor(
|
||
None, task.func, *task.args)
|
||
task.future.set_result(result)
|
||
except Exception as e:
|
||
task.future.set_exception(e)
|
||
finally:
|
||
self.queue.task_done()
|
||
|
||
async def start(self):
|
||
for _ in range(self.workers):
|
||
asyncio.create_task(self._worker())
|
||
```
|
||
|
||
---
|
||
|
||
## 附录G 补充技术规格
|
||
|
||
### G.1 模型热加载无缝切换
|
||
|
||
```python
|
||
# model_hot_swap.py
|
||
import threading
|
||
import time
|
||
from typing import Optional
|
||
|
||
class HotSwapModelManager:
|
||
"""模型热加载管理器,支持零停机切换模型版本"""
|
||
|
||
def __init__(self):
|
||
self._current_model = None
|
||
self._pending_model = None
|
||
self._lock = threading.RWLock()
|
||
self._request_count = 0
|
||
|
||
def load_new_version(self, model_path: str, model_type: str):
|
||
"""在后台加载新模型版本"""
|
||
def _load():
|
||
import torch
|
||
new_model = torch.jit.load(model_path)
|
||
new_model.eval()
|
||
|
||
# 等待当前请求处理完成
|
||
while self._request_count > 0:
|
||
time.sleep(0.01)
|
||
|
||
with self._lock.write():
|
||
old_model = self._current_model
|
||
self._current_model = new_model
|
||
self._pending_model = None
|
||
|
||
# 释放旧模型内存
|
||
if old_model is not None:
|
||
del old_model
|
||
torch.cuda.empty_cache()
|
||
|
||
print(f"Model {model_type} hot-swapped to {model_path}")
|
||
|
||
import threading
|
||
t = threading.Thread(target=_load, daemon=True)
|
||
t.start()
|
||
|
||
def infer(self, inputs):
|
||
"""推理时持有读锁,防止切换过程中的并发问题"""
|
||
with self._lock.read():
|
||
self._request_count += 1
|
||
try:
|
||
return self._current_model(inputs)
|
||
finally:
|
||
self._request_count -= 1
|
||
```
|
||
|
||
### G.2 GPU显存监控
|
||
|
||
```python
|
||
# gpu_monitor.py
|
||
import subprocess
|
||
import json
|
||
|
||
def get_gpu_stats() -> dict:
|
||
"""获取GPU使用统计(通过nvidia-smi)"""
|
||
try:
|
||
result = subprocess.run([
|
||
"nvidia-smi", "--query-gpu=name,memory.used,memory.total,utilization.gpu,temperature.gpu",
|
||
"--format=csv,noheader,nounits"
|
||
], capture_output=True, text=True, timeout=5)
|
||
|
||
if result.returncode != 0:
|
||
return {}
|
||
|
||
parts = [p.strip() for p in result.stdout.strip().split(",")]
|
||
return {
|
||
"name": parts[0],
|
||
"memory_used_mb": int(parts[1]),
|
||
"memory_total_mb": int(parts[2]),
|
||
"memory_util_pct": round(int(parts[1]) / int(parts[2]) * 100, 1),
|
||
"gpu_util_pct": int(parts[3]),
|
||
"temperature_c": int(parts[4])
|
||
}
|
||
except Exception as e:
|
||
return {"error": str(e)}
|
||
|
||
def check_oom_risk(threshold_pct: float = 90.0) -> bool:
|
||
"""检查是否有显存溢出风险"""
|
||
stats = get_gpu_stats()
|
||
if not stats or "memory_util_pct" not in stats:
|
||
return False
|
||
return stats["memory_util_pct"] >= threshold_pct
|
||
```
|
||
|
||
### G.3 识别结果缓存
|
||
|
||
```python
|
||
# result_cache.py
|
||
import hashlib
|
||
import json
|
||
import redis
|
||
from functools import wraps
|
||
|
||
class OcrResultCache:
|
||
"""基于Redis的识别结果缓存,提高重复图片的响应速度"""
|
||
|
||
def __init__(self, redis_client: redis.Redis, ttl_seconds: int = 3600):
|
||
self.redis = redis_client
|
||
self.ttl = ttl_seconds
|
||
|
||
def get_cache_key(self, image_bytes: bytes, options: dict) -> str:
|
||
"""基于图片内容和识别选项生成缓存键"""
|
||
content_hash = hashlib.sha256(image_bytes).hexdigest()
|
||
options_str = json.dumps(options, sort_keys=True)
|
||
options_hash = hashlib.md5(options_str.encode()).hexdigest()[:8]
|
||
return f"ocr:result:{content_hash}:{options_hash}"
|
||
|
||
def get(self, key: str) -> Optional[dict]:
|
||
value = self.redis.get(key)
|
||
if value:
|
||
return json.loads(value)
|
||
return None
|
||
|
||
def set(self, key: str, result: dict):
|
||
self.redis.setex(key, self.ttl, json.dumps(result, ensure_ascii=False))
|
||
|
||
def cached_ocr(self, ocr_func):
|
||
"""装饰器:为OCR函数添加缓存"""
|
||
@wraps(ocr_func)
|
||
async def wrapper(image_bytes: bytes, **options):
|
||
key = self.get_cache_key(image_bytes, options)
|
||
cached = self.get(key)
|
||
if cached:
|
||
cached["from_cache"] = True
|
||
return cached
|
||
|
||
result = await ocr_func(image_bytes, **options)
|
||
self.set(key, result)
|
||
return result
|
||
return wrapper
|
||
```
|
||
|
||
---
|
||
|
||
## 附录H 补充技术规格
|
||
|
||
### H.1 数学公式识别增强
|
||
|
||
```python
|
||
# math_formula_postprocess.py
|
||
import re
|
||
|
||
class MathFormulaPostProcessor:
|
||
"""数学公式识别后处理:LaTeX语法规范化"""
|
||
|
||
# 常见OCR错误修正
|
||
CORRECTIONS = {
|
||
r'\\frac\s*{': r'\\frac{',
|
||
r'\\sqrt\s*{': r'\\sqrt{',
|
||
r'\\sum\s*_': r'\\sum_',
|
||
r'x\^2': r'x^{2}', # 补充缺失的花括号
|
||
r'x\^(\d)': r'x^{\1}',
|
||
r'([a-z])\^([a-z])': r'\1^{\2}',
|
||
}
|
||
|
||
def process(self, latex: str) -> str:
|
||
result = latex.strip()
|
||
|
||
# 应用修正规则
|
||
for pattern, replacement in self.CORRECTIONS.items():
|
||
result = re.sub(pattern, replacement, result)
|
||
|
||
# 确保公式包含在$...$中
|
||
if not result.startswith('$'):
|
||
result = f'${result}$'
|
||
|
||
# 验证括号平衡
|
||
if not self._check_braces(result):
|
||
result = self._fix_braces(result)
|
||
|
||
return result
|
||
|
||
def _check_braces(self, s: str) -> bool:
|
||
count = 0
|
||
for c in s:
|
||
if c == '{': count += 1
|
||
elif c == '}': count -= 1
|
||
if count < 0: return False
|
||
return count == 0
|
||
|
||
def _fix_braces(self, s: str) -> str:
|
||
"""自动补全缺失的右括号"""
|
||
count = sum(1 if c == '{' else -1 if c == '}' else 0 for c in s)
|
||
if count > 0:
|
||
s = s.rstrip('$') + '}' * count
|
||
if not s.endswith('$'): s += '$'
|
||
return s
|
||
```
|
||
|
||
---
|
||
|
||
### H.2 版本历史
|
||
|
||
| 版本号 | 发布日期 | 变更说明 | 负责人 |
|
||
|--------|----------|---------|--------|
|
||
| V1.0.0 | 2024-01-15 | 初始版本,实现汉字/数字OCR基础识别 | AI组 |
|
||
| V1.1.0 | 2024-03-20 | 新增数学公式识别(Im2Latex模型) | 算法组 |
|
||
| V1.2.0 | 2024-05-15 | 引入TensorRT INT8量化,GPU推理延迟降低40% | 工程组 |
|
||
| V1.3.0 | 2024-07-10 | 新增笔顺评分功能,BKT算法校准 | AI组 |
|
||
| V1.4.0 | 2024-09-01 | 添加识别结果Redis缓存,重复图片响应<5ms | 工程组 |
|
||
| V1.5.0 | 2024-11-15 | 支持模型热加载,零停机版本升级 | 工程组 |
|
||
|
||
---
|
||
|
||
*本文档版权归深圳自然写科技有限公司所有,仅用于软件著作权登记鉴别。*
|