359 lines
13 KiB
Python
359 lines
13 KiB
Python
# 自然写手写识别与AI分析引擎软件 V1.0
|
||
# gRPC批量识别服务模块 - 高性能流式批量笔迹识别
|
||
|
||
"""
|
||
gRPC推理服务
|
||
提供高性能流式批量笔迹识别接口
|
||
采用gRPC双向流模式,适用于教室场景下多支笔并发识别需求
|
||
支持服务端流式响应,实现低延迟识别结果推送
|
||
"""
|
||
|
||
import time
|
||
import json
|
||
import logging
|
||
import uuid
|
||
import asyncio
|
||
from typing import List, Dict, Optional, AsyncIterator
|
||
from dataclasses import dataclass, field
|
||
from enum import Enum
|
||
from concurrent import futures
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# ==================== gRPC消息定义(等效Proto) ====================
|
||
|
||
class RecognitionType(str, Enum):
|
||
"""识别类型枚举"""
|
||
OCR = "ocr" # 文字识别
|
||
MATH = "math" # 数学识别
|
||
STROKE_ORDER = "stroke_order" # 笔顺评分
|
||
ESSAY = "essay" # 作文批改
|
||
|
||
|
||
@dataclass
|
||
class StrokePoint:
|
||
"""笔迹坐标点(对应protobuf StrokePoint message)"""
|
||
x: float
|
||
y: float
|
||
pressure: float = 0.5
|
||
timestamp: int = 0
|
||
|
||
|
||
@dataclass
|
||
class StrokeData:
|
||
"""笔迹数据(对应protobuf StrokeData message)"""
|
||
stroke_id: str = ""
|
||
pen_id: str = ""
|
||
page_id: str = ""
|
||
student_id: str = ""
|
||
strokes: List[List[StrokePoint]] = field(default_factory=list)
|
||
|
||
|
||
@dataclass
|
||
class RecognitionRequest:
|
||
"""识别请求(对应protobuf RecognitionRequest message)"""
|
||
request_id: str = ""
|
||
recognition_type: RecognitionType = RecognitionType.OCR
|
||
stroke_data: Optional[StrokeData] = None
|
||
priority: int = 2 # 0=最高优先级,4=最低
|
||
callback_topic: str = "" # 结果回调MQTT Topic
|
||
timeout_ms: int = 5000 # 超时时间
|
||
|
||
|
||
@dataclass
|
||
class RecognitionResult:
|
||
"""识别结果(对应protobuf RecognitionResult message)"""
|
||
request_id: str = ""
|
||
recognition_type: str = ""
|
||
status: str = "success" # success / error / timeout
|
||
result_text: str = ""
|
||
confidence: float = 0.0
|
||
details: Dict = field(default_factory=dict)
|
||
processing_time_ms: float = 0.0
|
||
model_version: str = ""
|
||
|
||
|
||
# ==================== 批量识别处理器 ====================
|
||
|
||
class BatchRecognitionProcessor:
|
||
"""
|
||
批量识别处理器
|
||
将多个识别请求按类型分组,批量送入GPU推理
|
||
通过批处理显著提升GPU利用率和吞吐量
|
||
"""
|
||
|
||
def __init__(self, max_batch_size: int = 32, max_wait_ms: int = 50):
|
||
self._max_batch_size = max_batch_size
|
||
self._max_wait_ms = max_wait_ms
|
||
self._pending_requests: Dict[str, List[RecognitionRequest]] = {
|
||
rt.value: [] for rt in RecognitionType
|
||
}
|
||
self._results: Dict[str, RecognitionResult] = {}
|
||
logger.info(f"批量识别处理器初始化: batch_size={max_batch_size}, wait_ms={max_wait_ms}")
|
||
|
||
def add_request(self, request: RecognitionRequest) -> str:
|
||
"""添加识别请求到批处理队列"""
|
||
if not request.request_id:
|
||
request.request_id = str(uuid.uuid4())
|
||
|
||
queue = self._pending_requests.get(request.recognition_type.value, [])
|
||
queue.append(request)
|
||
self._pending_requests[request.recognition_type.value] = queue
|
||
|
||
logger.debug(f"请求入队: id={request.request_id}, type={request.recognition_type.value}")
|
||
|
||
# 当队列达到批大小时触发批处理
|
||
if len(queue) >= self._max_batch_size:
|
||
self._process_batch(request.recognition_type.value)
|
||
|
||
return request.request_id
|
||
|
||
def _process_batch(self, recognition_type: str):
|
||
"""
|
||
执行批处理推理
|
||
将队列中的请求按批大小取出,统一送入模型推理
|
||
"""
|
||
queue = self._pending_requests.get(recognition_type, [])
|
||
if not queue:
|
||
return
|
||
|
||
batch = queue[:self._max_batch_size]
|
||
self._pending_requests[recognition_type] = queue[self._max_batch_size:]
|
||
|
||
batch_start = time.time()
|
||
logger.info(f"批处理开始: type={recognition_type}, batch_size={len(batch)}")
|
||
|
||
for req in batch:
|
||
try:
|
||
result = self._process_single(req)
|
||
self._results[req.request_id] = result
|
||
except Exception as e:
|
||
self._results[req.request_id] = RecognitionResult(
|
||
request_id=req.request_id,
|
||
recognition_type=recognition_type,
|
||
status="error",
|
||
details={"error": str(e)}
|
||
)
|
||
|
||
elapsed = (time.time() - batch_start) * 1000
|
||
logger.info(f"批处理完成: type={recognition_type}, count={len(batch)}, time={elapsed:.1f}ms")
|
||
|
||
def _process_single(self, request: RecognitionRequest) -> RecognitionResult:
|
||
"""处理单个识别请求"""
|
||
start_time = time.time()
|
||
|
||
# 根据识别类型分发到对应的推理引擎
|
||
if request.recognition_type == RecognitionType.OCR:
|
||
result_text = self._run_ocr_inference(request.stroke_data)
|
||
confidence = 0.92
|
||
elif request.recognition_type == RecognitionType.MATH:
|
||
result_text = self._run_math_inference(request.stroke_data)
|
||
confidence = 0.88
|
||
elif request.recognition_type == RecognitionType.STROKE_ORDER:
|
||
result_text = self._run_stroke_order_inference(request.stroke_data)
|
||
confidence = 0.95
|
||
else:
|
||
result_text = ""
|
||
confidence = 0.0
|
||
|
||
elapsed = (time.time() - start_time) * 1000
|
||
|
||
return RecognitionResult(
|
||
request_id=request.request_id,
|
||
recognition_type=request.recognition_type.value,
|
||
status="success",
|
||
result_text=result_text,
|
||
confidence=confidence,
|
||
processing_time_ms=round(elapsed, 2),
|
||
model_version="v1.0.0"
|
||
)
|
||
|
||
def _run_ocr_inference(self, stroke_data: Optional[StrokeData]) -> str:
|
||
"""执行OCR推理(调用PaddleOCR引擎)"""
|
||
if not stroke_data or not stroke_data.strokes:
|
||
return ""
|
||
# 实际环境中调用PaddleOCR推理管道
|
||
# preprocessed = preprocess(stroke_data)
|
||
# result = ocr_engine.recognize(preprocessed)
|
||
return "[OCR识别结果]"
|
||
|
||
def _run_math_inference(self, stroke_data: Optional[StrokeData]) -> str:
|
||
"""执行数学列式识别推理"""
|
||
if not stroke_data or not stroke_data.strokes:
|
||
return ""
|
||
return "[数学识别结果]"
|
||
|
||
def _run_stroke_order_inference(self, stroke_data: Optional[StrokeData]) -> str:
|
||
"""执行笔顺分析推理"""
|
||
if not stroke_data or not stroke_data.strokes:
|
||
return ""
|
||
return "[笔顺分析结果]"
|
||
|
||
def get_result(self, request_id: str) -> Optional[RecognitionResult]:
|
||
"""查询识别结果"""
|
||
return self._results.get(request_id)
|
||
|
||
def flush_all(self):
|
||
"""强制处理所有队列中的待处理请求"""
|
||
for rt in self._pending_requests:
|
||
while self._pending_requests[rt]:
|
||
self._process_batch(rt)
|
||
|
||
|
||
# ==================== gRPC服务实现 ====================
|
||
|
||
class RecognitionServiceImpl:
|
||
"""
|
||
gRPC RecognitionService 服务实现
|
||
对应 protobuf 服务定义:
|
||
service RecognitionService {
|
||
rpc Recognize(RecognitionRequest) returns (RecognitionResult);
|
||
rpc BatchRecognize(stream RecognitionRequest) returns (stream RecognitionResult);
|
||
rpc GetModelStatus(Empty) returns (ModelStatusResponse);
|
||
}
|
||
"""
|
||
|
||
def __init__(self):
|
||
self._processor = BatchRecognitionProcessor()
|
||
self._request_count = 0
|
||
self._total_latency_ms = 0.0
|
||
logger.info("gRPC RecognitionService 初始化完成")
|
||
|
||
def Recognize(self, request: RecognitionRequest) -> RecognitionResult:
|
||
"""
|
||
单次识别RPC
|
||
接收单个识别请求,返回识别结果
|
||
"""
|
||
self._request_count += 1
|
||
start_time = time.time()
|
||
|
||
# 验证请求参数
|
||
if not request.stroke_data or not request.stroke_data.strokes:
|
||
return RecognitionResult(
|
||
request_id=request.request_id,
|
||
status="error",
|
||
details={"error": "笔迹数据为空"}
|
||
)
|
||
|
||
# 提交到批处理器并等待结果
|
||
request_id = self._processor.add_request(request)
|
||
self._processor.flush_all() # 立即处理(单次调用不等待攒批)
|
||
|
||
result = self._processor.get_result(request_id)
|
||
elapsed = (time.time() - start_time) * 1000
|
||
self._total_latency_ms += elapsed
|
||
|
||
if result:
|
||
# 审计日志
|
||
logger.info(
|
||
f"gRPC Recognize: id={request_id}, type={request.recognition_type.value}, "
|
||
f"time={elapsed:.1f}ms, pen={request.stroke_data.pen_id}"
|
||
)
|
||
return result
|
||
|
||
return RecognitionResult(
|
||
request_id=request_id, status="error",
|
||
details={"error": "处理超时"}
|
||
)
|
||
|
||
def BatchRecognize(self, request_iterator) -> List[RecognitionResult]:
|
||
"""
|
||
流式批量识别RPC(双向流)
|
||
接收笔迹数据流,批量处理后流式返回识别结果
|
||
适用于教室场景下40+支笔并发传输的高吞吐识别
|
||
"""
|
||
results = []
|
||
request_ids = []
|
||
|
||
# 接收所有请求
|
||
for request in request_iterator:
|
||
rid = self._processor.add_request(request)
|
||
request_ids.append(rid)
|
||
self._request_count += 1
|
||
|
||
# 批量处理
|
||
self._processor.flush_all()
|
||
|
||
# 收集结果
|
||
for rid in request_ids:
|
||
result = self._processor.get_result(rid)
|
||
if result:
|
||
results.append(result)
|
||
|
||
logger.info(f"BatchRecognize完成: 请求数={len(request_ids)}, 结果数={len(results)}")
|
||
return results
|
||
|
||
def GetModelStatus(self) -> Dict:
|
||
"""查询模型状态RPC"""
|
||
return {
|
||
"total_requests": self._request_count,
|
||
"avg_latency_ms": round(self._total_latency_ms / max(self._request_count, 1), 2),
|
||
"models": [
|
||
{"name": "ocr_model", "version": "v1.0.0", "status": "active"},
|
||
{"name": "math_model", "version": "v1.0.0", "status": "active"},
|
||
{"name": "stroke_order_model", "version": "v1.0.0", "status": "active"},
|
||
]
|
||
}
|
||
|
||
|
||
# ==================== gRPC服务器启动 ====================
|
||
|
||
class GrpcServer:
|
||
"""
|
||
gRPC服务器管理
|
||
启动和管理gRPC推理服务端口
|
||
支持TLS双向认证(mTLS安全设计)
|
||
"""
|
||
|
||
def __init__(self, host: str = "0.0.0.0", port: int = 50051,
|
||
max_workers: int = 10, enable_tls: bool = True):
|
||
self._host = host
|
||
self._port = port
|
||
self._max_workers = max_workers
|
||
self._enable_tls = enable_tls
|
||
self._service = RecognitionServiceImpl()
|
||
self._server = None
|
||
logger.info(f"gRPC服务器配置: {host}:{port}, workers={max_workers}, tls={enable_tls}")
|
||
|
||
def start(self):
|
||
"""
|
||
启动gRPC服务器
|
||
如启用TLS,加载服务端证书和CA证书用于mTLS双向认证
|
||
"""
|
||
logger.info(f"启动gRPC服务器: {self._host}:{self._port}")
|
||
|
||
# 实际环境中的gRPC服务器启动代码
|
||
# self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=self._max_workers))
|
||
# inference_pb2_grpc.add_RecognitionServiceServicer_to_server(self._service, self._server)
|
||
#
|
||
# if self._enable_tls:
|
||
# # mTLS双向认证配置(安全设计)
|
||
# with open('/etc/ssl/server.key', 'rb') as f:
|
||
# server_key = f.read()
|
||
# with open('/etc/ssl/server.crt', 'rb') as f:
|
||
# server_cert = f.read()
|
||
# with open('/etc/ssl/ca.crt', 'rb') as f:
|
||
# ca_cert = f.read()
|
||
# credentials = grpc.ssl_server_credentials(
|
||
# [(server_key, server_cert)],
|
||
# root_certificates=ca_cert,
|
||
# require_client_auth=True # 要求客户端证书
|
||
# )
|
||
# self._server.add_secure_port(f'{self._host}:{self._port}', credentials)
|
||
# else:
|
||
# self._server.add_insecure_port(f'{self._host}:{self._port}')
|
||
#
|
||
# self._server.start()
|
||
|
||
logger.info(f"gRPC服务器已启动: {self._host}:{self._port}")
|
||
|
||
def stop(self, grace_seconds: int = 5):
|
||
"""优雅关闭gRPC服务器"""
|
||
if self._server:
|
||
# self._server.stop(grace_seconds)
|
||
logger.info("gRPC服务器已关闭")
|
||
|
||
def get_stats(self) -> Dict:
|
||
"""获取服务器统计信息"""
|
||
return self._service.GetModelStatus()
|