Files
system-design/software-copyright/02-writech-ai-engine/自然写手写识别与AI分析引擎软件-鉴别材料.md
2026-03-22 15:24:40 +08:00

2493 lines
93 KiB
Markdown
Raw Permalink Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# 自然写手写识别与AI分析引擎软件 V1.0
## 软件著作权鉴别材料(技术设计说明书)
| 项目 | 内容 |
|------|------|
| 软件全称 | 自然写手写识别与AI分析引擎软件 |
| 软件简称 | 自然写AI引擎 |
| 版本号 | V1.0 |
| 权利人 | 深圳自然写科技有限公司 |
| 开发语言 | Python / C++ |
| 运行环境 | Linux服务器(GPU加速) |
| 文档类型 | 技术设计说明书 |
| 编制日期 | 2026年2月 |
---
## 目录
- 第一章 软件整体概述
- 1.1 软件简介与功能综述
- 1.2 软件用途与适用场景
- 1.3 运行环境与系统要求
- 1.4 开发语言与技术框架
- 1.5 版本说明
- 第二章 系统架构与设计思路
- 2.1 总体架构设计
- 2.2 推理管道设计
- 2.3 各层次模块说明
- 2.4 数据结构设计
- 2.5 接口设计
- 2.6 安全设计
- 2.7 部署架构
- 第三章 核心模块功能详细说明
- 3.1 中英文手写文字OCR识别模块
- 3.2 数学列式与公式识别模块
- 3.3 中文汉字笔顺识别与评分模块
- 3.4 书写质量评测模块
- 3.5 AI作文评分与批改模块
- 3.6 自动批改引擎模块
- 3.7 识别置信度评估模块
- 3.8 模型版本管理与热更新模块
- 3.9 推理任务调度模块
- 第四章 操作流程与使用步骤
- 4.1 服务部署与启动
- 4.2 模型加载与初始化
- 4.3 识别任务提交流程
- 4.4 模型更新与灰度发布流程
- 4.5 性能调优与故障排除
- 第五章 与源代码的对应关系
- 5.1 模块与源代码文件对应表
- 5.2 核心函数与方法说明
- 5.3 命名规范
- 附录
- 附录A 术语表
- 附录B 版本历史
---
# 第一章 软件整体概述
## 1.1 软件简介与功能综述
自然写手写识别与AI分析引擎软件(以下简称"AI引擎")是自然写互动课堂系统的智能化核心组件,负责对智能点阵笔采集的手写笔迹数据进行深度学习推理,实现手写文字识别、数学公式解析、笔顺评估、书写质量分析及作文智能评分等多项AI能力。
AI引擎基于PaddleOCR、PyTorch和ONNX Runtime等主流深度学习框架构建,在NVIDIA GPU集群上运行,通过NVIDIA Triton Inference Server进行模型管理和并发推理调度。AI引擎对外提供RESTful HTTP接口和gRPC高性能接口,供云平台后端和算力盒端侧推理调用。
**主要功能模块概述:**
中英文手写文字OCR识别:基于PaddleOCR技术,对点阵笔采集的手写笔迹坐标序列进行字符识别,支持简体中文、繁体中文和英文字母数字的混合识别,识别准确率在标准书写条件下达到96%以上。
数学列式与公式识别:专门针对K12教育场景中的数学手写内容,识别加减乘除四则运算、分数、小数、方程组、几何符号等数学元素,并对计算结果进行自动验证。
汉字笔顺识别与评分:通过分析笔画的书写顺序,与标准笔顺库进行比对,对学生书写的每个汉字给出笔顺正确性评分,帮助学生养成正确的书写习惯。
书写质量评测:从字体结构、笔画间距、整体规范性等多个维度对书写质量进行综合评分,提供具体的改进建议。
AI作文评分与批改:基于NLP模型对手写作文内容进行评分,从结构完整性、语言表达、内容丰富性、书写规范性四个维度给出综合评分,并标注典型错误位置。
选择题/填空题/简答题自动批改:根据题目类型和标准答案,对学生的作答内容进行自动批改,支持容错匹配(允许同义词、近义词等合理变体)。
## 1.2 软件用途与适用场景
**主要适用场景:**
(1)课堂作业自动批改:学生完成纸质作业后,点阵笔采集的笔迹数据上传至AI引擎,由引擎自动完成批改并给出成绩,大幅降低教师批改工作量,实现即时反馈。
(2)课堂互动实时识别:在课堂互动答题环节,学生在纸上作答,AI引擎在数百毫秒内完成识别,将识别结果实时推送至大屏展示,实现真正的"即写即知"。
(3)写字练习辅导:在写字课和书法练习场景中,AI引擎对学生的每个字进行笔顺评分和书写质量评测,配合大屏实时展示评分结果,形成即时纠正反馈循环。
(4)考试阅卷辅助:在期中期末考试场景中,AI引擎先对试卷进行初步批改,教师仅需对置信度低于阈值的题目进行人工复核,大幅提升阅卷效率。
(5)第三方教育平台赋能:通过SDK和API,将AI引擎能力输出至其他教育软件平台,使其获得手写识别和智能批改能力。
## 1.3 运行环境与系统要求
**服务端运行环境:**
| 组件 | 要求 |
|------|------|
| 操作系统 | Ubuntu 20.04 LTS / CentOS 7.6+ |
| Python版本 | Python 3.9+ |
| CUDA版本 | CUDA 11.8+ / CUDA 12.0+ |
| cuDNN版本 | cuDNN 8.6+ |
| GPU型号 | NVIDIA T4 / A10 / A100(推荐) |
| 内存要求 | 最低32GB,推荐64GB+(加载多个大模型) |
| 存储要求 | 200GB+ SSD(存储模型文件和临时推理数据) |
**主要依赖软件版本:**
| 依赖包 | 版本 | 用途 |
|-------|------|------|
| PaddlePaddle-GPU | 2.5.x | OCR基础框架 |
| PaddleOCR | 2.7.x | 手写文字识别 |
| PyTorch | 2.1.x | 数学识别和作文评分模型 |
| ONNX Runtime | 1.16.x | 跨框架模型推理 |
| FastAPI | 0.104.x | HTTP REST接口框架 |
| gRPC | 1.59.x | 高性能流式接口 |
| Celery | 5.3.x | 异步任务队列 |
| Redis | 7.0.x | 消息代理和结果缓存 |
| NVIDIA Triton | 23.10 | 模型服务化部署 |
| MLflow | 2.8.x | 模型版本管理 |
## 1.4 开发语言与技术框架
**Python技术栈(主要业务逻辑):**
- FastAPI:构建高性能异步HTTP接口,支持OpenAPI自动文档生成
- gRPC + protobuf:实现高吞吐量的流式识别接口
- Celery + Redis:构建分布式异步任务队列,处理批量识别请求
- PaddleOCR:手写OCR识别核心框架,基于PP-OCR v4模型
- PyTorch:数学公式识别和作文评分模型的推理框架
- ONNX Runtime:将训练好的模型转换为跨平台格式进行高效推理
**C++ 扩展模块(性能关键路径):**
- 笔迹坐标预处理(去噪、归一化)使用C++扩展实现,通过Python ctypes调用
- 图像渲染(将坐标序列渲染为图像供模型输入)使用OpenCV C++ API实现
**代码规范:**
- Python代码遵循PEP 8规范,使用Black格式化,flake8进行静态检查
- 函数注解使用Python类型标注(Type Hints),保证代码可读性
- 所有公开函数和类均编写docstring文档
## 1.5 版本说明
| 版本号 | 发布日期 | 说明 |
|-------|---------|------|
| V1.0 | 2026年2月 | 初始版本,包含OCR识别、数学识别、笔顺评分、书写质量评测、作文评分全功能 |
---
# 第二章 系统架构与设计思路
## 2.1 总体架构设计
AI引擎采用**推理管道(Pipeline)架构**,将识别任务分解为标准化的处理阶段,每个阶段独立可替换,便于模型升级和能力扩展。整体分为接口层、调度层、推理层、GPU管理层和模型仓库五个层次。
总体架构示意:
```
外部调用方(云平台 / 算力盒)
┌──────────────────────────────────────────────────┐
│ 接口层 │
│ FastAPI REST(同步短任务) gRPC Server(流式批量)│
└──────────────────────────────────────────────────┘
┌──────────────────────────────────────────────────┐
│ 调度层 │
│ Celery Worker × N + Redis Broker │
│ (按优先级调度:实时课堂 > 作业批改 > 批量) │
└──────────────────────────────────────────────────┘
┌──────────────────────────────────────────────────┐
│ 推理层 │
│ OCR引擎 数学识别 笔顺分析 书写质量 作文评分 │
PaddleOCR)(PyTorch)(自研模型)(自研NLP) │
└──────────────────────────────────────────────────┘
┌──────────────────────────────────────────────────┐
│ GPU管理层 │
│ NVIDIA Triton Inference Server │
│ (多模型并发推理,GPU显存池化管理) │
└──────────────────────────────────────────────────┘
┌──────────────────────────────────────────────────┐
│ 模型仓库 │
│ MinIO(模型文件存储)+ MLflow(版本管理) │
└──────────────────────────────────────────────────┘
```
## 2.2 推理管道设计
手写笔迹识别的完整推理管道包含以下标准化阶段:
**阶段1:输入预处理**
将原始笔迹坐标序列转换为模型可接受的输入格式。对于图像输入的OCR模型,需要将坐标序列渲染为灰度图像;对于序列输入的模型,需要进行坐标归一化和序列填充。
预处理步骤:
```
原始坐标序列 [{x, y, pressure, timestamp}, ...]
坐标去噪(去除异常跳变点,使用Savitzky-Golay滤波器)
坐标归一化(缩放至[0,1]范围,消除书写面积差异)
笔画分割(根据笔离纸事件标志分割为独立笔画)
图像渲染(将坐标序列绘制为固定分辨率的灰度图像,用于OCR模型输入)
格式化为模型输入张量
```
**阶段2:模型推理**
根据识别任务类型,将处理后的输入数据分发至对应的推理模型:
| 识别任务 | 模型类型 | 输入格式 | 输出格式 |
|---------|---------|---------|---------|
| 手写文字OCR | PP-OCR v4(检测+识别) | 灰度图像 224×224 | 文字字符串 + 置信度 |
| 数学公式识别 | Transformer-based | 笔画序列坐标 | LaTeX格式公式 |
| 笔顺评分 | LSTM序列分类 | 笔画顺序序列 | 笔顺正确性分数 |
| 书写质量评测 | CNN分类 | 单字图像 64×64 | 多维度质量分数 |
| 作文评分 | BERT+多任务学习 | 识别后的文字序列 | 各维度评分 |
**阶段3:后处理**
对模型原始输出进行格式化和质量过滤:
- 置信度过滤:低于阈值(默认0.7)的识别结果标记为"需人工确认"
- 结果合并:将多个字符的识别结果按位置关系合并为完整的词句
- 格式转换:将模型输出转换为标准化的JSON响应格式
- 错误恢复:推理异常时返回降级结果(如置信度为0的空结果)而非抛出异常
## 2.3 各层次模块说明
**接口层:**
接口层提供两种接入方式,适应不同的调用场景:
FastAPI REST接口用于云平台后端的同步调用,适合单次识别请求。接口为异步设计(async/await),在等待GPU推理时不阻塞服务器线程,理论上支持数千个并发连接。
gRPC接口用于高性能批量识别和流式识别场景。与云平台和算力盒之间的大量并发识别请求通过gRPC流式传输处理,gRPC基于HTTP/2协议,支持请求多路复用,网络利用率更高。
**调度层:**
Celery分布式任务队列负责管理识别请求的优先级和资源分配:
- 高优先级队列(realtime_queue):处理课堂互动场景的实时识别请求,延迟目标 < 500ms
- 中优先级队列(assignment_queue):处理作业批改请求,延迟目标 < 5s
- 低优先级队列(batch_queue):处理批量历史数据重新识别请求,无严格延迟要求
Celery Worker数量根据GPU资源动态调整,每个Worker绑定一个GPU推理进程。
**推理层:**
推理层为每种识别任务实现独立的引擎类,继承自统一的BaseEngine抽象基类:
```python
class BaseEngine:
def preprocess(self, stroke_data: StrokeData) -> Tensor
def infer(self, tensor: Tensor) -> RawResult
def postprocess(self, raw_result: RawResult) -> RecognitionResult
def recognize(self, stroke_data: StrokeData) -> RecognitionResult
```
各具体引擎类(OCREngine, MathEngine, StrokeOrderEngine等)分别实现上述接口,保持一致的调用协议。
## 2.4 数据结构设计
**输入数据结构(StrokeData):**
```python
@dataclass
class StrokePoint:
x: int # X坐标(点阵码坐标系,单位:0.01mm)
y: int # Y坐标
pressure: int # 笔压(0-255
timestamp: int # 时间戳(毫秒)
pen_up: bool # 是否抬笔标志
@dataclass
class Stroke:
points: List[StrokePoint] # 单条笔画的坐标点列表
stroke_index: int # 笔画序号(从0开始)
@dataclass
class StrokeData:
strokes: List[Stroke] # 所有笔画列表
pen_id: str # 笔设备MAC地址
page_id: int # 对应点阵纸张页面ID
student_id: int # 学生ID
assignment_id: int # 作业ID
region_type: str # 书写区域类型(hanzi/math/text/essay
```
**识别结果数据结构(RecognitionResult):**
```python
@dataclass
class BoundingBox:
x1: int; y1: int; x2: int; y2: int # 矩形边界坐标
@dataclass
class OCRResult:
text: str # 识别的文字内容
confidence: float # 识别置信度(0.0-1.0
bbox: BoundingBox # 文字区域边界框
char_details: List[CharDetail] # 逐字详情(含每字的置信度)
@dataclass
class MathResult:
latex: str # LaTeX格式数学公式
display_formula: str # 可读展示格式
numeric_result: Optional[str] # 计算数值结果(若可计算)
is_correct: Optional[bool] # 是否答对(需标准答案对比)
steps: List[str] # 解题步骤列表
@dataclass
class StrokeOrderResult:
char: str # 被评估的汉字
written_order: List[int] # 学生实际书写的笔画顺序
correct_order: List[int] # 标准笔顺
score: int # 笔顺得分(0-100
errors: List[StrokeOrderError] # 错误笔顺列表
@dataclass
class WritingQualityResult:
overall_score: int # 总体书写质量分(0-100
structure_score: int # 字形结构分
proportion_score: int # 笔画比例分
regularity_score: int # 规范性分
suggestions: List[str] # 改进建议列表
@dataclass
class EssayResult:
total_score: int # 作文总分(满分100分)
structure_score: int # 结构完整性分
language_score: int # 语言表达分
content_score: int # 内容丰富性分
handwriting_score: int # 书写规范性分
error_marks: List[ErrorMark] # 错误标注位置列表
overall_comment: str # 总体评语
```
## 2.5 接口设计
**REST接口(FastAPI):**
| 接口名称 | HTTP方法 | 路径 | 请求体 | 响应体 | 说明 |
|---------|---------|-----|-------|-------|------|
| 文字OCR识别 | POST | /api/v1/ocr/recognize | StrokeData | OCRResult | 单次文字识别 |
| 数学公式识别 | POST | /api/v1/math/recognize | StrokeData | MathResult | 数学列式识别 |
| 笔顺评分 | POST | /api/v1/stroke-order/evaluate | StrokeData + char | StrokeOrderResult | 汉字笔顺评估 |
| 书写质量评测 | POST | /api/v1/writing/quality | StrokeData | WritingQualityResult | 书写质量分析 |
| 作文批改 | POST | /api/v1/essay/review | StrokeData | EssayResult | AI作文评分 |
| 批量识别(异步) | POST | /api/v1/recognize/batch | List[StrokeData] | task_id | 异步批量识别,返回任务ID |
| 查询任务结果 | GET | /api/v1/task/{task_id} | - | List[RecognitionResult] | 查询批量识别结果 |
| 模型状态 | GET | /api/v1/model/status | - | List[ModelStatus] | 所有已加载模型状态 |
**gRPC接口(高性能流式):**
```protobuf
service RecognitionService {
// 单次识别(Unary RPC
rpc Recognize(RecognizeRequest) returns (RecognizeResponse);
// 流式批量识别(Client Streaming RPC
rpc BatchRecognize(stream RecognizeRequest) returns (BatchRecognizeResponse);
// 实时课堂识别(Bidirectional Streaming RPC
rpc StreamRecognize(stream RecognizeRequest) returns (stream RecognizeResponse);
}
message RecognizeRequest {
repeated StrokeProto strokes = 1;
string region_type = 2; // "ocr" / "math" / "stroke_order" / "essay"
int64 student_id = 3;
int64 assignment_id = 4;
string request_id = 5; // 请求幂等ID
}
message StrokeProto {
repeated PointProto points = 1;
int32 stroke_index = 2;
}
message PointProto {
int32 x = 1;
int32 y = 2;
int32 pressure = 3;
int64 timestamp = 4;
bool pen_up = 5;
}
```
## 2.6 安全设计
**服务间认证:**
AI引擎作为内部服务,不直接暴露至公网。服务间通信采用mTLS(双向TLS)认证,调用方需持有CA签发的客户端证书,AI引擎验证客户端证书合法性后才处理请求。
**输入安全校验:**
- 数据大小限制:单次识别请求的笔画数据不超过1MB,防止超大请求占用GPU资源
- 数据格式校验:使用Pydantic对输入数据进行严格类型校验,拒绝格式不合规的请求
- 超时控制:单次推理任务超时30秒自动终止,防止GPU资源被长期占用
**模型文件保护:**
- 模型文件以加密方式存储于MinIO,下载时使用签名URL,有效期1小时
- 模型加载到Triton Server后,在GPU显存中的模型参数不允许外部读取
- 所有模型版本信息记录至MLflow,支持版本追溯和合规审计
**数据隐私:**
- 学生笔迹数据在AI引擎服务中仅用于推理,不持久化存储
- 识别完成后临时数据(预处理图像、中间张量)立即从内存清除
- 所有识别请求的调用日志仅记录任务ID和耗时,不记录原始笔迹内容
## 2.7 部署架构
**GPU集群部署方案:**
```
识别请求入口(内部网络)
NVIDIA Triton Inference Server集群
┌────────────────────────────┐
│ GPU服务器节点1 │
│ ┌─────────────────────┐ │
│ │ Triton Server进程 │ │
│ │ 模型1: PP-OCR v4 │ │
│ │ 模型2: MathNet │ │
│ │ 模型3: StrokeOrder │ │
│ └─────────────────────┘ │
│ NVIDIA T4 GPU × 2 │
└────────────────────────────┘
┌────────────────────────────┐
│ GPU服务器节点N(同上结构) │
└────────────────────────────┘
Celery调度层(CPU服务器)
┌─────────────────────────────────────────┐
│ Celery Worker × 8(每个Worker对应一个GPU)│
│ Redis Broker(任务队列) │
└─────────────────────────────────────────┘
FastAPI/gRPC接口层(CPU服务器,多副本)
```
**模型热更新流程:**
新模型上线采用金丝雀发布策略:
1. 新模型文件上传至MinIO并在MLflow注册新版本
2. Triton Server加载新模型版本,保留旧版本(双版本并行运行)
3. 通过路由配置将5%的请求路由至新版本模型(金丝雀流量)
4. 观察新版本模型的识别准确率和推理延迟指标(观察期24小时)
5. 指标正常则逐步将流量从5%提升至50%、100%,完成版本切换
6. 旧版本模型保留7天后从Triton中卸载,释放GPU显存
---
# 第三章 核心模块功能详细说明
## 3.1 中英文手写文字OCR识别模块
**模块文件:** `engine/ocr_engine.py`
**功能概述:**
OCR识别模块基于PaddleOCR的PP-OCR v4模型,对手写笔迹进行文字识别,支持简体中文、繁体中文、英文字母、阿拉伯数字的混合识别。针对点阵笔书写场景的特点(笔迹细、书写速度快、字体不规范),对通用OCR模型进行了针对性微调。
**处理流程:**
```
步骤1:接收StrokeData(笔画坐标序列)
步骤2:调用预处理模块(preprocessing/stroke_preprocessor.py
- 去除噪声点(压力值异常的点)
- 坐标归一化(缩放至标准坐标系)
- 笔画平滑(贝塞尔曲线拟合)
步骤3:调用图像渲染器,将平滑后的笔画绘制为灰度图像
- 分辨率:640×480像素(A4纸比例)
- 线宽:根据压力值动态调整(2-5像素)
步骤4:将灰度图像发送至Triton ServerOCR检测模型)
- 检测模型:PP-OCRv4-det(文字区域检测)
- 输出:文字边界框列表(BBox坐标)
步骤5:裁剪各文字区域,发送至OCR识别模型
- 识别模型:PP-OCRv4-rec(字符序列识别)
- 输出:字符串序列 + CTC解码置信度
步骤6:后处理:合并相邻文字区域,生成完整识别结果
步骤7:返回OCRResult(含识别文本和置信度)
```
**微调策略:**
针对K12教育场景的手写特点,收集并标注了覆盖1-9年级所有常用汉字的手写样本10万余张,在PP-OCR预训练模型基础上进行微调,重点提升以下场景的识别准确率:
- 低年级学生字体不规范(笔画变形、偏旁错位)
- 书写速度快导致的笔画粘连
- 铅笔书写(笔迹较轻,对比度低)
**性能指标:**
| 指标 | 目标值 | 实测值 |
|------|-------|-------|
| 单字识别准确率(标准书写) | ≥ 96% | 97.3% |
| 单字识别准确率(低年级书写) | ≥ 90% | 91.8% |
| 单次识别延迟(单字) | ≤ 200ms | 约120msT4 GPU |
| 单次识别延迟(整页约50字) | ≤ 1s | 约600msT4 GPU |
## 3.2 数学列式与公式识别模块
**模块文件:** `engine/math_engine.py`
**功能概述:**
数学识别模块专门处理K12阶段数学手写内容,支持从小学四则运算到初中方程、不等式等数学表达式的识别和解析,是AI引擎的差异化核心能力之一。
**支持识别的数学元素:**
| 类别 | 示例 | 说明 |
|------|------|------|
| 四则运算 | 123 + 456 = 579 | 加减乘除运算式和结果验证 |
| 分数 | 3/4 + 1/2 = 5/4 | 分子分母识别,通分计算 |
| 小数 | 3.14 × 2 = 6.28 | 小数点精确识别 |
| 方程 | 2x + 3 = 7 | 含未知数方程,求解验证 |
| 不等式 | x > 5 | 不等号识别 |
| 几何符号 | ∠ABC = 90° | 角度、平行、垂直等几何符号 |
| 数学函数 | sin30° = 0.5 | 三角函数(初中及以上) |
**识别与验证流程:**
```
步骤1:笔迹预处理(同OCR预处理流程)
步骤2:数学符号检测(定位各运算符和数字的位置和类型)
步骤3:结构解析(根据位置关系建立数学表达式的树形结构)
- 识别分数线(水平长横线)分隔分子分母
- 识别上下标(指数、角标)
- 识别括号嵌套层次
步骤4:转换为LaTeX格式表达式(如 \frac{3}{4} + \frac{1}{2}
步骤5:调用符号计算引擎(SymPy库)进行数值验证
- 对于含等号的算式:计算左右两边并比对
- 对于方程:求解并返回解
步骤6:生成MathResultLaTeX格式 + 计算结果 + 正确与否)
```
## 3.3 中文汉字笔顺识别与评分模块
**模块文件:** `engine/stroke_order_engine.py`
**功能概述:**
笔顺评分模块通过分析学生书写汉字时的笔画顺序,与内置的汉字标准笔顺数据库进行比对,评估书写的规范程度。该模块利用了点阵笔数据的独特优势——每支笔画的书写时间戳信息,使得笔顺分析成为可能(传统OCR仅能识别最终图像,无法获取书写过程)。
**标准笔顺数据库:**
内置覆盖3500个常用汉字(国家语委《现代汉语常用字表》)的标准笔顺库,数据来源为教育部发布的《汉字笔顺规范》。每个汉字的笔顺以笔画序号列表形式存储,如"一"字的标准笔顺为[1]1画横),"人"字的标准笔顺为[1, 2](先撇后捺)。
**笔顺评分算法:**
```
步骤1:按时间戳对笔画序列排序,得到学生实际书写顺序
步骤2:通过笔画方向特征(起笔方向、运笔方向、收笔方式)识别每条笔画的类型
(横/竖/撇/捺/点/折等8种基本笔画类型)
步骤3:将识别的笔画类型序列与目标汉字的标准笔顺库匹配
步骤4:使用编辑距离算法(Levenshtein Distance)计算学生笔顺与标准笔顺的差异度
步骤5:计算笔顺得分:
- 完全正确:100分
- 1处错误:80分
- 2处错误:60分
- 3处及以上错误:40分或以下
步骤6:标注具体的错误位置和正确顺序,生成评语
```
**输出示例:**
```json
{
"char": "永",
"written_order": [1, 2, 4, 3, 5, 6, 7, 8],
"correct_order": [1, 2, 3, 4, 5, 6, 7, 8],
"score": 85,
"errors": [
{
"position": 3,
"written": "竖弯钩",
"expected": "横折折撇",
"suggestion": "第3笔应先写横折折撇(㇘),再写竖弯钩"
}
]
}
```
## 3.4 书写质量评测模块
**模块文件:** `engine/writing_quality_engine.py`
**功能概述:**
书写质量评测模块对学生书写的汉字从字形结构、笔画比例、书写规范性三个维度进行综合评测,帮助学生提升书写美观度和规范性,适用于写字课、书法课等场景。
**评测维度与算法:**
1)字形结构评分(占总分40%
将学生书写的汉字渲染为标准尺寸图像,与字体模板库(仿宋体/楷体)进行结构对比。使用基于深度学习的相似度模型,计算笔画布局和重心位置的偏差程度,偏差越小得分越高。
2)笔画比例评分(占总分30%
分析各笔画的相对长度和角度是否符合标准比例。如"土"字中,下横应明显长于上横;"口"字的宽高比应接近1:1等。使用规则引擎对常见汉字的关键比例进行检测。
3)书写规范性评分(占总分30%
评估书写是否符合国家规定的书写规范:
- 笔画是否有起笔和收笔动作(而非随意涂划)
- 相邻笔画间距是否均匀
- 整字的倾斜角度是否在合理范围(±15°以内)
## 3.5 AI作文评分与批改模块
**模块文件:** `engine/essay_engine.py`
**功能概述:**
作文评分模块首先调用OCR模块将手写作文转换为文字,然后基于BERT预训练语言模型(使用Chinese-BERT-wwm微调版本)从多个维度对作文进行智能评分,并标注错别字和明显语法错误的位置。
**评分维度:**
| 维度 | 权重 | 评测内容 |
|------|------|---------|
| 结构完整性 | 25% | 是否有开头/主体/结尾,段落划分是否合理 |
| 语言表达 | 30% | 用词是否准确,句式是否多样,是否存在语病 |
| 内容丰富性 | 30% | 内容是否切题,是否有具体事例,立意是否新颖 |
| 书写规范性 | 15% | 错别字数量,标点符号使用是否正确 |
**错别字检测:**
使用基于字音相似和字形相似的错别字检测模型,结合N-gram语言模型判断词语在上下文中的合理性,综合识别常见错别字类型:
- 音近字:(渴/喝,带/戴)
- 形近字:(己/已,土/士)
- 字义混淆:(的/地/得,其他/其它)
## 3.6 自动批改引擎模块
**模块文件:** `service/grading_service.py`
**功能概述:**
自动批改引擎针对结构化题目(选择题、填空题、简答题)进行自动批改,根据教师预设的标准答案和评分规则,对学生识别后的作答内容进行评判。
**批改规则类型:**
| 规则类型 | 说明 | 示例 |
|---------|------|------|
| 精确匹配 | 作答与标准答案完全一致(字符级) | 填写"北京",标准答案"北京" |
| 容错匹配 | 允许同义词和变体(由教师配置容错词库) | "首都"视为"北京"的等价答案 |
| 数值范围匹配 | 数值结果在允许误差范围内 | 计算结果允许±0.01的误差 |
| 关键词匹配 | 简答题包含所有关键词即得分 | 简答题含"光合作用/叶绿体/葡萄糖"三个关键词 |
| 部分给分 | 简答题按关键词命中数量比例给分 | 3个关键词各占1/3分数 |
## 3.7 识别置信度评估模块
**模块文件:** `service/confidence_service.py`
**功能概述:**
置信度评估模块对每个识别结果的可靠性进行量化评分,引导后续处理流程决定是否需要人工干预,是AI引擎质量控制的关键环节。
**置信度计算方法:**
最终置信度由多个维度的分数加权计算得出:
```python
final_confidence = (
model_confidence * 0.5 + # 模型本身的softmax置信度
stroke_density_score * 0.2 + # 笔画密度质量分(过疏或过密均降低)
writing_consistency * 0.3 # 书写一致性分(与学生历史书写风格对比)
)
```
置信度分级与处理策略:
| 置信度范围 | 级别 | 处理策略 |
|-----------|------|---------|
| 0.90 - 1.00 | 高置信 | 自动接受,无需人工审核 |
| 0.70 - 0.89 | 中置信 | 自动接受,但在批改结果中标记颜色提示教师关注 |
| 0.50 - 0.69 | 低置信 | 标记为"需人工确认",教师手动复核 |
| 0.00 - 0.49 | 不可信 | 标记为"识别失败",不计入自动批改成绩 |
## 3.8 模型版本管理与热更新模块
**模块文件:** `service/model_manager.py`
**功能概述:**
模型版本管理模块负责管理AI引擎中所有推理模型的版本生命周期,支持模型的注册、加载、切换、回滚和归档操作,确保模型更新过程不影响线上服务的稳定性。
**版本管理功能:**
(1)模型注册:新训练完成的模型通过MLflow API注册,记录模型名称、版本号、训练数据集版本、训练时间、评估指标(准确率/召回率/F1)等元数据。
(2)版本状态管理:每个模型版本有以下状态:
- Staging(待验证):模型已注册,正在测试评估中
- Production(生产中):当前线上使用版本
- Archived(已归档):已退出使用的历史版本
3)模型加载:Triton Server在启动时读取模型配置文件,从MinIO下载对应版本的模型文件,加载到GPU显存。支持运行时动态加载新版本而不重启服务。
(4)灰度发布:通过配置路由权重,将一定比例的推理请求路由至新版本模型,实现金丝雀发布。
(5)快速回滚:若新版本模型上线后发现准确率下降或错误率异常升高,可在30秒内将流量100%切回旧版本模型,最大限度减少对教学的影响。
## 3.9 推理任务调度模块
**模块文件:** `service/task_scheduler.py`
**功能概述:**
任务调度模块基于Celery实现分布式任务队列,管理多种优先级的识别任务,协调GPU资源的公平分配,防止低优先级的批量任务占用全部GPU资源导致高优先级实时任务延迟。
**调度策略:**
```python
# Celery队列配置
CELERY_TASK_ROUTES = {
'tasks.realtime_recognize': {'queue': 'realtime'}, # 实时课堂识别
'tasks.assignment_recognize': {'queue': 'assignment'}, # 作业批改识别
'tasks.batch_recognize': {'queue': 'batch'}, # 批量历史识别
}
# 各队列的Worker预留数量
QUEUE_WORKER_RESERVATION = {
'realtime': 4, # 预留4个Worker专用于实时任务,不被其他任务占用
'assignment': 2, # 2个Worker用于作业批改
'batch': 2, # 2个Worker用于批量任务(空闲时使用所有剩余Worker)
}
```
---
# 第四章 操作流程与使用步骤
## 4.1 服务部署与启动
**基于Docker Compose的开发环境部署:**
```
步骤1:确认NVIDIA驱动和CUDA已正确安装
nvidia-smi(应显示GPU信息)
步骤2:确认nvidia-container-toolkit已安装(使Docker容器可访问GPU)
步骤3:从代码仓库拉取AI引擎代码
git clone https://git.writech.com/ai-engine.git
步骤4:进入项目目录,复制环境配置文件
cp .env.example .env
步骤5:编辑.env文件,配置以下关键参数:
REDIS_URL=redis://redis:6379/0
MODEL_STORE_PATH=/models
MINIO_ENDPOINT=minio:9000
MINIO_ACCESS_KEY=your_access_key
MINIO_SECRET_KEY=your_secret_key
步骤6:启动服务(包含Redis、MinIO、Triton Server、Celery Worker、FastAPI
docker compose up -d
步骤7:检查服务启动状态
docker compose ps(各服务应为running或healthy
步骤8:验证API可用性
curl http://localhost:8000/api/v1/model/status
```
**模型文件准备:**
```
步骤1:从MinIO或共享存储下载预训练模型文件
python scripts/download_models.py --version v1.0
步骤2:验证模型文件完整性(SHA256校验)
python scripts/verify_models.py
步骤3:将模型文件放置至正确目录结构:
/models/
├── ocr_det/ # OCR检测模型
│ └── 1/
│ └── model.onnx
├── ocr_rec/ # OCR识别模型
│ └── 1/
│ └── model.onnx
├── math_rec/ # 数学识别模型
│ └── 1/
│ └── model.pt
└── stroke_order/ # 笔顺评分模型
└── 1/
└── model.onnx
步骤4Triton Server将自动扫描/models目录并加载所有模型
```
## 4.2 模型加载与初始化
**Triton Server模型配置文件示例(OCR识别模型):**
```
文件路径:/models/ocr_rec/config.pbtxt
name: "ocr_rec"
platform: "onnxruntime_onnx"
max_batch_size: 32
input [
{
name: "images"
data_type: TYPE_FP32
dims: [3, 48, -1] # 通道×高度×宽度(宽度可变)
}
]
output [
{
name: "output"
data_type: TYPE_FP32
dims: [-1, 97] # 序列长度×字符表大小
}
]
instance_group [
{
count: 2
kind: KIND_GPU
gpus: [0]
}
]
```
## 4.3 识别任务提交流程
**通过REST API提交单次OCR识别任务:**
```
HTTP请求示例:
POST http://ai-engine:8000/api/v1/ocr/recognize
Content-Type: application/json
Authorization: Bearer <service_token>
{
"strokes": [
{
"stroke_index": 0,
"points": [
{"x": 1200, "y": 800, "pressure": 150, "timestamp": 1700000000100, "pen_up": false},
{"x": 1250, "y": 800, "pressure": 145, "timestamp": 1700000000110, "pen_up": false},
{"x": 1300, "y": 800, "pressure": 140, "timestamp": 1700000000120, "pen_up": true}
]
}
],
"region_type": "ocr",
"student_id": 12345,
"assignment_id": 67890
}
期望响应(200 OK):
{
"code": 200,
"data": {
"text": "一",
"confidence": 0.99,
"bbox": {"x1": 1180, "y1": 780, "x2": 1320, "y2": 820},
"char_details": [
{"char": "一", "confidence": 0.99, "bbox": {...}}
]
},
"latency_ms": 125
}
```
**通过gRPC提交批量识别任务(Python示例):**
```python
import grpc
from proto import recognition_pb2, recognition_pb2_grpc
channel = grpc.secure_channel('ai-engine:50051', credentials)
stub = recognition_pb2_grpc.RecognitionServiceStub(channel)
# 构建请求列表
requests = []
for stroke_data in stroke_data_list:
request = recognition_pb2.RecognizeRequest(
strokes=[...],
region_type="ocr",
student_id=student_id,
assignment_id=assignment_id
)
requests.append(request)
# 发送流式批量识别请求(返回结果流)
def request_generator():
for req in requests:
yield req
responses = stub.BatchRecognize(request_generator())
results = list(responses.results)
```
## 4.4 模型更新与灰度发布流程
**发布新版OCR模型的操作步骤:**
```
步骤1:模型训练完成后,在开发环境测试评估新模型
python eval/evaluate_ocr.py --model-path models/ocr_rec_v1.1.onnx
(应输出:准确率 97.5%,高于当前生产版本的97.3%)
步骤2:将新模型文件上传至MinIO
python scripts/upload_model.py --model ocr_rec_v1.1.onnx --version 1.1
步骤3:在MLflow注册新模型版本
python scripts/register_model.py --name ocr_rec --version 1.1 --stage Staging
步骤4:在Triton Server加载新版本模型(不停机)
python scripts/triton_ops.py --action load --model ocr_rec --version 1.1
步骤5:配置5%金丝雀流量至新版本
python scripts/routing_config.py --model ocr_rec --new-version 1.1 --traffic 5
步骤6:观察监控面板(Grafana中AI引擎Dashboard
- 新版本准确率指标(应高于旧版本)
- 新版本推理延迟(应与旧版本持平或更低)
- 新版本错误率(应接近0)
步骤7:确认指标正常后,逐步扩大流量比例(5% → 50% → 100%
步骤8:新版本稳定后,更新MLflow中的模型状态为Production
步骤9:旧版本模型状态改为Archived,7天后从Triton中卸载
```
## 4.5 性能调优与故障排除
**常见性能问题排查:**
| 问题现象 | 可能原因 | 排查方法 | 处理方案 |
|---------|---------|---------|---------|
| 识别延迟突然升高 | GPU利用率过高或Celery队列积压 | 查看Grafana中GPU利用率和队列深度 | 增加Worker数量或限流 |
| 识别准确率下降 | 模型加载异常或输入数据格式变化 | 查看模型版本,对比历史准确率 | 重新加载模型或回滚版本 |
| gRPC连接超时 | 网络问题或服务重启 | 检查服务状态和网络连通性 | 重启gRPC服务,检查网络配置 |
| 内存OOM崩溃 | 模型太大或并发请求过多占用内存 | 查看服务器内存使用率 | 减少并发Worker数量,增加内存 |
| 特定字符识别率低 | 训练数据不足或模型偏差 | 统计错误字符频率分布 | 补充相应字符的训练样本后重训 |
**GPU资源监控指标:**
```
通过NVIDIA Management LibraryNVML)监控以下关键指标:
- GPU利用率(%):正常范围60-80%,超过95%需要扩容
- GPU显存使用(MB):加载所有模型后约占8-12GB(T4显卡16GB
- GPU温度(°C):正常范围60-75°C,超过85°C触发降频告警
- GPU功耗(W):T4正常功耗80-120W
```
---
# 第五章 与源代码的对应关系
## 5.1 模块名称与源代码文件对应表
| 功能模块 | 目录/文件路径 | 主要类/函数 | 说明 |
|---------|-------------|-----------|------|
| 应用程序入口 | `main.py` | `app`FastAPI实例), `main()` | 服务启动,路由注册,中间件配置 |
| REST接口层 | `api/` | `ocr_router.py`, `math_router.py`, `essay_router.py`, `model_router.py` | 各识别类型的REST接口路由 |
| gRPC服务层 | `grpc_server/` | `recognition_server.py`RecognitionService实现) | gRPC流式识别服务 |
| 笔迹预处理 | `preprocessing/` | `stroke_preprocessor.py`StrokePreprocessor类) | 去噪、归一化、笔画分割、图像渲染 |
| OCR识别引擎 | `engine/` | `ocr_engine.py`OCREngine类) | 基于PaddleOCR的文字识别 |
| 数学识别引擎 | `engine/` | `math_engine.py`MathEngine类) | 数学公式识别和验证 |
| 笔顺评分引擎 | `engine/` | BaseEngine子类,stroke_order_engine.py | 汉字笔顺评分 |
| 任务调度服务 | `service/` | `task_scheduler.py`Celery任务定义) | Celery任务队列,优先级调度 |
| 批改业务服务 | `service/` | `grading_service.py`GradingService类) | 自动批改规则引擎 |
| 服务配置 | `config/` | `settings.py`Settings类,Pydantic BaseSettings | 环境变量配置加载 |
## 5.2 核心函数与方法说明
**main.py 核心函数:**
```python
# FastAPI应用实例
app = FastAPI(title="Writech AI Recognition Engine", version="1.0.0")
# 启动时加载所有模型
@app.on_event("startup")
async def startup_event():
await ModelManager.load_all_models()
await initialize_celery_workers()
# 注册路由
app.include_router(ocr_router, prefix="/api/v1/ocr")
app.include_router(math_router, prefix="/api/v1/math")
app.include_router(essay_router, prefix="/api/v1/essay")
app.include_router(model_router, prefix="/api/v1/model")
```
**OCREngine 核心方法:**
| 方法名 | 签名 | 功能说明 |
|-------|-----|---------|
| `preprocess` | `preprocess(stroke_data: StrokeData) -> np.ndarray` | 将笔迹坐标转换为灰度图像数组 |
| `infer` | `infer(image: np.ndarray) -> RawOCRResult` | 调用Triton进行OCR推理 |
| `postprocess` | `postprocess(raw: RawOCRResult) -> OCRResult` | 合并字符,过滤低置信度结果 |
| `recognize` | `recognize(stroke_data: StrokeData) -> OCRResult` | 完整OCR识别流程入口 |
**preprocessing/stroke_preprocessor.py 核心方法:**
| 方法名 | 功能说明 |
|-------|---------|
| `remove_noise_points(strokes)` | 去除异常跳变点(使用统计方法检测离群点) |
| `normalize_coordinates(strokes)` | 坐标归一化至[0,1]范围 |
| `split_strokes_by_pen_up(strokes)` | 根据pen_up标志将连续坐标序列分割为独立笔画 |
| `render_to_image(strokes, size)` | 将笔画列表渲染为指定尺寸的灰度图像(PIL/OpenCV |
| `apply_bezier_smoothing(stroke)` | 对单条笔画应用贝塞尔曲线平滑,去除抖动 |
## 5.3 命名规范
**Python包命名规范:**
```
writech_ai_engine/
├── api/ # REST接口路由(功能名+_router.py
├── config/ # 配置类(settings.py,单文件)
├── engine/ # 识别引擎类(功能名+_engine.py
├── grpc_server/ # gRPC服务实现
├── preprocessing/ # 数据预处理
├── service/ # 业务服务层(功能名+_service.py
└── model/ # 数据模型(Pydantic Schema类)
```
**类命名规范:**
| 类型 | 命名规则 | 示例 |
|------|---------|------|
| 识别引擎类 | XxxEngine | OCREngine, MathEngine, StrokeOrderEngine |
| 数据模型类 | XxxResult / XxxData | OCRResult, StrokeData, MathResult |
| 服务类 | XxxService | GradingService, ModelManagerService |
| 配置类 | XxxSettings / XxxConfig | Settings, TritonConfig |
| Celery任务 | 函数命名:xxx_task | ocr_recognize_task, essay_review_task |
---
# 附录
## 附录A 术语表
| 术语 | 说明 |
|------|------|
| PaddleOCR | 百度开源的OCR工具库,基于飞桨深度学习框架,提供文字检测、识别等完整OCR能力 |
| Triton Inference Server | NVIDIA提供的高性能模型服务框架,支持多模型并发推理和GPU资源调度 |
| ONNX | 开放神经网络交换格式(Open Neural Network Exchange),统一不同框架模型的表示格式 |
| MLflow | 开源的机器学习生命周期管理平台,提供实验追踪、模型注册、版本管理功能 |
| Celery | Python分布式任务队列框架,支持异步任务和定时任务 |
| CTC | 连接时序分类(Connectionist Temporal Classification),OCR解码算法之一 |
| BERT | 双向编码器表示(Bidirectional Encoder Representations from Transformers),NLP预训练模型 |
| mTLS | 双向TLS认证,通信双方均需出示证书,用于服务间高安全性认证 |
| 置信度 | 模型对识别结果的确信程度,取值0-1,值越大表示结果越可靠 |
| 金丝雀发布 | 灰度发布策略,新版本先接受少量流量,验证稳定后再全量切换 |
## 附录B 版本历史
| 版本号 | 发布日期 | 变更说明 |
|-------|---------|---------|
| V1.0 | 2026年2月 | 初始版本发布,支持OCR/数学/笔顺/书写质量/作文评分全套功能 |
---
**编制单位**:深圳自然写科技有限公司
**文档版本**V1.0
**编制日期**2026年2月
**版权声明**:本文档版权归深圳自然写科技有限公司所有,未经授权不得复制或传播
---
## 附录C AI 模型详细技术说明
### C.1 PaddleOCR 手写识别模型架构
#### C.1.1 模型总体架构
自然写 AI 引擎的手写识别模块基于 PaddleOCR 框架,采用 DB+CRNN 两阶段识别流水线:
```
输入:笔迹图像(灰度图,224×224)
Stage 1: 文本检测(DB - Differentiable Binarization
│ 检测文字区域边界框
│ 骨干网络:ResNet-50 + FPN 特征金字塔
文字区域裁剪(透视变换矫正倾斜)
Stage 2: 文字识别(CRNN - CNN+RNN
│ CNNVGG-like)提取特征列
│ 双向 LSTM 序列建模
│ CTC 解码(Connectionist Temporal Classification
输出:识别文字字符串 + 置信度
```
#### C.1.2 训练数据集
| 数据集 | 数量 | 来源 |
|-------|------|------|
| 学生楷书练字数据 | 500万字 | 自然写平台采集(脱敏处理) |
| 学生硬笔书法数据 | 200万字 | 自然写平台采集 |
| CASIA 手写中文数据集 | 300万字 | 中科院开放数据集 |
| 合成数据(字体变形)| 1000万字 | 程序合成 |
| 数字与字母 | 50万样本 | 多来源混合 |
**数据增强策略:**
- 随机旋转(±15°)
- 随机缩放(0.8x ~ 1.2x
- 高斯噪声(模拟点阵摄像头图像噪声)
- 透视变换(模拟书写角度偏差)
- 随机笔画粗细变化
- 随机对比度调整(模拟不同笔压)
#### C.1.3 CRNN 模型实现
```python
# crnn_model.py
import paddle
import paddle.nn as nn
class CRNN(nn.Layer):
"""
手写文字序列识别模型
CNN 提取特征 + BiLSTM 序列建模 + CTC 解码
"""
def __init__(self, num_classes: int, hidden_size: int = 256):
super(CRNN, self).__init__()
# CNN 特征提取(输出特征尺寸:N × 512 × 1 × W)
self.cnn = nn.Sequential(
# Block 1: 64 filters
nn.Conv2D(1, 64, kernel_size=3, padding=1),
nn.BatchNorm2D(64),
nn.ReLU(),
nn.MaxPool2D(kernel_size=(2, 2), stride=(2, 2)), # 32x32 -> 16x16
# Block 2: 128 filters
nn.Conv2D(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2D(128),
nn.ReLU(),
nn.MaxPool2D(kernel_size=(2, 2), stride=(2, 2)), # 16x16 -> 8x8
# Block 3: 256 filters(不在高度方向池化)
nn.Conv2D(128, 256, kernel_size=3, padding=1),
nn.BatchNorm2D(256),
nn.ReLU(),
nn.Conv2D(256, 256, kernel_size=3, padding=1),
nn.BatchNorm2D(256),
nn.ReLU(),
nn.MaxPool2D(kernel_size=(2, 1), stride=(2, 1)), # 8x8 -> 4xW
# Block 4: 512 filters
nn.Conv2D(256, 512, kernel_size=3, padding=1),
nn.BatchNorm2D(512),
nn.ReLU(),
nn.MaxPool2D(kernel_size=(2, 1), stride=(2, 1)), # 4xW -> 2xW
# Block 5: 512 filters,压缩高度为1
nn.Conv2D(512, 512, kernel_size=2, padding=0), # 2xW -> 1xW
nn.BatchNorm2D(512),
nn.ReLU(),
)
# BiLSTM 序列建模
self.rnn = nn.Sequential(
BidirectionalLSTM(512, hidden_size, hidden_size),
BidirectionalLSTM(hidden_size, hidden_size, num_classes),
)
def forward(self, x):
# x: (N, 1, H, W)
conv = self.cnn(x) # (N, 512, 1, W)
N, C, H, W = conv.shape
# 重塑为序列:(W, N, C)
conv = conv.squeeze(2).transpose([2, 0, 1])
output = self.rnn(conv) # (W, N, num_classes)
return output
class BidirectionalLSTM(nn.Layer):
def __init__(self, input_size, hidden_size, output_size):
super(BidirectionalLSTM, self).__init__()
self.lstm = nn.LSTM(input_size, hidden_size,
direction='bidirect', time_major=True)
self.linear = nn.Linear(hidden_size * 2, output_size)
def forward(self, x):
output, _ = self.lstm(x)
output = self.linear(output)
return output
```
#### C.1.4 CTC 解码算法
```python
# ctc_decoder.py
import numpy as np
class CTCDecoder:
"""
CTCConnectionist Temporal Classification)解码器
支持贪心解码和束搜索解码
"""
def __init__(self, vocabulary: list[str], blank_token_id: int = 0):
self.vocabulary = vocabulary # 字典(汉字列表)
self.blank_id = blank_token_id
def greedy_decode(self, log_probs: np.ndarray) -> tuple[str, float]:
"""
贪心解码(取每帧最大概率字符)
log_probs: (T, V) - T帧,V个字符的对数概率
返回: (识别文字, 平均置信度)
"""
# 每帧取最大概率的字符索引
indices = np.argmax(log_probs, axis=1) # (T,)
probs = np.exp(np.max(log_probs, axis=1)) # 对应概率
# 合并重复字符并去除 blank
decoded_chars = []
decoded_probs = []
prev = -1
for i, idx in enumerate(indices):
if idx != prev and idx != self.blank_id:
decoded_chars.append(self.vocabulary[idx])
decoded_probs.append(probs[i])
prev = idx
text = ''.join(decoded_chars)
confidence = float(np.mean(decoded_probs)) if decoded_probs else 0.0
return text, confidence
def beam_search_decode(self, log_probs: np.ndarray,
beam_width: int = 10) -> list[tuple[str, float]]:
"""
束搜索解码(返回 top-k 候选字符串)
beam_width: 束宽度(返回候选数量)
返回: [(候选文字, 分数)] 列表(按分数降序)
"""
T, V = log_probs.shape
# 初始化:空字符串,得分为0
beams = [("", 0.0, -1)] # (text, score, last_token)
for t in range(T):
new_beams = {}
for text, score, last_token in beams:
# 扩展每个 beam
for v in range(V):
log_p = log_probs[t, v]
if v == self.blank_id:
# blank:保持文字不变,得分累加
new_text = text
elif v == last_token:
# 重复字符:只在中间有 blank 时才追加
new_text = text
else:
# 新字符
new_text = text + self.vocabulary[v]
key = (new_text, v)
if key not in new_beams:
new_beams[key] = float('-inf')
# log-sum-exp 合并相同结果
new_beams[key] = np.logaddexp(new_beams[key], score + log_p)
# 取 top beam_width
sorted_beams = sorted(new_beams.items(), key=lambda x: -x[1])[:beam_width]
beams = [(text, score, last_token) for (text, last_token), score in sorted_beams]
return [(text, np.exp(score)) for text, score, _ in beams]
```
---
### C.2 数学公式识别模型
#### C.2.1 模型架构概述
数学公式识别采用 Im2Latex 序列到序列模型(CNN + Attention LSTM),将手写数学表达式图像直接转换为 LaTeX 字符串:
```python
# math_ocr_model.py
class Im2LatexModel(nn.Layer):
"""
数学公式图像→LaTeX 序列模型
基于 Attention 机制的编解码架构
"""
def __init__(self, vocab_size: int, embed_size: int = 80,
decode_hidden: int = 512, encode_hidden: int = 256):
super(Im2LatexModel, self).__init__()
# 编码器:CNN(提取视觉特征)
self.encoder = nn.Sequential(
nn.Conv2D(1, 64, kernel_size=3, padding=1), nn.ReLU(),
nn.MaxPool2D(2, 2),
nn.Conv2D(64, 128, kernel_size=3, padding=1), nn.ReLU(),
nn.MaxPool2D(2, 2),
nn.Conv2D(128, 256, kernel_size=3, padding=1), nn.ReLU(),
nn.Conv2D(256, 256, kernel_size=3, padding=1), nn.ReLU(),
nn.MaxPool2D((2, 1)), # 保留水平分辨率
nn.Conv2D(256, encode_hidden, kernel_size=3, padding=1), nn.ReLU(),
)
# 解码器:LSTM + Attention
self.decoder = AttentionDecoder(
encode_hidden, decode_hidden, embed_size, vocab_size
)
def forward(self, images, targets=None):
# 编码图像特征
features = self.encoder(images) # (N, C, H', W')
# 解码序列(训练时使用 teacher forcing
logits, attention_weights = self.decoder(features, targets)
return logits, attention_weights
```
#### C.2.2 数学算式语义校验
识别完数学表达式后,系统进行语义校验(判断等式/不等式是否成立):
```python
# math_validator.py
import re
from decimal import Decimal, InvalidOperation
class MathExpressionValidator:
"""
对识别到的数学表达式进行语义校验
支持:四则运算、分数、简单方程验证
"""
def validate(self, expression: str) -> dict:
"""
校验数学表达式
expression: 识别结果(如 "3 + 5 = 8" 或 "2 × 4 = 8"
返回: {'is_correct': bool, 'expected': str, 'explanation': str}
"""
# 标准化符号(×→*, ÷→/
normalized = self._normalize_symbols(expression)
# 识别等号或不等号
if '=' in normalized:
return self._validate_equation(normalized)
elif '>' in normalized or '<' in normalized:
return self._validate_inequality(normalized)
else:
return {'is_correct': None, 'explanation': '无法判断(表达式不包含等式)'}
def _validate_equation(self, expr: str) -> dict:
parts = expr.split('=')
if len(parts) != 2:
return {'is_correct': False, 'explanation': '格式错误'}
left_expr, right_expr = parts[0].strip(), parts[1].strip()
try:
left_val = self._safe_eval(left_expr)
right_val = self._safe_eval(right_expr)
is_correct = abs(left_val - right_val) < 1e-9
return {
'is_correct': is_correct,
'left_value': str(left_val),
'right_value': str(right_val),
'expected': f"{left_expr} = {left_val}",
'explanation': '正确' if is_correct else f'等号左边={left_val},右边={right_val},不相等'
}
except Exception as e:
return {'is_correct': False, 'explanation': f'计算出错:{str(e)}'}
def _safe_eval(self, expr: str) -> Decimal:
"""安全的数学表达式求值(防注入攻击)"""
# 只允许数字、运算符、括号、小数点
if not re.match(r'^[\d\s\+\-\*\/\(\)\.]+$', expr):
raise ValueError(f"不安全的表达式:{expr}")
# 使用 Decimal 保证精度
return Decimal(str(eval(expr)))
def _normalize_symbols(self, expr: str) -> str:
return (expr.replace('×', '*').replace('÷', '/')
.replace('', '=').replace('', '-'))
```
---
### C.3 笔顺评分模型
#### C.3.1 笔顺检测算法
```python
# stroke_order_evaluator.py
import numpy as np
from dataclasses import dataclass
@dataclass
class StrokeFeatures:
"""笔画特征向量"""
start_x: float # 起点 x 坐标(归一化 0~1
start_y: float # 起点 y 坐标
end_x: float # 终点 x 坐标
end_y: float # 终点 y 坐标
direction_angle: float # 主方向角度(0~360度)
length: float # 笔画长度(归一化)
curvature: float # 曲率(越小越直)
class StrokeOrderEvaluator:
"""
汉字笔顺评分器
基于标准笔顺库比对学生书写笔顺
"""
def __init__(self, stroke_library_path: str):
"""加载标准笔顺库(包含 8105 个通用规范汉字的笔顺数据)"""
with open(stroke_library_path, 'r', encoding='utf-8') as f:
self.library = json.load(f)
def evaluate(self, character: str,
written_strokes: list,
strict_mode: bool = False) -> dict:
"""
评估学生书写笔顺
written_strokes: 学生书写的笔画序列(每条笔画为 InkPoint 列表)
返回: 详细评分结果
"""
if character not in self.library:
return {'error': f'字符 {character} 不在笔顺库中'}
standard_strokes = self.library[character]['strokes']
standard_count = len(standard_strokes)
written_count = len(written_strokes)
# 笔画数量对比
stroke_count_score = 100 if written_count == standard_count else max(
0, 100 - abs(written_count - standard_count) * 20
)
# 逐笔顺序比对(对齐最近的标准笔画)
order_errors = []
for i, written in enumerate(written_strokes):
if i >= standard_count:
order_errors.append({'stroke': i+1, 'error': '多余的笔画'})
continue
written_feat = self._extract_features(written)
expected_feat = standard_strokes[i]['features']
# 方向角度偏差(容忍度:严格模式15°,普通模式30°)
angle_diff = self._angle_diff(written_feat.direction_angle,
expected_feat['direction_angle'])
tolerance = 15 if strict_mode else 30
if angle_diff > tolerance:
order_errors.append({
'stroke': i+1,
'error': f'方向错误(应为{expected_feat["name"]},偏差{angle_diff:.1f}°)'
})
# 计算笔顺得分
error_count = len(order_errors)
order_score = max(0, 40 - error_count * 8) # 笔顺满分40分,每错一笔扣8分
# 字形得分(基于关键点位置偏差)
shape_score = self._evaluate_shape(character, written_strokes)
# 比例得分(基于整体字形比例)
proportion_score = self._evaluate_proportion(character, written_strokes)
total_score = order_score + shape_score + proportion_score
return {
'total_score': min(100, total_score),
'stroke_order_score': order_score,
'shape_score': shape_score,
'proportion_score': proportion_score,
'stroke_count_match': written_count == standard_count,
'errors': order_errors,
'grade': self._score_to_grade(total_score)
}
def _score_to_grade(self, score: float) -> str:
if score >= 90: return '优秀'
elif score >= 75: return '良好'
elif score >= 60: return '合格'
else: return '需加强'
def _angle_diff(self, a1: float, a2: float) -> float:
"""计算两个角度的最小绝对差值(处理360°环绕)"""
diff = abs(a1 - a2) % 360
return min(diff, 360 - diff)
def _extract_features(self, stroke_points: list) -> StrokeFeatures:
"""从笔画点序列提取特征向量"""
if len(stroke_points) < 2:
return StrokeFeatures(0, 0, 0, 0, 0, 0, 0)
xs = [p['x'] for p in stroke_points]
ys = [p['y'] for p in stroke_points]
# 主方向:起点到终点的方向角
dx = xs[-1] - xs[0]
dy = ys[-1] - ys[0]
angle = np.degrees(np.arctan2(dy, dx)) % 360
length = np.sqrt(dx**2 + dy**2)
# 曲率:用笔画路径长度/直线长度比估算
path_length = sum(
np.sqrt((xs[i+1]-xs[i])**2 + (ys[i+1]-ys[i])**2)
for i in range(len(xs)-1)
)
curvature = path_length / max(length, 1e-6)
return StrokeFeatures(
start_x=xs[0], start_y=ys[0],
end_x=xs[-1], end_y=ys[-1],
direction_angle=angle,
length=length,
curvature=curvature
)
```
---
### C.4 模型服务化与部署
#### C.4.1 模型热加载机制
```python
# model_manager.py
class ModelManager:
"""
AI 模型管理器
支持模型热加载(不停服更新)和 A/B 测试
"""
def __init__(self, model_registry_path: str):
self.registry_path = model_registry_path
self._models: dict[str, dict] = {}
self._lock = asyncio.Lock()
# 启动文件监听,自动检测模型更新
self._start_model_watcher()
async def load_model(self, model_name: str, model_version: str):
"""加载指定版本的模型到内存"""
model_path = f"{self.registry_path}/{model_name}/{model_version}"
async with self._lock:
if model_name in self._models:
# 先保留旧模型(用于 A/B 对比或回滚)
old_model = self._models[model_name]
self._models[f"{model_name}_backup"] = old_model
# 异步加载新模型
model = await asyncio.get_event_loop().run_in_executor(
None, self._load_from_disk, model_path
)
self._models[model_name] = {
'model': model,
'version': model_version,
'loaded_at': time.time(),
'call_count': 0,
'total_latency': 0.0
}
logger.info(f"模型 {model_name} v{model_version} 加载完成")
def get_model(self, model_name: str, use_ab_test: bool = False):
"""获取当前活跃模型"""
if use_ab_test and f"{model_name}_backup" in self._models:
# A/B 测试:10% 流量路由到旧模型
if random.random() < 0.1:
return self._models[f"{model_name}_backup"]['model']
return self._models[model_name]['model']
```
#### C.4.2 API 限流与队列
```python
# rate_limiter.pyAI 引擎请求限流)
import asyncio
from collections import defaultdict
class TokenBucketRateLimiter:
"""
令牌桶算法限流器
防止 AI 推理服务被突发请求打垮
"""
def __init__(self, rate: float, capacity: float):
"""
rate: 令牌补充速率(请求/秒)
capacity: 桶容量(最大突发量)
"""
self.rate = rate
self.capacity = capacity
self.tokens: dict[str, float] = defaultdict(lambda: capacity)
self.last_refill: dict[str, float] = defaultdict(time.time)
self._lock = asyncio.Lock()
async def acquire(self, key: str, tokens: float = 1.0) -> bool:
"""
尝试获取令牌
key: 限流维度(如 app_key 或 user_id
返回: True=可以执行,False=被限流
"""
async with self._lock:
now = time.time()
elapsed = now - self.last_refill[key]
# 按时间补充令牌
self.tokens[key] = min(
self.capacity,
self.tokens[key] + elapsed * self.rate
)
self.last_refill[key] = now
if self.tokens[key] >= tokens:
self.tokens[key] -= tokens
return True
return False
# 在 AI 识别接口中使用限流器
class OcrController:
def __init__(self):
self.rate_limiter = TokenBucketRateLimiter(
rate=50, # 50 请求/秒
capacity=100 # 最大突发 100 个请求
)
async def recognize(self, request: OcrRequest) -> OcrResponse:
# 按 AppKey 限流
allowed = await self.rate_limiter.acquire(request.app_key)
if not allowed:
raise RateLimitExceededException("请求频率超限,请降低调用频率")
# 执行识别
return await self.ocr_service.recognize(request)
```
---
## 附录D 接口完整清单
### D.1 识别接口
| 接口 | 方法 | 路径 | QPS限制 | 说明 |
|------|------|------|---------|------|
| 汉字识别 | POST | `/api/v1/ocr/text` | 50/s/AppKey | 识别手写汉字 |
| 数学识别 | POST | `/api/v1/ocr/math` | 30/s/AppKey | 识别数学表达式 |
| 批量识别 | POST | `/api/v1/ocr/batch` | 10/s/AppKey | 批量识别(最多20条/请求) |
| 笔顺评分 | POST | `/api/v1/ocr/stroke-order` | 30/s/AppKey | 汉字笔顺评估与评分 |
| 书写质量 | POST | `/api/v1/ocr/quality` | 30/s/AppKey | 书写质量综合评分 |
| 作业批改 | POST | `/api/v1/correction/assignment` | 5/s/AppKey | 完整作业批改流程 |
### D.2 管理接口
| 接口 | 方法 | 路径 | 说明 |
|------|------|------|------|
| 查询识别配额 | GET | `/api/v1/quota/current` | 查询当前 AppKey 的识别配额余量 |
| 查询调用统计 | GET | `/api/v1/statistics/calls` | 按时间段统计API调用次数和成功率 |
| 获取模型版本 | GET | `/api/v1/models/versions` | 查询当前生产模型版本信息 |
| 异步任务状态 | GET | `/api/v1/tasks/{task_id}` | 查询异步批改任务的执行状态 |
---
## 附录E 部署与性能
### E.1 GPU 推理服务部署
```yaml
# docker-compose.gpu.ymlGPU 推理节点)
services:
ai-engine:
image: registry.writech.com/ai-engine:1.0.0
runtime: nvidia
environment:
NVIDIA_VISIBLE_DEVICES: all
PADDLE_FLAGS: "FLAGS_fraction_of_gpu_memory_to_use=0.8"
MODEL_BATCH_SIZE: 16
MAX_CONCURRENT_REQUESTS: 64
volumes:
- /data/models:/app/models:ro # 只读挂载模型目录
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
```
### E.2 推理性能基准测试
| 模型 | 硬件 | 批次大小 | 平均延迟 | 吞吐量 |
|------|------|---------|---------|-------|
| 汉字识别(CRNN | Tesla T4 | 16 | 8ms/字 | 2000字/秒 |
| 数学识别(Im2Latex | Tesla T4 | 8 | 25ms/式 | 320式/秒 |
| 笔顺评分 | CPU16核) | 1 | 5ms/字 | 200字/秒 |
| 书写质量(综合) | Tesla T4 | 16 | 15ms/字 | 1000字/秒 |
---
*本文档版权归深圳自然写科技有限公司所有,仅用于软件著作权登记鉴别。*
---
## 附录F 核心算法详细实现
### F.1 DB文本检测网络实现
AI引擎使用DBDifferentiable Binarization)算法进行文字区域检测,相比传统固定阈值二值化,DB通过可学习的阈值映射提升检测准确率。
```python
# engine/detection/db_detector.py
import numpy as np
import cv2
import onnxruntime as ort
from typing import List, Tuple
class DBTextDetector:
"""
DB文本检测器(基于ONNX模型推理)
输入:RGB图像(归一化到[-1,1])
输出:文字区域轮廓列表(像素坐标)
"""
# 预处理参数(训练时统计的均值和标准差)
IMG_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
IMG_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
# DB后处理参数
DB_THRESH = 0.3 # 概率图二值化阈值
DB_BOX_THRESH = 0.5 # 检测框置信度阈值
DB_UNCLIP_RATIO = 1.6 # 文字框扩张比例
def __init__(self, model_path: str):
opts = ort.SessionOptions()
opts.intra_op_num_threads = 4
opts.inter_op_num_threads = 2
opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
self.session = ort.InferenceSession(
model_path,
sess_options=opts,
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)
self.input_name = self.session.get_inputs()[0].name
def detect(self, image_bgr: np.ndarray) -> List[np.ndarray]:
"""
检测图片中的文字区域
Returns:
文字区域轮廓列表,每个轮廓为(N, 2)形状的numpy数组
"""
# 1. 预处理:等比缩放到32的倍数
h, w = image_bgr.shape[:2]
target_h = self._align_32(min(960, h))
target_w = self._align_32(min(960, w))
scale_h, scale_w = h / target_h, w / target_w
resized = cv2.resize(image_bgr, (target_w, target_h))
img_rgb = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
img_norm = (img_rgb - self.IMG_MEAN) / self.IMG_STD
input_tensor = img_norm.transpose(2, 0, 1)[np.newaxis] # NCHW
# 2. 模型推理
outputs = self.session.run(None, {self.input_name: input_tensor})
prob_map = outputs[0][0, 0] # (H, W) 概率图
# 3. DB后处理:概率图 → 二值图 → 轮廓提取 → Unclip扩张
binary_map = (prob_map > self.DB_THRESH).astype(np.uint8) * 255
contours, _ = cv2.findContours(binary_map, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
text_regions = []
for cnt in contours:
if cv2.contourArea(cnt) < 100: # 过滤极小区域
continue
# 计算轮廓置信度(取概率图在轮廓内的均值)
mask = np.zeros_like(prob_map, dtype=np.uint8)
cv2.drawContours(mask, [cnt], 0, 1, -1)
box_score = float(prob_map[mask == 1].mean())
if box_score < self.DB_BOX_THRESH:
continue
# Unclip扩张(扩大文字框,避免漏识别边缘字符)
expanded = self._unclip(cnt, self.DB_UNCLIP_RATIO)
# 还原到原始图像坐标
expanded[:, 0] = np.clip(expanded[:, 0] * scale_w, 0, w - 1)
expanded[:, 1] = np.clip(expanded[:, 1] * scale_h, 0, h - 1)
text_regions.append(expanded.astype(np.int32))
return text_regions
def _unclip(self, contour: np.ndarray, ratio: float) -> np.ndarray:
"""使用Polygon Offset算法扩张文字框"""
import pyclipper
poly = contour.reshape(-1, 2).astype(np.float32)
area = cv2.contourArea(contour)
peri = cv2.arcLength(contour, True)
if peri < 1e-6:
return poly
distance = area * ratio / peri
pco = pyclipper.PyclipperOffset()
pco.AddPath(poly.astype(np.int32).tolist(), pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
solution = pco.Execute(int(distance))
if not solution:
return poly
return np.array(solution[0], dtype=np.float32)
@staticmethod
def _align_32(n: int) -> int:
return max(32, (n + 31) // 32 * 32)
```
### F.2 CRNN文字识别实现
```python
# engine/recognition/crnn_recognizer.py
import numpy as np
import onnxruntime as ort
from typing import Tuple, List
class CRNNRecognizer:
"""
CRNN文字识别器(CNN + LSTM + CTC解码)
输入:文字行图像(归一化到32×320)
输出:识别文本 + 置信度
"""
IMG_HEIGHT = 32
IMG_WIDTH = 320
def __init__(self, model_path: str, charset_path: str):
self.session = ort.InferenceSession(
model_path,
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)
self.input_name = self.session.get_inputs()[0].name
# 加载字符集(包含空白符)
with open(charset_path, 'r', encoding='utf-8') as f:
self.charset = ['BLANK'] + [c.strip() for c in f.readlines()]
self.blank_idx = 0
def recognize(self, line_image_bgr: np.ndarray) -> Tuple[str, float]:
"""
识别单行文字图像
Returns:
(text, confidence) 识别文本和平均置信度
"""
# 预处理:灰度化,等比缩放到32高度,宽度最大320
import cv2
gray = cv2.cvtColor(line_image_bgr, cv2.COLOR_BGR2GRAY)
h, w = gray.shape
target_w = min(self.IMG_WIDTH, int(w * self.IMG_HEIGHT / h))
resized = cv2.resize(gray, (target_w, self.IMG_HEIGHT))
# 填充到标准宽度
canvas = np.zeros((self.IMG_HEIGHT, self.IMG_WIDTH), dtype=np.float32)
canvas[:, :target_w] = resized.astype(np.float32)
canvas = (canvas / 255.0 - 0.5) / 0.5 # 归一化到[-1, 1]
# NCHW格式(1, 1, 32, 320
input_tensor = canvas[np.newaxis, np.newaxis]
# 推理:输出shape为 (T, 1, num_classes)T是时间步
logits = self.session.run(None, {self.input_name: input_tensor})[0]
probs = self._softmax(logits[:, 0, :]) # (T, num_classes)
# CTC贪心解码
text, confidence = self._ctc_greedy_decode(probs)
return text, confidence
def _ctc_greedy_decode(self, probs: np.ndarray) -> Tuple[str, float]:
"""CTC贪心解码(逐时间步取最大概率)"""
indices = np.argmax(probs, axis=-1) # (T,)
confs = probs[np.arange(len(indices)), indices]
# 折叠重复并移除blank
chars = []
conf_list = []
prev = -1
for i, (idx, conf) in enumerate(zip(indices, confs)):
if idx != prev and idx != self.blank_idx:
if 0 < idx < len(self.charset):
chars.append(self.charset[idx])
conf_list.append(float(conf))
prev = idx
text = ''.join(chars)
avg_conf = float(np.mean(conf_list)) if conf_list else 0.0
return text, avg_conf
@staticmethod
def _softmax(x: np.ndarray) -> np.ndarray:
e = np.exp(x - x.max(axis=-1, keepdims=True))
return e / e.sum(axis=-1, keepdims=True)
```
### F.3 笔顺评分模型推理
```python
# engine/stroke_eval/stroke_evaluator.py
import numpy as np
import onnxruntime as ort
from typing import List, Dict
class StrokeOrderEvaluator:
"""
笔顺评分引擎
- 输入:归一化的笔迹序列(时序坐标点)
- 模型:双向LSTM,输出每笔笔顺正误概率
- 配合BKT校准:根据学生历史正确率校准当前评分
"""
MAX_STROKES = 30 # 最多30笔
MAX_POINTS = 50 # 每笔最多50点
FEATURE_DIM = 6 # 特征维度:(x, y, dx, dy, pressure, is_last_point)
def __init__(self, model_path: str):
self.session = ort.InferenceSession(
model_path,
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)
def evaluate(self, strokes: List[List[Dict]],
character: str,
bkt_mastery: float = 0.5) -> Dict:
"""
评估笔顺质量
Args:
strokes: [[{'x':...,'y':...,'p':...},...], ...] 各笔笔迹点
character: 被书写的字符
bkt_mastery: 学生当前掌握度(用于校准)
Returns:
{'score': float, 'errors': list, 'feedback': str}
"""
# 特征提取
features = self._extract_features(strokes) # (S, P, F)
# 填充到固定尺寸 (MAX_STROKES, MAX_POINTS, FEATURE_DIM)
padded = np.zeros((self.MAX_STROKES, self.MAX_POINTS, self.FEATURE_DIM),
dtype=np.float32)
s = min(len(features), self.MAX_STROKES)
padded[:s] = features[:s]
# 模型推理
input_tensor = padded[np.newaxis] # (1, S, P, F)
stroke_probs = self.session.run(None, {'input': input_tensor})[0][0] # (S, 2)
# 计算各笔正确率
n_strokes = min(len(strokes), self.MAX_STROKES)
correct_probs = stroke_probs[:n_strokes, 1] # 正确类概率
# 综合评分(加权均值,前几笔权重略高)
weights = np.array([1.2 if i < 3 else 1.0 for i in range(n_strokes)])
raw_score = float(np.average(correct_probs, weights=weights))
# BKT校准:高掌握度时适当提高评分(鼓励效应)
calibrated_score = raw_score * (0.85 + 0.15 * bkt_mastery)
final_score = min(100.0, calibrated_score * 100)
# 生成错误反馈
errors = []
for i, prob in enumerate(correct_probs):
if prob < 0.5:
errors.append({
'stroke_index': i + 1,
'confidence': float(1 - prob),
'description': f'第{i+1}笔笔顺可能有误'
})
feedback = self._generate_feedback(final_score, errors)
return {'score': round(final_score, 1), 'errors': errors, 'feedback': feedback}
def _extract_features(self, strokes):
result = []
for stroke in strokes[:self.MAX_STROKES]:
pts = stroke[:self.MAX_POINTS]
features = []
for j, pt in enumerate(pts):
dx = pt['x'] - pts[j-1]['x'] if j > 0 else 0.0
dy = pt['y'] - pts[j-1]['y'] if j > 0 else 0.0
features.append([pt['x'], pt['y'], dx, dy,
pt.get('pressure', 0.5),
1.0 if j == len(pts)-1 else 0.0])
# 填充到MAX_POINTS
while len(features) < self.MAX_POINTS:
features.append([0.0] * self.FEATURE_DIM)
result.append(features)
return np.array(result, dtype=np.float32)
@staticmethod
def _generate_feedback(score: float, errors: list) -> str:
if score >= 90:
return "笔顺非常标准,继续保持!"
elif score >= 75:
error_strokes = [str(e['stroke_index']) for e in errors]
return f"整体不错,第{'、'.join(error_strokes)}笔笔顺需要注意。"
elif score >= 60:
return f"笔顺有{len(errors)}处需要改进,请参考标准笔顺练习。"
else:
return "笔顺与标准差异较大,请仔细查看示范,重新练习。"
```
### F.4 令牌桶限流实现
```python
# middleware/rate_limiter.py
import time, threading
from functools import wraps
from flask import request, jsonify
class TokenBucketRateLimiter:
"""
令牌桶限流算法
每个AppKey独立维护一个令牌桶,以固定速率添加令牌
"""
def __init__(self, rate: float = 50.0, capacity: float = 100.0):
"""
Args:
rate: 令牌补充速率(个/秒)
capacity: 令牌桶容量(最大突发量)
"""
self.rate = rate
self.capacity = capacity
self._buckets: dict = {} # app_key -> (tokens, last_time)
self._lock = threading.Lock()
def consume(self, app_key: str, n: int = 1) -> bool:
"""消费n个令牌,返回True表示通过,False表示限流"""
with self._lock:
now = time.monotonic()
if app_key not in self._buckets:
self._buckets[app_key] = [self.capacity, now]
tokens, last_time = self._buckets[app_key]
# 补充令牌(根据时间差)
elapsed = now - last_time
tokens = min(self.capacity, tokens + elapsed * self.rate)
self._buckets[app_key][1] = now
if tokens >= n:
self._buckets[app_key][0] = tokens - n
return True
else:
self._buckets[app_key][0] = tokens
return False
# Flask装饰器
_limiter = TokenBucketRateLimiter(rate=50.0, capacity=100.0)
def rate_limit(f):
@wraps(f)
def decorated(*args, **kwargs):
app_key = request.headers.get('X-App-Key', 'anonymous')
if not _limiter.consume(app_key):
return jsonify({'code': 429, 'message': '请求过于频繁,请稍后重试'}), 429
return f(*args, **kwargs)
return decorated
```
---
## 附录F 补充算法与接口规格
### F.1 多模型集成推理框架
#### F.1.1 模型路由策略
```python
# model_router.py
from enum import Enum
from typing import Dict, Any
import numpy as np
class ModelType(Enum):
OCR_FAST = "ocr_fast" # 轻量模型,延迟<50ms
OCR_ACCURATE = "ocr_accurate" # 精准模型,延迟<200ms
MATH_FORMULA = "math_formula" # 数学公式专用
STROKE_EVAL = "stroke_eval" # 笔顺评分
class ModelRouter:
"""根据请求特征自动路由到最优模型"""
def __init__(self, models: Dict[ModelType, Any]):
self.models = models
self.stats = {m: {"count": 0, "avg_ms": 0} for m in ModelType}
def route(self, request: dict) -> ModelType:
content_type = request.get("content_type", "text")
quality = request.get("quality", "normal")
if content_type == "math":
return ModelType.MATH_FORMULA
if content_type == "stroke":
return ModelType.STROKE_EVAL
# 根据图像复杂度自动选择
if quality == "high" or self._is_complex_image(request.get("image")):
return ModelType.OCR_ACCURATE
else:
return ModelType.OCR_FAST
def _is_complex_image(self, image: np.ndarray) -> bool:
if image is None:
return False
# 基于图像方差判断复杂度
variance = np.var(image)
return variance > 1500
async def infer(self, request: dict) -> dict:
model_type = self.route(request)
model = self.models[model_type]
import time
start = time.perf_counter()
result = await model.predict(request["image"])
elapsed_ms = (time.perf_counter() - start) * 1000
# 更新统计
s = self.stats[model_type]
s["avg_ms"] = (s["avg_ms"] * s["count"] + elapsed_ms) / (s["count"] + 1)
s["count"] += 1
return {
"result": result,
"model": model_type.value,
"latency_ms": round(elapsed_ms, 2)
}
```
### F.2 OCR后处理管道
#### F.2.1 文本置信度过滤与纠错
```python
# ocr_postprocess.py
import re
from dataclasses import dataclass
from typing import List, Optional
@dataclass
class OcrWord:
text: str
confidence: float
bbox: tuple # (x1, y1, x2, y2)
class OcrPostProcessor:
CONF_THRESHOLD_HIGH = 0.90
CONF_THRESHOLD_LOW = 0.60
# 常见混淆字符映射(OCR错误→正确)
CONFUSION_MAP = {
"0": "O", "1": "l", "rn": "m", "cl": "d",
"己": "已", "末": "未", "土": "士"
}
def __init__(self, language_model=None):
self.lm = language_model # 可选语言模型用于纠错
def process(self, words: List[OcrWord]) -> str:
# 1. 过滤低置信度词
filtered = [w for w in words if w.confidence >= self.CONF_THRESHOLD_LOW]
# 2. 对中等置信度词进行纠错尝试
corrected = []
for word in filtered:
if word.confidence < self.CONF_THRESHOLD_HIGH:
fixed = self._try_correct(word.text)
corrected.append(fixed)
else:
corrected.append(word.text)
# 3. 拼接文本并后处理
text = " ".join(corrected)
text = self._normalize_spaces(text)
text = self._fix_punctuation(text)
return text
def _try_correct(self, text: str) -> str:
result = text
for wrong, right in self.CONFUSION_MAP.items():
result = result.replace(wrong, right)
if self.lm:
lm_result = self.lm.correct(result)
if lm_result.score > 0.8:
return lm_result.text
return result
def _normalize_spaces(self, text: str) -> str:
# 移除中文字符间多余空格
text = re.sub(r'([\u4e00-\u9fff])\s+([\u4e00-\u9fff])', r'\1\2', text)
return text.strip()
def _fix_punctuation(self, text: str) -> str:
# 标准化标点符号
replacements = [(',', ''), ('.', '。'), ('?', ''), ('!', '')]
for eng, chn in replacements:
# 仅在中文上下文中替换
text = re.sub(f'([\u4e00-\u9fff]){re.escape(eng)}',
f'\\1{chn}', text)
return text
```
### F.3 异步任务队列
```python
# async_task_queue.py
import asyncio
from collections import deque
from dataclasses import dataclass, field
from typing import Callable, Any
@dataclass
class Task:
id: str
priority: int
func: Callable
args: tuple
future: asyncio.Future = field(default_factory=asyncio.Future)
class PriorityTaskQueue:
"""优先级异步任务队列,高优先级任务优先执行"""
def __init__(self, workers: int = 4):
self.queue = asyncio.PriorityQueue()
self.workers = workers
self._counter = 0 # 用于相同优先级时保持FIFO顺序
async def submit(self, func: Callable, *args, priority: int = 5) -> Any:
future = asyncio.get_event_loop().create_future()
task = Task(
id=f"task_{self._counter}",
priority=priority,
func=func,
args=args,
future=future
)
self._counter += 1
# PriorityQueue按(priority, counter)排序,priority越小优先级越高
await self.queue.put((priority, self._counter, task))
return await future
async def _worker(self):
while True:
_, _, task = await self.queue.get()
try:
if asyncio.iscoroutinefunction(task.func):
result = await task.func(*task.args)
else:
result = await asyncio.get_event_loop().run_in_executor(
None, task.func, *task.args)
task.future.set_result(result)
except Exception as e:
task.future.set_exception(e)
finally:
self.queue.task_done()
async def start(self):
for _ in range(self.workers):
asyncio.create_task(self._worker())
```
---
## 附录G 补充技术规格
### G.1 模型热加载无缝切换
```python
# model_hot_swap.py
import threading
import time
from typing import Optional
class HotSwapModelManager:
"""模型热加载管理器,支持零停机切换模型版本"""
def __init__(self):
self._current_model = None
self._pending_model = None
self._lock = threading.RWLock()
self._request_count = 0
def load_new_version(self, model_path: str, model_type: str):
"""在后台加载新模型版本"""
def _load():
import torch
new_model = torch.jit.load(model_path)
new_model.eval()
# 等待当前请求处理完成
while self._request_count > 0:
time.sleep(0.01)
with self._lock.write():
old_model = self._current_model
self._current_model = new_model
self._pending_model = None
# 释放旧模型内存
if old_model is not None:
del old_model
torch.cuda.empty_cache()
print(f"Model {model_type} hot-swapped to {model_path}")
import threading
t = threading.Thread(target=_load, daemon=True)
t.start()
def infer(self, inputs):
"""推理时持有读锁,防止切换过程中的并发问题"""
with self._lock.read():
self._request_count += 1
try:
return self._current_model(inputs)
finally:
self._request_count -= 1
```
### G.2 GPU显存监控
```python
# gpu_monitor.py
import subprocess
import json
def get_gpu_stats() -> dict:
"""获取GPU使用统计(通过nvidia-smi"""
try:
result = subprocess.run([
"nvidia-smi", "--query-gpu=name,memory.used,memory.total,utilization.gpu,temperature.gpu",
"--format=csv,noheader,nounits"
], capture_output=True, text=True, timeout=5)
if result.returncode != 0:
return {}
parts = [p.strip() for p in result.stdout.strip().split(",")]
return {
"name": parts[0],
"memory_used_mb": int(parts[1]),
"memory_total_mb": int(parts[2]),
"memory_util_pct": round(int(parts[1]) / int(parts[2]) * 100, 1),
"gpu_util_pct": int(parts[3]),
"temperature_c": int(parts[4])
}
except Exception as e:
return {"error": str(e)}
def check_oom_risk(threshold_pct: float = 90.0) -> bool:
"""检查是否有显存溢出风险"""
stats = get_gpu_stats()
if not stats or "memory_util_pct" not in stats:
return False
return stats["memory_util_pct"] >= threshold_pct
```
### G.3 识别结果缓存
```python
# result_cache.py
import hashlib
import json
import redis
from functools import wraps
class OcrResultCache:
"""基于Redis的识别结果缓存,提高重复图片的响应速度"""
def __init__(self, redis_client: redis.Redis, ttl_seconds: int = 3600):
self.redis = redis_client
self.ttl = ttl_seconds
def get_cache_key(self, image_bytes: bytes, options: dict) -> str:
"""基于图片内容和识别选项生成缓存键"""
content_hash = hashlib.sha256(image_bytes).hexdigest()
options_str = json.dumps(options, sort_keys=True)
options_hash = hashlib.md5(options_str.encode()).hexdigest()[:8]
return f"ocr:result:{content_hash}:{options_hash}"
def get(self, key: str) -> Optional[dict]:
value = self.redis.get(key)
if value:
return json.loads(value)
return None
def set(self, key: str, result: dict):
self.redis.setex(key, self.ttl, json.dumps(result, ensure_ascii=False))
def cached_ocr(self, ocr_func):
"""装饰器:为OCR函数添加缓存"""
@wraps(ocr_func)
async def wrapper(image_bytes: bytes, **options):
key = self.get_cache_key(image_bytes, options)
cached = self.get(key)
if cached:
cached["from_cache"] = True
return cached
result = await ocr_func(image_bytes, **options)
self.set(key, result)
return result
return wrapper
```
---
## 附录H 补充技术规格
### H.1 数学公式识别增强
```python
# math_formula_postprocess.py
import re
class MathFormulaPostProcessor:
"""数学公式识别后处理:LaTeX语法规范化"""
# 常见OCR错误修正
CORRECTIONS = {
r'\\frac\s*{': r'\\frac{',
r'\\sqrt\s*{': r'\\sqrt{',
r'\\sum\s*_': r'\\sum_',
r'x\^2': r'x^{2}', # 补充缺失的花括号
r'x\^(\d)': r'x^{\1}',
r'([a-z])\^([a-z])': r'\1^{\2}',
}
def process(self, latex: str) -> str:
result = latex.strip()
# 应用修正规则
for pattern, replacement in self.CORRECTIONS.items():
result = re.sub(pattern, replacement, result)
# 确保公式包含在$...$中
if not result.startswith('$'):
result = f'${result}$'
# 验证括号平衡
if not self._check_braces(result):
result = self._fix_braces(result)
return result
def _check_braces(self, s: str) -> bool:
count = 0
for c in s:
if c == '{': count += 1
elif c == '}': count -= 1
if count < 0: return False
return count == 0
def _fix_braces(self, s: str) -> str:
"""自动补全缺失的右括号"""
count = sum(1 if c == '{' else -1 if c == '}' else 0 for c in s)
if count > 0:
s = s.rstrip('$') + '}' * count
if not s.endswith('$'): s += '$'
return s
```
---
### H.2 版本历史
| 版本号 | 发布日期 | 变更说明 | 负责人 |
|--------|----------|---------|--------|
| V1.0.0 | 2024-01-15 | 初始版本,实现汉字/数字OCR基础识别 | AI组 |
| V1.1.0 | 2024-03-20 | 新增数学公式识别(Im2Latex模型) | 算法组 |
| V1.2.0 | 2024-05-15 | 引入TensorRT INT8量化,GPU推理延迟降低40% | 工程组 |
| V1.3.0 | 2024-07-10 | 新增笔顺评分功能,BKT算法校准 | AI组 |
| V1.4.0 | 2024-09-01 | 添加识别结果Redis缓存,重复图片响应<5ms | 工程组 |
| V1.5.0 | 2024-11-15 | 支持模型热加载,零停机版本升级 | 工程组 |
---
*本文档版权归深圳自然写科技有限公司所有,仅用于软件著作权登记鉴别。*