# 自然写手写识别与AI分析引擎软件 V1.0 # gRPC批量识别服务模块 - 高性能流式批量笔迹识别 """ gRPC推理服务 提供高性能流式批量笔迹识别接口 采用gRPC双向流模式,适用于教室场景下多支笔并发识别需求 支持服务端流式响应,实现低延迟识别结果推送 """ import time import json import logging import uuid import asyncio from typing import List, Dict, Optional, AsyncIterator from dataclasses import dataclass, field from enum import Enum from concurrent import futures logger = logging.getLogger(__name__) # ==================== gRPC消息定义(等效Proto) ==================== class RecognitionType(str, Enum): """识别类型枚举""" OCR = "ocr" # 文字识别 MATH = "math" # 数学识别 STROKE_ORDER = "stroke_order" # 笔顺评分 ESSAY = "essay" # 作文批改 @dataclass class StrokePoint: """笔迹坐标点(对应protobuf StrokePoint message)""" x: float y: float pressure: float = 0.5 timestamp: int = 0 @dataclass class StrokeData: """笔迹数据(对应protobuf StrokeData message)""" stroke_id: str = "" pen_id: str = "" page_id: str = "" student_id: str = "" strokes: List[List[StrokePoint]] = field(default_factory=list) @dataclass class RecognitionRequest: """识别请求(对应protobuf RecognitionRequest message)""" request_id: str = "" recognition_type: RecognitionType = RecognitionType.OCR stroke_data: Optional[StrokeData] = None priority: int = 2 # 0=最高优先级,4=最低 callback_topic: str = "" # 结果回调MQTT Topic timeout_ms: int = 5000 # 超时时间 @dataclass class RecognitionResult: """识别结果(对应protobuf RecognitionResult message)""" request_id: str = "" recognition_type: str = "" status: str = "success" # success / error / timeout result_text: str = "" confidence: float = 0.0 details: Dict = field(default_factory=dict) processing_time_ms: float = 0.0 model_version: str = "" # ==================== 批量识别处理器 ==================== class BatchRecognitionProcessor: """ 批量识别处理器 将多个识别请求按类型分组,批量送入GPU推理 通过批处理显著提升GPU利用率和吞吐量 """ def __init__(self, max_batch_size: int = 32, max_wait_ms: int = 50): self._max_batch_size = max_batch_size self._max_wait_ms = max_wait_ms self._pending_requests: Dict[str, List[RecognitionRequest]] = { rt.value: [] for rt in RecognitionType } self._results: Dict[str, RecognitionResult] = {} logger.info(f"批量识别处理器初始化: batch_size={max_batch_size}, wait_ms={max_wait_ms}") def add_request(self, request: RecognitionRequest) -> str: """添加识别请求到批处理队列""" if not request.request_id: request.request_id = str(uuid.uuid4()) queue = self._pending_requests.get(request.recognition_type.value, []) queue.append(request) self._pending_requests[request.recognition_type.value] = queue logger.debug(f"请求入队: id={request.request_id}, type={request.recognition_type.value}") # 当队列达到批大小时触发批处理 if len(queue) >= self._max_batch_size: self._process_batch(request.recognition_type.value) return request.request_id def _process_batch(self, recognition_type: str): """ 执行批处理推理 将队列中的请求按批大小取出,统一送入模型推理 """ queue = self._pending_requests.get(recognition_type, []) if not queue: return batch = queue[:self._max_batch_size] self._pending_requests[recognition_type] = queue[self._max_batch_size:] batch_start = time.time() logger.info(f"批处理开始: type={recognition_type}, batch_size={len(batch)}") for req in batch: try: result = self._process_single(req) self._results[req.request_id] = result except Exception as e: self._results[req.request_id] = RecognitionResult( request_id=req.request_id, recognition_type=recognition_type, status="error", details={"error": str(e)} ) elapsed = (time.time() - batch_start) * 1000 logger.info(f"批处理完成: type={recognition_type}, count={len(batch)}, time={elapsed:.1f}ms") def _process_single(self, request: RecognitionRequest) -> RecognitionResult: """处理单个识别请求""" start_time = time.time() # 根据识别类型分发到对应的推理引擎 if request.recognition_type == RecognitionType.OCR: result_text = self._run_ocr_inference(request.stroke_data) confidence = 0.92 elif request.recognition_type == RecognitionType.MATH: result_text = self._run_math_inference(request.stroke_data) confidence = 0.88 elif request.recognition_type == RecognitionType.STROKE_ORDER: result_text = self._run_stroke_order_inference(request.stroke_data) confidence = 0.95 else: result_text = "" confidence = 0.0 elapsed = (time.time() - start_time) * 1000 return RecognitionResult( request_id=request.request_id, recognition_type=request.recognition_type.value, status="success", result_text=result_text, confidence=confidence, processing_time_ms=round(elapsed, 2), model_version="v1.0.0" ) def _run_ocr_inference(self, stroke_data: Optional[StrokeData]) -> str: """执行OCR推理(调用PaddleOCR引擎)""" if not stroke_data or not stroke_data.strokes: return "" # 实际环境中调用PaddleOCR推理管道 # preprocessed = preprocess(stroke_data) # result = ocr_engine.recognize(preprocessed) return "[OCR识别结果]" def _run_math_inference(self, stroke_data: Optional[StrokeData]) -> str: """执行数学列式识别推理""" if not stroke_data or not stroke_data.strokes: return "" return "[数学识别结果]" def _run_stroke_order_inference(self, stroke_data: Optional[StrokeData]) -> str: """执行笔顺分析推理""" if not stroke_data or not stroke_data.strokes: return "" return "[笔顺分析结果]" def get_result(self, request_id: str) -> Optional[RecognitionResult]: """查询识别结果""" return self._results.get(request_id) def flush_all(self): """强制处理所有队列中的待处理请求""" for rt in self._pending_requests: while self._pending_requests[rt]: self._process_batch(rt) # ==================== gRPC服务实现 ==================== class RecognitionServiceImpl: """ gRPC RecognitionService 服务实现 对应 protobuf 服务定义: service RecognitionService { rpc Recognize(RecognitionRequest) returns (RecognitionResult); rpc BatchRecognize(stream RecognitionRequest) returns (stream RecognitionResult); rpc GetModelStatus(Empty) returns (ModelStatusResponse); } """ def __init__(self): self._processor = BatchRecognitionProcessor() self._request_count = 0 self._total_latency_ms = 0.0 logger.info("gRPC RecognitionService 初始化完成") def Recognize(self, request: RecognitionRequest) -> RecognitionResult: """ 单次识别RPC 接收单个识别请求,返回识别结果 """ self._request_count += 1 start_time = time.time() # 验证请求参数 if not request.stroke_data or not request.stroke_data.strokes: return RecognitionResult( request_id=request.request_id, status="error", details={"error": "笔迹数据为空"} ) # 提交到批处理器并等待结果 request_id = self._processor.add_request(request) self._processor.flush_all() # 立即处理(单次调用不等待攒批) result = self._processor.get_result(request_id) elapsed = (time.time() - start_time) * 1000 self._total_latency_ms += elapsed if result: # 审计日志 logger.info( f"gRPC Recognize: id={request_id}, type={request.recognition_type.value}, " f"time={elapsed:.1f}ms, pen={request.stroke_data.pen_id}" ) return result return RecognitionResult( request_id=request_id, status="error", details={"error": "处理超时"} ) def BatchRecognize(self, request_iterator) -> List[RecognitionResult]: """ 流式批量识别RPC(双向流) 接收笔迹数据流,批量处理后流式返回识别结果 适用于教室场景下40+支笔并发传输的高吞吐识别 """ results = [] request_ids = [] # 接收所有请求 for request in request_iterator: rid = self._processor.add_request(request) request_ids.append(rid) self._request_count += 1 # 批量处理 self._processor.flush_all() # 收集结果 for rid in request_ids: result = self._processor.get_result(rid) if result: results.append(result) logger.info(f"BatchRecognize完成: 请求数={len(request_ids)}, 结果数={len(results)}") return results def GetModelStatus(self) -> Dict: """查询模型状态RPC""" return { "total_requests": self._request_count, "avg_latency_ms": round(self._total_latency_ms / max(self._request_count, 1), 2), "models": [ {"name": "ocr_model", "version": "v1.0.0", "status": "active"}, {"name": "math_model", "version": "v1.0.0", "status": "active"}, {"name": "stroke_order_model", "version": "v1.0.0", "status": "active"}, ] } # ==================== gRPC服务器启动 ==================== class GrpcServer: """ gRPC服务器管理 启动和管理gRPC推理服务端口 支持TLS双向认证(mTLS安全设计) """ def __init__(self, host: str = "0.0.0.0", port: int = 50051, max_workers: int = 10, enable_tls: bool = True): self._host = host self._port = port self._max_workers = max_workers self._enable_tls = enable_tls self._service = RecognitionServiceImpl() self._server = None logger.info(f"gRPC服务器配置: {host}:{port}, workers={max_workers}, tls={enable_tls}") def start(self): """ 启动gRPC服务器 如启用TLS,加载服务端证书和CA证书用于mTLS双向认证 """ logger.info(f"启动gRPC服务器: {self._host}:{self._port}") # 实际环境中的gRPC服务器启动代码 # self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=self._max_workers)) # inference_pb2_grpc.add_RecognitionServiceServicer_to_server(self._service, self._server) # # if self._enable_tls: # # mTLS双向认证配置(安全设计) # with open('/etc/ssl/server.key', 'rb') as f: # server_key = f.read() # with open('/etc/ssl/server.crt', 'rb') as f: # server_cert = f.read() # with open('/etc/ssl/ca.crt', 'rb') as f: # ca_cert = f.read() # credentials = grpc.ssl_server_credentials( # [(server_key, server_cert)], # root_certificates=ca_cert, # require_client_auth=True # 要求客户端证书 # ) # self._server.add_secure_port(f'{self._host}:{self._port}', credentials) # else: # self._server.add_insecure_port(f'{self._host}:{self._port}') # # self._server.start() logger.info(f"gRPC服务器已启动: {self._host}:{self._port}") def stop(self, grace_seconds: int = 5): """优雅关闭gRPC服务器""" if self._server: # self._server.stop(grace_seconds) logger.info("gRPC服务器已关闭") def get_stats(self) -> Dict: """获取服务器统计信息""" return self._service.GetModelStatus()