Files
system-design/software-copyright/02-writech-ai-engine/config/settings.py
T
2026-03-22 15:24:40 +08:00

337 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# 自然写手写识别与AI分析引擎软件 V1.0
# 配置与安全模块 - 全局配置管理与安全策略
"""
全局配置管理
提供AI引擎服务的所有配置项管理,包括:
服务端口、模型路径、GPU配置、安全认证、日志级别等
支持环境变量覆盖和配置热更新
"""
import os
import json
import logging
import hashlib
import hmac
import time
from typing import Dict, List, Optional, Any
from dataclasses import dataclass, field
from pathlib import Path
logger = logging.getLogger(__name__)
# ==================== 服务配置 ====================
@dataclass
class ServerConfig:
"""HTTP/gRPC服务配置"""
http_host: str = "0.0.0.0"
http_port: int = 8000
grpc_host: str = "0.0.0.0"
grpc_port: int = 50051
workers: int = 4 # FastAPI worker数量
grpc_max_workers: int = 10 # gRPC线程池大小
max_request_size_mb: int = 10 # 请求体大小限制(防恶意攻击)
request_timeout_s: int = 30 # 请求超时时间
cors_origins: List[str] = field(default_factory=lambda: ["*"])
debug: bool = False
@dataclass
class ModelConfig:
"""模型推理配置"""
models_dir: str = "/opt/models" # 模型文件根目录
ocr_model_path: str = "/opt/models/ocr" # OCR模型路径
math_model_path: str = "/opt/models/math" # 数学识别模型路径
stroke_model_path: str = "/opt/models/stroke" # 笔顺模型路径
essay_model_path: str = "/opt/models/essay" # 作文评分模型路径
max_batch_size: int = 32 # 最大推理批大小
inference_timeout_ms: int = 5000 # 单次推理超时
enable_fp16: bool = True # FP16半精度推理
model_cache_size_gb: float = 4.0 # 模型内存缓存大小
@dataclass
class GPUConfig:
"""GPU/NPU硬件加速配置"""
device: str = "cuda" # 推理设备: cuda / cpu / npu
gpu_ids: List[int] = field(default_factory=lambda: [0]) # 使用的GPU编号
gpu_memory_fraction: float = 0.8 # GPU显存使用比例上限
enable_tensorrt: bool = True # 是否启用TensorRT加速
tensorrt_precision: str = "fp16" # TensorRT精度: fp32/fp16/int8
triton_url: str = "localhost:8001" # Triton Inference Server地址
@dataclass
class CeleryConfig:
"""Celery任务队列配置"""
broker_url: str = "redis://localhost:6379/0" # Redis Broker地址
result_backend: str = "redis://localhost:6379/1" # 结果存储后端
task_serializer: str = "json"
result_serializer: str = "json"
task_default_queue: str = "writech.default"
task_time_limit: int = 300 # 任务最大执行时间(秒)
task_soft_time_limit: int = 240 # 软超时(触发SoftTimeLimitExceeded
worker_concurrency: int = 8 # Worker并发数
worker_prefetch_multiplier: int = 2 # 预取倍数
@dataclass
class DatabaseConfig:
"""数据库配置"""
mysql_url: str = "mysql+pymysql://user:password@localhost:3306/writech_ai"
redis_url: str = "redis://localhost:6379/0"
mongodb_url: str = "mongodb://localhost:27017/writech_stroke"
pool_size: int = 20 # 连接池大小
pool_recycle: int = 3600 # 连接回收时间(秒)
@dataclass
class LogConfig:
"""日志配置"""
level: str = "INFO"
format: str = "%(asctime)s [%(levelname)s] %(name)s: %(message)s"
log_dir: str = "/var/log/writech-ai"
max_file_size_mb: int = 100 # 单个日志文件大小上限
backup_count: int = 10 # 保留日志文件数量
enable_audit_log: bool = True # 启用审计日志
audit_log_file: str = "audit.log" # 审计日志文件名
# ==================== 安全配置 ====================
@dataclass
class SecurityConfig:
"""安全配置"""
# mTLS双向认证(安全设计:内部服务间mTLS双向认证)
enable_mtls: bool = True
server_cert_path: str = "/etc/ssl/server.crt"
server_key_path: str = "/etc/ssl/server.key"
ca_cert_path: str = "/etc/ssl/ca.crt"
# 模型文件加密(安全设计:模型文件加密存储,推理时内存解密)
model_encryption_enabled: bool = True
model_encryption_key_env: str = "WRITECH_MODEL_KEY" # 加密密钥从环境变量读取
# 请求校验(安全设计:输入数据格式校验与大小限制)
max_stroke_points: int = 100000 # 单次请求最大坐标点数
max_strokes_per_request: int = 500 # 单次请求最大笔画数
max_text_length: int = 10000 # 作文文本最大长度
# 速率限制
rate_limit_per_minute: int = 600 # 每分钟最大请求数
rate_limit_burst: int = 50 # 突发请求数
# 审计日志(安全设计:所有识别请求记录调用方、时间、模型版本)
enable_audit: bool = True
audit_retention_days: int = 90 # 审计日志保留天数
# ==================== mTLS认证管理 ====================
class MTLSAuthenticator:
"""
mTLS双向认证管理器
验证客户端证书,确保只有授权的内部服务可以调用AI引擎
"""
def __init__(self, config: SecurityConfig):
self._config = config
self._trusted_clients: Dict[str, str] = {} # 授信客户端证书指纹
logger.info("mTLS认证管理器初始化")
def load_certificates(self) -> bool:
"""加载服务端证书和CA证书"""
try:
cert_path = Path(self._config.server_cert_path)
key_path = Path(self._config.server_key_path)
ca_path = Path(self._config.ca_cert_path)
if not cert_path.exists():
logger.warning(f"服务端证书不存在: {cert_path}")
return False
logger.info("mTLS证书加载完成")
return True
except Exception as e:
logger.error(f"证书加载失败: {str(e)}")
return False
def verify_client_cert(self, cert_fingerprint: str) -> bool:
"""验证客户端证书指纹"""
if not self._config.enable_mtls:
return True
is_trusted = cert_fingerprint in self._trusted_clients
if not is_trusted:
logger.warning(f"未授信的客户端证书: {cert_fingerprint}")
return is_trusted
def register_trusted_client(self, name: str, fingerprint: str):
"""注册授信客户端"""
self._trusted_clients[fingerprint] = name
logger.info(f"注册授信客户端: {name}")
# ==================== 请求签名校验 ====================
class RequestValidator:
"""
请求签名校验器
对API请求进行HMAC签名校验,防止请求篡改和重放攻击
"""
def __init__(self, secret_key: str = ""):
self._secret = secret_key or os.environ.get("WRITECH_API_SECRET", "default-secret")
self._nonce_cache: Dict[str, float] = {} # 随机数缓存(防重放)
self._nonce_ttl = 300 # 随机数有效期(秒)
def generate_signature(self, payload: str, timestamp: int, nonce: str) -> str:
"""生成请求签名"""
message = f"{payload}&timestamp={timestamp}&nonce={nonce}"
return hmac.new(
self._secret.encode(), message.encode(), hashlib.sha256
).hexdigest()
def verify_signature(self, payload: str, timestamp: int,
nonce: str, signature: str) -> bool:
"""
校验请求签名
1. 检查时间戳是否在有效窗口内(防重放)
2. 检查随机数是否已使用(防重放)
3. 验证HMAC签名是否匹配(防篡改)
"""
# 时间窗口校验(±5分钟)
current_time = int(time.time())
if abs(current_time - timestamp) > 300:
logger.warning(f"请求时间戳过期: {timestamp}")
return False
# 随机数防重放检查
if nonce in self._nonce_cache:
logger.warning(f"重复的请求随机数: {nonce}")
return False
# HMAC签名验证
expected = self.generate_signature(payload, timestamp, nonce)
is_valid = hmac.compare_digest(expected, signature)
if is_valid:
# 缓存随机数
self._nonce_cache[nonce] = time.time()
self._cleanup_nonce_cache()
return is_valid
def _cleanup_nonce_cache(self):
"""清理过期的随机数缓存"""
current = time.time()
expired = [k for k, v in self._nonce_cache.items() if current - v > self._nonce_ttl]
for k in expired:
del self._nonce_cache[k]
# ==================== 全局配置管理器 ====================
class Settings:
"""
全局配置管理器(单例)
从环境变量和配置文件加载配置,支持运行时热更新
环境变量优先级高于配置文件
"""
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if hasattr(self, '_initialized'):
return
self._initialized = True
# 加载各模块配置
self.server = ServerConfig()
self.model = ModelConfig()
self.gpu = GPUConfig()
self.celery = CeleryConfig()
self.database = DatabaseConfig()
self.log = LogConfig()
self.security = SecurityConfig()
# 从环境变量覆盖配置
self._load_from_env()
# 初始化安全组件
self.mtls_auth = MTLSAuthenticator(self.security)
self.request_validator = RequestValidator()
logger.info("全局配置加载完成")
def _load_from_env(self):
"""从环境变量加载配置(覆盖默认值)"""
env_mapping = {
"WRITECH_HTTP_PORT": ("server", "http_port", int),
"WRITECH_GRPC_PORT": ("server", "grpc_port", int),
"WRITECH_WORKERS": ("server", "workers", int),
"WRITECH_DEBUG": ("server", "debug", lambda x: x.lower() == "true"),
"WRITECH_MODELS_DIR": ("model", "models_dir", str),
"WRITECH_GPU_DEVICE": ("gpu", "device", str),
"WRITECH_GPU_IDS": ("gpu", "gpu_ids", lambda x: [int(i) for i in x.split(",")]),
"WRITECH_REDIS_URL": ("celery", "broker_url", str),
"WRITECH_MYSQL_URL": ("database", "mysql_url", str),
"WRITECH_LOG_LEVEL": ("log", "level", str),
"WRITECH_ENABLE_MTLS": ("security", "enable_mtls", lambda x: x.lower() == "true"),
}
for env_key, (section, field, converter) in env_mapping.items():
value = os.environ.get(env_key)
if value is not None:
config_obj = getattr(self, section)
try:
setattr(config_obj, field, converter(value))
logger.info(f"环境变量覆盖配置: {env_key} -> {section}.{field}")
except (ValueError, TypeError) as e:
logger.warning(f"环境变量转换失败: {env_key}={value}, 错误: {str(e)}")
def load_from_file(self, config_path: str):
"""从JSON配置文件加载配置"""
try:
with open(config_path, 'r') as f:
config_data = json.load(f)
logger.info(f"配置文件加载完成: {config_path}")
# 逐section更新配置
for section_name, section_data in config_data.items():
if hasattr(self, section_name) and isinstance(section_data, dict):
config_obj = getattr(self, section_name)
for key, value in section_data.items():
if hasattr(config_obj, key):
setattr(config_obj, key, value)
except FileNotFoundError:
logger.warning(f"配置文件不存在: {config_path}")
except json.JSONDecodeError as e:
logger.error(f"配置文件JSON解析错误: {str(e)}")
def to_dict(self) -> Dict[str, Any]:
"""将所有配置导出为字典(隐藏敏感信息)"""
result = {}
for section in ['server', 'model', 'gpu', 'celery', 'log']:
config_obj = getattr(self, section)
section_dict = {}
for key in vars(config_obj):
value = getattr(config_obj, key)
# 隐藏密码和密钥类字段
if any(kw in key.lower() for kw in ['password', 'secret', 'key', 'token']):
section_dict[key] = "***"
else:
section_dict[key] = value
result[section] = section_dict
return result
# 全局配置实例
settings = Settings()