software copyright

This commit is contained in:
jiahong
2026-03-22 15:24:40 +08:00
parent e303bb868a
commit 60f336e345
155 changed files with 127262 additions and 0 deletions
@@ -0,0 +1,470 @@
/*
* 自然写互动课堂应用开发SDK软件 V1.0
* OCREngine - OCR识别引擎封装
*
* 功能说明:
* 1. 本地离线OCR识别(ONNX Runtime推理)
* 2. 云端在线OCR识别(REST API调用AI引擎)
* 3. 识别结果缓存与去重
* 4. 批量识别任务队列
* 5. 识别模式自动切换(在线优先,离线兜底)
*/
package com.writech.sdk.android;
import android.content.Context;
import android.graphics.Bitmap;
import android.os.Handler;
import android.os.HandlerThread;
import android.util.Log;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* OCR识别引擎
* 封装本地ONNX推理与云端AI引擎调用
*/
public class OCREngine {
private static final String TAG = "WritechOCREngine";
/* 识别模式枚举 */
public static final int MODE_AUTO = 0; /* 自动(在线优先,离线兜底) */
public static final int MODE_ONLINE_ONLY = 1; /* 仅在线 */
public static final int MODE_OFFLINE_ONLY = 2; /* 仅离线 */
/* 识别类型枚举 */
public static final int TYPE_HANDWRITING = 0; /* 手写文字识别 */
public static final int TYPE_MATH = 1; /* 数学公式识别 */
public static final int TYPE_STROKE_ORDER = 2; /* 笔顺评分 */
/* 云端API超时时间(毫秒) */
private static final int API_TIMEOUT_MS = 5000;
/* 最大离线缓存条目数 */
private static final int MAX_CACHE_SIZE = 500;
/* ========== 成员变量 ========== */
private final Context mContext;
private int mRecognitionMode = MODE_AUTO;
/* 离线ONNX模型文件路径 */
private String mOnnxModelPath;
private boolean mOfflineModelLoaded = false;
/* ONNX推理会话句柄(通过JNI调用C层) */
private long mOnnxSessionHandle = 0;
/* 云端API基础地址 */
private String mCloudApiBaseUrl;
private String mApiAccessToken;
/* 识别任务队列 */
private final Queue<RecognitionTask> mTaskQueue = new ConcurrentLinkedQueue<>();
private final AtomicBoolean mIsProcessing = new AtomicBoolean(false);
/* 后台处理线程 */
private HandlerThread mWorkerThread;
private Handler mWorkerHandler;
/* 结果缓存(简单LRU */
private final LinkedList<CacheEntry> mResultCache = new LinkedList<>();
/* ========== 内部数据结构 ========== */
/** 识别任务 */
private static class RecognitionTask {
int taskId; /* 任务ID */
int recognitionType; /* 识别类型 */
Bitmap inputImage; /* 输入图像 */
byte[] strokeData; /* 笔迹数据(笔顺识别用) */
String targetChar; /* 目标汉字(笔顺识别用) */
RecognitionCallback callback; /* 结果回调 */
}
/** 缓存条目 */
private static class CacheEntry {
String cacheKey; /* 缓存键(图像哈希) */
String result; /* 识别结果 */
long timestamp; /* 缓存时间 */
}
/** 识别结果回调接口 */
public interface RecognitionCallback {
void onSuccess(String result, float confidence, boolean fromCache);
void onError(int errorCode, String errorMessage);
}
/* ========== 构造与初始化 ========== */
/**
* 创建OCR引擎实例
* @param context Android上下文
* @param cloudBaseUrl 云端AI引擎API地址
* @param accessToken API访问令牌
*/
public OCREngine(Context context, String cloudBaseUrl, String accessToken) {
mContext = context.getApplicationContext();
mCloudApiBaseUrl = cloudBaseUrl;
mApiAccessToken = accessToken;
/* 创建后台处理线程 */
mWorkerThread = new HandlerThread("WritechOCR");
mWorkerThread.start();
mWorkerHandler = new Handler(mWorkerThread.getLooper());
Log.i(TAG, "OCR引擎初始化完成,云端地址: " + cloudBaseUrl);
}
/**
* 加载离线ONNX识别模型
* 从assets或本地文件加载预训练的手写识别模型
*
* @param modelPath 模型文件路径(.onnx格式)
* @return 是否加载成功
*/
public boolean loadOfflineModel(String modelPath) {
File modelFile = new File(modelPath);
if (!modelFile.exists()) {
Log.e(TAG, "离线模型文件不存在: " + modelPath);
return false;
}
/* 通过JNI调用C层ONNX Runtime加载模型 */
mOnnxSessionHandle = nativeLoadModel(modelPath);
if (mOnnxSessionHandle != 0) {
mOnnxModelPath = modelPath;
mOfflineModelLoaded = true;
Log.i(TAG, "离线ONNX模型加载成功: " + modelPath);
return true;
}
Log.e(TAG, "离线ONNX模型加载失败");
return false;
}
/** 设置识别模式 */
public void setRecognitionMode(int mode) {
mRecognitionMode = mode;
}
/* ========== 识别请求接口 ========== */
/**
* 提交手写文字识别任务
* @param image 笔迹图像(已渲染的Bitmap)
* @param callback 结果回调
* @return 任务ID
*/
public int recognizeHandwriting(Bitmap image, RecognitionCallback callback) {
return submitTask(TYPE_HANDWRITING, image, null, null, callback);
}
/**
* 提交数学公式识别任务
* @param image 公式图像
* @param callback 结果回调
* @return 任务ID
*/
public int recognizeMath(Bitmap image, RecognitionCallback callback) {
return submitTask(TYPE_MATH, image, null, null, callback);
}
/**
* 提交笔顺评分任务
* @param strokeData 笔迹轨迹数据(序列化的坐标数组)
* @param targetChar 目标汉字
* @param callback 结果回调
* @return 任务ID
*/
public int evaluateStrokeOrder(byte[] strokeData, String targetChar,
RecognitionCallback callback) {
return submitTask(TYPE_STROKE_ORDER, null, strokeData, targetChar, callback);
}
/* ========== 任务管理 ========== */
private int mTaskIdCounter = 0;
/** 提交识别任务到队列 */
private int submitTask(int type, Bitmap image, byte[] strokeData,
String targetChar, RecognitionCallback callback) {
RecognitionTask task = new RecognitionTask();
task.taskId = ++mTaskIdCounter;
task.recognitionType = type;
task.inputImage = image;
task.strokeData = strokeData;
task.targetChar = targetChar;
task.callback = callback;
mTaskQueue.offer(task);
Log.d(TAG, "识别任务已提交 #" + task.taskId + " 类型=" + type);
/* 如果没有正在处理的任务,启动处理循环 */
if (mIsProcessing.compareAndSet(false, true)) {
mWorkerHandler.post(this::processNextTask);
}
return task.taskId;
}
/** 处理队列中的下一个任务 */
private void processNextTask() {
RecognitionTask task = mTaskQueue.poll();
if (task == null) {
mIsProcessing.set(false);
return;
}
Log.d(TAG, "开始处理识别任务 #" + task.taskId);
try {
/* 检查缓存 */
String cacheKey = computeCacheKey(task);
String cachedResult = lookupCache(cacheKey);
if (cachedResult != null) {
task.callback.onSuccess(cachedResult, 1.0f, true);
Log.d(TAG, "任务 #" + task.taskId + " 命中缓存");
mWorkerHandler.post(this::processNextTask);
return;
}
String result = null;
float confidence = 0.0f;
/* 根据识别模式选择执行路径 */
switch (mRecognitionMode) {
case MODE_ONLINE_ONLY:
result = executeCloudRecognition(task);
confidence = 0.95f;
break;
case MODE_OFFLINE_ONLY:
result = executeOfflineRecognition(task);
confidence = 0.85f;
break;
case MODE_AUTO:
default:
/* 自动模式:先尝试在线,失败则回退到离线 */
try {
result = executeCloudRecognition(task);
confidence = 0.95f;
} catch (Exception e) {
Log.w(TAG, "在线识别失败,回退到离线: " + e.getMessage());
result = executeOfflineRecognition(task);
confidence = 0.85f;
}
break;
}
if (result != null) {
/* 存入缓存 */
putCache(cacheKey, result);
task.callback.onSuccess(result, confidence, false);
} else {
task.callback.onError(-1, "识别失败,无可用结果");
}
} catch (Exception e) {
Log.e(TAG, "识别任务 #" + task.taskId + " 异常: " + e.getMessage());
task.callback.onError(-2, e.getMessage());
}
/* 继续处理下一个任务 */
mWorkerHandler.post(this::processNextTask);
}
/* ========== 云端识别 ========== */
/** 调用云端AI引擎执行识别 */
private String executeCloudRecognition(RecognitionTask task) throws IOException {
String apiPath;
switch (task.recognitionType) {
case TYPE_MATH:
apiPath = "/api/v1/math/recognize";
break;
case TYPE_STROKE_ORDER:
apiPath = "/api/v1/stroke-order/evaluate";
break;
case TYPE_HANDWRITING:
default:
apiPath = "/api/v1/ocr/recognize";
break;
}
String url = mCloudApiBaseUrl + apiPath;
Log.d(TAG, "调用云端识别API: " + url);
/* 构建multipart请求体 */
byte[] imageBytes = null;
if (task.inputImage != null) {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
task.inputImage.compress(Bitmap.CompressFormat.PNG, 100, baos);
imageBytes = baos.toByteArray();
}
/* 使用CloudClient发送HTTP请求 */
String responseJson = CloudClient.postMultipart(url, mApiAccessToken,
imageBytes, task.strokeData, task.targetChar, API_TIMEOUT_MS);
/* 解析JSON响应提取识别结果 */
return parseRecognitionResult(responseJson);
}
/* ========== 离线识别 ========== */
/** 使用本地ONNX模型执行离线识别 */
private String executeOfflineRecognition(RecognitionTask task) {
if (!mOfflineModelLoaded || mOnnxSessionHandle == 0) {
Log.e(TAG, "离线模型未加载");
return null;
}
if (task.inputImage == null) {
Log.e(TAG, "离线识别需要输入图像");
return null;
}
/* 图像预处理:缩放到模型输入尺寸,转为灰度float数组 */
float[] inputTensor = preprocessImage(task.inputImage);
/* 通过JNI调用ONNX Runtime执行推理 */
String result = nativeRunInference(mOnnxSessionHandle, inputTensor,
task.inputImage.getWidth(), task.inputImage.getHeight());
return result;
}
/** 图像预处理(缩放+归一化) */
private float[] preprocessImage(Bitmap bitmap) {
int targetWidth = 320;
int targetHeight = 48;
/* 保持宽高比缩放 */
float scale = Math.min(
(float) targetWidth / bitmap.getWidth(),
(float) targetHeight / bitmap.getHeight()
);
int scaledW = (int) (bitmap.getWidth() * scale);
int scaledH = (int) (bitmap.getHeight() * scale);
Bitmap scaled = Bitmap.createScaledBitmap(bitmap, scaledW, scaledH, true);
float[] tensor = new float[targetWidth * targetHeight];
/* 填充灰度值并归一化到[0, 1] */
for (int y = 0; y < scaledH && y < targetHeight; y++) {
for (int x = 0; x < scaledW && x < targetWidth; x++) {
int pixel = scaled.getPixel(x, y);
/* 灰度化:0.299R + 0.587G + 0.114B */
float gray = (0.299f * ((pixel >> 16) & 0xFF)
+ 0.587f * ((pixel >> 8) & 0xFF)
+ 0.114f * (pixel & 0xFF)) / 255.0f;
tensor[y * targetWidth + x] = gray;
}
}
scaled.recycle();
return tensor;
}
/* ========== 结果缓存 ========== */
/** 计算缓存键 */
private String computeCacheKey(RecognitionTask task) {
if (task.inputImage != null) {
return "img_" + task.recognitionType + "_" + task.inputImage.hashCode();
}
if (task.strokeData != null && task.targetChar != null) {
return "stroke_" + task.targetChar + "_" + task.strokeData.length;
}
return "unknown_" + task.taskId;
}
/** 查找缓存 */
private String lookupCache(String key) {
synchronized (mResultCache) {
for (CacheEntry entry : mResultCache) {
if (entry.cacheKey.equals(key)) {
/* 检查过期(5分钟) */
if (System.currentTimeMillis() - entry.timestamp < 300000) {
return entry.result;
}
}
}
}
return null;
}
/** 存入缓存 */
private void putCache(String key, String result) {
synchronized (mResultCache) {
CacheEntry entry = new CacheEntry();
entry.cacheKey = key;
entry.result = result;
entry.timestamp = System.currentTimeMillis();
mResultCache.addFirst(entry);
/* 限制缓存大小 */
while (mResultCache.size() > MAX_CACHE_SIZE) {
mResultCache.removeLast();
}
}
}
/** 解析云端识别API返回的JSON */
private String parseRecognitionResult(String json) {
if (json == null || json.isEmpty()) return null;
/* 简化的JSON解析:提取result字段 */
int idx = json.indexOf("\"result\"");
if (idx < 0) return null;
int start = json.indexOf("\"", idx + 8) + 1;
int end = json.indexOf("\"", start);
if (start > 0 && end > start) {
return json.substring(start, end);
}
return null;
}
/* ========== JNI本地方法声明 ========== */
/** 加载ONNX模型,返回会话句柄 */
private native long nativeLoadModel(String modelPath);
/** 执行ONNX推理,返回识别结果JSON */
private native String nativeRunInference(long sessionHandle, float[] inputTensor,
int width, int height);
/** 释放ONNX会话资源 */
private native void nativeReleaseModel(long sessionHandle);
static {
System.loadLibrary("writech_ocr");
}
/* ========== 资源释放 ========== */
/** 释放OCR引擎资源 */
public void destroy() {
mTaskQueue.clear();
if (mOnnxSessionHandle != 0) {
nativeReleaseModel(mOnnxSessionHandle);
mOnnxSessionHandle = 0;
}
if (mWorkerThread != null) {
mWorkerThread.quitSafely();
mWorkerThread = null;
}
mResultCache.clear();
Log.i(TAG, "OCR引擎资源已释放");
}
}