Files
2026-03-22 15:24:40 +08:00

471 lines
16 KiB
Java
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
/*
* 自然写互动课堂应用开发SDK软件 V1.0
* OCREngine - OCR识别引擎封装
*
* 功能说明:
* 1. 本地离线OCR识别(ONNX Runtime推理)
* 2. 云端在线OCR识别(REST API调用AI引擎)
* 3. 识别结果缓存与去重
* 4. 批量识别任务队列
* 5. 识别模式自动切换(在线优先,离线兜底)
*/
package com.writech.sdk.android;
import android.content.Context;
import android.graphics.Bitmap;
import android.os.Handler;
import android.os.HandlerThread;
import android.util.Log;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* OCR识别引擎
* 封装本地ONNX推理与云端AI引擎调用
*/
public class OCREngine {
private static final String TAG = "WritechOCREngine";
/* 识别模式枚举 */
public static final int MODE_AUTO = 0; /* 自动(在线优先,离线兜底) */
public static final int MODE_ONLINE_ONLY = 1; /* 仅在线 */
public static final int MODE_OFFLINE_ONLY = 2; /* 仅离线 */
/* 识别类型枚举 */
public static final int TYPE_HANDWRITING = 0; /* 手写文字识别 */
public static final int TYPE_MATH = 1; /* 数学公式识别 */
public static final int TYPE_STROKE_ORDER = 2; /* 笔顺评分 */
/* 云端API超时时间(毫秒) */
private static final int API_TIMEOUT_MS = 5000;
/* 最大离线缓存条目数 */
private static final int MAX_CACHE_SIZE = 500;
/* ========== 成员变量 ========== */
private final Context mContext;
private int mRecognitionMode = MODE_AUTO;
/* 离线ONNX模型文件路径 */
private String mOnnxModelPath;
private boolean mOfflineModelLoaded = false;
/* ONNX推理会话句柄(通过JNI调用C层) */
private long mOnnxSessionHandle = 0;
/* 云端API基础地址 */
private String mCloudApiBaseUrl;
private String mApiAccessToken;
/* 识别任务队列 */
private final Queue<RecognitionTask> mTaskQueue = new ConcurrentLinkedQueue<>();
private final AtomicBoolean mIsProcessing = new AtomicBoolean(false);
/* 后台处理线程 */
private HandlerThread mWorkerThread;
private Handler mWorkerHandler;
/* 结果缓存(简单LRU */
private final LinkedList<CacheEntry> mResultCache = new LinkedList<>();
/* ========== 内部数据结构 ========== */
/** 识别任务 */
private static class RecognitionTask {
int taskId; /* 任务ID */
int recognitionType; /* 识别类型 */
Bitmap inputImage; /* 输入图像 */
byte[] strokeData; /* 笔迹数据(笔顺识别用) */
String targetChar; /* 目标汉字(笔顺识别用) */
RecognitionCallback callback; /* 结果回调 */
}
/** 缓存条目 */
private static class CacheEntry {
String cacheKey; /* 缓存键(图像哈希) */
String result; /* 识别结果 */
long timestamp; /* 缓存时间 */
}
/** 识别结果回调接口 */
public interface RecognitionCallback {
void onSuccess(String result, float confidence, boolean fromCache);
void onError(int errorCode, String errorMessage);
}
/* ========== 构造与初始化 ========== */
/**
* 创建OCR引擎实例
* @param context Android上下文
* @param cloudBaseUrl 云端AI引擎API地址
* @param accessToken API访问令牌
*/
public OCREngine(Context context, String cloudBaseUrl, String accessToken) {
mContext = context.getApplicationContext();
mCloudApiBaseUrl = cloudBaseUrl;
mApiAccessToken = accessToken;
/* 创建后台处理线程 */
mWorkerThread = new HandlerThread("WritechOCR");
mWorkerThread.start();
mWorkerHandler = new Handler(mWorkerThread.getLooper());
Log.i(TAG, "OCR引擎初始化完成,云端地址: " + cloudBaseUrl);
}
/**
* 加载离线ONNX识别模型
* 从assets或本地文件加载预训练的手写识别模型
*
* @param modelPath 模型文件路径(.onnx格式)
* @return 是否加载成功
*/
public boolean loadOfflineModel(String modelPath) {
File modelFile = new File(modelPath);
if (!modelFile.exists()) {
Log.e(TAG, "离线模型文件不存在: " + modelPath);
return false;
}
/* 通过JNI调用C层ONNX Runtime加载模型 */
mOnnxSessionHandle = nativeLoadModel(modelPath);
if (mOnnxSessionHandle != 0) {
mOnnxModelPath = modelPath;
mOfflineModelLoaded = true;
Log.i(TAG, "离线ONNX模型加载成功: " + modelPath);
return true;
}
Log.e(TAG, "离线ONNX模型加载失败");
return false;
}
/** 设置识别模式 */
public void setRecognitionMode(int mode) {
mRecognitionMode = mode;
}
/* ========== 识别请求接口 ========== */
/**
* 提交手写文字识别任务
* @param image 笔迹图像(已渲染的Bitmap)
* @param callback 结果回调
* @return 任务ID
*/
public int recognizeHandwriting(Bitmap image, RecognitionCallback callback) {
return submitTask(TYPE_HANDWRITING, image, null, null, callback);
}
/**
* 提交数学公式识别任务
* @param image 公式图像
* @param callback 结果回调
* @return 任务ID
*/
public int recognizeMath(Bitmap image, RecognitionCallback callback) {
return submitTask(TYPE_MATH, image, null, null, callback);
}
/**
* 提交笔顺评分任务
* @param strokeData 笔迹轨迹数据(序列化的坐标数组)
* @param targetChar 目标汉字
* @param callback 结果回调
* @return 任务ID
*/
public int evaluateStrokeOrder(byte[] strokeData, String targetChar,
RecognitionCallback callback) {
return submitTask(TYPE_STROKE_ORDER, null, strokeData, targetChar, callback);
}
/* ========== 任务管理 ========== */
private int mTaskIdCounter = 0;
/** 提交识别任务到队列 */
private int submitTask(int type, Bitmap image, byte[] strokeData,
String targetChar, RecognitionCallback callback) {
RecognitionTask task = new RecognitionTask();
task.taskId = ++mTaskIdCounter;
task.recognitionType = type;
task.inputImage = image;
task.strokeData = strokeData;
task.targetChar = targetChar;
task.callback = callback;
mTaskQueue.offer(task);
Log.d(TAG, "识别任务已提交 #" + task.taskId + " 类型=" + type);
/* 如果没有正在处理的任务,启动处理循环 */
if (mIsProcessing.compareAndSet(false, true)) {
mWorkerHandler.post(this::processNextTask);
}
return task.taskId;
}
/** 处理队列中的下一个任务 */
private void processNextTask() {
RecognitionTask task = mTaskQueue.poll();
if (task == null) {
mIsProcessing.set(false);
return;
}
Log.d(TAG, "开始处理识别任务 #" + task.taskId);
try {
/* 检查缓存 */
String cacheKey = computeCacheKey(task);
String cachedResult = lookupCache(cacheKey);
if (cachedResult != null) {
task.callback.onSuccess(cachedResult, 1.0f, true);
Log.d(TAG, "任务 #" + task.taskId + " 命中缓存");
mWorkerHandler.post(this::processNextTask);
return;
}
String result = null;
float confidence = 0.0f;
/* 根据识别模式选择执行路径 */
switch (mRecognitionMode) {
case MODE_ONLINE_ONLY:
result = executeCloudRecognition(task);
confidence = 0.95f;
break;
case MODE_OFFLINE_ONLY:
result = executeOfflineRecognition(task);
confidence = 0.85f;
break;
case MODE_AUTO:
default:
/* 自动模式:先尝试在线,失败则回退到离线 */
try {
result = executeCloudRecognition(task);
confidence = 0.95f;
} catch (Exception e) {
Log.w(TAG, "在线识别失败,回退到离线: " + e.getMessage());
result = executeOfflineRecognition(task);
confidence = 0.85f;
}
break;
}
if (result != null) {
/* 存入缓存 */
putCache(cacheKey, result);
task.callback.onSuccess(result, confidence, false);
} else {
task.callback.onError(-1, "识别失败,无可用结果");
}
} catch (Exception e) {
Log.e(TAG, "识别任务 #" + task.taskId + " 异常: " + e.getMessage());
task.callback.onError(-2, e.getMessage());
}
/* 继续处理下一个任务 */
mWorkerHandler.post(this::processNextTask);
}
/* ========== 云端识别 ========== */
/** 调用云端AI引擎执行识别 */
private String executeCloudRecognition(RecognitionTask task) throws IOException {
String apiPath;
switch (task.recognitionType) {
case TYPE_MATH:
apiPath = "/api/v1/math/recognize";
break;
case TYPE_STROKE_ORDER:
apiPath = "/api/v1/stroke-order/evaluate";
break;
case TYPE_HANDWRITING:
default:
apiPath = "/api/v1/ocr/recognize";
break;
}
String url = mCloudApiBaseUrl + apiPath;
Log.d(TAG, "调用云端识别API: " + url);
/* 构建multipart请求体 */
byte[] imageBytes = null;
if (task.inputImage != null) {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
task.inputImage.compress(Bitmap.CompressFormat.PNG, 100, baos);
imageBytes = baos.toByteArray();
}
/* 使用CloudClient发送HTTP请求 */
String responseJson = CloudClient.postMultipart(url, mApiAccessToken,
imageBytes, task.strokeData, task.targetChar, API_TIMEOUT_MS);
/* 解析JSON响应提取识别结果 */
return parseRecognitionResult(responseJson);
}
/* ========== 离线识别 ========== */
/** 使用本地ONNX模型执行离线识别 */
private String executeOfflineRecognition(RecognitionTask task) {
if (!mOfflineModelLoaded || mOnnxSessionHandle == 0) {
Log.e(TAG, "离线模型未加载");
return null;
}
if (task.inputImage == null) {
Log.e(TAG, "离线识别需要输入图像");
return null;
}
/* 图像预处理:缩放到模型输入尺寸,转为灰度float数组 */
float[] inputTensor = preprocessImage(task.inputImage);
/* 通过JNI调用ONNX Runtime执行推理 */
String result = nativeRunInference(mOnnxSessionHandle, inputTensor,
task.inputImage.getWidth(), task.inputImage.getHeight());
return result;
}
/** 图像预处理(缩放+归一化) */
private float[] preprocessImage(Bitmap bitmap) {
int targetWidth = 320;
int targetHeight = 48;
/* 保持宽高比缩放 */
float scale = Math.min(
(float) targetWidth / bitmap.getWidth(),
(float) targetHeight / bitmap.getHeight()
);
int scaledW = (int) (bitmap.getWidth() * scale);
int scaledH = (int) (bitmap.getHeight() * scale);
Bitmap scaled = Bitmap.createScaledBitmap(bitmap, scaledW, scaledH, true);
float[] tensor = new float[targetWidth * targetHeight];
/* 填充灰度值并归一化到[0, 1] */
for (int y = 0; y < scaledH && y < targetHeight; y++) {
for (int x = 0; x < scaledW && x < targetWidth; x++) {
int pixel = scaled.getPixel(x, y);
/* 灰度化:0.299R + 0.587G + 0.114B */
float gray = (0.299f * ((pixel >> 16) & 0xFF)
+ 0.587f * ((pixel >> 8) & 0xFF)
+ 0.114f * (pixel & 0xFF)) / 255.0f;
tensor[y * targetWidth + x] = gray;
}
}
scaled.recycle();
return tensor;
}
/* ========== 结果缓存 ========== */
/** 计算缓存键 */
private String computeCacheKey(RecognitionTask task) {
if (task.inputImage != null) {
return "img_" + task.recognitionType + "_" + task.inputImage.hashCode();
}
if (task.strokeData != null && task.targetChar != null) {
return "stroke_" + task.targetChar + "_" + task.strokeData.length;
}
return "unknown_" + task.taskId;
}
/** 查找缓存 */
private String lookupCache(String key) {
synchronized (mResultCache) {
for (CacheEntry entry : mResultCache) {
if (entry.cacheKey.equals(key)) {
/* 检查过期(5分钟) */
if (System.currentTimeMillis() - entry.timestamp < 300000) {
return entry.result;
}
}
}
}
return null;
}
/** 存入缓存 */
private void putCache(String key, String result) {
synchronized (mResultCache) {
CacheEntry entry = new CacheEntry();
entry.cacheKey = key;
entry.result = result;
entry.timestamp = System.currentTimeMillis();
mResultCache.addFirst(entry);
/* 限制缓存大小 */
while (mResultCache.size() > MAX_CACHE_SIZE) {
mResultCache.removeLast();
}
}
}
/** 解析云端识别API返回的JSON */
private String parseRecognitionResult(String json) {
if (json == null || json.isEmpty()) return null;
/* 简化的JSON解析:提取result字段 */
int idx = json.indexOf("\"result\"");
if (idx < 0) return null;
int start = json.indexOf("\"", idx + 8) + 1;
int end = json.indexOf("\"", start);
if (start > 0 && end > start) {
return json.substring(start, end);
}
return null;
}
/* ========== JNI本地方法声明 ========== */
/** 加载ONNX模型,返回会话句柄 */
private native long nativeLoadModel(String modelPath);
/** 执行ONNX推理,返回识别结果JSON */
private native String nativeRunInference(long sessionHandle, float[] inputTensor,
int width, int height);
/** 释放ONNX会话资源 */
private native void nativeReleaseModel(long sessionHandle);
static {
System.loadLibrary("writech_ocr");
}
/* ========== 资源释放 ========== */
/** 释放OCR引擎资源 */
public void destroy() {
mTaskQueue.clear();
if (mOnnxSessionHandle != 0) {
nativeReleaseModel(mOnnxSessionHandle);
mOnnxSessionHandle = 0;
}
if (mWorkerThread != null) {
mWorkerThread.quitSafely();
mWorkerThread = null;
}
mResultCache.clear();
Log.i(TAG, "OCR引擎资源已释放");
}
}