219 lines
6.0 KiB
Python
219 lines
6.0 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
自然写手写识别与AI分析引擎软件 V1.0
|
|
|
|
版权所有 (C) 2026
|
|
软件全称:自然写手写识别与AI分析引擎软件
|
|
版本号:V1.0
|
|
|
|
主启动文件 - FastAPI 服务入口
|
|
负责服务初始化、路由注册、中间件配置
|
|
"""
|
|
|
|
from fastapi import FastAPI, Request, HTTPException
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import JSONResponse
|
|
from contextlib import asynccontextmanager
|
|
import uvicorn
|
|
import logging
|
|
import time
|
|
from typing import Dict, Any
|
|
|
|
# 导入各业务模块路由
|
|
from api.ocr_api import router as ocr_router
|
|
from api.math_api import router as math_router
|
|
from api.stroke_order_api import router as stroke_order_router
|
|
from api.essay_api import router as essay_router
|
|
from service.model_manager import ModelManager
|
|
from config.settings import Settings
|
|
|
|
# 日志配置
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
|
|
)
|
|
logger = logging.getLogger("writech-ai-engine")
|
|
|
|
# 全局配置
|
|
settings = Settings()
|
|
|
|
# 全局模型管理器实例
|
|
model_manager = ModelManager(settings)
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""
|
|
应用生命周期管理
|
|
启动时加载所有AI模型到GPU/CPU内存
|
|
关闭时释放模型资源
|
|
"""
|
|
logger.info("自然写AI引擎启动中,加载模型...")
|
|
# 启动时加载所有模型
|
|
await model_manager.load_all_models()
|
|
logger.info("所有模型加载完成,服务就绪")
|
|
yield
|
|
# 关闭时释放资源
|
|
logger.info("服务关闭中,释放模型资源...")
|
|
model_manager.release_all_models()
|
|
logger.info("模型资源已释放")
|
|
|
|
|
|
# 创建 FastAPI 应用实例
|
|
app = FastAPI(
|
|
title="自然写手写识别与AI分析引擎",
|
|
description="对智能点阵笔采集的笔迹数据进行OCR识别、数学列式识别、笔顺分析及AI智能批改",
|
|
version="1.0.0",
|
|
lifespan=lifespan
|
|
)
|
|
|
|
# 跨域中间件配置
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
@app.middleware("http")
|
|
async def request_logging_middleware(request: Request, call_next):
|
|
"""
|
|
请求日志与性能监控中间件
|
|
记录每个请求的处理时间、状态码、推理耗时
|
|
"""
|
|
start_time = time.time()
|
|
request_id = request.headers.get("X-Request-ID", str(time.time()))
|
|
|
|
# 输入数据大小校验(防恶意攻击,最大10MB)
|
|
content_length = request.headers.get("content-length")
|
|
if content_length and int(content_length) > 10 * 1024 * 1024:
|
|
return JSONResponse(
|
|
status_code=413,
|
|
content={"code": 413, "msg": "请求数据过大,最大支持10MB", "data": None}
|
|
)
|
|
|
|
response = await call_next(request)
|
|
|
|
# 记录请求处理时间
|
|
process_time = time.time() - start_time
|
|
response.headers["X-Process-Time"] = f"{process_time:.4f}"
|
|
response.headers["X-Request-ID"] = request_id
|
|
|
|
logger.info(
|
|
f"{request.method} {request.url.path} "
|
|
f"status={response.status_code} "
|
|
f"time={process_time:.4f}s"
|
|
)
|
|
|
|
return response
|
|
|
|
|
|
@app.middleware("http")
|
|
async def mtls_authentication_middleware(request: Request, call_next):
|
|
"""
|
|
mTLS 双向认证中间件
|
|
内部服务间通信需携带有效的客户端证书
|
|
|
|
安全设计:
|
|
- 服务鉴权:内部服务间 mTLS 双向认证
|
|
- 请求校验:输入数据格式校验与大小限制(防恶意攻击)
|
|
"""
|
|
# 检查是否为内部服务调用
|
|
client_cert = request.headers.get("X-Client-Cert")
|
|
api_key = request.headers.get("X-API-Key")
|
|
|
|
# 白名单路径不需要认证
|
|
whitelist_paths = ["/health", "/docs", "/openapi.json"]
|
|
if request.url.path in whitelist_paths:
|
|
return await call_next(request)
|
|
|
|
# 验证API Key或客户端证书
|
|
if not api_key and not client_cert:
|
|
return JSONResponse(
|
|
status_code=401,
|
|
content={"code": 401, "msg": "缺少认证凭据", "data": None}
|
|
)
|
|
|
|
if api_key and api_key != settings.api_key:
|
|
return JSONResponse(
|
|
status_code=403,
|
|
content={"code": 403, "msg": "API Key无效", "data": None}
|
|
)
|
|
|
|
return await call_next(request)
|
|
|
|
|
|
# 注册各业务路由
|
|
app.include_router(ocr_router, prefix="/api/v1/ocr", tags=["OCR识别"])
|
|
app.include_router(math_router, prefix="/api/v1/math", tags=["数学识别"])
|
|
app.include_router(stroke_order_router, prefix="/api/v1/stroke-order", tags=["笔顺评分"])
|
|
app.include_router(essay_router, prefix="/api/v1/essay", tags=["作文批改"])
|
|
|
|
|
|
@app.get("/health")
|
|
async def health_check():
|
|
"""健康检查端点"""
|
|
model_status = model_manager.get_all_status()
|
|
return {
|
|
"code": 200,
|
|
"msg": "success",
|
|
"data": {
|
|
"status": "healthy",
|
|
"models": model_status,
|
|
"version": "1.0.0"
|
|
}
|
|
}
|
|
|
|
|
|
@app.get("/api/v1/model/status")
|
|
async def get_model_status():
|
|
"""
|
|
查询各模型加载状态与版本
|
|
GET /api/v1/model/status
|
|
"""
|
|
status = model_manager.get_all_status()
|
|
return {
|
|
"code": 200,
|
|
"msg": "success",
|
|
"data": status
|
|
}
|
|
|
|
|
|
@app.exception_handler(HTTPException)
|
|
async def http_exception_handler(request: Request, exc: HTTPException):
|
|
"""统一HTTP异常处理"""
|
|
return JSONResponse(
|
|
status_code=exc.status_code,
|
|
content={
|
|
"code": exc.status_code,
|
|
"msg": exc.detail,
|
|
"data": None
|
|
}
|
|
)
|
|
|
|
|
|
@app.exception_handler(Exception)
|
|
async def general_exception_handler(request: Request, exc: Exception):
|
|
"""统一异常处理"""
|
|
logger.error(f"未处理异常: {str(exc)}", exc_info=True)
|
|
return JSONResponse(
|
|
status_code=500,
|
|
content={
|
|
"code": 500,
|
|
"msg": "AI引擎内部错误",
|
|
"data": None
|
|
}
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
uvicorn.run(
|
|
"main:app",
|
|
host="0.0.0.0",
|
|
port=8001,
|
|
workers=4,
|
|
log_level="info"
|
|
)
|