503 lines
17 KiB
Python
503 lines
17 KiB
Python
# 自然写教学数据分析与学情诊断系统软件 V1.0
|
||
# etl/flink_processor.py - Flink ETL实时数据处理管道
|
||
|
||
import logging
|
||
import json
|
||
import hashlib
|
||
from typing import Any, Dict, List, Optional, Tuple
|
||
from datetime import datetime, timedelta
|
||
from dataclasses import dataclass, field, asdict
|
||
from enum import Enum
|
||
|
||
logger = logging.getLogger("writech.analytics.etl")
|
||
|
||
|
||
# ============================================================
|
||
# ETL数据模型
|
||
# ============================================================
|
||
|
||
class EventType(str, Enum):
|
||
"""数据事件类型"""
|
||
STROKE_RAW = "stroke_raw" # 原始笔迹数据
|
||
GRADE_RESULT = "grade_result" # 批改结果
|
||
HOMEWORK_SUBMIT = "homework_submit" # 作业提交
|
||
OCR_RESULT = "ocr_result" # OCR识别结果
|
||
STROKE_ORDER = "stroke_order" # 笔顺评分结果
|
||
WRITING_QUALITY = "writing_quality" # 书写质量评分
|
||
EXAM_SCORE = "exam_score" # 考试成绩
|
||
LOGIN_EVENT = "login_event" # 登录事件
|
||
|
||
|
||
@dataclass
|
||
class RawEvent:
|
||
"""原始事件数据"""
|
||
event_id: str
|
||
event_type: EventType
|
||
student_id: str
|
||
class_id: str
|
||
school_id: str
|
||
timestamp: str
|
||
payload: Dict[str, Any]
|
||
source: str = "" # 事件来源(pad/mobile/pc/board)
|
||
|
||
|
||
@dataclass
|
||
class AggregatedMetric:
|
||
"""聚合指标数据(写入ClickHouse)"""
|
||
metric_id: str
|
||
student_id: str
|
||
class_id: str
|
||
school_id: str
|
||
subject: str
|
||
metric_type: str # 指标类型
|
||
metric_value: float
|
||
dimension: str = "" # 维度(如knowledge_id)
|
||
period: str = "daily" # 聚合周期
|
||
period_start: str = ""
|
||
period_end: str = ""
|
||
created_at: str = ""
|
||
|
||
|
||
@dataclass
|
||
class StudentDailyStats:
|
||
"""学生每日统计汇总"""
|
||
student_id: str
|
||
date: str
|
||
subject: str
|
||
# 作业维度
|
||
homework_count: int = 0
|
||
homework_completed: int = 0
|
||
homework_avg_score: float = 0.0
|
||
# 练习维度
|
||
practice_count: int = 0
|
||
practice_total_chars: int = 0
|
||
practice_avg_score: float = 0.0
|
||
# 书写维度
|
||
writing_quality_avg: float = 0.0
|
||
stroke_order_accuracy: float = 0.0
|
||
writing_speed_avg: float = 0.0
|
||
# 错题维度
|
||
error_count: int = 0
|
||
error_knowledge_points: List[str] = field(default_factory=list)
|
||
# 时间维度
|
||
study_duration_minutes: int = 0
|
||
|
||
|
||
# ============================================================
|
||
# Flink ETL处理管道
|
||
# ============================================================
|
||
|
||
class FlinkETLProcessor:
|
||
"""
|
||
Flink实时ETL处理器
|
||
|
||
数据流:
|
||
原始笔迹/批改数据 → Kafka → Flink实时计算 →
|
||
聚合指标写入ClickHouse → 定时生成诊断报告
|
||
|
||
处理阶段:
|
||
1. 数据采集(Kafka Source)
|
||
2. 数据清洗与标准化
|
||
3. 实时聚合计算
|
||
4. 窗口统计
|
||
5. 写入ClickHouse(Sink)
|
||
"""
|
||
|
||
def __init__(self, config: Dict[str, Any]):
|
||
"""初始化ETL处理器"""
|
||
self.kafka_brokers = config.get("kafka_brokers", "localhost:9092")
|
||
self.kafka_topics = config.get("kafka_topics", [])
|
||
self.clickhouse_config = config.get("clickhouse", {})
|
||
self.batch_size = config.get("batch_size", 100)
|
||
self.window_size_seconds = config.get("window_size", 60)
|
||
|
||
# 内存中的聚合缓冲区
|
||
self._daily_stats_buffer: Dict[str, StudentDailyStats] = {}
|
||
self._metric_buffer: List[AggregatedMetric] = []
|
||
self._error_records_buffer: List[Dict[str, Any]] = []
|
||
|
||
# 数据质量统计
|
||
self._processed_count = 0
|
||
self._error_count = 0
|
||
self._dropped_count = 0
|
||
|
||
logger.info(
|
||
"FlinkETL初始化: brokers=%s, topics=%s, batch=%d",
|
||
self.kafka_brokers,
|
||
self.kafka_topics,
|
||
self.batch_size,
|
||
)
|
||
|
||
def start_pipeline(self) -> None:
|
||
"""启动ETL处理管道"""
|
||
logger.info("启动Flink ETL处理管道...")
|
||
|
||
# 配置Flink执行环境
|
||
# env = StreamExecutionEnvironment.get_execution_environment()
|
||
# env.set_parallelism(4)
|
||
# env.enable_checkpointing(60000) # 60秒checkpoint
|
||
|
||
# 定义Kafka数据源
|
||
# kafka_source = KafkaSource.builder() \
|
||
# .set_bootstrap_servers(self.kafka_brokers) \
|
||
# .set_topics(self.kafka_topics) \
|
||
# .set_group_id("analytics-etl") \
|
||
# .set_starting_offsets(KafkaOffsetsInitializer.latest()) \
|
||
# .set_value_only_deserializer(SimpleStringSchema()) \
|
||
# .build()
|
||
|
||
# 创建数据流
|
||
# stream = env.from_source(kafka_source, ...)
|
||
|
||
# 数据处理链
|
||
# stream \
|
||
# .map(self._parse_event) \
|
||
# .filter(self._validate_event) \
|
||
# .key_by(lambda e: e.student_id) \
|
||
# .window(TumblingEventTimeWindows.of(Time.minutes(1))) \
|
||
# .process(self._aggregate_window) \
|
||
# .add_sink(clickhouse_sink)
|
||
|
||
# env.execute("WritechAnalyticsETL")
|
||
|
||
logger.info("ETL管道已启动")
|
||
|
||
def _parse_event(self, raw_json: str) -> Optional[RawEvent]:
|
||
"""
|
||
解析原始JSON消息为RawEvent对象
|
||
|
||
数据清洗规则:
|
||
- 必须包含event_type, student_id, timestamp字段
|
||
- timestamp格式校验(ISO 8601)
|
||
- 过滤空payload
|
||
"""
|
||
try:
|
||
data = json.loads(raw_json)
|
||
|
||
# 字段完整性校验
|
||
required_fields = ["event_type", "student_id", "timestamp"]
|
||
for field_name in required_fields:
|
||
if field_name not in data or not data[field_name]:
|
||
self._dropped_count += 1
|
||
logger.debug("丢弃不完整事件: 缺少%s", field_name)
|
||
return None
|
||
|
||
# 事件类型校验
|
||
try:
|
||
event_type = EventType(data["event_type"])
|
||
except ValueError:
|
||
self._dropped_count += 1
|
||
logger.debug("丢弃未知事件类型: %s", data["event_type"])
|
||
return None
|
||
|
||
# 时间戳校验
|
||
try:
|
||
datetime.fromisoformat(
|
||
data["timestamp"].replace("Z", "+00:00")
|
||
)
|
||
except (ValueError, AttributeError):
|
||
self._dropped_count += 1
|
||
return None
|
||
|
||
# 生成唯一事件ID(去重用)
|
||
event_id = hashlib.md5(
|
||
f"{data['student_id']}_{data['timestamp']}_{raw_json[:50]}"
|
||
.encode()
|
||
).hexdigest()
|
||
|
||
event = RawEvent(
|
||
event_id=event_id,
|
||
event_type=event_type,
|
||
student_id=data["student_id"],
|
||
class_id=data.get("class_id", ""),
|
||
school_id=data.get("school_id", ""),
|
||
timestamp=data["timestamp"],
|
||
payload=data.get("payload", {}),
|
||
source=data.get("source", ""),
|
||
)
|
||
|
||
self._processed_count += 1
|
||
return event
|
||
|
||
except json.JSONDecodeError as e:
|
||
self._error_count += 1
|
||
logger.warning("JSON解析失败: %s", str(e))
|
||
return None
|
||
except Exception as e:
|
||
self._error_count += 1
|
||
logger.error("事件解析异常: %s", str(e))
|
||
return None
|
||
|
||
def _validate_event(self, event: Optional[RawEvent]) -> bool:
|
||
"""事件有效性过滤"""
|
||
if event is None:
|
||
return False
|
||
|
||
# 过滤过旧的数据(超过7天不处理)
|
||
try:
|
||
event_time = datetime.fromisoformat(
|
||
event.timestamp.replace("Z", "+00:00")
|
||
)
|
||
if datetime.now(event_time.tzinfo) - event_time > timedelta(days=7):
|
||
self._dropped_count += 1
|
||
return False
|
||
except Exception:
|
||
return False
|
||
|
||
return True
|
||
|
||
def process_event(self, event: RawEvent) -> None:
|
||
"""
|
||
根据事件类型分发处理
|
||
|
||
不同事件类型有不同的聚合逻辑:
|
||
- stroke_raw: 累计书写笔迹量
|
||
- grade_result: 更新作业得分统计
|
||
- stroke_order: 更新笔顺准确率
|
||
- writing_quality: 更新书写质量评分
|
||
"""
|
||
handler_map = {
|
||
EventType.STROKE_RAW: self._process_stroke,
|
||
EventType.GRADE_RESULT: self._process_grade,
|
||
EventType.HOMEWORK_SUBMIT: self._process_homework,
|
||
EventType.OCR_RESULT: self._process_ocr,
|
||
EventType.STROKE_ORDER: self._process_stroke_order,
|
||
EventType.WRITING_QUALITY: self._process_writing_quality,
|
||
EventType.EXAM_SCORE: self._process_exam_score,
|
||
}
|
||
|
||
handler = handler_map.get(event.event_type)
|
||
if handler:
|
||
handler(event)
|
||
else:
|
||
logger.debug("未处理的事件类型: %s", event.event_type)
|
||
|
||
def _get_daily_stats(
|
||
self, student_id: str, date_str: str, subject: str
|
||
) -> StudentDailyStats:
|
||
"""获取或创建学生每日统计缓冲"""
|
||
key = f"{student_id}_{date_str}_{subject}"
|
||
if key not in self._daily_stats_buffer:
|
||
self._daily_stats_buffer[key] = StudentDailyStats(
|
||
student_id=student_id,
|
||
date=date_str,
|
||
subject=subject,
|
||
)
|
||
return self._daily_stats_buffer[key]
|
||
|
||
def _process_stroke(self, event: RawEvent) -> None:
|
||
"""处理原始笔迹数据事件"""
|
||
payload = event.payload
|
||
stroke_count = payload.get("stroke_count", 0)
|
||
page_id = payload.get("page_id", "")
|
||
|
||
# 累计笔迹量到每日统计
|
||
date_str = event.timestamp[:10]
|
||
subject = payload.get("subject", "unknown")
|
||
stats = self._get_daily_stats(event.student_id, date_str, subject)
|
||
stats.practice_total_chars += stroke_count
|
||
|
||
# 生成笔迹量聚合指标
|
||
metric = AggregatedMetric(
|
||
metric_id=event.event_id,
|
||
student_id=event.student_id,
|
||
class_id=event.class_id,
|
||
school_id=event.school_id,
|
||
subject=subject,
|
||
metric_type="stroke_count",
|
||
metric_value=float(stroke_count),
|
||
dimension=page_id,
|
||
period_start=date_str,
|
||
created_at=event.timestamp,
|
||
)
|
||
self._metric_buffer.append(metric)
|
||
|
||
def _process_grade(self, event: RawEvent) -> None:
|
||
"""处理批改结果事件"""
|
||
payload = event.payload
|
||
score = payload.get("score", 0)
|
||
total_score = payload.get("total_score", 100)
|
||
subject = payload.get("subject", "unknown")
|
||
homework_id = payload.get("homework_id", "")
|
||
|
||
date_str = event.timestamp[:10]
|
||
stats = self._get_daily_stats(event.student_id, date_str, subject)
|
||
stats.homework_count += 1
|
||
stats.homework_completed += 1
|
||
|
||
# 增量更新平均分
|
||
n = stats.homework_completed
|
||
stats.homework_avg_score = (
|
||
stats.homework_avg_score * (n - 1) + score
|
||
) / n
|
||
|
||
# 处理错题记录
|
||
errors = payload.get("errors", [])
|
||
for error in errors:
|
||
knowledge_point = error.get("knowledge_point", "")
|
||
if knowledge_point:
|
||
stats.error_count += 1
|
||
if knowledge_point not in stats.error_knowledge_points:
|
||
stats.error_knowledge_points.append(knowledge_point)
|
||
|
||
# 错题写入MySQL
|
||
self._error_records_buffer.append({
|
||
"student_id": event.student_id,
|
||
"homework_id": homework_id,
|
||
"question_id": error.get("question_id", ""),
|
||
"subject": subject,
|
||
"knowledge_point": knowledge_point,
|
||
"error_type": error.get("error_type", ""),
|
||
"created_at": event.timestamp,
|
||
})
|
||
|
||
def _process_homework(self, event: RawEvent) -> None:
|
||
"""处理作业提交事件"""
|
||
payload = event.payload
|
||
subject = payload.get("subject", "unknown")
|
||
time_cost = payload.get("time_cost_minutes", 0)
|
||
|
||
date_str = event.timestamp[:10]
|
||
stats = self._get_daily_stats(event.student_id, date_str, subject)
|
||
stats.study_duration_minutes += time_cost
|
||
|
||
def _process_ocr(self, event: RawEvent) -> None:
|
||
"""处理OCR识别结果事件"""
|
||
payload = event.payload
|
||
confidence = payload.get("confidence", 0.0)
|
||
char_count = payload.get("char_count", 0)
|
||
|
||
# OCR识别结果用于辅助计算书写清晰度指标
|
||
metric = AggregatedMetric(
|
||
metric_id=event.event_id,
|
||
student_id=event.student_id,
|
||
class_id=event.class_id,
|
||
school_id=event.school_id,
|
||
subject="chinese",
|
||
metric_type="ocr_confidence",
|
||
metric_value=confidence,
|
||
created_at=event.timestamp,
|
||
)
|
||
self._metric_buffer.append(metric)
|
||
|
||
def _process_stroke_order(self, event: RawEvent) -> None:
|
||
"""处理笔顺评分结果事件"""
|
||
payload = event.payload
|
||
score = payload.get("score", 0.0)
|
||
character = payload.get("character", "")
|
||
|
||
date_str = event.timestamp[:10]
|
||
stats = self._get_daily_stats(event.student_id, date_str, "chinese")
|
||
|
||
# 增量更新笔顺准确率
|
||
stats.practice_count += 1
|
||
n = stats.practice_count
|
||
stats.stroke_order_accuracy = (
|
||
stats.stroke_order_accuracy * (n - 1) + score
|
||
) / n
|
||
|
||
def _process_writing_quality(self, event: RawEvent) -> None:
|
||
"""处理书写质量评分事件"""
|
||
payload = event.payload
|
||
quality_score = payload.get("quality_score", 0.0)
|
||
speed = payload.get("speed", 0.0)
|
||
|
||
date_str = event.timestamp[:10]
|
||
stats = self._get_daily_stats(event.student_id, date_str, "chinese")
|
||
|
||
# 更新书写质量指标
|
||
count = max(stats.practice_count, 1)
|
||
stats.writing_quality_avg = (
|
||
stats.writing_quality_avg * (count - 1) + quality_score
|
||
) / count
|
||
stats.writing_speed_avg = (
|
||
stats.writing_speed_avg * (count - 1) + speed
|
||
) / count
|
||
|
||
def _process_exam_score(self, event: RawEvent) -> None:
|
||
"""处理考试成绩事件"""
|
||
payload = event.payload
|
||
subject = payload.get("subject", "unknown")
|
||
score = payload.get("score", 0)
|
||
total = payload.get("total_score", 100)
|
||
|
||
metric = AggregatedMetric(
|
||
metric_id=event.event_id,
|
||
student_id=event.student_id,
|
||
class_id=event.class_id,
|
||
school_id=event.school_id,
|
||
subject=subject,
|
||
metric_type="exam_score",
|
||
metric_value=float(score),
|
||
dimension=payload.get("exam_id", ""),
|
||
created_at=event.timestamp,
|
||
)
|
||
self._metric_buffer.append(metric)
|
||
|
||
def flush_to_clickhouse(self) -> int:
|
||
"""
|
||
将缓冲区的聚合指标批量写入ClickHouse
|
||
|
||
使用ClickHouse的INSERT批量写入提高性能。
|
||
写入后清空缓冲区。
|
||
返回写入的记录数。
|
||
"""
|
||
if not self._metric_buffer and not self._daily_stats_buffer:
|
||
return 0
|
||
|
||
total_written = 0
|
||
|
||
# 写入聚合指标
|
||
if self._metric_buffer:
|
||
metrics = [asdict(m) for m in self._metric_buffer]
|
||
# clickhouse_client.execute(
|
||
# "INSERT INTO analytics_metrics VALUES",
|
||
# metrics,
|
||
# )
|
||
total_written += len(metrics)
|
||
logger.info("写入%d条聚合指标到ClickHouse", len(metrics))
|
||
self._metric_buffer.clear()
|
||
|
||
# 写入每日统计
|
||
if self._daily_stats_buffer:
|
||
daily_stats = [
|
||
asdict(s) for s in self._daily_stats_buffer.values()
|
||
]
|
||
# clickhouse_client.execute(
|
||
# "INSERT INTO student_daily_stats VALUES",
|
||
# daily_stats,
|
||
# )
|
||
total_written += len(daily_stats)
|
||
logger.info("写入%d条每日统计到ClickHouse", len(daily_stats))
|
||
self._daily_stats_buffer.clear()
|
||
|
||
# 写入错题记录到MySQL
|
||
if self._error_records_buffer:
|
||
# mysql_batch_insert("error_record", self._error_records_buffer)
|
||
total_written += len(self._error_records_buffer)
|
||
logger.info(
|
||
"写入%d条错题记录到MySQL", len(self._error_records_buffer)
|
||
)
|
||
self._error_records_buffer.clear()
|
||
|
||
return total_written
|
||
|
||
def get_pipeline_stats(self) -> Dict[str, int]:
|
||
"""获取管道处理统计"""
|
||
return {
|
||
"processed": self._processed_count,
|
||
"errors": self._error_count,
|
||
"dropped": self._dropped_count,
|
||
"buffer_metrics": len(self._metric_buffer),
|
||
"buffer_daily": len(self._daily_stats_buffer),
|
||
"buffer_errors": len(self._error_records_buffer),
|
||
}
|
||
|
||
def stop_pipeline(self) -> None:
|
||
"""停止ETL管道,刷新所有缓冲区"""
|
||
logger.info("正在停止ETL管道...")
|
||
self.flush_to_clickhouse()
|
||
logger.info(
|
||
"ETL管道已停止. 统计: %s", self.get_pipeline_stats()
|
||
)
|