software copyright
This commit is contained in:
@@ -0,0 +1,218 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
自然写手写识别与AI分析引擎软件 V1.0
|
||||
|
||||
版权所有 (C) 2026
|
||||
软件全称:自然写手写识别与AI分析引擎软件
|
||||
版本号:V1.0
|
||||
|
||||
主启动文件 - FastAPI 服务入口
|
||||
负责服务初始化、路由注册、中间件配置
|
||||
"""
|
||||
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from contextlib import asynccontextmanager
|
||||
import uvicorn
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, Any
|
||||
|
||||
# 导入各业务模块路由
|
||||
from api.ocr_api import router as ocr_router
|
||||
from api.math_api import router as math_router
|
||||
from api.stroke_order_api import router as stroke_order_router
|
||||
from api.essay_api import router as essay_router
|
||||
from service.model_manager import ModelManager
|
||||
from config.settings import Settings
|
||||
|
||||
# 日志配置
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
|
||||
)
|
||||
logger = logging.getLogger("writech-ai-engine")
|
||||
|
||||
# 全局配置
|
||||
settings = Settings()
|
||||
|
||||
# 全局模型管理器实例
|
||||
model_manager = ModelManager(settings)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""
|
||||
应用生命周期管理
|
||||
启动时加载所有AI模型到GPU/CPU内存
|
||||
关闭时释放模型资源
|
||||
"""
|
||||
logger.info("自然写AI引擎启动中,加载模型...")
|
||||
# 启动时加载所有模型
|
||||
await model_manager.load_all_models()
|
||||
logger.info("所有模型加载完成,服务就绪")
|
||||
yield
|
||||
# 关闭时释放资源
|
||||
logger.info("服务关闭中,释放模型资源...")
|
||||
model_manager.release_all_models()
|
||||
logger.info("模型资源已释放")
|
||||
|
||||
|
||||
# 创建 FastAPI 应用实例
|
||||
app = FastAPI(
|
||||
title="自然写手写识别与AI分析引擎",
|
||||
description="对智能点阵笔采集的笔迹数据进行OCR识别、数学列式识别、笔顺分析及AI智能批改",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# 跨域中间件配置
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def request_logging_middleware(request: Request, call_next):
|
||||
"""
|
||||
请求日志与性能监控中间件
|
||||
记录每个请求的处理时间、状态码、推理耗时
|
||||
"""
|
||||
start_time = time.time()
|
||||
request_id = request.headers.get("X-Request-ID", str(time.time()))
|
||||
|
||||
# 输入数据大小校验(防恶意攻击,最大10MB)
|
||||
content_length = request.headers.get("content-length")
|
||||
if content_length and int(content_length) > 10 * 1024 * 1024:
|
||||
return JSONResponse(
|
||||
status_code=413,
|
||||
content={"code": 413, "msg": "请求数据过大,最大支持10MB", "data": None}
|
||||
)
|
||||
|
||||
response = await call_next(request)
|
||||
|
||||
# 记录请求处理时间
|
||||
process_time = time.time() - start_time
|
||||
response.headers["X-Process-Time"] = f"{process_time:.4f}"
|
||||
response.headers["X-Request-ID"] = request_id
|
||||
|
||||
logger.info(
|
||||
f"{request.method} {request.url.path} "
|
||||
f"status={response.status_code} "
|
||||
f"time={process_time:.4f}s"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def mtls_authentication_middleware(request: Request, call_next):
|
||||
"""
|
||||
mTLS 双向认证中间件
|
||||
内部服务间通信需携带有效的客户端证书
|
||||
|
||||
安全设计:
|
||||
- 服务鉴权:内部服务间 mTLS 双向认证
|
||||
- 请求校验:输入数据格式校验与大小限制(防恶意攻击)
|
||||
"""
|
||||
# 检查是否为内部服务调用
|
||||
client_cert = request.headers.get("X-Client-Cert")
|
||||
api_key = request.headers.get("X-API-Key")
|
||||
|
||||
# 白名单路径不需要认证
|
||||
whitelist_paths = ["/health", "/docs", "/openapi.json"]
|
||||
if request.url.path in whitelist_paths:
|
||||
return await call_next(request)
|
||||
|
||||
# 验证API Key或客户端证书
|
||||
if not api_key and not client_cert:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"code": 401, "msg": "缺少认证凭据", "data": None}
|
||||
)
|
||||
|
||||
if api_key and api_key != settings.api_key:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"code": 403, "msg": "API Key无效", "data": None}
|
||||
)
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
# 注册各业务路由
|
||||
app.include_router(ocr_router, prefix="/api/v1/ocr", tags=["OCR识别"])
|
||||
app.include_router(math_router, prefix="/api/v1/math", tags=["数学识别"])
|
||||
app.include_router(stroke_order_router, prefix="/api/v1/stroke-order", tags=["笔顺评分"])
|
||||
app.include_router(essay_router, prefix="/api/v1/essay", tags=["作文批改"])
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""健康检查端点"""
|
||||
model_status = model_manager.get_all_status()
|
||||
return {
|
||||
"code": 200,
|
||||
"msg": "success",
|
||||
"data": {
|
||||
"status": "healthy",
|
||||
"models": model_status,
|
||||
"version": "1.0.0"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/v1/model/status")
|
||||
async def get_model_status():
|
||||
"""
|
||||
查询各模型加载状态与版本
|
||||
GET /api/v1/model/status
|
||||
"""
|
||||
status = model_manager.get_all_status()
|
||||
return {
|
||||
"code": 200,
|
||||
"msg": "success",
|
||||
"data": status
|
||||
}
|
||||
|
||||
|
||||
@app.exception_handler(HTTPException)
|
||||
async def http_exception_handler(request: Request, exc: HTTPException):
|
||||
"""统一HTTP异常处理"""
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={
|
||||
"code": exc.status_code,
|
||||
"msg": exc.detail,
|
||||
"data": None
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def general_exception_handler(request: Request, exc: Exception):
|
||||
"""统一异常处理"""
|
||||
logger.error(f"未处理异常: {str(exc)}", exc_info=True)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"code": 500,
|
||||
"msg": "AI引擎内部错误",
|
||||
"data": None
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
"main:app",
|
||||
host="0.0.0.0",
|
||||
port=8001,
|
||||
workers=4,
|
||||
log_level="info"
|
||||
)
|
||||
Reference in New Issue
Block a user