337 lines
13 KiB
Python
337 lines
13 KiB
Python
# 自然写手写识别与AI分析引擎软件 V1.0
|
||
# 配置与安全模块 - 全局配置管理与安全策略
|
||
|
||
"""
|
||
全局配置管理
|
||
提供AI引擎服务的所有配置项管理,包括:
|
||
服务端口、模型路径、GPU配置、安全认证、日志级别等
|
||
支持环境变量覆盖和配置热更新
|
||
"""
|
||
|
||
import os
|
||
import json
|
||
import logging
|
||
import hashlib
|
||
import hmac
|
||
import time
|
||
from typing import Dict, List, Optional, Any
|
||
from dataclasses import dataclass, field
|
||
from pathlib import Path
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# ==================== 服务配置 ====================
|
||
|
||
@dataclass
|
||
class ServerConfig:
|
||
"""HTTP/gRPC服务配置"""
|
||
http_host: str = "0.0.0.0"
|
||
http_port: int = 8000
|
||
grpc_host: str = "0.0.0.0"
|
||
grpc_port: int = 50051
|
||
workers: int = 4 # FastAPI worker数量
|
||
grpc_max_workers: int = 10 # gRPC线程池大小
|
||
max_request_size_mb: int = 10 # 请求体大小限制(防恶意攻击)
|
||
request_timeout_s: int = 30 # 请求超时时间
|
||
cors_origins: List[str] = field(default_factory=lambda: ["*"])
|
||
debug: bool = False
|
||
|
||
|
||
@dataclass
|
||
class ModelConfig:
|
||
"""模型推理配置"""
|
||
models_dir: str = "/opt/models" # 模型文件根目录
|
||
ocr_model_path: str = "/opt/models/ocr" # OCR模型路径
|
||
math_model_path: str = "/opt/models/math" # 数学识别模型路径
|
||
stroke_model_path: str = "/opt/models/stroke" # 笔顺模型路径
|
||
essay_model_path: str = "/opt/models/essay" # 作文评分模型路径
|
||
max_batch_size: int = 32 # 最大推理批大小
|
||
inference_timeout_ms: int = 5000 # 单次推理超时
|
||
enable_fp16: bool = True # FP16半精度推理
|
||
model_cache_size_gb: float = 4.0 # 模型内存缓存大小
|
||
|
||
|
||
@dataclass
|
||
class GPUConfig:
|
||
"""GPU/NPU硬件加速配置"""
|
||
device: str = "cuda" # 推理设备: cuda / cpu / npu
|
||
gpu_ids: List[int] = field(default_factory=lambda: [0]) # 使用的GPU编号
|
||
gpu_memory_fraction: float = 0.8 # GPU显存使用比例上限
|
||
enable_tensorrt: bool = True # 是否启用TensorRT加速
|
||
tensorrt_precision: str = "fp16" # TensorRT精度: fp32/fp16/int8
|
||
triton_url: str = "localhost:8001" # Triton Inference Server地址
|
||
|
||
|
||
@dataclass
|
||
class CeleryConfig:
|
||
"""Celery任务队列配置"""
|
||
broker_url: str = "redis://localhost:6379/0" # Redis Broker地址
|
||
result_backend: str = "redis://localhost:6379/1" # 结果存储后端
|
||
task_serializer: str = "json"
|
||
result_serializer: str = "json"
|
||
task_default_queue: str = "writech.default"
|
||
task_time_limit: int = 300 # 任务最大执行时间(秒)
|
||
task_soft_time_limit: int = 240 # 软超时(触发SoftTimeLimitExceeded)
|
||
worker_concurrency: int = 8 # Worker并发数
|
||
worker_prefetch_multiplier: int = 2 # 预取倍数
|
||
|
||
|
||
@dataclass
|
||
class DatabaseConfig:
|
||
"""数据库配置"""
|
||
mysql_url: str = "mysql+pymysql://user:password@localhost:3306/writech_ai"
|
||
redis_url: str = "redis://localhost:6379/0"
|
||
mongodb_url: str = "mongodb://localhost:27017/writech_stroke"
|
||
pool_size: int = 20 # 连接池大小
|
||
pool_recycle: int = 3600 # 连接回收时间(秒)
|
||
|
||
|
||
@dataclass
|
||
class LogConfig:
|
||
"""日志配置"""
|
||
level: str = "INFO"
|
||
format: str = "%(asctime)s [%(levelname)s] %(name)s: %(message)s"
|
||
log_dir: str = "/var/log/writech-ai"
|
||
max_file_size_mb: int = 100 # 单个日志文件大小上限
|
||
backup_count: int = 10 # 保留日志文件数量
|
||
enable_audit_log: bool = True # 启用审计日志
|
||
audit_log_file: str = "audit.log" # 审计日志文件名
|
||
|
||
|
||
# ==================== 安全配置 ====================
|
||
|
||
@dataclass
|
||
class SecurityConfig:
|
||
"""安全配置"""
|
||
# mTLS双向认证(安全设计:内部服务间mTLS双向认证)
|
||
enable_mtls: bool = True
|
||
server_cert_path: str = "/etc/ssl/server.crt"
|
||
server_key_path: str = "/etc/ssl/server.key"
|
||
ca_cert_path: str = "/etc/ssl/ca.crt"
|
||
|
||
# 模型文件加密(安全设计:模型文件加密存储,推理时内存解密)
|
||
model_encryption_enabled: bool = True
|
||
model_encryption_key_env: str = "WRITECH_MODEL_KEY" # 加密密钥从环境变量读取
|
||
|
||
# 请求校验(安全设计:输入数据格式校验与大小限制)
|
||
max_stroke_points: int = 100000 # 单次请求最大坐标点数
|
||
max_strokes_per_request: int = 500 # 单次请求最大笔画数
|
||
max_text_length: int = 10000 # 作文文本最大长度
|
||
|
||
# 速率限制
|
||
rate_limit_per_minute: int = 600 # 每分钟最大请求数
|
||
rate_limit_burst: int = 50 # 突发请求数
|
||
|
||
# 审计日志(安全设计:所有识别请求记录调用方、时间、模型版本)
|
||
enable_audit: bool = True
|
||
audit_retention_days: int = 90 # 审计日志保留天数
|
||
|
||
|
||
# ==================== mTLS认证管理 ====================
|
||
|
||
class MTLSAuthenticator:
|
||
"""
|
||
mTLS双向认证管理器
|
||
验证客户端证书,确保只有授权的内部服务可以调用AI引擎
|
||
"""
|
||
|
||
def __init__(self, config: SecurityConfig):
|
||
self._config = config
|
||
self._trusted_clients: Dict[str, str] = {} # 授信客户端证书指纹
|
||
logger.info("mTLS认证管理器初始化")
|
||
|
||
def load_certificates(self) -> bool:
|
||
"""加载服务端证书和CA证书"""
|
||
try:
|
||
cert_path = Path(self._config.server_cert_path)
|
||
key_path = Path(self._config.server_key_path)
|
||
ca_path = Path(self._config.ca_cert_path)
|
||
|
||
if not cert_path.exists():
|
||
logger.warning(f"服务端证书不存在: {cert_path}")
|
||
return False
|
||
|
||
logger.info("mTLS证书加载完成")
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"证书加载失败: {str(e)}")
|
||
return False
|
||
|
||
def verify_client_cert(self, cert_fingerprint: str) -> bool:
|
||
"""验证客户端证书指纹"""
|
||
if not self._config.enable_mtls:
|
||
return True
|
||
is_trusted = cert_fingerprint in self._trusted_clients
|
||
if not is_trusted:
|
||
logger.warning(f"未授信的客户端证书: {cert_fingerprint}")
|
||
return is_trusted
|
||
|
||
def register_trusted_client(self, name: str, fingerprint: str):
|
||
"""注册授信客户端"""
|
||
self._trusted_clients[fingerprint] = name
|
||
logger.info(f"注册授信客户端: {name}")
|
||
|
||
|
||
# ==================== 请求签名校验 ====================
|
||
|
||
class RequestValidator:
|
||
"""
|
||
请求签名校验器
|
||
对API请求进行HMAC签名校验,防止请求篡改和重放攻击
|
||
"""
|
||
|
||
def __init__(self, secret_key: str = ""):
|
||
self._secret = secret_key or os.environ.get("WRITECH_API_SECRET", "default-secret")
|
||
self._nonce_cache: Dict[str, float] = {} # 随机数缓存(防重放)
|
||
self._nonce_ttl = 300 # 随机数有效期(秒)
|
||
|
||
def generate_signature(self, payload: str, timestamp: int, nonce: str) -> str:
|
||
"""生成请求签名"""
|
||
message = f"{payload}×tamp={timestamp}&nonce={nonce}"
|
||
return hmac.new(
|
||
self._secret.encode(), message.encode(), hashlib.sha256
|
||
).hexdigest()
|
||
|
||
def verify_signature(self, payload: str, timestamp: int,
|
||
nonce: str, signature: str) -> bool:
|
||
"""
|
||
校验请求签名
|
||
1. 检查时间戳是否在有效窗口内(防重放)
|
||
2. 检查随机数是否已使用(防重放)
|
||
3. 验证HMAC签名是否匹配(防篡改)
|
||
"""
|
||
# 时间窗口校验(±5分钟)
|
||
current_time = int(time.time())
|
||
if abs(current_time - timestamp) > 300:
|
||
logger.warning(f"请求时间戳过期: {timestamp}")
|
||
return False
|
||
|
||
# 随机数防重放检查
|
||
if nonce in self._nonce_cache:
|
||
logger.warning(f"重复的请求随机数: {nonce}")
|
||
return False
|
||
|
||
# HMAC签名验证
|
||
expected = self.generate_signature(payload, timestamp, nonce)
|
||
is_valid = hmac.compare_digest(expected, signature)
|
||
|
||
if is_valid:
|
||
# 缓存随机数
|
||
self._nonce_cache[nonce] = time.time()
|
||
self._cleanup_nonce_cache()
|
||
|
||
return is_valid
|
||
|
||
def _cleanup_nonce_cache(self):
|
||
"""清理过期的随机数缓存"""
|
||
current = time.time()
|
||
expired = [k for k, v in self._nonce_cache.items() if current - v > self._nonce_ttl]
|
||
for k in expired:
|
||
del self._nonce_cache[k]
|
||
|
||
|
||
# ==================== 全局配置管理器 ====================
|
||
|
||
class Settings:
|
||
"""
|
||
全局配置管理器(单例)
|
||
从环境变量和配置文件加载配置,支持运行时热更新
|
||
环境变量优先级高于配置文件
|
||
"""
|
||
|
||
_instance = None
|
||
|
||
def __new__(cls):
|
||
if cls._instance is None:
|
||
cls._instance = super().__new__(cls)
|
||
return cls._instance
|
||
|
||
def __init__(self):
|
||
if hasattr(self, '_initialized'):
|
||
return
|
||
self._initialized = True
|
||
|
||
# 加载各模块配置
|
||
self.server = ServerConfig()
|
||
self.model = ModelConfig()
|
||
self.gpu = GPUConfig()
|
||
self.celery = CeleryConfig()
|
||
self.database = DatabaseConfig()
|
||
self.log = LogConfig()
|
||
self.security = SecurityConfig()
|
||
|
||
# 从环境变量覆盖配置
|
||
self._load_from_env()
|
||
|
||
# 初始化安全组件
|
||
self.mtls_auth = MTLSAuthenticator(self.security)
|
||
self.request_validator = RequestValidator()
|
||
|
||
logger.info("全局配置加载完成")
|
||
|
||
def _load_from_env(self):
|
||
"""从环境变量加载配置(覆盖默认值)"""
|
||
env_mapping = {
|
||
"WRITECH_HTTP_PORT": ("server", "http_port", int),
|
||
"WRITECH_GRPC_PORT": ("server", "grpc_port", int),
|
||
"WRITECH_WORKERS": ("server", "workers", int),
|
||
"WRITECH_DEBUG": ("server", "debug", lambda x: x.lower() == "true"),
|
||
"WRITECH_MODELS_DIR": ("model", "models_dir", str),
|
||
"WRITECH_GPU_DEVICE": ("gpu", "device", str),
|
||
"WRITECH_GPU_IDS": ("gpu", "gpu_ids", lambda x: [int(i) for i in x.split(",")]),
|
||
"WRITECH_REDIS_URL": ("celery", "broker_url", str),
|
||
"WRITECH_MYSQL_URL": ("database", "mysql_url", str),
|
||
"WRITECH_LOG_LEVEL": ("log", "level", str),
|
||
"WRITECH_ENABLE_MTLS": ("security", "enable_mtls", lambda x: x.lower() == "true"),
|
||
}
|
||
|
||
for env_key, (section, field, converter) in env_mapping.items():
|
||
value = os.environ.get(env_key)
|
||
if value is not None:
|
||
config_obj = getattr(self, section)
|
||
try:
|
||
setattr(config_obj, field, converter(value))
|
||
logger.info(f"环境变量覆盖配置: {env_key} -> {section}.{field}")
|
||
except (ValueError, TypeError) as e:
|
||
logger.warning(f"环境变量转换失败: {env_key}={value}, 错误: {str(e)}")
|
||
|
||
def load_from_file(self, config_path: str):
|
||
"""从JSON配置文件加载配置"""
|
||
try:
|
||
with open(config_path, 'r') as f:
|
||
config_data = json.load(f)
|
||
logger.info(f"配置文件加载完成: {config_path}")
|
||
|
||
# 逐section更新配置
|
||
for section_name, section_data in config_data.items():
|
||
if hasattr(self, section_name) and isinstance(section_data, dict):
|
||
config_obj = getattr(self, section_name)
|
||
for key, value in section_data.items():
|
||
if hasattr(config_obj, key):
|
||
setattr(config_obj, key, value)
|
||
|
||
except FileNotFoundError:
|
||
logger.warning(f"配置文件不存在: {config_path}")
|
||
except json.JSONDecodeError as e:
|
||
logger.error(f"配置文件JSON解析错误: {str(e)}")
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
"""将所有配置导出为字典(隐藏敏感信息)"""
|
||
result = {}
|
||
for section in ['server', 'model', 'gpu', 'celery', 'log']:
|
||
config_obj = getattr(self, section)
|
||
section_dict = {}
|
||
for key in vars(config_obj):
|
||
value = getattr(config_obj, key)
|
||
# 隐藏密码和密钥类字段
|
||
if any(kw in key.lower() for kw in ['password', 'secret', 'key', 'token']):
|
||
section_dict[key] = "***"
|
||
else:
|
||
section_dict[key] = value
|
||
result[section] = section_dict
|
||
return result
|
||
|
||
|
||
# 全局配置实例
|
||
settings = Settings()
|