software copyright
This commit is contained in:
@@ -0,0 +1,470 @@
|
||||
/*
|
||||
* 自然写互动课堂应用开发SDK软件 V1.0
|
||||
* OCREngine - OCR识别引擎封装
|
||||
*
|
||||
* 功能说明:
|
||||
* 1. 本地离线OCR识别(ONNX Runtime推理)
|
||||
* 2. 云端在线OCR识别(REST API调用AI引擎)
|
||||
* 3. 识别结果缓存与去重
|
||||
* 4. 批量识别任务队列
|
||||
* 5. 识别模式自动切换(在线优先,离线兜底)
|
||||
*/
|
||||
|
||||
package com.writech.sdk.android;
|
||||
|
||||
import android.content.Context;
|
||||
import android.graphics.Bitmap;
|
||||
import android.os.Handler;
|
||||
import android.os.HandlerThread;
|
||||
import android.util.Log;
|
||||
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.Queue;
|
||||
import java.util.concurrent.ConcurrentLinkedQueue;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
|
||||
/**
|
||||
* OCR识别引擎
|
||||
* 封装本地ONNX推理与云端AI引擎调用
|
||||
*/
|
||||
public class OCREngine {
|
||||
|
||||
private static final String TAG = "WritechOCREngine";
|
||||
|
||||
/* 识别模式枚举 */
|
||||
public static final int MODE_AUTO = 0; /* 自动(在线优先,离线兜底) */
|
||||
public static final int MODE_ONLINE_ONLY = 1; /* 仅在线 */
|
||||
public static final int MODE_OFFLINE_ONLY = 2; /* 仅离线 */
|
||||
|
||||
/* 识别类型枚举 */
|
||||
public static final int TYPE_HANDWRITING = 0; /* 手写文字识别 */
|
||||
public static final int TYPE_MATH = 1; /* 数学公式识别 */
|
||||
public static final int TYPE_STROKE_ORDER = 2; /* 笔顺评分 */
|
||||
|
||||
/* 云端API超时时间(毫秒) */
|
||||
private static final int API_TIMEOUT_MS = 5000;
|
||||
|
||||
/* 最大离线缓存条目数 */
|
||||
private static final int MAX_CACHE_SIZE = 500;
|
||||
|
||||
/* ========== 成员变量 ========== */
|
||||
|
||||
private final Context mContext;
|
||||
private int mRecognitionMode = MODE_AUTO;
|
||||
|
||||
/* 离线ONNX模型文件路径 */
|
||||
private String mOnnxModelPath;
|
||||
private boolean mOfflineModelLoaded = false;
|
||||
|
||||
/* ONNX推理会话句柄(通过JNI调用C层) */
|
||||
private long mOnnxSessionHandle = 0;
|
||||
|
||||
/* 云端API基础地址 */
|
||||
private String mCloudApiBaseUrl;
|
||||
private String mApiAccessToken;
|
||||
|
||||
/* 识别任务队列 */
|
||||
private final Queue<RecognitionTask> mTaskQueue = new ConcurrentLinkedQueue<>();
|
||||
private final AtomicBoolean mIsProcessing = new AtomicBoolean(false);
|
||||
|
||||
/* 后台处理线程 */
|
||||
private HandlerThread mWorkerThread;
|
||||
private Handler mWorkerHandler;
|
||||
|
||||
/* 结果缓存(简单LRU) */
|
||||
private final LinkedList<CacheEntry> mResultCache = new LinkedList<>();
|
||||
|
||||
/* ========== 内部数据结构 ========== */
|
||||
|
||||
/** 识别任务 */
|
||||
private static class RecognitionTask {
|
||||
int taskId; /* 任务ID */
|
||||
int recognitionType; /* 识别类型 */
|
||||
Bitmap inputImage; /* 输入图像 */
|
||||
byte[] strokeData; /* 笔迹数据(笔顺识别用) */
|
||||
String targetChar; /* 目标汉字(笔顺识别用) */
|
||||
RecognitionCallback callback; /* 结果回调 */
|
||||
}
|
||||
|
||||
/** 缓存条目 */
|
||||
private static class CacheEntry {
|
||||
String cacheKey; /* 缓存键(图像哈希) */
|
||||
String result; /* 识别结果 */
|
||||
long timestamp; /* 缓存时间 */
|
||||
}
|
||||
|
||||
/** 识别结果回调接口 */
|
||||
public interface RecognitionCallback {
|
||||
void onSuccess(String result, float confidence, boolean fromCache);
|
||||
void onError(int errorCode, String errorMessage);
|
||||
}
|
||||
|
||||
/* ========== 构造与初始化 ========== */
|
||||
|
||||
/**
|
||||
* 创建OCR引擎实例
|
||||
* @param context Android上下文
|
||||
* @param cloudBaseUrl 云端AI引擎API地址
|
||||
* @param accessToken API访问令牌
|
||||
*/
|
||||
public OCREngine(Context context, String cloudBaseUrl, String accessToken) {
|
||||
mContext = context.getApplicationContext();
|
||||
mCloudApiBaseUrl = cloudBaseUrl;
|
||||
mApiAccessToken = accessToken;
|
||||
|
||||
/* 创建后台处理线程 */
|
||||
mWorkerThread = new HandlerThread("WritechOCR");
|
||||
mWorkerThread.start();
|
||||
mWorkerHandler = new Handler(mWorkerThread.getLooper());
|
||||
|
||||
Log.i(TAG, "OCR引擎初始化完成,云端地址: " + cloudBaseUrl);
|
||||
}
|
||||
|
||||
/**
|
||||
* 加载离线ONNX识别模型
|
||||
* 从assets或本地文件加载预训练的手写识别模型
|
||||
*
|
||||
* @param modelPath 模型文件路径(.onnx格式)
|
||||
* @return 是否加载成功
|
||||
*/
|
||||
public boolean loadOfflineModel(String modelPath) {
|
||||
File modelFile = new File(modelPath);
|
||||
if (!modelFile.exists()) {
|
||||
Log.e(TAG, "离线模型文件不存在: " + modelPath);
|
||||
return false;
|
||||
}
|
||||
|
||||
/* 通过JNI调用C层ONNX Runtime加载模型 */
|
||||
mOnnxSessionHandle = nativeLoadModel(modelPath);
|
||||
if (mOnnxSessionHandle != 0) {
|
||||
mOnnxModelPath = modelPath;
|
||||
mOfflineModelLoaded = true;
|
||||
Log.i(TAG, "离线ONNX模型加载成功: " + modelPath);
|
||||
return true;
|
||||
}
|
||||
|
||||
Log.e(TAG, "离线ONNX模型加载失败");
|
||||
return false;
|
||||
}
|
||||
|
||||
/** 设置识别模式 */
|
||||
public void setRecognitionMode(int mode) {
|
||||
mRecognitionMode = mode;
|
||||
}
|
||||
|
||||
/* ========== 识别请求接口 ========== */
|
||||
|
||||
/**
|
||||
* 提交手写文字识别任务
|
||||
* @param image 笔迹图像(已渲染的Bitmap)
|
||||
* @param callback 结果回调
|
||||
* @return 任务ID
|
||||
*/
|
||||
public int recognizeHandwriting(Bitmap image, RecognitionCallback callback) {
|
||||
return submitTask(TYPE_HANDWRITING, image, null, null, callback);
|
||||
}
|
||||
|
||||
/**
|
||||
* 提交数学公式识别任务
|
||||
* @param image 公式图像
|
||||
* @param callback 结果回调
|
||||
* @return 任务ID
|
||||
*/
|
||||
public int recognizeMath(Bitmap image, RecognitionCallback callback) {
|
||||
return submitTask(TYPE_MATH, image, null, null, callback);
|
||||
}
|
||||
|
||||
/**
|
||||
* 提交笔顺评分任务
|
||||
* @param strokeData 笔迹轨迹数据(序列化的坐标数组)
|
||||
* @param targetChar 目标汉字
|
||||
* @param callback 结果回调
|
||||
* @return 任务ID
|
||||
*/
|
||||
public int evaluateStrokeOrder(byte[] strokeData, String targetChar,
|
||||
RecognitionCallback callback) {
|
||||
return submitTask(TYPE_STROKE_ORDER, null, strokeData, targetChar, callback);
|
||||
}
|
||||
|
||||
/* ========== 任务管理 ========== */
|
||||
|
||||
private int mTaskIdCounter = 0;
|
||||
|
||||
/** 提交识别任务到队列 */
|
||||
private int submitTask(int type, Bitmap image, byte[] strokeData,
|
||||
String targetChar, RecognitionCallback callback) {
|
||||
RecognitionTask task = new RecognitionTask();
|
||||
task.taskId = ++mTaskIdCounter;
|
||||
task.recognitionType = type;
|
||||
task.inputImage = image;
|
||||
task.strokeData = strokeData;
|
||||
task.targetChar = targetChar;
|
||||
task.callback = callback;
|
||||
|
||||
mTaskQueue.offer(task);
|
||||
Log.d(TAG, "识别任务已提交 #" + task.taskId + " 类型=" + type);
|
||||
|
||||
/* 如果没有正在处理的任务,启动处理循环 */
|
||||
if (mIsProcessing.compareAndSet(false, true)) {
|
||||
mWorkerHandler.post(this::processNextTask);
|
||||
}
|
||||
|
||||
return task.taskId;
|
||||
}
|
||||
|
||||
/** 处理队列中的下一个任务 */
|
||||
private void processNextTask() {
|
||||
RecognitionTask task = mTaskQueue.poll();
|
||||
if (task == null) {
|
||||
mIsProcessing.set(false);
|
||||
return;
|
||||
}
|
||||
|
||||
Log.d(TAG, "开始处理识别任务 #" + task.taskId);
|
||||
|
||||
try {
|
||||
/* 检查缓存 */
|
||||
String cacheKey = computeCacheKey(task);
|
||||
String cachedResult = lookupCache(cacheKey);
|
||||
if (cachedResult != null) {
|
||||
task.callback.onSuccess(cachedResult, 1.0f, true);
|
||||
Log.d(TAG, "任务 #" + task.taskId + " 命中缓存");
|
||||
mWorkerHandler.post(this::processNextTask);
|
||||
return;
|
||||
}
|
||||
|
||||
String result = null;
|
||||
float confidence = 0.0f;
|
||||
|
||||
/* 根据识别模式选择执行路径 */
|
||||
switch (mRecognitionMode) {
|
||||
case MODE_ONLINE_ONLY:
|
||||
result = executeCloudRecognition(task);
|
||||
confidence = 0.95f;
|
||||
break;
|
||||
|
||||
case MODE_OFFLINE_ONLY:
|
||||
result = executeOfflineRecognition(task);
|
||||
confidence = 0.85f;
|
||||
break;
|
||||
|
||||
case MODE_AUTO:
|
||||
default:
|
||||
/* 自动模式:先尝试在线,失败则回退到离线 */
|
||||
try {
|
||||
result = executeCloudRecognition(task);
|
||||
confidence = 0.95f;
|
||||
} catch (Exception e) {
|
||||
Log.w(TAG, "在线识别失败,回退到离线: " + e.getMessage());
|
||||
result = executeOfflineRecognition(task);
|
||||
confidence = 0.85f;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
if (result != null) {
|
||||
/* 存入缓存 */
|
||||
putCache(cacheKey, result);
|
||||
task.callback.onSuccess(result, confidence, false);
|
||||
} else {
|
||||
task.callback.onError(-1, "识别失败,无可用结果");
|
||||
}
|
||||
|
||||
} catch (Exception e) {
|
||||
Log.e(TAG, "识别任务 #" + task.taskId + " 异常: " + e.getMessage());
|
||||
task.callback.onError(-2, e.getMessage());
|
||||
}
|
||||
|
||||
/* 继续处理下一个任务 */
|
||||
mWorkerHandler.post(this::processNextTask);
|
||||
}
|
||||
|
||||
/* ========== 云端识别 ========== */
|
||||
|
||||
/** 调用云端AI引擎执行识别 */
|
||||
private String executeCloudRecognition(RecognitionTask task) throws IOException {
|
||||
String apiPath;
|
||||
switch (task.recognitionType) {
|
||||
case TYPE_MATH:
|
||||
apiPath = "/api/v1/math/recognize";
|
||||
break;
|
||||
case TYPE_STROKE_ORDER:
|
||||
apiPath = "/api/v1/stroke-order/evaluate";
|
||||
break;
|
||||
case TYPE_HANDWRITING:
|
||||
default:
|
||||
apiPath = "/api/v1/ocr/recognize";
|
||||
break;
|
||||
}
|
||||
|
||||
String url = mCloudApiBaseUrl + apiPath;
|
||||
Log.d(TAG, "调用云端识别API: " + url);
|
||||
|
||||
/* 构建multipart请求体 */
|
||||
byte[] imageBytes = null;
|
||||
if (task.inputImage != null) {
|
||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||
task.inputImage.compress(Bitmap.CompressFormat.PNG, 100, baos);
|
||||
imageBytes = baos.toByteArray();
|
||||
}
|
||||
|
||||
/* 使用CloudClient发送HTTP请求 */
|
||||
String responseJson = CloudClient.postMultipart(url, mApiAccessToken,
|
||||
imageBytes, task.strokeData, task.targetChar, API_TIMEOUT_MS);
|
||||
|
||||
/* 解析JSON响应提取识别结果 */
|
||||
return parseRecognitionResult(responseJson);
|
||||
}
|
||||
|
||||
/* ========== 离线识别 ========== */
|
||||
|
||||
/** 使用本地ONNX模型执行离线识别 */
|
||||
private String executeOfflineRecognition(RecognitionTask task) {
|
||||
if (!mOfflineModelLoaded || mOnnxSessionHandle == 0) {
|
||||
Log.e(TAG, "离线模型未加载");
|
||||
return null;
|
||||
}
|
||||
|
||||
if (task.inputImage == null) {
|
||||
Log.e(TAG, "离线识别需要输入图像");
|
||||
return null;
|
||||
}
|
||||
|
||||
/* 图像预处理:缩放到模型输入尺寸,转为灰度float数组 */
|
||||
float[] inputTensor = preprocessImage(task.inputImage);
|
||||
|
||||
/* 通过JNI调用ONNX Runtime执行推理 */
|
||||
String result = nativeRunInference(mOnnxSessionHandle, inputTensor,
|
||||
task.inputImage.getWidth(), task.inputImage.getHeight());
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/** 图像预处理(缩放+归一化) */
|
||||
private float[] preprocessImage(Bitmap bitmap) {
|
||||
int targetWidth = 320;
|
||||
int targetHeight = 48;
|
||||
|
||||
/* 保持宽高比缩放 */
|
||||
float scale = Math.min(
|
||||
(float) targetWidth / bitmap.getWidth(),
|
||||
(float) targetHeight / bitmap.getHeight()
|
||||
);
|
||||
int scaledW = (int) (bitmap.getWidth() * scale);
|
||||
int scaledH = (int) (bitmap.getHeight() * scale);
|
||||
|
||||
Bitmap scaled = Bitmap.createScaledBitmap(bitmap, scaledW, scaledH, true);
|
||||
float[] tensor = new float[targetWidth * targetHeight];
|
||||
|
||||
/* 填充灰度值并归一化到[0, 1] */
|
||||
for (int y = 0; y < scaledH && y < targetHeight; y++) {
|
||||
for (int x = 0; x < scaledW && x < targetWidth; x++) {
|
||||
int pixel = scaled.getPixel(x, y);
|
||||
/* 灰度化:0.299R + 0.587G + 0.114B */
|
||||
float gray = (0.299f * ((pixel >> 16) & 0xFF)
|
||||
+ 0.587f * ((pixel >> 8) & 0xFF)
|
||||
+ 0.114f * (pixel & 0xFF)) / 255.0f;
|
||||
tensor[y * targetWidth + x] = gray;
|
||||
}
|
||||
}
|
||||
|
||||
scaled.recycle();
|
||||
return tensor;
|
||||
}
|
||||
|
||||
/* ========== 结果缓存 ========== */
|
||||
|
||||
/** 计算缓存键 */
|
||||
private String computeCacheKey(RecognitionTask task) {
|
||||
if (task.inputImage != null) {
|
||||
return "img_" + task.recognitionType + "_" + task.inputImage.hashCode();
|
||||
}
|
||||
if (task.strokeData != null && task.targetChar != null) {
|
||||
return "stroke_" + task.targetChar + "_" + task.strokeData.length;
|
||||
}
|
||||
return "unknown_" + task.taskId;
|
||||
}
|
||||
|
||||
/** 查找缓存 */
|
||||
private String lookupCache(String key) {
|
||||
synchronized (mResultCache) {
|
||||
for (CacheEntry entry : mResultCache) {
|
||||
if (entry.cacheKey.equals(key)) {
|
||||
/* 检查过期(5分钟) */
|
||||
if (System.currentTimeMillis() - entry.timestamp < 300000) {
|
||||
return entry.result;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/** 存入缓存 */
|
||||
private void putCache(String key, String result) {
|
||||
synchronized (mResultCache) {
|
||||
CacheEntry entry = new CacheEntry();
|
||||
entry.cacheKey = key;
|
||||
entry.result = result;
|
||||
entry.timestamp = System.currentTimeMillis();
|
||||
mResultCache.addFirst(entry);
|
||||
|
||||
/* 限制缓存大小 */
|
||||
while (mResultCache.size() > MAX_CACHE_SIZE) {
|
||||
mResultCache.removeLast();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** 解析云端识别API返回的JSON */
|
||||
private String parseRecognitionResult(String json) {
|
||||
if (json == null || json.isEmpty()) return null;
|
||||
/* 简化的JSON解析:提取result字段 */
|
||||
int idx = json.indexOf("\"result\"");
|
||||
if (idx < 0) return null;
|
||||
int start = json.indexOf("\"", idx + 8) + 1;
|
||||
int end = json.indexOf("\"", start);
|
||||
if (start > 0 && end > start) {
|
||||
return json.substring(start, end);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/* ========== JNI本地方法声明 ========== */
|
||||
|
||||
/** 加载ONNX模型,返回会话句柄 */
|
||||
private native long nativeLoadModel(String modelPath);
|
||||
|
||||
/** 执行ONNX推理,返回识别结果JSON */
|
||||
private native String nativeRunInference(long sessionHandle, float[] inputTensor,
|
||||
int width, int height);
|
||||
|
||||
/** 释放ONNX会话资源 */
|
||||
private native void nativeReleaseModel(long sessionHandle);
|
||||
|
||||
static {
|
||||
System.loadLibrary("writech_ocr");
|
||||
}
|
||||
|
||||
/* ========== 资源释放 ========== */
|
||||
|
||||
/** 释放OCR引擎资源 */
|
||||
public void destroy() {
|
||||
mTaskQueue.clear();
|
||||
if (mOnnxSessionHandle != 0) {
|
||||
nativeReleaseModel(mOnnxSessionHandle);
|
||||
mOnnxSessionHandle = 0;
|
||||
}
|
||||
if (mWorkerThread != null) {
|
||||
mWorkerThread.quitSafely();
|
||||
mWorkerThread = null;
|
||||
}
|
||||
mResultCache.clear();
|
||||
Log.i(TAG, "OCR引擎资源已释放");
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user