software copyright
This commit is contained in:
@@ -0,0 +1,336 @@
|
||||
# 自然写手写识别与AI分析引擎软件 V1.0
|
||||
# 配置与安全模块 - 全局配置管理与安全策略
|
||||
|
||||
"""
|
||||
全局配置管理
|
||||
提供AI引擎服务的所有配置项管理,包括:
|
||||
服务端口、模型路径、GPU配置、安全认证、日志级别等
|
||||
支持环境变量覆盖和配置热更新
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import hashlib
|
||||
import hmac
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ==================== 服务配置 ====================
|
||||
|
||||
@dataclass
|
||||
class ServerConfig:
|
||||
"""HTTP/gRPC服务配置"""
|
||||
http_host: str = "0.0.0.0"
|
||||
http_port: int = 8000
|
||||
grpc_host: str = "0.0.0.0"
|
||||
grpc_port: int = 50051
|
||||
workers: int = 4 # FastAPI worker数量
|
||||
grpc_max_workers: int = 10 # gRPC线程池大小
|
||||
max_request_size_mb: int = 10 # 请求体大小限制(防恶意攻击)
|
||||
request_timeout_s: int = 30 # 请求超时时间
|
||||
cors_origins: List[str] = field(default_factory=lambda: ["*"])
|
||||
debug: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
"""模型推理配置"""
|
||||
models_dir: str = "/opt/models" # 模型文件根目录
|
||||
ocr_model_path: str = "/opt/models/ocr" # OCR模型路径
|
||||
math_model_path: str = "/opt/models/math" # 数学识别模型路径
|
||||
stroke_model_path: str = "/opt/models/stroke" # 笔顺模型路径
|
||||
essay_model_path: str = "/opt/models/essay" # 作文评分模型路径
|
||||
max_batch_size: int = 32 # 最大推理批大小
|
||||
inference_timeout_ms: int = 5000 # 单次推理超时
|
||||
enable_fp16: bool = True # FP16半精度推理
|
||||
model_cache_size_gb: float = 4.0 # 模型内存缓存大小
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPUConfig:
|
||||
"""GPU/NPU硬件加速配置"""
|
||||
device: str = "cuda" # 推理设备: cuda / cpu / npu
|
||||
gpu_ids: List[int] = field(default_factory=lambda: [0]) # 使用的GPU编号
|
||||
gpu_memory_fraction: float = 0.8 # GPU显存使用比例上限
|
||||
enable_tensorrt: bool = True # 是否启用TensorRT加速
|
||||
tensorrt_precision: str = "fp16" # TensorRT精度: fp32/fp16/int8
|
||||
triton_url: str = "localhost:8001" # Triton Inference Server地址
|
||||
|
||||
|
||||
@dataclass
|
||||
class CeleryConfig:
|
||||
"""Celery任务队列配置"""
|
||||
broker_url: str = "redis://localhost:6379/0" # Redis Broker地址
|
||||
result_backend: str = "redis://localhost:6379/1" # 结果存储后端
|
||||
task_serializer: str = "json"
|
||||
result_serializer: str = "json"
|
||||
task_default_queue: str = "writech.default"
|
||||
task_time_limit: int = 300 # 任务最大执行时间(秒)
|
||||
task_soft_time_limit: int = 240 # 软超时(触发SoftTimeLimitExceeded)
|
||||
worker_concurrency: int = 8 # Worker并发数
|
||||
worker_prefetch_multiplier: int = 2 # 预取倍数
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatabaseConfig:
|
||||
"""数据库配置"""
|
||||
mysql_url: str = "mysql+pymysql://user:password@localhost:3306/writech_ai"
|
||||
redis_url: str = "redis://localhost:6379/0"
|
||||
mongodb_url: str = "mongodb://localhost:27017/writech_stroke"
|
||||
pool_size: int = 20 # 连接池大小
|
||||
pool_recycle: int = 3600 # 连接回收时间(秒)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LogConfig:
|
||||
"""日志配置"""
|
||||
level: str = "INFO"
|
||||
format: str = "%(asctime)s [%(levelname)s] %(name)s: %(message)s"
|
||||
log_dir: str = "/var/log/writech-ai"
|
||||
max_file_size_mb: int = 100 # 单个日志文件大小上限
|
||||
backup_count: int = 10 # 保留日志文件数量
|
||||
enable_audit_log: bool = True # 启用审计日志
|
||||
audit_log_file: str = "audit.log" # 审计日志文件名
|
||||
|
||||
|
||||
# ==================== 安全配置 ====================
|
||||
|
||||
@dataclass
|
||||
class SecurityConfig:
|
||||
"""安全配置"""
|
||||
# mTLS双向认证(安全设计:内部服务间mTLS双向认证)
|
||||
enable_mtls: bool = True
|
||||
server_cert_path: str = "/etc/ssl/server.crt"
|
||||
server_key_path: str = "/etc/ssl/server.key"
|
||||
ca_cert_path: str = "/etc/ssl/ca.crt"
|
||||
|
||||
# 模型文件加密(安全设计:模型文件加密存储,推理时内存解密)
|
||||
model_encryption_enabled: bool = True
|
||||
model_encryption_key_env: str = "WRITECH_MODEL_KEY" # 加密密钥从环境变量读取
|
||||
|
||||
# 请求校验(安全设计:输入数据格式校验与大小限制)
|
||||
max_stroke_points: int = 100000 # 单次请求最大坐标点数
|
||||
max_strokes_per_request: int = 500 # 单次请求最大笔画数
|
||||
max_text_length: int = 10000 # 作文文本最大长度
|
||||
|
||||
# 速率限制
|
||||
rate_limit_per_minute: int = 600 # 每分钟最大请求数
|
||||
rate_limit_burst: int = 50 # 突发请求数
|
||||
|
||||
# 审计日志(安全设计:所有识别请求记录调用方、时间、模型版本)
|
||||
enable_audit: bool = True
|
||||
audit_retention_days: int = 90 # 审计日志保留天数
|
||||
|
||||
|
||||
# ==================== mTLS认证管理 ====================
|
||||
|
||||
class MTLSAuthenticator:
|
||||
"""
|
||||
mTLS双向认证管理器
|
||||
验证客户端证书,确保只有授权的内部服务可以调用AI引擎
|
||||
"""
|
||||
|
||||
def __init__(self, config: SecurityConfig):
|
||||
self._config = config
|
||||
self._trusted_clients: Dict[str, str] = {} # 授信客户端证书指纹
|
||||
logger.info("mTLS认证管理器初始化")
|
||||
|
||||
def load_certificates(self) -> bool:
|
||||
"""加载服务端证书和CA证书"""
|
||||
try:
|
||||
cert_path = Path(self._config.server_cert_path)
|
||||
key_path = Path(self._config.server_key_path)
|
||||
ca_path = Path(self._config.ca_cert_path)
|
||||
|
||||
if not cert_path.exists():
|
||||
logger.warning(f"服务端证书不存在: {cert_path}")
|
||||
return False
|
||||
|
||||
logger.info("mTLS证书加载完成")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"证书加载失败: {str(e)}")
|
||||
return False
|
||||
|
||||
def verify_client_cert(self, cert_fingerprint: str) -> bool:
|
||||
"""验证客户端证书指纹"""
|
||||
if not self._config.enable_mtls:
|
||||
return True
|
||||
is_trusted = cert_fingerprint in self._trusted_clients
|
||||
if not is_trusted:
|
||||
logger.warning(f"未授信的客户端证书: {cert_fingerprint}")
|
||||
return is_trusted
|
||||
|
||||
def register_trusted_client(self, name: str, fingerprint: str):
|
||||
"""注册授信客户端"""
|
||||
self._trusted_clients[fingerprint] = name
|
||||
logger.info(f"注册授信客户端: {name}")
|
||||
|
||||
|
||||
# ==================== 请求签名校验 ====================
|
||||
|
||||
class RequestValidator:
|
||||
"""
|
||||
请求签名校验器
|
||||
对API请求进行HMAC签名校验,防止请求篡改和重放攻击
|
||||
"""
|
||||
|
||||
def __init__(self, secret_key: str = ""):
|
||||
self._secret = secret_key or os.environ.get("WRITECH_API_SECRET", "default-secret")
|
||||
self._nonce_cache: Dict[str, float] = {} # 随机数缓存(防重放)
|
||||
self._nonce_ttl = 300 # 随机数有效期(秒)
|
||||
|
||||
def generate_signature(self, payload: str, timestamp: int, nonce: str) -> str:
|
||||
"""生成请求签名"""
|
||||
message = f"{payload}×tamp={timestamp}&nonce={nonce}"
|
||||
return hmac.new(
|
||||
self._secret.encode(), message.encode(), hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
def verify_signature(self, payload: str, timestamp: int,
|
||||
nonce: str, signature: str) -> bool:
|
||||
"""
|
||||
校验请求签名
|
||||
1. 检查时间戳是否在有效窗口内(防重放)
|
||||
2. 检查随机数是否已使用(防重放)
|
||||
3. 验证HMAC签名是否匹配(防篡改)
|
||||
"""
|
||||
# 时间窗口校验(±5分钟)
|
||||
current_time = int(time.time())
|
||||
if abs(current_time - timestamp) > 300:
|
||||
logger.warning(f"请求时间戳过期: {timestamp}")
|
||||
return False
|
||||
|
||||
# 随机数防重放检查
|
||||
if nonce in self._nonce_cache:
|
||||
logger.warning(f"重复的请求随机数: {nonce}")
|
||||
return False
|
||||
|
||||
# HMAC签名验证
|
||||
expected = self.generate_signature(payload, timestamp, nonce)
|
||||
is_valid = hmac.compare_digest(expected, signature)
|
||||
|
||||
if is_valid:
|
||||
# 缓存随机数
|
||||
self._nonce_cache[nonce] = time.time()
|
||||
self._cleanup_nonce_cache()
|
||||
|
||||
return is_valid
|
||||
|
||||
def _cleanup_nonce_cache(self):
|
||||
"""清理过期的随机数缓存"""
|
||||
current = time.time()
|
||||
expired = [k for k, v in self._nonce_cache.items() if current - v > self._nonce_ttl]
|
||||
for k in expired:
|
||||
del self._nonce_cache[k]
|
||||
|
||||
|
||||
# ==================== 全局配置管理器 ====================
|
||||
|
||||
class Settings:
|
||||
"""
|
||||
全局配置管理器(单例)
|
||||
从环境变量和配置文件加载配置,支持运行时热更新
|
||||
环境变量优先级高于配置文件
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if hasattr(self, '_initialized'):
|
||||
return
|
||||
self._initialized = True
|
||||
|
||||
# 加载各模块配置
|
||||
self.server = ServerConfig()
|
||||
self.model = ModelConfig()
|
||||
self.gpu = GPUConfig()
|
||||
self.celery = CeleryConfig()
|
||||
self.database = DatabaseConfig()
|
||||
self.log = LogConfig()
|
||||
self.security = SecurityConfig()
|
||||
|
||||
# 从环境变量覆盖配置
|
||||
self._load_from_env()
|
||||
|
||||
# 初始化安全组件
|
||||
self.mtls_auth = MTLSAuthenticator(self.security)
|
||||
self.request_validator = RequestValidator()
|
||||
|
||||
logger.info("全局配置加载完成")
|
||||
|
||||
def _load_from_env(self):
|
||||
"""从环境变量加载配置(覆盖默认值)"""
|
||||
env_mapping = {
|
||||
"WRITECH_HTTP_PORT": ("server", "http_port", int),
|
||||
"WRITECH_GRPC_PORT": ("server", "grpc_port", int),
|
||||
"WRITECH_WORKERS": ("server", "workers", int),
|
||||
"WRITECH_DEBUG": ("server", "debug", lambda x: x.lower() == "true"),
|
||||
"WRITECH_MODELS_DIR": ("model", "models_dir", str),
|
||||
"WRITECH_GPU_DEVICE": ("gpu", "device", str),
|
||||
"WRITECH_GPU_IDS": ("gpu", "gpu_ids", lambda x: [int(i) for i in x.split(",")]),
|
||||
"WRITECH_REDIS_URL": ("celery", "broker_url", str),
|
||||
"WRITECH_MYSQL_URL": ("database", "mysql_url", str),
|
||||
"WRITECH_LOG_LEVEL": ("log", "level", str),
|
||||
"WRITECH_ENABLE_MTLS": ("security", "enable_mtls", lambda x: x.lower() == "true"),
|
||||
}
|
||||
|
||||
for env_key, (section, field, converter) in env_mapping.items():
|
||||
value = os.environ.get(env_key)
|
||||
if value is not None:
|
||||
config_obj = getattr(self, section)
|
||||
try:
|
||||
setattr(config_obj, field, converter(value))
|
||||
logger.info(f"环境变量覆盖配置: {env_key} -> {section}.{field}")
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.warning(f"环境变量转换失败: {env_key}={value}, 错误: {str(e)}")
|
||||
|
||||
def load_from_file(self, config_path: str):
|
||||
"""从JSON配置文件加载配置"""
|
||||
try:
|
||||
with open(config_path, 'r') as f:
|
||||
config_data = json.load(f)
|
||||
logger.info(f"配置文件加载完成: {config_path}")
|
||||
|
||||
# 逐section更新配置
|
||||
for section_name, section_data in config_data.items():
|
||||
if hasattr(self, section_name) and isinstance(section_data, dict):
|
||||
config_obj = getattr(self, section_name)
|
||||
for key, value in section_data.items():
|
||||
if hasattr(config_obj, key):
|
||||
setattr(config_obj, key, value)
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.warning(f"配置文件不存在: {config_path}")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"配置文件JSON解析错误: {str(e)}")
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""将所有配置导出为字典(隐藏敏感信息)"""
|
||||
result = {}
|
||||
for section in ['server', 'model', 'gpu', 'celery', 'log']:
|
||||
config_obj = getattr(self, section)
|
||||
section_dict = {}
|
||||
for key in vars(config_obj):
|
||||
value = getattr(config_obj, key)
|
||||
# 隐藏密码和密钥类字段
|
||||
if any(kw in key.lower() for kw in ['password', 'secret', 'key', 'token']):
|
||||
section_dict[key] = "***"
|
||||
else:
|
||||
section_dict[key] = value
|
||||
result[section] = section_dict
|
||||
return result
|
||||
|
||||
|
||||
# 全局配置实例
|
||||
settings = Settings()
|
||||
Reference in New Issue
Block a user