diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 00000000..7d640881 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,11 @@ +# top-most EditorConfig file +root = true + +# Unix-style newlines with a newline ending every file +[*] +end_of_line = lf +insert_final_newline = true + +[*.{kt,kts}] +ij_kotlin_code_style_defaults = KOTLIN_OFFICIAL +ktlint_code_style = intellij_idea diff --git a/.github/workflows/build-check.yml b/.github/workflows/build-check.yml new file mode 100644 index 00000000..6d4c0807 --- /dev/null +++ b/.github/workflows/build-check.yml @@ -0,0 +1,62 @@ +name: Build Check + +on: + push: + branches: ["*"] + paths-ignore: + - "**.md" + - "**.txt" + - "docs/**" + - ".gitignore" + - "LICENSE" + pull_request: + types: + - opened + - reopened + - synchronize + +concurrency: + group: build-check + cancel-in-progress: true + +jobs: + build: + name: Build Check + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up Java 21 + uses: actions/setup-java@v4 + with: + distribution: "temurin" + java-version: "21" + cache: "gradle" + + - name: Validate Gradle Wrapper + uses: gradle/wrapper-validation-action@v2 + + - name: Set up Gradle + uses: gradle/actions/setup-gradle@v3 + + - name: Check code style + run: | + ./gradlew --no-daemon ktlintCheck + + - name: Set up Android SDK + uses: android-actions/setup-android@v3 + with: + api-level: 34 + build-tools: 34.0.0 + + - name: Build Debug APK + run: | + ./gradlew --no-daemon clean assembleDebug + + - name: Upload Debug APK + uses: actions/upload-artifact@v4 + with: + name: debug-apk + path: app/build/outputs/apk/debug/*.apk + retention-days: 7 diff --git a/app/build.gradle.kts b/app/build.gradle.kts index 27a56217..513b4dbe 100644 --- a/app/build.gradle.kts +++ b/app/build.gradle.kts @@ -2,6 +2,7 @@ plugins { id("com.android.application") id("org.jetbrains.kotlin.android") id("org.jetbrains.kotlin.plugin.serialization") version "2.2.20" + id("org.jlleitschuh.gradle.ktlint") } android { @@ -31,7 +32,8 @@ android { signingConfigs { create("release") { // 从环境变量读取签名配置 - storeFile = System.getenv("KEYSTORE_FILE")?.let { file(it) } + val keystoreFile = System.getenv("KEYSTORE_FILE") + storeFile = keystoreFile?.takeIf { it.isNotBlank() }?.let { file(it) } storePassword = System.getenv("KEYSTORE_PASSWORD") keyAlias = System.getenv("KEY_ALIAS") keyPassword = System.getenv("KEY_PASSWORD") @@ -46,7 +48,7 @@ android { proguardFiles( getDefaultProguardFile("proguard-android-optimize.txt"), - "proguard-rules.pro" + "proguard-rules.pro", ) signingConfig = signingConfigs.findByName("release")?.takeIf { @@ -86,14 +88,14 @@ android { excludes += listOf( "**/libonnxruntime4j_jni.so", "**/libsherpa-onnx-c-api.so", - "**/libsherpa-onnx-cxx-api.so" + "**/libsherpa-onnx-cxx-api.so", ) } resources { excludes += listOf( "META-INF/services/lombok.*", "README.md", - "META-INF/README.md" + "META-INF/README.md", ) } } @@ -113,7 +115,7 @@ tasks.withType(JavaCompile::class.java).configureEach { javaCompiler.set( toolchainService.compilerFor { languageVersion.set(JavaLanguageVersion.of(21)) - } + }, ) } diff --git a/app/src/main/java/com/brycewg/asrkb/App.kt b/app/src/main/java/com/brycewg/asrkb/App.kt index 7b75d61c..8dfa2c91 100644 --- a/app/src/main/java/com/brycewg/asrkb/App.kt +++ b/app/src/main/java/com/brycewg/asrkb/App.kt @@ -5,23 +5,23 @@ */ package com.brycewg.asrkb -import android.app.Application import android.app.Activity import android.app.ActivityManager -import android.os.Bundle +import android.app.Application import android.content.Intent +import android.os.Bundle import android.provider.Settings import android.util.Log -import com.google.android.material.color.DynamicColors -import com.brycewg.asrkb.store.Prefs import androidx.appcompat.app.AppCompatDelegate import androidx.core.os.LocaleListCompat +import com.brycewg.asrkb.analytics.AnalyticsManager +import com.brycewg.asrkb.asr.VadDetector +import com.brycewg.asrkb.store.Prefs import com.brycewg.asrkb.ui.floating.FloatingAsrService import com.brycewg.asrkb.ui.floating.FloatingKeepAliveService import com.brycewg.asrkb.ui.floating.PrivilegedKeepAliveScheduler import com.brycewg.asrkb.ui.floating.PrivilegedKeepAliveStarter -import com.brycewg.asrkb.asr.VadDetector -import com.brycewg.asrkb.analytics.AnalyticsManager +import com.google.android.material.color.DynamicColors import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.SupervisorJob diff --git a/app/src/main/java/com/brycewg/asrkb/BootReceiver.kt b/app/src/main/java/com/brycewg/asrkb/BootReceiver.kt index 9fc087b2..a0472611 100644 --- a/app/src/main/java/com/brycewg/asrkb/BootReceiver.kt +++ b/app/src/main/java/com/brycewg/asrkb/BootReceiver.kt @@ -95,7 +95,7 @@ class BootReceiver : BroadcastReceiver() { context, "keepalive", "boot_start_result", - mapOf("ok" to false, "method" to result.method.name.lowercase(), "exit" to result.exitCode) + mapOf("ok" to false, "method" to result.method.name.lowercase(), "exit" to result.exitCode), ) } } catch (t: Throwable) { diff --git a/app/src/main/java/com/brycewg/asrkb/LocaleHelper.kt b/app/src/main/java/com/brycewg/asrkb/LocaleHelper.kt index 3c395503..0d22270e 100644 --- a/app/src/main/java/com/brycewg/asrkb/LocaleHelper.kt +++ b/app/src/main/java/com/brycewg/asrkb/LocaleHelper.kt @@ -8,22 +8,22 @@ import androidx.core.os.LocaleListCompat import java.util.Locale object LocaleHelper { - fun wrap(newBase: Context): Context { - val locales = AppCompatDelegate.getApplicationLocales() - if (locales.isEmpty) return newBase - val config = Configuration(newBase.resources.configuration) - applyLocales(config, locales) - return newBase.createConfigurationContext(config) - } + fun wrap(newBase: Context): Context { + val locales = AppCompatDelegate.getApplicationLocales() + if (locales.isEmpty) return newBase + val config = Configuration(newBase.resources.configuration) + applyLocales(config, locales) + return newBase.createConfigurationContext(config) + } - private fun applyLocales(config: Configuration, locales: LocaleListCompat) { - if (locales.isEmpty) return - val tags = locales.toLanguageTags() - if (tags.isEmpty()) return - val localeList = LocaleList.forLanguageTags(tags) - if (localeList.isEmpty) return - config.setLocales(localeList) - LocaleList.setDefault(localeList) - Locale.setDefault(localeList[0]) - } + private fun applyLocales(config: Configuration, locales: LocaleListCompat) { + if (locales.isEmpty) return + val tags = locales.toLanguageTags() + if (tags.isEmpty()) return + val localeList = LocaleList.forLanguageTags(tags) + if (localeList.isEmpty) return + config.setLocales(localeList) + LocaleList.setDefault(localeList) + Locale.setDefault(localeList[0]) + } } diff --git a/app/src/main/java/com/brycewg/asrkb/UiColorTokens.kt b/app/src/main/java/com/brycewg/asrkb/UiColorTokens.kt index 95a36803..a0bff471 100644 --- a/app/src/main/java/com/brycewg/asrkb/UiColorTokens.kt +++ b/app/src/main/java/com/brycewg/asrkb/UiColorTokens.kt @@ -25,7 +25,6 @@ object UiColorTokens { /** 容器前景色 */ val containerFg = com.google.android.material.R.attr.colorOnSurfaceVariant - // ==================== 键盘相关 ==================== /** 键盘按键背景 */ @@ -37,7 +36,6 @@ object UiColorTokens { /** 键盘容器背景 */ val kbdContainerBg = com.google.android.material.R.attr.colorSurfaceVariant - // ==================== 强调与状态色 ==================== /** 主强调色(主要操作按钮等) */ @@ -76,7 +74,6 @@ object UiColorTokens { /** 错误容器前景色 */ val onErrorContainer = com.google.android.material.R.attr.colorOnErrorContainer - // ==================== 选中与高亮 ==================== /** 选中项背景色 */ @@ -91,7 +88,6 @@ object UiColorTokens { /** 遮罩色(用于暗化/系统栏对齐等) */ val scrim = R.attr.asrScrimColor - // ==================== 边框与分割线 ==================== /** 主要边框色 */ @@ -100,7 +96,6 @@ object UiColorTokens { /** 次要边框色(更淡) */ val outlineVariant = com.google.android.material.R.attr.colorOutlineVariant - // ==================== 悬浮球相关 ==================== /** 悬浮球容器背景 */ @@ -112,7 +107,6 @@ object UiColorTokens { /** 悬浮球错误状态色 */ val floatingError = android.R.attr.colorError - // ==================== 状态芯片 ==================== /** 芯片背景色 */ diff --git a/app/src/main/java/com/brycewg/asrkb/aidl/SpeechConfig.kt b/app/src/main/java/com/brycewg/asrkb/aidl/SpeechConfig.kt index 215020b2..0848d6f2 100644 --- a/app/src/main/java/com/brycewg/asrkb/aidl/SpeechConfig.kt +++ b/app/src/main/java/com/brycewg/asrkb/aidl/SpeechConfig.kt @@ -8,11 +8,16 @@ import android.os.Parcelable * 为保持向后兼容,新增字段请保持可空并提供合理默认。 */ data class SpeechConfig( - val vendorId: String? = null, // 供应商ID(如 "volc"、"soniox");为空则按应用内设置 - val streamingPreferred: Boolean = true,// 调用方偏好流式(若供应商/设置不支持则回落) - val punctuationEnabled: Boolean? = null,// 标点开关(部分供应商有效);null=按应用设置 - val autoStopOnSilence: Boolean? = null,// 静音自动判停(null=按应用设置) - val sessionTag: String? = null // 调用方自定义标记,用于打点/排障 + // 供应商ID(如 "volc"、"soniox");为空则按应用内设置 + val vendorId: String? = null, + // 调用方偏好流式(若供应商/设置不支持则回落) + val streamingPreferred: Boolean = true, + // 标点开关(部分供应商有效);null=按应用设置 + val punctuationEnabled: Boolean? = null, + // 静音自动判停(null=按应用设置) + val autoStopOnSilence: Boolean? = null, + // 调用方自定义标记,用于打点/排障 + val sessionTag: String? = null, ) : Parcelable { override fun writeToParcel(dest: Parcel, flags: Int) { dest.writeString(vendorId) @@ -51,4 +56,3 @@ data class SpeechConfig( } } } - diff --git a/app/src/main/java/com/brycewg/asrkb/analytics/AnalyticsManager.kt b/app/src/main/java/com/brycewg/asrkb/analytics/AnalyticsManager.kt index e1c6cc40..2fd4d9c5 100644 --- a/app/src/main/java/com/brycewg/asrkb/analytics/AnalyticsManager.kt +++ b/app/src/main/java/com/brycewg/asrkb/analytics/AnalyticsManager.kt @@ -7,13 +7,6 @@ import com.brycewg.asrkb.BuildConfig import com.brycewg.asrkb.R import com.brycewg.asrkb.store.AnalyticsStore import com.brycewg.asrkb.store.Prefs -import java.time.LocalDate -import java.time.ZoneId -import java.net.URLEncoder -import java.util.Calendar -import java.util.Locale -import java.util.UUID -import kotlin.random.Random import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Job @@ -29,6 +22,13 @@ import okhttp3.MediaType.Companion.toMediaType import okhttp3.OkHttpClient import okhttp3.Request import okhttp3.RequestBody.Companion.toRequestBody +import java.net.URLEncoder +import java.time.LocalDate +import java.time.ZoneId +import java.util.Calendar +import java.util.Locale +import java.util.UUID +import kotlin.random.Random /** * 轻量匿名统计(PocketBase)。 @@ -37,410 +37,419 @@ import okhttp3.RequestBody.Companion.toRequestBody * - 本地缓存事件,按用户随机的每日时间上传一次 */ object AnalyticsManager { - private const val TAG = "AnalyticsManager" - private const val COLLECTION_DAILY_REPORT = "daily_reports" - private const val COLLECTION_CONSENT = "device_consents" - private const val MIN_UPLOAD_INTERVAL_DAYS = 1 - private const val RETRY_COOLDOWN_MS = 10 * 60 * 1000L - private const val RETRY_INTERVAL_MS = 3 * 60 * 1000L - private const val MAX_RETRIES = 3 - - private val scope = CoroutineScope(SupervisorJob() + Dispatchers.IO) - private val json = Json { - ignoreUnknownKeys = true - isLenient = true - encodeDefaults = true - } - - private var scheduleJob: Job? = null - @Volatile private var uploading = false - - @Serializable - private data class DailyReport( - val userId: String, - val appVersion: String, - val language: String, - val reportAt: Long, - val randomMinuteOfDay: Int, - val eventCount: Int, - val audioMsTotal: Long, - val appStartCount: Int, - val asrEvents: List - ) - - @Serializable - private data class ConsentRecord( - val userId: String, - val enabled: Boolean, - val channel: String, - val firstSeen: Long, - val sdkInt: Int, - val deviceModel: String, - val timestamp: Long, - val appVersion: String, - val language: String - ) - - @Serializable - private data class PbListResponse( - val items: List = emptyList() - ) - - @Serializable - private data class ConsentItem( - val id: String - ) - - fun init(context: Context) { - val appContext = context.applicationContext - val prefs = Prefs(appContext) - if (!prefs.dataCollectionEnabled) return - ensureUserId(prefs) - ensureRandomMinute(prefs) - recordAppStart(appContext) - maybeUploadIfDue(appContext) - scheduleLoop(appContext) - } - - fun recordAppStart(context: Context) { - val appContext = context.applicationContext - val prefs = Prefs(appContext) - if (!prefs.dataCollectionEnabled) return - ensureUserId(prefs) - ensureRandomMinute(prefs) - - scope.launch { - try { - AnalyticsStore(appContext).addAppStart( - AnalyticsStore.AppStartEvent(timestamp = System.currentTimeMillis()) - ) - } catch (t: Throwable) { - Log.w(TAG, "Failed to add app start event", t) - } - } - } - - fun recordAsrEvent( - context: Context, - vendorId: String, - audioMs: Long, - procMs: Long, - source: String, - aiProcessed: Boolean, - charCount: Int - ) { - if (charCount <= 0 && audioMs <= 0L) return - val appContext = context.applicationContext - val prefs = Prefs(appContext) - if (!prefs.dataCollectionEnabled) return - ensureUserId(prefs) - ensureRandomMinute(prefs) - - scope.launch { - try { - AnalyticsStore(appContext).addAsrEvent( - AnalyticsStore.AsrEvent( - timestamp = System.currentTimeMillis(), - vendorId = vendorId, - audioMs = audioMs, - procMs = procMs, - source = source, - aiProcessed = aiProcessed, - charCount = charCount - ) - ) - } catch (t: Throwable) { - Log.w(TAG, "Failed to add ASR event", t) - } - maybeUploadIfDue(appContext) + private const val TAG = "AnalyticsManager" + private const val COLLECTION_DAILY_REPORT = "daily_reports" + private const val COLLECTION_CONSENT = "device_consents" + private const val MIN_UPLOAD_INTERVAL_DAYS = 1 + private const val RETRY_COOLDOWN_MS = 10 * 60 * 1000L + private const val RETRY_INTERVAL_MS = 3 * 60 * 1000L + private const val MAX_RETRIES = 3 + + private val scope = CoroutineScope(SupervisorJob() + Dispatchers.IO) + private val json = Json { + ignoreUnknownKeys = true + isLenient = true + encodeDefaults = true } - } - - private fun ensureUserId(prefs: Prefs): String { - val cur = prefs.analyticsUserId - if (cur.isNotBlank()) return cur - val id = UUID.randomUUID().toString() - prefs.analyticsUserId = id - return id - } - - private fun ensureRandomMinute(prefs: Prefs): Int { - val cur = prefs.analyticsReportMinuteOfDay - if (cur in 0..1439) return cur - val minute = Random.nextInt(0, 1440) - prefs.analyticsReportMinuteOfDay = minute - return minute - } - - private fun maybeUploadIfDue(context: Context) { - val prefs = Prefs(context) - if (!prefs.dataCollectionEnabled) return - - val now = System.currentTimeMillis() - val reportMinute = ensureRandomMinute(prefs) - val cal = Calendar.getInstance() - cal.timeInMillis = now - val minuteOfDay = cal.get(Calendar.HOUR_OF_DAY) * 60 + cal.get(Calendar.MINUTE) - val todayEpochDay = LocalDate.now(ZoneId.systemDefault()).toEpochDay() - val lastUploadEpochDay = prefs.analyticsLastUploadEpochDay - if (lastUploadEpochDay >= 0L && (todayEpochDay - lastUploadEpochDay) < MIN_UPLOAD_INTERVAL_DAYS) return - if (minuteOfDay < reportMinute) return - - val lastAttemptEpochDay = prefs.analyticsLastAttemptEpochDay - val lastAttemptEpochMs = prefs.analyticsLastAttemptEpochMs - val canRetry = lastAttemptEpochDay == todayEpochDay && - prefs.analyticsRetryUsedEpochDay != todayEpochDay && - lastAttemptEpochMs > 0L && - (now - lastAttemptEpochMs) >= RETRY_COOLDOWN_MS - val isFirstAttemptToday = lastAttemptEpochDay != todayEpochDay - if (!isFirstAttemptToday && !canRetry) return - - scope.launch { uploadOnce(context.applicationContext, isRetry = !isFirstAttemptToday) } - } - - private suspend fun uploadOnce(context: Context, isRetry: Boolean) { - val prefs = Prefs(context) - if (!prefs.dataCollectionEnabled) return - val todayEpochDay = LocalDate.now(ZoneId.systemDefault()).toEpochDay() - if (uploading) return - uploading = true - prefs.analyticsLastAttemptEpochDay = todayEpochDay - prefs.analyticsLastAttemptEpochMs = System.currentTimeMillis() - if (isRetry) { - prefs.analyticsRetryUsedEpochDay = todayEpochDay + + private var scheduleJob: Job? = null + + @Volatile private var uploading = false + + @Serializable + private data class DailyReport( + val userId: String, + val appVersion: String, + val language: String, + val reportAt: Long, + val randomMinuteOfDay: Int, + val eventCount: Int, + val audioMsTotal: Long, + val appStartCount: Int, + val asrEvents: List, + ) + + @Serializable + private data class ConsentRecord( + val userId: String, + val enabled: Boolean, + val channel: String, + val firstSeen: Long, + val sdkInt: Int, + val deviceModel: String, + val timestamp: Long, + val appVersion: String, + val language: String, + ) + + @Serializable + private data class PbListResponse( + val items: List = emptyList(), + ) + + @Serializable + private data class ConsentItem( + val id: String, + ) + + fun init(context: Context) { + val appContext = context.applicationContext + val prefs = Prefs(appContext) + if (!prefs.dataCollectionEnabled) return + ensureUserId(prefs) + ensureRandomMinute(prefs) + recordAppStart(appContext) + maybeUploadIfDue(appContext) + scheduleLoop(appContext) } - try { - val baseUrl = try { - context.getString(R.string.pocketbase_base_url).trim().trimEnd('/') - } catch (t: Throwable) { - Log.w(TAG, "Failed to read PocketBase base url", t) - "" - } - if (baseUrl.isBlank()) { - Log.w(TAG, "PocketBase base url empty, skip upload") - return - } - - val userId = ensureUserId(prefs) - val store = AnalyticsStore(context) - val asrEvents = try { store.listAsrEvents() } catch (t: Throwable) { - Log.w(TAG, "Failed to read ASR events", t) - emptyList() - } - val appStarts = try { store.listAppStarts() } catch (t: Throwable) { - Log.w(TAG, "Failed to read app starts", t) - emptyList() - } - - val language = resolveLanguage(prefs) - val report = DailyReport( - userId = userId, - appVersion = BuildConfig.VERSION_NAME, - language = language, - reportAt = System.currentTimeMillis(), - randomMinuteOfDay = prefs.analyticsReportMinuteOfDay.coerceIn(0, 1439), - eventCount = asrEvents.size, - audioMsTotal = asrEvents.sumOf { it.audioMs.coerceAtLeast(0L) }, - appStartCount = appStarts.size, - asrEvents = asrEvents - ) - - val bodyText = try { json.encodeToString(report) } catch (t: Throwable) { - Log.w(TAG, "Failed to encode daily report", t) - return - } - - val reqBody = bodyText.toRequestBody("application/json; charset=utf-8".toMediaType()) - val url = "$baseUrl/api/collections/$COLLECTION_DAILY_REPORT/records" - val req = Request.Builder().url(url).post(reqBody).build() - - val ok = postWithRetries(req) - if (ok) { - prefs.analyticsLastUploadEpochDay = todayEpochDay - store.deleteAsrEventsByIds(asrEvents.map { it.id }.toSet()) - store.deleteAppStartsByIds(appStarts.map { it.id }.toSet()) - Log.i(TAG, "Daily report uploaded, events=${asrEvents.size}, starts=${appStarts.size}") - } - } finally { - uploading = false + + fun recordAppStart(context: Context) { + val appContext = context.applicationContext + val prefs = Prefs(appContext) + if (!prefs.dataCollectionEnabled) return + ensureUserId(prefs) + ensureRandomMinute(prefs) + + scope.launch { + try { + AnalyticsStore(appContext).addAppStart( + AnalyticsStore.AppStartEvent(timestamp = System.currentTimeMillis()), + ) + } catch (t: Throwable) { + Log.w(TAG, "Failed to add app start event", t) + } + } } - } - - /** - * 用户选择匿名统计开关时上报一次(即使未开启统计也会上报 enabled=false)。 - */ - fun sendConsentChoice(context: Context, enabled: Boolean) { - val appContext = context.applicationContext - val prefs = Prefs(appContext) - val baseUrl = try { - appContext.getString(R.string.pocketbase_base_url).trim().trimEnd('/') - } catch (t: Throwable) { - Log.w(TAG, "Failed to read PocketBase base url for consent", t) - "" + + fun recordAsrEvent( + context: Context, + vendorId: String, + audioMs: Long, + procMs: Long, + source: String, + aiProcessed: Boolean, + charCount: Int, + ) { + if (charCount <= 0 && audioMs <= 0L) return + val appContext = context.applicationContext + val prefs = Prefs(appContext) + if (!prefs.dataCollectionEnabled) return + ensureUserId(prefs) + ensureRandomMinute(prefs) + + scope.launch { + try { + AnalyticsStore(appContext).addAsrEvent( + AnalyticsStore.AsrEvent( + timestamp = System.currentTimeMillis(), + vendorId = vendorId, + audioMs = audioMs, + procMs = procMs, + source = source, + aiProcessed = aiProcessed, + charCount = charCount, + ), + ) + } catch (t: Throwable) { + Log.w(TAG, "Failed to add ASR event", t) + } + maybeUploadIfDue(appContext) + } } - if (baseUrl.isBlank()) { - Log.w(TAG, "PocketBase base url empty, skip consent upload") - return + + private fun ensureUserId(prefs: Prefs): String { + val cur = prefs.analyticsUserId + if (cur.isNotBlank()) return cur + val id = UUID.randomUUID().toString() + prefs.analyticsUserId = id + return id } - val userId = ensureUserId(prefs) - val language = resolveLanguage(prefs) - val channel = resolveInstallChannel(appContext) - val firstSeen = resolveFirstSeenEpochMs(prefs) - val sdkInt = resolveSdkInt() - val deviceModel = resolveDeviceModel() - val record = ConsentRecord( - userId = userId, - enabled = enabled, - channel = channel, - firstSeen = firstSeen, - sdkInt = sdkInt, - deviceModel = deviceModel, - timestamp = System.currentTimeMillis(), - appVersion = BuildConfig.VERSION_NAME, - language = language - ) - val bodyText = try { json.encodeToString(record) } catch (t: Throwable) { - Log.w(TAG, "Failed to encode consent record", t) - return + private fun ensureRandomMinute(prefs: Prefs): Int { + val cur = prefs.analyticsReportMinuteOfDay + if (cur in 0..1439) return cur + val minute = Random.nextInt(0, 1440) + prefs.analyticsReportMinuteOfDay = minute + return minute } - val reqBody = bodyText.toRequestBody("application/json; charset=utf-8".toMediaType()) - scope.launch { - val existingId = fetchLatestConsentId(baseUrl, userId) - val req = if (existingId != null) { - val url = "$baseUrl/api/collections/$COLLECTION_CONSENT/records/$existingId" - Request.Builder().url(url).patch(reqBody).build() - } else { - val url = "$baseUrl/api/collections/$COLLECTION_CONSENT/records" - Request.Builder().url(url).post(reqBody).build() - } - postWithRetries(req) + + private fun maybeUploadIfDue(context: Context) { + val prefs = Prefs(context) + if (!prefs.dataCollectionEnabled) return + + val now = System.currentTimeMillis() + val reportMinute = ensureRandomMinute(prefs) + val cal = Calendar.getInstance() + cal.timeInMillis = now + val minuteOfDay = cal.get(Calendar.HOUR_OF_DAY) * 60 + cal.get(Calendar.MINUTE) + val todayEpochDay = LocalDate.now(ZoneId.systemDefault()).toEpochDay() + val lastUploadEpochDay = prefs.analyticsLastUploadEpochDay + if (lastUploadEpochDay >= 0L && (todayEpochDay - lastUploadEpochDay) < MIN_UPLOAD_INTERVAL_DAYS) return + if (minuteOfDay < reportMinute) return + + val lastAttemptEpochDay = prefs.analyticsLastAttemptEpochDay + val lastAttemptEpochMs = prefs.analyticsLastAttemptEpochMs + val canRetry = lastAttemptEpochDay == todayEpochDay && + prefs.analyticsRetryUsedEpochDay != todayEpochDay && + lastAttemptEpochMs > 0L && + (now - lastAttemptEpochMs) >= RETRY_COOLDOWN_MS + val isFirstAttemptToday = lastAttemptEpochDay != todayEpochDay + if (!isFirstAttemptToday && !canRetry) return + + scope.launch { uploadOnce(context.applicationContext, isRetry = !isFirstAttemptToday) } } - } - - private suspend fun fetchLatestConsentId(baseUrl: String, userId: String): String? { - val filterRaw = "userId=\"$userId\"" - val filterEncoded = try { - URLEncoder.encode(filterRaw, "UTF-8") - } catch (t: Throwable) { - Log.w(TAG, "Failed to encode consent filter", t) - return null + + private suspend fun uploadOnce(context: Context, isRetry: Boolean) { + val prefs = Prefs(context) + if (!prefs.dataCollectionEnabled) return + val todayEpochDay = LocalDate.now(ZoneId.systemDefault()).toEpochDay() + if (uploading) return + uploading = true + prefs.analyticsLastAttemptEpochDay = todayEpochDay + prefs.analyticsLastAttemptEpochMs = System.currentTimeMillis() + if (isRetry) { + prefs.analyticsRetryUsedEpochDay = todayEpochDay + } + try { + val baseUrl = try { + context.getString(R.string.pocketbase_base_url).trim().trimEnd('/') + } catch (t: Throwable) { + Log.w(TAG, "Failed to read PocketBase base url", t) + "" + } + if (baseUrl.isBlank()) { + Log.w(TAG, "PocketBase base url empty, skip upload") + return + } + + val userId = ensureUserId(prefs) + val store = AnalyticsStore(context) + val asrEvents = try { + store.listAsrEvents() + } catch (t: Throwable) { + Log.w(TAG, "Failed to read ASR events", t) + emptyList() + } + val appStarts = try { + store.listAppStarts() + } catch (t: Throwable) { + Log.w(TAG, "Failed to read app starts", t) + emptyList() + } + + val language = resolveLanguage(prefs) + val report = DailyReport( + userId = userId, + appVersion = BuildConfig.VERSION_NAME, + language = language, + reportAt = System.currentTimeMillis(), + randomMinuteOfDay = prefs.analyticsReportMinuteOfDay.coerceIn(0, 1439), + eventCount = asrEvents.size, + audioMsTotal = asrEvents.sumOf { it.audioMs.coerceAtLeast(0L) }, + appStartCount = appStarts.size, + asrEvents = asrEvents, + ) + + val bodyText = try { + json.encodeToString(report) + } catch (t: Throwable) { + Log.w(TAG, "Failed to encode daily report", t) + return + } + + val reqBody = bodyText.toRequestBody("application/json; charset=utf-8".toMediaType()) + val url = "$baseUrl/api/collections/$COLLECTION_DAILY_REPORT/records" + val req = Request.Builder().url(url).post(reqBody).build() + + val ok = postWithRetries(req) + if (ok) { + prefs.analyticsLastUploadEpochDay = todayEpochDay + store.deleteAsrEventsByIds(asrEvents.map { it.id }.toSet()) + store.deleteAppStartsByIds(appStarts.map { it.id }.toSet()) + Log.i(TAG, "Daily report uploaded, events=${asrEvents.size}, starts=${appStarts.size}") + } + } finally { + uploading = false + } } - val url = - "$baseUrl/api/collections/$COLLECTION_CONSENT/records?filter=$filterEncoded&sort=-timestamp&perPage=1" - val req = Request.Builder().url(url).get().build() - return try { - OkHttpClient.Builder().build().newCall(req).execute().use { res -> - if (!res.isSuccessful) { - Log.w(TAG, "Fetch consent list failed: code=${res.code}") - return null + + /** + * 用户选择匿名统计开关时上报一次(即使未开启统计也会上报 enabled=false)。 + */ + fun sendConsentChoice(context: Context, enabled: Boolean) { + val appContext = context.applicationContext + val prefs = Prefs(appContext) + val baseUrl = try { + appContext.getString(R.string.pocketbase_base_url).trim().trimEnd('/') + } catch (t: Throwable) { + Log.w(TAG, "Failed to read PocketBase base url for consent", t) + "" + } + if (baseUrl.isBlank()) { + Log.w(TAG, "PocketBase base url empty, skip consent upload") + return + } + + val userId = ensureUserId(prefs) + val language = resolveLanguage(prefs) + val channel = resolveInstallChannel(appContext) + val firstSeen = resolveFirstSeenEpochMs(prefs) + val sdkInt = resolveSdkInt() + val deviceModel = resolveDeviceModel() + val record = ConsentRecord( + userId = userId, + enabled = enabled, + channel = channel, + firstSeen = firstSeen, + sdkInt = sdkInt, + deviceModel = deviceModel, + timestamp = System.currentTimeMillis(), + appVersion = BuildConfig.VERSION_NAME, + language = language, + ) + val bodyText = try { + json.encodeToString(record) + } catch (t: Throwable) { + Log.w(TAG, "Failed to encode consent record", t) + return + } + val reqBody = bodyText.toRequestBody("application/json; charset=utf-8".toMediaType()) + scope.launch { + val existingId = fetchLatestConsentId(baseUrl, userId) + val req = if (existingId != null) { + val url = "$baseUrl/api/collections/$COLLECTION_CONSENT/records/$existingId" + Request.Builder().url(url).patch(reqBody).build() + } else { + val url = "$baseUrl/api/collections/$COLLECTION_CONSENT/records" + Request.Builder().url(url).post(reqBody).build() + } + postWithRetries(req) } - val body = res.body?.string().orEmpty() - val parsed = json.decodeFromString>(body) - parsed.items.firstOrNull()?.id - } - } catch (t: Throwable) { - Log.w(TAG, "Fetch consent list error", t) - null } - } - private fun resolveLanguage(prefs: Prefs): String { - val tag = prefs.appLanguageTag.trim() - return if (tag.isNotBlank()) tag else Locale.getDefault().toLanguageTag() - } + private suspend fun fetchLatestConsentId(baseUrl: String, userId: String): String? { + val filterRaw = "userId=\"$userId\"" + val filterEncoded = try { + URLEncoder.encode(filterRaw, "UTF-8") + } catch (t: Throwable) { + Log.w(TAG, "Failed to encode consent filter", t) + return null + } + val url = + "$baseUrl/api/collections/$COLLECTION_CONSENT/records?filter=$filterEncoded&sort=-timestamp&perPage=1" + val req = Request.Builder().url(url).get().build() + return try { + OkHttpClient.Builder().build().newCall(req).execute().use { res -> + if (!res.isSuccessful) { + Log.w(TAG, "Fetch consent list failed: code=${res.code}") + return null + } + val body = res.body?.string().orEmpty() + val parsed = json.decodeFromString>(body) + parsed.items.firstOrNull()?.id + } + } catch (t: Throwable) { + Log.w(TAG, "Fetch consent list error", t) + null + } + } - private fun scheduleLoop(context: Context) { - scheduleJob?.cancel() - scheduleJob = scope.launch { - while (isActive) { - val prefs = Prefs(context) - if (!prefs.dataCollectionEnabled) break - val minute = ensureRandomMinute(prefs) - val delayMs = computeDelayToNextReport(minute) - delay(delayMs) - maybeUploadIfDue(context) - } + private fun resolveLanguage(prefs: Prefs): String { + val tag = prefs.appLanguageTag.trim() + return if (tag.isNotBlank()) tag else Locale.getDefault().toLanguageTag() } - } - - private fun computeDelayToNextReport(reportMinute: Int): Long { - val now = System.currentTimeMillis() - val cal = Calendar.getInstance() - cal.timeInMillis = now - val next = cal.clone() as Calendar - next.set(Calendar.HOUR_OF_DAY, reportMinute / 60) - next.set(Calendar.MINUTE, reportMinute % 60) - next.set(Calendar.SECOND, 0) - next.set(Calendar.MILLISECOND, 0) - if (next.timeInMillis <= now) { - next.add(Calendar.DAY_OF_YEAR, 1) + + private fun scheduleLoop(context: Context) { + scheduleJob?.cancel() + scheduleJob = scope.launch { + while (isActive) { + val prefs = Prefs(context) + if (!prefs.dataCollectionEnabled) break + val minute = ensureRandomMinute(prefs) + val delayMs = computeDelayToNextReport(minute) + delay(delayMs) + maybeUploadIfDue(context) + } + } } - return (next.timeInMillis - now).coerceAtLeast(60_000L) - } - - private fun resolveInstallChannel(context: Context): String { - if (BuildConfig.DEBUG) return "debug" - return try { - val pm = context.packageManager - val installer = if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.R) { - pm.getInstallSourceInfo(context.packageName).installingPackageName - } else { - @Suppress("DEPRECATION") - pm.getInstallerPackageName(context.packageName) - } - installer?.takeIf { it.isNotBlank() } ?: "unknown" - } catch (t: Throwable) { - Log.w(TAG, "Failed to resolve install channel", t) - "unknown" + + private fun computeDelayToNextReport(reportMinute: Int): Long { + val now = System.currentTimeMillis() + val cal = Calendar.getInstance() + cal.timeInMillis = now + val next = cal.clone() as Calendar + next.set(Calendar.HOUR_OF_DAY, reportMinute / 60) + next.set(Calendar.MINUTE, reportMinute % 60) + next.set(Calendar.SECOND, 0) + next.set(Calendar.MILLISECOND, 0) + if (next.timeInMillis <= now) { + next.add(Calendar.DAY_OF_YEAR, 1) + } + return (next.timeInMillis - now).coerceAtLeast(60_000L) } - } - private fun resolveFirstSeenEpochMs(prefs: Prefs): Long { - val fud = prefs.firstUseDate.ifBlank { - val today = LocalDate.now().format(java.time.format.DateTimeFormatter.BASIC_ISO_DATE) - prefs.firstUseDate = today - today + private fun resolveInstallChannel(context: Context): String { + if (BuildConfig.DEBUG) return "debug" + return try { + val pm = context.packageManager + val installer = if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.R) { + pm.getInstallSourceInfo(context.packageName).installingPackageName + } else { + @Suppress("DEPRECATION") + pm.getInstallerPackageName(context.packageName) + } + installer?.takeIf { it.isNotBlank() } ?: "unknown" + } catch (t: Throwable) { + Log.w(TAG, "Failed to resolve install channel", t) + "unknown" + } } - return try { - val date = LocalDate.parse(fud, java.time.format.DateTimeFormatter.BASIC_ISO_DATE) - date.atStartOfDay(ZoneId.systemDefault()).toInstant().toEpochMilli() - } catch (t: Throwable) { - Log.w(TAG, "Failed to parse firstUseDate '$fud' for firstSeen", t) - System.currentTimeMillis() + + private fun resolveFirstSeenEpochMs(prefs: Prefs): Long { + val fud = prefs.firstUseDate.ifBlank { + val today = LocalDate.now().format(java.time.format.DateTimeFormatter.BASIC_ISO_DATE) + prefs.firstUseDate = today + today + } + return try { + val date = LocalDate.parse(fud, java.time.format.DateTimeFormatter.BASIC_ISO_DATE) + date.atStartOfDay(ZoneId.systemDefault()).toInstant().toEpochMilli() + } catch (t: Throwable) { + Log.w(TAG, "Failed to parse firstUseDate '$fud' for firstSeen", t) + System.currentTimeMillis() + } } - } - private fun resolveSdkInt(): Int = Build.VERSION.SDK_INT + private fun resolveSdkInt(): Int = Build.VERSION.SDK_INT - private fun resolveDeviceModel(): String { - val manufacturer = Build.MANUFACTURER?.trim().orEmpty() - val model = Build.MODEL?.trim().orEmpty() - val combined = listOf(manufacturer, model).filter { it.isNotBlank() }.joinToString(" ") - return combined.ifBlank { "unknown" } - } + private fun resolveDeviceModel(): String { + val manufacturer = Build.MANUFACTURER?.trim().orEmpty() + val model = Build.MODEL?.trim().orEmpty() + val combined = listOf(manufacturer, model).filter { it.isNotBlank() }.joinToString(" ") + return combined.ifBlank { "unknown" } + } - private suspend fun postWithRetries(req: Request): Boolean { - try { - val client = OkHttpClient.Builder().build() - for (attempt in 1..MAX_RETRIES) { + private suspend fun postWithRetries(req: Request): Boolean { try { - client.newCall(req).execute().use { res -> - if (res.isSuccessful) { - return true + val client = OkHttpClient.Builder().build() + for (attempt in 1..MAX_RETRIES) { + try { + client.newCall(req).execute().use { res -> + if (res.isSuccessful) { + return true + } + Log.w(TAG, "POST failed (attempt=$attempt): code=${res.code}") + } + } catch (t: Throwable) { + Log.w(TAG, "POST error (attempt=$attempt)", t) + } + if (attempt < MAX_RETRIES) { + delay(RETRY_INTERVAL_MS) + } } - Log.w(TAG, "POST failed (attempt=$attempt): code=${res.code}") - } } catch (t: Throwable) { - Log.w(TAG, "POST error (attempt=$attempt)", t) - } - if (attempt < MAX_RETRIES) { - delay(RETRY_INTERVAL_MS) + Log.w(TAG, "postWithRetries unexpected error", t) } - } - } catch (t: Throwable) { - Log.w(TAG, "postWithRetries unexpected error", t) + return false } - return false - } } diff --git a/app/src/main/java/com/brycewg/asrkb/api/AsrRecognitionService.kt b/app/src/main/java/com/brycewg/asrkb/api/AsrRecognitionService.kt index ecdff844..78f902e8 100644 --- a/app/src/main/java/com/brycewg/asrkb/api/AsrRecognitionService.kt +++ b/app/src/main/java/com/brycewg/asrkb/api/AsrRecognitionService.kt @@ -13,7 +13,33 @@ import android.speech.SpeechRecognizer import android.util.Log import androidx.annotation.RequiresApi import androidx.core.content.ContextCompat -import com.brycewg.asrkb.asr.* +import com.brycewg.asrkb.asr.AsrTimeoutCalculator +import com.brycewg.asrkb.asr.AsrVendor +import com.brycewg.asrkb.asr.DashscopeFileAsrEngine +import com.brycewg.asrkb.asr.DashscopeStreamAsrEngine +import com.brycewg.asrkb.asr.ElevenLabsFileAsrEngine +import com.brycewg.asrkb.asr.ElevenLabsStreamAsrEngine +import com.brycewg.asrkb.asr.FunAsrNanoFileAsrEngine +import com.brycewg.asrkb.asr.GeminiFileAsrEngine +import com.brycewg.asrkb.asr.LOCAL_MODEL_READY_WAIT_MAX_MS +import com.brycewg.asrkb.asr.OpenAiFileAsrEngine +import com.brycewg.asrkb.asr.ParaformerStreamAsrEngine +import com.brycewg.asrkb.asr.ParallelAsrEngine +import com.brycewg.asrkb.asr.SenseVoiceFileAsrEngine +import com.brycewg.asrkb.asr.SenseVoicePseudoStreamAsrEngine +import com.brycewg.asrkb.asr.SiliconFlowFileAsrEngine +import com.brycewg.asrkb.asr.SonioxFileAsrEngine +import com.brycewg.asrkb.asr.SonioxStreamAsrEngine +import com.brycewg.asrkb.asr.StreamingAsrEngine +import com.brycewg.asrkb.asr.TelespeechFileAsrEngine +import com.brycewg.asrkb.asr.TelespeechPseudoStreamAsrEngine +import com.brycewg.asrkb.asr.VadAutoStopGuard +import com.brycewg.asrkb.asr.VolcFileAsrEngine +import com.brycewg.asrkb.asr.VolcStandardFileAsrEngine +import com.brycewg.asrkb.asr.VolcStreamAsrEngine +import com.brycewg.asrkb.asr.ZhipuFileAsrEngine +import com.brycewg.asrkb.asr.awaitLocalAsrReady +import com.brycewg.asrkb.asr.isLocalAsrVendor import com.brycewg.asrkb.store.Prefs import com.brycewg.asrkb.util.TypewriterTextAnimator import kotlinx.coroutines.CoroutineScope @@ -57,7 +83,7 @@ class AsrRecognitionService : RecognitionService() { // 检查录音权限 val hasPermission = ContextCompat.checkSelfPermission( this, - android.Manifest.permission.RECORD_AUDIO + android.Manifest.permission.RECORD_AUDIO, ) == PackageManager.PERMISSION_GRANTED if (!hasPermission) { @@ -68,7 +94,7 @@ class AsrRecognitionService : RecognitionService() { // 校验调用方录音权限(系统通常已做校验,但此处显式防御,避免成为权限代理录音入口) val callingHasPermission = checkCallingOrSelfPermission( - android.Manifest.permission.RECORD_AUDIO + android.Manifest.permission.RECORD_AUDIO, ) == PackageManager.PERMISSION_GRANTED if (!callingHasPermission) { Log.w(TAG, "Calling app missing RECORD_AUDIO permission") @@ -89,7 +115,7 @@ class AsrRecognitionService : RecognitionService() { val config = parseRecognizerIntent(intent) Log.d( TAG, - "Parsed external config: language=${config.language}, partialResults=${config.partialResults}" + "Parsed external config: language=${config.language}, partialResults=${config.partialResults}", ) // language 等参数仅用于日志/回调控制(如是否返回 partialResults), // 不会影响内部供应商选择、本地/云端模型或具体识别配置,相关行为完全由 Prefs 决定。 @@ -97,7 +123,7 @@ class AsrRecognitionService : RecognitionService() { // 创建会话(先创建以便作为 listener 传递给引擎) val session = RecognitionSession( callback = callback, - config = config + config = config, ) // 构建引擎 @@ -147,7 +173,7 @@ class AsrRecognitionService : RecognitionService() { return RecognitionConfig( language = language, partialResults = partialResults, - maxResults = maxResults + maxResults = maxResults, ) } @@ -185,7 +211,7 @@ class AsrRecognitionService : RecognitionService() { prefs = prefs, listener = listener, primaryVendor = vendor, - backupVendor = backupVendor + backupVendor = backupVendor, ) } @@ -244,7 +270,11 @@ class AsrRecognitionService : RecognitionService() { } private fun shouldUseBackupAsr(primaryVendor: AsrVendor, backupVendor: AsrVendor): Boolean { - val enabled = try { prefs.backupAsrEnabled } catch (_: Throwable) { false } + val enabled = try { + prefs.backupAsrEnabled + } catch (_: Throwable) { + false + } if (!enabled) return false if (backupVendor == primaryVendor) return false return try { @@ -281,21 +311,21 @@ class AsrRecognitionService : RecognitionService() { message.contains("permission", ignoreCase = true) -> SpeechRecognizer.ERROR_INSUFFICIENT_PERMISSIONS message.contains("network", ignoreCase = true) || - message.contains("timeout", ignoreCase = true) || - message.contains("connect", ignoreCase = true) -> + message.contains("timeout", ignoreCase = true) || + message.contains("connect", ignoreCase = true) -> SpeechRecognizer.ERROR_NETWORK message.contains("audio", ignoreCase = true) || - message.contains("microphone", ignoreCase = true) || - message.contains("record", ignoreCase = true) -> + message.contains("microphone", ignoreCase = true) || + message.contains("record", ignoreCase = true) -> SpeechRecognizer.ERROR_AUDIO message.contains("busy", ignoreCase = true) -> SpeechRecognizer.ERROR_RECOGNIZER_BUSY message.contains("empty", ignoreCase = true) || - message.contains("no speech", ignoreCase = true) || - message.contains("no match", ignoreCase = true) -> + message.contains("no speech", ignoreCase = true) || + message.contains("no match", ignoreCase = true) -> SpeechRecognizer.ERROR_NO_MATCH message.contains("server", ignoreCase = true) || - message.contains("api", ignoreCase = true) -> + message.contains("api", ignoreCase = true) -> SpeechRecognizer.ERROR_SERVER else -> SpeechRecognizer.ERROR_CLIENT } @@ -307,7 +337,7 @@ class AsrRecognitionService : RecognitionService() { private data class RecognitionConfig( val language: String?, val partialResults: Boolean, - val maxResults: Int + val maxResults: Int, ) /** @@ -315,7 +345,7 @@ class AsrRecognitionService : RecognitionService() { */ private inner class RecognitionSession( private val callback: Callback, - private val config: RecognitionConfig + private val config: RecognitionConfig, ) : StreamingAsrEngine.Listener { private var engine: StreamingAsrEngine? = null @@ -409,13 +439,23 @@ class AsrRecognitionService : RecognitionService() { val usedBackupResult = (engine as? ParallelAsrEngine)?.wasLastResultFromBackup() == true - val doAi = try { prefs.postProcessEnabled && prefs.hasLlmKeys() } catch (_: Throwable) { false } + val doAi = try { + prefs.postProcessEnabled && prefs.hasLlmKeys() + } catch (_: Throwable) { + false + } serviceScope.launch { if (canceled || finished) return@launch val processedText = if (doAi) { val allowPartial = config.partialResults - val typewriterEnabled = allowPartial && (try { prefs.postprocTypewriterEnabled } catch (_: Throwable) { true }) + val typewriterEnabled = allowPartial && ( + try { + prefs.postprocTypewriterEnabled + } catch (_: Throwable) { + true + } + ) var postprocCommitted = false var lastPostprocTarget: String? = null val typewriter = if (typewriterEnabled) { @@ -428,7 +468,7 @@ class AsrRecognitionService : RecognitionService() { deliverPartialResults(typed) }, frameDelayMs = 20L, - idleStopDelayMs = 1200L + idleStopDelayMs = 1200L, ) } else { null @@ -446,32 +486,34 @@ class AsrRecognitionService : RecognitionService() { deliverPartialResults(streamed) } } - } else null - try { - val res = com.brycewg.asrkb.util.AsrFinalFilters.applyWithAi( - this@AsrRecognitionService, - prefs, - text, - onStreamingUpdate = onStreamingUpdate - ) - val aiUsed = (res.usedAi && res.ok) - val finalOut = res.text.ifBlank { - try { - com.brycewg.asrkb.util.AsrFinalFilters.applySimple( - this@AsrRecognitionService, - prefs, - text - ) - } catch (_: Throwable) { - text - } - } - if (typewriter != null && aiUsed && finalOut.isNotEmpty()) { - typewriter?.submit(finalOut, rush = true) - val finalLen = finalOut.length - val t0 = SystemClock.uptimeMillis() - while (!canceled && !finished && (SystemClock.uptimeMillis() - t0) < 2_000L && - typewriter?.currentText()?.length != finalLen + } else { + null + } + try { + val res = com.brycewg.asrkb.util.AsrFinalFilters.applyWithAi( + this@AsrRecognitionService, + prefs, + text, + onStreamingUpdate = onStreamingUpdate, + ) + val aiUsed = (res.usedAi && res.ok) + val finalOut = res.text.ifBlank { + try { + com.brycewg.asrkb.util.AsrFinalFilters.applySimple( + this@AsrRecognitionService, + prefs, + text, + ) + } catch (_: Throwable) { + text + } + } + if (typewriter != null && aiUsed && finalOut.isNotEmpty()) { + typewriter?.submit(finalOut, rush = true) + val finalLen = finalOut.length + val t0 = SystemClock.uptimeMillis() + while (!canceled && !finished && (SystemClock.uptimeMillis() - t0) < 2_000L && + typewriter?.currentText()?.length != finalLen ) { delay(20) } @@ -483,7 +525,7 @@ class AsrRecognitionService : RecognitionService() { com.brycewg.asrkb.util.AsrFinalFilters.applySimple( this@AsrRecognitionService, prefs, - text + text, ) } catch (_: Throwable) { text @@ -497,7 +539,7 @@ class AsrRecognitionService : RecognitionService() { com.brycewg.asrkb.util.AsrFinalFilters.applySimple( this@AsrRecognitionService, prefs, - text + text, ) } catch (t: Throwable) { Log.w(TAG, "Post-processing failed", t) @@ -511,12 +553,13 @@ class AsrRecognitionService : RecognitionService() { val results = Bundle().apply { putStringArrayList( SpeechRecognizer.RESULTS_RECOGNITION, - arrayListOf(processedText) + arrayListOf(processedText), ) // 可选:添加置信度分数 putFloatArray( SpeechRecognizer.CONFIDENCE_SCORES, - floatArrayOf(1.0f) // 单结果,置信度设为 1.0 + // 单结果,置信度设为 1.0 + floatArrayOf(1.0f), ) putBoolean(EXTRA_USED_BACKUP_ASR, usedBackupResult) } @@ -673,7 +716,7 @@ class AsrRecognitionService : RecognitionService() { val partialBundle = Bundle().apply { putStringArrayList( SpeechRecognizer.RESULTS_RECOGNITION, - arrayListOf(text) + arrayListOf(text), ) } try { diff --git a/app/src/main/java/com/brycewg/asrkb/api/ExternalSpeechService.kt b/app/src/main/java/com/brycewg/asrkb/api/ExternalSpeechService.kt index c9a0a662..aaad572d 100644 --- a/app/src/main/java/com/brycewg/asrkb/api/ExternalSpeechService.kt +++ b/app/src/main/java/com/brycewg/asrkb/api/ExternalSpeechService.kt @@ -6,14 +6,39 @@ import android.content.Intent import android.content.pm.PackageManager import android.os.Binder import android.os.IBinder -import android.os.SystemClock import android.os.Parcel +import android.os.SystemClock import android.util.Log import androidx.core.content.ContextCompat import com.brycewg.asrkb.R import com.brycewg.asrkb.aidl.SpeechConfig -import com.brycewg.asrkb.asr.* import com.brycewg.asrkb.analytics.AnalyticsManager +import com.brycewg.asrkb.asr.AsrTimeoutCalculator +import com.brycewg.asrkb.asr.AsrVendor +import com.brycewg.asrkb.asr.DashscopeFileAsrEngine +import com.brycewg.asrkb.asr.DashscopeStreamAsrEngine +import com.brycewg.asrkb.asr.ElevenLabsFileAsrEngine +import com.brycewg.asrkb.asr.ElevenLabsStreamAsrEngine +import com.brycewg.asrkb.asr.FunAsrNanoFileAsrEngine +import com.brycewg.asrkb.asr.GeminiFileAsrEngine +import com.brycewg.asrkb.asr.GenericPushFileAsrAdapter +import com.brycewg.asrkb.asr.OpenAiFileAsrEngine +import com.brycewg.asrkb.asr.ParaformerStreamAsrEngine +import com.brycewg.asrkb.asr.ParallelAsrEngine +import com.brycewg.asrkb.asr.SenseVoiceFileAsrEngine +import com.brycewg.asrkb.asr.SiliconFlowFileAsrEngine +import com.brycewg.asrkb.asr.SonioxFileAsrEngine +import com.brycewg.asrkb.asr.SonioxStreamAsrEngine +import com.brycewg.asrkb.asr.StreamingAsrEngine +import com.brycewg.asrkb.asr.TelespeechFileAsrEngine +import com.brycewg.asrkb.asr.VadAutoStopGuard +import com.brycewg.asrkb.asr.VolcFileAsrEngine +import com.brycewg.asrkb.asr.VolcStandardFileAsrEngine +import com.brycewg.asrkb.asr.VolcStreamAsrEngine +import com.brycewg.asrkb.asr.ZhipuFileAsrEngine +import com.brycewg.asrkb.asr.awaitLocalAsrReady +import com.brycewg.asrkb.asr.isLocalAsrReady +import com.brycewg.asrkb.asr.isLocalAsrVendor import com.brycewg.asrkb.store.Prefs import com.brycewg.asrkb.util.TypewriterTextAnimator import kotlinx.coroutines.CoroutineScope @@ -34,6 +59,7 @@ class ExternalSpeechService : Service() { private val prefs by lazy { Prefs(this) } private val sessions = ConcurrentHashMap() + @Volatile private var nextId: Int = 1 override fun onBind(intent: Intent?): IBinder? = object : Binder() { @@ -43,7 +69,7 @@ class ExternalSpeechService : Service() { reply?.writeString(DESCRIPTOR_SVC) return true } - TRANSACTION_startSession -> { + TRANSACTION_START_SESSION -> { data.enforceInterface(DESCRIPTOR_SVC) val cfg = if (data.readInt() != 0) SpeechConfig.CREATOR.createFromParcel(data) else null val cbBinder = data.readStrongBinder() @@ -52,7 +78,10 @@ class ExternalSpeechService : Service() { // 开关与权限检查:仅要求开启外部联动 if (!prefs.externalAidlEnabled) { safe { cb.onError(-1, 403, "feature disabled") } - reply?.apply { writeNoException(); writeInt(-3) } + reply?.apply { + writeNoException() + writeInt(-3) + } return true } // 联通测试:当 vendorId == "mock" 时,无需录音权限,直接回调固定内容并结束 @@ -62,43 +91,58 @@ class ExternalSpeechService : Service() { safe { cb.onPartial(sid, "【联通测试中】……") } safe { cb.onFinal(sid, "说点啥外部AIDL联通成功(mock)") } safe { cb.onState(sid, STATE_IDLE, "final") } - reply?.apply { writeNoException(); writeInt(sid) } + reply?.apply { + writeNoException() + writeInt(sid) + } return true } val permOk = ContextCompat.checkSelfPermission( this@ExternalSpeechService, - android.Manifest.permission.RECORD_AUDIO + android.Manifest.permission.RECORD_AUDIO, ) == PackageManager.PERMISSION_GRANTED if (!permOk) { safe { cb.onError(-1, 401, "record permission denied") } - reply?.apply { writeNoException(); writeInt(-4) } + reply?.apply { + writeNoException() + writeInt(-4) + } return true } if (sessions.values.any { it.engine?.isRunning == true }) { - reply?.apply { writeNoException(); writeInt(-2) } + reply?.apply { + writeNoException() + writeInt(-2) + } return true } val sid = synchronized(this@ExternalSpeechService) { nextId++ } val s = Session(sid, this@ExternalSpeechService, prefs, cb) if (!s.prepare()) { - reply?.apply { writeNoException(); writeInt(-3) } + reply?.apply { + writeNoException() + writeInt(-3) + } return true } sessions[sid] = s s.start() - reply?.apply { writeNoException(); writeInt(sid) } + reply?.apply { + writeNoException() + writeInt(sid) + } return true } - TRANSACTION_stopSession -> { + TRANSACTION_STOP_SESSION -> { data.enforceInterface(DESCRIPTOR_SVC) val sid = data.readInt() sessions[sid]?.stop() reply?.writeNoException() return true } - TRANSACTION_cancelSession -> { + TRANSACTION_CANCEL_SESSION -> { data.enforceInterface(DESCRIPTOR_SVC) val sid = data.readInt() sessions[sid]?.cancel() @@ -106,26 +150,35 @@ class ExternalSpeechService : Service() { reply?.writeNoException() return true } - TRANSACTION_isRecording -> { + TRANSACTION_IS_RECORDING -> { data.enforceInterface(DESCRIPTOR_SVC) val sid = data.readInt() val r = sessions[sid]?.engine?.isRunning == true - reply?.apply { writeNoException(); writeInt(if (r) 1 else 0) } + reply?.apply { + writeNoException() + writeInt(if (r) 1 else 0) + } return true } - TRANSACTION_isAnyRecording -> { + TRANSACTION_IS_ANY_RECORDING -> { data.enforceInterface(DESCRIPTOR_SVC) val r = sessions.values.any { it.engine?.isRunning == true } - reply?.apply { writeNoException(); writeInt(if (r) 1 else 0) } + reply?.apply { + writeNoException() + writeInt(if (r) 1 else 0) + } return true } - TRANSACTION_getVersion -> { + TRANSACTION_GET_VERSION -> { data.enforceInterface(DESCRIPTOR_SVC) - reply?.apply { writeNoException(); writeString(com.brycewg.asrkb.BuildConfig.VERSION_NAME) } + reply?.apply { + writeNoException() + writeString(com.brycewg.asrkb.BuildConfig.VERSION_NAME) + } return true } // ================= 推送 PCM 模式 ================= - TRANSACTION_startPcmSession -> { + TRANSACTION_START_PCM_SESSION -> { data.enforceInterface(DESCRIPTOR_SVC) if (data.readInt() != 0) SpeechConfig.CREATOR.createFromParcel(data) else null val cbBinder = data.readStrongBinder() @@ -133,26 +186,38 @@ class ExternalSpeechService : Service() { if (!prefs.externalAidlEnabled) { safe { cb.onError(-1, 403, "feature disabled") } - reply?.apply { writeNoException(); writeInt(-3) } + reply?.apply { + writeNoException() + writeInt(-3) + } return true } if (sessions.values.any { it.engine?.isRunning == true }) { - reply?.apply { writeNoException(); writeInt(-2) } + reply?.apply { + writeNoException() + writeInt(-2) + } return true } val sid = synchronized(this@ExternalSpeechService) { nextId++ } val s = Session(sid, this@ExternalSpeechService, prefs, cb) if (!s.preparePushPcm()) { - reply?.apply { writeNoException(); writeInt(-5) } + reply?.apply { + writeNoException() + writeInt(-5) + } return true } sessions[sid] = s s.start() - reply?.apply { writeNoException(); writeInt(sid) } + reply?.apply { + writeNoException() + writeInt(sid) + } return true } - TRANSACTION_writePcm -> { + TRANSACTION_WRITE_PCM -> { data.enforceInterface(DESCRIPTOR_SVC) val sid = data.readInt() val bytes = data.createByteArray() ?: ByteArray(0) @@ -162,7 +227,7 @@ class ExternalSpeechService : Service() { reply?.writeNoException() return true } - TRANSACTION_finishPcm -> { + TRANSACTION_FINISH_PCM -> { data.enforceInterface(DESCRIPTOR_SVC) val sid = data.readInt() sessions[sid]?.stop() @@ -178,14 +243,15 @@ class ExternalSpeechService : Service() { private val id: Int, private val context: Context, private val prefs: Prefs, - private val cb: CallbackProxy + private val cb: CallbackProxy, ) : StreamingAsrEngine.Listener { var engine: StreamingAsrEngine? = null private var autoStopSuppression: AutoCloseable? = null + // 统计:录音起止与耗时(用于历史记录展示) - private var sessionStartUptimeMs: Long = 0L - private var sessionStartTotalUptimeMs: Long = 0L - private var lastAudioMsForStats: Long = 0L + private var sessionStartUptimeMs: Long = 0L + private var sessionStartTotalUptimeMs: Long = 0L + private var lastAudioMsForStats: Long = 0L private var lastRequestDurationMs: Long? = null private var lastPostprocPreview: String? = null private var vendor: AsrVendor? = null @@ -198,9 +264,13 @@ class ExternalSpeechService : Service() { private val sessionJob = SupervisorJob() private val sessionScope = CoroutineScope(sessionJob + Dispatchers.Default) private val processingTimeoutLock = Any() + @Volatile private var processingTimeoutJob: Job? = null + @Volatile private var finished: Boolean = false + @Volatile private var canceled: Boolean = false + @Volatile private var hasAsrPartial: Boolean = false private fun ensureAutoStopSuppressed() { @@ -275,7 +345,7 @@ class ExternalSpeechService : Service() { cancelLocalModelReadyWait() } - private fun computeProcMsForStats(): Long { + private fun computeProcMsForStats(): Long { val fromEngine = lastRequestDurationMs if (fromEngine != null) return fromEngine val start = processingStartUptimeMs @@ -284,28 +354,32 @@ class ExternalSpeechService : Service() { val total = (end - start).coerceAtLeast(0L) val wait = localModelReadyWaitMs.get().coerceAtLeast(0L) return (total - wait).coerceAtLeast(0L) - } + } - private fun popTotalElapsedMsForStats(): Long { - val start = sessionStartTotalUptimeMs - if (start <= 0L) return 0L - val now = try { - SystemClock.uptimeMillis() - } catch (t: Throwable) { - Log.w(TAG, "Failed to read uptime for total elapsed ms", t) - sessionStartTotalUptimeMs = 0L - return 0L - } - val elapsed = if (now >= start) (now - start).coerceAtLeast(0L) else 0L - sessionStartTotalUptimeMs = if (engine?.isRunning == true) now else 0L - return elapsed - } + private fun popTotalElapsedMsForStats(): Long { + val start = sessionStartTotalUptimeMs + if (start <= 0L) return 0L + val now = try { + SystemClock.uptimeMillis() + } catch (t: Throwable) { + Log.w(TAG, "Failed to read uptime for total elapsed ms", t) + sessionStartTotalUptimeMs = 0L + return 0L + } + val elapsed = if (now >= start) (now - start).coerceAtLeast(0L) else 0L + sessionStartTotalUptimeMs = if (engine?.isRunning == true) now else 0L + return elapsed + } private fun resolveFinalVendorForRecord(): AsrVendor { val e = engine return when (e) { is ParallelAsrEngine -> if (e.wasLastResultFromBackup()) e.backupVendor else e.primaryVendor - else -> vendor ?: try { prefs.asrVendor } catch (_: Throwable) { AsrVendor.Volc } + else -> vendor ?: try { + prefs.asrVendor + } catch (_: Throwable) { + AsrVendor.Volc + } } } @@ -376,7 +450,7 @@ class ExternalSpeechService : Service() { listener = this, primaryVendor = primaryVendor, backupVendor = backupVendor, - onPrimaryRequestDuration = ::onRequestDuration + onPrimaryRequestDuration = ::onRequestDuration, ) } else { buildEngine(primaryVendor, streamingPref) @@ -399,7 +473,7 @@ class ExternalSpeechService : Service() { primaryVendor = primaryVendor, backupVendor = backupVendor, onPrimaryRequestDuration = ::onRequestDuration, - externalPcmInput = true + externalPcmInput = true, ) } else { buildPushPcmEngine(primaryVendor, streamingPref) @@ -407,15 +481,15 @@ class ExternalSpeechService : Service() { return engine != null } - fun start() { - safe { cb.onState(id, STATE_RECORDING, "recording") } - try { - sessionStartUptimeMs = SystemClock.uptimeMillis() - sessionStartTotalUptimeMs = sessionStartUptimeMs - // 新会话开始时重置上次请求耗时,避免串台(流式模式不会更新此值) - lastRequestDurationMs = null - lastAudioMsForStats = 0L - lastPostprocPreview = null + fun start() { + safe { cb.onState(id, STATE_RECORDING, "recording") } + try { + sessionStartUptimeMs = SystemClock.uptimeMillis() + sessionStartTotalUptimeMs = sessionStartUptimeMs + // 新会话开始时重置上次请求耗时,避免串台(流式模式不会更新此值) + lastRequestDurationMs = null + lastAudioMsForStats = 0L + lastPostprocPreview = null processingStartUptimeMs = 0L processingEndUptimeMs = 0L localModelWaitStartUptimeMs = 0L @@ -426,18 +500,18 @@ class ExternalSpeechService : Service() { hasAsrPartial = false finished = false cancelProcessingTimeout() - } catch (t: Throwable) { - Log.w(TAG, "Failed to mark session start", t) - } - ensureAutoStopSuppressed() - engine?.start() - } + } catch (t: Throwable) { + Log.w(TAG, "Failed to mark session start", t) + } + ensureAutoStopSuppressed() + engine?.start() + } fun stop() { if (canceled || finished) return releaseAutoStopSuppression() // 记录一次会话录音时长(用于超时与统计);部分引擎 stop() 不会回调 onStopped(如外部推流的本地流式),因此这里也做一次兜底快照。 - if (sessionStartUptimeMs > 0L) { + if (sessionStartUptimeMs > 0L) { try { if (lastAudioMsForStats == 0L) { val dur = (SystemClock.uptimeMillis() - sessionStartUptimeMs).coerceAtLeast(0) @@ -449,14 +523,14 @@ class ExternalSpeechService : Service() { sessionStartUptimeMs = 0L } } - if (processingStartUptimeMs == 0L) { - processingStartUptimeMs = SystemClock.uptimeMillis() - } + if (processingStartUptimeMs == 0L) { + processingStartUptimeMs = SystemClock.uptimeMillis() + } markLocalModelProcessingStartIfNeeded() scheduleProcessingTimeoutIfNeeded() - engine?.stop() - safe { cb.onState(id, STATE_PROCESSING, "processing") } - } + engine?.stop() + safe { cb.onState(id, STATE_PROCESSING, "processing") } + } fun cancel() { canceled = true @@ -496,7 +570,11 @@ class ExternalSpeechService : Service() { } private fun shouldUseBackupAsr(primaryVendor: AsrVendor, backupVendor: AsrVendor): Boolean { - val enabled = try { prefs.backupAsrEnabled } catch (_: Throwable) { false } + val enabled = try { + prefs.backupAsrEnabled + } catch (_: Throwable) { + false + } if (!enabled) return false if (backupVendor == primaryVendor) return false return try { @@ -535,60 +613,90 @@ class ExternalSpeechService : Service() { scope, prefs, this, - onRequestDuration = ::onRequestDuration + onRequestDuration = ::onRequestDuration, ) } AsrVendor.SiliconFlow -> SiliconFlowFileAsrEngine( - context, scope, prefs, this, - onRequestDuration = ::onRequestDuration + context, + scope, + prefs, + this, + onRequestDuration = ::onRequestDuration, ) AsrVendor.ElevenLabs -> if (streamingPreferred) { ElevenLabsStreamAsrEngine(context, scope, prefs, this) } else { ElevenLabsFileAsrEngine( - context, scope, prefs, this, - onRequestDuration = ::onRequestDuration + context, + scope, + prefs, + this, + onRequestDuration = ::onRequestDuration, ) } AsrVendor.OpenAI -> OpenAiFileAsrEngine( - context, scope, prefs, this, - onRequestDuration = ::onRequestDuration + context, + scope, + prefs, + this, + onRequestDuration = ::onRequestDuration, ) AsrVendor.DashScope -> if (streamingPreferred) { DashscopeStreamAsrEngine(context, scope, prefs, this) } else { DashscopeFileAsrEngine( - context, scope, prefs, this, - onRequestDuration = ::onRequestDuration + context, + scope, + prefs, + this, + onRequestDuration = ::onRequestDuration, ) } AsrVendor.Gemini -> GeminiFileAsrEngine( - context, scope, prefs, this, - onRequestDuration = ::onRequestDuration + context, + scope, + prefs, + this, + onRequestDuration = ::onRequestDuration, ) AsrVendor.Soniox -> if (streamingPreferred) { SonioxStreamAsrEngine(context, scope, prefs, this) } else { SonioxFileAsrEngine( - context, scope, prefs, this, - onRequestDuration = ::onRequestDuration + context, + scope, + prefs, + this, + onRequestDuration = ::onRequestDuration, ) } AsrVendor.Zhipu -> ZhipuFileAsrEngine( - context, scope, prefs, this, - onRequestDuration = ::onRequestDuration + context, + scope, + prefs, + this, + onRequestDuration = ::onRequestDuration, ) AsrVendor.SenseVoice -> SenseVoiceFileAsrEngine( - context, scope, prefs, this, - onRequestDuration = ::onRequestDuration + context, + scope, + prefs, + this, + onRequestDuration = ::onRequestDuration, ) AsrVendor.FunAsrNano -> FunAsrNanoFileAsrEngine( - context, scope, prefs, this, - onRequestDuration = ::onRequestDuration + context, + scope, + prefs, + this, + onRequestDuration = ::onRequestDuration, ) AsrVendor.Telespeech -> TelespeechFileAsrEngine( - context, scope, prefs, this, - onRequestDuration = ::onRequestDuration + context, + scope, + prefs, + this, + onRequestDuration = ::onRequestDuration, ) AsrVendor.Paraformer -> ParaformerStreamAsrEngine(context, scope, prefs, this) } @@ -602,19 +710,31 @@ class ExternalSpeechService : Service() { } else { if (prefs.volcFileStandardEnabled) { com.brycewg.asrkb.asr.GenericPushFileAsrAdapter( - context, scope, prefs, this, + context, + scope, + prefs, + this, com.brycewg.asrkb.asr.VolcStandardFileAsrEngine( - context, scope, prefs, this, - onRequestDuration = ::onRequestDuration - ) + context, + scope, + prefs, + this, + onRequestDuration = ::onRequestDuration, + ), ) } else { com.brycewg.asrkb.asr.GenericPushFileAsrAdapter( - context, scope, prefs, this, + context, + scope, + prefs, + this, com.brycewg.asrkb.asr.VolcFileAsrEngine( - context, scope, prefs, this, - onRequestDuration = ::onRequestDuration - ) + context, + scope, + prefs, + this, + onRequestDuration = ::onRequestDuration, + ), ) } } @@ -623,11 +743,17 @@ class ExternalSpeechService : Service() { com.brycewg.asrkb.asr.DashscopeStreamAsrEngine(context, scope, prefs, this, externalPcmMode = true) } else { com.brycewg.asrkb.asr.GenericPushFileAsrAdapter( - context, scope, prefs, this, + context, + scope, + prefs, + this, com.brycewg.asrkb.asr.DashscopeFileAsrEngine( - context, scope, prefs, this, - onRequestDuration = ::onRequestDuration - ) + context, + scope, + prefs, + this, + onRequestDuration = ::onRequestDuration, + ), ) } // Soniox:依据设置走流式或非流式 @@ -635,11 +761,17 @@ class ExternalSpeechService : Service() { com.brycewg.asrkb.asr.SonioxStreamAsrEngine(context, scope, prefs, this, externalPcmMode = true) } else { com.brycewg.asrkb.asr.GenericPushFileAsrAdapter( - context, scope, prefs, this, + context, + scope, + prefs, + this, com.brycewg.asrkb.asr.SonioxFileAsrEngine( - context, scope, prefs, this, - onRequestDuration = ::onRequestDuration - ) + context, + scope, + prefs, + this, + onRequestDuration = ::onRequestDuration, + ), ) } // 其他云厂商:仅非流式(若供应商另行支持流式则走对应分支) @@ -647,40 +779,70 @@ class ExternalSpeechService : Service() { com.brycewg.asrkb.asr.ElevenLabsStreamAsrEngine(context, scope, prefs, this, externalPcmMode = true) } else { com.brycewg.asrkb.asr.GenericPushFileAsrAdapter( - context, scope, prefs, this, + context, + scope, + prefs, + this, com.brycewg.asrkb.asr.ElevenLabsFileAsrEngine( - context, scope, prefs, this, - onRequestDuration = ::onRequestDuration - ) + context, + scope, + prefs, + this, + onRequestDuration = ::onRequestDuration, + ), ) } AsrVendor.OpenAI -> com.brycewg.asrkb.asr.GenericPushFileAsrAdapter( - context, scope, prefs, this, + context, + scope, + prefs, + this, com.brycewg.asrkb.asr.OpenAiFileAsrEngine( - context, scope, prefs, this, - onRequestDuration = ::onRequestDuration - ) + context, + scope, + prefs, + this, + onRequestDuration = ::onRequestDuration, + ), ) AsrVendor.Gemini -> com.brycewg.asrkb.asr.GenericPushFileAsrAdapter( - context, scope, prefs, this, + context, + scope, + prefs, + this, com.brycewg.asrkb.asr.GeminiFileAsrEngine( - context, scope, prefs, this, - onRequestDuration = ::onRequestDuration - ) + context, + scope, + prefs, + this, + onRequestDuration = ::onRequestDuration, + ), ) AsrVendor.SiliconFlow -> com.brycewg.asrkb.asr.GenericPushFileAsrAdapter( - context, scope, prefs, this, + context, + scope, + prefs, + this, com.brycewg.asrkb.asr.SiliconFlowFileAsrEngine( - context, scope, prefs, this, - onRequestDuration = ::onRequestDuration - ) + context, + scope, + prefs, + this, + onRequestDuration = ::onRequestDuration, + ), ) AsrVendor.Zhipu -> com.brycewg.asrkb.asr.GenericPushFileAsrAdapter( - context, scope, prefs, this, + context, + scope, + prefs, + this, com.brycewg.asrkb.asr.ZhipuFileAsrEngine( - context, scope, prefs, this, - onRequestDuration = ::onRequestDuration - ) + context, + scope, + prefs, + this, + onRequestDuration = ::onRequestDuration, + ), ) // 本地:Paraformer 固定流式 AsrVendor.Paraformer -> com.brycewg.asrkb.asr.ParaformerStreamAsrEngine(context, scope, prefs, this, externalPcmMode = true) @@ -688,16 +850,25 @@ class ExternalSpeechService : Service() { AsrVendor.SenseVoice -> { if (prefs.svPseudoStreamEnabled) { com.brycewg.asrkb.asr.SenseVoicePushPcmPseudoStreamAsrEngine( - context, scope, prefs, this, - onRequestDuration = ::onRequestDuration + context, + scope, + prefs, + this, + onRequestDuration = ::onRequestDuration, ) } else { com.brycewg.asrkb.asr.GenericPushFileAsrAdapter( - context, scope, prefs, this, + context, + scope, + prefs, + this, com.brycewg.asrkb.asr.SenseVoiceFileAsrEngine( - context, scope, prefs, this, - onRequestDuration = ::onRequestDuration - ) + context, + scope, + prefs, + this, + onRequestDuration = ::onRequestDuration, + ), ) } } @@ -705,27 +876,42 @@ class ExternalSpeechService : Service() { AsrVendor.FunAsrNano -> { // FunASR Nano 算力开销高:不支持伪流式预览,仅保留整段离线识别 com.brycewg.asrkb.asr.GenericPushFileAsrAdapter( - context, scope, prefs, this, + context, + scope, + prefs, + this, com.brycewg.asrkb.asr.FunAsrNanoFileAsrEngine( - context, scope, prefs, this, - onRequestDuration = ::onRequestDuration - ) + context, + scope, + prefs, + this, + onRequestDuration = ::onRequestDuration, + ), ) } // TeleSpeech:支持伪流式(VAD 分片预览 + 整段离线识别) AsrVendor.Telespeech -> { if (prefs.tsPseudoStreamEnabled) { com.brycewg.asrkb.asr.TelespeechPushPcmPseudoStreamAsrEngine( - context, scope, prefs, this, - onRequestDuration = ::onRequestDuration + context, + scope, + prefs, + this, + onRequestDuration = ::onRequestDuration, ) } else { com.brycewg.asrkb.asr.GenericPushFileAsrAdapter( - context, scope, prefs, this, + context, + scope, + prefs, + this, com.brycewg.asrkb.asr.TelespeechFileAsrEngine( - context, scope, prefs, this, - onRequestDuration = ::onRequestDuration - ) + context, + scope, + prefs, + this, + onRequestDuration = ::onRequestDuration, + ), ) } } @@ -748,7 +934,11 @@ class ExternalSpeechService : Service() { Log.w(TAG, "Failed to compute audio duration on final", t) } } - val doAi = try { prefs.postProcessEnabled && prefs.hasLlmKeys() } catch (_: Throwable) { false } + val doAi = try { + prefs.postProcessEnabled && prefs.hasLlmKeys() + } catch (_: Throwable) { + false + } if (doAi) { if (!hasAsrPartial && text.isNotEmpty()) { hasAsrPartial = true @@ -757,7 +947,11 @@ class ExternalSpeechService : Service() { // 执行带 AI 的完整后处理链(IO 在线程内切换) CoroutineScope(Dispatchers.Main).launch { if (canceled) return@launch - val typewriterEnabled = try { prefs.postprocTypewriterEnabled } catch (_: Throwable) { true } + val typewriterEnabled = try { + prefs.postprocTypewriterEnabled + } catch (_: Throwable) { + true + } var postprocCommitted = false var lastPostprocTarget: String? = null val typewriter = if (typewriterEnabled) { @@ -774,7 +968,7 @@ class ExternalSpeechService : Service() { normalTargetFrames = 18, normalMaxStep = 6, rushTargetFrames = 8, - rushMaxStep = 24 + rushMaxStep = 24, ) } else { null @@ -795,103 +989,111 @@ class ExternalSpeechService : Service() { var aiPostMs = 0L var aiPostStatus = com.brycewg.asrkb.store.AsrHistoryStore.AiPostStatus.NONE val out = try { - val res = com.brycewg.asrkb.util.AsrFinalFilters.applyWithAi( - context, - prefs, - text, - onStreamingUpdate = onStreamingUpdate - ) - aiUsed = (res.usedAi && res.ok) - aiPostMs = if (res.attempted) res.llmMs else 0L - aiPostStatus = when { - res.attempted && aiUsed -> com.brycewg.asrkb.store.AsrHistoryStore.AiPostStatus.SUCCESS - res.attempted -> com.brycewg.asrkb.store.AsrHistoryStore.AiPostStatus.FAILED - else -> com.brycewg.asrkb.store.AsrHistoryStore.AiPostStatus.NONE - } - - val processed = res.text - val finalOut = processed.ifBlank { - // AI 返回空:回退到简单后处理(包含正则/繁体) - try { - com.brycewg.asrkb.util.AsrFinalFilters.applySimple( + val res = com.brycewg.asrkb.util.AsrFinalFilters.applyWithAi( context, prefs, - text - ) - } catch (_: Throwable) { - text + text, + onStreamingUpdate = onStreamingUpdate, + ) + aiUsed = (res.usedAi && res.ok) + aiPostMs = if (res.attempted) res.llmMs else 0L + aiPostStatus = when { + res.attempted && aiUsed -> com.brycewg.asrkb.store.AsrHistoryStore.AiPostStatus.SUCCESS + res.attempted -> com.brycewg.asrkb.store.AsrHistoryStore.AiPostStatus.FAILED + else -> com.brycewg.asrkb.store.AsrHistoryStore.AiPostStatus.NONE } - } - if (typewriter != null && aiUsed && finalOut.isNotEmpty()) { - typewriter.submit(finalOut, rush = true) - val finalLen = finalOut.length - val t0 = SystemClock.uptimeMillis() - while (!canceled && (SystemClock.uptimeMillis() - t0) < 2_000L && - typewriter.currentText().length != finalLen - ) { - delay(20) + + val processed = res.text + val finalOut = processed.ifBlank { + // AI 返回空:回退到简单后处理(包含正则/繁体) + try { + com.brycewg.asrkb.util.AsrFinalFilters.applySimple( + context, + prefs, + text, + ) + } catch (_: Throwable) { + text + } + } + if (typewriter != null && aiUsed && finalOut.isNotEmpty()) { + typewriter.submit(finalOut, rush = true) + val finalLen = finalOut.length + val t0 = SystemClock.uptimeMillis() + while (!canceled && (SystemClock.uptimeMillis() - t0) < 2_000L && + typewriter.currentText().length != finalLen + ) { + delay(20) + } } - } - finalOut + finalOut } catch (t: Throwable) { - Log.w(TAG, "applyWithAi failed, fallback to simple", t) - aiUsed = false - aiPostMs = 0L - aiPostStatus = com.brycewg.asrkb.store.AsrHistoryStore.AiPostStatus.FAILED - try { - com.brycewg.asrkb.util.AsrFinalFilters.applySimple(context, prefs, text) - } catch (_: Throwable) { - text - } + Log.w(TAG, "applyWithAi failed, fallback to simple", t) + aiUsed = false + aiPostMs = 0L + aiPostStatus = com.brycewg.asrkb.store.AsrHistoryStore.AiPostStatus.FAILED + try { + com.brycewg.asrkb.util.AsrFinalFilters.applySimple(context, prefs, text) + } catch (_: Throwable) { + text + } } finally { - postprocCommitted = true - typewriter?.cancel() + postprocCommitted = true + typewriter?.cancel() } if (canceled) return@launch - // 记录使用统计与识别历史(来源标记为 external;尊重开关) - try { - val audioMs = lastAudioMsForStats - val totalElapsedMs = popTotalElapsedMsForStats() - val procMs = computeProcMsForStats() - val chars = try { com.brycewg.asrkb.util.TextSanitizer.countEffectiveChars(out) } catch (_: Throwable) { out.length } - val vendorForRecord = resolveFinalVendorForRecord() - AnalyticsManager.recordAsrEvent( - context = context, - vendorId = vendorForRecord.id, - audioMs = audioMs, - procMs = procMs, - source = "external", - aiProcessed = aiUsed, - charCount = chars - ) + // 记录使用统计与识别历史(来源标记为 external;尊重开关) + try { + val audioMs = lastAudioMsForStats + val totalElapsedMs = popTotalElapsedMsForStats() + val procMs = computeProcMsForStats() + val chars = try { + com.brycewg.asrkb.util.TextSanitizer.countEffectiveChars(out) + } catch (_: Throwable) { + out.length + } + val vendorForRecord = resolveFinalVendorForRecord() + AnalyticsManager.recordAsrEvent( + context = context, + vendorId = vendorForRecord.id, + audioMs = audioMs, + procMs = procMs, + source = "external", + aiProcessed = aiUsed, + charCount = chars, + ) if (!prefs.disableUsageStats) { prefs.recordUsageCommit("external", vendorForRecord, audioMs, chars, procMs) } if (!prefs.disableAsrHistory) { val store = com.brycewg.asrkb.store.AsrHistoryStore(context) - store.add( - com.brycewg.asrkb.store.AsrHistoryStore.AsrHistoryRecord( - timestamp = System.currentTimeMillis(), - text = out, - vendorId = vendorForRecord.id, - audioMs = audioMs, - totalElapsedMs = totalElapsedMs, - procMs = procMs, - source = "external", - aiProcessed = aiUsed, - aiPostMs = aiPostMs, - aiPostStatus = aiPostStatus, - charCount = chars - ) - ) - } + store.add( + com.brycewg.asrkb.store.AsrHistoryStore.AsrHistoryRecord( + timestamp = System.currentTimeMillis(), + text = out, + vendorId = vendorForRecord.id, + audioMs = audioMs, + totalElapsedMs = totalElapsedMs, + procMs = procMs, + source = "external", + aiProcessed = aiUsed, + aiPostMs = aiPostMs, + aiPostStatus = aiPostStatus, + charCount = chars, + ), + ) + } } catch (e: Exception) { Log.e(TAG, "Failed to add ASR history (external, ai)", e) } if (canceled) return@launch safe { cb.onFinal(id, out) } safe { cb.onState(id, STATE_IDLE, "final") } - try { (context as? ExternalSpeechService)?.onSessionDone(id) } catch (t: Throwable) { Log.w(TAG, "remove session on final failed", t) } + try { + (context as? ExternalSpeechService)?.onSessionDone(id) + } catch (t: Throwable) { + Log.w(TAG, "remove session on final failed", t) + } } } else { if (canceled) return @@ -902,13 +1104,17 @@ class ExternalSpeechService : Service() { Log.w(TAG, "applySimple failed, fallback to raw text", t) text } - // 记录使用统计与识别历史(来源标记为 external;尊重开关) - try { - val audioMs = lastAudioMsForStats - val totalElapsedMs = popTotalElapsedMsForStats() - val procMs = computeProcMsForStats() - val chars = try { com.brycewg.asrkb.util.TextSanitizer.countEffectiveChars(out) } catch (_: Throwable) { out.length } - val vendorForRecord = resolveFinalVendorForRecord() + // 记录使用统计与识别历史(来源标记为 external;尊重开关) + try { + val audioMs = lastAudioMsForStats + val totalElapsedMs = popTotalElapsedMsForStats() + val procMs = computeProcMsForStats() + val chars = try { + com.brycewg.asrkb.util.TextSanitizer.countEffectiveChars(out) + } catch (_: Throwable) { + out.length + } + val vendorForRecord = resolveFinalVendorForRecord() AnalyticsManager.recordAsrEvent( context = context, vendorId = vendorForRecord.id, @@ -916,33 +1122,37 @@ class ExternalSpeechService : Service() { procMs = procMs, source = "external", aiProcessed = false, - charCount = chars + charCount = chars, ) if (!prefs.disableUsageStats) { prefs.recordUsageCommit("external", vendorForRecord, audioMs, chars, procMs) } if (!prefs.disableAsrHistory) { val store = com.brycewg.asrkb.store.AsrHistoryStore(context) - store.add( - com.brycewg.asrkb.store.AsrHistoryStore.AsrHistoryRecord( - timestamp = System.currentTimeMillis(), - text = out, - vendorId = vendorForRecord.id, - audioMs = audioMs, - totalElapsedMs = totalElapsedMs, - procMs = procMs, - source = "external", - aiProcessed = false, - charCount = chars - ) - ) - } + store.add( + com.brycewg.asrkb.store.AsrHistoryStore.AsrHistoryRecord( + timestamp = System.currentTimeMillis(), + text = out, + vendorId = vendorForRecord.id, + audioMs = audioMs, + totalElapsedMs = totalElapsedMs, + procMs = procMs, + source = "external", + aiProcessed = false, + charCount = chars, + ), + ) + } } catch (e: Exception) { Log.e(TAG, "Failed to add ASR history (external, simple)", e) } safe { cb.onFinal(id, out) } safe { cb.onState(id, STATE_IDLE, "final") } - try { (context as? ExternalSpeechService)?.onSessionDone(id) } catch (t: Throwable) { Log.w(TAG, "remove session on final failed", t) } + try { + (context as? ExternalSpeechService)?.onSessionDone(id) + } catch (t: Throwable) { + Log.w(TAG, "remove session on final failed", t) + } } } @@ -957,7 +1167,11 @@ class ExternalSpeechService : Service() { cb.onError(id, 500, message) cb.onState(id, STATE_ERROR, message) } - try { (context as? ExternalSpeechService)?.onSessionDone(id) } catch (t: Throwable) { Log.w(TAG, "remove session on error failed", t) } + try { + (context as? ExternalSpeechService)?.onSessionDone(id) + } catch (t: Throwable) { + Log.w(TAG, "remove session on error failed", t) + } } override fun onPartial(text: String) { @@ -997,7 +1211,7 @@ class ExternalSpeechService : Service() { private class CallbackProxy(private val remote: IBinder?) { fun onState(sessionId: Int, state: Int, msg: String) { - transact(CB_onState) { data -> + transact(CB_ON_STATE) { data -> data.writeInterfaceToken(DESCRIPTOR_CB) data.writeInt(sessionId) data.writeInt(state) @@ -1005,21 +1219,21 @@ class ExternalSpeechService : Service() { } } fun onPartial(sessionId: Int, text: String) { - transact(CB_onPartial) { data -> + transact(CB_ON_PARTIAL) { data -> data.writeInterfaceToken(DESCRIPTOR_CB) data.writeInt(sessionId) data.writeString(text) } } fun onFinal(sessionId: Int, text: String) { - transact(CB_onFinal) { data -> + transact(CB_ON_FINAL) { data -> data.writeInterfaceToken(DESCRIPTOR_CB) data.writeInt(sessionId) data.writeString(text) } } fun onError(sessionId: Int, code: Int, message: String) { - transact(CB_onError) { data -> + transact(CB_ON_ERROR) { data -> data.writeInterfaceToken(DESCRIPTOR_CB) data.writeInt(sessionId) data.writeInt(code) @@ -1027,7 +1241,7 @@ class ExternalSpeechService : Service() { } } fun onAmplitude(sessionId: Int, amp: Float) { - transact(CB_onAmplitude) { data -> + transact(CB_ON_AMPLITUDE) { data -> data.writeInterfaceToken(DESCRIPTOR_CB) data.writeInt(sessionId) data.writeFloat(amp) @@ -1064,22 +1278,22 @@ class ExternalSpeechService : Service() { // 与 AIDL 生成的 Stub 保持一致的描述符与事务号 private const val DESCRIPTOR_SVC = "com.brycewg.asrkb.aidl.IExternalSpeechService" - private const val TRANSACTION_startSession = IBinder.FIRST_CALL_TRANSACTION + 0 - private const val TRANSACTION_stopSession = IBinder.FIRST_CALL_TRANSACTION + 1 - private const val TRANSACTION_cancelSession = IBinder.FIRST_CALL_TRANSACTION + 2 - private const val TRANSACTION_isRecording = IBinder.FIRST_CALL_TRANSACTION + 3 - private const val TRANSACTION_isAnyRecording = IBinder.FIRST_CALL_TRANSACTION + 4 - private const val TRANSACTION_getVersion = IBinder.FIRST_CALL_TRANSACTION + 5 - private const val TRANSACTION_startPcmSession = IBinder.FIRST_CALL_TRANSACTION + 6 - private const val TRANSACTION_writePcm = IBinder.FIRST_CALL_TRANSACTION + 7 - private const val TRANSACTION_finishPcm = IBinder.FIRST_CALL_TRANSACTION + 8 + private const val TRANSACTION_START_SESSION = IBinder.FIRST_CALL_TRANSACTION + 0 + private const val TRANSACTION_STOP_SESSION = IBinder.FIRST_CALL_TRANSACTION + 1 + private const val TRANSACTION_CANCEL_SESSION = IBinder.FIRST_CALL_TRANSACTION + 2 + private const val TRANSACTION_IS_RECORDING = IBinder.FIRST_CALL_TRANSACTION + 3 + private const val TRANSACTION_IS_ANY_RECORDING = IBinder.FIRST_CALL_TRANSACTION + 4 + private const val TRANSACTION_GET_VERSION = IBinder.FIRST_CALL_TRANSACTION + 5 + private const val TRANSACTION_START_PCM_SESSION = IBinder.FIRST_CALL_TRANSACTION + 6 + private const val TRANSACTION_WRITE_PCM = IBinder.FIRST_CALL_TRANSACTION + 7 + private const val TRANSACTION_FINISH_PCM = IBinder.FIRST_CALL_TRANSACTION + 8 private const val DESCRIPTOR_CB = "com.brycewg.asrkb.aidl.ISpeechCallback" - private const val CB_onState = IBinder.FIRST_CALL_TRANSACTION + 0 - private const val CB_onPartial = IBinder.FIRST_CALL_TRANSACTION + 1 - private const val CB_onFinal = IBinder.FIRST_CALL_TRANSACTION + 2 - private const val CB_onError = IBinder.FIRST_CALL_TRANSACTION + 3 - private const val CB_onAmplitude = IBinder.FIRST_CALL_TRANSACTION + 4 + private const val CB_ON_STATE = IBinder.FIRST_CALL_TRANSACTION + 0 + private const val CB_ON_PARTIAL = IBinder.FIRST_CALL_TRANSACTION + 1 + private const val CB_ON_FINAL = IBinder.FIRST_CALL_TRANSACTION + 2 + private const val CB_ON_ERROR = IBinder.FIRST_CALL_TRANSACTION + 3 + private const val CB_ON_AMPLITUDE = IBinder.FIRST_CALL_TRANSACTION + 4 private const val STATE_IDLE = 0 private const val STATE_RECORDING = 1 @@ -1089,7 +1303,13 @@ class ExternalSpeechService : Service() { private const val LOCAL_MODEL_READY_WAIT_MAX_MS = 60_000L private const val LOCAL_MODEL_READY_WAIT_CONSUMED = -1L - private inline fun safe(block: () -> Unit) { try { block() } catch (t: Throwable) { Log.w(TAG, "callback failed", t) } } + private inline fun safe(block: () -> Unit) { + try { + block() + } catch (t: Throwable) { + Log.w(TAG, "callback failed", t) + } + } } // 统一的会话清理入口:在 onFinal/onError 触发后移除,避免内存泄漏 diff --git a/app/src/main/java/com/brycewg/asrkb/asr/AsrErrorFormatter.kt b/app/src/main/java/com/brycewg/asrkb/asr/AsrErrorFormatter.kt index 965c8b9e..e86bba8b 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/AsrErrorFormatter.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/AsrErrorFormatter.kt @@ -1,10 +1,10 @@ package com.brycewg.asrkb.asr internal fun formatHttpDetail(message: String?, extra: String? = null): String { - val primary = message?.trim().orEmpty() - val extraPart = extra?.trim().orEmpty() - return buildString { - if (primary.isNotEmpty()) append(": ").append(primary) - if (extraPart.isNotEmpty()) append(" — ").append(extraPart) - } + val primary = message?.trim().orEmpty() + val extraPart = extra?.trim().orEmpty() + return buildString { + if (primary.isNotEmpty()) append(": ").append(primary) + if (extraPart.isNotEmpty()) append(" — ").append(extraPart) + } } diff --git a/app/src/main/java/com/brycewg/asrkb/asr/AsrErrorMessageMapper.kt b/app/src/main/java/com/brycewg/asrkb/asr/AsrErrorMessageMapper.kt index da5a6b00..c779d2d9 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/AsrErrorMessageMapper.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/AsrErrorMessageMapper.kt @@ -92,4 +92,3 @@ internal object AsrErrorMessageMapper { } } } - diff --git a/app/src/main/java/com/brycewg/asrkb/asr/AsrVendor.kt b/app/src/main/java/com/brycewg/asrkb/asr/AsrVendor.kt index 01066da0..f88d6c52 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/AsrVendor.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/AsrVendor.kt @@ -12,7 +12,8 @@ enum class AsrVendor(val id: String) { SenseVoice("sensevoice"), FunAsrNano("funasr_nano"), Telespeech("telespeech"), - Paraformer("paraformer"); + Paraformer("paraformer"), + ; companion object { fun fromId(id: String?): AsrVendor = when (id?.lowercase()) { diff --git a/app/src/main/java/com/brycewg/asrkb/asr/AsrVendorAvailability.kt b/app/src/main/java/com/brycewg/asrkb/asr/AsrVendorAvailability.kt index 1866b5cb..28a3d150 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/AsrVendorAvailability.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/AsrVendorAvailability.kt @@ -12,13 +12,13 @@ import java.io.File internal data class AsrVendorPartition( val configured: List, - val unconfigured: List + val unconfigured: List, ) internal fun partitionAsrVendorsByConfigured( context: Context, prefs: Prefs, - vendors: List + vendors: List, ): AsrVendorPartition { val configured = mutableListOf() val unconfigured = mutableListOf() @@ -31,14 +31,14 @@ internal fun partitionAsrVendorsByConfigured( } return AsrVendorPartition( configured = configured, - unconfigured = unconfigured + unconfigured = unconfigured, ) } internal fun isAsrVendorConfigured( context: Context, prefs: Prefs, - vendor: AsrVendor + vendor: AsrVendor, ): Boolean { return try { when (vendor) { diff --git a/app/src/main/java/com/brycewg/asrkb/asr/AudioCaptureManager.kt b/app/src/main/java/com/brycewg/asrkb/asr/AudioCaptureManager.kt index 9e97d594..ad836df3 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/AudioCaptureManager.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/AudioCaptureManager.kt @@ -15,12 +15,12 @@ import android.util.Log import androidx.annotation.RequiresApi import androidx.core.content.ContextCompat import com.brycewg.asrkb.store.Prefs +import com.brycewg.asrkb.store.debug.DebugLogManager import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.flow import kotlinx.coroutines.suspendCancellableCoroutine import kotlinx.coroutines.withTimeoutOrNull import kotlin.coroutines.resume -import com.brycewg.asrkb.store.debug.DebugLogManager /** * 音频采集管理器 @@ -43,7 +43,7 @@ class AudioCaptureManager( private val sampleRate: Int = 16000, private val channelConfig: Int = AudioFormat.CHANNEL_IN_MONO, private val audioFormat: Int = AudioFormat.ENCODING_PCM_16BIT, - private val chunkMillis: Int = 200 + private val chunkMillis: Int = 200, ) { private val bytesPerSample = 2 // 16bit mono PCM private val prefs by lazy { Prefs(context) } @@ -68,7 +68,7 @@ class AudioCaptureManager( fun hasPermission(): Boolean { return ContextCompat.checkSelfPermission( context, - Manifest.permission.RECORD_AUDIO + Manifest.permission.RECORD_AUDIO, ) == PackageManager.PERMISSION_GRANTED } @@ -80,7 +80,7 @@ class AudioCaptureManager( sampleRate = sampleRate, channelConfig = channelConfig, audioEncoding = audioFormat, - bufferSize = bufferSize + bufferSize = bufferSize, ) } else { AudioRecord( @@ -88,7 +88,7 @@ class AudioCaptureManager( sampleRate, channelConfig, audioFormat, - bufferSize + bufferSize, ) } } @@ -101,7 +101,7 @@ class AudioCaptureManager( sampleRate: Int, channelConfig: Int, audioEncoding: Int, - bufferSize: Int + bufferSize: Int, ): AudioRecord { val format = AudioFormat.Builder() .setSampleRate(sampleRate) @@ -137,15 +137,17 @@ class AudioCaptureManager( event = "acm_start", data = mapOf( "sr" to sampleRate, - "chunkMs" to chunkMillis - ) + "chunkMs" to chunkMillis, + ), ) } catch (_: Throwable) { } // 1. 权限检查 if (!hasPermission()) { val error = SecurityException("Missing RECORD_AUDIO permission") Log.e(TAG, "Permission check failed", error) - try { DebugLogManager.log("audio", "acm_error", mapOf("stage" to "perm", "msg" to error.message)) } catch (_: Throwable) { } + try { + DebugLogManager.log("audio", "acm_error", mapOf("stage" to "perm", "msg" to error.message)) + } catch (_: Throwable) { } throw error } @@ -163,7 +165,11 @@ class AudioCaptureManager( if (prefs.headsetMicPriorityEnabled) { if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.S) { // 若已有通信设备(可能由预热设置),则不重复设置,也不在 finally 清理 - val cur = try { audioManager.getCommunicationDevice() } catch (_: Throwable) { null } + val cur = try { + audioManager.getCommunicationDevice() + } catch (_: Throwable) { + null + } if (cur != null && (cur.type == AudioDeviceInfo.TYPE_BLE_HEADSET || cur.type == AudioDeviceInfo.TYPE_BLUETOOTH_SCO || cur.type == AudioDeviceInfo.TYPE_WIRED_HEADSET)) { preferredInputDevice = cur routePrepared = true @@ -195,7 +201,11 @@ class AudioCaptureManager( if (preferredInputDevice?.type == AudioDeviceInfo.TYPE_BLUETOOTH_SCO) { // 通话模式可改善部分设备的路由与增益(仅在当前非通话模式时切换) - val curMode = try { audioManager.mode } catch (_: Throwable) { AudioManager.MODE_NORMAL } + val curMode = try { + audioManager.mode + } catch (_: Throwable) { + AudioManager.MODE_NORMAL + } if (curMode != AudioManager.MODE_IN_COMMUNICATION) { previousAudioMode = curMode try { @@ -273,14 +283,20 @@ class AudioCaptureManager( try { activeRecorder.startRecording() Log.d(TAG, "AudioRecord started successfully") - try { DebugLogManager.log("audio", "recorder_started") } catch (_: Throwable) { } + try { + DebugLogManager.log("audio", "recorder_started") + } catch (_: Throwable) { } } catch (se: SecurityException) { Log.e(TAG, "SecurityException during startRecording", se) - try { DebugLogManager.log("audio", "acm_error", mapOf("stage" to "start", "type" to "security", "msg" to se.message)) } catch (_: Throwable) { } + try { + DebugLogManager.log("audio", "acm_error", mapOf("stage" to "start", "type" to "security", "msg" to se.message)) + } catch (_: Throwable) { } throw se } catch (t: Throwable) { Log.e(TAG, "Failed to start recording", t) - try { DebugLogManager.log("audio", "acm_error", mapOf("stage" to "start", "type" to "start_fail", "msg" to t.message)) } catch (_: Throwable) { } + try { + DebugLogManager.log("audio", "acm_error", mapOf("stage" to "start", "type" to "start_fail", "msg" to t.message)) + } catch (_: Throwable) { } throw IllegalStateException("Failed to start recording", t) } @@ -301,7 +317,9 @@ class AudioCaptureManager( activeRecorder.read(buf, 0, buf.size) } catch (t: Throwable) { Log.e(TAG, "Error reading audio data", t) - try { DebugLogManager.log("audio", "acm_error", mapOf("stage" to "read", "msg" to t.message)) } catch (_: Throwable) { } + try { + DebugLogManager.log("audio", "acm_error", mapOf("stage" to "read", "msg" to t.message)) + } catch (_: Throwable) { } throw IllegalStateException("Error reading audio data", t) } @@ -332,7 +350,7 @@ class AudioCaptureManager( try { DebugLogManager.log( category = "audio", - event = "acm_cleanup" + event = "acm_cleanup", ) } catch (_: Throwable) { } // 清理通信设备与 SCO / 恢复模式 @@ -367,7 +385,7 @@ class AudioCaptureManager( current: AudioRecord, buf: ByteArray, bufferSize: Int, - avoidMicFallback: Boolean + avoidMicFallback: Boolean, ): Pair { if (!hasPermission()) { val error = SecurityException("RECORD_AUDIO permission was revoked during warmup") @@ -516,7 +534,11 @@ class AudioCaptureManager( var listenerToken: Any? = null var setOk = false var routeReady = false - val t0 = try { android.os.SystemClock.elapsedRealtime() } catch (_: Throwable) { 0L } + val t0 = try { + android.os.SystemClock.elapsedRealtime() + } catch (_: Throwable) { + 0L + } try { val candidates = try { audioManager.getAvailableCommunicationDevices() @@ -545,7 +567,11 @@ class AudioCaptureManager( } if (!setOk) return CommRouteResult(false, selected, null, false) - val cur = try { audioManager.getCommunicationDevice() } catch (_: Throwable) { null } + val cur = try { + audioManager.getCommunicationDevice() + } catch (_: Throwable) { + null + } if (cur != null && selected.id == cur.id) { val dt = if (t0 > 0) (android.os.SystemClock.elapsedRealtime() - t0) else -1 if (dt >= 0) Log.i(TAG, "Communication device ready immediately in ${dt}ms (id=${cur.id})") @@ -556,7 +582,11 @@ class AudioCaptureManager( routeReady = withTimeoutOrNull(2000L) { suspendCancellableCoroutine { cont -> val exec = java.util.concurrent.Executor { r -> - try { r.run() } catch (t: Throwable) { Log.w(TAG, "CommDevice listener runnable error", t) } + try { + r.run() + } catch (t: Throwable) { + Log.w(TAG, "CommDevice listener runnable error", t) + } } val l = AudioManager.OnCommunicationDeviceChangedListener { dev -> selected?.let { sel -> @@ -575,7 +605,9 @@ class AudioCaptureManager( if (cont.isActive) cont.resume(false) } cont.invokeOnCancellation { - try { audioManager.removeOnCommunicationDeviceChangedListener(l) } catch (_: Throwable) {} + try { + audioManager.removeOnCommunicationDeviceChangedListener(l) + } catch (_: Throwable) {} } } } ?: false @@ -628,7 +660,11 @@ class AudioCaptureManager( val filter = IntentFilter(AudioManager.ACTION_SCO_AUDIO_STATE_UPDATED) var receiver: android.content.BroadcastReceiver? = null - val t0 = try { android.os.SystemClock.elapsedRealtime() } catch (_: Throwable) { 0L } + val t0 = try { + android.os.SystemClock.elapsedRealtime() + } catch (_: Throwable) { + 0L + } val ok = withTimeoutOrNull(2500L) { suspendCancellableCoroutine { cont -> receiver = object : android.content.BroadcastReceiver() { @@ -636,7 +672,7 @@ class AudioCaptureManager( if (intent?.action != AudioManager.ACTION_SCO_AUDIO_STATE_UPDATED) return val state = intent.getIntExtra( AudioManager.EXTRA_SCO_AUDIO_STATE, - AudioManager.SCO_AUDIO_STATE_ERROR + AudioManager.SCO_AUDIO_STATE_ERROR, ) when (state) { AudioManager.SCO_AUDIO_STATE_CONNECTED -> { @@ -648,7 +684,11 @@ class AudioCaptureManager( } } } - try { context.registerReceiver(receiver, filter) } catch (t: Throwable) { Log.w(TAG, "registerReceiver failed", t) } + try { + context.registerReceiver(receiver, filter) + } catch (t: Throwable) { + Log.w(TAG, "registerReceiver failed", t) + } try { if (!am.isBluetoothScoOn) am.startBluetoothSco() } catch (t: Throwable) { @@ -656,11 +696,15 @@ class AudioCaptureManager( if (cont.isActive) cont.resume(false) } cont.invokeOnCancellation { - try { receiver?.let { context.unregisterReceiver(it) } } catch (_: Throwable) {} + try { + receiver?.let { context.unregisterReceiver(it) } + } catch (_: Throwable) {} } } } ?: false - try { receiver?.let { context.unregisterReceiver(it) } } catch (_: Throwable) {} + try { + receiver?.let { context.unregisterReceiver(it) } + } catch (_: Throwable) {} ok } catch (t: Throwable) { Log.w(TAG, "startScoAndAwaitConnected exception", t) @@ -693,6 +737,6 @@ class AudioCaptureManager( val commDeviceSet: Boolean, val selectedDevice: AudioDeviceInfo?, val listenerToken: Any?, - val routeReady: Boolean + val routeReady: Boolean, ) } diff --git a/app/src/main/java/com/brycewg/asrkb/asr/AudioUtils.kt b/app/src/main/java/com/brycewg/asrkb/asr/AudioUtils.kt index 8087e737..aca9a2e3 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/AudioUtils.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/AudioUtils.kt @@ -26,7 +26,7 @@ data class FrameStats( val maxAbs: Int, val sumSquares: Long, val countAboveThreshold: Int, - val sampleCount: Int + val sampleCount: Int, ) /** diff --git a/app/src/main/java/com/brycewg/asrkb/asr/BaseFileAsrEngine.kt b/app/src/main/java/com/brycewg/asrkb/asr/BaseFileAsrEngine.kt index 8dbb8dfe..bdfef82e 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/BaseFileAsrEngine.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/BaseFileAsrEngine.kt @@ -5,17 +5,17 @@ import android.media.AudioFormat import android.util.Log import com.brycewg.asrkb.R import com.brycewg.asrkb.store.Prefs +import com.brycewg.asrkb.store.debug.DebugLogManager import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Job -import kotlinx.coroutines.launch import kotlinx.coroutines.channels.BufferOverflow import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.launch import java.io.ByteArrayOutputStream import java.nio.ByteBuffer import java.nio.ByteOrder import java.util.concurrent.atomic.AtomicBoolean -import com.brycewg.asrkb.store.debug.DebugLogManager /** * 基础的文件识别 ASR 引擎,封装了麦克风采集、静音判停等通用逻辑, @@ -26,7 +26,7 @@ abstract class BaseFileAsrEngine( private val scope: CoroutineScope, protected val prefs: Prefs, protected val listener: StreamingAsrEngine.Listener, - protected val onRequestDuration: ((Long) -> Unit)? = null + protected val onRequestDuration: ((Long) -> Unit)? = null, ) : StreamingAsrEngine { companion object { @@ -34,18 +34,22 @@ abstract class BaseFileAsrEngine( } private val running = AtomicBoolean(false) + @Volatile private var stopRequested: Boolean = false + @Volatile private var stoppedDelivered: Boolean = false private var audioJob: Job? = null private var processingJob: Job? = null private var segmentChan: Channel? = null private var lastSegmentForRetry: ByteArray? = null + @Volatile private var discardOnStop: Boolean = false protected open val sampleRate: Int = 16000 protected open val channelConfig: Int = AudioFormat.CHANNEL_IN_MONO protected open val audioFormat: Int = AudioFormat.ENCODING_PCM_16BIT protected open val chunkMillis: Int = 200 + // 非流式录音的最大时长(子类按供应商覆盖)。 // 达到该时长会立即结束录音并触发一次识别请求,以避免超过服务商限制。 protected open val maxRecordDurationMillis: Int = 30 * 60 * 1000 // 默认 30 分钟 @@ -67,7 +71,7 @@ abstract class BaseFileAsrEngine( // 使用有界队列并在溢出时丢弃最旧的数据,避免内存溢出 val chan = Channel( capacity = 10, - onBufferOverflow = BufferOverflow.DROP_OLDEST + onBufferOverflow = BufferOverflow.DROP_OLDEST, ) segmentChan = chan // 顺序消费识别请求,确保结果按段落顺序提交 @@ -84,14 +88,14 @@ abstract class BaseFileAsrEngine( context = context, prefs = prefs, pcm = seg, - sampleRate = sampleRate + sampleRate = sampleRate, ) recognize(denoised) } catch (t: Throwable) { Log.e(TAG, "Recognition failed for segment", t) try { listener.onError( - context.getString(R.string.error_recognize_failed_with_reason, t.message ?: "") + context.getString(R.string.error_recognize_failed_with_reason, t.message ?: ""), ) } catch (e: Throwable) { Log.e(TAG, "Failed to notify recognition error", e) @@ -108,7 +112,9 @@ abstract class BaseFileAsrEngine( recordAndEnqueueSegments(chan) } finally { running.set(false) - try { DebugLogManager.log("asr", "engine_run_end", mapOf("reason" to "audio_job_end")) } catch (_: Throwable) { } + try { + DebugLogManager.log("asr", "engine_run_end", mapOf("reason" to "audio_job_end")) + } catch (_: Throwable) { } // 若录音流意外结束且未显式通知 onStopped,则补发一次,确保上层释放音频焦点与路由。 if (!stoppedDelivered) { try { @@ -174,7 +180,7 @@ abstract class BaseFileAsrEngine( sampleRate = sampleRate, channelConfig = channelConfig, audioFormat = audioFormat, - chunkMillis = chunkMillis + chunkMillis = chunkMillis, ) // 权限检查 @@ -194,9 +200,11 @@ abstract class BaseFileAsrEngine( context, sampleRate, prefs.autoStopSilenceWindowMs, - prefs.autoStopSilenceSensitivity + prefs.autoStopSilenceSensitivity, ) - } else null + } else { + null + } // 计算分段阈值 val maxBytes = (maxRecordDurationMillis / 1000.0 * sampleRate * bytesPerSample).toInt() @@ -236,7 +244,9 @@ abstract class BaseFileAsrEngine( val ok = chan.trySend(head).isSuccess if (ok) { pendingList.removeFirst() - } else break + } else { + break + } } // 尝试直接投递最后一段;不成则加入待发送 val ok2 = chan.trySend(last).isSuccess @@ -263,7 +273,9 @@ abstract class BaseFileAsrEngine( if (r.isSuccess) { pendingList.removeFirst() Log.d(TAG, "Pending segment sent (${head.size} bytes)") - } else break + } else { + break + } } // 达到上限:切出一个片段,不打断录音 @@ -278,7 +290,9 @@ abstract class BaseFileAsrEngine( val ok = chan.trySend(head).isSuccess if (ok) { pendingList.removeFirst() - } else break + } else { + break + } } // 再投递当前片段;不成则加入待发送队列 @@ -426,14 +440,14 @@ abstract class BaseFileAsrEngine( context = context, prefs = prefs, pcm = data, - sampleRate = sampleRate + sampleRate = sampleRate, ) recognize(denoised) } catch (t: Throwable) { Log.e(TAG, "retryLastSegment recognize failed", t) try { listener.onError( - context.getString(R.string.error_recognize_failed_with_reason, t.message ?: "") + context.getString(R.string.error_recognize_failed_with_reason, t.message ?: ""), ) } catch (e: Throwable) { Log.e(TAG, "Failed to notify recognition error (retry)", e) diff --git a/app/src/main/java/com/brycewg/asrkb/asr/BluetoothRouteManager.kt b/app/src/main/java/com/brycewg/asrkb/asr/BluetoothRouteManager.kt index 78f86e75..b99eee0b 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/BluetoothRouteManager.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/BluetoothRouteManager.kt @@ -1,13 +1,11 @@ package com.brycewg.asrkb.asr import android.content.Context -import android.content.Intent -import android.content.IntentFilter import android.media.AudioDeviceInfo import android.media.AudioManager import android.os.Build -import android.util.Log import android.os.SystemClock +import android.util.Log import androidx.annotation.RequiresApi import com.brycewg.asrkb.store.Prefs import kotlinx.coroutines.CoroutineScope @@ -33,16 +31,24 @@ object BluetoothRouteManager { // 状态 @Volatile private var imeSceneActive: Boolean = false + @Volatile private var recordingCount: Int = 0 // 当前路由句柄 @Volatile private var commDeviceSet: Boolean = false + @Volatile private var commListener: Any? = null + @Volatile private var selectedDevice: AudioDeviceInfo? = null + @Volatile private var scoStarted: Boolean = false + @Volatile private var audioModeChanged: Boolean = false + @Volatile private var previousAudioMode: Int? = null + @Volatile private var pendingConnectStartAtMs: Long = 0L + @Volatile private var lastAutoReconnectAtMs: Long = 0L private val scope = CoroutineScope(SupervisorJob() + Dispatchers.IO) @@ -73,9 +79,17 @@ object BluetoothRouteManager { } fun cleanup() { - try { scope.cancel() } catch (t: Throwable) { Log.w(TAG, "cleanup cancel scope", t) } + try { + scope.cancel() + } catch (t: Throwable) { + Log.w(TAG, "cleanup cancel scope", t) + } // 尝试断开路由 - try { disconnectInternal("cleanup") } catch (t: Throwable) { Log.w(TAG, "cleanup disconnect", t) } + try { + disconnectInternal("cleanup") + } catch (t: Throwable) { + Log.w(TAG, "cleanup disconnect", t) + } } private fun shouldRoute(): Boolean { @@ -123,18 +137,30 @@ object BluetoothRouteManager { if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.S) { if (commDeviceSet) { - try { clearCommunicationDeviceSafely(am, commListener) } catch (t: Throwable) { Log.w(TAG, "clearCommunicationDeviceSafely", t) } + try { + clearCommunicationDeviceSafely(am, commListener) + } catch (t: Throwable) { + Log.w(TAG, "clearCommunicationDeviceSafely", t) + } commDeviceSet = false selectedDevice = null commListener = null } } if (scoStarted) { - try { am.stopBluetoothSco() } catch (t: Throwable) { Log.w(TAG, "stopBluetoothSco", t) } + try { + am.stopBluetoothSco() + } catch (t: Throwable) { + Log.w(TAG, "stopBluetoothSco", t) + } scoStarted = false } if (audioModeChanged) { - try { if (previousAudioMode != null) am.mode = previousAudioMode!! } catch (t: Throwable) { Log.w(TAG, "restore audio mode", t) } + try { + if (previousAudioMode != null) am.mode = previousAudioMode!! + } catch (t: Throwable) { + Log.w(TAG, "restore audio mode", t) + } previousAudioMode = null audioModeChanged = false } @@ -159,7 +185,9 @@ object BluetoothRouteManager { if (target == null) return pendingConnectStartAtMs = SystemClock.elapsedRealtime() - val ok = try { am.setCommunicationDevice(target) } catch (t: Throwable) { + val ok = try { + am.setCommunicationDevice(target) + } catch (t: Throwable) { Log.w(TAG, "setCommunicationDevice failed", t) false } @@ -168,7 +196,13 @@ object BluetoothRouteManager { selectedDevice = target commDeviceSet = true // 轻量监听:遇到系统切走时可感知(不强制等待) - val exec = java.util.concurrent.Executor { r -> try { r.run() } catch (t: Throwable) { Log.w(TAG, "CommDeviceChanged runnable error", t) } } + val exec = java.util.concurrent.Executor { r -> + try { + r.run() + } catch (t: Throwable) { + Log.w(TAG, "CommDeviceChanged runnable error", t) + } + } val l = AudioManager.OnCommunicationDeviceChangedListener { dev -> val sel = selectedDevice if (dev == null) return@OnCommunicationDeviceChangedListener @@ -186,13 +220,19 @@ object BluetoothRouteManager { lastAutoReconnectAtMs = now try { // 重新选择目标并设置 - val devices = try { am.getAvailableCommunicationDevices() } catch (_: Throwable) { emptyList() } + val devices = try { + am.getAvailableCommunicationDevices() + } catch (_: Throwable) { + emptyList() + } val desired = devices.firstOrNull { it.type == AudioDeviceInfo.TYPE_BLE_HEADSET } ?: devices.firstOrNull { it.type == AudioDeviceInfo.TYPE_BLUETOOTH_SCO } ?: devices.firstOrNull { it.type == AudioDeviceInfo.TYPE_WIRED_HEADSET } if (desired != null && desired.id != dev.id) { pendingConnectStartAtMs = SystemClock.elapsedRealtime() - val ok2 = try { am.setCommunicationDevice(desired) } catch (t: Throwable) { + val ok2 = try { + am.setCommunicationDevice(desired) + } catch (t: Throwable) { Log.w(TAG, "auto-reconnect setCommunicationDevice failed", t) false } @@ -209,7 +249,11 @@ object BluetoothRouteManager { } } commListener = l - try { am.addOnCommunicationDeviceChangedListener(exec, l) } catch (t: Throwable) { Log.w(TAG, "addOnCommunicationDeviceChangedListener", t) } + try { + am.addOnCommunicationDeviceChangedListener(exec, l) + } catch (t: Throwable) { + Log.w(TAG, "addOnCommunicationDeviceChangedListener", t) + } } @Suppress("DEPRECATION") @@ -218,7 +262,12 @@ object BluetoothRouteManager { if (!am.isBluetoothScoAvailableOffCall) return // 切到通信模式可提升部分机型的稳定性 previousAudioMode = am.mode - try { am.mode = AudioManager.MODE_IN_COMMUNICATION; audioModeChanged = true } catch (t: Throwable) { Log.w(TAG, "set MODE_IN_COMMUNICATION", t) } + try { + am.mode = AudioManager.MODE_IN_COMMUNICATION + audioModeChanged = true + } catch (t: Throwable) { + Log.w(TAG, "set MODE_IN_COMMUNICATION", t) + } try { if (!am.isBluetoothScoOn) am.startBluetoothSco() @@ -231,9 +280,17 @@ object BluetoothRouteManager { @RequiresApi(Build.VERSION_CODES.S) private fun clearCommunicationDeviceSafely(audioManager: AudioManager, listenerToken: Any?) { if (listenerToken is AudioManager.OnCommunicationDeviceChangedListener) { - try { audioManager.removeOnCommunicationDeviceChangedListener(listenerToken) } catch (t: Throwable) { Log.w(TAG, "removeOnCommunicationDeviceChangedListener", t) } + try { + audioManager.removeOnCommunicationDeviceChangedListener(listenerToken) + } catch (t: Throwable) { + Log.w(TAG, "removeOnCommunicationDeviceChangedListener", t) + } + } + try { + audioManager.clearCommunicationDevice() + } catch (t: Throwable) { + Log.w(TAG, "clearCommunicationDevice", t) } - try { audioManager.clearCommunicationDevice() } catch (t: Throwable) { Log.w(TAG, "clearCommunicationDevice", t) } } private fun requireNotNullContext(): Context { diff --git a/app/src/main/java/com/brycewg/asrkb/asr/ChineseItn.kt b/app/src/main/java/com/brycewg/asrkb/asr/ChineseItn.kt index b1db4ac9..14953b07 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/ChineseItn.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/ChineseItn.kt @@ -5,578 +5,586 @@ package com.brycewg.asrkb.asr * 参考 CapsWriter-Offline 自研规则,支持范围、百分/分数/比值、日期与时间等场景。 */ object ChineseItn { - private val unitMapping: Map = linkedMapOf( - "千米每小时" to "km/h", - "千克" to "kg", - "千米" to "千米", - "克" to "g", - "米" to "米", - "个" to null, - "只" to null, - "分" to null, - "万" to null, - "亿" to null, - "秒" to null, - "年" to null, - "月" to null, - "日" to null, - "天" to null, - "时" to null, - "钟" to null, - "人" to null, - "层" to null, - "楼" to null, - "倍" to null, - "块" to null, - "次" to null - ) - - private val commonUnits: String = unitMapping.keys - .sortedByDescending { it.length } - .joinToString("|") { Regex.escape(it) } - - private val unitSuffixRegex = Regex("($commonUnits|[a-zA-Z]+)$") - - private val numMapper: Map = mapOf( - '零' to '0', - '一' to '1', - '幺' to '1', - '二' to '2', - '两' to '2', - '三' to '3', - '四' to '4', - '五' to '5', - '六' to '6', - '七' to '7', - '八' to '8', - '九' to '9', - '点' to '.' - ) - - private val valueMapper: Map = mapOf( - '零' to 0, - '一' to 1, - '二' to 2, - '两' to 2, - '三' to 3, - '四' to 4, - '五' to 5, - '六' to 6, - '七' to 7, - '八' to 8, - '九' to 9, - '十' to 10, - '百' to 100, - '千' to 1000, - '万' to 10000, - '亿' to 100000000 - ) - - private val idioms: Set = setOf( - "正经八百", "五零二落", "五零四散", "五十步笑百步", "乌七八糟", "污七八糟", "四百四病", "思绪万千", - "十有八九", "十之八九", "三十而立", "三十六策", "三十六计", "三十六行", "三五成群", "三百六十行", "三六九等", - "七老八十", "七零八落", "七零八碎", "七七八八", "乱七八遭", "乱七八糟", "略知一二", "零零星星", "零七八碎", - "九九归一", "二三其德", "二三其意", "无银三百两", "八九不离十", "百分之百", "年三十", "烂七八糟", - "一点一滴", "路易十六", "九三学社", "五四运动", "入木三分", "三十六计", "九九八十一", "三七二十一", - "十二五", "十三五", "十四五", "十五五", "十六五", "十七五", "十八五" - ) - - private val idiomRegex = Regex( - idioms - .sortedByDescending { it.length } - .joinToString("|") { Regex.escape(it) } - ) - - private val fuzzyRegex = Regex("几") - - private val pureNumRegex = Regex("[零幺一二两三四五六七八九]+(点[零幺一二两三四五六七八九]+)* *([a-zA-Z]|$commonUnits)?") - private val valueNumRegex = Regex("十?(零?[一二两三四五六七八九十][十百千万]{1,2})*零?十?[一二三四五六七八九]?(点[零一二三四五六七八九]+)? *([a-zA-Z]|$commonUnits)?") - private val consecutiveTensRegex = Regex("^((?:十[一二三四五六七八九])+)(?:($commonUnits))?$") - private val consecutiveHundredsRegex = Regex("^((?:[一二三四五六七八九]百零?[一二三四五六七八九])+)(?:($commonUnits))?$") - - private val percentRegex = Regex("(? lastIndex) { - sb.append(input.substring(lastIndex, start)) - } - if (idiomRanges.any { rangesOverlap(it, m.range) }) { - sb.append(m.value) - } else { - sb.append(replaceSegment(m.value)) - } - lastIndex = m.range.last + 1 - } - if (lastIndex < input.length) { - sb.append(input.substring(lastIndex)) + private val unitMapping: Map = linkedMapOf( + "千米每小时" to "km/h", + "千克" to "kg", + "千米" to "千米", + "克" to "g", + "米" to "米", + "个" to null, + "只" to null, + "分" to null, + "万" to null, + "亿" to null, + "秒" to null, + "年" to null, + "月" to null, + "日" to null, + "天" to null, + "时" to null, + "钟" to null, + "人" to null, + "层" to null, + "楼" to null, + "倍" to null, + "块" to null, + "次" to null, + ) + + private val commonUnits: String = unitMapping.keys + .sortedByDescending { it.length } + .joinToString("|") { Regex.escape(it) } + + private val unitSuffixRegex = Regex("($commonUnits|[a-zA-Z]+)$") + + private val numMapper: Map = mapOf( + '零' to '0', + '一' to '1', + '幺' to '1', + '二' to '2', + '两' to '2', + '三' to '3', + '四' to '4', + '五' to '5', + '六' to '6', + '七' to '7', + '八' to '8', + '九' to '9', + '点' to '.', + ) + + private val valueMapper: Map = mapOf( + '零' to 0, + '一' to 1, + '二' to 2, + '两' to 2, + '三' to 3, + '四' to 4, + '五' to 5, + '六' to 6, + '七' to 7, + '八' to 8, + '九' to 9, + '十' to 10, + '百' to 100, + '千' to 1000, + '万' to 10000, + '亿' to 100000000, + ) + + private val idioms: Set = setOf( + "正经八百", "五零二落", "五零四散", "五十步笑百步", "乌七八糟", "污七八糟", "四百四病", "思绪万千", + "十有八九", "十之八九", "三十而立", "三十六策", "三十六计", "三十六行", "三五成群", "三百六十行", "三六九等", + "七老八十", "七零八落", "七零八碎", "七七八八", "乱七八遭", "乱七八糟", "略知一二", "零零星星", "零七八碎", + "九九归一", "二三其德", "二三其意", "无银三百两", "八九不离十", "百分之百", "年三十", "烂七八糟", + "一点一滴", "路易十六", "九三学社", "五四运动", "入木三分", "三十六计", "九九八十一", "三七二十一", + "十二五", "十三五", "十四五", "十五五", "十六五", "十七五", "十八五", + ) + + private val idiomRegex = Regex( + idioms + .sortedByDescending { it.length } + .joinToString("|") { Regex.escape(it) }, + ) + + private val fuzzyRegex = Regex("几") + + private val pureNumRegex = Regex("[零幺一二两三四五六七八九]+(点[零幺一二两三四五六七八九]+)* *([a-zA-Z]|$commonUnits)?") + private val valueNumRegex = Regex("十?(零?[一二两三四五六七八九十][十百千万]{1,2})*零?十?[一二三四五六七八九]?(点[零一二三四五六七八九]+)? *([a-zA-Z]|$commonUnits)?") + private val consecutiveTensRegex = Regex("^((?:十[一二三四五六七八九])+)(?:($commonUnits))?$") + private val consecutiveHundredsRegex = Regex("^((?:[一二三四五六七八九]百零?[一二三四五六七八九])+)(?:($commonUnits))?$") + + private val percentRegex = Regex("(? lastIndex) { + sb.append(input.substring(lastIndex, start)) + } + if (idiomRanges.any { rangesOverlap(it, m.range) }) { + sb.append(m.value) + } else { + sb.append(replaceSegment(m.value)) + } + lastIndex = m.range.last + 1 + } + if (lastIndex < input.length) { + sb.append(input.substring(lastIndex)) + } + return normalizeArabicDecimalDotSpacing(sb.toString()) } - return normalizeArabicDecimalDotSpacing(sb.toString()) - } - - private fun replaceSegment(segment: String): String { - val (leading, core, trailing) = splitOuterWhitespace(segment) - if (core.isEmpty()) return segment - return leading + replaceSegmentCore(core) + trailing - } - - private fun replaceSegmentCore(segment: String): String { - if (!containsChineseNumber(segment)) return segment - if (idioms.any { segment.contains(it) }) return segment - if (fuzzyRegex.containsMatchIn(segment)) return segment - - val prefixMatch = Regex("^([a-zA-Z]+\\s*)(.+)$").find(segment) - if (prefixMatch != null && containsChineseNumber(prefixMatch.groupValues[2])) { - return prefixMatch.groupValues[1] + replaceSegment(prefixMatch.groupValues[2]) + + private fun replaceSegment(segment: String): String { + val (leading, core, trailing) = splitOuterWhitespace(segment) + if (core.isEmpty()) return segment + return leading + replaceSegmentCore(core) + trailing } - val range = convertRange(segment) - if (range != null) return range + private fun replaceSegmentCore(segment: String): String { + if (!containsChineseNumber(segment)) return segment + if (idioms.any { segment.contains(it) }) return segment + if (fuzzyRegex.containsMatchIn(segment)) return segment - val dateTime = convertDateTimeConnected(segment) - if (dateTime != null) return dateTime + val prefixMatch = Regex("^([a-zA-Z]+\\s*)(.+)$").find(segment) + if (prefixMatch != null && containsChineseNumber(prefixMatch.groupValues[2])) { + return prefixMatch.groupValues[1] + replaceSegment(prefixMatch.groupValues[2]) + } - val time = convertTime(segment) - if (time != null) return time + val range = convertRange(segment) + if (range != null) return range - val matchBase = stripTrailingUnit(segment) - val compactBase = stripWhitespace(matchBase) - if (pureNumRegex.matches(compactBase)) { - val pure = convertPureNum(segment, strict = false) - if (pure != null) return pure - } + val dateTime = convertDateTimeConnected(segment) + if (dateTime != null) return dateTime - val consecutive = convertConsecutive(segment) - if (consecutive != null) return consecutive + val time = convertTime(segment) + if (time != null) return time - if (valueNumRegex.matches(compactBase)) { - val value = convertValueNum(segment) - if (value != null) return value - } + val matchBase = stripTrailingUnit(segment) + val compactBase = stripWhitespace(matchBase) + if (pureNumRegex.matches(compactBase)) { + val pure = convertPureNum(segment, strict = false) + if (pure != null) return pure + } - val percent = convertPercent(segment) - if (percent != null) return percent - - val fraction = convertFraction(segment) - if (fraction != null) return fraction - - val date = convertDate(segment) - if (date != null) return date - - return segment - } - - private fun convertRange(input: String): String? { - if (input.contains('点')) return null - val (baseRaw, unit) = stripUnit(input) - val base = stripWhitespace(baseRaw) - - rangePattern1.find(base)?.let { m -> - val d1 = valueMapper[m.groupValues[1][0]] ?: return null - val d2 = valueMapper[m.groupValues[2][0]] ?: return null - val scale = m.groupValues[3] - val suffix = m.groupValues.getOrNull(4).orEmpty() - val out = when (scale) { - "十" -> "${d1 * 10}~${d2 * 10}$suffix" - "百" -> "${d1 * 100}~${d2 * 100}$suffix" - "千" -> "${d1 * 1000}~${d2 * 1000}$suffix" - "万", "亿" -> "$d1~$d2$scale$suffix" - else -> return null - } - return out + unit - } + val consecutive = convertConsecutive(segment) + if (consecutive != null) return consecutive - rangePattern2.find(base)?.let { m -> - val basePart = m.groupValues[1] - val d1 = valueMapper[m.groupValues[2][0]] ?: return null - val d2 = valueMapper[m.groupValues[3][0]] ?: return null - val suffix = m.groupValues.getOrNull(4).orEmpty() - val lastChar = basePart.lastOrNull() ?: return null - val lastValue = valueMapper[lastChar] ?: return null - val baseValue = parseValueWithoutUnit(basePart) ?: return null - val multiplier = lastValue / 10 - val left = baseValue + d1 * multiplier - val right = baseValue + d2 * multiplier - return "${left}~${right}$suffix$unit" - } + if (valueNumRegex.matches(compactBase)) { + val value = convertValueNum(segment) + if (value != null) return value + } + + val percent = convertPercent(segment) + if (percent != null) return percent - rangePattern3.find(base)?.let { m -> - val d1 = valueMapper[m.groupValues[1][0]] ?: return null - val d2 = valueMapper[m.groupValues[2][0]] ?: return null - return "$d1~$d2$unit" + val fraction = convertFraction(segment) + if (fraction != null) return fraction + + val date = convertDate(segment) + if (date != null) return date + + return segment } - return null - } - - private fun convertTime(input: String): String? { - val text = stripWhitespace(input) - if (!timeRegex.matches(text)) return null - val dot = text.indexOf('点') - val fen = text.indexOf('分') - if (dot <= 0 || fen <= dot) return null - val miao = text.indexOf('秒') - val hourText = text.substring(0, dot) - val minuteText = text.substring(dot + 1, fen) - val secondText = if (miao > fen) text.substring(fen + 1, miao) else "" - - val hour = parseIntValue(hourText) ?: return null - val minute = parseIntValue(minuteText) ?: return null - val second = if (secondText.isNotEmpty()) parseIntValue(secondText) else null - - val base = "${hour.toString().padStart(2, '0')}:${minute.toString().padStart(2, '0')}" - return if (second != null) { - "$base:${second.toString().padStart(2, '0')}" - } else { - base + private fun convertRange(input: String): String? { + if (input.contains('点')) return null + val (baseRaw, unit) = stripUnit(input) + val base = stripWhitespace(baseRaw) + + rangePattern1.find(base)?.let { m -> + val d1 = valueMapper[m.groupValues[1][0]] ?: return null + val d2 = valueMapper[m.groupValues[2][0]] ?: return null + val scale = m.groupValues[3] + val suffix = m.groupValues.getOrNull(4).orEmpty() + val out = when (scale) { + "十" -> "${d1 * 10}~${d2 * 10}$suffix" + "百" -> "${d1 * 100}~${d2 * 100}$suffix" + "千" -> "${d1 * 1000}~${d2 * 1000}$suffix" + "万", "亿" -> "$d1~$d2$scale$suffix" + else -> return null + } + return out + unit + } + + rangePattern2.find(base)?.let { m -> + val basePart = m.groupValues[1] + val d1 = valueMapper[m.groupValues[2][0]] ?: return null + val d2 = valueMapper[m.groupValues[3][0]] ?: return null + val suffix = m.groupValues.getOrNull(4).orEmpty() + val lastChar = basePart.lastOrNull() ?: return null + val lastValue = valueMapper[lastChar] ?: return null + val baseValue = parseValueWithoutUnit(basePart) ?: return null + val multiplier = lastValue / 10 + val left = baseValue + d1 * multiplier + val right = baseValue + d2 * multiplier + return "$left~${right}$suffix$unit" + } + + rangePattern3.find(base)?.let { m -> + val d1 = valueMapper[m.groupValues[1][0]] ?: return null + val d2 = valueMapper[m.groupValues[2][0]] ?: return null + return "$d1~$d2$unit" + } + + return null } - } - - private fun convertPercent(input: String): String? { - val m = percentRegex.matchEntire(input) ?: return null - val num = m.value.removePrefix("百分之") - val converted = convertValueNum(num) ?: return null - return converted + "%" - } - - private fun convertFraction(input: String): String? { - val m = fractionRegex.matchEntire(input) ?: return null - val left = m.groupValues[1] - val right = m.groupValues[2] - val leftValue = convertValueNum(left) ?: return null - val rightValue = convertValueNum(right) ?: return null - return "$rightValue/$leftValue" - } - - private fun convertDate(input: String): String? { - val text = stripWhitespace(input) - val m = dateRegex.matchEntire(text) ?: return null - val yearRaw = m.groupValues.getOrNull(1).orEmpty() - val monthRaw = m.groupValues.getOrNull(2).orEmpty() - val dayRaw = m.groupValues.getOrNull(3).orEmpty() - if (yearRaw.isEmpty() && monthRaw.isEmpty() && dayRaw.isEmpty()) return null - - val year = if (yearRaw.isNotEmpty()) { - val part = yearRaw.removeSuffix("年") - part.toIntOrNull()?.toString() - ?: convertPureNum(part, strict = true)?.takeIf { it.isNotEmpty() } - ?: convertValueNum(part) - } else null - - val month = if (monthRaw.isNotEmpty()) { - val part = monthRaw.removeSuffix("月") - part.toIntOrNull()?.toString() ?: convertValueNum(part) - } else null - - val day = if (dayRaw.isNotEmpty()) { - val suffix = if (dayRaw.endsWith("日")) "日" else "号" - val part = dayRaw.removeSuffix(suffix) - (part.toIntOrNull()?.toString() ?: convertValueNum(part))?.let { it + suffix } - } else null - - val sb = StringBuilder() - if (year != null) sb.append(year).append("年") - if (month != null) sb.append(month).append("月") - if (day != null) sb.append(day) - return if (sb.isEmpty()) null else sb.toString() - } - - private fun convertConsecutive(input: String): String? { - consecutiveTensRegex.matchEntire(input)?.let { m -> - val body = m.groupValues[1] - val unit = mapUnit(m.groupValues.getOrNull(2).orEmpty()) - val parts = Regex("十[一二三四五六七八九]").findAll(body).mapNotNull { - convertValueNum(it.value) - }.toList() - if (parts.isNotEmpty()) { - return parts.joinToString(" ") + unit - } + + private fun convertTime(input: String): String? { + val text = stripWhitespace(input) + if (!timeRegex.matches(text)) return null + val dot = text.indexOf('点') + val fen = text.indexOf('分') + if (dot <= 0 || fen <= dot) return null + val miao = text.indexOf('秒') + val hourText = text.substring(0, dot) + val minuteText = text.substring(dot + 1, fen) + val secondText = if (miao > fen) text.substring(fen + 1, miao) else "" + + val hour = parseIntValue(hourText) ?: return null + val minute = parseIntValue(minuteText) ?: return null + val second = if (secondText.isNotEmpty()) parseIntValue(secondText) else null + + val base = "${hour.toString().padStart(2, '0')}:${minute.toString().padStart(2, '0')}" + return if (second != null) { + "$base:${second.toString().padStart(2, '0')}" + } else { + base + } } - consecutiveHundredsRegex.matchEntire(input)?.let { m -> - val body = m.groupValues[1] - val unit = mapUnit(m.groupValues.getOrNull(2).orEmpty()) - val parts = Regex("[一二三四五六七八九]百零?[一二三四五六七八九]").findAll(body).mapNotNull { - convertValueNum(it.value) - }.toList() - if (parts.isNotEmpty()) { - return parts.joinToString(" ") + unit - } + + private fun convertPercent(input: String): String? { + val m = percentRegex.matchEntire(input) ?: return null + val num = m.value.removePrefix("百分之") + val converted = convertValueNum(num) ?: return null + return converted + "%" } - return null - } - - private fun convertPureNum(input: String, strict: Boolean): String? { - val (raw, unit) = stripUnit(input) - val text = stripWhitespace(raw) - if (text.isEmpty()) return null - if (!strict && text.length == 1 && text[0] == '一' && unit.isEmpty()) return input - val out = StringBuilder() - for (ch in text) { - val mapped = numMapper[ch] ?: return null - out.append(mapped) + + private fun convertFraction(input: String): String? { + val m = fractionRegex.matchEntire(input) ?: return null + val left = m.groupValues[1] + val right = m.groupValues[2] + val leftValue = convertValueNum(left) ?: return null + val rightValue = convertValueNum(right) ?: return null + return "$rightValue/$leftValue" } - return out.toString() + unit - } - - private fun convertValueNum(input: String): String? { - val (raw, unit) = stripUnit(input) - val text = stripWhitespace(raw) - if (text.isEmpty()) return null - - val parts = text.split('点', limit = 2) - val intPart = parts[0] - val decPart = if (parts.size > 1) parts[1] else "" - - var value = 0L - var temp = 0L - var base = 1L - for (ch in intPart) { - when (ch) { - '十' -> { - temp = if (temp == 0L) 10L else temp * 10L - base = 1L + + private fun convertDate(input: String): String? { + val text = stripWhitespace(input) + val m = dateRegex.matchEntire(text) ?: return null + val yearRaw = m.groupValues.getOrNull(1).orEmpty() + val monthRaw = m.groupValues.getOrNull(2).orEmpty() + val dayRaw = m.groupValues.getOrNull(3).orEmpty() + if (yearRaw.isEmpty() && monthRaw.isEmpty() && dayRaw.isEmpty()) return null + + val year = if (yearRaw.isNotEmpty()) { + val part = yearRaw.removeSuffix("年") + part.toIntOrNull()?.toString() + ?: convertPureNum(part, strict = true)?.takeIf { it.isNotEmpty() } + ?: convertValueNum(part) + } else { + null + } + + val month = if (monthRaw.isNotEmpty()) { + val part = monthRaw.removeSuffix("月") + part.toIntOrNull()?.toString() ?: convertValueNum(part) + } else { + null } - '零' -> base = 1L - '一', '二', '两', '三', '四', '五', '六', '七', '八', '九' -> { - temp += valueMapper[ch] ?: 0 + + val day = if (dayRaw.isNotEmpty()) { + val suffix = if (dayRaw.endsWith("日")) "日" else "号" + val part = dayRaw.removeSuffix(suffix) + (part.toIntOrNull()?.toString() ?: convertValueNum(part))?.let { it + suffix } + } else { + null } - '万' -> { - value += temp - value *= 10000L - base = 1000L - temp = 0L + + val sb = StringBuilder() + if (year != null) sb.append(year).append("年") + if (month != null) sb.append(month).append("月") + if (day != null) sb.append(day) + return if (sb.isEmpty()) null else sb.toString() + } + + private fun convertConsecutive(input: String): String? { + consecutiveTensRegex.matchEntire(input)?.let { m -> + val body = m.groupValues[1] + val unit = mapUnit(m.groupValues.getOrNull(2).orEmpty()) + val parts = Regex("十[一二三四五六七八九]").findAll(body).mapNotNull { + convertValueNum(it.value) + }.toList() + if (parts.isNotEmpty()) { + return parts.joinToString(" ") + unit + } } - '百', '千' -> { - val weight = valueMapper[ch] ?: return null - value += temp * weight - base = weight / 10L - temp = 0L + consecutiveHundredsRegex.matchEntire(input)?.let { m -> + val body = m.groupValues[1] + val unit = mapUnit(m.groupValues.getOrNull(2).orEmpty()) + val parts = Regex("[一二三四五六七八九]百零?[一二三四五六七八九]").findAll(body).mapNotNull { + convertValueNum(it.value) + }.toList() + if (parts.isNotEmpty()) { + return parts.joinToString(" ") + unit + } } - else -> return null - } + return null } - value += temp * base - - val decOut = if (decPart.isNotEmpty()) { - val sb = StringBuilder() - for (ch in decPart) { - val mapped = numMapper[ch] ?: return null - sb.append(mapped) - } - sb.toString() - } else "" - - val out = if (decOut.isNotEmpty()) "$value.$decOut" else value.toString() - return out + unit - } - - private fun parseIntValue(input: String): Int? { - val text = stripWhitespace(input) - text.toIntOrNull()?.let { return it } - val value = convertValueNum(text) ?: convertPureNum(text, strict = true) - return value?.toIntOrNull() - } - - private fun parseValueWithoutUnit(input: String): Long? { - val text = stripWhitespace(input) - if (text.isEmpty()) return null - - val parts = text.split('点', limit = 2) - val intPart = parts[0] - val decPart = if (parts.size > 1) parts[1] else "" - if (decPart.isNotEmpty()) return null - - var value = 0L - var temp = 0L - var base = 1L - for (ch in intPart) { - when (ch) { - '十' -> { - temp = if (temp == 0L) 10L else temp * 10L - base = 1L + + private fun convertPureNum(input: String, strict: Boolean): String? { + val (raw, unit) = stripUnit(input) + val text = stripWhitespace(raw) + if (text.isEmpty()) return null + if (!strict && text.length == 1 && text[0] == '一' && unit.isEmpty()) return input + val out = StringBuilder() + for (ch in text) { + val mapped = numMapper[ch] ?: return null + out.append(mapped) } - '零' -> base = 1L - '一', '二', '两', '三', '四', '五', '六', '七', '八', '九' -> { - temp += valueMapper[ch] ?: 0 + return out.toString() + unit + } + + private fun convertValueNum(input: String): String? { + val (raw, unit) = stripUnit(input) + val text = stripWhitespace(raw) + if (text.isEmpty()) return null + + val parts = text.split('点', limit = 2) + val intPart = parts[0] + val decPart = if (parts.size > 1) parts[1] else "" + + var value = 0L + var temp = 0L + var base = 1L + for (ch in intPart) { + when (ch) { + '十' -> { + temp = if (temp == 0L) 10L else temp * 10L + base = 1L + } + '零' -> base = 1L + '一', '二', '两', '三', '四', '五', '六', '七', '八', '九' -> { + temp += valueMapper[ch] ?: 0 + } + '万' -> { + value += temp + value *= 10000L + base = 1000L + temp = 0L + } + '百', '千' -> { + val weight = valueMapper[ch] ?: return null + value += temp * weight + base = weight / 10L + temp = 0L + } + else -> return null + } } - '万' -> { - value += temp - value *= 10000L - base = 1000L - temp = 0L + value += temp * base + + val decOut = if (decPart.isNotEmpty()) { + val sb = StringBuilder() + for (ch in decPart) { + val mapped = numMapper[ch] ?: return null + sb.append(mapped) + } + sb.toString() + } else { + "" } - '百', '千' -> { - val weight = valueMapper[ch] ?: return null - value += temp * weight - base = weight / 10L - temp = 0L + + val out = if (decOut.isNotEmpty()) "$value.$decOut" else value.toString() + return out + unit + } + + private fun parseIntValue(input: String): Int? { + val text = stripWhitespace(input) + text.toIntOrNull()?.let { return it } + val value = convertValueNum(text) ?: convertPureNum(text, strict = true) + return value?.toIntOrNull() + } + + private fun parseValueWithoutUnit(input: String): Long? { + val text = stripWhitespace(input) + if (text.isEmpty()) return null + + val parts = text.split('点', limit = 2) + val intPart = parts[0] + val decPart = if (parts.size > 1) parts[1] else "" + if (decPart.isNotEmpty()) return null + + var value = 0L + var temp = 0L + var base = 1L + for (ch in intPart) { + when (ch) { + '十' -> { + temp = if (temp == 0L) 10L else temp * 10L + base = 1L + } + '零' -> base = 1L + '一', '二', '两', '三', '四', '五', '六', '七', '八', '九' -> { + temp += valueMapper[ch] ?: 0 + } + '万' -> { + value += temp + value *= 10000L + base = 1000L + temp = 0L + } + '百', '千' -> { + val weight = valueMapper[ch] ?: return null + value += temp * weight + base = weight / 10L + temp = 0L + } + else -> return null + } } - else -> return null - } + value += temp * base + return value + } + + private fun stripTrailingUnit(input: String): String { + val m = unitSuffixRegex.find(input) ?: return input + return input.substring(0, m.range.first) } - value += temp * base - return value - } - - private fun stripTrailingUnit(input: String): String { - val m = unitSuffixRegex.find(input) ?: return input - return input.substring(0, m.range.first) - } - - private fun stripUnit(input: String): Pair { - val m = unitSuffixRegex.find(input) - if (m == null) return input to "" - val unit = m.value - val mapped = mapUnit(unit) - return input.substring(0, m.range.first) to mapped - } - - private fun mapUnit(unit: String): String { - if (unit.isEmpty()) return "" - val mapped = unitMapping[unit] ?: return unit - return mapped ?: unit - } - - private fun convertDateTimeConnected(input: String): String? { - val text = stripWhitespace(input) - val dotIndex = text.indexOf('点') - if (dotIndex <= 0) return null - val splitIndex = text.lastIndexOfAny(charArrayOf('日', '号', '月', '年'), startIndex = dotIndex - 1) - if (splitIndex < 0) return null - - val dateCandidate = text.substring(0, splitIndex + 1) - val timeCandidate = text.substring(splitIndex + 1) - if (dateCandidate.isEmpty() || timeCandidate.isEmpty()) return null - - val convertedDate = convertDate(dateCandidate) ?: dateCandidate - val convertedTime = convertTime(timeCandidate) ?: timeCandidate - val out = convertedDate + convertedTime - return out.takeIf { it != text } - } - - private fun containsChineseNumber(input: String): Boolean { - for (ch in input) { - if (numMapper.containsKey(ch) || valueMapper.containsKey(ch) || ch == '幺' || ch == '两') { - return true - } + + private fun stripUnit(input: String): Pair { + val m = unitSuffixRegex.find(input) + if (m == null) return input to "" + val unit = m.value + val mapped = mapUnit(unit) + return input.substring(0, m.range.first) to mapped } - return false - } - - private fun findIdiomRanges(input: String): List { - val ranges = idiomRegex.findAll(input).map { it.range }.toList() - if (ranges.isEmpty()) return emptyList() - val sorted = ranges.sortedBy { it.first } - val merged = ArrayList(sorted.size) - var current = sorted[0] - for (i in 1 until sorted.size) { - val next = sorted[i] - if (next.first <= current.last + 1) { - current = current.first..maxOf(current.last, next.last) - } else { - merged.add(current) - current = next - } + + private fun mapUnit(unit: String): String { + if (unit.isEmpty()) return "" + val mapped = unitMapping[unit] ?: return unit + return mapped ?: unit } - merged.add(current) - return merged - } - - private fun rangesOverlap(a: IntRange, b: IntRange): Boolean { - return a.first <= b.last && b.first <= a.last - } - - private fun splitOuterWhitespace(input: String): Triple { - if (input.isEmpty()) return Triple("", "", "") - var start = 0 - while (start < input.length && isAnyWhitespace(input[start])) start++ - var end = input.length - while (end > start && isAnyWhitespace(input[end - 1])) end-- - return Triple(input.substring(0, start), input.substring(start, end), input.substring(end)) - } - - private fun stripWhitespace(input: String): String { - if (input.isEmpty()) return input - return input.filterNot { isAnyWhitespace(it) } - } - - private fun isAnyWhitespace(ch: Char): Boolean { - return ch.isWhitespace() || Character.isSpaceChar(ch) - } - - private fun normalizeArabicDecimalDotSpacing(input: String): String { - if (input.isEmpty()) return input - if (!input.contains('.') && !input.contains('.')) return input - - fun isAsciiDigit(ch: Char): Boolean = ch in '0'..'9' - - val out = StringBuilder(input.length) - var i = 0 - while (i < input.length) { - val ch = input[i] - if (isAsciiDigit(ch)) { - out.append(ch) - val afterDigit = i + 1 - var k = afterDigit - while (k < input.length && isAnyWhitespace(input[k])) k++ - if (k < input.length && (input[k] == '.' || input[k] == '.')) { - var l = k + 1 - while (l < input.length && isAnyWhitespace(input[l])) l++ - if (l < input.length && isAsciiDigit(input[l])) { - out.append('.') - out.append(input[l]) - i = l + 1 - continue - } + + private fun convertDateTimeConnected(input: String): String? { + val text = stripWhitespace(input) + val dotIndex = text.indexOf('点') + if (dotIndex <= 0) return null + val splitIndex = text.lastIndexOfAny(charArrayOf('日', '号', '月', '年'), startIndex = dotIndex - 1) + if (splitIndex < 0) return null + + val dateCandidate = text.substring(0, splitIndex + 1) + val timeCandidate = text.substring(splitIndex + 1) + if (dateCandidate.isEmpty() || timeCandidate.isEmpty()) return null + + val convertedDate = convertDate(dateCandidate) ?: dateCandidate + val convertedTime = convertTime(timeCandidate) ?: timeCandidate + val out = convertedDate + convertedTime + return out.takeIf { it != text } + } + + private fun containsChineseNumber(input: String): Boolean { + for (ch in input) { + if (numMapper.containsKey(ch) || valueMapper.containsKey(ch) || ch == '幺' || ch == '两') { + return true + } } - i = afterDigit - continue - } - out.append(ch) - i++ + return false } - return out.toString() - } - - private fun buildAllowedCharClass(): String { - val chars = LinkedHashSet() - numMapper.keys.forEach { chars.add(it) } - valueMapper.keys.forEach { chars.add(it) } - chars.add('几') - chars.add('分') - chars.add('之') - chars.add('年') - chars.add('月') - chars.add('日') - chars.add('号') - chars.add('秒') - unitMapping.keys.forEach { unit -> - unit.forEach { chars.add(it) } + + private fun findIdiomRanges(input: String): List { + val ranges = idiomRegex.findAll(input).map { it.range }.toList() + if (ranges.isEmpty()) return emptyList() + val sorted = ranges.sortedBy { it.first } + val merged = ArrayList(sorted.size) + var current = sorted[0] + for (i in 1 until sorted.size) { + val next = sorted[i] + if (next.first <= current.last + 1) { + current = current.first..maxOf(current.last, next.last) + } else { + merged.add(current) + current = next + } + } + merged.add(current) + return merged + } + + private fun rangesOverlap(a: IntRange, b: IntRange): Boolean { + return a.first <= b.last && b.first <= a.last + } + + private fun splitOuterWhitespace(input: String): Triple { + if (input.isEmpty()) return Triple("", "", "") + var start = 0 + while (start < input.length && isAnyWhitespace(input[start])) start++ + var end = input.length + while (end > start && isAnyWhitespace(input[end - 1])) end-- + return Triple(input.substring(0, start), input.substring(start, end), input.substring(end)) + } + + private fun stripWhitespace(input: String): String { + if (input.isEmpty()) return input + return input.filterNot { isAnyWhitespace(it) } } - val sb = StringBuilder() - for (ch in chars) { - when (ch) { - '\\', '-', ']', '^' -> { - sb.append('\\').append(ch) + private fun isAnyWhitespace(ch: Char): Boolean { + return ch.isWhitespace() || Character.isSpaceChar(ch) + } + + private fun normalizeArabicDecimalDotSpacing(input: String): String { + if (input.isEmpty()) return input + if (!input.contains('.') && !input.contains('.')) return input + + fun isAsciiDigit(ch: Char): Boolean = ch in '0'..'9' + + val out = StringBuilder(input.length) + var i = 0 + while (i < input.length) { + val ch = input[i] + if (isAsciiDigit(ch)) { + out.append(ch) + val afterDigit = i + 1 + var k = afterDigit + while (k < input.length && isAnyWhitespace(input[k])) k++ + if (k < input.length && (input[k] == '.' || input[k] == '.')) { + var l = k + 1 + while (l < input.length && isAnyWhitespace(input[l])) l++ + if (l < input.length && isAsciiDigit(input[l])) { + out.append('.') + out.append(input[l]) + i = l + 1 + continue + } + } + i = afterDigit + continue + } + out.append(ch) + i++ + } + return out.toString() + } + + private fun buildAllowedCharClass(): String { + val chars = LinkedHashSet() + numMapper.keys.forEach { chars.add(it) } + valueMapper.keys.forEach { chars.add(it) } + chars.add('几') + chars.add('分') + chars.add('之') + chars.add('年') + chars.add('月') + chars.add('日') + chars.add('号') + chars.add('秒') + unitMapping.keys.forEach { unit -> + unit.forEach { chars.add(it) } + } + + val sb = StringBuilder() + for (ch in chars) { + when (ch) { + '\\', '-', ']', '^' -> { + sb.append('\\').append(ch) + } + else -> sb.append(ch) + } } - else -> sb.append(ch) - } + return sb.toString() } - return sb.toString() - } } diff --git a/app/src/main/java/com/brycewg/asrkb/asr/DashscopeFileAsrEngine.kt b/app/src/main/java/com/brycewg/asrkb/asr/DashscopeFileAsrEngine.kt index 6f10ed1d..8df337d6 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/DashscopeFileAsrEngine.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/DashscopeFileAsrEngine.kt @@ -7,9 +7,6 @@ import com.alibaba.dashscope.aigc.multimodalconversation.MultiModalConversationP import com.alibaba.dashscope.aigc.multimodalconversation.MultiModalConversationResult import com.alibaba.dashscope.common.MultiModalMessage import com.alibaba.dashscope.common.Role -import com.alibaba.dashscope.exception.ApiException -import com.alibaba.dashscope.exception.NoApiKeyException -import com.alibaba.dashscope.exception.UploadFileException import com.alibaba.dashscope.utils.Constants import com.alibaba.dashscope.utils.JsonUtils import com.brycewg.asrkb.R @@ -29,7 +26,7 @@ class DashscopeFileAsrEngine( scope: CoroutineScope, prefs: Prefs, listener: StreamingAsrEngine.Listener, - onRequestDuration: ((Long) -> Unit)? = null + onRequestDuration: ((Long) -> Unit)? = null, ) : BaseFileAsrEngine(context, scope, prefs, listener, onRequestDuration), PcmBatchRecognizer { companion object { @@ -58,7 +55,7 @@ class DashscopeFileAsrEngine( } catch (e: Throwable) { Log.e(TAG, "Failed to materialize WAV file", e) listener.onError( - context.getString(R.string.error_recognize_failed_with_reason, e.message ?: "") + context.getString(R.string.error_recognize_failed_with_reason, e.message ?: ""), ) return } @@ -103,23 +100,31 @@ class DashscopeFileAsrEngine( val result: MultiModalConversationResult = conv.call(param) // 6) 解析结果(沿用原有 JSON 解析逻辑) - val json = try { JsonUtils.toJson(result) } catch (e: Throwable) { "" } + val json = try { + JsonUtils.toJson(result) + } catch (e: Throwable) { + "" + } val text = parseDashscopeText(json) if (text.isNotBlank()) { val dt = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - t0) - try { onRequestDuration?.invoke(dt) } catch (e: Throwable) { + try { + onRequestDuration?.invoke(dt) + } catch (e: Throwable) { Log.w(TAG, "Failed to dispatch duration", e) } listener.onFinal(text) } else { listener.onError(context.getString(R.string.error_asr_empty_result)) } - }finally { + } finally { tmp.delete() } } - override suspend fun recognizeFromPcm(pcm: ByteArray) { recognize(pcm) } + override suspend fun recognizeFromPcm(pcm: ByteArray) { + recognize(pcm) + } /** * 从 DashScope 响应体中解析转写文本 @@ -152,7 +157,9 @@ class DashscopeFileAsrEngine( * 获取文件名的安全方法 */ private fun File.nameIfExists(): String { - return try { name } catch (t: Throwable) { + return try { + name + } catch (t: Throwable) { Log.e(TAG, "Failed to get file name", t) "upload.wav" } diff --git a/app/src/main/java/com/brycewg/asrkb/asr/DashscopeStreamAsrEngine.kt b/app/src/main/java/com/brycewg/asrkb/asr/DashscopeStreamAsrEngine.kt index 86327c93..96b99d8c 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/DashscopeStreamAsrEngine.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/DashscopeStreamAsrEngine.kt @@ -10,13 +10,13 @@ import androidx.core.content.ContextCompat import com.alibaba.dashscope.audio.asr.recognition.Recognition import com.alibaba.dashscope.audio.asr.recognition.RecognitionParam import com.alibaba.dashscope.audio.asr.recognition.RecognitionResult -import com.alibaba.dashscope.common.ResultCallback import com.alibaba.dashscope.audio.omni.OmniRealtimeCallback import com.alibaba.dashscope.audio.omni.OmniRealtimeConfig import com.alibaba.dashscope.audio.omni.OmniRealtimeConversation import com.alibaba.dashscope.audio.omni.OmniRealtimeModality import com.alibaba.dashscope.audio.omni.OmniRealtimeParam import com.alibaba.dashscope.audio.omni.OmniRealtimeTranscriptionParam +import com.alibaba.dashscope.common.ResultCallback import com.alibaba.dashscope.utils.Constants import com.brycewg.asrkb.R import com.brycewg.asrkb.store.Prefs @@ -41,646 +41,657 @@ import java.util.concurrent.atomic.AtomicBoolean * - 支持 language 和 corpusText 参数提升识别准确度。 */ class DashscopeStreamAsrEngine( - private val context: Context, - private val scope: CoroutineScope, - private val prefs: Prefs, - private val listener: StreamingAsrEngine.Listener, - private val externalPcmMode: Boolean = false + private val context: Context, + private val scope: CoroutineScope, + private val prefs: Prefs, + private val listener: StreamingAsrEngine.Listener, + private val externalPcmMode: Boolean = false, ) : StreamingAsrEngine, ExternalPcmConsumer { - companion object { - private const val TAG = "DashscopeStreamAsrEngine" - private const val MODEL_QWEN3 = Prefs.DASH_MODEL_QWEN3_REALTIME - private const val MODEL_FUN_ASR = Prefs.DASH_MODEL_FUN_ASR_REALTIME - private const val WS_URL_CN = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime" - private const val WS_URL_INTL = "wss://dashscope-intl.aliyuncs.com/api-ws/v1/realtime" - private const val WS_URL_INFER_CN = "wss://dashscope.aliyuncs.com/api-ws/v1/inference" - private const val WS_URL_INFER_INTL = "wss://dashscope-intl.aliyuncs.com/api-ws/v1/inference" - private const val FINAL_RESULT_TIMEOUT_MS = 6000L - } - - private val running = AtomicBoolean(false) - private var audioJob: Job? = null - private var controlJob: Job? = null - - private val sampleRate = 16000 - private val channelConfig = AudioFormat.CHANNEL_IN_MONO - private val audioFormat = AudioFormat.ENCODING_PCM_16BIT - - private var conversation: OmniRealtimeConversation? = null - private var recognizer: Recognition? = null - private var useFunAsrModel: Boolean = false - - // 用于识别结果 - // currentTurnText: 当前已确定的文本(来自 text 事件的 text 字段,用于实时预览) - // currentTurnStash: 当前未确定的中间文本(来自 text 事件的 stash 字段,用于实时预览) - // finalTranscript: 用户停止后,由 commit() 触发的最终完整识别结果 - private var currentTurnText: String = "" - private var currentTurnStash: String = "" - private var finalTranscript: String? = null - private var finalResultDeferred: CompletableDeferred? = null - private val finalDelivered = AtomicBoolean(false) - - override val isRunning: Boolean - get() = running.get() - - private val prebuffer = java.util.ArrayDeque() - private val prebufferLock = Any() - @Volatile private var convReady: Boolean = false - - override fun start() { - if (running.get()) return - if (!externalPcmMode) { - val hasPermission = ContextCompat.checkSelfPermission( - context, - Manifest.permission.RECORD_AUDIO - ) == PackageManager.PERMISSION_GRANTED - if (!hasPermission) { - listener.onError(context.getString(R.string.error_record_permission_denied)) - return - } - } - if (prefs.dashApiKey.isBlank()) { - listener.onError(context.getString(R.string.error_missing_dashscope_key)) - return + companion object { + private const val TAG = "DashscopeStreamAsrEngine" + private const val MODEL_QWEN3 = Prefs.DASH_MODEL_QWEN3_REALTIME + private const val MODEL_FUN_ASR = Prefs.DASH_MODEL_FUN_ASR_REALTIME + private const val WS_URL_CN = "wss://dashscope.aliyuncs.com/api-ws/v1/realtime" + private const val WS_URL_INTL = "wss://dashscope-intl.aliyuncs.com/api-ws/v1/realtime" + private const val WS_URL_INFER_CN = "wss://dashscope.aliyuncs.com/api-ws/v1/inference" + private const val WS_URL_INFER_INTL = "wss://dashscope-intl.aliyuncs.com/api-ws/v1/inference" + private const val FINAL_RESULT_TIMEOUT_MS = 6000L } - useFunAsrModel = prefs.dashAsrModel.startsWith("fun-asr", ignoreCase = true) + private val running = AtomicBoolean(false) + private var audioJob: Job? = null + private var controlJob: Job? = null - running.set(true) - currentTurnText = "" - currentTurnStash = "" - finalTranscript = null - finalResultDeferred = null - finalDelivered.set(false) + private val sampleRate = 16000 + private val channelConfig = AudioFormat.CHANNEL_IN_MONO + private val audioFormat = AudioFormat.ENCODING_PCM_16BIT - // 在 IO 线程启动 SDK 识别并随后启动采集 - controlJob?.cancel() - controlJob = scope.launch(Dispatchers.IO) { - try { - if (useFunAsrModel) { - startFunAsrStreaming() - return@launch - } + private var conversation: OmniRealtimeConversation? = null + private var recognizer: Recognition? = null + private var useFunAsrModel: Boolean = false - // 根据地域选择 WebSocket URL - val wsUrl = if (prefs.dashRegion.equals("intl", ignoreCase = true)) WS_URL_INTL else WS_URL_CN - - // 构建 OmniRealtimeParam - val param = OmniRealtimeParam.builder() - .model(MODEL_QWEN3) - .apikey(prefs.dashApiKey) - .url(wsUrl) - .build() - - // 构建 OmniRealtimeTranscriptionParam - val transcriptionParam = OmniRealtimeTranscriptionParam() - transcriptionParam.setInputSampleRate(sampleRate) - transcriptionParam.setInputAudioFormat("pcm") - // 可选:设置语言以提升准确度 - val lang = prefs.dashLanguage - if (lang.isNotBlank()) { - transcriptionParam.setLanguage(lang) + // 用于识别结果 + // currentTurnText: 当前已确定的文本(来自 text 事件的 text 字段,用于实时预览) + // currentTurnStash: 当前未确定的中间文本(来自 text 事件的 stash 字段,用于实时预览) + // finalTranscript: 用户停止后,由 commit() 触发的最终完整识别结果 + private var currentTurnText: String = "" + private var currentTurnStash: String = "" + private var finalTranscript: String? = null + private var finalResultDeferred: CompletableDeferred? = null + private val finalDelivered = AtomicBoolean(false) + + override val isRunning: Boolean + get() = running.get() + + private val prebuffer = java.util.ArrayDeque() + private val prebufferLock = Any() + + @Volatile private var convReady: Boolean = false + + override fun start() { + if (running.get()) return + if (!externalPcmMode) { + val hasPermission = ContextCompat.checkSelfPermission( + context, + Manifest.permission.RECORD_AUDIO, + ) == PackageManager.PERMISSION_GRANTED + if (!hasPermission) { + listener.onError(context.getString(R.string.error_record_permission_denied)) + return + } } - // 可选:设置语料文本 - val corpus = prefs.dashPrompt - if (corpus.isNotBlank()) { - transcriptionParam.setCorpusText(corpus) + if (prefs.dashApiKey.isBlank()) { + listener.onError(context.getString(R.string.error_missing_dashscope_key)) + return } - // 构建 OmniRealtimeConfig(关闭服务端 VAD,使用手动模式) - // 手动模式下:text 事件仍实时返回用于预览,用户停止时调用 commit() 触发最终识别 - val config = OmniRealtimeConfig.builder() - .modalities(listOf(OmniRealtimeModality.TEXT)) - .enableTurnDetection(false) // 关闭服务端 VAD - .transcriptionConfig(transcriptionParam) - .build() - - // 创建回调 - val callback = object : OmniRealtimeCallback() { - override fun onOpen() { - Log.d(TAG, "WebSocket opened, updating session config") + useFunAsrModel = prefs.dashAsrModel.startsWith("fun-asr", ignoreCase = true) + + running.set(true) + currentTurnText = "" + currentTurnStash = "" + finalTranscript = null + finalResultDeferred = null + finalDelivered.set(false) + + // 在 IO 线程启动 SDK 识别并随后启动采集 + controlJob?.cancel() + controlJob = scope.launch(Dispatchers.IO) { try { - conversation?.updateSession(config) - convReady = true - // 冲刷预缓冲 - flushPrebuffer() + if (useFunAsrModel) { + startFunAsrStreaming() + return@launch + } + + // 根据地域选择 WebSocket URL + val wsUrl = if (prefs.dashRegion.equals("intl", ignoreCase = true)) WS_URL_INTL else WS_URL_CN + + // 构建 OmniRealtimeParam + val param = OmniRealtimeParam.builder() + .model(MODEL_QWEN3) + .apikey(prefs.dashApiKey) + .url(wsUrl) + .build() + + // 构建 OmniRealtimeTranscriptionParam + val transcriptionParam = OmniRealtimeTranscriptionParam() + transcriptionParam.setInputSampleRate(sampleRate) + transcriptionParam.setInputAudioFormat("pcm") + // 可选:设置语言以提升准确度 + val lang = prefs.dashLanguage + if (lang.isNotBlank()) { + transcriptionParam.setLanguage(lang) + } + // 可选:设置语料文本 + val corpus = prefs.dashPrompt + if (corpus.isNotBlank()) { + transcriptionParam.setCorpusText(corpus) + } + + // 构建 OmniRealtimeConfig(关闭服务端 VAD,使用手动模式) + // 手动模式下:text 事件仍实时返回用于预览,用户停止时调用 commit() 触发最终识别 + val config = OmniRealtimeConfig.builder() + .modalities(listOf(OmniRealtimeModality.TEXT)) + .enableTurnDetection(false) // 关闭服务端 VAD + .transcriptionConfig(transcriptionParam) + .build() + + // 创建回调 + val callback = object : OmniRealtimeCallback() { + override fun onOpen() { + Log.d(TAG, "WebSocket opened, updating session config") + try { + conversation?.updateSession(config) + convReady = true + // 冲刷预缓冲 + flushPrebuffer() + } catch (t: Throwable) { + Log.e(TAG, "updateSession failed", t) + } + } + + override fun onEvent(message: JsonObject) { + handleServerEvent(message) + } + + override fun onClose(code: Int, reason: String) { + Log.d(TAG, "WebSocket closed: $code $reason") + if (running.get()) { + // 非预期关闭 + running.set(false) + try { + listener.onError(context.getString(R.string.error_recognize_failed_with_reason, reason)) + } catch (t: Throwable) { + Log.e(TAG, "notify error failed", t) + } + } + } + } + + // 创建并连接(使用构造函数,与官方示例一致) + val conv = OmniRealtimeConversation(param, callback) + conversation = conv + convReady = false + conv.connect() + + // 建立连接后开始推送音频(仅非外部模式) + if (!externalPcmMode) { + startCaptureAndSend() + } } catch (t: Throwable) { - Log.e(TAG, "updateSession failed", t) + Log.e(TAG, "Failed to start DashScope streaming recognition", t) + try { + listener.onError(context.getString(R.string.error_recognize_failed_with_reason, t.message ?: "")) + } catch (notifyError: Throwable) { + Log.e(TAG, "notify error failed", notifyError) + } + running.set(false) + safeClose() } - } - - override fun onEvent(message: JsonObject) { - handleServerEvent(message) - } - - override fun onClose(code: Int, reason: String) { - Log.d(TAG, "WebSocket closed: $code $reason") - if (running.get()) { - // 非预期关闭 - running.set(false) - try { - listener.onError(context.getString(R.string.error_recognize_failed_with_reason, reason)) - } catch (t: Throwable) { - Log.e(TAG, "notify error failed", t) - } - } - } } + } - // 创建并连接(使用构造函数,与官方示例一致) - val conv = OmniRealtimeConversation(param, callback) - conversation = conv - convReady = false - conv.connect() - - // 建立连接后开始推送音频(仅非外部模式) - if (!externalPcmMode) { - startCaptureAndSend() - } - } catch (t: Throwable) { - Log.e(TAG, "Failed to start DashScope streaming recognition", t) + private fun startFunAsrStreaming() { + // Fun-ASR 使用 Recognition SDK:需要通过 Constants.baseWebsocketApiUrl 指定地域 endpoint + val wsUrl = if (prefs.dashRegion.equals("intl", ignoreCase = true)) WS_URL_INFER_INTL else WS_URL_INFER_CN try { - listener.onError(context.getString(R.string.error_recognize_failed_with_reason, t.message ?: "")) - } catch (notifyError: Throwable) { - Log.e(TAG, "notify error failed", notifyError) + Constants.baseWebsocketApiUrl = wsUrl + } catch (t: Throwable) { + Log.w(TAG, "Failed to set baseWebsocketApiUrl", t) } - running.set(false) - safeClose() - } - } - } - - private fun startFunAsrStreaming() { - // Fun-ASR 使用 Recognition SDK:需要通过 Constants.baseWebsocketApiUrl 指定地域 endpoint - val wsUrl = if (prefs.dashRegion.equals("intl", ignoreCase = true)) WS_URL_INFER_INTL else WS_URL_INFER_CN - try { - Constants.baseWebsocketApiUrl = wsUrl - } catch (t: Throwable) { - Log.w(TAG, "Failed to set baseWebsocketApiUrl", t) - } - val builder = RecognitionParam.builder() - .model(MODEL_FUN_ASR) - .apiKey(prefs.dashApiKey) - .format("pcm") - .sampleRate(sampleRate) - - val lang = prefs.dashLanguage.trim() - if (lang.isNotBlank()) { - try { - builder.parameter("language_hints", arrayOf(lang)) - } catch (t: Throwable) { - Log.w(TAG, "Failed to set language_hints", t) - } - } + val builder = RecognitionParam.builder() + .model(MODEL_FUN_ASR) + .apiKey(prefs.dashApiKey) + .format("pcm") + .sampleRate(sampleRate) - // 语义断句:开启时使用 LLM 语义断句,关闭时使用 VAD 断句 - try { - builder.parameter("semantic_punctuation_enabled", prefs.dashFunAsrSemanticPunctEnabled) - } catch (t: Throwable) { - Log.w(TAG, "Failed to set semantic_punctuation_enabled", t) - } + val lang = prefs.dashLanguage.trim() + if (lang.isNotBlank()) { + try { + builder.parameter("language_hints", arrayOf(lang)) + } catch (t: Throwable) { + Log.w(TAG, "Failed to set language_hints", t) + } + } - val param = builder.build() - val rec = Recognition() - recognizer = rec - conversation = null + // 语义断句:开启时使用 LLM 语义断句,关闭时使用 VAD 断句 + try { + builder.parameter("semantic_punctuation_enabled", prefs.dashFunAsrSemanticPunctEnabled) + } catch (t: Throwable) { + Log.w(TAG, "Failed to set semantic_punctuation_enabled", t) + } - convReady = false - val callback = object : ResultCallback() { - override fun onEvent(result: RecognitionResult) { - handleFunAsrEvent(result) - } + val param = builder.build() + val rec = Recognition() + recognizer = rec + conversation = null - override fun onComplete() { - handleFunAsrComplete() - } + convReady = false + val callback = object : ResultCallback() { + override fun onEvent(result: RecognitionResult) { + handleFunAsrEvent(result) + } - override fun onError(e: Exception) { - handleFunAsrError(e) - } - } + override fun onComplete() { + handleFunAsrComplete() + } - rec.call(param, callback) + override fun onError(e: Exception) { + handleFunAsrError(e) + } + } - convReady = true - flushPrebuffer() - if (!externalPcmMode) { - startCaptureAndSend() - } - } - - private fun handleFunAsrEvent(result: RecognitionResult) { - val sentenceText = result.getSentence()?.getText().orEmpty() - if (sentenceText.isBlank()) return - - val isEnd = result.isSentenceEnd - if (isEnd) { - currentTurnText = appendSentence(currentTurnText, sentenceText) - currentTurnStash = "" - } else { - currentTurnStash = sentenceText - } + rec.call(param, callback) - if (!running.get()) return - val preview = (currentTurnText + currentTurnStash).trim() - if (preview.isNotEmpty()) { - try { - listener.onPartial(preview) - } catch (t: Throwable) { - Log.e(TAG, "notify partial failed", t) - } - } - } - - private fun handleFunAsrComplete() { - val finalText = (currentTurnText + currentTurnStash).trim() - finalTranscript = finalText - finalResultDeferred?.complete(finalText) - - if (finalDelivered.compareAndSet(false, true)) { - try { - listener.onFinal(finalText) - } catch (t: Throwable) { - Log.e(TAG, "notify final failed", t) - } - } - } - - private fun handleFunAsrError(e: Exception) { - val msg = e.message ?: "Recognition error" - Log.e(TAG, "Fun-ASR streaming error: $msg", e) - if (running.get()) { - running.set(false) - if (!finalDelivered.get()) { - try { - listener.onError(context.getString(R.string.error_recognize_failed_with_reason, msg)) - } catch (t: Throwable) { - Log.e(TAG, "notify error failed", t) + convReady = true + flushPrebuffer() + if (!externalPcmMode) { + startCaptureAndSend() } - } } - finalResultDeferred?.complete(null) - try { - audioJob?.cancel() - } catch (t: Throwable) { - Log.w(TAG, "cancel audio job after failure failed", t) - } - audioJob = null - safeClose() - } - - private fun appendSentence(existing: String, sentence: String): String { - val s = sentence.trim() - if (s.isEmpty()) return existing - val cur = existing.trim() - if (cur.isEmpty()) return s - val last = cur.last() - val first = s.first() - val needsSpace = last.isAsciiLetterOrDigit() && first.isAsciiLetterOrDigit() - return if (needsSpace) "$cur $s" else cur + s - } - - private fun Char.isAsciiLetterOrDigit(): Boolean { - return (this in 'a'..'z') || (this in 'A'..'Z') || (this in '0'..'9') - } - - /** - * 处理服务端事件 - */ - private fun handleServerEvent(message: JsonObject) { - val eventType = message.get("type")?.asString ?: return - - when (eventType) { - "session.created" -> { - Log.d(TAG, "Session created") - } - "session.updated" -> { - Log.d(TAG, "Session updated") - } - "input_audio_buffer.speech_started" -> { - Log.d(TAG, "Speech started") - } - "input_audio_buffer.speech_stopped" -> { - Log.d(TAG, "Speech stopped") - } - "input_audio_buffer.committed" -> { - Log.d(TAG, "Audio committed") - } - "conversation.item.created" -> { - Log.d(TAG, "Conversation item created") - } - "conversation.item.input_audio_transcription.text" -> { - // 实时识别结果(用于预览) - // text: 已确定的文本(完整,非增量) - // stash: 尚未确定的中间文本 - val text = message.get("text")?.asString ?: "" - val stash = message.get("stash")?.asString ?: "" - if (running.get()) { - // 更新当前文本(用于实时预览) - currentTurnText = text - currentTurnStash = stash + private fun handleFunAsrEvent(result: RecognitionResult) { + val sentenceText = result.getSentence()?.getText().orEmpty() + if (sentenceText.isBlank()) return - // 实时预览 = text + stash - val preview = currentTurnText + currentTurnStash - if (preview.isNotEmpty()) { + val isEnd = result.isSentenceEnd + if (isEnd) { + currentTurnText = appendSentence(currentTurnText, sentenceText) + currentTurnStash = "" + } else { + currentTurnStash = sentenceText + } + + if (!running.get()) return + val preview = (currentTurnText + currentTurnStash).trim() + if (preview.isNotEmpty()) { try { - listener.onPartial(preview) + listener.onPartial(preview) } catch (t: Throwable) { - Log.e(TAG, "notify partial failed", t) + Log.e(TAG, "notify partial failed", t) } - } } - } - "conversation.item.input_audio_transcription.completed" -> { - // 最终识别结果(由 commit() 触发,手动模式下只会有一次) - val transcript = message.get("transcript")?.asString ?: "" - Log.d(TAG, "Transcription completed: $transcript") - - // 保存最终结果(用于 stop() 时返回) - finalTranscript = transcript + } - finalResultDeferred?.complete(transcript) + private fun handleFunAsrComplete() { + val finalText = (currentTurnText + currentTurnStash).trim() + finalTranscript = finalText + finalResultDeferred?.complete(finalText) - // 通知 UI 最终结果 if (finalDelivered.compareAndSet(false, true)) { - try { - listener.onFinal(transcript) - } catch (t: Throwable) { - Log.e(TAG, "notify final failed", t) - } + try { + listener.onFinal(finalText) + } catch (t: Throwable) { + Log.e(TAG, "notify final failed", t) + } } - } - "conversation.item.input_audio_transcription.failed" -> { - // 识别失败 - val error = message.getAsJsonObject("error") - val errorMsg = error?.get("message")?.asString ?: "Transcription failed" - Log.e(TAG, "Transcription failed: $errorMsg") - running.set(false) - if (!finalDelivered.get()) { - try { - listener.onError(context.getString(R.string.error_recognize_failed_with_reason, errorMsg)) - } catch (t: Throwable) { - Log.e(TAG, "notify error failed", t) - } + } + + private fun handleFunAsrError(e: Exception) { + val msg = e.message ?: "Recognition error" + Log.e(TAG, "Fun-ASR streaming error: $msg", e) + if (running.get()) { + running.set(false) + if (!finalDelivered.get()) { + try { + listener.onError(context.getString(R.string.error_recognize_failed_with_reason, msg)) + } catch (t: Throwable) { + Log.e(TAG, "notify error failed", t) + } + } } finalResultDeferred?.complete(null) try { - audioJob?.cancel() + audioJob?.cancel() } catch (t: Throwable) { - Log.w(TAG, "cancel audio job after failure failed", t) + Log.w(TAG, "cancel audio job after failure failed", t) } audioJob = null safeClose() - } - "error" -> { - // 通用错误 - val error = message.getAsJsonObject("error") - val errorMsg = error?.get("message")?.asString ?: "Unknown error" - Log.e(TAG, "Server error: $errorMsg") - if (running.get()) { - running.set(false) - try { - listener.onError(context.getString(R.string.error_recognize_failed_with_reason, errorMsg)) - } catch (t: Throwable) { - Log.e(TAG, "notify error failed", t) - } - } - } - else -> { - Log.d(TAG, "Unknown event: $eventType") - } - } - } - - /** - * 冲刷预缓冲区 - */ - private fun flushPrebuffer() { - var flushed: Array? = null - synchronized(prebufferLock) { - if (prebuffer.isNotEmpty()) { - flushed = prebuffer.toTypedArray() - prebuffer.clear() - } - } - flushed?.forEach { b -> - sendAudioFrame(b) - } - } - - /** - * 发送音频帧(Base64 编码) - */ - private fun sendAudioFrame(audioChunk: ByteArray) { - if (useFunAsrModel) { - try { - recognizer?.sendAudioFrame(ByteBuffer.wrap(audioChunk)) - } catch (t: Throwable) { - Log.e(TAG, "sendAudioFrame failed", t) - } - return - } - try { - val base64Audio = Base64.encodeToString(audioChunk, Base64.NO_WRAP) - conversation?.appendAudio(base64Audio) - } catch (t: Throwable) { - Log.e(TAG, "appendAudio failed", t) - } - } - - // ========== ExternalPcmConsumer(外部推流) ========== - override fun appendPcm(pcm: ByteArray, sampleRate: Int, channels: Int) { - if (!running.get()) return - if (sampleRate != 16000 || channels != 1) return - try { - listener.onAmplitude(calculateNormalizedAmplitude(pcm)) - } catch (t: Throwable) { - Log.w(TAG, "notify amplitude failed", t) } - if (!convReady) { - synchronized(prebufferLock) { prebuffer.addLast(pcm.copyOf()) } - } else { - // 先冲刷预缓冲 - flushPrebuffer() - sendAudioFrame(pcm) + private fun appendSentence(existing: String, sentence: String): String { + val s = sentence.trim() + if (s.isEmpty()) return existing + val cur = existing.trim() + if (cur.isEmpty()) return s + val last = cur.last() + val first = s.first() + val needsSpace = last.isAsciiLetterOrDigit() && first.isAsciiLetterOrDigit() + return if (needsSpace) "$cur $s" else cur + s } - } - override fun stop() { - if (!running.get()) return - running.set(false) + private fun Char.isAsciiLetterOrDigit(): Boolean { + return (this in 'a'..'z') || (this in 'A'..'Z') || (this in '0'..'9') + } - // 先取消音频采集,然后调用 commit() 触发最终识别 - scope.launch(Dispatchers.IO) { - val resultDeferred = CompletableDeferred() - finalResultDeferred = resultDeferred - try { - // 通知 UI:录音阶段结束,可复位麦克风按钮 - try { listener.onStopped() } catch (t: Throwable) { Log.e(TAG, "notify stopped failed", t) } + /** + * 处理服务端事件 + */ + private fun handleServerEvent(message: JsonObject) { + val eventType = message.get("type")?.asString ?: return - // 取消音频采集协程,触发 AudioRecord 释放 - try { - audioJob?.cancel() - // 等待音频采集协程完全结束,确保 AudioRecord 被完全释放 - audioJob?.join() - } catch (t: Throwable) { - Log.w(TAG, "cancel/join audio job failed", t) - } - audioJob = null - - if (useFunAsrModel) { - // Fun-ASR:调用 stop() 触发最终回调(onComplete) - try { - Log.d(TAG, "Calling recognizer.stop() to trigger final recognition") - recognizer?.stop() - } catch (t: Throwable) { - Log.w(TAG, "recognizer.stop() failed", t) - val fallbackText = (currentTurnText + currentTurnStash).trim() - if (finalDelivered.compareAndSet(false, true)) { - try { - listener.onFinal(fallbackText) - } catch (notifyError: Throwable) { - Log.e(TAG, "notify final fallback failed", notifyError) - } + when (eventType) { + "session.created" -> { + Log.d(TAG, "Session created") } - if (!resultDeferred.isCompleted) { - resultDeferred.complete(fallbackText) + "session.updated" -> { + Log.d(TAG, "Session updated") } - } - } else { - // 调用 commit() 触发最终识别(手动模式必需) - // completed 事件会在回调中调用 listener.onFinal() - try { - Log.d(TAG, "Calling commit() to trigger final recognition") - conversation?.commit() - } catch (t: Throwable) { - Log.w(TAG, "commit() failed", t) - // 如果 commit 失败,使用当前预览作为最终结果 - val fallbackText = (currentTurnText + currentTurnStash).trim() - if (finalDelivered.compareAndSet(false, true)) { - try { - listener.onFinal(fallbackText) - } catch (notifyError: Throwable) { - Log.e(TAG, "notify final fallback failed", notifyError) - } + "input_audio_buffer.speech_started" -> { + Log.d(TAG, "Speech started") + } + "input_audio_buffer.speech_stopped" -> { + Log.d(TAG, "Speech stopped") + } + "input_audio_buffer.committed" -> { + Log.d(TAG, "Audio committed") + } + "conversation.item.created" -> { + Log.d(TAG, "Conversation item created") + } + "conversation.item.input_audio_transcription.text" -> { + // 实时识别结果(用于预览) + // text: 已确定的文本(完整,非增量) + // stash: 尚未确定的中间文本 + val text = message.get("text")?.asString ?: "" + val stash = message.get("stash")?.asString ?: "" + + if (running.get()) { + // 更新当前文本(用于实时预览) + currentTurnText = text + currentTurnStash = stash + + // 实时预览 = text + stash + val preview = currentTurnText + currentTurnStash + if (preview.isNotEmpty()) { + try { + listener.onPartial(preview) + } catch (t: Throwable) { + Log.e(TAG, "notify partial failed", t) + } + } + } } - if (!resultDeferred.isCompleted) { - resultDeferred.complete(fallbackText) + "conversation.item.input_audio_transcription.completed" -> { + // 最终识别结果(由 commit() 触发,手动模式下只会有一次) + val transcript = message.get("transcript")?.asString ?: "" + Log.d(TAG, "Transcription completed: $transcript") + + // 保存最终结果(用于 stop() 时返回) + finalTranscript = transcript + + finalResultDeferred?.complete(transcript) + + // 通知 UI 最终结果 + if (finalDelivered.compareAndSet(false, true)) { + try { + listener.onFinal(transcript) + } catch (t: Throwable) { + Log.e(TAG, "notify final failed", t) + } + } + } + "conversation.item.input_audio_transcription.failed" -> { + // 识别失败 + val error = message.getAsJsonObject("error") + val errorMsg = error?.get("message")?.asString ?: "Transcription failed" + Log.e(TAG, "Transcription failed: $errorMsg") + running.set(false) + if (!finalDelivered.get()) { + try { + listener.onError(context.getString(R.string.error_recognize_failed_with_reason, errorMsg)) + } catch (t: Throwable) { + Log.e(TAG, "notify error failed", t) + } + } + finalResultDeferred?.complete(null) + try { + audioJob?.cancel() + } catch (t: Throwable) { + Log.w(TAG, "cancel audio job after failure failed", t) + } + audioJob = null + safeClose() + } + "error" -> { + // 通用错误 + val error = message.getAsJsonObject("error") + val errorMsg = error?.get("message")?.asString ?: "Unknown error" + Log.e(TAG, "Server error: $errorMsg") + if (running.get()) { + running.set(false) + try { + listener.onError(context.getString(R.string.error_recognize_failed_with_reason, errorMsg)) + } catch (t: Throwable) { + Log.e(TAG, "notify error failed", t) + } + } + } + else -> { + Log.d(TAG, "Unknown event: $eventType") } - } } + } - // 等待 completed 事件返回或超时 - val awaited = withTimeoutOrNull(FINAL_RESULT_TIMEOUT_MS) { resultDeferred.await() } - if (awaited == null && finalDelivered.compareAndSet(false, true)) { - // 超时后使用当前文本作为兜底结果 - val fallbackText = (finalTranscript ?: (currentTurnText + currentTurnStash)).trim() - try { - listener.onFinal(fallbackText) - } catch (notifyError: Throwable) { - Log.e(TAG, "notify final timeout fallback failed", notifyError) - } + /** + * 冲刷预缓冲区 + */ + private fun flushPrebuffer() { + var flushed: Array? = null + synchronized(prebufferLock) { + if (prebuffer.isNotEmpty()) { + flushed = prebuffer.toTypedArray() + prebuffer.clear() + } } - } catch (t: Throwable) { - Log.w(TAG, "stop cleanup failed", t) - } finally { - if (!resultDeferred.isCompleted) { - resultDeferred.complete(finalTranscript) + flushed?.forEach { b -> + sendAudioFrame(b) } - finalResultDeferred = null - safeClose() - } } - } - - private fun startCaptureAndSend() { - audioJob?.cancel() - audioJob = scope.launch(Dispatchers.IO) { - val chunkMillis = 100 // 建议 100ms 左右 - val audioManager = AudioCaptureManager( - context = context, - sampleRate = sampleRate, - channelConfig = channelConfig, - audioFormat = audioFormat, - chunkMillis = chunkMillis - ) - - if (!audioManager.hasPermission()) { - Log.e(TAG, "Missing RECORD_AUDIO permission") - listener.onError(context.getString(R.string.error_record_permission_denied)) - running.set(false) - return@launch - } - - val vadDetector = if (isVadAutoStopEnabled(context, prefs)) - VadDetector(context, sampleRate, prefs.autoStopSilenceWindowMs, prefs.autoStopSilenceSensitivity) - else null - - try { - audioManager.startCapture().collect { audioChunk -> - if (!running.get()) return@collect - - // Calculate and send audio amplitude (for waveform animation) - try { - val amplitude = calculateNormalizedAmplitude(audioChunk) - listener.onAmplitude(amplitude) - } catch (t: Throwable) { - Log.w(TAG, "Failed to calculate amplitude", t) - } - - // 客户端 VAD 自动停止(可选,与服务端 VAD 独立) - if (vadDetector?.shouldStop(audioChunk, audioChunk.size) == true) { - Log.d(TAG, "Client VAD: silence detected, stopping recording") - try { listener.onStopped() } catch (t: Throwable) { Log.e(TAG, "notify stopped failed", t) } - stop() - return@collect - } - - // 发送音频 - if (!convReady) { - synchronized(prebufferLock) { prebuffer.addLast(audioChunk.copyOf()) } - } else { - flushPrebuffer() - sendAudioFrame(audioChunk) - } + + /** + * 发送音频帧(Base64 编码) + */ + private fun sendAudioFrame(audioChunk: ByteArray) { + if (useFunAsrModel) { + try { + recognizer?.sendAudioFrame(ByteBuffer.wrap(audioChunk)) + } catch (t: Throwable) { + Log.e(TAG, "sendAudioFrame failed", t) + } + return } - } catch (t: Throwable) { - if (t is kotlinx.coroutines.CancellationException) { - Log.d(TAG, "Audio streaming cancelled: ${t.message}") - } else { - Log.e(TAG, "Audio streaming failed: ${t.message}", t) - listener.onError(context.getString(R.string.error_audio_error, t.message ?: "")) + try { + val base64Audio = Base64.encodeToString(audioChunk, Base64.NO_WRAP) + conversation?.appendAudio(base64Audio) + } catch (t: Throwable) { + Log.e(TAG, "appendAudio failed", t) } - } finally { + } + + // ========== ExternalPcmConsumer(外部推流) ========== + override fun appendPcm(pcm: ByteArray, sampleRate: Int, channels: Int) { + if (!running.get()) return + if (sampleRate != 16000 || channels != 1) return try { - vadDetector?.release() + listener.onAmplitude(calculateNormalizedAmplitude(pcm)) } catch (t: Throwable) { - Log.w(TAG, "VAD release failed", t) + Log.w(TAG, "notify amplitude failed", t) + } + + if (!convReady) { + synchronized(prebufferLock) { prebuffer.addLast(pcm.copyOf()) } + } else { + // 先冲刷预缓冲 + flushPrebuffer() + sendAudioFrame(pcm) } - } } - } - - private fun safeClose() { - convReady = false - try { - conversation?.close() - } catch (t: Throwable) { - Log.w(TAG, "conversation close failed", t) - } finally { - conversation = null + + override fun stop() { + if (!running.get()) return + running.set(false) + + // 先取消音频采集,然后调用 commit() 触发最终识别 + scope.launch(Dispatchers.IO) { + val resultDeferred = CompletableDeferred() + finalResultDeferred = resultDeferred + try { + // 通知 UI:录音阶段结束,可复位麦克风按钮 + try { + listener.onStopped() + } catch (t: Throwable) { + Log.e(TAG, "notify stopped failed", t) + } + + // 取消音频采集协程,触发 AudioRecord 释放 + try { + audioJob?.cancel() + // 等待音频采集协程完全结束,确保 AudioRecord 被完全释放 + audioJob?.join() + } catch (t: Throwable) { + Log.w(TAG, "cancel/join audio job failed", t) + } + audioJob = null + + if (useFunAsrModel) { + // Fun-ASR:调用 stop() 触发最终回调(onComplete) + try { + Log.d(TAG, "Calling recognizer.stop() to trigger final recognition") + recognizer?.stop() + } catch (t: Throwable) { + Log.w(TAG, "recognizer.stop() failed", t) + val fallbackText = (currentTurnText + currentTurnStash).trim() + if (finalDelivered.compareAndSet(false, true)) { + try { + listener.onFinal(fallbackText) + } catch (notifyError: Throwable) { + Log.e(TAG, "notify final fallback failed", notifyError) + } + } + if (!resultDeferred.isCompleted) { + resultDeferred.complete(fallbackText) + } + } + } else { + // 调用 commit() 触发最终识别(手动模式必需) + // completed 事件会在回调中调用 listener.onFinal() + try { + Log.d(TAG, "Calling commit() to trigger final recognition") + conversation?.commit() + } catch (t: Throwable) { + Log.w(TAG, "commit() failed", t) + // 如果 commit 失败,使用当前预览作为最终结果 + val fallbackText = (currentTurnText + currentTurnStash).trim() + if (finalDelivered.compareAndSet(false, true)) { + try { + listener.onFinal(fallbackText) + } catch (notifyError: Throwable) { + Log.e(TAG, "notify final fallback failed", notifyError) + } + } + if (!resultDeferred.isCompleted) { + resultDeferred.complete(fallbackText) + } + } + } + + // 等待 completed 事件返回或超时 + val awaited = withTimeoutOrNull(FINAL_RESULT_TIMEOUT_MS) { resultDeferred.await() } + if (awaited == null && finalDelivered.compareAndSet(false, true)) { + // 超时后使用当前文本作为兜底结果 + val fallbackText = (finalTranscript ?: (currentTurnText + currentTurnStash)).trim() + try { + listener.onFinal(fallbackText) + } catch (notifyError: Throwable) { + Log.e(TAG, "notify final timeout fallback failed", notifyError) + } + } + } catch (t: Throwable) { + Log.w(TAG, "stop cleanup failed", t) + } finally { + if (!resultDeferred.isCompleted) { + resultDeferred.complete(finalTranscript) + } + finalResultDeferred = null + safeClose() + } + } } - try { - recognizer?.stop() - } catch (t: Throwable) { - Log.w(TAG, "recognizer stop failed", t) - } finally { - recognizer = null + private fun startCaptureAndSend() { + audioJob?.cancel() + audioJob = scope.launch(Dispatchers.IO) { + val chunkMillis = 100 // 建议 100ms 左右 + val audioManager = AudioCaptureManager( + context = context, + sampleRate = sampleRate, + channelConfig = channelConfig, + audioFormat = audioFormat, + chunkMillis = chunkMillis, + ) + + if (!audioManager.hasPermission()) { + Log.e(TAG, "Missing RECORD_AUDIO permission") + listener.onError(context.getString(R.string.error_record_permission_denied)) + running.set(false) + return@launch + } + + val vadDetector = if (isVadAutoStopEnabled(context, prefs)) { + VadDetector(context, sampleRate, prefs.autoStopSilenceWindowMs, prefs.autoStopSilenceSensitivity) + } else { + null + } + + try { + audioManager.startCapture().collect { audioChunk -> + if (!running.get()) return@collect + + // Calculate and send audio amplitude (for waveform animation) + try { + val amplitude = calculateNormalizedAmplitude(audioChunk) + listener.onAmplitude(amplitude) + } catch (t: Throwable) { + Log.w(TAG, "Failed to calculate amplitude", t) + } + + // 客户端 VAD 自动停止(可选,与服务端 VAD 独立) + if (vadDetector?.shouldStop(audioChunk, audioChunk.size) == true) { + Log.d(TAG, "Client VAD: silence detected, stopping recording") + try { + listener.onStopped() + } catch (t: Throwable) { + Log.e(TAG, "notify stopped failed", t) + } + stop() + return@collect + } + + // 发送音频 + if (!convReady) { + synchronized(prebufferLock) { prebuffer.addLast(audioChunk.copyOf()) } + } else { + flushPrebuffer() + sendAudioFrame(audioChunk) + } + } + } catch (t: Throwable) { + if (t is kotlinx.coroutines.CancellationException) { + Log.d(TAG, "Audio streaming cancelled: ${t.message}") + } else { + Log.e(TAG, "Audio streaming failed: ${t.message}", t) + listener.onError(context.getString(R.string.error_audio_error, t.message ?: "")) + } + } finally { + try { + vadDetector?.release() + } catch (t: Throwable) { + Log.w(TAG, "VAD release failed", t) + } + } + } + } + + private fun safeClose() { + convReady = false + try { + conversation?.close() + } catch (t: Throwable) { + Log.w(TAG, "conversation close failed", t) + } finally { + conversation = null + } + + try { + recognizer?.stop() + } catch (t: Throwable) { + Log.w(TAG, "recognizer stop failed", t) + } finally { + recognizer = null + } } - } } diff --git a/app/src/main/java/com/brycewg/asrkb/asr/ElevenLabsFileAsrEngine.kt b/app/src/main/java/com/brycewg/asrkb/asr/ElevenLabsFileAsrEngine.kt index 7df893c3..bc1426a5 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/ElevenLabsFileAsrEngine.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/ElevenLabsFileAsrEngine.kt @@ -24,7 +24,7 @@ class ElevenLabsFileAsrEngine( prefs: Prefs, listener: StreamingAsrEngine.Listener, onRequestDuration: ((Long) -> Unit)? = null, - httpClient: OkHttpClient? = null + httpClient: OkHttpClient? = null, ) : BaseFileAsrEngine(context, scope, prefs, listener, onRequestDuration), PcmBatchRecognizer { companion object { @@ -59,7 +59,7 @@ class ElevenLabsFileAsrEngine( .addFormDataPart( "file", "audio.wav", - tmp.asRequestBody("audio/wav".toMediaType()) + tmp.asRequestBody("audio/wav".toMediaType()), ) .addFormDataPart("model_id", MODEL_ID) .addFormDataPart("tag_audio_events", "false") @@ -83,14 +83,16 @@ class ElevenLabsFileAsrEngine( val extra = extractErrorHint(bodyStr) val detail = formatHttpDetail(r.message, extra) listener.onError( - context.getString(R.string.error_request_failed_http, r.code, detail) + context.getString(R.string.error_request_failed_http, r.code, detail), ) return } val text = parseTextFromResponse(bodyStr) if (text.isNotBlank()) { val dt = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - t0) - try { onRequestDuration?.invoke(dt) } catch (_: Throwable) {} + try { + onRequestDuration?.invoke(dt) + } catch (_: Throwable) {} listener.onFinal(text) } else { listener.onError(context.getString(R.string.error_asr_empty_result)) @@ -98,12 +100,14 @@ class ElevenLabsFileAsrEngine( } } catch (t: Throwable) { listener.onError( - context.getString(R.string.error_recognize_failed_with_reason, t.message ?: "") + context.getString(R.string.error_recognize_failed_with_reason, t.message ?: ""), ) } } - override suspend fun recognizeFromPcm(pcm: ByteArray) { recognize(pcm) } + override suspend fun recognizeFromPcm(pcm: ByteArray) { + recognize(pcm) + } /** * 从响应体中提取错误提示信息 @@ -163,7 +167,9 @@ class ElevenLabsFileAsrEngine( if (t.isNotEmpty()) list.add(t) } list.joinToString("\n").trim() - } else "" + } else { + "" + } } else -> "" } diff --git a/app/src/main/java/com/brycewg/asrkb/asr/ElevenLabsStreamAsrEngine.kt b/app/src/main/java/com/brycewg/asrkb/asr/ElevenLabsStreamAsrEngine.kt index 49efae6a..6c4fd304 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/ElevenLabsStreamAsrEngine.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/ElevenLabsStreamAsrEngine.kt @@ -40,7 +40,7 @@ class ElevenLabsStreamAsrEngine( private val scope: CoroutineScope, private val prefs: Prefs, private val listener: StreamingAsrEngine.Listener, - private val externalPcmMode: Boolean = false + private val externalPcmMode: Boolean = false, ) : StreamingAsrEngine, ExternalPcmConsumer { companion object { @@ -143,55 +143,58 @@ class ElevenLabsStreamAsrEngine( .addHeader("xi-api-key", prefs.elevenApiKey.trim()) .build() - ws = httpClient.newWebSocket(request, object : WebSocketListener() { - override fun onOpen(webSocket: WebSocket, response: Response) { - Log.d(TAG, "WebSocket opened: $response") - wsReady.set(true) - flushPrebuffer() - } + ws = httpClient.newWebSocket( + request, + object : WebSocketListener() { + override fun onOpen(webSocket: WebSocket, response: Response) { + Log.d(TAG, "WebSocket opened: $response") + wsReady.set(true) + flushPrebuffer() + } - override fun onMessage(webSocket: WebSocket, text: String) { - handleMessage(text) - } + override fun onMessage(webSocket: WebSocket, text: String) { + handleMessage(text) + } - override fun onClosing(webSocket: WebSocket, code: Int, reason: String) { - Log.d(TAG, "WebSocket closing: code=$code reason=$reason") - } + override fun onClosing(webSocket: WebSocket, code: Int, reason: String) { + Log.d(TAG, "WebSocket closing: code=$code reason=$reason") + } - override fun onClosed(webSocket: WebSocket, code: Int, reason: String) { - Log.d(TAG, "WebSocket closed: code=$code reason=$reason") - ws = null - running.set(false) - audioJob?.cancel() - audioJob = null - if (closingByUser.get() || !running.get()) { - emitFinalIfNeeded("closed") - } else if (!finalEmitted.get()) { - listener.onError( - context.getString( - R.string.error_recognize_failed_with_reason, - "connection closed" + override fun onClosed(webSocket: WebSocket, code: Int, reason: String) { + Log.d(TAG, "WebSocket closed: code=$code reason=$reason") + ws = null + running.set(false) + audioJob?.cancel() + audioJob = null + if (closingByUser.get() || !running.get()) { + emitFinalIfNeeded("closed") + } else if (!finalEmitted.get()) { + listener.onError( + context.getString( + R.string.error_recognize_failed_with_reason, + "connection closed", + ), ) - ) + } } - } - override fun onFailure(webSocket: WebSocket, t: Throwable, response: Response?) { - Log.e(TAG, "WebSocket failure: ${t.message}", t) - ws = null - audioJob?.cancel() - audioJob = null - if (closingByUser.get()) { - emitFinalIfNeeded("failure_after_stop") - return + override fun onFailure(webSocket: WebSocket, t: Throwable, response: Response?) { + Log.e(TAG, "WebSocket failure: ${t.message}", t) + ws = null + audioJob?.cancel() + audioJob = null + if (closingByUser.get()) { + emitFinalIfNeeded("failure_after_stop") + return + } + val detail = response?.message ?: t.message.orEmpty() + listener.onError( + context.getString(R.string.error_recognize_failed_with_reason, detail), + ) + running.set(false) } - val detail = response?.message ?: t.message.orEmpty() - listener.onError( - context.getString(R.string.error_recognize_failed_with_reason, detail) - ) - running.set(false) - } - }) + }, + ) } private fun startCaptureAndStream() { @@ -203,7 +206,7 @@ class ElevenLabsStreamAsrEngine( sampleRate = sampleRate, channelConfig = channelConfig, audioFormat = audioFormat, - chunkMillis = chunkMillis + chunkMillis = chunkMillis, ) if (!audioManager.hasPermission()) { @@ -218,9 +221,11 @@ class ElevenLabsStreamAsrEngine( context, sampleRate, prefs.autoStopSilenceWindowMs, - prefs.autoStopSilenceSensitivity + prefs.autoStopSilenceSensitivity, ) - } else null + } else { + null + } val maxFrames = (2000 / chunkMillis).coerceAtLeast(1) try { @@ -259,8 +264,8 @@ class ElevenLabsStreamAsrEngine( listener.onError( context.getString( R.string.error_audio_error, - t.message ?: "" - ) + t.message ?: "", + ), ) stop() } @@ -347,7 +352,7 @@ class ElevenLabsStreamAsrEngine( "error", "auth_error" -> { val err = obj.optString("error").ifBlank { "unknown" } listener.onError( - context.getString(R.string.error_recognize_failed_with_reason, err) + context.getString(R.string.error_recognize_failed_with_reason, err), ) stop() } @@ -396,7 +401,7 @@ class ElevenLabsStreamAsrEngine( private fun hasRecordPermission(): Boolean { return ContextCompat.checkSelfPermission( context, - Manifest.permission.RECORD_AUDIO + Manifest.permission.RECORD_AUDIO, ) == PackageManager.PERMISSION_GRANTED } } diff --git a/app/src/main/java/com/brycewg/asrkb/asr/FunAsrNanoFileAsrEngine.kt b/app/src/main/java/com/brycewg/asrkb/asr/FunAsrNanoFileAsrEngine.kt index a032e5ae..310eedc2 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/FunAsrNanoFileAsrEngine.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/FunAsrNanoFileAsrEngine.kt @@ -25,247 +25,247 @@ import java.io.File * - 目前仅支持 int8 版本。 */ class FunAsrNanoFileAsrEngine( - context: Context, - scope: CoroutineScope, - prefs: Prefs, - listener: StreamingAsrEngine.Listener, - onRequestDuration: ((Long) -> Unit)? = null + context: Context, + scope: CoroutineScope, + prefs: Prefs, + listener: StreamingAsrEngine.Listener, + onRequestDuration: ((Long) -> Unit)? = null, ) : BaseFileAsrEngine(context, scope, prefs, listener, onRequestDuration), PcmBatchRecognizer { - // FunASR Nano 本地:同 SenseVoice/TeleSpeech,默认限制为 5 分钟以控制内存与处理时长 - override val maxRecordDurationMillis: Int = 5 * 60 * 1000 + // FunASR Nano 本地:同 SenseVoice/TeleSpeech,默认限制为 5 分钟以控制内存与处理时长 + override val maxRecordDurationMillis: Int = 5 * 60 * 1000 - private fun showToast(resId: Int) { - try { - Handler(Looper.getMainLooper()).post { + private fun showToast(resId: Int) { try { - Toast.makeText(context, context.getString(resId), Toast.LENGTH_SHORT).show() + Handler(Looper.getMainLooper()).post { + try { + Toast.makeText(context, context.getString(resId), Toast.LENGTH_SHORT).show() + } catch (t: Throwable) { + Log.e("FunAsrNanoFileAsrEngine", "Failed to show toast", t) + } + } } catch (t: Throwable) { - Log.e("FunAsrNanoFileAsrEngine", "Failed to show toast", t) + Log.e("FunAsrNanoFileAsrEngine", "Failed to post toast", t) } - } - } catch (t: Throwable) { - Log.e("FunAsrNanoFileAsrEngine", "Failed to post toast", t) } - } - - private fun notifyLoadStart() { - val ui = (listener as? SenseVoiceFileAsrEngine.LocalModelLoadUi) - if (ui != null) { - try { - ui.onLocalModelLoadStart() - } catch (t: Throwable) { - Log.e("FunAsrNanoFileAsrEngine", "Failed to notify load start", t) - } - } else { - showToast(R.string.sv_loading_model) + + private fun notifyLoadStart() { + val ui = (listener as? SenseVoiceFileAsrEngine.LocalModelLoadUi) + if (ui != null) { + try { + ui.onLocalModelLoadStart() + } catch (t: Throwable) { + Log.e("FunAsrNanoFileAsrEngine", "Failed to notify load start", t) + } + } else { + showToast(R.string.sv_loading_model) + } } - } - - private fun notifyLoadDone() { - val ui = (listener as? SenseVoiceFileAsrEngine.LocalModelLoadUi) - if (ui != null) { - try { - ui.onLocalModelLoadDone() - } catch (t: Throwable) { - Log.e("FunAsrNanoFileAsrEngine", "Failed to notify load done", t) - } + + private fun notifyLoadDone() { + val ui = (listener as? SenseVoiceFileAsrEngine.LocalModelLoadUi) + if (ui != null) { + try { + ui.onLocalModelLoadDone() + } catch (t: Throwable) { + Log.e("FunAsrNanoFileAsrEngine", "Failed to notify load done", t) + } + } } - } - override fun ensureReady(): Boolean { - if (!super.ensureReady()) return false - val manager = FunAsrNanoOnnxManager.getInstance() - if (!manager.isOnnxAvailable()) { - try { - listener.onError(context.getString(R.string.error_local_asr_not_ready)) - } catch (t: Throwable) { - Log.e("FunAsrNanoFileAsrEngine", "Failed to send error callback", t) - } - return false + override fun ensureReady(): Boolean { + if (!super.ensureReady()) return false + val manager = FunAsrNanoOnnxManager.getInstance() + if (!manager.isOnnxAvailable()) { + try { + listener.onError(context.getString(R.string.error_local_asr_not_ready)) + } catch (t: Throwable) { + Log.e("FunAsrNanoFileAsrEngine", "Failed to send error callback", t) + } + return false + } + return true } - return true - } - override suspend fun recognize(pcm: ByteArray) { - val t0 = System.currentTimeMillis() - try { - val manager = FunAsrNanoOnnxManager.getInstance() - if (!manager.isOnnxAvailable()) { - listener.onError(context.getString(R.string.error_local_asr_not_ready)) - return - } - - val base = try { - context.getExternalFilesDir(null) - } catch (t: Throwable) { - Log.w("FunAsrNanoFileAsrEngine", "Failed to get external files dir", t) - null - } ?: context.filesDir - - val probeRoot = File(base, "funasr_nano") - val variantDir = File(probeRoot, "nano-int8") - val modelDir = findFnModelDir(variantDir) ?: findFnModelDir(probeRoot) - if (modelDir == null) { - listener.onError(context.getString(R.string.error_funasr_model_missing)) - return - } - - val encoderAdaptor = File(modelDir, "encoder_adaptor.int8.onnx") - val llm = File(modelDir, "llm.int8.onnx") - val embedding = File(modelDir, "embedding.int8.onnx") - val tokenizerDir = findFnTokenizerDir(modelDir) - - val minOnnxBytes = 8L * 1024L * 1024L - val minLlmBytes = 32L * 1024L * 1024L - val tokenizerJsonOk = tokenizerDir?.let { File(it, "tokenizer.json").exists() } == true - if ( - !encoderAdaptor.exists() || - !embedding.exists() || - !llm.exists() || - tokenizerDir == null || - !tokenizerJsonOk || - encoderAdaptor.length() < minOnnxBytes || - embedding.length() < minOnnxBytes || - llm.length() < minLlmBytes - ) { - listener.onError(context.getString(R.string.error_funasr_model_missing)) - return - } - - val samples = pcmToFloatArray(pcm) - if (samples.isEmpty()) { - listener.onError(context.getString(R.string.error_audio_empty)) - return - } - - val keepMinutes = try { - prefs.fnKeepAliveMinutes - } catch (t: Throwable) { - Log.w("FunAsrNanoFileAsrEngine", "Failed to get keep alive minutes", t) - -1 - } - val keepMs = if (keepMinutes <= 0) 0L else keepMinutes.toLong() * 60_000L - val alwaysKeep = keepMinutes < 0 - - val userPrompt = try { - prefs.fnUserPrompt.trim().ifBlank { "语音转写:" } - } catch (t: Throwable) { - Log.w("FunAsrNanoFileAsrEngine", "Failed to get fnUserPrompt", t) - "语音转写:" - } - - val text = manager.decodeOffline( - assetManager = null, - encoderAdaptor = encoderAdaptor.absolutePath, - llm = llm.absolutePath, - embedding = embedding.absolutePath, - tokenizerDir = tokenizerDir.absolutePath, - userPrompt = userPrompt, - provider = "cpu", - numThreads = try { - prefs.fnNumThreads - } catch (t: Throwable) { - Log.w("FunAsrNanoFileAsrEngine", "Failed to get num threads", t) - 2 - }, - samples = samples, - sampleRate = sampleRate, - keepAliveMs = keepMs, - alwaysKeep = alwaysKeep, - onLoadStart = { notifyLoadStart() }, - onLoadDone = { notifyLoadDone() } - ) - - if (text.isNullOrBlank()) { - listener.onError(context.getString(R.string.error_asr_empty_result)) - } else { - val raw = text.trim() - val useItn = try { - prefs.fnUseItn + override suspend fun recognize(pcm: ByteArray) { + val t0 = System.currentTimeMillis() + try { + val manager = FunAsrNanoOnnxManager.getInstance() + if (!manager.isOnnxAvailable()) { + listener.onError(context.getString(R.string.error_local_asr_not_ready)) + return + } + + val base = try { + context.getExternalFilesDir(null) + } catch (t: Throwable) { + Log.w("FunAsrNanoFileAsrEngine", "Failed to get external files dir", t) + null + } ?: context.filesDir + + val probeRoot = File(base, "funasr_nano") + val variantDir = File(probeRoot, "nano-int8") + val modelDir = findFnModelDir(variantDir) ?: findFnModelDir(probeRoot) + if (modelDir == null) { + listener.onError(context.getString(R.string.error_funasr_model_missing)) + return + } + + val encoderAdaptor = File(modelDir, "encoder_adaptor.int8.onnx") + val llm = File(modelDir, "llm.int8.onnx") + val embedding = File(modelDir, "embedding.int8.onnx") + val tokenizerDir = findFnTokenizerDir(modelDir) + + val minOnnxBytes = 8L * 1024L * 1024L + val minLlmBytes = 32L * 1024L * 1024L + val tokenizerJsonOk = tokenizerDir?.let { File(it, "tokenizer.json").exists() } == true + if ( + !encoderAdaptor.exists() || + !embedding.exists() || + !llm.exists() || + tokenizerDir == null || + !tokenizerJsonOk || + encoderAdaptor.length() < minOnnxBytes || + embedding.length() < minOnnxBytes || + llm.length() < minLlmBytes + ) { + listener.onError(context.getString(R.string.error_funasr_model_missing)) + return + } + + val samples = pcmToFloatArray(pcm) + if (samples.isEmpty()) { + listener.onError(context.getString(R.string.error_audio_empty)) + return + } + + val keepMinutes = try { + prefs.fnKeepAliveMinutes + } catch (t: Throwable) { + Log.w("FunAsrNanoFileAsrEngine", "Failed to get keep alive minutes", t) + -1 + } + val keepMs = if (keepMinutes <= 0) 0L else keepMinutes.toLong() * 60_000L + val alwaysKeep = keepMinutes < 0 + + val userPrompt = try { + prefs.fnUserPrompt.trim().ifBlank { "语音转写:" } + } catch (t: Throwable) { + Log.w("FunAsrNanoFileAsrEngine", "Failed to get fnUserPrompt", t) + "语音转写:" + } + + val text = manager.decodeOffline( + assetManager = null, + encoderAdaptor = encoderAdaptor.absolutePath, + llm = llm.absolutePath, + embedding = embedding.absolutePath, + tokenizerDir = tokenizerDir.absolutePath, + userPrompt = userPrompt, + provider = "cpu", + numThreads = try { + prefs.fnNumThreads + } catch (t: Throwable) { + Log.w("FunAsrNanoFileAsrEngine", "Failed to get num threads", t) + 2 + }, + samples = samples, + sampleRate = sampleRate, + keepAliveMs = keepMs, + alwaysKeep = alwaysKeep, + onLoadStart = { notifyLoadStart() }, + onLoadDone = { notifyLoadDone() }, + ) + + if (text.isNullOrBlank()) { + listener.onError(context.getString(R.string.error_asr_empty_result)) + } else { + val raw = text.trim() + val useItn = try { + prefs.fnUseItn + } catch (t: Throwable) { + Log.w("FunAsrNanoFileAsrEngine", "Failed to get fnUseItn", t) + false + } + val finalText = if (useItn) ChineseItn.normalize(raw) else raw + listener.onFinal(finalText) + } } catch (t: Throwable) { - Log.w("FunAsrNanoFileAsrEngine", "Failed to get fnUseItn", t) - false + Log.e("FunAsrNanoFileAsrEngine", "Recognition failed", t) + listener.onError(context.getString(R.string.error_recognize_failed_with_reason, t.message ?: "")) + } finally { + val dt = System.currentTimeMillis() - t0 + try { + onRequestDuration?.invoke(dt) + } catch (t: Throwable) { + Log.e("FunAsrNanoFileAsrEngine", "Failed to invoke duration callback", t) + } } - val finalText = if (useItn) ChineseItn.normalize(raw) else raw - listener.onFinal(finalText) - } - } catch (t: Throwable) { - Log.e("FunAsrNanoFileAsrEngine", "Recognition failed", t) - listener.onError(context.getString(R.string.error_recognize_failed_with_reason, t.message ?: "")) - } finally { - val dt = System.currentTimeMillis() - t0 - try { - onRequestDuration?.invoke(dt) - } catch (t: Throwable) { - Log.e("FunAsrNanoFileAsrEngine", "Failed to invoke duration callback", t) - } } - } - - override suspend fun recognizeFromPcm(pcm: ByteArray) { - recognize(pcm) - } - - private fun pcmToFloatArray(pcm: ByteArray): FloatArray { - if (pcm.isEmpty()) return FloatArray(0) - val n = pcm.size / 2 - val out = FloatArray(n) - val bb = java.nio.ByteBuffer.wrap(pcm).order(java.nio.ByteOrder.LITTLE_ENDIAN) - var i = 0 - while (i < n) { - val s = bb.short.toInt() - var f = s / 32768.0f - if (f > 1f) f = 1f else if (f < -1f) f = -1f - out[i] = f - i++ + + override suspend fun recognizeFromPcm(pcm: ByteArray) { + recognize(pcm) + } + + private fun pcmToFloatArray(pcm: ByteArray): FloatArray { + if (pcm.isEmpty()) return FloatArray(0) + val n = pcm.size / 2 + val out = FloatArray(n) + val bb = java.nio.ByteBuffer.wrap(pcm).order(java.nio.ByteOrder.LITTLE_ENDIAN) + var i = 0 + while (i < n) { + val s = bb.short.toInt() + var f = s / 32768.0f + if (f > 1f) f = 1f else if (f < -1f) f = -1f + out[i] = f + i++ + } + return out } - return out - } } // 公开卸载入口:供设置页在清除模型后释放本地识别器内存 fun unloadFunAsrNanoRecognizer() { - LocalModelLoadCoordinator.cancel() - FunAsrNanoOnnxManager.getInstance().unload() + LocalModelLoadCoordinator.cancel() + FunAsrNanoOnnxManager.getInstance().unload() } // 判断是否已有缓存的本地识别器(已加载或正在加载中) fun isFunAsrNanoPrepared(): Boolean { - val manager = FunAsrNanoOnnxManager.getInstance() - return manager.isPrepared() || manager.isPreparing() + val manager = FunAsrNanoOnnxManager.getInstance() + return manager.isPrepared() || manager.isPreparing() } // FunASR Nano 模型目录探测: // - 官方包内不含 tokens.txt;通过 onnx + tokenizer 目录判定(最多一层) fun findFnModelDir(root: File?): File? { - if (root == null || !root.exists()) return null - if (isFnModelDir(root)) return root - val subs = root.listFiles() ?: return null - for (f in subs) { - if (f.isDirectory && isFnModelDir(f)) return f - } - return null + if (root == null || !root.exists()) return null + if (isFnModelDir(root)) return root + val subs = root.listFiles() ?: return null + for (f in subs) { + if (f.isDirectory && isFnModelDir(f)) return f + } + return null } fun findFnTokenizerDir(modelDir: File): File? { - val direct = File(modelDir, "tokenizer.json") - if (direct.exists()) return modelDir - val qwen = File(modelDir, "Qwen3-0.6B") - if (File(qwen, "tokenizer.json").exists()) return qwen - val subs = modelDir.listFiles() ?: return null - for (f in subs) { - if (f.isDirectory && File(f, "tokenizer.json").exists()) return f - } - return null + val direct = File(modelDir, "tokenizer.json") + if (direct.exists()) return modelDir + val qwen = File(modelDir, "Qwen3-0.6B") + if (File(qwen, "tokenizer.json").exists()) return qwen + val subs = modelDir.listFiles() ?: return null + for (f in subs) { + if (f.isDirectory && File(f, "tokenizer.json").exists()) return f + } + return null } private fun isFnModelDir(dir: File): Boolean { - val encoderAdaptor = File(dir, "encoder_adaptor.int8.onnx") - val llm = File(dir, "llm.int8.onnx") - val embedding = File(dir, "embedding.int8.onnx") - if (!encoderAdaptor.exists() || !llm.exists() || !embedding.exists()) return false - return findFnTokenizerDir(dir) != null + val encoderAdaptor = File(dir, "encoder_adaptor.int8.onnx") + val llm = File(dir, "llm.int8.onnx") + val embedding = File(dir, "embedding.int8.onnx") + if (!encoderAdaptor.exists() || !llm.exists() || !embedding.exists()) return false + return findFnTokenizerDir(dir) != null } /** @@ -273,504 +273,506 @@ private fun isFnModelDir(dir: File): Boolean { */ class FunAsrNanoOnnxManager private constructor() { - companion object { - private const val TAG = "FunAsrNanoOnnxManager" + companion object { + private const val TAG = "FunAsrNanoOnnxManager" - @Volatile - private var instance: FunAsrNanoOnnxManager? = null + @Volatile + private var instance: FunAsrNanoOnnxManager? = null - fun getInstance(): FunAsrNanoOnnxManager { - return instance ?: synchronized(this) { - instance ?: FunAsrNanoOnnxManager().also { instance = it } - } + fun getInstance(): FunAsrNanoOnnxManager { + return instance ?: synchronized(this) { + instance ?: FunAsrNanoOnnxManager().also { instance = it } + } + } } - } - private val scope = CoroutineScope(SupervisorJob()) - private val mutex = Mutex() + private val scope = CoroutineScope(SupervisorJob()) + private val mutex = Mutex() - @Volatile - private var cachedConfig: RecognizerConfig? = null + @Volatile + private var cachedConfig: RecognizerConfig? = null - @Volatile - private var cachedRecognizer: ReflectiveRecognizer? = null + @Volatile + private var cachedRecognizer: ReflectiveRecognizer? = null - @Volatile - private var preparing: Boolean = false + @Volatile + private var preparing: Boolean = false - @Volatile - private var clsOfflineRecognizer: Class<*>? = null + @Volatile + private var clsOfflineRecognizer: Class<*>? = null - @Volatile - private var clsOfflineRecognizerConfig: Class<*>? = null + @Volatile + private var clsOfflineRecognizerConfig: Class<*>? = null - @Volatile - private var clsOfflineModelConfig: Class<*>? = null + @Volatile + private var clsOfflineModelConfig: Class<*>? = null - @Volatile - private var clsFeatureConfig: Class<*>? = null + @Volatile + private var clsFeatureConfig: Class<*>? = null - @Volatile - private var clsOfflineFunAsrNanoModelConfig: Class<*>? = null + @Volatile + private var clsOfflineFunAsrNanoModelConfig: Class<*>? = null - @Volatile - private var unloadJob: Job? = null + @Volatile + private var unloadJob: Job? = null - @Volatile - private var lastKeepAliveMs: Long = 0L + @Volatile + private var lastKeepAliveMs: Long = 0L - @Volatile - private var lastAlwaysKeep: Boolean = false + @Volatile + private var lastAlwaysKeep: Boolean = false - fun isOnnxAvailable(): Boolean { - return try { - Class.forName("com.k2fsa.sherpa.onnx.OfflineRecognizer") - Class.forName("com.k2fsa.sherpa.onnx.OfflineFunAsrNanoModelConfig") - true - } catch (t: Throwable) { - Log.d(TAG, "sherpa-onnx not available", t) - false - } - } - - fun unload() { - val snapshot = cachedRecognizer ?: return - scope.launch { - val shouldRelease = mutex.withLock { - if (cachedRecognizer !== snapshot) return@withLock false - cachedRecognizer = null - cachedConfig = null - unloadJob?.cancel() - unloadJob = null - true - } - if (shouldRelease) { - try { - snapshot.release() + fun isOnnxAvailable(): Boolean { + return try { + Class.forName("com.k2fsa.sherpa.onnx.OfflineRecognizer") + Class.forName("com.k2fsa.sherpa.onnx.OfflineFunAsrNanoModelConfig") + true } catch (t: Throwable) { - Log.e(TAG, "Failed to release recognizer on unload", t) + Log.d(TAG, "sherpa-onnx not available", t) + false + } + } + + fun unload() { + val snapshot = cachedRecognizer ?: return + scope.launch { + val shouldRelease = mutex.withLock { + if (cachedRecognizer !== snapshot) return@withLock false + cachedRecognizer = null + cachedConfig = null + unloadJob?.cancel() + unloadJob = null + true + } + if (shouldRelease) { + try { + snapshot.release() + } catch (t: Throwable) { + Log.e(TAG, "Failed to release recognizer on unload", t) + } + Log.d(TAG, "Recognizer unloaded") + } } - Log.d(TAG, "Recognizer unloaded") - } } - } - fun isPrepared(): Boolean = cachedRecognizer != null + fun isPrepared(): Boolean = cachedRecognizer != null - fun isPreparing(): Boolean = preparing + fun isPreparing(): Boolean = preparing - private fun scheduleAutoUnload(keepAliveMs: Long, alwaysKeep: Boolean) { - unloadJob?.cancel() - if (alwaysKeep) { - Log.d(TAG, "Recognizer will be kept alive indefinitely") - return - } - if (keepAliveMs <= 0L) { - Log.d(TAG, "Auto-unloading immediately (keepAliveMs=$keepAliveMs)") - unload() - return - } - Log.d(TAG, "Scheduling auto-unload in ${keepAliveMs}ms") - unloadJob = scope.launch { - delay(keepAliveMs) - Log.d(TAG, "Auto-unloading recognizer after timeout") - unload() - } - } - - private data class RecognizerConfig( - val encoderAdaptor: String, - val llm: String, - val embedding: String, - val tokenizerDir: String, - val userPrompt: String, - val provider: String, - val numThreads: Int, - val sampleRate: Int, - val featureDim: Int - ) - - private fun initClasses() { - if (clsOfflineRecognizer == null) { - clsOfflineRecognizer = Class.forName("com.k2fsa.sherpa.onnx.OfflineRecognizer") - clsOfflineRecognizerConfig = Class.forName("com.k2fsa.sherpa.onnx.OfflineRecognizerConfig") - clsOfflineModelConfig = Class.forName("com.k2fsa.sherpa.onnx.OfflineModelConfig") - clsFeatureConfig = Class.forName("com.k2fsa.sherpa.onnx.FeatureConfig") - clsOfflineFunAsrNanoModelConfig = Class.forName("com.k2fsa.sherpa.onnx.OfflineFunAsrNanoModelConfig") - Log.d(TAG, "Initialized sherpa-onnx reflection classes for FunASR Nano") + private fun scheduleAutoUnload(keepAliveMs: Long, alwaysKeep: Boolean) { + unloadJob?.cancel() + if (alwaysKeep) { + Log.d(TAG, "Recognizer will be kept alive indefinitely") + return + } + if (keepAliveMs <= 0L) { + Log.d(TAG, "Auto-unloading immediately (keepAliveMs=$keepAliveMs)") + unload() + return + } + Log.d(TAG, "Scheduling auto-unload in ${keepAliveMs}ms") + unloadJob = scope.launch { + delay(keepAliveMs) + Log.d(TAG, "Auto-unloading recognizer after timeout") + unload() + } } - } - - private fun trySetField(target: Any, name: String, value: Any?): Boolean { - return try { - val f = target.javaClass.getDeclaredField(name) - f.isAccessible = true - f.set(target, value) - true - } catch (t: Throwable) { - try { - val methodName = "set" + name.replaceFirstChar { - if (it.isLowerCase()) it.titlecase() else it.toString() + + private data class RecognizerConfig( + val encoderAdaptor: String, + val llm: String, + val embedding: String, + val tokenizerDir: String, + val userPrompt: String, + val provider: String, + val numThreads: Int, + val sampleRate: Int, + val featureDim: Int, + ) + + private fun initClasses() { + if (clsOfflineRecognizer == null) { + clsOfflineRecognizer = Class.forName("com.k2fsa.sherpa.onnx.OfflineRecognizer") + clsOfflineRecognizerConfig = Class.forName("com.k2fsa.sherpa.onnx.OfflineRecognizerConfig") + clsOfflineModelConfig = Class.forName("com.k2fsa.sherpa.onnx.OfflineModelConfig") + clsFeatureConfig = Class.forName("com.k2fsa.sherpa.onnx.FeatureConfig") + clsOfflineFunAsrNanoModelConfig = Class.forName("com.k2fsa.sherpa.onnx.OfflineFunAsrNanoModelConfig") + Log.d(TAG, "Initialized sherpa-onnx reflection classes for FunASR Nano") } - val m = if (value == null) { - target.javaClass.getMethod(methodName, Any::class.java) - } else { - target.javaClass.getMethod(methodName, value.javaClass) + } + + private fun trySetField(target: Any, name: String, value: Any?): Boolean { + return try { + val f = target.javaClass.getDeclaredField(name) + f.isAccessible = true + f.set(target, value) + true + } catch (t: Throwable) { + try { + val methodName = "set" + name.replaceFirstChar { + if (it.isLowerCase()) it.titlecase() else it.toString() + } + val m = if (value == null) { + target.javaClass.getMethod(methodName, Any::class.java) + } else { + target.javaClass.getMethod(methodName, value.javaClass) + } + m.invoke(target, value) + true + } catch (t2: Throwable) { + Log.w(TAG, "Failed to set field '$name'", t2) + false + } } - m.invoke(target, value) - true - } catch (t2: Throwable) { - Log.w(TAG, "Failed to set field '$name'", t2) - false - } } - } - - private fun buildFeatureConfig(sampleRate: Int, featureDim: Int): Any { - val feat = clsFeatureConfig!!.getDeclaredConstructor().newInstance() - trySetField(feat, "sampleRate", sampleRate) - trySetField(feat, "featureDim", featureDim) - return feat - } - - private fun buildFunAsrNanoModelConfig( - encoderAdaptor: String, - llm: String, - embedding: String, - tokenizerDir: String, - userPrompt: String - ): Any { - val inst = clsOfflineFunAsrNanoModelConfig!!.getDeclaredConstructor().newInstance() - trySetField(inst, "encoderAdaptor", encoderAdaptor) - trySetField(inst, "llm", llm) - trySetField(inst, "embedding", embedding) - trySetField(inst, "tokenizer", tokenizerDir) - trySetField(inst, "userPrompt", userPrompt) - return inst - } - - private fun buildModelConfig( - encoderAdaptor: String, - llm: String, - embedding: String, - tokenizerDir: String, - userPrompt: String, - numThreads: Int, - provider: String - ): Any { - val modelConfig = clsOfflineModelConfig!!.getDeclaredConstructor().newInstance() - trySetField(modelConfig, "tokens", "") - trySetField(modelConfig, "numThreads", numThreads) - trySetField(modelConfig, "provider", provider) - trySetField(modelConfig, "debug", false) - - val funasrNano = buildFunAsrNanoModelConfig(encoderAdaptor, llm, embedding, tokenizerDir, userPrompt) - if (!trySetField(modelConfig, "funasrNano", funasrNano)) { - trySetField(modelConfig, "funasr_nano", funasrNano) + + private fun buildFeatureConfig(sampleRate: Int, featureDim: Int): Any { + val feat = clsFeatureConfig!!.getDeclaredConstructor().newInstance() + trySetField(feat, "sampleRate", sampleRate) + trySetField(feat, "featureDim", featureDim) + return feat } - return modelConfig - } - - private fun buildRecognizerConfig(config: RecognizerConfig): Any { - val modelConfig = buildModelConfig( - encoderAdaptor = config.encoderAdaptor, - llm = config.llm, - embedding = config.embedding, - tokenizerDir = config.tokenizerDir, - userPrompt = config.userPrompt, - numThreads = config.numThreads, - provider = config.provider - ) - val featConfig = buildFeatureConfig(config.sampleRate, config.featureDim) - val recConfig = clsOfflineRecognizerConfig!!.getDeclaredConstructor().newInstance() - if (!trySetField(recConfig, "modelConfig", modelConfig)) { - trySetField(recConfig, "model_config", modelConfig) + + private fun buildFunAsrNanoModelConfig( + encoderAdaptor: String, + llm: String, + embedding: String, + tokenizerDir: String, + userPrompt: String, + ): Any { + val inst = clsOfflineFunAsrNanoModelConfig!!.getDeclaredConstructor().newInstance() + trySetField(inst, "encoderAdaptor", encoderAdaptor) + trySetField(inst, "llm", llm) + trySetField(inst, "embedding", embedding) + trySetField(inst, "tokenizer", tokenizerDir) + trySetField(inst, "userPrompt", userPrompt) + return inst } - if (!trySetField(recConfig, "featConfig", featConfig)) { - trySetField(recConfig, "feat_config", featConfig) + + private fun buildModelConfig( + encoderAdaptor: String, + llm: String, + embedding: String, + tokenizerDir: String, + userPrompt: String, + numThreads: Int, + provider: String, + ): Any { + val modelConfig = clsOfflineModelConfig!!.getDeclaredConstructor().newInstance() + trySetField(modelConfig, "tokens", "") + trySetField(modelConfig, "numThreads", numThreads) + trySetField(modelConfig, "provider", provider) + trySetField(modelConfig, "debug", false) + + val funasrNano = buildFunAsrNanoModelConfig(encoderAdaptor, llm, embedding, tokenizerDir, userPrompt) + if (!trySetField(modelConfig, "funasrNano", funasrNano)) { + trySetField(modelConfig, "funasr_nano", funasrNano) + } + return modelConfig } - trySetField(recConfig, "decodingMethod", "greedy_search") - trySetField(recConfig, "maxActivePaths", 4) - return recConfig - } - - private fun createRecognizer(assetManager: android.content.res.AssetManager?, recConfig: Any): Any { - val ctor = if (assetManager == null) { - try { - clsOfflineRecognizer!!.getDeclaredConstructor(clsOfflineRecognizerConfig) - } catch (t: Throwable) { - Log.d(TAG, "No single-param constructor, using AssetManager variant", t) - clsOfflineRecognizer!!.getDeclaredConstructor( - android.content.res.AssetManager::class.java, - clsOfflineRecognizerConfig - ) - } - } else { - try { - clsOfflineRecognizer!!.getDeclaredConstructor( - android.content.res.AssetManager::class.java, - clsOfflineRecognizerConfig + + private fun buildRecognizerConfig(config: RecognizerConfig): Any { + val modelConfig = buildModelConfig( + encoderAdaptor = config.encoderAdaptor, + llm = config.llm, + embedding = config.embedding, + tokenizerDir = config.tokenizerDir, + userPrompt = config.userPrompt, + numThreads = config.numThreads, + provider = config.provider, ) - } catch (t: Throwable) { - Log.d(TAG, "No AssetManager constructor, using single-param variant", t) - clsOfflineRecognizer!!.getDeclaredConstructor(clsOfflineRecognizerConfig) - } + val featConfig = buildFeatureConfig(config.sampleRate, config.featureDim) + val recConfig = clsOfflineRecognizerConfig!!.getDeclaredConstructor().newInstance() + if (!trySetField(recConfig, "modelConfig", modelConfig)) { + trySetField(recConfig, "model_config", modelConfig) + } + if (!trySetField(recConfig, "featConfig", featConfig)) { + trySetField(recConfig, "feat_config", featConfig) + } + trySetField(recConfig, "decodingMethod", "greedy_search") + trySetField(recConfig, "maxActivePaths", 4) + return recConfig } - return if (ctor.parameterCount == 2) { - ctor.newInstance(assetManager, recConfig) - } else { - ctor.newInstance(recConfig) + + private fun createRecognizer(assetManager: android.content.res.AssetManager?, recConfig: Any): Any { + val ctor = if (assetManager == null) { + try { + clsOfflineRecognizer!!.getDeclaredConstructor(clsOfflineRecognizerConfig) + } catch (t: Throwable) { + Log.d(TAG, "No single-param constructor, using AssetManager variant", t) + clsOfflineRecognizer!!.getDeclaredConstructor( + android.content.res.AssetManager::class.java, + clsOfflineRecognizerConfig, + ) + } + } else { + try { + clsOfflineRecognizer!!.getDeclaredConstructor( + android.content.res.AssetManager::class.java, + clsOfflineRecognizerConfig, + ) + } catch (t: Throwable) { + Log.d(TAG, "No AssetManager constructor, using single-param variant", t) + clsOfflineRecognizer!!.getDeclaredConstructor(clsOfflineRecognizerConfig) + } + } + return if (ctor.parameterCount == 2) { + ctor.newInstance(assetManager, recConfig) + } else { + ctor.newInstance(recConfig) + } } - } - private fun releaseRecognizerSafely(recognizer: ReflectiveRecognizer?, reason: String) { - if (recognizer == null) return - try { - recognizer.release() - } catch (t: Throwable) { - Log.e(TAG, "Failed to release recognizer ($reason)", t) + private fun releaseRecognizerSafely(recognizer: ReflectiveRecognizer?, reason: String) { + if (recognizer == null) return + try { + recognizer.release() + } catch (t: Throwable) { + Log.e(TAG, "Failed to release recognizer ($reason)", t) + } } - } - private fun invokeCallbackSafely(name: String, callback: (() -> Unit)?) { - if (callback == null) return - try { - callback() - } catch (t: Throwable) { - Log.e(TAG, "$name callback failed", t) + private fun invokeCallbackSafely(name: String, callback: (() -> Unit)?) { + if (callback == null) return + try { + callback() + } catch (t: Throwable) { + Log.e(TAG, "$name callback failed", t) + } } - } - - private suspend fun ensurePreparedLocked( - assetManager: android.content.res.AssetManager?, - config: RecognizerConfig, - onLoadStart: (() -> Unit)?, - onLoadDone: (() -> Unit)? - ): ReflectiveRecognizer? { - initClasses() - val cached = cachedRecognizer - if (cached != null && cachedConfig == config) return cached - - preparing = true - unloadJob?.cancel() - unloadJob = null - - var newRecognizer: ReflectiveRecognizer? = null - try { - currentCoroutineContext().ensureActive() - invokeCallbackSafely("onLoadStart", onLoadStart) - currentCoroutineContext().ensureActive() - - val recConfig = buildRecognizerConfig(config) - currentCoroutineContext().ensureActive() - val raw = createRecognizer(assetManager, recConfig) - newRecognizer = ReflectiveRecognizer(raw, clsOfflineRecognizer!!) - currentCoroutineContext().ensureActive() - - val oldRecognizer = cachedRecognizer - cachedRecognizer = newRecognizer - cachedConfig = config - invokeCallbackSafely("onLoadDone", onLoadDone) - - if (oldRecognizer != null && oldRecognizer !== newRecognizer) { - releaseRecognizerSafely(oldRecognizer, "old") - } - return newRecognizer - } catch (t: CancellationException) { - releaseRecognizerSafely(newRecognizer, "canceled") - throw t - } catch (t: Throwable) { - releaseRecognizerSafely(newRecognizer, "failed") - throw t - } finally { - preparing = false + + private suspend fun ensurePreparedLocked( + assetManager: android.content.res.AssetManager?, + config: RecognizerConfig, + onLoadStart: (() -> Unit)?, + onLoadDone: (() -> Unit)?, + ): ReflectiveRecognizer? { + initClasses() + val cached = cachedRecognizer + if (cached != null && cachedConfig == config) return cached + + preparing = true + unloadJob?.cancel() + unloadJob = null + + var newRecognizer: ReflectiveRecognizer? = null + try { + currentCoroutineContext().ensureActive() + invokeCallbackSafely("onLoadStart", onLoadStart) + currentCoroutineContext().ensureActive() + + val recConfig = buildRecognizerConfig(config) + currentCoroutineContext().ensureActive() + val raw = createRecognizer(assetManager, recConfig) + newRecognizer = ReflectiveRecognizer(raw, clsOfflineRecognizer!!) + currentCoroutineContext().ensureActive() + + val oldRecognizer = cachedRecognizer + cachedRecognizer = newRecognizer + cachedConfig = config + invokeCallbackSafely("onLoadDone", onLoadDone) + + if (oldRecognizer != null && oldRecognizer !== newRecognizer) { + releaseRecognizerSafely(oldRecognizer, "old") + } + return newRecognizer + } catch (t: CancellationException) { + releaseRecognizerSafely(newRecognizer, "canceled") + throw t + } catch (t: Throwable) { + releaseRecognizerSafely(newRecognizer, "failed") + throw t + } finally { + preparing = false + } } - } - - suspend fun decodeOffline( - assetManager: android.content.res.AssetManager?, - encoderAdaptor: String, - llm: String, - embedding: String, - tokenizerDir: String, - userPrompt: String, - provider: String, - numThreads: Int, - samples: FloatArray, - sampleRate: Int, - keepAliveMs: Long, - alwaysKeep: Boolean, - onLoadStart: (() -> Unit)? = null, - onLoadDone: (() -> Unit)? = null - ): String? = mutex.withLock { - try { - val cfg = RecognizerConfig( - encoderAdaptor = encoderAdaptor, - llm = llm, - embedding = embedding, - tokenizerDir = tokenizerDir, - userPrompt = userPrompt, - provider = provider, - numThreads = numThreads, - sampleRate = sampleRate, - featureDim = 80 - ) - val recognizer = ensurePreparedLocked(assetManager, cfg, onLoadStart, onLoadDone) - ?: return@withLock null - lastKeepAliveMs = keepAliveMs - lastAlwaysKeep = alwaysKeep - val stream = recognizer.createStream() - try { - stream.acceptWaveform(samples, sampleRate) - val text = recognizer.decode(stream) - scheduleAutoUnload(keepAliveMs, alwaysKeep) - return@withLock text - } finally { - stream.release() - } - } catch (t: CancellationException) { - throw t - } catch (t: Throwable) { - Log.e(TAG, "Failed to decode offline FunASR Nano: ${t.message}", t) - return@withLock null + + suspend fun decodeOffline( + assetManager: android.content.res.AssetManager?, + encoderAdaptor: String, + llm: String, + embedding: String, + tokenizerDir: String, + userPrompt: String, + provider: String, + numThreads: Int, + samples: FloatArray, + sampleRate: Int, + keepAliveMs: Long, + alwaysKeep: Boolean, + onLoadStart: (() -> Unit)? = null, + onLoadDone: (() -> Unit)? = null, + ): String? = mutex.withLock { + try { + val cfg = RecognizerConfig( + encoderAdaptor = encoderAdaptor, + llm = llm, + embedding = embedding, + tokenizerDir = tokenizerDir, + userPrompt = userPrompt, + provider = provider, + numThreads = numThreads, + sampleRate = sampleRate, + featureDim = 80, + ) + val recognizer = ensurePreparedLocked(assetManager, cfg, onLoadStart, onLoadDone) + ?: return@withLock null + lastKeepAliveMs = keepAliveMs + lastAlwaysKeep = alwaysKeep + val stream = recognizer.createStream() + try { + stream.acceptWaveform(samples, sampleRate) + val text = recognizer.decode(stream) + scheduleAutoUnload(keepAliveMs, alwaysKeep) + return@withLock text + } finally { + stream.release() + } + } catch (t: CancellationException) { + throw t + } catch (t: Throwable) { + Log.e(TAG, "Failed to decode offline FunASR Nano: ${t.message}", t) + return@withLock null + } } - } - - suspend fun prepare( - assetManager: android.content.res.AssetManager?, - encoderAdaptor: String, - llm: String, - embedding: String, - tokenizerDir: String, - userPrompt: String, - provider: String, - numThreads: Int, - keepAliveMs: Long, - alwaysKeep: Boolean, - onLoadStart: (() -> Unit)? = null, - onLoadDone: (() -> Unit)? = null - ): Boolean = mutex.withLock { - try { - val cfg = RecognizerConfig( - encoderAdaptor = encoderAdaptor, - llm = llm, - embedding = embedding, - tokenizerDir = tokenizerDir, - userPrompt = userPrompt, - provider = provider, - numThreads = numThreads, - sampleRate = 16000, - featureDim = 80 - ) - val ok = ensurePreparedLocked(assetManager, cfg, onLoadStart, onLoadDone) != null - if (!ok) return@withLock false - lastKeepAliveMs = keepAliveMs - lastAlwaysKeep = alwaysKeep - true - } catch (t: CancellationException) { - throw t - } catch (t: Throwable) { - Log.e(TAG, "Failed to prepare FunASR Nano recognizer: ${t.message}", t) - false + + suspend fun prepare( + assetManager: android.content.res.AssetManager?, + encoderAdaptor: String, + llm: String, + embedding: String, + tokenizerDir: String, + userPrompt: String, + provider: String, + numThreads: Int, + keepAliveMs: Long, + alwaysKeep: Boolean, + onLoadStart: (() -> Unit)? = null, + onLoadDone: (() -> Unit)? = null, + ): Boolean = mutex.withLock { + try { + val cfg = RecognizerConfig( + encoderAdaptor = encoderAdaptor, + llm = llm, + embedding = embedding, + tokenizerDir = tokenizerDir, + userPrompt = userPrompt, + provider = provider, + numThreads = numThreads, + sampleRate = 16000, + featureDim = 80, + ) + val ok = ensurePreparedLocked(assetManager, cfg, onLoadStart, onLoadDone) != null + if (!ok) return@withLock false + lastKeepAliveMs = keepAliveMs + lastAlwaysKeep = alwaysKeep + true + } catch (t: CancellationException) { + throw t + } catch (t: Throwable) { + Log.e(TAG, "Failed to prepare FunASR Nano recognizer: ${t.message}", t) + false + } } - } } /** * FunASR Nano 预加载:根据当前配置尝试构建本地识别器,便于降低首次点击等待 */ fun preloadFunAsrNanoIfConfigured( - context: Context, - prefs: Prefs, - onLoadStart: (() -> Unit)? = null, - onLoadDone: (() -> Unit)? = null, - suppressToastOnStart: Boolean = false, - forImmediateUse: Boolean = false + context: Context, + prefs: Prefs, + onLoadStart: (() -> Unit)? = null, + onLoadDone: (() -> Unit)? = null, + suppressToastOnStart: Boolean = false, + forImmediateUse: Boolean = false, ) { - try { - val manager = FunAsrNanoOnnxManager.getInstance() - if (!manager.isOnnxAvailable()) return - - val base = context.getExternalFilesDir(null) ?: context.filesDir - - val probeRoot = File(base, "funasr_nano") - val variantDir = File(probeRoot, "nano-int8") - val modelDir = findFnModelDir(variantDir) ?: findFnModelDir(probeRoot) ?: return - - val encoderAdaptor = File(modelDir, "encoder_adaptor.int8.onnx") - val llm = File(modelDir, "llm.int8.onnx") - val embedding = File(modelDir, "embedding.int8.onnx") - val tokenizerDir = findFnTokenizerDir(modelDir) ?: return - - val minOnnxBytes = 8L * 1024L * 1024L - val minLlmBytes = 32L * 1024L * 1024L - if ( - !encoderAdaptor.exists() || - !embedding.exists() || - !llm.exists() || - encoderAdaptor.length() < minOnnxBytes || - embedding.length() < minOnnxBytes || - llm.length() < minLlmBytes || - !File(tokenizerDir, "tokenizer.json").exists() - ) return - - val keepMinutes = prefs.fnKeepAliveMinutes - val keepMs = if (keepMinutes <= 0) 0L else keepMinutes.toLong() * 60_000L - val alwaysKeep = keepMinutes < 0 - - val userPrompt = prefs.fnUserPrompt.trim().ifBlank { "语音转写:" } - - val numThreads = prefs.fnNumThreads - val key = "funasr_nano|" + - "encoder=${encoderAdaptor.absolutePath}|" + - "llm=${llm.absolutePath}|" + - "embedding=${embedding.absolutePath}|" + - "tokenizer=${tokenizerDir.absolutePath}|" + - "prompt=$userPrompt|" + - "provider=cpu|" + - "threads=$numThreads" - - val mainHandler = Handler(Looper.getMainLooper()) - LocalModelLoadCoordinator.request(key) { - val t0 = android.os.SystemClock.uptimeMillis() - val ok = manager.prepare( - assetManager = null, - encoderAdaptor = encoderAdaptor.absolutePath, - llm = llm.absolutePath, - embedding = embedding.absolutePath, - tokenizerDir = tokenizerDir.absolutePath, - userPrompt = userPrompt, - provider = "cpu", - numThreads = numThreads, - keepAliveMs = keepMs, - alwaysKeep = alwaysKeep, - onLoadStart = { - if (!suppressToastOnStart) { - mainHandler.post { - Toast.makeText( - context, - context.getString(R.string.sv_loading_model), - Toast.LENGTH_SHORT - ).show() + try { + val manager = FunAsrNanoOnnxManager.getInstance() + if (!manager.isOnnxAvailable()) return + + val base = context.getExternalFilesDir(null) ?: context.filesDir + + val probeRoot = File(base, "funasr_nano") + val variantDir = File(probeRoot, "nano-int8") + val modelDir = findFnModelDir(variantDir) ?: findFnModelDir(probeRoot) ?: return + + val encoderAdaptor = File(modelDir, "encoder_adaptor.int8.onnx") + val llm = File(modelDir, "llm.int8.onnx") + val embedding = File(modelDir, "embedding.int8.onnx") + val tokenizerDir = findFnTokenizerDir(modelDir) ?: return + + val minOnnxBytes = 8L * 1024L * 1024L + val minLlmBytes = 32L * 1024L * 1024L + if ( + !encoderAdaptor.exists() || + !embedding.exists() || + !llm.exists() || + encoderAdaptor.length() < minOnnxBytes || + embedding.length() < minOnnxBytes || + llm.length() < minLlmBytes || + !File(tokenizerDir, "tokenizer.json").exists() + ) { + return + } + + val keepMinutes = prefs.fnKeepAliveMinutes + val keepMs = if (keepMinutes <= 0) 0L else keepMinutes.toLong() * 60_000L + val alwaysKeep = keepMinutes < 0 + + val userPrompt = prefs.fnUserPrompt.trim().ifBlank { "语音转写:" } + + val numThreads = prefs.fnNumThreads + val key = "funasr_nano|" + + "encoder=${encoderAdaptor.absolutePath}|" + + "llm=${llm.absolutePath}|" + + "embedding=${embedding.absolutePath}|" + + "tokenizer=${tokenizerDir.absolutePath}|" + + "prompt=$userPrompt|" + + "provider=cpu|" + + "threads=$numThreads" + + val mainHandler = Handler(Looper.getMainLooper()) + LocalModelLoadCoordinator.request(key) { + val t0 = android.os.SystemClock.uptimeMillis() + val ok = manager.prepare( + assetManager = null, + encoderAdaptor = encoderAdaptor.absolutePath, + llm = llm.absolutePath, + embedding = embedding.absolutePath, + tokenizerDir = tokenizerDir.absolutePath, + userPrompt = userPrompt, + provider = "cpu", + numThreads = numThreads, + keepAliveMs = keepMs, + alwaysKeep = alwaysKeep, + onLoadStart = { + if (!suppressToastOnStart) { + mainHandler.post { + Toast.makeText( + context, + context.getString(R.string.sv_loading_model), + Toast.LENGTH_SHORT, + ).show() + } + } + onLoadStart?.invoke() + }, + onLoadDone = onLoadDone, + ) + + if (ok && !forImmediateUse) { + val dt = (android.os.SystemClock.uptimeMillis() - t0).coerceAtLeast(0) + mainHandler.post { + Toast.makeText( + context, + context.getString(R.string.sv_model_ready_with_ms, dt), + Toast.LENGTH_SHORT, + ).show() + } } - } - onLoadStart?.invoke() - }, - onLoadDone = onLoadDone - ) - - if (ok && !forImmediateUse) { - val dt = (android.os.SystemClock.uptimeMillis() - t0).coerceAtLeast(0) - mainHandler.post { - Toast.makeText( - context, - context.getString(R.string.sv_model_ready_with_ms, dt), - Toast.LENGTH_SHORT - ).show() } - } + } catch (t: Throwable) { + Log.e("FunAsrNanoFileAsrEngine", "Failed to preload FunASR Nano model", t) } - } catch (t: Throwable) { - Log.e("FunAsrNanoFileAsrEngine", "Failed to preload FunASR Nano model", t) - } } diff --git a/app/src/main/java/com/brycewg/asrkb/asr/GeminiFileAsrEngine.kt b/app/src/main/java/com/brycewg/asrkb/asr/GeminiFileAsrEngine.kt index cd1ad8c1..25817351 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/GeminiFileAsrEngine.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/GeminiFileAsrEngine.kt @@ -27,7 +27,7 @@ class GeminiFileAsrEngine( prefs: Prefs, listener: StreamingAsrEngine.Listener, onRequestDuration: ((Long) -> Unit)? = null, - httpClient: OkHttpClient? = null + httpClient: OkHttpClient? = null, ) : BaseFileAsrEngine(context, scope, prefs, listener, onRequestDuration), PcmBatchRecognizer { companion object { @@ -75,14 +75,16 @@ class GeminiFileAsrEngine( val hint = extractGeminiError(str) val detail = formatHttpDetail(r.message, hint) listener.onError( - context.getString(R.string.error_request_failed_http, r.code, detail) + context.getString(R.string.error_request_failed_http, r.code, detail), ) return } val text = parseGeminiText(str) if (text.isNotBlank()) { val dt = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - t0) - try { onRequestDuration?.invoke(dt) } catch (_: Throwable) {} + try { + onRequestDuration?.invoke(dt) + } catch (_: Throwable) {} listener.onFinal(text) } else { listener.onError(context.getString(R.string.error_asr_empty_result)) @@ -90,51 +92,68 @@ class GeminiFileAsrEngine( } } catch (t: Throwable) { listener.onError( - context.getString(R.string.error_recognize_failed_with_reason, t.message ?: "") + context.getString(R.string.error_recognize_failed_with_reason, t.message ?: ""), ) } } - override suspend fun recognizeFromPcm(pcm: ByteArray) { recognize(pcm) } + override suspend fun recognizeFromPcm(pcm: ByteArray) { + recognize(pcm) + } /** * 构建 Gemini API 请求体 */ private fun buildGeminiRequestBody(base64Wav: String, prompt: String, model: String): String { val inlineAudio = JSONObject().apply { - put("inline_data", JSONObject().apply { - put("mime_type", "audio/wav") - put("data", base64Wav) - }) + put( + "inline_data", + JSONObject().apply { + put("mime_type", "audio/wav") + put("data", base64Wav) + }, + ) } val systemInstruction = JSONObject().apply { - put("parts", org.json.JSONArray().apply { - put(JSONObject().apply { put("text", prompt) }) - }) + put( + "parts", + org.json.JSONArray().apply { + put(JSONObject().apply { put("text", prompt) }) + }, + ) } val user = JSONObject().apply { put("role", "user") - put("parts", org.json.JSONArray().apply { - put(inlineAudio) - }) + put( + "parts", + org.json.JSONArray().apply { + put(inlineAudio) + }, + ) } return JSONObject().apply { put("system_instruction", systemInstruction) put("contents", org.json.JSONArray().apply { put(user) }) - put("generation_config", JSONObject().apply { - put("temperature", 0) - if (prefs.geminiDisableThinking) { - // 根据模型类型设置合适的 thinkingBudget - val budget = when { - model.contains("2.5-pro", ignoreCase = true) -> 128 - model.contains("2.5-flash", ignoreCase = true) -> 0 // Flash 可以为 0 - else -> 0 // 其他情况默认为 0 + put( + "generation_config", + JSONObject().apply { + put("temperature", 0) + if (prefs.geminiDisableThinking) { + // 根据模型类型设置合适的 thinkingBudget + val budget = when { + model.contains("2.5-pro", ignoreCase = true) -> 128 + model.contains("2.5-flash", ignoreCase = true) -> 0 // Flash 可以为 0 + else -> 0 // 其他情况默认为 0 + } + put( + "thinkingConfig", + JSONObject().apply { + put("thinkingBudget", budget) + }, + ) } - put("thinkingConfig", JSONObject().apply { - put("thinkingBudget", budget) - }) - } - }) + }, + ) }.toString() } @@ -169,7 +188,9 @@ class GeminiFileAsrEngine( val msg = e?.optString("message").orEmpty() val status = e?.optString("status").orEmpty() listOf(status, msg).filter { it.isNotBlank() }.joinToString(": ") - } else body.take(200).trim() + } else { + body.take(200).trim() + } } catch (t: Throwable) { Log.e(TAG, "Failed to parse Gemini error", t) body.take(200).trim() @@ -192,7 +213,10 @@ class GeminiFileAsrEngine( for (i in 0 until parts.length()) { val p = parts.optJSONObject(i) ?: continue val t = p.optString("text").trim() - if (t.isNotEmpty()) { txt = t; break } + if (t.isNotEmpty()) { + txt = t + break + } } txt } catch (t: Throwable) { diff --git a/app/src/main/java/com/brycewg/asrkb/asr/GenericPushFileAsrAdapter.kt b/app/src/main/java/com/brycewg/asrkb/asr/GenericPushFileAsrAdapter.kt index ecf46cdc..fa190861 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/GenericPushFileAsrAdapter.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/GenericPushFileAsrAdapter.kt @@ -25,10 +25,12 @@ class GenericPushFileAsrAdapter( private val scope: CoroutineScope, private val prefs: Prefs, private val listener: StreamingAsrEngine.Listener, - private val recognizer: PcmBatchRecognizer + private val recognizer: PcmBatchRecognizer, ) : StreamingAsrEngine, ExternalPcmConsumer { - companion object { private const val TAG = "PushFileAdapter" } + companion object { + private const val TAG = "PushFileAdapter" + } private val running = AtomicBoolean(false) private val bos = ByteArrayOutputStream() @@ -44,11 +46,17 @@ class GenericPushFileAsrAdapter( override fun stop() { if (!running.get()) return running.set(false) - try { listener.onStopped() } catch (t: Throwable) { Log.w(TAG, "notify stopped failed", t) } + try { + listener.onStopped() + } catch (t: Throwable) { + Log.w(TAG, "notify stopped failed", t) + } val data = bos.toByteArray() bos.reset() if (data.isEmpty()) { - try { listener.onError(context.getString(R.string.error_audio_empty)) } catch (_: Throwable) {} + try { + listener.onError(context.getString(R.string.error_audio_empty)) + } catch (_: Throwable) {} return } scope.launch(Dispatchers.IO) { @@ -57,7 +65,7 @@ class GenericPushFileAsrAdapter( context = context, prefs = prefs, pcm = data, - sampleRate = 16000 + sampleRate = 16000, ) recognizer.recognizeFromPcm(denoised) } catch (t: Throwable) { @@ -75,7 +83,15 @@ class GenericPushFileAsrAdapter( Log.w(TAG, "ignore frame: sr=$sampleRate ch=$channels") return } - try { listener.onAmplitude(calculateNormalizedAmplitude(pcm)) } catch (t: Throwable) { Log.w(TAG, "amp cb failed", t) } - try { bos.write(pcm) } catch (t: Throwable) { Log.e(TAG, "buffer write failed", t) } + try { + listener.onAmplitude(calculateNormalizedAmplitude(pcm)) + } catch (t: Throwable) { + Log.w(TAG, "amp cb failed", t) + } + try { + bos.write(pcm) + } catch (t: Throwable) { + Log.e(TAG, "buffer write failed", t) + } } } diff --git a/app/src/main/java/com/brycewg/asrkb/asr/LlmPostProcessor.kt b/app/src/main/java/com/brycewg/asrkb/asr/LlmPostProcessor.kt index 9349cb79..b7a6b729 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/LlmPostProcessor.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/LlmPostProcessor.kt @@ -11,8 +11,8 @@ import com.brycewg.asrkb.R import com.brycewg.asrkb.store.Prefs import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext -import okhttp3.MediaType.Companion.toMediaType import okhttp3.Call +import okhttp3.MediaType.Companion.toMediaType import okhttp3.OkHttpClient import okhttp3.Request import okhttp3.RequestBody.Companion.toRequestBody @@ -27,931 +27,951 @@ import java.util.concurrent.TimeUnit * 使用与 Chat Completions 兼容的 API,并在存在简单字段时回退使用。 */ class LlmPostProcessor(private val client: OkHttpClient? = null) { - private val jsonMedia = "application/json; charset=utf-8".toMediaType() - @Volatile - private var activeCall: Call? = null - - /** - * LLM 测试结果 - */ - data class LlmTestResult( - val ok: Boolean, - val httpCode: Int? = null, - val message: String? = null, - val contentPreview: String? = null - ) - - /** - * /models 拉取结果 - */ - data class LlmModelsResult( - val ok: Boolean, - val models: List = emptyList(), - val httpCode: Int? = null, - val message: String? = null - ) - - /** - * 统一的底层调用结果 - */ - private data class RawCallResult( - val ok: Boolean, - val httpCode: Int? = null, - val text: String? = null, - val error: String? = null - ) - - /** - * 标准化的上层处理结果,用于向调用方传递是否成功以及返回文本。 - */ - data class LlmProcessResult( - val ok: Boolean, - val text: String, - val errorMessage: String? = null, - val httpCode: Int? = null, - // 表示本次结果是否“实际使用了 AI 输出”(调用成功并采用其文本) - val usedAi: Boolean = false, - // 是否实际发起了 LLM 请求(跳过/空输入等场景为 false) - val attempted: Boolean = false, - // LLM 请求耗时(毫秒);未尝试时为 0 - val llmMs: Long = 0 - ) - - /** - * LLM 请求配置 - */ - private data class LlmRequestConfig( - val apiKey: String, - val endpoint: String, - val model: String, - val temperature: Double, - val vendor: LlmVendor, - val enableReasoning: Boolean, - val supportsReasoningControl: Boolean, - val useCustomReasoningParams: Boolean, - val reasoningParamsOnJson: String, - val reasoningParamsOffJson: String - ) - - companion object { - private const val TAG = "LlmPostProcessor" - - /** 连接超时(秒) */ - private const val CONNECT_TIMEOUT_SECONDS = 30L - - /** 首 token 超时(秒)- streaming 模式下等待首个数据块的最大时间 */ - private const val FIRST_TOKEN_TIMEOUT_SECONDS = 60L - - } - - private fun buildRequestConfig( - apiKey: String, - endpoint: String, - model: String, - temperature: Double, - vendor: LlmVendor, - enableReasoning: Boolean, - useCustomReasoningParams: Boolean, - reasoningParamsOnJson: String, - reasoningParamsOffJson: String - ): LlmRequestConfig { - val supportsReasoning = vendor.supportsReasoningControl(model) - return LlmRequestConfig( - apiKey = apiKey, - endpoint = endpoint, - model = model, - temperature = temperature, - vendor = vendor, - enableReasoning = enableReasoning, - supportsReasoningControl = supportsReasoning, - useCustomReasoningParams = useCustomReasoningParams, - reasoningParamsOnJson = reasoningParamsOnJson, - reasoningParamsOffJson = reasoningParamsOffJson + private val jsonMedia = "application/json; charset=utf-8".toMediaType() + + @Volatile + private var activeCall: Call? = null + + /** + * LLM 测试结果 + */ + data class LlmTestResult( + val ok: Boolean, + val httpCode: Int? = null, + val message: String? = null, + val contentPreview: String? = null, ) - } - - /** - * 从 Prefs 获取活动的 LLM 配置(使用新的供应商架构) - */ - private fun getActiveConfig(prefs: Prefs): LlmRequestConfig { - val vendor = prefs.llmVendor - - // SiliconFlow 免费服务特殊处理 - if (vendor == LlmVendor.SF_FREE && !prefs.sfFreeLlmUsePaidKey) { - val model = prefs.sfFreeLlmModel - val effective = prefs.getEffectiveLlmConfig() - return buildRequestConfig( - apiKey = BuildConfig.SF_FREE_API_KEY, - endpoint = Prefs.SF_CHAT_COMPLETIONS_ENDPOINT, - model = model, - temperature = Prefs.DEFAULT_LLM_TEMPERATURE.toDouble(), - vendor = vendor, - enableReasoning = prefs.getLlmVendorReasoningEnabled(vendor), - useCustomReasoningParams = effective?.useCustomReasoningParams ?: false, - reasoningParamsOnJson = effective?.reasoningParamsOnJson ?: Prefs.DEFAULT_CUSTOM_REASONING_PARAMS_ON_JSON, - reasoningParamsOffJson = effective?.reasoningParamsOffJson ?: Prefs.DEFAULT_CUSTOM_REASONING_PARAMS_OFF_JSON - ) - } - // 使用统一的 getEffectiveLlmConfig - val config = prefs.getEffectiveLlmConfig() - if (config != null) { - return buildRequestConfig( - apiKey = config.apiKey, - endpoint = config.endpoint, - model = config.model, - temperature = config.temperature.toDouble(), - vendor = config.vendor, - enableReasoning = config.enableReasoning, - useCustomReasoningParams = config.useCustomReasoningParams, - reasoningParamsOnJson = config.reasoningParamsOnJson, - reasoningParamsOffJson = config.reasoningParamsOffJson - ) - } + /** + * /models 拉取结果 + */ + data class LlmModelsResult( + val ok: Boolean, + val models: List = emptyList(), + val httpCode: Int? = null, + val message: String? = null, + ) - // 回退到旧的逻辑(兼容性) - val active = prefs.getActiveLlmProvider() - val fallbackEndpoint = if (vendor.hasBuiltinEndpoint) vendor.endpoint else (active?.endpoint ?: prefs.llmEndpoint) - return buildRequestConfig( - apiKey = active?.apiKey ?: prefs.llmApiKey, - endpoint = fallbackEndpoint, - model = active?.model ?: prefs.llmModel, - temperature = (active?.temperature ?: prefs.llmTemperature).toDouble(), - vendor = vendor, - enableReasoning = prefs.getLlmVendorReasoningEnabled(vendor), - useCustomReasoningParams = false, - reasoningParamsOnJson = Prefs.DEFAULT_CUSTOM_REASONING_PARAMS_ON_JSON, - reasoningParamsOffJson = Prefs.DEFAULT_CUSTOM_REASONING_PARAMS_OFF_JSON + /** + * 统一的底层调用结果 + */ + private data class RawCallResult( + val ok: Boolean, + val httpCode: Int? = null, + val text: String? = null, + val error: String? = null, ) - } - - /** - * 解析 URL,自动添加 /chat/completions 后缀 - */ - private fun resolveUrl(base: String): String { - val raw = base.trim() - if (raw.isEmpty()) return Prefs.DEFAULT_LLM_ENDPOINT.trimEnd('/') + "/chat/completions" - val b = raw.trimEnd('/') - // 要求用户填写完整 URL(包含 http/https),不再自动补全协议 - val hasScheme = b.startsWith("http://", true) || b.startsWith("https://", true) - if (!hasScheme) throw IllegalArgumentException("Endpoint must start with http:// or https://") - - // 如果已直接指向 chat/completions 或 responses,则原样使用 - if (b.endsWith("/chat/completions")) return b - - // 其他情况:直接补全 /chat/completions - return "$b/chat/completions" - } - - /** - * 解析 /models URL,支持将 /chat/completions 转换为 /models - */ - private fun resolveModelsUrl(base: String): String { - val raw = base.trim() - if (raw.isEmpty()) throw IllegalArgumentException("Missing endpoint") - val b = raw.trimEnd('/') - val hasScheme = b.startsWith("http://", true) || b.startsWith("https://", true) - if (!hasScheme) throw IllegalArgumentException("Endpoint must start with http:// or https://") - if (b.endsWith("/models")) return b - if (b.endsWith("/chat/completions")) { - return b.removeSuffix("/chat/completions") + "/models" - } - return "$b/models" - } - - /** - * 根据供应商添加推理控制参数到请求体 - * - * @param body 请求 JSON 对象 - * @param config LLM 配置 - */ - private fun addReasoningParams(body: JSONObject, config: LlmRequestConfig) { - val vendor = config.vendor - if (config.useCustomReasoningParams) { - val raw = if (config.enableReasoning) config.reasoningParamsOnJson else config.reasoningParamsOffJson - val trimmed = raw.trim() - if (trimmed.isEmpty()) return - if (!trimmed.startsWith("{")) { - Log.w(TAG, "Reasoning params must be a JSON object: $trimmed") - return - } - val obj = try { - JSONObject(trimmed) - } catch (t: Throwable) { - Log.w(TAG, "Failed to parse reasoning params JSON: $trimmed", t) - return - } - val keys = obj.keys() - while (keys.hasNext()) { - val key = keys.next() - body.put(key, obj.opt(key)) - } - return - } - if (!config.supportsReasoningControl) return - - when (vendor) { - LlmVendor.SF_FREE -> { - // SiliconFlow: enable_thinking 支持显式开关 - body.put("enable_thinking", config.enableReasoning) - return - } - LlmVendor.VOLCENGINE, LlmVendor.ZHIPU -> { - // 火山/智谱:通过 thinking.type 控制开关 - val type = if (config.enableReasoning) "enabled" else "disabled" - body.put("thinking", JSONObject().put("type", type)) - return - } - LlmVendor.GEMINI -> { - // Gemini Pro 只能将预算调低;flash 系列可关闭 - if (config.enableReasoning) return - val modelLower = config.model.lowercase() - val effort = if (modelLower.contains("pro") || modelLower.startsWith("gemini-3")) "low" else "none" - body.put("reasoning_effort", effort) - return - } - LlmVendor.GROQ -> { - // Groq:仅对支持思考的模型下发对应最小值 - if (config.enableReasoning) return - val modelLower = config.model.lowercase() - val effort = when { - modelLower.contains("qwen3") || modelLower.contains("qwen/") -> "none" - modelLower.contains("gpt-oss") -> "low" - else -> return - } - body.put("reasoning_effort", effort) - return - } - LlmVendor.CEREBRAS -> { - // Cerebras 仅 gpt-oss-120b 支持 reasoning_effort,且最小为 low - val isGptOss120b = config.model.equals("gpt-oss-120b", ignoreCase = true) - if (!isGptOss120b) return - if (!config.enableReasoning) { - body.put("reasoning_effort", "low") - } - return - } - LlmVendor.FIREWORKS -> { - // Fireworks 模型有不同的推理控制行为: - // - DeepSeek V3.1/V3.2: 二进制开关,默认关闭 - // - GLM 4.5/4.6: 二进制开关,默认开启 - // - GPT-OSS: 只支持 low/medium/high,不支持 none - val modelLower = config.model.lowercase() - when { - modelLower.contains("deepseek") -> { - // DeepSeek: 开启发送 medium,关闭发送 none - body.put("reasoning_effort", if (config.enableReasoning) "medium" else "none") - } - modelLower.contains("glm") -> { - // GLM: 默认开启,仅关闭时发送 none - if (!config.enableReasoning) { - body.put("reasoning_effort", "none") - } - } - modelLower.contains("gpt-oss") -> { - // GPT-OSS: 不支持 none,开启用 medium,关闭用 low - body.put("reasoning_effort", if (config.enableReasoning) "medium" else "low") - } - } - return - } - else -> { - // fall through to generic handling - } - } + /** + * 标准化的上层处理结果,用于向调用方传递是否成功以及返回文本。 + */ + data class LlmProcessResult( + val ok: Boolean, + val text: String, + val errorMessage: String? = null, + val httpCode: Int? = null, + // 表示本次结果是否“实际使用了 AI 输出”(调用成功并采用其文本) + val usedAi: Boolean = false, + // 是否实际发起了 LLM 请求(跳过/空输入等场景为 false) + val attempted: Boolean = false, + // LLM 请求耗时(毫秒);未尝试时为 0 + val llmMs: Long = 0, + ) - when (vendor.reasoningMode) { - ReasoningMode.ENABLE_THINKING -> { - body.put("enable_thinking", config.enableReasoning) - } - ReasoningMode.REASONING_EFFORT -> { - if (!config.enableReasoning) { - body.put("reasoning_effort", "none") - } - } - ReasoningMode.THINKING_TYPE -> { - val type = if (config.enableReasoning) "enabled" else "disabled" - body.put("thinking", JSONObject().put("type", type)) - } - ReasoningMode.MODEL_SELECTION, ReasoningMode.NONE -> { - // No parameter needed - controlled via model selection or not supported - } - } - } - - /** - * 构建标准的 OpenAI Chat Completions 请求 - * - * @param config LLM 配置 - * @param messages 消息列表(JSONArray) - * @param streaming 是否启用流式传输 - * @return 构建好的 Request 对象 - */ - private fun buildRequest( - config: LlmRequestConfig, - messages: JSONArray, - streaming: Boolean = true - ): Request { - val url = resolveUrl(config.endpoint) - - val reqJson = JSONObject().apply { - if (config.model.isNotBlank()) { - put("model", config.model) - } - put("temperature", kotlin.math.round(config.temperature * 100) / 100) - put("messages", messages) - put("stream", streaming) - - // Add reasoning control parameters based on vendor - addReasoningParams(this, config) - }.toString() - - val body = reqJson.toRequestBody(jsonMedia) - val builder = Request.Builder() - .url(url) - .addHeader("Content-Type", "application/json") - .post(body) - - if (config.apiKey.isNotBlank()) { - builder.addHeader("Authorization", "Bearer ${config.apiKey}") - } + /** + * LLM 请求配置 + */ + private data class LlmRequestConfig( + val apiKey: String, + val endpoint: String, + val model: String, + val temperature: Double, + val vendor: LlmVendor, + val enableReasoning: Boolean, + val supportsReasoningControl: Boolean, + val useCustomReasoningParams: Boolean, + val reasoningParamsOnJson: String, + val reasoningParamsOffJson: String, + ) - return builder.build() - } - - /** - * 获取或创建 OkHttpClient - * - * 超时策略说明: - * - connectTimeout: 建立连接的超时时间 - * - readTimeout: 等待首个数据块的超时时间(首 token 超时) - * - writeTimeout: 写入请求体的超时时间 - * - 不设置 callTimeout: streaming 模式下总时长不受限制 - */ - private fun getHttpClient(): OkHttpClient { - return client ?: OkHttpClient.Builder() - .connectTimeout(CONNECT_TIMEOUT_SECONDS, TimeUnit.SECONDS) - .readTimeout(0, TimeUnit.SECONDS) - .writeTimeout(CONNECT_TIMEOUT_SECONDS, TimeUnit.SECONDS) - // 不设置 callTimeout,让 streaming 可以持续接收数据 - .build() - } - - /** - * 获取模型列表时使用的客户端(避免无限读超时) - */ - private fun getModelsHttpClient(): OkHttpClient { - return client ?: OkHttpClient.Builder() - .connectTimeout(CONNECT_TIMEOUT_SECONDS, TimeUnit.SECONDS) - .readTimeout(CONNECT_TIMEOUT_SECONDS, TimeUnit.SECONDS) - .writeTimeout(CONNECT_TIMEOUT_SECONDS, TimeUnit.SECONDS) - .build() - } - - /** - * 过滤掉 AI 输出中的 ... 标签及其内容 - * 部分模型会将推理内容放在正文中,需要过滤 - * - * @param text 原始文本 - * @return 过滤后的文本 - */ - private fun filterThinkTags(text: String): String { - // 使用正则表达式移除 ... 标签及其内容 - // (?s) 表示 DOTALL 模式,让 . 可以匹配换行符 - return text.replace(Regex("""(?s).*?"""), "").trim() - } - - private fun filterThinkTagsForStreaming(text: String): String { - val filtered = filterThinkTags(text) - val start = filtered.indexOf("") - if (start < 0) return filtered - val end = filtered.indexOf("", start + 7) - if (end >= 0) return filtered - return filtered.substring(0, start).trimEnd() - } - - /** - * 从响应 JSON 中提取文本内容 - * - * 支持标准 OpenAI 格式和自定义 output_text 字段 - * - * @param responseJson 响应的 JSON 字符串 - * @param fallback 提取失败时的回退文本 - * @return 提取的文本或 fallback - */ - private fun extractTextFromResponse(responseJson: String, fallback: String): String { - return try { - val obj = JSONObject(responseJson) - val rawText = when { - obj.has("choices") -> { - val choices = obj.getJSONArray("choices") - if (choices.length() > 0) { - val msg = choices.getJSONObject(0).optJSONObject("message") - msg?.optString("content")?.ifBlank { fallback } ?: fallback - } else fallback - } - obj.has("output_text") -> obj.optString("output_text", fallback) - else -> fallback - } - // 过滤掉 think 标签及其内容 - filterThinkTags(rawText) - } catch (t: Throwable) { - Log.e(TAG, "Failed to extract text from response", t) - fallback - } - } - - /** - * 解析 OpenAI 标准 /models 返回,抽取模型 ID 列表 - */ - private fun parseModelsFromResponse(responseJson: String): List { - val obj = JSONObject(responseJson) - val data = obj.optJSONArray("data") ?: return emptyList() - val result = mutableListOf() - for (i in 0 until data.length()) { - val item = data.optJSONObject(i) ?: continue - val id = item.optString("id").trim() - if (id.isNotEmpty()) { - result.add(id) - } + companion object { + private const val TAG = "LlmPostProcessor" + + /** 连接超时(秒) */ + private const val CONNECT_TIMEOUT_SECONDS = 30L + + /** 首 token 超时(秒)- streaming 模式下等待首个数据块的最大时间 */ + private const val FIRST_TOKEN_TIMEOUT_SECONDS = 60L + } + + private fun buildRequestConfig( + apiKey: String, + endpoint: String, + model: String, + temperature: Double, + vendor: LlmVendor, + enableReasoning: Boolean, + useCustomReasoningParams: Boolean, + reasoningParamsOnJson: String, + reasoningParamsOffJson: String, + ): LlmRequestConfig { + val supportsReasoning = vendor.supportsReasoningControl(model) + return LlmRequestConfig( + apiKey = apiKey, + endpoint = endpoint, + model = model, + temperature = temperature, + vendor = vendor, + enableReasoning = enableReasoning, + supportsReasoningControl = supportsReasoning, + useCustomReasoningParams = useCustomReasoningParams, + reasoningParamsOnJson = reasoningParamsOnJson, + reasoningParamsOffJson = reasoningParamsOffJson, + ) + } + + /** + * 从 Prefs 获取活动的 LLM 配置(使用新的供应商架构) + */ + private fun getActiveConfig(prefs: Prefs): LlmRequestConfig { + val vendor = prefs.llmVendor + + // SiliconFlow 免费服务特殊处理 + if (vendor == LlmVendor.SF_FREE && !prefs.sfFreeLlmUsePaidKey) { + val model = prefs.sfFreeLlmModel + val effective = prefs.getEffectiveLlmConfig() + return buildRequestConfig( + apiKey = BuildConfig.SF_FREE_API_KEY, + endpoint = Prefs.SF_CHAT_COMPLETIONS_ENDPOINT, + model = model, + temperature = Prefs.DEFAULT_LLM_TEMPERATURE.toDouble(), + vendor = vendor, + enableReasoning = prefs.getLlmVendorReasoningEnabled(vendor), + useCustomReasoningParams = effective?.useCustomReasoningParams ?: false, + reasoningParamsOnJson = effective?.reasoningParamsOnJson ?: Prefs.DEFAULT_CUSTOM_REASONING_PARAMS_ON_JSON, + reasoningParamsOffJson = effective?.reasoningParamsOffJson ?: Prefs.DEFAULT_CUSTOM_REASONING_PARAMS_OFF_JSON, + ) + } + + // 使用统一的 getEffectiveLlmConfig + val config = prefs.getEffectiveLlmConfig() + if (config != null) { + return buildRequestConfig( + apiKey = config.apiKey, + endpoint = config.endpoint, + model = config.model, + temperature = config.temperature.toDouble(), + vendor = config.vendor, + enableReasoning = config.enableReasoning, + useCustomReasoningParams = config.useCustomReasoningParams, + reasoningParamsOnJson = config.reasoningParamsOnJson, + reasoningParamsOffJson = config.reasoningParamsOffJson, + ) + } + + // 回退到旧的逻辑(兼容性) + val active = prefs.getActiveLlmProvider() + val fallbackEndpoint = if (vendor.hasBuiltinEndpoint) vendor.endpoint else (active?.endpoint ?: prefs.llmEndpoint) + return buildRequestConfig( + apiKey = active?.apiKey ?: prefs.llmApiKey, + endpoint = fallbackEndpoint, + model = active?.model ?: prefs.llmModel, + temperature = (active?.temperature ?: prefs.llmTemperature).toDouble(), + vendor = vendor, + enableReasoning = prefs.getLlmVendorReasoningEnabled(vendor), + useCustomReasoningParams = false, + reasoningParamsOnJson = Prefs.DEFAULT_CUSTOM_REASONING_PARAMS_ON_JSON, + reasoningParamsOffJson = Prefs.DEFAULT_CUSTOM_REASONING_PARAMS_OFF_JSON, + ) + } + + /** + * 解析 URL,自动添加 /chat/completions 后缀 + */ + private fun resolveUrl(base: String): String { + val raw = base.trim() + if (raw.isEmpty()) return Prefs.DEFAULT_LLM_ENDPOINT.trimEnd('/') + "/chat/completions" + val b = raw.trimEnd('/') + // 要求用户填写完整 URL(包含 http/https),不再自动补全协议 + val hasScheme = b.startsWith("http://", true) || b.startsWith("https://", true) + if (!hasScheme) throw IllegalArgumentException("Endpoint must start with http:// or https://") + + // 如果已直接指向 chat/completions 或 responses,则原样使用 + if (b.endsWith("/chat/completions")) return b + + // 其他情况:直接补全 /chat/completions + return "$b/chat/completions" + } + + /** + * 解析 /models URL,支持将 /chat/completions 转换为 /models + */ + private fun resolveModelsUrl(base: String): String { + val raw = base.trim() + if (raw.isEmpty()) throw IllegalArgumentException("Missing endpoint") + val b = raw.trimEnd('/') + val hasScheme = b.startsWith("http://", true) || b.startsWith("https://", true) + if (!hasScheme) throw IllegalArgumentException("Endpoint must start with http:// or https://") + if (b.endsWith("/models")) return b + if (b.endsWith("/chat/completions")) { + return b.removeSuffix("/chat/completions") + "/models" + } + return "$b/models" + } + + /** + * 根据供应商添加推理控制参数到请求体 + * + * @param body 请求 JSON 对象 + * @param config LLM 配置 + */ + private fun addReasoningParams(body: JSONObject, config: LlmRequestConfig) { + val vendor = config.vendor + if (config.useCustomReasoningParams) { + val raw = if (config.enableReasoning) config.reasoningParamsOnJson else config.reasoningParamsOffJson + val trimmed = raw.trim() + if (trimmed.isEmpty()) return + if (!trimmed.startsWith("{")) { + Log.w(TAG, "Reasoning params must be a JSON object: $trimmed") + return + } + val obj = try { + JSONObject(trimmed) + } catch (t: Throwable) { + Log.w(TAG, "Failed to parse reasoning params JSON: $trimmed", t) + return + } + val keys = obj.keys() + while (keys.hasNext()) { + val key = keys.next() + body.put(key, obj.opt(key)) + } + return + } + + if (!config.supportsReasoningControl) return + + when (vendor) { + LlmVendor.SF_FREE -> { + // SiliconFlow: enable_thinking 支持显式开关 + body.put("enable_thinking", config.enableReasoning) + return + } + LlmVendor.VOLCENGINE, LlmVendor.ZHIPU -> { + // 火山/智谱:通过 thinking.type 控制开关 + val type = if (config.enableReasoning) "enabled" else "disabled" + body.put("thinking", JSONObject().put("type", type)) + return + } + LlmVendor.GEMINI -> { + // Gemini Pro 只能将预算调低;flash 系列可关闭 + if (config.enableReasoning) return + val modelLower = config.model.lowercase() + val effort = if (modelLower.contains("pro") || modelLower.startsWith("gemini-3")) "low" else "none" + body.put("reasoning_effort", effort) + return + } + LlmVendor.GROQ -> { + // Groq:仅对支持思考的模型下发对应最小值 + if (config.enableReasoning) return + val modelLower = config.model.lowercase() + val effort = when { + modelLower.contains("qwen3") || modelLower.contains("qwen/") -> "none" + modelLower.contains("gpt-oss") -> "low" + else -> return + } + body.put("reasoning_effort", effort) + return + } + LlmVendor.CEREBRAS -> { + // Cerebras 仅 gpt-oss-120b 支持 reasoning_effort,且最小为 low + val isGptOss120b = config.model.equals("gpt-oss-120b", ignoreCase = true) + if (!isGptOss120b) return + if (!config.enableReasoning) { + body.put("reasoning_effort", "low") + } + return + } + LlmVendor.FIREWORKS -> { + // Fireworks 模型有不同的推理控制行为: + // - DeepSeek V3.1/V3.2: 二进制开关,默认关闭 + // - GLM 4.5/4.6: 二进制开关,默认开启 + // - GPT-OSS: 只支持 low/medium/high,不支持 none + val modelLower = config.model.lowercase() + when { + modelLower.contains("deepseek") -> { + // DeepSeek: 开启发送 medium,关闭发送 none + body.put("reasoning_effort", if (config.enableReasoning) "medium" else "none") + } + modelLower.contains("glm") -> { + // GLM: 默认开启,仅关闭时发送 none + if (!config.enableReasoning) { + body.put("reasoning_effort", "none") + } + } + modelLower.contains("gpt-oss") -> { + // GPT-OSS: 不支持 none,开启用 medium,关闭用 low + body.put("reasoning_effort", if (config.enableReasoning) "medium" else "low") + } + } + return + } + else -> { + // fall through to generic handling + } + } + + when (vendor.reasoningMode) { + ReasoningMode.ENABLE_THINKING -> { + body.put("enable_thinking", config.enableReasoning) + } + ReasoningMode.REASONING_EFFORT -> { + if (!config.enableReasoning) { + body.put("reasoning_effort", "none") + } + } + ReasoningMode.THINKING_TYPE -> { + val type = if (config.enableReasoning) "enabled" else "disabled" + body.put("thinking", JSONObject().put("type", type)) + } + ReasoningMode.MODEL_SELECTION, ReasoningMode.NONE -> { + // No parameter needed - controlled via model selection or not supported + } + } } - return result - } - - /** - * 从 SSE 流中解析并拼接所有文本内容 - * - * @param source 响应的 BufferedSource - * @return 拼接后的完整文本 - */ - private fun parseStreamingResponse( - source: BufferedSource, - onStreamingUpdate: ((String) -> Unit)? = null - ): String { - val contentBuilder = StringBuilder() - var lastEmittedText: String? = null - val timeout = source.timeout() - // 仅首个数据块启用超时,之后允许长间隔 - timeout.timeout(FIRST_TOKEN_TIMEOUT_SECONDS, TimeUnit.SECONDS) - var waitingFirstEvent = true - var shouldStop = false - val eventBuilder = StringBuilder() - - fun emitStreamingUpdateIfNeeded() { - val handler = onStreamingUpdate ?: return - val current = filterThinkTagsForStreaming(contentBuilder.toString()) - if (current.isEmpty() || current == lastEmittedText) return - lastEmittedText = current - try { - handler(current) - } catch (t: Throwable) { - Log.w(TAG, "Streaming update callback failed", t) - } + + /** + * 构建标准的 OpenAI Chat Completions 请求 + * + * @param config LLM 配置 + * @param messages 消息列表(JSONArray) + * @param streaming 是否启用流式传输 + * @return 构建好的 Request 对象 + */ + private fun buildRequest( + config: LlmRequestConfig, + messages: JSONArray, + streaming: Boolean = true, + ): Request { + val url = resolveUrl(config.endpoint) + + val reqJson = JSONObject().apply { + if (config.model.isNotBlank()) { + put("model", config.model) + } + put("temperature", kotlin.math.round(config.temperature * 100) / 100) + put("messages", messages) + put("stream", streaming) + + // Add reasoning control parameters based on vendor + addReasoningParams(this, config) + }.toString() + + val body = reqJson.toRequestBody(jsonMedia) + val builder = Request.Builder() + .url(url) + .addHeader("Content-Type", "application/json") + .post(body) + + if (config.apiKey.isNotBlank()) { + builder.addHeader("Authorization", "Bearer ${config.apiKey}") + } + + return builder.build() + } + + /** + * 获取或创建 OkHttpClient + * + * 超时策略说明: + * - connectTimeout: 建立连接的超时时间 + * - readTimeout: 等待首个数据块的超时时间(首 token 超时) + * - writeTimeout: 写入请求体的超时时间 + * - 不设置 callTimeout: streaming 模式下总时长不受限制 + */ + private fun getHttpClient(): OkHttpClient { + return client ?: OkHttpClient.Builder() + .connectTimeout(CONNECT_TIMEOUT_SECONDS, TimeUnit.SECONDS) + .readTimeout(0, TimeUnit.SECONDS) + .writeTimeout(CONNECT_TIMEOUT_SECONDS, TimeUnit.SECONDS) + // 不设置 callTimeout,让 streaming 可以持续接收数据 + .build() + } + + /** + * 获取模型列表时使用的客户端(避免无限读超时) + */ + private fun getModelsHttpClient(): OkHttpClient { + return client ?: OkHttpClient.Builder() + .connectTimeout(CONNECT_TIMEOUT_SECONDS, TimeUnit.SECONDS) + .readTimeout(CONNECT_TIMEOUT_SECONDS, TimeUnit.SECONDS) + .writeTimeout(CONNECT_TIMEOUT_SECONDS, TimeUnit.SECONDS) + .build() + } + + /** + * 过滤掉 AI 输出中的 ... 标签及其内容 + * 部分模型会将推理内容放在正文中,需要过滤 + * + * @param text 原始文本 + * @return 过滤后的文本 + */ + private fun filterThinkTags(text: String): String { + // 使用正则表达式移除 ... 标签及其内容 + // (?s) 表示 DOTALL 模式,让 . 可以匹配换行符 + return text.replace(Regex("""(?s).*?"""), "").trim() + } + + private fun filterThinkTagsForStreaming(text: String): String { + val filtered = filterThinkTags(text) + val start = filtered.indexOf("") + if (start < 0) return filtered + val end = filtered.indexOf("", start + 7) + if (end >= 0) return filtered + return filtered.substring(0, start).trimEnd() + } + + /** + * 从响应 JSON 中提取文本内容 + * + * 支持标准 OpenAI 格式和自定义 output_text 字段 + * + * @param responseJson 响应的 JSON 字符串 + * @param fallback 提取失败时的回退文本 + * @return 提取的文本或 fallback + */ + private fun extractTextFromResponse(responseJson: String, fallback: String): String { + return try { + val obj = JSONObject(responseJson) + val rawText = when { + obj.has("choices") -> { + val choices = obj.getJSONArray("choices") + if (choices.length() > 0) { + val msg = choices.getJSONObject(0).optJSONObject("message") + msg?.optString("content")?.ifBlank { fallback } ?: fallback + } else { + fallback + } + } + obj.has("output_text") -> obj.optString("output_text", fallback) + else -> fallback + } + // 过滤掉 think 标签及其内容 + filterThinkTags(rawText) + } catch (t: Throwable) { + Log.e(TAG, "Failed to extract text from response", t) + fallback + } } - fun flushEvent() { - if (eventBuilder.isEmpty()) return - val rawData = eventBuilder.toString().trim() - eventBuilder.clear() - - if (waitingFirstEvent) { - timeout.timeout(0, TimeUnit.MILLISECONDS) - timeout.clearDeadline() - waitingFirstEvent = false - } - - if (rawData.isEmpty()) return - if (rawData == "[DONE]") { - shouldStop = true - return - } - - try { - val json = JSONObject(rawData) - val choices = json.optJSONArray("choices") ?: return - if (choices.length() == 0) return - - val choice = choices.getJSONObject(0) - val delta = choice.optJSONObject("delta") - var appended = false - if (delta != null) { - when (val content = delta.opt("content")) { - is String -> if (content.isNotEmpty()) { - contentBuilder.append(content) - appended = true - } - is JSONArray -> { - for (i in 0 until content.length()) { - when (val item = content.get(i)) { - is String -> if (item.isNotEmpty()) { - contentBuilder.append(item) - appended = true - } - is JSONObject -> { - val textPart = item.optString("text") - if (textPart.isNotEmpty()) { - contentBuilder.append(textPart) - appended = true + /** + * 解析 OpenAI 标准 /models 返回,抽取模型 ID 列表 + */ + private fun parseModelsFromResponse(responseJson: String): List { + val obj = JSONObject(responseJson) + val data = obj.optJSONArray("data") ?: return emptyList() + val result = mutableListOf() + for (i in 0 until data.length()) { + val item = data.optJSONObject(i) ?: continue + val id = item.optString("id").trim() + if (id.isNotEmpty()) { + result.add(id) + } + } + return result + } + + /** + * 从 SSE 流中解析并拼接所有文本内容 + * + * @param source 响应的 BufferedSource + * @return 拼接后的完整文本 + */ + private fun parseStreamingResponse( + source: BufferedSource, + onStreamingUpdate: ((String) -> Unit)? = null, + ): String { + val contentBuilder = StringBuilder() + var lastEmittedText: String? = null + val timeout = source.timeout() + // 仅首个数据块启用超时,之后允许长间隔 + timeout.timeout(FIRST_TOKEN_TIMEOUT_SECONDS, TimeUnit.SECONDS) + var waitingFirstEvent = true + var shouldStop = false + val eventBuilder = StringBuilder() + + fun emitStreamingUpdateIfNeeded() { + val handler = onStreamingUpdate ?: return + val current = filterThinkTagsForStreaming(contentBuilder.toString()) + if (current.isEmpty() || current == lastEmittedText) return + lastEmittedText = current + try { + handler(current) + } catch (t: Throwable) { + Log.w(TAG, "Streaming update callback failed", t) + } + } + + fun flushEvent() { + if (eventBuilder.isEmpty()) return + val rawData = eventBuilder.toString().trim() + eventBuilder.clear() + + if (waitingFirstEvent) { + timeout.timeout(0, TimeUnit.MILLISECONDS) + timeout.clearDeadline() + waitingFirstEvent = false + } + + if (rawData.isEmpty()) return + if (rawData == "[DONE]") { + shouldStop = true + return + } + + try { + val json = JSONObject(rawData) + val choices = json.optJSONArray("choices") ?: return + if (choices.length() == 0) return + + val choice = choices.getJSONObject(0) + val delta = choice.optJSONObject("delta") + var appended = false + if (delta != null) { + when (val content = delta.opt("content")) { + is String -> if (content.isNotEmpty()) { + contentBuilder.append(content) + appended = true + } + is JSONArray -> { + for (i in 0 until content.length()) { + when (val item = content.get(i)) { + is String -> if (item.isNotEmpty()) { + contentBuilder.append(item) + appended = true + } + is JSONObject -> { + val textPart = item.optString("text") + if (textPart.isNotEmpty()) { + contentBuilder.append(textPart) + appended = true + } + } + } + } + } } - } } - } + if (appended) { + emitStreamingUpdateIfNeeded() + } + + val finishReason = choice.optString("finish_reason", "") + if (finishReason == "stop") { + shouldStop = true + } + } catch (e: Exception) { + Log.w(TAG, "Parse SSE chunk failed: $rawData", e) + } + } + + while (!source.exhausted() && !shouldStop) { + val line = try { + source.readUtf8Line() ?: break + } catch (e: IOException) { + Log.w(TAG, "Read line failed", e) + break + } + + if (line.isEmpty()) { + flushEvent() + continue + } + + // SSE 格式: 以 data: 开头的事件行,可能跨多行 + if (line.startsWith("data:")) { + eventBuilder.append(line.removePrefix("data:").trim()).append('\n') } - } } - if (appended) { - emitStreamingUpdateIfNeeded() + + // 处理未以空行结尾的事件 + if (!shouldStop) { + flushEvent() } - val finishReason = choice.optString("finish_reason", "") - if (finishReason == "stop") { - shouldStop = true + return contentBuilder.toString() + } + + /** + * 复用的底层 Chat 调用:构建请求、执行并解析文本。 + * 使用 streaming 模式,支持长时间等待和持续接收。 + * 需确保在非主线程调用。 + */ + private fun performChat( + config: LlmRequestConfig, + messages: JSONArray, + onStreamingUpdate: ((String) -> Unit)? = null, + ): RawCallResult { + val streamingResult = performChatInternal( + config, + messages, + streaming = true, + onStreamingUpdate = onStreamingUpdate, + ) + if (streamingResult.ok) return streamingResult + + // 若服务端拒绝或不支持流式,尝试回退到非流模式 + val shouldRetryWithoutStream = streamingResult.httpCode in listOf(400, 404, 405, 415, 422) || + (streamingResult.error?.contains("stream", ignoreCase = true) == true) || + (streamingResult.error?.contains("sse", ignoreCase = true) == true) + + if (!shouldRetryWithoutStream) return streamingResult + + Log.w(TAG, "Streaming call failed (code=${streamingResult.httpCode}): ${streamingResult.error ?: ""}. Retrying without stream.") + val fallback = performChatInternal(config, messages, streaming = false) + if (fallback.ok) return fallback + + return fallback.copy(error = fallback.error ?: streamingResult.error) + } + + private fun performChatInternal( + config: LlmRequestConfig, + messages: JSONArray, + streaming: Boolean, + onStreamingUpdate: ((String) -> Unit)? = null, + ): RawCallResult { + val req = try { + buildRequest(config, messages, streaming = streaming) + } catch (t: Throwable) { + Log.e(TAG, "Failed to build request", t) + return RawCallResult(false, error = "Build request failed: ${t.message}") } - } catch (e: Exception) { - Log.w(TAG, "Parse SSE chunk failed: $rawData", e) - } - } - while (!source.exhausted() && !shouldStop) { - val line = try { - source.readUtf8Line() ?: break - } catch (e: IOException) { - Log.w(TAG, "Read line failed", e) - break - } - - if (line.isEmpty()) { - flushEvent() - continue - } - - // SSE 格式: 以 data: 开头的事件行,可能跨多行 - if (line.startsWith("data:")) { - eventBuilder.append(line.removePrefix("data:").trim()).append('\n') - } - } + val http = getHttpClient() + val call = http.newCall(req) + activeCall = call + val resp = try { + call.execute() + } catch (t: Throwable) { + if (activeCall === call) { + activeCall = null + } + Log.e(TAG, "HTTP request failed", t) + return RawCallResult(false, error = t.message ?: "Network error") + } - // 处理未以空行结尾的事件 - if (!shouldStop) { - flushEvent() - } + if (!resp.isSuccessful) { + val code = resp.code + val err = try { + resp.body?.string() + } catch (_: Throwable) { + null + } finally { + resp.close() + } + if (activeCall === call) { + activeCall = null + } + return RawCallResult(false, httpCode = code, error = err?.take(256) ?: "HTTP $code") + } - return contentBuilder.toString() - } - - /** - * 复用的底层 Chat 调用:构建请求、执行并解析文本。 - * 使用 streaming 模式,支持长时间等待和持续接收。 - * 需确保在非主线程调用。 - */ - private fun performChat( - config: LlmRequestConfig, - messages: JSONArray, - onStreamingUpdate: ((String) -> Unit)? = null - ): RawCallResult { - val streamingResult = performChatInternal( - config, - messages, - streaming = true, - onStreamingUpdate = onStreamingUpdate - ) - if (streamingResult.ok) return streamingResult - - // 若服务端拒绝或不支持流式,尝试回退到非流模式 - val shouldRetryWithoutStream = streamingResult.httpCode in listOf(400, 404, 405, 415, 422) || - (streamingResult.error?.contains("stream", ignoreCase = true) == true) || - (streamingResult.error?.contains("sse", ignoreCase = true) == true) - - if (!shouldRetryWithoutStream) return streamingResult - - Log.w(TAG, "Streaming call failed (code=${streamingResult.httpCode}): ${streamingResult.error ?: ""}. Retrying without stream.") - val fallback = performChatInternal(config, messages, streaming = false) - if (fallback.ok) return fallback - - return fallback.copy(error = fallback.error ?: streamingResult.error) - } - - private fun performChatInternal( - config: LlmRequestConfig, - messages: JSONArray, - streaming: Boolean, - onStreamingUpdate: ((String) -> Unit)? = null - ): RawCallResult { - val req = try { - buildRequest(config, messages, streaming = streaming) - } catch (t: Throwable) { - Log.e(TAG, "Failed to build request", t) - return RawCallResult(false, error = "Build request failed: ${t.message}") - } + val text = try { + val body = resp.body ?: run { + Log.w(TAG, "Response body is null") + return RawCallResult(false, error = "Empty body") + } - val http = getHttpClient() - val call = http.newCall(req) - activeCall = call - val resp = try { - call.execute() - } catch (t: Throwable) { - if (activeCall === call) { - activeCall = null - } - Log.e(TAG, "HTTP request failed", t) - return RawCallResult(false, error = t.message ?: "Network error") - } + val contentType = resp.header("Content-Type") ?: body.contentType()?.toString().orEmpty() + val isEventStream = streaming && contentType.contains("text/event-stream", ignoreCase = true) - if (!resp.isSuccessful) { - val code = resp.code - val err = try { resp.body?.string() } catch (_: Throwable) { null } finally { resp.close() } - if (activeCall === call) { - activeCall = null - } - return RawCallResult(false, httpCode = code, error = err?.take(256) ?: "HTTP $code") - } + val parsed = if (isEventStream) { + parseStreamingResponse(body.source(), onStreamingUpdate = onStreamingUpdate) + } else { + val respText = body.string() + extractTextFromResponse(respText, fallback = "") + } - val text = try { - val body = resp.body ?: run { - Log.w(TAG, "Response body is null") - return RawCallResult(false, error = "Empty body") - } - - val contentType = resp.header("Content-Type") ?: body.contentType()?.toString().orEmpty() - val isEventStream = streaming && contentType.contains("text/event-stream", ignoreCase = true) - - val parsed = if (isEventStream) { - parseStreamingResponse(body.source(), onStreamingUpdate = onStreamingUpdate) - } else { - val respText = body.string() - extractTextFromResponse(respText, fallback = "") - } - - val filtered = filterThinkTags(parsed) - if (filtered.isBlank()) { - return RawCallResult(false, error = "Empty result") - } - filtered - } catch (t: Throwable) { - Log.e(TAG, "Failed to parse ${if (streaming) "streaming" else "non-streaming"} response", t) - return RawCallResult(false, error = t.message ?: "Parse error") - } finally { - try { - resp.close() - } catch (closeErr: Throwable) { - Log.w(TAG, "Close response failed", closeErr) - } - if (activeCall === call) { - activeCall = null - } - } + val filtered = filterThinkTags(parsed) + if (filtered.isBlank()) { + return RawCallResult(false, error = "Empty result") + } + filtered + } catch (t: Throwable) { + Log.e(TAG, "Failed to parse ${if (streaming) "streaming" else "non-streaming"} response", t) + return RawCallResult(false, error = t.message ?: "Parse error") + } finally { + try { + resp.close() + } catch (closeErr: Throwable) { + Log.w(TAG, "Close response failed", closeErr) + } + if (activeCall === call) { + activeCall = null + } + } - return RawCallResult(true, text = text) - } - - /** - * 取消当前进行中的 LLM 请求。 - */ - fun cancelActiveRequest() { - val call = activeCall - if (call == null) return - try { - call.cancel() - } catch (t: Throwable) { - Log.w(TAG, "Cancel active request failed", t) - } - } - - /** - * 带一次自动重试的调用。 - */ - private suspend fun performChatWithRetry( - config: LlmRequestConfig, - messages: JSONArray, - maxRetry: Int = 1, - onStreamingUpdate: ((String) -> Unit)? = null - ): RawCallResult { - var attempt = 0 - var last: RawCallResult - while (true) { - attempt++ - last = performChat(config, messages, onStreamingUpdate = onStreamingUpdate) - if (last.ok) return last - if (attempt > maxRetry) return last - Log.w(TAG, "performChat failed (attempt=$attempt), will retry once: ${last.httpCode ?: ""} ${last.error ?: ""}") - try { - kotlinx.coroutines.delay(350) - } catch (t: Throwable) { - Log.w(TAG, "Retry delay interrupted", t) - } - } - } - - /** - * 测试 LLM 调用是否可用:发送最简单 Prompt,看是否有返回内容。 - * 不改变任何业务状态,仅用于连通性自检/配置校验。 - */ - suspend fun testConnectivity(prefs: Prefs): LlmTestResult = withContext(Dispatchers.IO) { - // 基础必填校验(endpoint / model) - val active = getActiveConfig(prefs) - val requiresModel = active.vendor != LlmVendor.CUSTOM - if (active.endpoint.isBlank() || (requiresModel && active.model.isBlank())) { - val message = if (active.endpoint.isBlank()) "Missing endpoint" else "Missing model" - return@withContext LlmTestResult( - ok = false, - message = message - ) + return RawCallResult(true, text = text) } - val messages = JSONArray().apply { - put(JSONObject().apply { - put("role", "user") - put("content", "say `hi`") - }) + /** + * 取消当前进行中的 LLM 请求。 + */ + fun cancelActiveRequest() { + val call = activeCall + if (call == null) return + try { + call.cancel() + } catch (t: Throwable) { + Log.w(TAG, "Cancel active request failed", t) + } } - val result = performChat(active, messages) - if (result.ok) { - return@withContext LlmTestResult(true, contentPreview = result.text?.take(120)) - } else { - return@withContext LlmTestResult(false, httpCode = result.httpCode, message = result.error) - } - } - - /** - * 拉取 OpenAI 标准 /models 列表 - */ - suspend fun fetchModels(endpoint: String, apiKey: String): LlmModelsResult = withContext(Dispatchers.IO) { - val url = try { - resolveModelsUrl(endpoint) - } catch (t: Throwable) { - Log.e(TAG, "Resolve /models url failed", t) - return@withContext LlmModelsResult(false, message = t.message ?: "Invalid endpoint") + /** + * 带一次自动重试的调用。 + */ + private suspend fun performChatWithRetry( + config: LlmRequestConfig, + messages: JSONArray, + maxRetry: Int = 1, + onStreamingUpdate: ((String) -> Unit)? = null, + ): RawCallResult { + var attempt = 0 + var last: RawCallResult + while (true) { + attempt++ + last = performChat(config, messages, onStreamingUpdate = onStreamingUpdate) + if (last.ok) return last + if (attempt > maxRetry) return last + Log.w(TAG, "performChat failed (attempt=$attempt), will retry once: ${last.httpCode ?: ""} ${last.error ?: ""}") + try { + kotlinx.coroutines.delay(350) + } catch (t: Throwable) { + Log.w(TAG, "Retry delay interrupted", t) + } + } } - val reqBuilder = Request.Builder() - .url(url) - .get() - .addHeader("Content-Type", "application/json") + /** + * 测试 LLM 调用是否可用:发送最简单 Prompt,看是否有返回内容。 + * 不改变任何业务状态,仅用于连通性自检/配置校验。 + */ + suspend fun testConnectivity(prefs: Prefs): LlmTestResult = withContext(Dispatchers.IO) { + // 基础必填校验(endpoint / model) + val active = getActiveConfig(prefs) + val requiresModel = active.vendor != LlmVendor.CUSTOM + if (active.endpoint.isBlank() || (requiresModel && active.model.isBlank())) { + val message = if (active.endpoint.isBlank()) "Missing endpoint" else "Missing model" + return@withContext LlmTestResult( + ok = false, + message = message, + ) + } - if (apiKey.isNotBlank()) { - reqBuilder.addHeader("Authorization", "Bearer $apiKey") - } + val messages = JSONArray().apply { + put( + JSONObject().apply { + put("role", "user") + put("content", "say `hi`") + }, + ) + } - val resp = try { - getModelsHttpClient().newCall(reqBuilder.build()).execute() - } catch (t: Throwable) { - Log.e(TAG, "Fetch /models failed", t) - return@withContext LlmModelsResult(false, message = t.message ?: "Network error") + val result = performChat(active, messages) + if (result.ok) { + return@withContext LlmTestResult(true, contentPreview = result.text?.take(120)) + } else { + return@withContext LlmTestResult(false, httpCode = result.httpCode, message = result.error) + } } - val code = resp.code - val isSuccessful = resp.isSuccessful - val rawBody = try { resp.body?.string().orEmpty() } catch (t: Throwable) { - Log.w(TAG, "Read /models response failed", t) - "" - } finally { - try { - resp.close() - } catch (closeErr: Throwable) { - Log.w(TAG, "Close /models response failed", closeErr) - } - } + /** + * 拉取 OpenAI 标准 /models 列表 + */ + suspend fun fetchModels(endpoint: String, apiKey: String): LlmModelsResult = withContext(Dispatchers.IO) { + val url = try { + resolveModelsUrl(endpoint) + } catch (t: Throwable) { + Log.e(TAG, "Resolve /models url failed", t) + return@withContext LlmModelsResult(false, message = t.message ?: "Invalid endpoint") + } - if (!isSuccessful) { - val msg = rawBody.take(256).ifBlank { "HTTP $code" } - return@withContext LlmModelsResult(false, httpCode = code, message = msg) - } + val reqBuilder = Request.Builder() + .url(url) + .get() + .addHeader("Content-Type", "application/json") - val models = try { - parseModelsFromResponse(rawBody) - } catch (t: Throwable) { - Log.e(TAG, "Parse /models response failed", t) - return@withContext LlmModelsResult(false, httpCode = code, message = t.message ?: "Parse error") - } + if (apiKey.isNotBlank()) { + reqBuilder.addHeader("Authorization", "Bearer $apiKey") + } - if (models.isEmpty()) { - return@withContext LlmModelsResult(false, httpCode = code, message = "No models found") - } + val resp = try { + getModelsHttpClient().newCall(reqBuilder.build()).execute() + } catch (t: Throwable) { + Log.e(TAG, "Fetch /models failed", t) + return@withContext LlmModelsResult(false, message = t.message ?: "Network error") + } - return@withContext LlmModelsResult(true, models = models.distinct()) - } - - /** - * 与 process 等价,但返回是否成功及错误信息,便于 UI 反馈。 - * - * 用户选择的 prompt 直接作为完整的 system prompt 使用, - * 待处理的文本统一放在 user prompt 中,使用简洁的包装格式。 - */ - suspend fun processWithStatus( - input: String, - prefs: Prefs, - promptOverride: String? = null, - onStreamingUpdate: ((String) -> Unit)? = null - ): LlmProcessResult = withContext(Dispatchers.IO) { - if (input.isBlank()) { - Log.d(TAG, "Input is blank, skipping processing") - return@withContext LlmProcessResult(ok = true, text = input, usedAi = false, attempted = false, llmMs = 0) - } + val code = resp.code + val isSuccessful = resp.isSuccessful + val rawBody = try { + resp.body?.string().orEmpty() + } catch (t: Throwable) { + Log.w(TAG, "Read /models response failed", t) + "" + } finally { + try { + resp.close() + } catch (closeErr: Throwable) { + Log.w(TAG, "Close /models response failed", closeErr) + } + } - val config = getActiveConfig(prefs) - val systemPrompt = (promptOverride ?: prefs.activePromptContent) - val userInputPrefix = prefs.getLocalizedString(R.string.llm_prompt_user_input_prefix) - val userContent = "$userInputPrefix$input" - - val messages = JSONArray().apply { - put(JSONObject().apply { - put("role", "system") - put("content", systemPrompt) - }) - put(JSONObject().apply { - put("role", "user") - put("content", userContent) - }) - } + if (!isSuccessful) { + val msg = rawBody.take(256).ifBlank { "HTTP $code" } + return@withContext LlmModelsResult(false, httpCode = code, message = msg) + } - val t0 = System.nanoTime() - val result = performChatWithRetry(config, messages, onStreamingUpdate = onStreamingUpdate) - val dt = TimeUnit.NANOSECONDS - .toMillis((System.nanoTime() - t0).coerceAtLeast(0L)) - .coerceAtLeast(0L) - if (!result.ok) { - if (result.httpCode != null) { - Log.w(TAG, "LLM process() failed: HTTP ${result.httpCode}, ${result.error}") - } else { - Log.w(TAG, "LLM process() failed: ${result.error}") - } - return@withContext LlmProcessResult( - false, - text = input, - errorMessage = result.error, - httpCode = result.httpCode, - usedAi = false, - attempted = true, - llmMs = dt - ) - } + val models = try { + parseModelsFromResponse(rawBody) + } catch (t: Throwable) { + Log.e(TAG, "Parse /models response failed", t) + return@withContext LlmModelsResult(false, httpCode = code, message = t.message ?: "Parse error") + } - val text = result.text ?: input - Log.d(TAG, "Text processing completed, output length: ${text.length}") - return@withContext LlmProcessResult(true, text = text, usedAi = true, attempted = true, llmMs = dt) - } - - /** - * 与 editText 等价,但返回是否成功及错误信息,便于 UI 反馈。 - */ - suspend fun editTextWithStatus(original: String, instruction: String, prefs: Prefs): LlmProcessResult = withContext(Dispatchers.IO) { - if (original.isBlank() || instruction.isBlank()) { - Log.d(TAG, "Original or instruction is blank, skipping edit") - return@withContext LlmProcessResult(true, text = original, usedAi = false, attempted = false, llmMs = 0) + if (models.isEmpty()) { + return@withContext LlmModelsResult(false, httpCode = code, message = "No models found") + } + + return@withContext LlmModelsResult(true, models = models.distinct()) + } + + /** + * 与 process 等价,但返回是否成功及错误信息,便于 UI 反馈。 + * + * 用户选择的 prompt 直接作为完整的 system prompt 使用, + * 待处理的文本统一放在 user prompt 中,使用简洁的包装格式。 + */ + suspend fun processWithStatus( + input: String, + prefs: Prefs, + promptOverride: String? = null, + onStreamingUpdate: ((String) -> Unit)? = null, + ): LlmProcessResult = withContext(Dispatchers.IO) { + if (input.isBlank()) { + Log.d(TAG, "Input is blank, skipping processing") + return@withContext LlmProcessResult(ok = true, text = input, usedAi = false, attempted = false, llmMs = 0) + } + + val config = getActiveConfig(prefs) + val systemPrompt = (promptOverride ?: prefs.activePromptContent) + val userInputPrefix = prefs.getLocalizedString(R.string.llm_prompt_user_input_prefix) + val userContent = "$userInputPrefix$input" + + val messages = JSONArray().apply { + put( + JSONObject().apply { + put("role", "system") + put("content", systemPrompt) + }, + ) + put( + JSONObject().apply { + put("role", "user") + put("content", userContent) + }, + ) + } + + val t0 = System.nanoTime() + val result = performChatWithRetry(config, messages, onStreamingUpdate = onStreamingUpdate) + val dt = TimeUnit.NANOSECONDS + .toMillis((System.nanoTime() - t0).coerceAtLeast(0L)) + .coerceAtLeast(0L) + if (!result.ok) { + if (result.httpCode != null) { + Log.w(TAG, "LLM process() failed: HTTP ${result.httpCode}, ${result.error}") + } else { + Log.w(TAG, "LLM process() failed: ${result.error}") + } + return@withContext LlmProcessResult( + false, + text = input, + errorMessage = result.error, + httpCode = result.httpCode, + usedAi = false, + attempted = true, + llmMs = dt, + ) + } + + val text = result.text ?: input + Log.d(TAG, "Text processing completed, output length: ${text.length}") + return@withContext LlmProcessResult(true, text = text, usedAi = true, attempted = true, llmMs = dt) } - val config = getActiveConfig(prefs) + /** + * 与 editText 等价,但返回是否成功及错误信息,便于 UI 反馈。 + */ + suspend fun editTextWithStatus(original: String, instruction: String, prefs: Prefs): LlmProcessResult = withContext(Dispatchers.IO) { + if (original.isBlank() || instruction.isBlank()) { + Log.d(TAG, "Original or instruction is blank, skipping edit") + return@withContext LlmProcessResult(true, text = original, usedAi = false, attempted = false, llmMs = 0) + } + + val config = getActiveConfig(prefs) - val systemPrompt = prefs.getLocalizedString(R.string.llm_edit_system_prompt) - val instructionLabel = prefs.getLocalizedString(R.string.llm_edit_instruction_label) - val originalLabel = prefs.getLocalizedString(R.string.llm_edit_original_label) + val systemPrompt = prefs.getLocalizedString(R.string.llm_edit_system_prompt) + val instructionLabel = prefs.getLocalizedString(R.string.llm_edit_instruction_label) + val originalLabel = prefs.getLocalizedString(R.string.llm_edit_original_label) - val userContent = """ + val userContent = """ $instructionLabel $instruction $originalLabel $original - """.trimIndent() - - val messages = JSONArray().apply { - put(JSONObject().apply { - put("role", "system") - put("content", systemPrompt) - }) - put(JSONObject().apply { - put("role", "user") - put("content", userContent) - }) - } + """.trimIndent() + + val messages = JSONArray().apply { + put( + JSONObject().apply { + put("role", "system") + put("content", systemPrompt) + }, + ) + put( + JSONObject().apply { + put("role", "user") + put("content", userContent) + }, + ) + } - val t0 = System.nanoTime() - val result = performChatWithRetry(config, messages) - val dt = TimeUnit.NANOSECONDS - .toMillis((System.nanoTime() - t0).coerceAtLeast(0L)) - .coerceAtLeast(0L) - if (!result.ok) { - if (result.httpCode != null) { - Log.w(TAG, "LLM editText() failed: HTTP ${result.httpCode}, ${result.error}") - } else { - Log.w(TAG, "LLM editText() failed: ${result.error}") - } - return@withContext LlmProcessResult( - false, - text = original, - errorMessage = result.error, - httpCode = result.httpCode, - usedAi = false, - attempted = true, - llmMs = dt - ) - } + val t0 = System.nanoTime() + val result = performChatWithRetry(config, messages) + val dt = TimeUnit.NANOSECONDS + .toMillis((System.nanoTime() - t0).coerceAtLeast(0L)) + .coerceAtLeast(0L) + if (!result.ok) { + if (result.httpCode != null) { + Log.w(TAG, "LLM editText() failed: HTTP ${result.httpCode}, ${result.error}") + } else { + Log.w(TAG, "LLM editText() failed: ${result.error}") + } + return@withContext LlmProcessResult( + false, + text = original, + errorMessage = result.error, + httpCode = result.httpCode, + usedAi = false, + attempted = true, + llmMs = dt, + ) + } - val out = result.text ?: original + val out = result.text ?: original - Log.d(TAG, "Text editing completed, output length: ${out.length}") - return@withContext LlmProcessResult(true, text = out, usedAi = true, attempted = true, llmMs = dt) - } + Log.d(TAG, "Text editing completed, output length: ${out.length}") + return@withContext LlmProcessResult(true, text = out, usedAi = true, attempted = true, llmMs = dt) + } } diff --git a/app/src/main/java/com/brycewg/asrkb/asr/LlmVendor.kt b/app/src/main/java/com/brycewg/asrkb/asr/LlmVendor.kt index 71c9f828..3434eeb6 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/LlmVendor.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/LlmVendor.kt @@ -8,14 +8,18 @@ import com.brycewg.asrkb.R enum class ReasoningMode { /** No reasoning control support */ NONE, + /** Control via model selection (DeepSeek, Moonshot) */ MODEL_SELECTION, + /** SiliconFlow: enable_thinking parameter */ ENABLE_THINKING, + /** Gemini/Groq/Cerebras/OhMyGPT: reasoning_effort parameter */ REASONING_EFFORT, + /** Volcengine/Zhipu: thinking.type parameter */ - THINKING_TYPE + THINKING_TYPE, } /** @@ -37,7 +41,7 @@ enum class LlmVendor( /** How this vendor controls reasoning/thinking mode */ val reasoningMode: ReasoningMode = ReasoningMode.NONE, /** Models that support reasoning control (empty = all models) */ - val reasoningModels: Set = emptySet() + val reasoningModels: Set = emptySet(), ) { /** SiliconFlow - supports free tier and paid API */ SF_FREE( @@ -57,7 +61,7 @@ enum class LlmVendor( "Qwen/Qwen3-Next-80B-A3B-Thinking", "deepseek-ai/DeepSeek-V3.1-Terminus", "deepseek-ai/DeepSeek-V3.2", - "zai-org/GLM-4.6" + "zai-org/GLM-4.6", ), registerUrl = "https://cloud.siliconflow.cn/i/g8thUcWa", guideUrl = "https://docs.siliconflow.cn/cn/api-reference/chat-completions/chat-completions", @@ -74,8 +78,8 @@ enum class LlmVendor( "Qwen/Qwen3-235B-A22B-Thinking-2507", "deepseek-ai/DeepSeek-V3.1-Terminus", "deepseek-ai/DeepSeek-V3.2", - "zai-org/GLM-4.6" - ) + "zai-org/GLM-4.6", + ), ), /** OpenAI - GPT models */ @@ -89,7 +93,7 @@ enum class LlmVendor( guideUrl = "https://platform.openai.com/docs/quickstart", temperatureMin = 0f, temperatureMax = 2f, - reasoningMode = ReasoningMode.NONE + reasoningMode = ReasoningMode.NONE, ), /** Google Gemini */ @@ -104,7 +108,7 @@ enum class LlmVendor( "gemini-2.5-flash", "gemini-2.5-pro", "gemini-3-flash-preview", - "gemini-3-pro-preview" + "gemini-3-pro-preview", ), registerUrl = "https://aistudio.google.com/apikey", guideUrl = "https://ai.google.dev/gemini-api/docs/openai?hl=zh-cn", @@ -116,8 +120,8 @@ enum class LlmVendor( "gemini-2.5-flash", "gemini-2.5-pro", "gemini-3-flash-preview", - "gemini-3-pro-preview" - ) + "gemini-3-pro-preview", + ), ), /** DeepSeek - V3.2 models */ @@ -131,8 +135,9 @@ enum class LlmVendor( guideUrl = "https://api-docs.deepseek.com/", temperatureMin = 0f, temperatureMax = 2f, - reasoningMode = ReasoningMode.MODEL_SELECTION, // chat=non-thinking, reasoner=thinking - reasoningModels = setOf("deepseek-reasoner") + // chat=non-thinking, reasoner=thinking + reasoningMode = ReasoningMode.MODEL_SELECTION, + reasoningModels = setOf("deepseek-reasoner"), ), /** Moonshot (Kimi) */ @@ -143,14 +148,14 @@ enum class LlmVendor( defaultModel = "kimi-k2-0905-preview", models = listOf( "kimi-k2-0905-preview", - "kimi-k2-thinking" + "kimi-k2-thinking", ), registerUrl = "https://platform.moonshot.cn/console/api-keys", guideUrl = "https://platform.moonshot.cn/docs/api/chat", temperatureMin = 0f, temperatureMax = 1f, reasoningMode = ReasoningMode.MODEL_SELECTION, - reasoningModels = setOf("kimi-k2-thinking") + reasoningModels = setOf("kimi-k2-thinking"), ), /** Zhipu GLM */ @@ -166,14 +171,14 @@ enum class LlmVendor( "glm-4.5-air", "glm-4.5-flash", "glm-4-plus", - "glm-4-flashx" + "glm-4-flashx", ), registerUrl = "https://bigmodel.cn/usercenter/proj-mgmt/apikeys", guideUrl = "https://docs.bigmodel.cn/api-reference", temperatureMin = 0f, temperatureMax = 1f, reasoningMode = ReasoningMode.THINKING_TYPE, - reasoningModels = setOf("glm-4.7","glm-4.6", "glm-4.5", "glm-4.5-air", "glm-4.5-flash") + reasoningModels = setOf("glm-4.7", "glm-4.6", "glm-4.5", "glm-4.5-air", "glm-4.5-flash"), ), /** Volcengine (火山引擎) */ @@ -187,7 +192,7 @@ enum class LlmVendor( "doubao-seed-1-6-251015", "doubao-seed-1-6-flash-250828", "deepseek-v3-1-terminus", - "deepseek-v3-2-251201" + "deepseek-v3-2-251201", ), registerUrl = "https://console.volcengine.com/ark", guideUrl = "https://www.volcengine.com/docs/82379/1399328", @@ -199,8 +204,8 @@ enum class LlmVendor( "doubao-seed-1-6-251015", "doubao-seed-1-6-flash-250828", "deepseek-v3-1-terminus", - "deepseek-v3-2-251201" - ) + "deepseek-v3-2-251201", + ), ), /** Groq - fast inference */ @@ -215,7 +220,7 @@ enum class LlmVendor( "openai/gpt-oss-120b", "openai/gpt-oss-20b", "llama-3.3-70b-versatile", - "meta-llama/llama-4-maverick-17b-128e-instruct" + "meta-llama/llama-4-maverick-17b-128e-instruct", ), registerUrl = "https://console.groq.com/keys", guideUrl = "https://console.groq.com/docs/api-reference#chat-create", @@ -225,8 +230,8 @@ enum class LlmVendor( reasoningModels = setOf( "qwen/qwen3-32b", "openai/gpt-oss-120b", - "openai/gpt-oss-20b" - ) + "openai/gpt-oss-20b", + ), ), /** Cerebras - fast inference */ @@ -241,14 +246,14 @@ enum class LlmVendor( "qwen-3-32b", "qwen-3-235b-a22b-instruct-2507", "gpt-oss-120b", - "zai-glm-4.6" + "zai-glm-4.6", ), registerUrl = "https://cloud.cerebras.ai/platform", guideUrl = "https://inference-docs.cerebras.ai/api-reference/chat-completions", temperatureMin = 0f, temperatureMax = 1.5f, reasoningMode = ReasoningMode.REASONING_EFFORT, - reasoningModels = setOf("gpt-oss-120b") + reasoningModels = setOf("gpt-oss-120b"), ), /** OhMyGPT - multi-provider relay */ @@ -269,7 +274,7 @@ enum class LlmVendor( "gemini-2.5-flash-lite", "gemini-2.5-flash", "claude-haiku-4-5", - "claude-sonnet-4-5" + "claude-sonnet-4-5", ), registerUrl = "https://x.dogenet.win/i/CXuHm49s", guideUrl = "https://docs.ohmygpt.com/zh", @@ -277,10 +282,13 @@ enum class LlmVendor( temperatureMax = 2f, reasoningMode = ReasoningMode.REASONING_EFFORT, reasoningModels = setOf( - "gemini-2.5-flash-lite", "gemini-2.5-flash", - "claude-haiku-4-5", "claude-sonnet-4-5", - "gpt-5-mini", "gpt-5-nano" - ) + "gemini-2.5-flash-lite", + "gemini-2.5-flash", + "claude-haiku-4-5", + "claude-sonnet-4-5", + "gpt-5-mini", + "gpt-5-nano", + ), ), /** Fireworks AI - fast inference with multiple models */ @@ -301,7 +309,7 @@ enum class LlmVendor( "accounts/fireworks/models/gpt-oss-20b", // GLM models "accounts/fireworks/models/glm-4p6", - "accounts/fireworks/models/glm-4p7" + "accounts/fireworks/models/glm-4p7", ), registerUrl = "https://fireworks.ai/login", guideUrl = "https://docs.fireworks.ai/", @@ -317,8 +325,8 @@ enum class LlmVendor( "accounts/fireworks/models/glm-4p7", // GPT-OSS: only low/medium/high, no 'none' support "accounts/fireworks/models/gpt-oss-120b", - "accounts/fireworks/models/gpt-oss-20b" - ) + "accounts/fireworks/models/gpt-oss-20b", + ), ), /** Alibaba DashScope - 阿里云百炼 */ @@ -330,7 +338,7 @@ enum class LlmVendor( models = listOf( "qwen3.5-plus", "qwen3.5-397b-a17b", - "qwen3-max" + "qwen3-max", ), registerUrl = "https://dashscope.aliyun.com/", guideUrl = "https://help.aliyun.com/zh/dashscope/", @@ -340,8 +348,8 @@ enum class LlmVendor( reasoningModels = setOf( "qwen3.5-plus", "qwen3.5-397b-a17b", - "qwen3-max" - ) + "qwen3-max", + ), ), /** Custom - user-defined OpenAI-compatible API */ @@ -352,8 +360,9 @@ enum class LlmVendor( defaultModel = "", models = emptyList(), registerUrl = "", - guideUrl = "" - ); + guideUrl = "", + ), + ; /** Whether this vendor requires an API key */ val requiresApiKey: Boolean @@ -367,7 +376,7 @@ enum class LlmVendor( fun supportsReasoningControl(model: String): Boolean { return when (reasoningMode) { ReasoningMode.NONE -> false - ReasoningMode.MODEL_SELECTION -> false // Controlled via model selection, no switch needed + ReasoningMode.MODEL_SELECTION -> false // Controlled via model selection, no switch needed else -> reasoningModels.isEmpty() || reasoningModels.contains(model) } } @@ -395,19 +404,32 @@ enum class LlmVendor( * Ordered by: Free tier -> Domestic (China) -> International -> Custom */ fun allVendors(): List = listOf( - SF_FREE, // 1. Free service - DEEPSEEK, // 2. Domestic - popular - ZHIPU, // 3. Domestic - MOONSHOT, // 4. Domestic - VOLCENGINE, // 5. Domestic - DASHSCOPE, // 6. Domestic - Alibaba - OPENAI, // 7. International - GEMINI, // 8. International - GROQ, // 9. International - free tier - CEREBRAS, // 10. International - free tier - FIREWORKS, // 11. International - fast inference - OHMYGPT, // 12. Relay platform - CUSTOM // 13. Custom + // 1. Free service + SF_FREE, + // 2. Domestic - popular + DEEPSEEK, + // 3. Domestic + ZHIPU, + // 4. Domestic + MOONSHOT, + // 5. Domestic + VOLCENGINE, + // 6. Domestic - Alibaba + DASHSCOPE, + // 7. International + OPENAI, + // 8. International + GEMINI, + // 9. International - free tier + GROQ, + // 10. International - free tier + CEREBRAS, + // 11. International - fast inference + FIREWORKS, + // 12. Relay platform + OHMYGPT, + // 13. Custom + CUSTOM, ) /** Get built-in vendors (excluding custom) */ diff --git a/app/src/main/java/com/brycewg/asrkb/asr/LlmVendorAvailability.kt b/app/src/main/java/com/brycewg/asrkb/asr/LlmVendorAvailability.kt index 8a887cf3..6b883457 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/LlmVendorAvailability.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/LlmVendorAvailability.kt @@ -11,12 +11,12 @@ import com.brycewg.asrkb.store.Prefs internal data class LlmVendorPartition( val configured: List, - val unconfigured: List + val unconfigured: List, ) internal fun partitionLlmVendorsByConfigured( prefs: Prefs, - vendors: List + vendors: List, ): LlmVendorPartition { val configured = mutableListOf() val unconfigured = mutableListOf() @@ -29,7 +29,7 @@ internal fun partitionLlmVendorsByConfigured( } return LlmVendorPartition( configured = configured, - unconfigured = unconfigured + unconfigured = unconfigured, ) } @@ -60,4 +60,3 @@ internal fun isLlmVendorConfigured(prefs: Prefs, vendor: LlmVendor): Boolean { } private const val TAG = "LlmVendorAvailability" - diff --git a/app/src/main/java/com/brycewg/asrkb/asr/LocalModelLoadCoordinator.kt b/app/src/main/java/com/brycewg/asrkb/asr/LocalModelLoadCoordinator.kt index f490b8eb..c8c8b312 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/LocalModelLoadCoordinator.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/LocalModelLoadCoordinator.kt @@ -18,106 +18,106 @@ import kotlinx.coroutines.sync.withLock * - 不同 key 的请求会取消当前加载,并以新 key 重新加载。 */ internal object LocalModelLoadCoordinator { - private const val TAG = "LocalModelLoadCoordinator" + private const val TAG = "LocalModelLoadCoordinator" - private val scope = CoroutineScope(SupervisorJob() + Dispatchers.Default) - private val stateMutex = Mutex() + private val scope = CoroutineScope(SupervisorJob() + Dispatchers.Default) + private val stateMutex = Mutex() - private var runningKey: String? = null - private var runningJob: Job? = null - private var pendingKey: String? = null - private var pendingJob: Job? = null + private var runningKey: String? = null + private var runningJob: Job? = null + private var pendingKey: String? = null + private var pendingJob: Job? = null - fun request(key: String, loader: suspend () -> Unit) { - scope.launch { - val currentJob = coroutineContext[Job] - if (currentJob == null) return@launch + fun request(key: String, loader: suspend () -> Unit) { + scope.launch { + val currentJob = coroutineContext[Job] + if (currentJob == null) return@launch - var shouldRunNow = false - var jobToCancel: Job? = null - var pendingToCancel: Job? = null - val dedup = stateMutex.withLock { - val running = runningJob - val pending = pendingJob - val sameRunning = running?.isActive == true && runningKey == key - val samePending = pending?.isActive == true && pendingKey == key - if (sameRunning || samePending) return@withLock true + var shouldRunNow = false + var jobToCancel: Job? = null + var pendingToCancel: Job? = null + val dedup = stateMutex.withLock { + val running = runningJob + val pending = pendingJob + val sameRunning = running?.isActive == true && runningKey == key + val samePending = pending?.isActive == true && pendingKey == key + if (sameRunning || samePending) return@withLock true - pendingToCancel = pending?.takeIf { it.isActive && it != currentJob } - pendingKey = key - pendingJob = currentJob + pendingToCancel = pending?.takeIf { it.isActive && it != currentJob } + pendingKey = key + pendingJob = currentJob - jobToCancel = running?.takeIf { !it.isCompleted } - if (jobToCancel == null) { - runningKey = key - runningJob = currentJob - pendingKey = null - pendingJob = null - shouldRunNow = true - } - false - } + jobToCancel = running?.takeIf { !it.isCompleted } + if (jobToCancel == null) { + runningKey = key + runningJob = currentJob + pendingKey = null + pendingJob = null + shouldRunNow = true + } + false + } - if (dedup) return@launch + if (dedup) return@launch - pendingToCancel?.cancel() - if (!shouldRunNow) { - jobToCancel?.cancelAndJoin() - val promoted = stateMutex.withLock { - if (pendingJob != currentJob || pendingKey != key) return@withLock false - runningKey = key - runningJob = currentJob - pendingKey = null - pendingJob = null - true - } - if (!promoted) return@launch - } + pendingToCancel?.cancel() + if (!shouldRunNow) { + jobToCancel?.cancelAndJoin() + val promoted = stateMutex.withLock { + if (pendingJob != currentJob || pendingKey != key) return@withLock false + runningKey = key + runningJob = currentJob + pendingKey = null + pendingJob = null + true + } + if (!promoted) return@launch + } - try { - loader() - } catch (t: CancellationException) { - throw t - } catch (t: Throwable) { - Log.e(TAG, "Local model load failed (key=$key)", t) - } finally { - stateMutex.withLock { - if (runningJob == currentJob) { - runningJob = null - runningKey = null - } - if (pendingJob == currentJob) { - pendingJob = null - pendingKey = null - } + try { + loader() + } catch (t: CancellationException) { + throw t + } catch (t: Throwable) { + Log.e(TAG, "Local model load failed (key=$key)", t) + } finally { + stateMutex.withLock { + if (runningJob == currentJob) { + runningJob = null + runningKey = null + } + if (pendingJob == currentJob) { + pendingJob = null + pendingKey = null + } + } + } } - } } - } - fun cancel() { - scope.launch { - val (running, pending) = stateMutex.withLock { - runningKey = null - pendingKey = null - val running = runningJob - val pending = pendingJob - pendingJob = null - running to pending - } + fun cancel() { + scope.launch { + val (running, pending) = stateMutex.withLock { + runningKey = null + pendingKey = null + val running = runningJob + val pending = pendingJob + pendingJob = null + running to pending + } - pending?.cancelAndJoin() - running?.cancelAndJoin() - stateMutex.withLock { - if (runningJob == running && (running == null || running.isCompleted)) { - runningJob = null - runningKey = null - } - if (pendingJob == pending && (pending == null || pending.isCompleted)) { - pendingJob = null - pendingKey = null + pending?.cancelAndJoin() + running?.cancelAndJoin() + stateMutex.withLock { + if (runningJob == running && (running == null || running.isCompleted)) { + runningJob = null + runningKey = null + } + if (pendingJob == pending && (pending == null || pending.isCompleted)) { + pendingJob = null + pendingKey = null + } + } } - } } - } } diff --git a/app/src/main/java/com/brycewg/asrkb/asr/LocalModelPreload.kt b/app/src/main/java/com/brycewg/asrkb/asr/LocalModelPreload.kt index e58e0f07..752d26d4 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/LocalModelPreload.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/LocalModelPreload.kt @@ -21,21 +21,41 @@ fun preloadLocalAsrIfConfigured( onLoadStart: (() -> Unit)? = null, onLoadDone: (() -> Unit)? = null, suppressToastOnStart: Boolean = false, - forImmediateUse: Boolean = false + forImmediateUse: Boolean = false, ) { try { when (prefs.asrVendor) { AsrVendor.SenseVoice -> preloadSenseVoiceIfConfigured( - context, prefs, onLoadStart, onLoadDone, suppressToastOnStart, forImmediateUse + context, + prefs, + onLoadStart, + onLoadDone, + suppressToastOnStart, + forImmediateUse, ) AsrVendor.FunAsrNano -> preloadFunAsrNanoIfConfigured( - context, prefs, onLoadStart, onLoadDone, suppressToastOnStart, forImmediateUse + context, + prefs, + onLoadStart, + onLoadDone, + suppressToastOnStart, + forImmediateUse, ) AsrVendor.Telespeech -> preloadTelespeechIfConfigured( - context, prefs, onLoadStart, onLoadDone, suppressToastOnStart, forImmediateUse + context, + prefs, + onLoadStart, + onLoadDone, + suppressToastOnStart, + forImmediateUse, ) AsrVendor.Paraformer -> preloadParaformerIfConfigured( - context, prefs, onLoadStart, onLoadDone, suppressToastOnStart, forImmediateUse + context, + prefs, + onLoadStart, + onLoadDone, + suppressToastOnStart, + forImmediateUse, ) else -> { /* no-op for cloud vendors */ } } @@ -71,7 +91,8 @@ fun isLocalAsrVendor(vendor: AsrVendor): Boolean { AsrVendor.SenseVoice, AsrVendor.FunAsrNano, AsrVendor.Telespeech, - AsrVendor.Paraformer -> true + AsrVendor.Paraformer, + -> true else -> false } } @@ -118,9 +139,11 @@ fun isLocalAsrReady(prefs: Prefs): Boolean { suspend fun awaitLocalAsrReady( prefs: Prefs, pollIntervalMs: Long = 50L, - maxWaitMs: Long = 0L + maxWaitMs: Long = 0L, ): Boolean { - val vendor = try { prefs.asrVendor } catch (t: Throwable) { + val vendor = try { + prefs.asrVendor + } catch (t: Throwable) { Log.e("LocalModelPreload", "awaitLocalAsrReady: read vendor failed", t) return false } diff --git a/app/src/main/java/com/brycewg/asrkb/asr/LocalModelPseudoStreamAsrEngine.kt b/app/src/main/java/com/brycewg/asrkb/asr/LocalModelPseudoStreamAsrEngine.kt index 59592aec..c47800a5 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/LocalModelPseudoStreamAsrEngine.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/LocalModelPseudoStreamAsrEngine.kt @@ -24,7 +24,7 @@ abstract class LocalModelPseudoStreamAsrEngine( protected val scope: CoroutineScope, protected val prefs: Prefs, protected val listener: StreamingAsrEngine.Listener, - protected val onRequestDuration: ((Long) -> Unit)? = null + protected val onRequestDuration: ((Long) -> Unit)? = null, ) : StreamingAsrEngine { companion object { @@ -121,7 +121,7 @@ abstract class LocalModelPseudoStreamAsrEngine( sampleRate = sampleRate, channelConfig = channelConfig, audioFormat = audioFormat, - chunkMillis = chunkMillis + chunkMillis = chunkMillis, ) if (!audioManager.hasPermission()) { @@ -149,7 +149,7 @@ abstract class LocalModelPseudoStreamAsrEngine( context = context, sampleRate = sampleRate, windowMs = stopWindowMs, - sensitivityLevel = prefs.autoStopSilenceSensitivity + sensitivityLevel = prefs.autoStopSilenceSensitivity, ) } catch (t: Throwable) { Log.e(TAG, "Failed to create stop VAD for pseudo stream", t) @@ -280,7 +280,7 @@ abstract class LocalModelPseudoStreamAsrEngine( context = context, prefs = prefs, pcm = fullPcm, - sampleRate = sampleRate + sampleRate = sampleRate, ) // stop() 会 cancel 录音协程。若直接在 finally 内调用 suspend 的 onSessionFinished, // 其内部若使用可取消的 suspend API(mutex.withLock / ensureActive 等)会被 CancellationException 中断, @@ -296,8 +296,8 @@ abstract class LocalModelPseudoStreamAsrEngine( listener.onError( context.getString( R.string.error_recognize_failed_with_reason, - t.message ?: "" - ) + t.message ?: "", + ), ) } catch (e: Throwable) { Log.e(TAG, "Failed to notify final recognition error", e) diff --git a/app/src/main/java/com/brycewg/asrkb/asr/OfflineSpeechDenoiserManager.kt b/app/src/main/java/com/brycewg/asrkb/asr/OfflineSpeechDenoiserManager.kt index b07c9836..8ea67c79 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/OfflineSpeechDenoiserManager.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/OfflineSpeechDenoiserManager.kt @@ -7,155 +7,157 @@ import java.nio.ByteBuffer import java.nio.ByteOrder object OfflineSpeechDenoiserManager { - private const val TAG = "OfflineDenoiser" - private const val MODEL_ASSET_PATH = "denoiser/gtcrn_simple.onnx" - - @Volatile private var denoiser: Any? = null - @Volatile private var denoiserClass: Class<*>? = null - @Volatile private var loadFailed: Boolean = false - private val runLock = Any() - - fun denoiseIfEnabled( - context: Context, - prefs: Prefs, - pcm: ByteArray, - sampleRate: Int - ): ByteArray { - if (!prefs.offlineDenoiseEnabled) return pcm - if (pcm.isEmpty() || sampleRate <= 0) return pcm - - val denoiser = getOrCreate(context) ?: return pcm - val samples = pcm16leToFloatArray(pcm) - if (samples.isEmpty()) return pcm - - return try { - synchronized(runLock) { - val cls = denoiserClass ?: return pcm - val runMethod = cls.getMethod( - "run", - FloatArray::class.java, - Int::class.javaPrimitiveType - ) - val out = runMethod.invoke(denoiser, samples, sampleRate) ?: return pcm - val outSamples = out.javaClass.getMethod("getSamples").invoke(out) as? FloatArray ?: return pcm - val outRate = out.javaClass.getMethod("getSampleRate").invoke(out) as? Int ?: sampleRate - if (outRate != sampleRate) { - Log.w(TAG, "Denoiser output sampleRate=$outRate mismatch input=$sampleRate, skip") - return pcm - } - floatArrayToPcm16le(outSamples) - } - } catch (t: Throwable) { - Log.e(TAG, "Offline denoise failed", t) - pcm - } - } - - private fun getOrCreate(context: Context): Any? { - if (denoiser != null || loadFailed) return denoiser - synchronized(this) { - if (denoiser != null || loadFailed) return denoiser - return try { - try { - System.loadLibrary("sherpa-onnx-jni") + private const val TAG = "OfflineDenoiser" + private const val MODEL_ASSET_PATH = "denoiser/gtcrn_simple.onnx" + + @Volatile private var denoiser: Any? = null + + @Volatile private var denoiserClass: Class<*>? = null + + @Volatile private var loadFailed: Boolean = false + private val runLock = Any() + + fun denoiseIfEnabled( + context: Context, + prefs: Prefs, + pcm: ByteArray, + sampleRate: Int, + ): ByteArray { + if (!prefs.offlineDenoiseEnabled) return pcm + if (pcm.isEmpty() || sampleRate <= 0) return pcm + + val denoiser = getOrCreate(context) ?: return pcm + val samples = pcm16leToFloatArray(pcm) + if (samples.isEmpty()) return pcm + + return try { + synchronized(runLock) { + val cls = denoiserClass ?: return pcm + val runMethod = cls.getMethod( + "run", + FloatArray::class.java, + Int::class.javaPrimitiveType, + ) + val out = runMethod.invoke(denoiser, samples, sampleRate) ?: return pcm + val outSamples = out.javaClass.getMethod("getSamples").invoke(out) as? FloatArray ?: return pcm + val outRate = out.javaClass.getMethod("getSampleRate").invoke(out) as? Int ?: sampleRate + if (outRate != sampleRate) { + Log.w(TAG, "Denoiser output sampleRate=$outRate mismatch input=$sampleRate, skip") + return pcm + } + floatArrayToPcm16le(outSamples) + } } catch (t: Throwable) { - Log.w(TAG, "Failed to load sherpa-onnx-jni", t) + Log.e(TAG, "Offline denoise failed", t) + pcm } - - val gtcrnClass = Class.forName("com.k2fsa.sherpa.onnx.OfflineSpeechDenoiserGtcrnModelConfig") - val modelClass = Class.forName("com.k2fsa.sherpa.onnx.OfflineSpeechDenoiserModelConfig") - val configClass = Class.forName("com.k2fsa.sherpa.onnx.OfflineSpeechDenoiserConfig") - val denoiserCls = Class.forName("com.k2fsa.sherpa.onnx.OfflineSpeechDenoiser") - - val gtcrn = gtcrnClass.getDeclaredConstructor().newInstance() - setField(gtcrn, "model", MODEL_ASSET_PATH) - - val modelConfig = modelClass.getDeclaredConstructor().newInstance() - setField(modelConfig, "gtcrn", gtcrn) - setField(modelConfig, "numThreads", 1) - setField(modelConfig, "debug", false) - setField(modelConfig, "provider", "cpu") - - val config = configClass.getDeclaredConstructor().newInstance() - setField(config, "model", modelConfig) - - val ctor = denoiserCls.getDeclaredConstructor( - android.content.res.AssetManager::class.java, - configClass - ) - val instance = ctor.newInstance(context.assets, config) - denoiser = instance - denoiserClass = denoiserCls - Log.i(TAG, "Offline denoiser loaded: $MODEL_ASSET_PATH") - instance - } catch (t: Throwable) { - Log.e(TAG, "Failed to initialize offline denoiser", t) - loadFailed = true - denoiser = null - null - } } - } - - private fun setField(target: Any, name: String, value: Any?): Boolean { - return try { - val field = target.javaClass.getDeclaredField(name) - field.isAccessible = true - field.set(target, value) - true - } catch (t: Throwable) { - try { - val methodName = "set" + name.replaceFirstChar { - if (it.isLowerCase()) it.titlecase() else it.toString() + + private fun getOrCreate(context: Context): Any? { + if (denoiser != null || loadFailed) return denoiser + synchronized(this) { + if (denoiser != null || loadFailed) return denoiser + return try { + try { + System.loadLibrary("sherpa-onnx-jni") + } catch (t: Throwable) { + Log.w(TAG, "Failed to load sherpa-onnx-jni", t) + } + + val gtcrnClass = Class.forName("com.k2fsa.sherpa.onnx.OfflineSpeechDenoiserGtcrnModelConfig") + val modelClass = Class.forName("com.k2fsa.sherpa.onnx.OfflineSpeechDenoiserModelConfig") + val configClass = Class.forName("com.k2fsa.sherpa.onnx.OfflineSpeechDenoiserConfig") + val denoiserCls = Class.forName("com.k2fsa.sherpa.onnx.OfflineSpeechDenoiser") + + val gtcrn = gtcrnClass.getDeclaredConstructor().newInstance() + setField(gtcrn, "model", MODEL_ASSET_PATH) + + val modelConfig = modelClass.getDeclaredConstructor().newInstance() + setField(modelConfig, "gtcrn", gtcrn) + setField(modelConfig, "numThreads", 1) + setField(modelConfig, "debug", false) + setField(modelConfig, "provider", "cpu") + + val config = configClass.getDeclaredConstructor().newInstance() + setField(config, "model", modelConfig) + + val ctor = denoiserCls.getDeclaredConstructor( + android.content.res.AssetManager::class.java, + configClass, + ) + val instance = ctor.newInstance(context.assets, config) + denoiser = instance + denoiserClass = denoiserCls + Log.i(TAG, "Offline denoiser loaded: $MODEL_ASSET_PATH") + instance + } catch (t: Throwable) { + Log.e(TAG, "Failed to initialize offline denoiser", t) + loadFailed = true + denoiser = null + null + } } - val paramType = when (value) { - is Int -> Int::class.javaPrimitiveType - is Boolean -> Boolean::class.javaPrimitiveType - is Float -> Float::class.javaPrimitiveType - is Double -> Double::class.javaPrimitiveType - is Long -> Long::class.javaPrimitiveType - is String -> String::class.java - else -> value?.javaClass ?: Any::class.java + } + + private fun setField(target: Any, name: String, value: Any?): Boolean { + return try { + val field = target.javaClass.getDeclaredField(name) + field.isAccessible = true + field.set(target, value) + true + } catch (t: Throwable) { + try { + val methodName = "set" + name.replaceFirstChar { + if (it.isLowerCase()) it.titlecase() else it.toString() + } + val paramType = when (value) { + is Int -> Int::class.javaPrimitiveType + is Boolean -> Boolean::class.javaPrimitiveType + is Float -> Float::class.javaPrimitiveType + is Double -> Double::class.javaPrimitiveType + is Long -> Long::class.javaPrimitiveType + is String -> String::class.java + else -> value?.javaClass ?: Any::class.java + } + val method = target.javaClass.getMethod(methodName, paramType) + method.invoke(target, value) + true + } catch (t2: Throwable) { + Log.w(TAG, "Failed to set field $name", t2) + false + } } - val method = target.javaClass.getMethod(methodName, paramType) - method.invoke(target, value) - true - } catch (t2: Throwable) { - Log.w(TAG, "Failed to set field $name", t2) - false - } } - } - - private fun pcm16leToFloatArray(pcm: ByteArray): FloatArray { - if (pcm.isEmpty()) return FloatArray(0) - val n = pcm.size / 2 - val out = FloatArray(n) - val bb = ByteBuffer.wrap(pcm).order(ByteOrder.LITTLE_ENDIAN) - var i = 0 - while (i < n) { - val s = bb.short.toInt() - var f = s / 32768.0f - if (f > 1f) f = 1f else if (f < -1f) f = -1f - out[i] = f - i++ + + private fun pcm16leToFloatArray(pcm: ByteArray): FloatArray { + if (pcm.isEmpty()) return FloatArray(0) + val n = pcm.size / 2 + val out = FloatArray(n) + val bb = ByteBuffer.wrap(pcm).order(ByteOrder.LITTLE_ENDIAN) + var i = 0 + while (i < n) { + val s = bb.short.toInt() + var f = s / 32768.0f + if (f > 1f) f = 1f else if (f < -1f) f = -1f + out[i] = f + i++ + } + return out } - return out - } - - private fun floatArrayToPcm16le(samples: FloatArray): ByteArray { - if (samples.isEmpty()) return ByteArray(0) - val out = ByteArray(samples.size * 2) - var i = 0 - var j = 0 - while (i < samples.size) { - val f = samples[i].coerceIn(-1f, 1f) - val v = (f * 32767f).toInt() - out[j] = (v and 0xFF).toByte() - out[j + 1] = ((v shr 8) and 0xFF).toByte() - i++ - j += 2 + + private fun floatArrayToPcm16le(samples: FloatArray): ByteArray { + if (samples.isEmpty()) return ByteArray(0) + val out = ByteArray(samples.size * 2) + var i = 0 + var j = 0 + while (i < samples.size) { + val f = samples[i].coerceIn(-1f, 1f) + val v = (f * 32767f).toInt() + out[j] = (v and 0xFF).toByte() + out[j + 1] = ((v shr 8) and 0xFF).toByte() + i++ + j += 2 + } + return out } - return out - } } diff --git a/app/src/main/java/com/brycewg/asrkb/asr/OpenAiFileAsrEngine.kt b/app/src/main/java/com/brycewg/asrkb/asr/OpenAiFileAsrEngine.kt index 32649cfd..a1785787 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/OpenAiFileAsrEngine.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/OpenAiFileAsrEngine.kt @@ -25,7 +25,7 @@ class OpenAiFileAsrEngine( prefs: Prefs, listener: StreamingAsrEngine.Listener, onRequestDuration: ((Long) -> Unit)? = null, - httpClient: OkHttpClient? = null + httpClient: OkHttpClient? = null, ) : BaseFileAsrEngine(context, scope, prefs, listener, onRequestDuration), PcmBatchRecognizer { companion object { @@ -57,7 +57,7 @@ class OpenAiFileAsrEngine( .addFormDataPart( "file", "audio.wav", - tmp.asRequestBody("audio/wav".toMediaType()) + tmp.asRequestBody("audio/wav".toMediaType()), ) .addFormDataPart("response_format", "json") if (usePrompt && prompt.isNotEmpty()) { @@ -85,14 +85,16 @@ class OpenAiFileAsrEngine( val extra = extractErrorHint(bodyStr) val detail = formatHttpDetail(r.message, extra) listener.onError( - context.getString(R.string.error_request_failed_http, r.code, detail) + context.getString(R.string.error_request_failed_http, r.code, detail), ) return } val text = parseTextFromResponse(bodyStr) if (text.isNotBlank()) { val dt = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - t0) - try { onRequestDuration?.invoke(dt) } catch (_: Throwable) {} + try { + onRequestDuration?.invoke(dt) + } catch (_: Throwable) {} listener.onFinal(text) } else { listener.onError(context.getString(R.string.error_asr_empty_result)) @@ -100,12 +102,14 @@ class OpenAiFileAsrEngine( } } catch (t: Throwable) { listener.onError( - context.getString(R.string.error_recognize_failed_with_reason, t.message ?: "") + context.getString(R.string.error_recognize_failed_with_reason, t.message ?: ""), ) } } - override suspend fun recognizeFromPcm(pcm: ByteArray) { recognize(pcm) } + override suspend fun recognizeFromPcm(pcm: ByteArray) { + recognize(pcm) + } /** * 从响应体中提取错误提示信息 diff --git a/app/src/main/java/com/brycewg/asrkb/asr/ParaformerStreamAsrEngine.kt b/app/src/main/java/com/brycewg/asrkb/asr/ParaformerStreamAsrEngine.kt index e671bcf1..352fa518 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/ParaformerStreamAsrEngine.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/ParaformerStreamAsrEngine.kt @@ -38,7 +38,7 @@ class ParaformerStreamAsrEngine( private val scope: CoroutineScope, private val prefs: Prefs, private val listener: StreamingAsrEngine.Listener, - private val externalPcmMode: Boolean = false + private val externalPcmMode: Boolean = false, ) : StreamingAsrEngine, ExternalPcmConsumer { companion object { @@ -50,9 +50,11 @@ class ParaformerStreamAsrEngine( private val closing = AtomicBoolean(false) private val finalizeOnce = AtomicBoolean(false) private val closeSilently = AtomicBoolean(false) + @Volatile private var useItnForSession: Boolean = false private var audioJob: Job? = null private val mgr = ParaformerOnnxManager.getInstance() + @Volatile private var currentStream: Any? = null private val streamMutex = Mutex() @@ -81,7 +83,7 @@ class ParaformerStreamAsrEngine( if (!externalPcmMode) { val hasPermission = ContextCompat.checkSelfPermission( context, - Manifest.permission.RECORD_AUDIO + Manifest.permission.RECORD_AUDIO, ) == PackageManager.PERMISSION_GRANTED if (!hasPermission) { listener.onError(context.getString(R.string.error_record_permission_denied)) @@ -168,12 +170,23 @@ class ParaformerStreamAsrEngine( // 若在准备期间已调用 stop(),此处直接做最终解码 if (closing.get() && finalizeOnce.compareAndSet(false, true)) { if (closeSilently.get()) { - try { releaseStreamSilently(stream) } catch (t: Throwable) { Log.e(TAG, "releaseStreamSilently failed", t) } + try { + releaseStreamSilently(stream) + } catch (t: Throwable) { + Log.e(TAG, "releaseStreamSilently failed", t) + } } else { - val finalText = try { finalizeAndRelease(stream) } catch (t: Throwable) { - Log.e(TAG, "finalizeAndRelease failed", t); "" + val finalText = try { + finalizeAndRelease(stream) + } catch (t: Throwable) { + Log.e(TAG, "finalizeAndRelease failed", t) + "" + } + try { + listener.onFinal(finalText) + } catch (t: Throwable) { + Log.e(TAG, "notify final failed", t) } - try { listener.onFinal(finalText) } catch (t: Throwable) { Log.e(TAG, "notify final failed", t) } } closing.set(false) running.set(false) @@ -186,7 +199,9 @@ class ParaformerStreamAsrEngine( override fun appendPcm(pcm: ByteArray, sampleRate: Int, channels: Int) { if (!running.get() && currentStream == null && !closing.get()) return if (sampleRate != 16000 || channels != 1) return - try { listener.onAmplitude(com.brycewg.asrkb.asr.calculateNormalizedAmplitude(pcm)) } catch (_: Throwable) { } + try { + listener.onAmplitude(com.brycewg.asrkb.asr.calculateNormalizedAmplitude(pcm)) + } catch (_: Throwable) { } val s = currentStream if (s == null) { scope.launch { appendPrebuffer(pcm) } @@ -211,10 +226,17 @@ class ParaformerStreamAsrEngine( val s = currentStream if (s != null && finalizeOnce.compareAndSet(false, true)) { scope.launch(Dispatchers.Default) { - val finalText = try { finalizeAndRelease(s) } catch (t: Throwable) { - Log.e(TAG, "finalizeAndRelease failed", t); "" + val finalText = try { + finalizeAndRelease(s) + } catch (t: Throwable) { + Log.e(TAG, "finalizeAndRelease failed", t) + "" + } + try { + listener.onFinal(finalText) + } catch (t: Throwable) { + Log.e(TAG, "notify final failed", t) } - try { listener.onFinal(finalText) } catch (t: Throwable) { Log.e(TAG, "notify final failed", t) } closing.set(false) } } @@ -234,7 +256,7 @@ class ParaformerStreamAsrEngine( sampleRate = sampleRate, channelConfig = channelConfig, audioFormat = audioFormat, - chunkMillis = chunkMillis + chunkMillis = chunkMillis, ) if (!audioManager.hasPermission()) { @@ -244,9 +266,11 @@ class ParaformerStreamAsrEngine( return@launch } - val vadDetector = if (isVadAutoStopEnabled(context, prefs)) + val vadDetector = if (isVadAutoStopEnabled(context, prefs)) { VadDetector(context, sampleRate, prefs.autoStopSilenceWindowMs, prefs.autoStopSilenceSensitivity) - else null + } else { + null + } try { Log.d(TAG, "Starting audio capture for Paraformer with chunk=${chunkMillis}ms") @@ -264,7 +288,11 @@ class ParaformerStreamAsrEngine( // VAD 自动判停 if (vadDetector?.shouldStop(audioChunk, audioChunk.size) == true) { Log.d(TAG, "Silence detected, stopping recording") - try { listener.onStopped() } catch (t: Throwable) { Log.e(TAG, "Failed to notify stopped", t) } + try { + listener.onStopped() + } catch (t: Throwable) { + Log.e(TAG, "Failed to notify stopped", t) + } stop() return@collect } @@ -290,7 +318,11 @@ class ParaformerStreamAsrEngine( } else { context.getString(R.string.error_audio_error, t.message ?: "") } - try { listener.onError(msg) } catch (err: Throwable) { Log.e(TAG, "notify error failed", err) } + try { + listener.onError(msg) + } catch (err: Throwable) { + Log.e(TAG, "notify error failed", err) + } // 录音被系统中断:静默释放(不再回调 onFinal),避免后续 stop() 触发 JNI 竞态 closeSilently.set(true) @@ -299,7 +331,9 @@ class ParaformerStreamAsrEngine( val s = currentStream if (s != null && finalizeOnce.compareAndSet(false, true)) { scope.launch(Dispatchers.Default) { - try { releaseStreamSilently(s) } catch (releaseErr: Throwable) { + try { + releaseStreamSilently(s) + } catch (releaseErr: Throwable) { Log.e(TAG, "releaseStreamSilently failed", releaseErr) } finally { closeSilently.set(false) @@ -364,7 +398,11 @@ class ParaformerStreamAsrEngine( val normalized = if (useItnForSession) ChineseItn.normalize(trimmed) else trimmed val needEmit = (now - lastEmitUptimeMs) >= FRAME_MS && normalized != lastEmittedText if (needEmit) { - try { listener.onPartial(normalized) } catch (t: Throwable) { Log.e(TAG, "notify partial failed", t) } + try { + listener.onPartial(normalized) + } catch (t: Throwable) { + Log.e(TAG, "notify partial failed", t) + } lastEmitUptimeMs = now lastEmittedText = normalized } @@ -389,7 +427,11 @@ class ParaformerStreamAsrEngine( loops++ } text = mgr.getResultText(stream) - try { mgr.releaseStream(stream) } catch (t: Throwable) { Log.e(TAG, "releaseStream failed", t) } + try { + mgr.releaseStream(stream) + } catch (t: Throwable) { + Log.e(TAG, "releaseStream failed", t) + } currentStream = null } var out = text?.trim().orEmpty() @@ -409,7 +451,11 @@ class ParaformerStreamAsrEngine( private suspend fun releaseStreamSilently(stream: Any) { streamMutex.withLock { if (currentStream !== stream) return - try { mgr.releaseStream(stream) } catch (t: Throwable) { Log.e(TAG, "releaseStream failed", t) } + try { + mgr.releaseStream(stream) + } catch (t: Throwable) { + Log.e(TAG, "releaseStream failed", t) + } currentStream = null } } @@ -506,13 +552,17 @@ private class ReflectiveOnlineStream(val instance: Any) { } fun inputFinished() { - try { cls.getMethod("inputFinished").invoke(instance) } catch (t: Throwable) { + try { + cls.getMethod("inputFinished").invoke(instance) + } catch (t: Throwable) { Log.e("ROnlineStream", "inputFinished failed", t) } } fun release() { - try { cls.getMethod("release").invoke(instance) } catch (t: Throwable) { + try { + cls.getMethod("release").invoke(instance) + } catch (t: Throwable) { Log.e("ROnlineStream", "release failed", t) } } @@ -546,7 +596,9 @@ private class ReflectiveOnlineRecognizer(private val instance: Any, private val } fun release() { - try { cls.getMethod("release").invoke(instance) } catch (t: Throwable) { + try { + cls.getMethod("release").invoke(instance) + } catch (t: Throwable) { Log.e("ROnlineRecognizer", "release failed", t) } } @@ -567,24 +619,35 @@ class ParaformerOnnxManager private constructor() { private val runtimeLock = Any() @Volatile private var cachedConfig: RecognizerConfig? = null + @Volatile private var cachedRecognizer: ReflectiveOnlineRecognizer? = null + @Volatile private var preparing: Boolean = false + @Volatile private var clsOnlineRecognizer: Class<*>? = null + @Volatile private var clsOnlineRecognizerConfig: Class<*>? = null + @Volatile private var clsOnlineModelConfig: Class<*>? = null + @Volatile private var clsOnlineParaformerModelConfig: Class<*>? = null + @Volatile private var clsFeatureConfig: Class<*>? = null + @Volatile private var unloadJob: Job? = null // 最近一次配置与流计数:用于保留/卸载 @Volatile private var lastKeepAliveMs: Long = 0L + @Volatile private var lastAlwaysKeep: Boolean = false private val activeStreams = AtomicInteger(0) + @Volatile private var pendingUnload: Boolean = false fun isOnnxAvailable(): Boolean { return try { - Class.forName("com.k2fsa.sherpa.onnx.OnlineRecognizer"); true + Class.forName("com.k2fsa.sherpa.onnx.OnlineRecognizer") + true } catch (t: Throwable) { Log.d(TAG, "sherpa-onnx online not available", t) false @@ -632,7 +695,10 @@ class ParaformerOnnxManager private constructor() { private fun scheduleAutoUnload(keepAliveMs: Long, alwaysKeep: Boolean) { unloadJob?.cancel() if (alwaysKeep) return - if (keepAliveMs <= 0L) { unload(); return } + if (keepAliveMs <= 0L) { + unload() + return + } unloadJob = scope.launch { delay(keepAliveMs) unload() @@ -660,7 +726,7 @@ class ParaformerOnnxManager private constructor() { try { val m = target.javaClass.getMethod( "set" + name.replaceFirstChar { if (it.isLowerCase()) it.titlecase() else it.toString() }, - value?.javaClass ?: Any::class.java + value?.javaClass ?: Any::class.java, ) m.invoke(target, value) true @@ -679,7 +745,7 @@ class ParaformerOnnxManager private constructor() { val provider: String = "cpu", val sampleRate: Int = 16000, val featureDim: Int = 80, - val debug: Boolean = false + val debug: Boolean = false, ) { fun toCacheKey(): String = listOf(tokens, encoder, decoder, numThreads, provider, sampleRate, featureDim, debug).joinToString("|") } @@ -720,7 +786,7 @@ class ParaformerOnnxManager private constructor() { private fun createRecognizer(recConfig: Any): Any { val ctor = clsOnlineRecognizer!!.getDeclaredConstructor( android.content.res.AssetManager::class.java, - clsOnlineRecognizerConfig!! + clsOnlineRecognizerConfig!!, ) return ctor.newInstance(null, recConfig) } @@ -813,7 +879,8 @@ class ParaformerOnnxManager private constructor() { activeStreams.incrementAndGet() s } catch (t: Throwable) { - Log.e(TAG, "createStream failed", t); null + Log.e(TAG, "createStream failed", t) + null } } @@ -844,7 +911,8 @@ class ParaformerOnnxManager private constructor() { if (r != null && stream is ReflectiveOnlineStream) r.isReady(stream) else false } } catch (t: Throwable) { - Log.e(TAG, "isReady failed", t); false + Log.e(TAG, "isReady failed", t) + false } } @@ -866,7 +934,8 @@ class ParaformerOnnxManager private constructor() { if (r != null && stream is ReflectiveOnlineStream) r.getResultText(stream) else null } } catch (t: Throwable) { - Log.e(TAG, "getResultText failed", t); null + Log.e(TAG, "getResultText failed", t) + null } } @@ -901,7 +970,7 @@ fun preloadParaformerIfConfigured( onLoadStart: (() -> Unit)? = null, onLoadDone: (() -> Unit)? = null, suppressToastOnStart: Boolean = false, - forImmediateUse: Boolean = false + forImmediateUse: Boolean = false, ) { try { val manager = ParaformerOnnxManager.getInstance() @@ -941,7 +1010,7 @@ fun preloadParaformerIfConfigured( android.widget.Toast.makeText( context, context.getString(com.brycewg.asrkb.R.string.pf_loading_model), - android.widget.Toast.LENGTH_SHORT + android.widget.Toast.LENGTH_SHORT, ).show() } } @@ -955,7 +1024,7 @@ fun preloadParaformerIfConfigured( android.widget.Toast.makeText( context, context.getString(com.brycewg.asrkb.R.string.sv_model_ready_with_ms, dt), - android.widget.Toast.LENGTH_SHORT + android.widget.Toast.LENGTH_SHORT, ).show() } manager.scheduleUnloadIfIdle() diff --git a/app/src/main/java/com/brycewg/asrkb/asr/ParallelAsrEngine.kt b/app/src/main/java/com/brycewg/asrkb/asr/ParallelAsrEngine.kt index 5aea4807..025a4fbc 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/ParallelAsrEngine.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/ParallelAsrEngine.kt @@ -13,8 +13,8 @@ import kotlinx.coroutines.Job import kotlinx.coroutines.delay import kotlinx.coroutines.isActive import kotlinx.coroutines.launch -import java.util.concurrent.atomic.AtomicLong import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.AtomicLong /** * 并行主备 ASR 引擎: @@ -25,606 +25,631 @@ import java.util.concurrent.atomic.AtomicBoolean * - 主用超时或失败(onError/空 onFinal):尝试采用备用结果 */ class ParallelAsrEngine( - private val context: Context, - private val scope: CoroutineScope, - private val prefs: Prefs, - private val listener: StreamingAsrEngine.Listener, - val primaryVendor: AsrVendor, - val backupVendor: AsrVendor, - private val onPrimaryRequestDuration: ((Long) -> Unit)? = null, - private val externalPcmInput: Boolean = false + private val context: Context, + private val scope: CoroutineScope, + private val prefs: Prefs, + private val listener: StreamingAsrEngine.Listener, + val primaryVendor: AsrVendor, + val backupVendor: AsrVendor, + private val onPrimaryRequestDuration: ((Long) -> Unit)? = null, + private val externalPcmInput: Boolean = false, ) : StreamingAsrEngine, ExternalPcmConsumer { - companion object { - private const val TAG = "ParallelAsrEngine" - private const val SAMPLE_RATE = 16000 - private const val CHANNELS = 1 - private const val CHUNK_MS = 200 - private const val PRIMARY_SWITCH_RATIO_BALANCED = 0.75 - private const val PRIMARY_SWITCH_RATIO_SENSITIVE = 0.5 - private const val PRIMARY_SWITCH_MIN_MS = 6_000L - private const val PRIMARY_SWITCH_MAX_MS = 15_000L - private const val PRIMARY_SWITCH_NONSTREAM_SOFT_MAX_BALANCED_MS = 25_000L - private const val PRIMARY_SWITCH_NONSTREAM_SOFT_MAX_SENSITIVE_MS = 18_000L - } - - private enum class Source { PRIMARY, BACKUP } - - private sealed class Terminal { - data class Final(val text: String) : Terminal() - data class Error(val message: String) : Terminal() - } - - override val isRunning: Boolean - get() = running.get() - - private val running = AtomicBoolean(false) - private val stopRequested = AtomicBoolean(false) - private val terminalDelivered = AtomicBoolean(false) - - private val stateLock = Any() - - private var audioJob: Job? = null - private var primaryTimeoutJob: Job? = null - - @Volatile private var startUptimeMs: Long = 0L - private val audioBytes = AtomicLong(0L) - @Volatile private var stoppedNotified: Boolean = false - @Volatile private var primaryTimedOut: Boolean = false - @Volatile private var primaryTerminal: Terminal? = null - @Volatile private var backupTerminal: Terminal? = null - @Volatile private var lastFinalFromBackup: Boolean = false - - fun wasLastResultFromBackup(): Boolean = lastFinalFromBackup - - private val primaryListener = EngineListener(Source.PRIMARY, forwardLocalModelUi = true) - private val backupListener = EngineListener(Source.BACKUP, forwardLocalModelUi = false) - - private var primaryEngine: StreamingAsrEngine? = null - private var backupEngine: StreamingAsrEngine? = null - private var primaryConsumer: ExternalPcmConsumer? = null - private var backupConsumer: ExternalPcmConsumer? = null - - override fun start() { - if (!running.compareAndSet(false, true)) return - - stopRequested.set(false) - terminalDelivered.set(false) - stoppedNotified = false - primaryTimedOut = false - primaryTerminal = null - backupTerminal = null - lastFinalFromBackup = false - audioBytes.set(0L) - primaryTimeoutJob?.cancel() - primaryTimeoutJob = null - - startUptimeMs = try { SystemClock.uptimeMillis() } catch (_: Throwable) { 0L } - - primaryEngine = buildPushPcmEngine(primaryVendor, primaryListener, onPrimaryRequestDuration) - backupEngine = buildPushPcmEngine(backupVendor, backupListener, onRequestDuration = null) - primaryConsumer = primaryEngine as? ExternalPcmConsumer - backupConsumer = backupEngine as? ExternalPcmConsumer - - if (primaryEngine == null && backupEngine == null) { - running.set(false) - try { - listener.onError(context.getString(R.string.error_recognize_failed_with_reason, "No engine available")) - } catch (t: Throwable) { - Log.w(TAG, "notify no engine available failed", t) - } - return + companion object { + private const val TAG = "ParallelAsrEngine" + private const val SAMPLE_RATE = 16000 + private const val CHANNELS = 1 + private const val CHUNK_MS = 200 + private const val PRIMARY_SWITCH_RATIO_BALANCED = 0.75 + private const val PRIMARY_SWITCH_RATIO_SENSITIVE = 0.5 + private const val PRIMARY_SWITCH_MIN_MS = 6_000L + private const val PRIMARY_SWITCH_MAX_MS = 15_000L + private const val PRIMARY_SWITCH_NONSTREAM_SOFT_MAX_BALANCED_MS = 25_000L + private const val PRIMARY_SWITCH_NONSTREAM_SOFT_MAX_SENSITIVE_MS = 18_000L } - try { - primaryEngine?.start() - } catch (t: Throwable) { - Log.e(TAG, "primary start failed", t) - onTerminal(Source.PRIMARY, Terminal.Error(t.message ?: "primary start failed")) - } - try { - backupEngine?.start() - } catch (t: Throwable) { - Log.e(TAG, "backup start failed", t) - onTerminal(Source.BACKUP, Terminal.Error(t.message ?: "backup start failed")) - } + private enum class Source { PRIMARY, BACKUP } - if (!externalPcmInput) { - startAudioCapture() + private sealed class Terminal { + data class Final(val text: String) : Terminal() + data class Error(val message: String) : Terminal() } - } - override fun stop() { - if (stopRequested.getAndSet(true)) return + override val isRunning: Boolean + get() = running.get() - running.set(false) - if (!terminalDelivered.get()) { - notifyStoppedIfNeeded() - } + private val running = AtomicBoolean(false) + private val stopRequested = AtomicBoolean(false) + private val terminalDelivered = AtomicBoolean(false) - try { - audioJob?.cancel() - } catch (t: Throwable) { - Log.w(TAG, "cancel audio job failed", t) - } finally { - audioJob = null - } + private val stateLock = Any() - try { - primaryEngine?.stop() - } catch (t: Throwable) { - Log.w(TAG, "primary stop failed", t) - } - try { - backupEngine?.stop() - } catch (t: Throwable) { - Log.w(TAG, "backup stop failed", t) - } + private var audioJob: Job? = null + private var primaryTimeoutJob: Job? = null - schedulePrimaryTimeoutIfNeeded() - } + @Volatile private var startUptimeMs: Long = 0L + private val audioBytes = AtomicLong(0L) - override fun appendPcm(pcm: ByteArray, sampleRate: Int, channels: Int) { - if (!externalPcmInput) return - if (!running.get()) return - if (terminalDelivered.get()) return - if (sampleRate != SAMPLE_RATE || channels != CHANNELS) return + @Volatile private var stoppedNotified: Boolean = false + + @Volatile private var primaryTimedOut: Boolean = false + + @Volatile private var primaryTerminal: Terminal? = null + + @Volatile private var backupTerminal: Terminal? = null + + @Volatile private var lastFinalFromBackup: Boolean = false + + fun wasLastResultFromBackup(): Boolean = lastFinalFromBackup + + private val primaryListener = EngineListener(Source.PRIMARY, forwardLocalModelUi = true) + private val backupListener = EngineListener(Source.BACKUP, forwardLocalModelUi = false) + + private var primaryEngine: StreamingAsrEngine? = null + private var backupEngine: StreamingAsrEngine? = null + private var primaryConsumer: ExternalPcmConsumer? = null + private var backupConsumer: ExternalPcmConsumer? = null + + override fun start() { + if (!running.compareAndSet(false, true)) return + + stopRequested.set(false) + terminalDelivered.set(false) + stoppedNotified = false + primaryTimedOut = false + primaryTerminal = null + backupTerminal = null + lastFinalFromBackup = false + audioBytes.set(0L) + primaryTimeoutJob?.cancel() + primaryTimeoutJob = null + + startUptimeMs = try { + SystemClock.uptimeMillis() + } catch (_: Throwable) { + 0L + } + + primaryEngine = buildPushPcmEngine(primaryVendor, primaryListener, onPrimaryRequestDuration) + backupEngine = buildPushPcmEngine(backupVendor, backupListener, onRequestDuration = null) + primaryConsumer = primaryEngine as? ExternalPcmConsumer + backupConsumer = backupEngine as? ExternalPcmConsumer + + if (primaryEngine == null && backupEngine == null) { + running.set(false) + try { + listener.onError(context.getString(R.string.error_recognize_failed_with_reason, "No engine available")) + } catch (t: Throwable) { + Log.w(TAG, "notify no engine available failed", t) + } + return + } - audioBytes.addAndGet(pcm.size.toLong()) - try { - listener.onAmplitude(calculateNormalizedAmplitude(pcm)) - } catch (t: Throwable) { - Log.w(TAG, "notify amplitude failed (externalPcmInput)", t) - } - try { - primaryConsumer?.appendPcm(pcm, sampleRate, channels) - } catch (t: Throwable) { - Log.w(TAG, "primary appendPcm failed (externalPcmInput)", t) - } - try { - backupConsumer?.appendPcm(pcm, sampleRate, channels) - } catch (t: Throwable) { - Log.w(TAG, "backup appendPcm failed (externalPcmInput)", t) - } - } - - private fun cleanupAfterTerminal() { - stopRequested.set(true) - running.set(false) - try { - primaryTimeoutJob?.cancel() - } catch (t: Throwable) { - Log.w(TAG, "cancel primaryTimeoutJob failed in cleanupAfterTerminal", t) - } finally { - primaryTimeoutJob = null - } - try { - audioJob?.cancel() - } catch (t: Throwable) { - Log.w(TAG, "cancel audio job failed in cleanupAfterTerminal", t) - } finally { - audioJob = null - } - try { - primaryEngine?.stop() - } catch (t: Throwable) { - Log.w(TAG, "primary stop failed in cleanupAfterTerminal", t) - } - try { - backupEngine?.stop() - } catch (t: Throwable) { - Log.w(TAG, "backup stop failed in cleanupAfterTerminal", t) - } - } - - private fun startAudioCapture() { - audioJob?.cancel() - audioJob = scope.launch(Dispatchers.IO) { - val audioManager = AudioCaptureManager( - context = context, - sampleRate = SAMPLE_RATE, - channelConfig = AudioFormat.CHANNEL_IN_MONO, - audioFormat = AudioFormat.ENCODING_PCM_16BIT, - chunkMillis = CHUNK_MS - ) - - if (!audioManager.hasPermission()) { - Log.e(TAG, "Missing RECORD_AUDIO permission") - fatalCaptureError(context.getString(R.string.error_record_permission_denied)) - return@launch - } - - val vadDetector = if (isVadAutoStopEnabled(context, prefs)) { try { - VadDetector(context, SAMPLE_RATE, prefs.autoStopSilenceWindowMs, prefs.autoStopSilenceSensitivity) + primaryEngine?.start() } catch (t: Throwable) { - Log.e(TAG, "Failed to create VAD detector", t) - null - } - } else { - null - } - - try { - audioManager.startCapture().collect { chunk -> - if (!isActive || !running.get()) return@collect - if (terminalDelivered.get()) return@collect - - try { - listener.onAmplitude(calculateNormalizedAmplitude(chunk)) - } catch (t: Throwable) { - Log.w(TAG, "notify amplitude failed", t) - } - audioBytes.addAndGet(chunk.size.toLong()) - - try { - primaryConsumer?.appendPcm(chunk, SAMPLE_RATE, CHANNELS) - } catch (t: Throwable) { - Log.w(TAG, "primary appendPcm failed", t) - } - try { - backupConsumer?.appendPcm(chunk, SAMPLE_RATE, CHANNELS) - } catch (t: Throwable) { - Log.w(TAG, "backup appendPcm failed", t) - } - - if (vadDetector?.shouldStop(chunk, chunk.size) == true) { - Log.d(TAG, "VAD silence detected, stopping session") + Log.e(TAG, "primary start failed", t) + onTerminal(Source.PRIMARY, Terminal.Error(t.message ?: "primary start failed")) + } + try { + backupEngine?.start() + } catch (t: Throwable) { + Log.e(TAG, "backup start failed", t) + onTerminal(Source.BACKUP, Terminal.Error(t.message ?: "backup start failed")) + } + + if (!externalPcmInput) { + startAudioCapture() + } + } + + override fun stop() { + if (stopRequested.getAndSet(true)) return + + running.set(false) + if (!terminalDelivered.get()) { notifyStoppedIfNeeded() - stop() - return@collect - } } - } catch (t: Throwable) { - if (t is CancellationException) { - Log.d(TAG, "Audio capture cancelled: ${t.message}") - } else { - Log.e(TAG, "Audio capture failed", t) - fatalCaptureError(context.getString(R.string.error_audio_error, t.message ?: "")) + + try { + audioJob?.cancel() + } catch (t: Throwable) { + Log.w(TAG, "cancel audio job failed", t) + } finally { + audioJob = null } - } finally { + try { - vadDetector?.release() + primaryEngine?.stop() } catch (t: Throwable) { - Log.w(TAG, "VAD release failed", t) + Log.w(TAG, "primary stop failed", t) } - } - } - } - - private fun fatalCaptureError(message: String) { - onTerminal(Source.PRIMARY, Terminal.Error(message)) - onTerminal(Source.BACKUP, Terminal.Error(message)) - } - - private fun notifyStoppedIfNeeded() { - if (stoppedNotified) return - stoppedNotified = true - try { - listener.onStopped() - } catch (t: Throwable) { - Log.w(TAG, "notify onStopped failed", t) - } - } - - private fun schedulePrimaryTimeoutIfNeeded() { - if (primaryEngine == null || backupEngine == null) return - if (terminalDelivered.get()) return - - val bytesAudioMs = audioMsFromBytes(audioBytes.get()) - val audioMs = if (bytesAudioMs > 0L) { - bytesAudioMs - } else { - val t0 = startUptimeMs - val t1 = try { SystemClock.uptimeMillis() } catch (_: Throwable) { 0L } - if (t0 > 0L && t1 >= t0) (t1 - t0) else 0L + try { + backupEngine?.stop() + } catch (t: Throwable) { + Log.w(TAG, "backup stop failed", t) + } + + schedulePrimaryTimeoutIfNeeded() } - val baseTimeoutMs = AsrTimeoutCalculator.calculateTimeoutMs(audioMs) - val sensitivityTier = try { prefs.backupAsrTimeoutSensitivity } catch (_: Throwable) { 1 } - val switchTimeoutMs = calculatePrimarySwitchTimeoutMs(baseTimeoutMs, isPrimaryStreamingForSwitch(), sensitivityTier) + override fun appendPcm(pcm: ByteArray, sampleRate: Int, channels: Int) { + if (!externalPcmInput) return + if (!running.get()) return + if (terminalDelivered.get()) return + if (sampleRate != SAMPLE_RATE || channels != CHANNELS) return - try { - primaryTimeoutJob?.cancel() - } catch (t: Throwable) { - Log.w(TAG, "cancel primaryTimeoutJob failed", t) - } - primaryTimeoutJob = scope.launch { - delay(switchTimeoutMs) - synchronized(stateLock) { - if (terminalDelivered.get()) return@synchronized - if (primaryTerminal == null) { - primaryTimedOut = true - Log.w(TAG, "Primary timeout fired (audioMs=$audioMs, switchTimeoutMs=$switchTimeoutMs)") - } - tryResolveLocked() - } - } - Log.d( - TAG, - "Primary timeout scheduled: audioMs=$audioMs, baseTimeoutMs=$baseTimeoutMs, switchTimeoutMs=$switchTimeoutMs" - ) - } - - private fun calculatePrimarySwitchTimeoutMs(baseTimeoutMs: Long, primaryStreaming: Boolean, sensitivityTier: Int): Long { - val ratio = when (sensitivityTier.coerceIn(0, 2)) { - 0 -> 1.0 - 2 -> PRIMARY_SWITCH_RATIO_SENSITIVE - else -> PRIMARY_SWITCH_RATIO_BALANCED + audioBytes.addAndGet(pcm.size.toLong()) + try { + listener.onAmplitude(calculateNormalizedAmplitude(pcm)) + } catch (t: Throwable) { + Log.w(TAG, "notify amplitude failed (externalPcmInput)", t) + } + try { + primaryConsumer?.appendPcm(pcm, sampleRate, channels) + } catch (t: Throwable) { + Log.w(TAG, "primary appendPcm failed (externalPcmInput)", t) + } + try { + backupConsumer?.appendPcm(pcm, sampleRate, channels) + } catch (t: Throwable) { + Log.w(TAG, "backup appendPcm failed (externalPcmInput)", t) + } } - var timeoutMs = (baseTimeoutMs.toDouble() * ratio).toLong().coerceAtLeast(0L) - - if (primaryStreaming) { - timeoutMs = timeoutMs.coerceIn(PRIMARY_SWITCH_MIN_MS, PRIMARY_SWITCH_MAX_MS) - } else { - val minMs = when (sensitivityTier.coerceIn(0, 2)) { - 2 -> 5_000L - 1 -> 6_000L - else -> 0L - } - val softMaxMs = when (sensitivityTier.coerceIn(0, 2)) { - 2 -> PRIMARY_SWITCH_NONSTREAM_SOFT_MAX_SENSITIVE_MS - 1 -> PRIMARY_SWITCH_NONSTREAM_SOFT_MAX_BALANCED_MS - else -> Long.MAX_VALUE - } - timeoutMs = timeoutMs.coerceAtLeast(minMs).coerceAtMost(softMaxMs) + private fun cleanupAfterTerminal() { + stopRequested.set(true) + running.set(false) + try { + primaryTimeoutJob?.cancel() + } catch (t: Throwable) { + Log.w(TAG, "cancel primaryTimeoutJob failed in cleanupAfterTerminal", t) + } finally { + primaryTimeoutJob = null + } + try { + audioJob?.cancel() + } catch (t: Throwable) { + Log.w(TAG, "cancel audio job failed in cleanupAfterTerminal", t) + } finally { + audioJob = null + } + try { + primaryEngine?.stop() + } catch (t: Throwable) { + Log.w(TAG, "primary stop failed in cleanupAfterTerminal", t) + } + try { + backupEngine?.stop() + } catch (t: Throwable) { + Log.w(TAG, "backup stop failed in cleanupAfterTerminal", t) + } } - return timeoutMs - } - - private fun audioMsFromBytes(bytes: Long): Long { - if (bytes <= 0L) return 0L - val denom = SAMPLE_RATE.toLong() * CHANNELS.toLong() * 2L - if (denom <= 0L) return 0L - return (bytes * 1000L / denom).coerceAtLeast(0L) - } - - private fun isPrimaryStreamingForSwitch(): Boolean { - return when (primaryVendor) { - AsrVendor.Volc -> prefs.volcStreamingEnabled - AsrVendor.DashScope -> prefs.isDashStreamingModelSelected() - AsrVendor.Soniox -> prefs.sonioxStreamingEnabled - AsrVendor.ElevenLabs -> prefs.elevenStreamingEnabled - AsrVendor.Paraformer -> true - else -> false - } - } - - private fun onTerminal(source: Source, t: Terminal) { - val shouldStopCapture = synchronized(stateLock) { - if (terminalDelivered.get()) return - when (source) { - Source.PRIMARY -> primaryTerminal = t - Source.BACKUP -> backupTerminal = t - } - tryResolveLocked() - terminalDelivered.get() - } - if (shouldStopCapture) { - cleanupAfterTerminal() - } - } - - private fun tryResolveLocked() { - if (terminalDelivered.get()) return - - val p = primaryTerminal - val b = backupTerminal - val hasPrimary = primaryEngine != null - val hasBackup = backupEngine != null - - // 无主用:直接采用备用 - if (!hasPrimary) { - when (b) { - is Terminal.Final -> deliverFinalLocked(b.text, Source.BACKUP) - is Terminal.Error -> deliverErrorLocked(b.message) - null -> Unit - } - return + private fun startAudioCapture() { + audioJob?.cancel() + audioJob = scope.launch(Dispatchers.IO) { + val audioManager = AudioCaptureManager( + context = context, + sampleRate = SAMPLE_RATE, + channelConfig = AudioFormat.CHANNEL_IN_MONO, + audioFormat = AudioFormat.ENCODING_PCM_16BIT, + chunkMillis = CHUNK_MS, + ) + + if (!audioManager.hasPermission()) { + Log.e(TAG, "Missing RECORD_AUDIO permission") + fatalCaptureError(context.getString(R.string.error_record_permission_denied)) + return@launch + } + + val vadDetector = if (isVadAutoStopEnabled(context, prefs)) { + try { + VadDetector(context, SAMPLE_RATE, prefs.autoStopSilenceWindowMs, prefs.autoStopSilenceSensitivity) + } catch (t: Throwable) { + Log.e(TAG, "Failed to create VAD detector", t) + null + } + } else { + null + } + + try { + audioManager.startCapture().collect { chunk -> + if (!isActive || !running.get()) return@collect + if (terminalDelivered.get()) return@collect + + try { + listener.onAmplitude(calculateNormalizedAmplitude(chunk)) + } catch (t: Throwable) { + Log.w(TAG, "notify amplitude failed", t) + } + audioBytes.addAndGet(chunk.size.toLong()) + + try { + primaryConsumer?.appendPcm(chunk, SAMPLE_RATE, CHANNELS) + } catch (t: Throwable) { + Log.w(TAG, "primary appendPcm failed", t) + } + try { + backupConsumer?.appendPcm(chunk, SAMPLE_RATE, CHANNELS) + } catch (t: Throwable) { + Log.w(TAG, "backup appendPcm failed", t) + } + + if (vadDetector?.shouldStop(chunk, chunk.size) == true) { + Log.d(TAG, "VAD silence detected, stopping session") + notifyStoppedIfNeeded() + stop() + return@collect + } + } + } catch (t: Throwable) { + if (t is CancellationException) { + Log.d(TAG, "Audio capture cancelled: ${t.message}") + } else { + Log.e(TAG, "Audio capture failed", t) + fatalCaptureError(context.getString(R.string.error_audio_error, t.message ?: "")) + } + } finally { + try { + vadDetector?.release() + } catch (t: Throwable) { + Log.w(TAG, "VAD release failed", t) + } + } + } } - // 主用有结果(非空)直接采用 - val pFinal = p as? Terminal.Final - if (pFinal != null && pFinal.text.isNotBlank()) { - deliverFinalLocked(pFinal.text, Source.PRIMARY) - return + private fun fatalCaptureError(message: String) { + onTerminal(Source.PRIMARY, Terminal.Error(message)) + onTerminal(Source.BACKUP, Terminal.Error(message)) } - // 没有备用:主用终止即交付 - if (!hasBackup) { - when (p) { - is Terminal.Final -> deliverFinalLocked(p.text, Source.PRIMARY) - is Terminal.Error -> deliverErrorLocked(p.message) - null -> Unit - } - return + private fun notifyStoppedIfNeeded() { + if (stoppedNotified) return + stoppedNotified = true + try { + listener.onStopped() + } catch (t: Throwable) { + Log.w(TAG, "notify onStopped failed", t) + } } - val pFailed = when (p) { - is Terminal.Error -> true - is Terminal.Final -> p.text.isBlank() - null -> false - } + private fun schedulePrimaryTimeoutIfNeeded() { + if (primaryEngine == null || backupEngine == null) return + if (terminalDelivered.get()) return - // 主用失败:尽快尝试采用备用(无需等待主用超时阈值) - if (pFailed) { - val bFinal = b as? Terminal.Final - when { - bFinal != null && bFinal.text.isNotBlank() -> deliverFinalLocked(bFinal.text, Source.BACKUP) - b is Terminal.Error -> deliverErrorLocked(b.message) - bFinal != null -> deliverFinalLocked(bFinal.text, Source.BACKUP) - else -> Unit - } - return - } + val bytesAudioMs = audioMsFromBytes(audioBytes.get()) + val audioMs = if (bytesAudioMs > 0L) { + bytesAudioMs + } else { + val t0 = startUptimeMs + val t1 = try { + SystemClock.uptimeMillis() + } catch (_: Throwable) { + 0L + } + if (t0 > 0L && t1 >= t0) (t1 - t0) else 0L + } - // 主用未终止但已超时:切换到备用 - if (primaryTimedOut) { - when (b) { - is Terminal.Final -> deliverFinalLocked(b.text, Source.BACKUP) - is Terminal.Error -> deliverErrorLocked(b.message) - null -> Unit - } - return + val baseTimeoutMs = AsrTimeoutCalculator.calculateTimeoutMs(audioMs) + val sensitivityTier = try { + prefs.backupAsrTimeoutSensitivity + } catch (_: Throwable) { + 1 + } + val switchTimeoutMs = calculatePrimarySwitchTimeoutMs(baseTimeoutMs, isPrimaryStreamingForSwitch(), sensitivityTier) + + try { + primaryTimeoutJob?.cancel() + } catch (t: Throwable) { + Log.w(TAG, "cancel primaryTimeoutJob failed", t) + } + primaryTimeoutJob = scope.launch { + delay(switchTimeoutMs) + synchronized(stateLock) { + if (terminalDelivered.get()) return@synchronized + if (primaryTerminal == null) { + primaryTimedOut = true + Log.w(TAG, "Primary timeout fired (audioMs=$audioMs, switchTimeoutMs=$switchTimeoutMs)") + } + tryResolveLocked() + } + } + Log.d( + TAG, + "Primary timeout scheduled: audioMs=$audioMs, baseTimeoutMs=$baseTimeoutMs, switchTimeoutMs=$switchTimeoutMs", + ) } - // 否则:继续等待主用(备用结果缓存,不立即提交) - } - - private fun deliverFinalLocked(text: String, from: Source) { - if (!terminalDelivered.compareAndSet(false, true)) return - lastFinalFromBackup = (from == Source.BACKUP) - try { - primaryTimeoutJob?.cancel() - } catch (t: Throwable) { - Log.w(TAG, "cancel primaryTimeoutJob failed on deliverFinal", t) - } finally { - primaryTimeoutJob = null + private fun calculatePrimarySwitchTimeoutMs(baseTimeoutMs: Long, primaryStreaming: Boolean, sensitivityTier: Int): Long { + val ratio = when (sensitivityTier.coerceIn(0, 2)) { + 0 -> 1.0 + 2 -> PRIMARY_SWITCH_RATIO_SENSITIVE + else -> PRIMARY_SWITCH_RATIO_BALANCED + } + + var timeoutMs = (baseTimeoutMs.toDouble() * ratio).toLong().coerceAtLeast(0L) + + if (primaryStreaming) { + timeoutMs = timeoutMs.coerceIn(PRIMARY_SWITCH_MIN_MS, PRIMARY_SWITCH_MAX_MS) + } else { + val minMs = when (sensitivityTier.coerceIn(0, 2)) { + 2 -> 5_000L + 1 -> 6_000L + else -> 0L + } + val softMaxMs = when (sensitivityTier.coerceIn(0, 2)) { + 2 -> PRIMARY_SWITCH_NONSTREAM_SOFT_MAX_SENSITIVE_MS + 1 -> PRIMARY_SWITCH_NONSTREAM_SOFT_MAX_BALANCED_MS + else -> Long.MAX_VALUE + } + timeoutMs = timeoutMs.coerceAtLeast(minMs).coerceAtMost(softMaxMs) + } + + return timeoutMs } - try { - listener.onFinal(text) - } catch (t: Throwable) { - Log.e(TAG, "notify final failed", t) + + private fun audioMsFromBytes(bytes: Long): Long { + if (bytes <= 0L) return 0L + val denom = SAMPLE_RATE.toLong() * CHANNELS.toLong() * 2L + if (denom <= 0L) return 0L + return (bytes * 1000L / denom).coerceAtLeast(0L) } - } - - private fun deliverErrorLocked(message: String) { - if (!terminalDelivered.compareAndSet(false, true)) return - try { - primaryTimeoutJob?.cancel() - } catch (t: Throwable) { - Log.w(TAG, "cancel primaryTimeoutJob failed on deliverError", t) - } finally { - primaryTimeoutJob = null + + private fun isPrimaryStreamingForSwitch(): Boolean { + return when (primaryVendor) { + AsrVendor.Volc -> prefs.volcStreamingEnabled + AsrVendor.DashScope -> prefs.isDashStreamingModelSelected() + AsrVendor.Soniox -> prefs.sonioxStreamingEnabled + AsrVendor.ElevenLabs -> prefs.elevenStreamingEnabled + AsrVendor.Paraformer -> true + else -> false + } } - try { - listener.onError(message) - } catch (t: Throwable) { - Log.e(TAG, "notify error failed", t) + + private fun onTerminal(source: Source, t: Terminal) { + val shouldStopCapture = synchronized(stateLock) { + if (terminalDelivered.get()) return + when (source) { + Source.PRIMARY -> primaryTerminal = t + Source.BACKUP -> backupTerminal = t + } + tryResolveLocked() + terminalDelivered.get() + } + if (shouldStopCapture) { + cleanupAfterTerminal() + } } - } - private inner class EngineListener( - private val source: Source, - private val forwardLocalModelUi: Boolean - ) : StreamingAsrEngine.Listener, SenseVoiceFileAsrEngine.LocalModelLoadUi { + private fun tryResolveLocked() { + if (terminalDelivered.get()) return - override fun onFinal(text: String) { - onTerminal(source, Terminal.Final(text)) - } + val p = primaryTerminal + val b = backupTerminal + val hasPrimary = primaryEngine != null + val hasBackup = backupEngine != null - override fun onError(message: String) { - onTerminal(source, Terminal.Error(message)) - } + // 无主用:直接采用备用 + if (!hasPrimary) { + when (b) { + is Terminal.Final -> deliverFinalLocked(b.text, Source.BACKUP) + is Terminal.Error -> deliverErrorLocked(b.message) + null -> Unit + } + return + } + + // 主用有结果(非空)直接采用 + val pFinal = p as? Terminal.Final + if (pFinal != null && pFinal.text.isNotBlank()) { + deliverFinalLocked(pFinal.text, Source.PRIMARY) + return + } + + // 没有备用:主用终止即交付 + if (!hasBackup) { + when (p) { + is Terminal.Final -> deliverFinalLocked(p.text, Source.PRIMARY) + is Terminal.Error -> deliverErrorLocked(p.message) + null -> Unit + } + return + } + + val pFailed = when (p) { + is Terminal.Error -> true + is Terminal.Final -> p.text.isBlank() + null -> false + } + + // 主用失败:尽快尝试采用备用(无需等待主用超时阈值) + if (pFailed) { + val bFinal = b as? Terminal.Final + when { + bFinal != null && bFinal.text.isNotBlank() -> deliverFinalLocked(bFinal.text, Source.BACKUP) + b is Terminal.Error -> deliverErrorLocked(b.message) + bFinal != null -> deliverFinalLocked(bFinal.text, Source.BACKUP) + else -> Unit + } + return + } - override fun onPartial(text: String) { - if (source != Source.PRIMARY) return - try { - listener.onPartial(text) - } catch (t: Throwable) { - Log.w(TAG, "notify partial failed", t) - } + // 主用未终止但已超时:切换到备用 + if (primaryTimedOut) { + when (b) { + is Terminal.Final -> deliverFinalLocked(b.text, Source.BACKUP) + is Terminal.Error -> deliverErrorLocked(b.message) + null -> Unit + } + return + } + + // 否则:继续等待主用(备用结果缓存,不立即提交) } - override fun onLocalModelLoadStart() { - if (!forwardLocalModelUi) return - val ui = listener as? SenseVoiceFileAsrEngine.LocalModelLoadUi ?: return - try { ui.onLocalModelLoadStart() } catch (t: Throwable) { Log.w(TAG, "forward loadStart failed", t) } + private fun deliverFinalLocked(text: String, from: Source) { + if (!terminalDelivered.compareAndSet(false, true)) return + lastFinalFromBackup = (from == Source.BACKUP) + try { + primaryTimeoutJob?.cancel() + } catch (t: Throwable) { + Log.w(TAG, "cancel primaryTimeoutJob failed on deliverFinal", t) + } finally { + primaryTimeoutJob = null + } + try { + listener.onFinal(text) + } catch (t: Throwable) { + Log.e(TAG, "notify final failed", t) + } } - override fun onLocalModelLoadDone() { - if (!forwardLocalModelUi) return - val ui = listener as? SenseVoiceFileAsrEngine.LocalModelLoadUi ?: return - try { ui.onLocalModelLoadDone() } catch (t: Throwable) { Log.w(TAG, "forward loadDone failed", t) } + private fun deliverErrorLocked(message: String) { + if (!terminalDelivered.compareAndSet(false, true)) return + try { + primaryTimeoutJob?.cancel() + } catch (t: Throwable) { + Log.w(TAG, "cancel primaryTimeoutJob failed on deliverError", t) + } finally { + primaryTimeoutJob = null + } + try { + listener.onError(message) + } catch (t: Throwable) { + Log.e(TAG, "notify error failed", t) + } } - } - - private fun buildPushPcmEngine( - vendor: AsrVendor, - engineListener: StreamingAsrEngine.Listener, - onRequestDuration: ((Long) -> Unit)? - ): StreamingAsrEngine? { - val hasKeys = try { - when (vendor) { - AsrVendor.SiliconFlow -> prefs.hasSfKeys() - else -> prefs.hasVendorKeys(vendor) - } - } catch (t: Throwable) { - Log.w(TAG, "Failed to read keys for vendor=$vendor", t) - false + + private inner class EngineListener( + private val source: Source, + private val forwardLocalModelUi: Boolean, + ) : StreamingAsrEngine.Listener, SenseVoiceFileAsrEngine.LocalModelLoadUi { + + override fun onFinal(text: String) { + onTerminal(source, Terminal.Final(text)) + } + + override fun onError(message: String) { + onTerminal(source, Terminal.Error(message)) + } + + override fun onPartial(text: String) { + if (source != Source.PRIMARY) return + try { + listener.onPartial(text) + } catch (t: Throwable) { + Log.w(TAG, "notify partial failed", t) + } + } + + override fun onLocalModelLoadStart() { + if (!forwardLocalModelUi) return + val ui = listener as? SenseVoiceFileAsrEngine.LocalModelLoadUi ?: return + try { + ui.onLocalModelLoadStart() + } catch (t: Throwable) { + Log.w(TAG, "forward loadStart failed", t) + } + } + + override fun onLocalModelLoadDone() { + if (!forwardLocalModelUi) return + val ui = listener as? SenseVoiceFileAsrEngine.LocalModelLoadUi ?: return + try { + ui.onLocalModelLoadDone() + } catch (t: Throwable) { + Log.w(TAG, "forward loadDone failed", t) + } + } } - if (!hasKeys) return null - - return when (vendor) { - AsrVendor.Volc -> if (prefs.volcStreamingEnabled) { - VolcStreamAsrEngine(context, scope, prefs, engineListener, externalPcmMode = true) - } else { - if (prefs.volcFileStandardEnabled) { - GenericPushFileAsrAdapter( - context, - scope, - prefs, - engineListener, - VolcStandardFileAsrEngine(context, scope, prefs, engineListener, onRequestDuration = onRequestDuration) - ) - } else { - GenericPushFileAsrAdapter( - context, - scope, - prefs, - engineListener, - VolcFileAsrEngine(context, scope, prefs, engineListener, onRequestDuration = onRequestDuration) - ) - } - } - AsrVendor.SiliconFlow -> SiliconFlowFileAsrEngine( - context, - scope, - prefs, - engineListener, - onRequestDuration = onRequestDuration - ).let { GenericPushFileAsrAdapter(context, scope, prefs, engineListener, it) } - AsrVendor.ElevenLabs -> if (prefs.elevenStreamingEnabled) { - ElevenLabsStreamAsrEngine(context, scope, prefs, engineListener, externalPcmMode = true) - } else { - ElevenLabsFileAsrEngine(context, scope, prefs, engineListener, onRequestDuration = onRequestDuration) - .let { GenericPushFileAsrAdapter(context, scope, prefs, engineListener, it) } - } - AsrVendor.OpenAI -> OpenAiFileAsrEngine(context, scope, prefs, engineListener, onRequestDuration = onRequestDuration) - .let { GenericPushFileAsrAdapter(context, scope, prefs, engineListener, it) } - AsrVendor.DashScope -> if (prefs.isDashStreamingModelSelected()) { - DashscopeStreamAsrEngine(context, scope, prefs, engineListener, externalPcmMode = true) - } else { - DashscopeFileAsrEngine(context, scope, prefs, engineListener, onRequestDuration = onRequestDuration) - .let { GenericPushFileAsrAdapter(context, scope, prefs, engineListener, it) } - } - AsrVendor.Gemini -> GeminiFileAsrEngine(context, scope, prefs, engineListener, onRequestDuration = onRequestDuration) - .let { GenericPushFileAsrAdapter(context, scope, prefs, engineListener, it) } - AsrVendor.Soniox -> if (prefs.sonioxStreamingEnabled) { - SonioxStreamAsrEngine(context, scope, prefs, engineListener, externalPcmMode = true) - } else { - SonioxFileAsrEngine(context, scope, prefs, engineListener, onRequestDuration = onRequestDuration) - .let { GenericPushFileAsrAdapter(context, scope, prefs, engineListener, it) } - } - AsrVendor.Zhipu -> ZhipuFileAsrEngine(context, scope, prefs, engineListener, onRequestDuration = onRequestDuration) - .let { GenericPushFileAsrAdapter(context, scope, prefs, engineListener, it) } - AsrVendor.SenseVoice -> if (prefs.svPseudoStreamEnabled) { - SenseVoicePushPcmPseudoStreamAsrEngine(context, scope, prefs, engineListener, onRequestDuration = onRequestDuration) - } else { - SenseVoiceFileAsrEngine(context, scope, prefs, engineListener, onRequestDuration = onRequestDuration) - .let { GenericPushFileAsrAdapter(context, scope, prefs, engineListener, it) } - } - AsrVendor.FunAsrNano -> FunAsrNanoFileAsrEngine(context, scope, prefs, engineListener, onRequestDuration = onRequestDuration) - .let { GenericPushFileAsrAdapter(context, scope, prefs, engineListener, it) } - AsrVendor.Telespeech -> if (prefs.tsPseudoStreamEnabled) { - TelespeechPushPcmPseudoStreamAsrEngine(context, scope, prefs, engineListener, onRequestDuration = onRequestDuration) - } else { - TelespeechFileAsrEngine(context, scope, prefs, engineListener, onRequestDuration = onRequestDuration) - .let { GenericPushFileAsrAdapter(context, scope, prefs, engineListener, it) } - } - AsrVendor.Paraformer -> ParaformerStreamAsrEngine(context, scope, prefs, engineListener, externalPcmMode = true) + + private fun buildPushPcmEngine( + vendor: AsrVendor, + engineListener: StreamingAsrEngine.Listener, + onRequestDuration: ((Long) -> Unit)?, + ): StreamingAsrEngine? { + val hasKeys = try { + when (vendor) { + AsrVendor.SiliconFlow -> prefs.hasSfKeys() + else -> prefs.hasVendorKeys(vendor) + } + } catch (t: Throwable) { + Log.w(TAG, "Failed to read keys for vendor=$vendor", t) + false + } + if (!hasKeys) return null + + return when (vendor) { + AsrVendor.Volc -> if (prefs.volcStreamingEnabled) { + VolcStreamAsrEngine(context, scope, prefs, engineListener, externalPcmMode = true) + } else { + if (prefs.volcFileStandardEnabled) { + GenericPushFileAsrAdapter( + context, + scope, + prefs, + engineListener, + VolcStandardFileAsrEngine(context, scope, prefs, engineListener, onRequestDuration = onRequestDuration), + ) + } else { + GenericPushFileAsrAdapter( + context, + scope, + prefs, + engineListener, + VolcFileAsrEngine(context, scope, prefs, engineListener, onRequestDuration = onRequestDuration), + ) + } + } + AsrVendor.SiliconFlow -> SiliconFlowFileAsrEngine( + context, + scope, + prefs, + engineListener, + onRequestDuration = onRequestDuration, + ).let { GenericPushFileAsrAdapter(context, scope, prefs, engineListener, it) } + AsrVendor.ElevenLabs -> if (prefs.elevenStreamingEnabled) { + ElevenLabsStreamAsrEngine(context, scope, prefs, engineListener, externalPcmMode = true) + } else { + ElevenLabsFileAsrEngine(context, scope, prefs, engineListener, onRequestDuration = onRequestDuration) + .let { GenericPushFileAsrAdapter(context, scope, prefs, engineListener, it) } + } + AsrVendor.OpenAI -> OpenAiFileAsrEngine(context, scope, prefs, engineListener, onRequestDuration = onRequestDuration) + .let { GenericPushFileAsrAdapter(context, scope, prefs, engineListener, it) } + AsrVendor.DashScope -> if (prefs.isDashStreamingModelSelected()) { + DashscopeStreamAsrEngine(context, scope, prefs, engineListener, externalPcmMode = true) + } else { + DashscopeFileAsrEngine(context, scope, prefs, engineListener, onRequestDuration = onRequestDuration) + .let { GenericPushFileAsrAdapter(context, scope, prefs, engineListener, it) } + } + AsrVendor.Gemini -> GeminiFileAsrEngine(context, scope, prefs, engineListener, onRequestDuration = onRequestDuration) + .let { GenericPushFileAsrAdapter(context, scope, prefs, engineListener, it) } + AsrVendor.Soniox -> if (prefs.sonioxStreamingEnabled) { + SonioxStreamAsrEngine(context, scope, prefs, engineListener, externalPcmMode = true) + } else { + SonioxFileAsrEngine(context, scope, prefs, engineListener, onRequestDuration = onRequestDuration) + .let { GenericPushFileAsrAdapter(context, scope, prefs, engineListener, it) } + } + AsrVendor.Zhipu -> ZhipuFileAsrEngine(context, scope, prefs, engineListener, onRequestDuration = onRequestDuration) + .let { GenericPushFileAsrAdapter(context, scope, prefs, engineListener, it) } + AsrVendor.SenseVoice -> if (prefs.svPseudoStreamEnabled) { + SenseVoicePushPcmPseudoStreamAsrEngine(context, scope, prefs, engineListener, onRequestDuration = onRequestDuration) + } else { + SenseVoiceFileAsrEngine(context, scope, prefs, engineListener, onRequestDuration = onRequestDuration) + .let { GenericPushFileAsrAdapter(context, scope, prefs, engineListener, it) } + } + AsrVendor.FunAsrNano -> FunAsrNanoFileAsrEngine(context, scope, prefs, engineListener, onRequestDuration = onRequestDuration) + .let { GenericPushFileAsrAdapter(context, scope, prefs, engineListener, it) } + AsrVendor.Telespeech -> if (prefs.tsPseudoStreamEnabled) { + TelespeechPushPcmPseudoStreamAsrEngine(context, scope, prefs, engineListener, onRequestDuration = onRequestDuration) + } else { + TelespeechFileAsrEngine(context, scope, prefs, engineListener, onRequestDuration = onRequestDuration) + .let { GenericPushFileAsrAdapter(context, scope, prefs, engineListener, it) } + } + AsrVendor.Paraformer -> ParaformerStreamAsrEngine(context, scope, prefs, engineListener, externalPcmMode = true) + } } - } } diff --git a/app/src/main/java/com/brycewg/asrkb/asr/PushPcmPseudoStreamAsrEngine.kt b/app/src/main/java/com/brycewg/asrkb/asr/PushPcmPseudoStreamAsrEngine.kt index 0ded5241..963d45b6 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/PushPcmPseudoStreamAsrEngine.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/PushPcmPseudoStreamAsrEngine.kt @@ -21,161 +21,173 @@ import java.util.concurrent.atomic.AtomicBoolean * 注意:该引擎仅做“分句预览”,不做自动判停(由外部 finishPcm 决定会话结束)。 */ abstract class PushPcmPseudoStreamAsrEngine( - protected val context: Context, - protected val scope: CoroutineScope, - protected val prefs: Prefs, - protected val listener: StreamingAsrEngine.Listener, - protected val onRequestDuration: ((Long) -> Unit)? = null + protected val context: Context, + protected val scope: CoroutineScope, + protected val prefs: Prefs, + protected val listener: StreamingAsrEngine.Listener, + protected val onRequestDuration: ((Long) -> Unit)? = null, ) : StreamingAsrEngine, ExternalPcmConsumer { - companion object { - private const val TAG = "PushPcmPseudoStream" - private const val PREVIEW_SEGMENT_MS = 800 - } - - protected open val sampleRate: Int = 16000 - protected open val channelConfig: Int = AudioFormat.CHANNEL_IN_MONO - protected open val audioFormat: Int = AudioFormat.ENCODING_PCM_16BIT - - private val running = AtomicBoolean(false) - private val finalized = AtomicBoolean(false) + companion object { + private const val TAG = "PushPcmPseudoStream" + private const val PREVIEW_SEGMENT_MS = 800 + } - private val sessionBuffer = ByteArrayOutputStream() - private val segmentBuffer = ByteArrayOutputStream() - private var segmentElapsedMs: Long = 0L + protected open val sampleRate: Int = 16000 + protected open val channelConfig: Int = AudioFormat.CHANNEL_IN_MONO + protected open val audioFormat: Int = AudioFormat.ENCODING_PCM_16BIT - override val isRunning: Boolean - get() = running.get() + private val running = AtomicBoolean(false) + private val finalized = AtomicBoolean(false) - protected open fun ensureReady(): Boolean = true + private val sessionBuffer = ByteArrayOutputStream() + private val segmentBuffer = ByteArrayOutputStream() + private var segmentElapsedMs: Long = 0L - protected abstract fun onSegmentBoundary(pcmSegment: ByteArray) + override val isRunning: Boolean + get() = running.get() - protected abstract suspend fun onSessionFinished(fullPcm: ByteArray) + protected open fun ensureReady(): Boolean = true - override fun start() { - if (running.get()) return - if (!ensureReady()) return + protected abstract fun onSegmentBoundary(pcmSegment: ByteArray) - running.set(true) - finalized.set(false) - sessionBuffer.reset() - segmentBuffer.reset() - segmentElapsedMs = 0L - } + protected abstract suspend fun onSessionFinished(fullPcm: ByteArray) - override fun stop() { - if (!running.get()) return - running.set(false) - try { listener.onStopped() } catch (t: Throwable) { Log.w(TAG, "notify onStopped failed", t) } - finalizeOnce() - } + override fun start() { + if (running.get()) return + if (!ensureReady()) return - override fun appendPcm(pcm: ByteArray, sampleRate: Int, channels: Int) { - if (!running.get()) return - if (sampleRate != this.sampleRate || channels != 1) { - Log.w(TAG, "ignore frame: sr=$sampleRate ch=$channels") - return + running.set(true) + finalized.set(false) + sessionBuffer.reset() + segmentBuffer.reset() + segmentElapsedMs = 0L } - if (pcm.isEmpty()) return - try { listener.onAmplitude(calculateNormalizedAmplitude(pcm)) } catch (t: Throwable) { - Log.w(TAG, "amp cb failed", t) + override fun stop() { + if (!running.get()) return + running.set(false) + try { + listener.onStopped() + } catch (t: Throwable) { + Log.w(TAG, "notify onStopped failed", t) + } + finalizeOnce() } - try { - segmentBuffer.write(pcm) - } catch (t: Throwable) { - Log.e(TAG, "Failed to buffer audio chunk", t) - return - } + override fun appendPcm(pcm: ByteArray, sampleRate: Int, channels: Int) { + if (!running.get()) return + if (sampleRate != this.sampleRate || channels != 1) { + Log.w(TAG, "ignore frame: sr=$sampleRate ch=$channels") + return + } + if (pcm.isEmpty()) return - val frameMs = if (sampleRate > 0) { - ((pcm.size / 2) * 1000L) / sampleRate - } else { - 0L - } - if (frameMs > 0L) { - segmentElapsedMs += frameMs - } - if (segmentElapsedMs >= PREVIEW_SEGMENT_MS && segmentBuffer.size() > 0) { - val segBytes = try { segmentBuffer.toByteArray() } catch (t: Throwable) { - Log.e(TAG, "Failed to toByteArray for segment", t) - null - } ?: return - try { - sessionBuffer.write(segBytes) - } catch (t: Throwable) { - Log.e(TAG, "Failed to append segment to session buffer", t) - } - try { - onSegmentBoundary(segBytes) - } catch (t: Throwable) { - Log.e(TAG, "onSegmentBoundary failed", t) - } - segmentBuffer.reset() - segmentElapsedMs = 0L - } - } - - private fun finalizeOnce() { - if (!finalized.compareAndSet(false, true)) return - if (segmentBuffer.size() > 0) { - val segBytes = try { segmentBuffer.toByteArray() } catch (t: Throwable) { - Log.e(TAG, "Failed to toByteArray for tail segment", t) - null - } - if (segBytes != null) { try { - sessionBuffer.write(segBytes) + listener.onAmplitude(calculateNormalizedAmplitude(pcm)) } catch (t: Throwable) { - Log.e(TAG, "Failed to append tail segment to session buffer", t) + Log.w(TAG, "amp cb failed", t) } - } - segmentBuffer.reset() - } - val fullPcm = try { sessionBuffer.toByteArray() } catch (t: Throwable) { - Log.e(TAG, "Failed to dump session buffer", t) - ByteArray(0) - } - sessionBuffer.reset() - segmentBuffer.reset() - - if (fullPcm.isEmpty()) { - try { - listener.onError(context.getString(R.string.error_audio_empty)) - } catch (t: Throwable) { - Log.w(TAG, "notify audio empty failed", t) - } - return - } - scope.launch(Dispatchers.IO) { - try { - val denoised = OfflineSpeechDenoiserManager.denoiseIfEnabled( - context = context, - prefs = prefs, - pcm = fullPcm, - sampleRate = sampleRate - ) - onSessionFinished(denoised) - } catch (t: Throwable) { - if (t is CancellationException) { - Log.d(TAG, "final recognition cancelled: ${t.message}") + try { + segmentBuffer.write(pcm) + } catch (t: Throwable) { + Log.e(TAG, "Failed to buffer audio chunk", t) + return + } + + val frameMs = if (sampleRate > 0) { + ((pcm.size / 2) * 1000L) / sampleRate } else { - Log.e(TAG, "Final recognition failed", t) - try { - listener.onError( - context.getString( - R.string.error_recognize_failed_with_reason, - t.message ?: "" - ) - ) - } catch (e: Throwable) { - Log.w(TAG, "notify final error failed", e) - } + 0L + } + if (frameMs > 0L) { + segmentElapsedMs += frameMs + } + if (segmentElapsedMs >= PREVIEW_SEGMENT_MS && segmentBuffer.size() > 0) { + val segBytes = try { + segmentBuffer.toByteArray() + } catch (t: Throwable) { + Log.e(TAG, "Failed to toByteArray for segment", t) + null + } ?: return + try { + sessionBuffer.write(segBytes) + } catch (t: Throwable) { + Log.e(TAG, "Failed to append segment to session buffer", t) + } + try { + onSegmentBoundary(segBytes) + } catch (t: Throwable) { + Log.e(TAG, "onSegmentBoundary failed", t) + } + segmentBuffer.reset() + segmentElapsedMs = 0L + } + } + + private fun finalizeOnce() { + if (!finalized.compareAndSet(false, true)) return + if (segmentBuffer.size() > 0) { + val segBytes = try { + segmentBuffer.toByteArray() + } catch (t: Throwable) { + Log.e(TAG, "Failed to toByteArray for tail segment", t) + null + } + if (segBytes != null) { + try { + sessionBuffer.write(segBytes) + } catch (t: Throwable) { + Log.e(TAG, "Failed to append tail segment to session buffer", t) + } + } + segmentBuffer.reset() + } + val fullPcm = try { + sessionBuffer.toByteArray() + } catch (t: Throwable) { + Log.e(TAG, "Failed to dump session buffer", t) + ByteArray(0) + } + sessionBuffer.reset() + segmentBuffer.reset() + + if (fullPcm.isEmpty()) { + try { + listener.onError(context.getString(R.string.error_audio_empty)) + } catch (t: Throwable) { + Log.w(TAG, "notify audio empty failed", t) + } + return + } + + scope.launch(Dispatchers.IO) { + try { + val denoised = OfflineSpeechDenoiserManager.denoiseIfEnabled( + context = context, + prefs = prefs, + pcm = fullPcm, + sampleRate = sampleRate, + ) + onSessionFinished(denoised) + } catch (t: Throwable) { + if (t is CancellationException) { + Log.d(TAG, "final recognition cancelled: ${t.message}") + } else { + Log.e(TAG, "Final recognition failed", t) + try { + listener.onError( + context.getString( + R.string.error_recognize_failed_with_reason, + t.message ?: "", + ), + ) + } catch (e: Throwable) { + Log.w(TAG, "notify final error failed", e) + } + } + } } - } } - } } diff --git a/app/src/main/java/com/brycewg/asrkb/asr/SenseVoiceFileAsrEngine.kt b/app/src/main/java/com/brycewg/asrkb/asr/SenseVoiceFileAsrEngine.kt index 32bc8c6e..2254d50e 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/SenseVoiceFileAsrEngine.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/SenseVoiceFileAsrEngine.kt @@ -28,7 +28,7 @@ class SenseVoiceFileAsrEngine( scope: CoroutineScope, prefs: Prefs, listener: StreamingAsrEngine.Listener, - onRequestDuration: ((Long) -> Unit)? = null + onRequestDuration: ((Long) -> Unit)? = null, ) : BaseFileAsrEngine(context, scope, prefs, listener, onRequestDuration), PcmBatchRecognizer { // 本地 SenseVoice:为降低内存占用,主动限制为 5 分钟 @@ -42,7 +42,9 @@ class SenseVoiceFileAsrEngine( private fun showToast(resId: Int) { try { Handler(Looper.getMainLooper()).post { - try { Toast.makeText(context, context.getString(resId), Toast.LENGTH_SHORT).show() } catch (t: Throwable) { + try { + Toast.makeText(context, context.getString(resId), Toast.LENGTH_SHORT).show() + } catch (t: Throwable) { Log.e("SenseVoiceFileAsrEngine", "Failed to show toast", t) } } @@ -54,7 +56,9 @@ class SenseVoiceFileAsrEngine( private fun notifyLoadStart() { val ui = (listener as? LocalModelLoadUi) if (ui != null) { - try { ui.onLocalModelLoadStart() } catch (t: Throwable) { + try { + ui.onLocalModelLoadStart() + } catch (t: Throwable) { Log.e("SenseVoiceFileAsrEngine", "Failed to notify load start", t) } } else { @@ -65,7 +69,9 @@ class SenseVoiceFileAsrEngine( private fun notifyLoadDone() { val ui = (listener as? LocalModelLoadUi) if (ui != null) { - try { ui.onLocalModelLoadDone() } catch (t: Throwable) { + try { + ui.onLocalModelLoadDone() + } catch (t: Throwable) { Log.e("SenseVoiceFileAsrEngine", "Failed to notify load done", t) } } @@ -76,7 +82,9 @@ class SenseVoiceFileAsrEngine( // 若未集成 sherpa-onnx Kotlin/so,则直接报错以避免无意义的录音 val manager = SenseVoiceOnnxManager.getInstance() if (!manager.isOnnxAvailable()) { - try { listener.onError(context.getString(R.string.error_local_asr_not_ready)) } catch (t: Throwable) { + try { + listener.onError(context.getString(R.string.error_local_asr_not_ready)) + } catch (t: Throwable) { Log.e("SenseVoiceFileAsrEngine", "Failed to send error callback", t) } return false @@ -93,12 +101,16 @@ class SenseVoiceFileAsrEngine( return } // 模型目录:固定为外部专属目录(不可配置);外部不可用时回退内部目录 - val base = try { context.getExternalFilesDir(null) } catch (t: Throwable) { + val base = try { + context.getExternalFilesDir(null) + } catch (t: Throwable) { Log.w("SenseVoiceFileAsrEngine", "Failed to get external files dir", t) null } ?: context.filesDir val probeRoot = java.io.File(base, "sensevoice") - val rawVariant = try { prefs.svModelVariant } catch (t: Throwable) { + val rawVariant = try { + prefs.svModelVariant + } catch (t: Throwable) { Log.w("SenseVoiceFileAsrEngine", "Failed to get model variant", t) "small-int8" } @@ -136,7 +148,9 @@ class SenseVoiceFileAsrEngine( // 注意:当从绝对路径加载模型/词表时,必须将 assetManager 设为 null // 参考 sherpa-onnx 提示 https://github.com/k2-fsa/sherpa-onnx/issues/2562 // 在需要创建新识别器时,向用户提示"加载中/完成" - val keepMinutes = try { prefs.svKeepAliveMinutes } catch (t: Throwable) { + val keepMinutes = try { + prefs.svKeepAliveMinutes + } catch (t: Throwable) { Log.w("SenseVoiceFileAsrEngine", "Failed to get keep alive minutes", t) -1 } @@ -153,12 +167,16 @@ class SenseVoiceFileAsrEngine( Log.w("SenseVoiceFileAsrEngine", "Failed to get language", t) "auto" }, - useItn = try { prefs.svUseItn } catch (t: Throwable) { + useItn = try { + prefs.svUseItn + } catch (t: Throwable) { Log.w("SenseVoiceFileAsrEngine", "Failed to get useItn", t) false }, provider = "cpu", - numThreads = try { prefs.svNumThreads } catch (t: Throwable) { + numThreads = try { + prefs.svNumThreads + } catch (t: Throwable) { Log.w("SenseVoiceFileAsrEngine", "Failed to get num threads", t) 2 }, @@ -167,7 +185,7 @@ class SenseVoiceFileAsrEngine( keepAliveMs = keepMs, alwaysKeep = alwaysKeep, onLoadStart = { notifyLoadStart() }, - onLoadDone = { notifyLoadDone() } + onLoadDone = { notifyLoadDone() }, ) if (text.isNullOrBlank()) { @@ -181,13 +199,17 @@ class SenseVoiceFileAsrEngine( listener.onError(context.getString(R.string.error_recognize_failed_with_reason, t.message ?: "")) } finally { val dt = System.currentTimeMillis() - t0 - try { onRequestDuration?.invoke(dt) } catch (t: Throwable) { + try { + onRequestDuration?.invoke(dt) + } catch (t: Throwable) { Log.e("SenseVoiceFileAsrEngine", "Failed to invoke duration callback", t) } } } - override suspend fun recognizeFromPcm(pcm: ByteArray) { recognize(pcm) } + override suspend fun recognizeFromPcm(pcm: ByteArray) { + recognize(pcm) + } private fun pcmToFloatArray(pcm: ByteArray): FloatArray { if (pcm.isEmpty()) return FloatArray(0) @@ -226,7 +248,7 @@ fun preloadSenseVoiceIfConfigured( force: Boolean = false, // 该预加载是否紧跟着会被立即使用(例如开始录音)。 // 若为 true,则不在此处调度卸载,由实际使用/释放时再调度;否则预加载后按设置调度卸载。 - forImmediateUse: Boolean = false + forImmediateUse: Boolean = false, ) { try { val manager = SenseVoiceOnnxManager.getInstance() @@ -283,7 +305,7 @@ fun preloadSenseVoiceIfConfigured( Toast.makeText( context, context.getString(R.string.sv_model_ready_with_ms, dt), - Toast.LENGTH_SHORT + Toast.LENGTH_SHORT, ).show() } } @@ -346,7 +368,7 @@ private data class RecognizerConfig( val language: String, val useItn: Boolean, val provider: String, - val numThreads: Int + val numThreads: Int, ) { fun toCacheKey(): String = listOf(tokens, model, language, useItn, provider, numThreads).joinToString("|") } @@ -384,7 +406,7 @@ internal class ReflectiveStream(val instance: Any) { */ internal class ReflectiveRecognizer( private val instance: Any, - private val clsOfflineRecognizer: Class<*> + private val clsOfflineRecognizer: Class<*>, ) { fun createStream(): ReflectiveStream { @@ -435,7 +457,7 @@ class SenseVoiceOnnxManager private constructor() { private const val TAG = "SenseVoiceOnnxManager" @Volatile - private var instance: SenseVoiceOnnxManager? = null + private var instance: SenseVoiceOnnxManager? = null fun getInstance(): SenseVoiceOnnxManager { return instance ?: synchronized(this) { @@ -448,12 +470,19 @@ class SenseVoiceOnnxManager private constructor() { private val mutex = Mutex() @Volatile private var cachedConfig: RecognizerConfig? = null + @Volatile private var cachedRecognizer: ReflectiveRecognizer? = null + @Volatile private var preparing: Boolean = false + @Volatile private var clsOfflineRecognizer: Class<*>? = null + @Volatile private var clsOfflineRecognizerConfig: Class<*>? = null + @Volatile private var clsOfflineModelConfig: Class<*>? = null + @Volatile private var clsOfflineSenseVoiceModelConfig: Class<*>? = null + @Volatile private var unloadJob: Job? = null fun isOnnxAvailable(): Boolean { @@ -496,6 +525,7 @@ class SenseVoiceOnnxManager private constructor() { // 记录最近一次配置,用于解码完成后按用户设置调度卸载 @Volatile private var lastKeepAliveMs: Long = 0L + @Volatile private var lastAlwaysKeep: Boolean = false /** @@ -659,7 +689,7 @@ class SenseVoiceOnnxManager private constructor() { assetManager: android.content.res.AssetManager?, config: RecognizerConfig, onLoadStart: (() -> Unit)?, - onLoadDone: (() -> Unit)? + onLoadDone: (() -> Unit)?, ): ReflectiveRecognizer? { initClasses() @@ -718,7 +748,7 @@ class SenseVoiceOnnxManager private constructor() { keepAliveMs: Long, alwaysKeep: Boolean, onLoadStart: (() -> Unit)? = null, - onLoadDone: (() -> Unit)? = null + onLoadDone: (() -> Unit)? = null, ): String? = mutex.withLock { try { val config = RecognizerConfig(tokens, model, language, useItn, provider, numThreads) @@ -760,7 +790,7 @@ class SenseVoiceOnnxManager private constructor() { keepAliveMs: Long, alwaysKeep: Boolean, onLoadStart: (() -> Unit)? = null, - onLoadDone: (() -> Unit)? = null + onLoadDone: (() -> Unit)? = null, ): Boolean = mutex.withLock { try { val config = RecognizerConfig(tokens, model, language, useItn, provider, numThreads) @@ -778,7 +808,6 @@ class SenseVoiceOnnxManager private constructor() { return@withLock false } } - } /** @@ -808,10 +837,10 @@ object SenseVoiceOnnxBridge { keepAliveMs: Long, alwaysKeep: Boolean, onLoadStart: (() -> Unit)? = null, - onLoadDone: (() -> Unit)? = null + onLoadDone: (() -> Unit)? = null, ): String? = manager.decodeOffline( assetManager, tokens, model, language, useItn, provider, numThreads, - samples, sampleRate, keepAliveMs, alwaysKeep, onLoadStart, onLoadDone + samples, sampleRate, keepAliveMs, alwaysKeep, onLoadStart, onLoadDone, ) suspend fun prepare( @@ -825,10 +854,10 @@ object SenseVoiceOnnxBridge { keepAliveMs: Long, alwaysKeep: Boolean, onLoadStart: (() -> Unit)? = null, - onLoadDone: (() -> Unit)? = null + onLoadDone: (() -> Unit)? = null, ): Boolean = manager.prepare( assetManager, tokens, model, language, useItn, provider, numThreads, - keepAliveMs, alwaysKeep, onLoadStart, onLoadDone + keepAliveMs, alwaysKeep, onLoadStart, onLoadDone, ) } diff --git a/app/src/main/java/com/brycewg/asrkb/asr/SenseVoicePseudoStreamAsrEngine.kt b/app/src/main/java/com/brycewg/asrkb/asr/SenseVoicePseudoStreamAsrEngine.kt index 9f62e455..78af402d 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/SenseVoicePseudoStreamAsrEngine.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/SenseVoicePseudoStreamAsrEngine.kt @@ -14,7 +14,7 @@ class SenseVoicePseudoStreamAsrEngine( scope: CoroutineScope, prefs: Prefs, listener: StreamingAsrEngine.Listener, - onRequestDuration: ((Long) -> Unit)? = null + onRequestDuration: ((Long) -> Unit)? = null, ) : LocalModelPseudoStreamAsrEngine(context, scope, prefs, listener, onRequestDuration) { companion object { @@ -28,7 +28,7 @@ class SenseVoicePseudoStreamAsrEngine( listener = listener, sampleRate = sampleRate, onRequestDuration = onRequestDuration, - tag = TAG + tag = TAG, ) override fun ensureReady(): Boolean { diff --git a/app/src/main/java/com/brycewg/asrkb/asr/SenseVoicePseudoStreamDelegate.kt b/app/src/main/java/com/brycewg/asrkb/asr/SenseVoicePseudoStreamDelegate.kt index 269db7e1..a6d745eb 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/SenseVoicePseudoStreamDelegate.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/SenseVoicePseudoStreamDelegate.kt @@ -22,7 +22,7 @@ internal class SenseVoicePseudoStreamDelegate( private val listener: StreamingAsrEngine.Listener, private val sampleRate: Int, private val onRequestDuration: ((Long) -> Unit)?, - private val tag: String + private val tag: String, ) { private val previewMutex = Mutex() @@ -140,8 +140,8 @@ internal class SenseVoicePseudoStreamDelegate( listener.onError( context.getString( R.string.error_recognize_failed_with_reason, - t.message ?: "" - ) + t.message ?: "", + ), ) } catch (e: Throwable) { Log.e(tag, "Failed to notify final recognition error", e) @@ -174,7 +174,7 @@ internal class SenseVoicePseudoStreamDelegate( Toast.makeText( context, context.getString(R.string.sv_loading_model), - Toast.LENGTH_SHORT + Toast.LENGTH_SHORT, ).show() } catch (t: Throwable) { Log.e(tag, "Failed to show toast", t) @@ -199,7 +199,7 @@ internal class SenseVoicePseudoStreamDelegate( private suspend fun decodeOnce( pcm: ByteArray, - reportErrorToUser: Boolean + reportErrorToUser: Boolean, ): String? { val manager = SenseVoiceOnnxManager.getInstance() if (!manager.isOnnxAvailable()) { @@ -314,7 +314,7 @@ internal class SenseVoicePseudoStreamDelegate( keepAliveMs = keepMs, alwaysKeep = alwaysKeep, onLoadStart = { notifyLoadStart() }, - onLoadDone = { notifyLoadDone() } + onLoadDone = { notifyLoadDone() }, ) if (text.isNullOrBlank()) { diff --git a/app/src/main/java/com/brycewg/asrkb/asr/SenseVoicePushPcmPseudoStreamAsrEngine.kt b/app/src/main/java/com/brycewg/asrkb/asr/SenseVoicePushPcmPseudoStreamAsrEngine.kt index ec16f1bc..b0186fa8 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/SenseVoicePushPcmPseudoStreamAsrEngine.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/SenseVoicePushPcmPseudoStreamAsrEngine.kt @@ -12,59 +12,59 @@ import java.util.concurrent.atomic.AtomicLong * - finishPcm/stop 时对整段音频做一次离线识别(onFinal)。 */ class SenseVoicePushPcmPseudoStreamAsrEngine( - context: Context, - scope: CoroutineScope, - prefs: Prefs, - listener: StreamingAsrEngine.Listener, - onRequestDuration: ((Long) -> Unit)? = null + context: Context, + scope: CoroutineScope, + prefs: Prefs, + listener: StreamingAsrEngine.Listener, + onRequestDuration: ((Long) -> Unit)? = null, ) : PushPcmPseudoStreamAsrEngine(context, scope, prefs, listener, onRequestDuration) { - companion object { - private const val TAG = "SvPushPcmPseudo" - } + companion object { + private const val TAG = "SvPushPcmPseudo" + } - private val delegate = SenseVoicePseudoStreamDelegate( - context = context, - scope = scope, - prefs = prefs, - listener = listener, - sampleRate = sampleRate, - onRequestDuration = onRequestDuration, - tag = TAG - ) + private val delegate = SenseVoicePseudoStreamDelegate( + context = context, + scope = scope, + prefs = prefs, + listener = listener, + sampleRate = sampleRate, + onRequestDuration = onRequestDuration, + tag = TAG, + ) - private val sessionIdGenerator = AtomicLong(0L) + private val sessionIdGenerator = AtomicLong(0L) - @Volatile - private var activeSessionId: Long = 0L + @Volatile + private var activeSessionId: Long = 0L - @Volatile - private var finishingSessionId: Long = 0L + @Volatile + private var finishingSessionId: Long = 0L - override fun start() { - val wasRunning = isRunning - super.start() - if (wasRunning || !isRunning) return - val sessionId = sessionIdGenerator.incrementAndGet() - activeSessionId = sessionId - finishingSessionId = sessionId - delegate.onSessionStart(sessionId) - } + override fun start() { + val wasRunning = isRunning + super.start() + if (wasRunning || !isRunning) return + val sessionId = sessionIdGenerator.incrementAndGet() + activeSessionId = sessionId + finishingSessionId = sessionId + delegate.onSessionStart(sessionId) + } - override fun stop() { - finishingSessionId = activeSessionId - super.stop() - } + override fun stop() { + finishingSessionId = activeSessionId + super.stop() + } - override fun ensureReady(): Boolean { - return delegate.ensureReady() - } + override fun ensureReady(): Boolean { + return delegate.ensureReady() + } - override fun onSegmentBoundary(pcmSegment: ByteArray) { - delegate.onSegmentBoundary(activeSessionId, pcmSegment) - } + override fun onSegmentBoundary(pcmSegment: ByteArray) { + delegate.onSegmentBoundary(activeSessionId, pcmSegment) + } - override suspend fun onSessionFinished(fullPcm: ByteArray) { - delegate.onSessionFinished(finishingSessionId, fullPcm) - } + override suspend fun onSessionFinished(fullPcm: ByteArray) { + delegate.onSessionFinished(finishingSessionId, fullPcm) + } } diff --git a/app/src/main/java/com/brycewg/asrkb/asr/SherpaPunctuationManager.kt b/app/src/main/java/com/brycewg/asrkb/asr/SherpaPunctuationManager.kt index 74d06ad6..c44cda93 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/SherpaPunctuationManager.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/SherpaPunctuationManager.kt @@ -5,9 +5,9 @@ import android.util.Log import android.widget.Toast import com.brycewg.asrkb.R import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.withContext import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock +import kotlinx.coroutines.withContext /** * sherpa-onnx 文本标点管理器(Offline/Online Punctuation)。 @@ -19,440 +19,448 @@ import kotlinx.coroutines.sync.withLock */ class SherpaPunctuationManager private constructor() { - companion object { - private const val TAG = "SherpaPunctuationManager" + companion object { + private const val TAG = "SherpaPunctuationManager" + + // 统一目录/文件约定:三个本地模型共用同一套标点模型 + private const val MODEL_DIR_NAME = "punctuation" + private const val MODEL_FILE_NAME = "model.int8.onnx" + + // 兼容:官方压缩包解压后的默认目录名 + private const val LEGACY_MODEL_PARENT = + "sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12-int8" + + @Volatile + private var hasWarnedMissingModel: Boolean = false + + @Volatile + private var instance: SherpaPunctuationManager? = null - // 统一目录/文件约定:三个本地模型共用同一套标点模型 - private const val MODEL_DIR_NAME = "punctuation" - private const val MODEL_FILE_NAME = "model.int8.onnx" + fun getInstance(): SherpaPunctuationManager { + return instance ?: synchronized(this) { + instance ?: SherpaPunctuationManager().also { instance = it } + } + } + + /** + * 统一查找标点模型所在目录: + * - 优先 externalFilesDir/punctuation 下的 model.int8.onnx; + * - 次选 externalFilesDir/LEGACY_MODEL_PARENT; + * - 最后在 punctuation/ 下遍历一级子目录寻找 model.int8.onnx。 + */ + fun findPunctuationModelDir(context: Context): java.io.File? { + val base = try { + context.getExternalFilesDir(null) + } catch (t: Throwable) { + Log.w(TAG, "Failed to get external files dir", t) + null + } ?: context.filesDir + + val root = java.io.File(base, MODEL_DIR_NAME) + val direct = java.io.File(root, MODEL_FILE_NAME) + if (direct.exists()) return root + + val legacyRoot = java.io.File(base, LEGACY_MODEL_PARENT) + val legacyFile = java.io.File(legacyRoot, MODEL_FILE_NAME) + if (legacyFile.exists()) return legacyRoot + + if (root.exists()) { + val subs = root.listFiles() ?: emptyArray() + for (dir in subs) { + if (!dir.isDirectory) continue + val f = java.io.File(dir, MODEL_FILE_NAME) + if (f.exists()) return dir + } + } + return null + } - // 兼容:官方压缩包解压后的默认目录名 - private const val LEGACY_MODEL_PARENT = - "sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12-int8" + /** + * 返回离线标点模型路径(model.int8.onnx),若未找到则返回 null。 + */ + fun findOfflineModelPath(context: Context): String? { + val dir = findPunctuationModelDir(context) ?: return null + val file = java.io.File(dir, MODEL_FILE_NAME) + return if (file.exists()) file.absolutePath else null + } + + /** + * 是否已安装本地标点模型。 + */ + fun isModelInstalled(context: Context): Boolean { + return findOfflineModelPath(context) != null + } + /** + * 清除已安装的标点模型(通用),并释放缓存实例。 + * 返回是否实际删除了任何文件。 + */ + fun clearInstalledModel(context: Context): Boolean { + // 先尝试释放运行时实例 + try { + getInstance().unloadOffline() + } catch (t: Throwable) { + Log.e(TAG, "Failed to unload offline punctuation", t) + } + try { + getInstance().unloadOnline() + } catch (t: Throwable) { + Log.e(TAG, "Failed to unload online punctuation", t) + } + + val base = try { + context.getExternalFilesDir(null) + } catch (t: Throwable) { + Log.w(TAG, "Failed to get external files dir for clearInstalledModel", t) + null + } ?: context.filesDir + + var deleted = false + val primary = java.io.File(base, MODEL_DIR_NAME) + val legacy = java.io.File(base, LEGACY_MODEL_PARENT) + + fun clearDir(dir: java.io.File) { + if (!dir.exists()) return + try { + if (dir.deleteRecursively()) { + deleted = true + } else { + Log.w(TAG, "Failed to delete punctuation dir: ${dir.path}") + } + } catch (t: Throwable) { + Log.w(TAG, "Error deleting punctuation dir: ${dir.path}", t) + } + } + + clearDir(primary) + clearDir(legacy) + // 清理过程中重置提示标志,使后续可再次提示 + hasWarnedMissingModel = false + return deleted + } + + /** + * 缺少标点模型时给出一次性 Toast 提示(当前进程仅提示一次)。 + */ + fun maybeWarnModelMissing(context: Context) { + if (isModelInstalled(context)) return + if (hasWarnedMissingModel) return + hasWarnedMissingModel = true + try { + Toast.makeText( + context.applicationContext, + context.getString(R.string.toast_punct_model_missing), + Toast.LENGTH_SHORT, + ).show() + } catch (t: Throwable) { + Log.e(TAG, "Failed to show punctuation missing toast", t) + } + } + } + + private val mutex = Mutex() + + // OfflinePunctuation 反射类与缓存实例 @Volatile - private var hasWarnedMissingModel: Boolean = false + private var clsOfflinePunctuation: Class<*>? = null @Volatile - private var instance: SherpaPunctuationManager? = null + private var clsOfflinePuncConfig: Class<*>? = null - fun getInstance(): SherpaPunctuationManager { - return instance ?: synchronized(this) { - instance ?: SherpaPunctuationManager().also { instance = it } - } - } + @Volatile + private var clsOfflinePuncModelConfig: Class<*>? = null - /** - * 统一查找标点模型所在目录: - * - 优先 externalFilesDir/punctuation 下的 model.int8.onnx; - * - 次选 externalFilesDir/LEGACY_MODEL_PARENT; - * - 最后在 punctuation/ 下遍历一级子目录寻找 model.int8.onnx。 - */ - fun findPunctuationModelDir(context: Context): java.io.File? { - val base = try { - context.getExternalFilesDir(null) - } catch (t: Throwable) { - Log.w(TAG, "Failed to get external files dir", t) - null - } ?: context.filesDir - - val root = java.io.File(base, MODEL_DIR_NAME) - val direct = java.io.File(root, MODEL_FILE_NAME) - if (direct.exists()) return root - - val legacyRoot = java.io.File(base, LEGACY_MODEL_PARENT) - val legacyFile = java.io.File(legacyRoot, MODEL_FILE_NAME) - if (legacyFile.exists()) return legacyRoot - - if (root.exists()) { - val subs = root.listFiles() ?: emptyArray() - for (dir in subs) { - if (!dir.isDirectory) continue - val f = java.io.File(dir, MODEL_FILE_NAME) - if (f.exists()) return dir + @Volatile + private var offline: ReflectiveOfflinePunctuation? = null + + @Volatile + private var offlineModelPath: String? = null + + // OnlinePunctuation 反射类与缓存实例(为流式模型预留) + @Volatile + private var clsOnlinePunctuation: Class<*>? = null + + @Volatile + private var clsOnlinePuncConfig: Class<*>? = null + + @Volatile + private var clsOnlinePuncModelConfig: Class<*>? = null + + @Volatile + private var online: ReflectiveOnlinePunctuation? = null + + @Volatile + private var onlineModelKey: OnlineModelKey? = null + + private data class OnlineModelKey( + val cnnBilstm: String, + val bpeVocab: String?, + ) + + fun isOfflineSupported(): Boolean { + return try { + Class.forName("com.k2fsa.sherpa.onnx.OfflinePunctuation") + true + } catch (t: Throwable) { + Log.d(TAG, "OfflinePunctuation not available", t) + false } - } - return null } - /** - * 返回离线标点模型路径(model.int8.onnx),若未找到则返回 null。 - */ - fun findOfflineModelPath(context: Context): String? { - val dir = findPunctuationModelDir(context) ?: return null - val file = java.io.File(dir, MODEL_FILE_NAME) - return if (file.exists()) file.absolutePath else null + fun isOnlineSupported(): Boolean { + return try { + Class.forName("com.k2fsa.sherpa.onnx.OnlinePunctuation") + true + } catch (t: Throwable) { + Log.d(TAG, "OnlinePunctuation not available", t) + false + } } /** - * 是否已安装本地标点模型。 + * 确保 OfflinePunctuation 已按当前文件路径加载。 + * 不触发任何下载,仅根据现有文件构建实例。 */ - fun isModelInstalled(context: Context): Boolean { - return findOfflineModelPath(context) != null + suspend fun ensureOfflineLoaded(context: Context): Boolean = mutex.withLock { + if (!isOfflineSupported()) return@withLock false + val modelPath = findOfflineModelPath(context) ?: return@withLock false + if (offline != null && offlineModelPath == modelPath) return@withLock true + + initOfflineClasses() + val inst = createOfflineInstance(modelPath) ?: return@withLock false + try { + offline?.release() + } catch (t: Throwable) { + Log.e(TAG, "Failed to release previous OfflinePunctuation", t) + } + offline = inst + offlineModelPath = modelPath + true } /** - * 清除已安装的标点模型(通用),并释放缓存实例。 - * 返回是否实际删除了任何文件。 + * 通过指定的模型文件路径准备 OnlinePunctuation(流式标点)。 + * - cnnBilstm: CNN-BiLSTM 模型 onnx 路径; + * - bpeVocab: 词表路径,可为空。 */ - fun clearInstalledModel(context: Context): Boolean { - // 先尝试释放运行时实例 - try { - getInstance().unloadOffline() - } catch (t: Throwable) { - Log.e(TAG, "Failed to unload offline punctuation", t) - } - try { - getInstance().unloadOnline() - } catch (t: Throwable) { - Log.e(TAG, "Failed to unload online punctuation", t) - } - - val base = try { - context.getExternalFilesDir(null) - } catch (t: Throwable) { - Log.w(TAG, "Failed to get external files dir for clearInstalledModel", t) - null - } ?: context.filesDir - - var deleted = false - val primary = java.io.File(base, MODEL_DIR_NAME) - val legacy = java.io.File(base, LEGACY_MODEL_PARENT) - - fun clearDir(dir: java.io.File) { - if (!dir.exists()) return + suspend fun ensureOnlineLoaded( + cnnBilstm: String, + bpeVocab: String?, + ): Boolean = mutex.withLock { + if (!isOnlineSupported()) return@withLock false + val key = OnlineModelKey(cnnBilstm, bpeVocab) + if (online != null && onlineModelKey == key) return@withLock true + + initOnlineClasses() + val inst = createOnlineInstance(cnnBilstm, bpeVocab) ?: return@withLock false try { - if (dir.deleteRecursively()) { - deleted = true - } else { - Log.w(TAG, "Failed to delete punctuation dir: ${dir.path}") - } + online?.release() } catch (t: Throwable) { - Log.w(TAG, "Error deleting punctuation dir: ${dir.path}", t) + Log.e(TAG, "Failed to release previous OnlinePunctuation", t) } - } - - clearDir(primary) - clearDir(legacy) - // 清理过程中重置提示标志,使后续可再次提示 - hasWarnedMissingModel = false - return deleted + online = inst + onlineModelKey = key + true } /** - * 缺少标点模型时给出一次性 Toast 提示(当前进程仅提示一次)。 + * 对文本执行离线标点(ct-transformer),失败时返回原文。 */ - fun maybeWarnModelMissing(context: Context) { - if (isModelInstalled(context)) return - if (hasWarnedMissingModel) return - hasWarnedMissingModel = true - try { - Toast.makeText( - context.applicationContext, - context.getString(R.string.toast_punct_model_missing), - Toast.LENGTH_SHORT - ).show() - } catch (t: Throwable) { - Log.e(TAG, "Failed to show punctuation missing toast", t) - } - } - } - - private val mutex = Mutex() - - // OfflinePunctuation 反射类与缓存实例 - @Volatile - private var clsOfflinePunctuation: Class<*>? = null - @Volatile - private var clsOfflinePuncConfig: Class<*>? = null - @Volatile - private var clsOfflinePuncModelConfig: Class<*>? = null - @Volatile - private var offline: ReflectiveOfflinePunctuation? = null - @Volatile - private var offlineModelPath: String? = null - - // OnlinePunctuation 反射类与缓存实例(为流式模型预留) - @Volatile - private var clsOnlinePunctuation: Class<*>? = null - @Volatile - private var clsOnlinePuncConfig: Class<*>? = null - @Volatile - private var clsOnlinePuncModelConfig: Class<*>? = null - @Volatile - private var online: ReflectiveOnlinePunctuation? = null - @Volatile - private var onlineModelKey: OnlineModelKey? = null - - private data class OnlineModelKey( - val cnnBilstm: String, - val bpeVocab: String? - ) - - fun isOfflineSupported(): Boolean { - return try { - Class.forName("com.k2fsa.sherpa.onnx.OfflinePunctuation") - true - } catch (t: Throwable) { - Log.d(TAG, "OfflinePunctuation not available", t) - false - } - } - - fun isOnlineSupported(): Boolean { - return try { - Class.forName("com.k2fsa.sherpa.onnx.OnlinePunctuation") - true - } catch (t: Throwable) { - Log.d(TAG, "OnlinePunctuation not available", t) - false - } - } - - /** - * 确保 OfflinePunctuation 已按当前文件路径加载。 - * 不触发任何下载,仅根据现有文件构建实例。 - */ - suspend fun ensureOfflineLoaded(context: Context): Boolean = mutex.withLock { - if (!isOfflineSupported()) return@withLock false - val modelPath = findOfflineModelPath(context) ?: return@withLock false - if (offline != null && offlineModelPath == modelPath) return@withLock true - - initOfflineClasses() - val inst = createOfflineInstance(modelPath) ?: return@withLock false - try { - offline?.release() - } catch (t: Throwable) { - Log.e(TAG, "Failed to release previous OfflinePunctuation", t) - } - offline = inst - offlineModelPath = modelPath - true - } - - /** - * 通过指定的模型文件路径准备 OnlinePunctuation(流式标点)。 - * - cnnBilstm: CNN-BiLSTM 模型 onnx 路径; - * - bpeVocab: 词表路径,可为空。 - */ - suspend fun ensureOnlineLoaded( - cnnBilstm: String, - bpeVocab: String? - ): Boolean = mutex.withLock { - if (!isOnlineSupported()) return@withLock false - val key = OnlineModelKey(cnnBilstm, bpeVocab) - if (online != null && onlineModelKey == key) return@withLock true - - initOnlineClasses() - val inst = createOnlineInstance(cnnBilstm, bpeVocab) ?: return@withLock false - try { - online?.release() - } catch (t: Throwable) { - Log.e(TAG, "Failed to release previous OnlinePunctuation", t) - } - online = inst - onlineModelKey = key - true - } - - /** - * 对文本执行离线标点(ct-transformer),失败时返回原文。 - */ - suspend fun addOfflinePunctuation(context: Context, text: String): String { - if (text.isEmpty()) return text - val ok = ensureOfflineLoaded(context) - val p = offline - if (!ok || p == null) return text - - return withContext(Dispatchers.Default) { - try { - p.addPunctuation(text) ?: text - } catch (t: Throwable) { - Log.e(TAG, "Offline punctuation failed", t) - text - } - } - } - - /** - * 对文本执行在线标点(CNN-BiLSTM),失败时返回原文。 - * 调用方需事先通过 ensureOnlineLoaded 准备好模型。 - */ - suspend fun addOnlinePunctuation(text: String): String { - if (text.isEmpty()) return text - val p = online ?: return text - return withContext(Dispatchers.Default) { - try { - p.addPunctuation(text) ?: text - } catch (t: Throwable) { - Log.e(TAG, "Online punctuation failed", t) - text - } + suspend fun addOfflinePunctuation(context: Context, text: String): String { + if (text.isEmpty()) return text + val ok = ensureOfflineLoaded(context) + val p = offline + if (!ok || p == null) return text + + return withContext(Dispatchers.Default) { + try { + p.addPunctuation(text) ?: text + } catch (t: Throwable) { + Log.e(TAG, "Offline punctuation failed", t) + text + } + } } - } - - fun unloadOffline() { - try { - val inst = offline - offline = null - offlineModelPath = null - inst?.release() - } catch (t: Throwable) { - Log.e(TAG, "unloadOffline failed", t) + + /** + * 对文本执行在线标点(CNN-BiLSTM),失败时返回原文。 + * 调用方需事先通过 ensureOnlineLoaded 准备好模型。 + */ + suspend fun addOnlinePunctuation(text: String): String { + if (text.isEmpty()) return text + val p = online ?: return text + return withContext(Dispatchers.Default) { + try { + p.addPunctuation(text) ?: text + } catch (t: Throwable) { + Log.e(TAG, "Online punctuation failed", t) + text + } + } } - } - - fun unloadOnline() { - try { - val inst = online - online = null - onlineModelKey = null - inst?.release() - } catch (t: Throwable) { - Log.e(TAG, "unloadOnline failed", t) + + fun unloadOffline() { + try { + val inst = offline + offline = null + offlineModelPath = null + inst?.release() + } catch (t: Throwable) { + Log.e(TAG, "unloadOffline failed", t) + } } - } - - private fun initOfflineClasses() { - if (clsOfflinePunctuation == null) { - clsOfflinePunctuation = Class.forName("com.k2fsa.sherpa.onnx.OfflinePunctuation") - clsOfflinePuncConfig = Class.forName("com.k2fsa.sherpa.onnx.OfflinePunctuationConfig") - clsOfflinePuncModelConfig = Class.forName("com.k2fsa.sherpa.onnx.OfflinePunctuationModelConfig") - Log.d(TAG, "Initialized OfflinePunctuation reflection classes") + + fun unloadOnline() { + try { + val inst = online + online = null + onlineModelKey = null + inst?.release() + } catch (t: Throwable) { + Log.e(TAG, "unloadOnline failed", t) + } } - } - - private fun initOnlineClasses() { - if (clsOnlinePunctuation == null) { - clsOnlinePunctuation = Class.forName("com.k2fsa.sherpa.onnx.OnlinePunctuation") - clsOnlinePuncConfig = Class.forName("com.k2fsa.sherpa.onnx.OnlinePunctuationConfig") - clsOnlinePuncModelConfig = Class.forName("com.k2fsa.sherpa.onnx.OnlinePunctuationModelConfig") - Log.d(TAG, "Initialized OnlinePunctuation reflection classes") + + private fun initOfflineClasses() { + if (clsOfflinePunctuation == null) { + clsOfflinePunctuation = Class.forName("com.k2fsa.sherpa.onnx.OfflinePunctuation") + clsOfflinePuncConfig = Class.forName("com.k2fsa.sherpa.onnx.OfflinePunctuationConfig") + clsOfflinePuncModelConfig = Class.forName("com.k2fsa.sherpa.onnx.OfflinePunctuationModelConfig") + Log.d(TAG, "Initialized OfflinePunctuation reflection classes") + } } - } - - private fun trySetField(target: Any, name: String, value: Any?): Boolean { - return try { - val f = target.javaClass.getDeclaredField(name) - f.isAccessible = true - f.set(target, value) - true - } catch (t: Throwable) { - try { - val methodName = "set" + name.replaceFirstChar { - if (it.isLowerCase()) it.titlecase() else it.toString() + + private fun initOnlineClasses() { + if (clsOnlinePunctuation == null) { + clsOnlinePunctuation = Class.forName("com.k2fsa.sherpa.onnx.OnlinePunctuation") + clsOnlinePuncConfig = Class.forName("com.k2fsa.sherpa.onnx.OnlinePunctuationConfig") + clsOnlinePuncModelConfig = Class.forName("com.k2fsa.sherpa.onnx.OnlinePunctuationModelConfig") + Log.d(TAG, "Initialized OnlinePunctuation reflection classes") } - val m = if (value == null) { - target.javaClass.getMethod(methodName, Any::class.java) - } else { - target.javaClass.getMethod(methodName, value.javaClass) + } + + private fun trySetField(target: Any, name: String, value: Any?): Boolean { + return try { + val f = target.javaClass.getDeclaredField(name) + f.isAccessible = true + f.set(target, value) + true + } catch (t: Throwable) { + try { + val methodName = "set" + name.replaceFirstChar { + if (it.isLowerCase()) it.titlecase() else it.toString() + } + val m = if (value == null) { + target.javaClass.getMethod(methodName, Any::class.java) + } else { + target.javaClass.getMethod(methodName, value.javaClass) + } + m.invoke(target, value) + true + } catch (t2: Throwable) { + Log.w(TAG, "Failed to set field '$name'", t2) + false + } } - m.invoke(target, value) - true - } catch (t2: Throwable) { - Log.w(TAG, "Failed to set field '$name'", t2) - false - } } - } - - private fun createOfflineInstance(modelPath: String): ReflectiveOfflinePunctuation? { - return try { - val modelCfg = clsOfflinePuncModelConfig!!.getDeclaredConstructor().newInstance() - trySetField(modelCfg, "ctTransformer", modelPath) - trySetField(modelCfg, "numThreads", 1) - trySetField(modelCfg, "debug", false) - trySetField(modelCfg, "provider", "cpu") - - val cfgCtor = clsOfflinePuncConfig!!.getDeclaredConstructor(clsOfflinePuncModelConfig) - val cfg = cfgCtor.newInstance(modelCfg) - - val ctor = clsOfflinePunctuation!!.getDeclaredConstructor( - android.content.res.AssetManager::class.java, - clsOfflinePuncConfig - ) - val inst = ctor.newInstance(null, cfg) - ReflectiveOfflinePunctuation(inst, clsOfflinePunctuation!!) - } catch (t: Throwable) { - Log.e(TAG, "Failed to create OfflinePunctuation instance", t) - null + + private fun createOfflineInstance(modelPath: String): ReflectiveOfflinePunctuation? { + return try { + val modelCfg = clsOfflinePuncModelConfig!!.getDeclaredConstructor().newInstance() + trySetField(modelCfg, "ctTransformer", modelPath) + trySetField(modelCfg, "numThreads", 1) + trySetField(modelCfg, "debug", false) + trySetField(modelCfg, "provider", "cpu") + + val cfgCtor = clsOfflinePuncConfig!!.getDeclaredConstructor(clsOfflinePuncModelConfig) + val cfg = cfgCtor.newInstance(modelCfg) + + val ctor = clsOfflinePunctuation!!.getDeclaredConstructor( + android.content.res.AssetManager::class.java, + clsOfflinePuncConfig, + ) + val inst = ctor.newInstance(null, cfg) + ReflectiveOfflinePunctuation(inst, clsOfflinePunctuation!!) + } catch (t: Throwable) { + Log.e(TAG, "Failed to create OfflinePunctuation instance", t) + null + } } - } - - private fun createOnlineInstance( - cnnBilstm: String, - bpeVocab: String? - ): ReflectiveOnlinePunctuation? { - return try { - val modelCfg = clsOnlinePuncModelConfig!!.getDeclaredConstructor().newInstance() - trySetField(modelCfg, "cnnBilstm", cnnBilstm) - if (!bpeVocab.isNullOrBlank()) { - trySetField(modelCfg, "bpeVocab", bpeVocab) - } - trySetField(modelCfg, "numThreads", 1) - trySetField(modelCfg, "debug", false) - trySetField(modelCfg, "provider", "cpu") - - val cfgCtor = clsOnlinePuncConfig!!.getDeclaredConstructor(clsOnlinePuncModelConfig) - val cfg = cfgCtor.newInstance(modelCfg) - - val ctor = clsOnlinePunctuation!!.getDeclaredConstructor( - android.content.res.AssetManager::class.java, - clsOnlinePuncConfig - ) - val inst = ctor.newInstance(null, cfg) - ReflectiveOnlinePunctuation(inst, clsOnlinePunctuation!!) - } catch (t: Throwable) { - Log.e(TAG, "Failed to create OnlinePunctuation instance", t) - null + + private fun createOnlineInstance( + cnnBilstm: String, + bpeVocab: String?, + ): ReflectiveOnlinePunctuation? { + return try { + val modelCfg = clsOnlinePuncModelConfig!!.getDeclaredConstructor().newInstance() + trySetField(modelCfg, "cnnBilstm", cnnBilstm) + if (!bpeVocab.isNullOrBlank()) { + trySetField(modelCfg, "bpeVocab", bpeVocab) + } + trySetField(modelCfg, "numThreads", 1) + trySetField(modelCfg, "debug", false) + trySetField(modelCfg, "provider", "cpu") + + val cfgCtor = clsOnlinePuncConfig!!.getDeclaredConstructor(clsOnlinePuncModelConfig) + val cfg = cfgCtor.newInstance(modelCfg) + + val ctor = clsOnlinePunctuation!!.getDeclaredConstructor( + android.content.res.AssetManager::class.java, + clsOnlinePuncConfig, + ) + val inst = ctor.newInstance(null, cfg) + ReflectiveOnlinePunctuation(inst, clsOnlinePunctuation!!) + } catch (t: Throwable) { + Log.e(TAG, "Failed to create OnlinePunctuation instance", t) + null + } } - } } private class ReflectiveOfflinePunctuation( - private val instance: Any, - private val cls: Class<*> + private val instance: Any, + private val cls: Class<*>, ) { - fun addPunctuation(text: String): String? { - return try { - cls.getMethod("addPunctuation", String::class.java) - .invoke(instance, text) as? String - } catch (t: Throwable) { - Log.e("ReflectiveOfflinePunc", "addPunctuation failed", t) - null + fun addPunctuation(text: String): String? { + return try { + cls.getMethod("addPunctuation", String::class.java) + .invoke(instance, text) as? String + } catch (t: Throwable) { + Log.e("ReflectiveOfflinePunc", "addPunctuation failed", t) + null + } } - } - fun release() { - try { - cls.getMethod("release").invoke(instance) - } catch (t: Throwable) { - Log.e("ReflectiveOfflinePunc", "release failed", t) + fun release() { + try { + cls.getMethod("release").invoke(instance) + } catch (t: Throwable) { + Log.e("ReflectiveOfflinePunc", "release failed", t) + } } - } } private class ReflectiveOnlinePunctuation( - private val instance: Any, - private val cls: Class<*> + private val instance: Any, + private val cls: Class<*>, ) { - fun addPunctuation(text: String): String? { - return try { - cls.getMethod("addPunctuation", String::class.java) - .invoke(instance, text) as? String - } catch (t: Throwable) { - Log.e("ReflectiveOnlinePunc", "addPunctuation failed", t) - null + fun addPunctuation(text: String): String? { + return try { + cls.getMethod("addPunctuation", String::class.java) + .invoke(instance, text) as? String + } catch (t: Throwable) { + Log.e("ReflectiveOnlinePunc", "addPunctuation failed", t) + null + } } - } - fun release() { - try { - cls.getMethod("release").invoke(instance) - } catch (t: Throwable) { - Log.e("ReflectiveOnlinePunc", "release failed", t) + fun release() { + try { + cls.getMethod("release").invoke(instance) + } catch (t: Throwable) { + Log.e("ReflectiveOnlinePunc", "release failed", t) + } } - } } diff --git a/app/src/main/java/com/brycewg/asrkb/asr/SiliconFlowFileAsrEngine.kt b/app/src/main/java/com/brycewg/asrkb/asr/SiliconFlowFileAsrEngine.kt index 814273f0..2b99fadf 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/SiliconFlowFileAsrEngine.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/SiliconFlowFileAsrEngine.kt @@ -35,7 +35,7 @@ class SiliconFlowFileAsrEngine( prefs: Prefs, listener: StreamingAsrEngine.Listener, onRequestDuration: ((Long) -> Unit)? = null, - httpClient: OkHttpClient? = null + httpClient: OkHttpClient? = null, ) : BaseFileAsrEngine(context, scope, prefs, listener, onRequestDuration), PcmBatchRecognizer { companion object { @@ -87,7 +87,7 @@ class SiliconFlowFileAsrEngine( } } catch (t: Throwable) { listener.onError( - context.getString(R.string.error_recognize_failed_with_reason, t.message ?: "") + context.getString(R.string.error_recognize_failed_with_reason, t.message ?: ""), ) } } @@ -104,7 +104,7 @@ class SiliconFlowFileAsrEngine( .addFormDataPart( "file", "audio.wav", - tmp.asRequestBody("audio/wav".toMediaType()) + tmp.asRequestBody("audio/wav".toMediaType()), ) .build() @@ -120,17 +120,21 @@ class SiliconFlowFileAsrEngine( if (!r.isSuccessful) { val detail = formatHttpDetail(r.message, null) listener.onError( - context.getString(R.string.error_request_failed_http, r.code, detail) + context.getString(R.string.error_request_failed_http, r.code, detail), ) return } val text = try { val obj = JSONObject(bodyStr) obj.optString("text", "") - } catch (_: Throwable) { "" } + } catch (_: Throwable) { + "" + } if (text.isNotBlank()) { val dt = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - t0) - try { onRequestDuration?.invoke(dt) } catch (_: Throwable) {} + try { + onRequestDuration?.invoke(dt) + } catch (_: Throwable) {} listener.onFinal(text) } else { listener.onError(context.getString(R.string.error_asr_empty_result)) @@ -174,14 +178,16 @@ class SiliconFlowFileAsrEngine( if (!r.isSuccessful) { val detail = formatHttpDetail(r.message, null) listener.onError( - context.getString(R.string.error_request_failed_http, r.code, detail) + context.getString(R.string.error_request_failed_http, r.code, detail), ) return } val text = parseSfChatText(str) if (text.isNotBlank()) { val dt = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - t0) - try { onRequestDuration?.invoke(dt) } catch (_: Throwable) {} + try { + onRequestDuration?.invoke(dt) + } catch (_: Throwable) {} listener.onFinal(text) } else { listener.onError(context.getString(R.string.error_asr_empty_result)) @@ -197,7 +203,7 @@ class SiliconFlowFileAsrEngine( .addFormDataPart( "file", "audio.wav", - tmp.asRequestBody("audio/wav".toMediaType()) + tmp.asRequestBody("audio/wav".toMediaType()), ) .build() val request = Request.Builder() @@ -210,7 +216,7 @@ class SiliconFlowFileAsrEngine( if (!r.isSuccessful) { val detail = formatHttpDetail(r.message, null) listener.onError( - context.getString(R.string.error_request_failed_http, r.code, detail) + context.getString(R.string.error_request_failed_http, r.code, detail), ) return } @@ -218,10 +224,14 @@ class SiliconFlowFileAsrEngine( val text = try { val obj = JSONObject(bodyStr) obj.optString("text", "") - } catch (_: Throwable) { "" } + } catch (_: Throwable) { + "" + } if (text.isNotBlank()) { val dt = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - t0) - try { onRequestDuration?.invoke(dt) } catch (_: Throwable) {} + try { + onRequestDuration?.invoke(dt) + } catch (_: Throwable) {} listener.onFinal(text) } else { listener.onError(context.getString(R.string.error_asr_empty_result)) @@ -230,8 +240,9 @@ class SiliconFlowFileAsrEngine( } } - override suspend fun recognizeFromPcm(pcm: ByteArray) { recognize(pcm) } - + override suspend fun recognizeFromPcm(pcm: ByteArray) { + recognize(pcm) + } /** * 构建 SiliconFlow Chat Completions API 请求体 @@ -239,31 +250,45 @@ class SiliconFlowFileAsrEngine( private fun buildSfChatCompletionsBody(model: String, base64Wav: String, prompt: String): String { val audioPart = JSONObject().apply { put("type", "audio_url") - put("audio_url", JSONObject().apply { - put("url", "data:audio/wav;base64,$base64Wav") - }) + put( + "audio_url", + JSONObject().apply { + put("url", "data:audio/wav;base64,$base64Wav") + }, + ) } val system = JSONObject().apply { put("role", "system") - put("content", org.json.JSONArray().apply { - put(JSONObject().apply { - put("type", "text") - put("text", prompt) - }) - }) + put( + "content", + org.json.JSONArray().apply { + put( + JSONObject().apply { + put("type", "text") + put("text", prompt) + }, + ) + }, + ) } val user = JSONObject().apply { put("role", "user") - put("content", org.json.JSONArray().apply { - put(audioPart) - }) + put( + "content", + org.json.JSONArray().apply { + put(audioPart) + }, + ) } return JSONObject().apply { put("model", model) - put("messages", org.json.JSONArray().apply { - put(system) - put(user) - }) + put( + "messages", + org.json.JSONArray().apply { + put(system) + put(user) + }, + ) }.toString() } diff --git a/app/src/main/java/com/brycewg/asrkb/asr/SonioxFileAsrEngine.kt b/app/src/main/java/com/brycewg/asrkb/asr/SonioxFileAsrEngine.kt index afb19dcd..70f9e710 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/SonioxFileAsrEngine.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/SonioxFileAsrEngine.kt @@ -27,7 +27,7 @@ class SonioxFileAsrEngine( prefs: Prefs, listener: StreamingAsrEngine.Listener, onRequestDuration: ((Long) -> Unit)? = null, - httpClient: OkHttpClient? = null + httpClient: OkHttpClient? = null, ) : BaseFileAsrEngine(context, scope, prefs, listener, onRequestDuration), PcmBatchRecognizer { companion object { @@ -65,26 +65,30 @@ class SonioxFileAsrEngine( if (text.isNotBlank()) { val dt = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - t0) - try { onRequestDuration?.invoke(dt) } catch (_: Throwable) {} + try { + onRequestDuration?.invoke(dt) + } catch (_: Throwable) {} listener.onFinal(text) } else { listener.onError(context.getString(R.string.error_asr_empty_result)) } } catch (t: Throwable) { listener.onError( - context.getString(R.string.error_recognize_failed_with_reason, t.message ?: "") + context.getString(R.string.error_recognize_failed_with_reason, t.message ?: ""), ) } } - override suspend fun recognizeFromPcm(pcm: ByteArray) { recognize(pcm) } + override suspend fun recognizeFromPcm(pcm: ByteArray) { + recognize(pcm) + } private fun uploadAudioFile(apiKey: String, file: File): String { val multipart = MultipartBody.Builder().setType(MultipartBody.FORM) .addFormDataPart( "file", file.name, - file.asRequestBody("audio/wav".toMediaType()) + file.asRequestBody("audio/wav".toMediaType()), ) .build() val req = Request.Builder() @@ -99,7 +103,11 @@ class SonioxFileAsrEngine( val detail = formatHttpDetail(r.message, extractErrorHint(body)) throw RuntimeException(context.getString(R.string.error_request_failed_http, r.code, detail)) } - val id = try { JSONObject(body).optString("id").trim() } catch (_: Throwable) { "" } + val id = try { + JSONObject(body).optString("id").trim() + } catch (_: Throwable) { + "" + } if (id.isBlank()) throw RuntimeException("uploadAudio: empty file id") return id } @@ -134,7 +142,11 @@ class SonioxFileAsrEngine( val detail = formatHttpDetail(r.message, extractErrorHint(body)) throw RuntimeException(context.getString(R.string.error_request_failed_http, r.code, detail)) } - val id = try { JSONObject(body).optString("id").trim() } catch (_: Throwable) { "" } + val id = try { + JSONObject(body).optString("id").trim() + } catch (_: Throwable) { + "" + } if (id.isBlank()) throw RuntimeException("createTranscription: empty id") return id } @@ -154,11 +166,19 @@ class SonioxFileAsrEngine( val detail = formatHttpDetail(r.message, extractErrorHint(body)) throw RuntimeException(context.getString(R.string.error_request_failed_http, r.code, detail)) } - val status = try { JSONObject(body).optString("status").lowercase() } catch (_: Throwable) { "" } + val status = try { + JSONObject(body).optString("status").lowercase() + } catch (_: Throwable) { + "" + } when (status) { "completed" -> return "error" -> { - val err = try { JSONObject(body).optString("error_message") } catch (_: Throwable) { "" } + val err = try { + JSONObject(body).optString("error_message") + } catch (_: Throwable) { + "" + } throw RuntimeException("Soniox error: $err") } } diff --git a/app/src/main/java/com/brycewg/asrkb/asr/SonioxStreamAsrEngine.kt b/app/src/main/java/com/brycewg/asrkb/asr/SonioxStreamAsrEngine.kt index 43d9c7ce..392e4b0f 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/SonioxStreamAsrEngine.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/SonioxStreamAsrEngine.kt @@ -4,8 +4,6 @@ import android.Manifest import android.content.Context import android.content.pm.PackageManager import android.media.AudioFormat -import android.media.AudioRecord -import android.media.MediaRecorder import android.util.Log import androidx.core.content.ContextCompat import com.brycewg.asrkb.R @@ -22,9 +20,9 @@ import okhttp3.WebSocket import okhttp3.WebSocketListener import okio.ByteString import org.json.JSONObject +import java.util.ArrayDeque import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicBoolean -import java.util.ArrayDeque /** * Soniox WebSocket 实时 ASR 引擎实现。 @@ -36,7 +34,7 @@ class SonioxStreamAsrEngine( private val scope: CoroutineScope, private val prefs: Prefs, private val listener: StreamingAsrEngine.Listener, - private val externalPcmMode: Boolean = false + private val externalPcmMode: Boolean = false, ) : StreamingAsrEngine, ExternalPcmConsumer { companion object { @@ -74,7 +72,7 @@ class SonioxStreamAsrEngine( if (!externalPcmMode) { val hasPermission = ContextCompat.checkSelfPermission( context, - Manifest.permission.RECORD_AUDIO + Manifest.permission.RECORD_AUDIO, ) == PackageManager.PERMISSION_GRANTED if (!hasPermission) { listener.onError(context.getString(R.string.error_record_permission_denied)) @@ -138,70 +136,75 @@ class SonioxStreamAsrEngine( val req = Request.Builder() .url(Prefs.SONIOX_WS_URL) .build() - ws = http.newWebSocket(req, object : WebSocketListener() { - override fun onOpen(webSocket: WebSocket, response: Response) { - Log.d(TAG, "WebSocket opened, sending config") - try { - // 发送配置 - val config = buildConfigJson() - webSocket.send(config) - // 标记 WS 已就绪,录音线程将冲刷预缓冲并进入实时发送 - wsReady.set(true) - Log.d(TAG, "WebSocket ready, audio streaming can begin") - if (!running.get() && awaitingFinal.get()) { - flushPrebufferAndRequestEnd("stop_before_ready") + ws = http.newWebSocket( + req, + object : WebSocketListener() { + override fun onOpen(webSocket: WebSocket, response: Response) { + Log.d(TAG, "WebSocket opened, sending config") + try { + // 发送配置 + val config = buildConfigJson() + webSocket.send(config) + // 标记 WS 已就绪,录音线程将冲刷预缓冲并进入实时发送 + wsReady.set(true) + Log.d(TAG, "WebSocket ready, audio streaming can begin") + if (!running.get() && awaitingFinal.get()) { + flushPrebufferAndRequestEnd("stop_before_ready") + } + } catch (t: Throwable) { + Log.e(TAG, "Failed to send config", t) + val message = context.getString(R.string.error_recognize_failed_with_reason, t.message ?: "") + notifyError(message, "config", t) + running.set(false) + awaitingFinal.set(false) + finalizeJob?.cancel() + closeWebSocket("config_error") } - } catch (t: Throwable) { - Log.e(TAG, "Failed to send config", t) - val message = context.getString(R.string.error_recognize_failed_with_reason, t.message ?: "") - notifyError(message, "config", t) - running.set(false) - awaitingFinal.set(false) - finalizeJob?.cancel() - closeWebSocket("config_error") } - } - - override fun onMessage(webSocket: WebSocket, text: String) { - handleMessage(text) - } - override fun onMessage(webSocket: WebSocket, bytes: ByteString) { - // Soniox 文档为文本 JSON 响应;此处保底兼容 - try { handleMessage(bytes.utf8()) } catch (t: Throwable) { - Log.e(TAG, "Failed to decode binary message", t) + override fun onMessage(webSocket: WebSocket, text: String) { + handleMessage(text) } - } - override fun onFailure(webSocket: WebSocket, t: Throwable, response: Response?) { - val message = context.getString(R.string.error_recognize_failed_with_reason, t.message ?: "") - if (running.get() || awaitingFinal.get()) { - notifyError(message, "failure", t) + override fun onMessage(webSocket: WebSocket, bytes: ByteString) { + // Soniox 文档为文本 JSON 响应;此处保底兼容 + try { + handleMessage(bytes.utf8()) + } catch (t: Throwable) { + Log.e(TAG, "Failed to decode binary message", t) + } } - running.set(false) - awaitingFinal.set(false) - finalizeJob?.cancel() - closeWebSocket("failure") - } - override fun onClosed(webSocket: WebSocket, code: Int, reason: String) { - Log.d(TAG, "WebSocket closed with code $code: $reason") - if (awaitingFinal.get()) { - val message = if (reason.isNotBlank()) { - context.getString(R.string.error_recognize_failed_with_reason, reason) - } else { - context.getString(R.string.error_asr_timeout) + override fun onFailure(webSocket: WebSocket, t: Throwable, response: Response?) { + val message = context.getString(R.string.error_recognize_failed_with_reason, t.message ?: "") + if (running.get() || awaitingFinal.get()) { + notifyError(message, "failure", t) } - notifyError(message, "closed_without_final", null) + running.set(false) awaitingFinal.set(false) + finalizeJob?.cancel() + closeWebSocket("failure") } - running.set(false) - finalizeJob?.cancel() - ws = null - wsReady.set(false) - synchronized(prebufferLock) { prebuffer.clear() } - } - }) + + override fun onClosed(webSocket: WebSocket, code: Int, reason: String) { + Log.d(TAG, "WebSocket closed with code $code: $reason") + if (awaitingFinal.get()) { + val message = if (reason.isNotBlank()) { + context.getString(R.string.error_recognize_failed_with_reason, reason) + } else { + context.getString(R.string.error_asr_timeout) + } + notifyError(message, "closed_without_final", null) + awaitingFinal.set(false) + } + running.set(false) + finalizeJob?.cancel() + ws = null + wsReady.set(false) + synchronized(prebufferLock) { prebuffer.clear() } + } + }, + ) } /** @@ -221,7 +224,7 @@ class SonioxStreamAsrEngine( sampleRate = sampleRate, channelConfig = channelConfig, audioFormat = audioFormat, - chunkMillis = chunkMillis + chunkMillis = chunkMillis, ) if (!audioManager.hasPermission()) { @@ -232,9 +235,11 @@ class SonioxStreamAsrEngine( } // 长按说话模式下由用户松手决定停止,绕过 VAD 自动判停 - val vadDetector = if (isVadAutoStopEnabled(context, prefs)) + val vadDetector = if (isVadAutoStopEnabled(context, prefs)) { VadDetector(context, sampleRate, prefs.autoStopSilenceWindowMs, prefs.autoStopSilenceSensitivity) - else null + } else { + null + } val maxFrames = (2000 / chunkMillis).coerceAtLeast(1) // 预缓冲≈2s @@ -254,7 +259,9 @@ class SonioxStreamAsrEngine( // VAD 自动判停 if (vadDetector?.shouldStop(audioChunk, audioChunk.size) == true) { Log.d(TAG, "Silence detected, stopping recording") - try { listener.onStopped() } catch (t: Throwable) { + try { + listener.onStopped() + } catch (t: Throwable) { Log.e(TAG, "Failed to notify stopped", t) } stop() @@ -341,13 +348,18 @@ class SonioxStreamAsrEngine( override fun appendPcm(pcm: ByteArray, sampleRate: Int, channels: Int) { if (!running.get()) return if (sampleRate != 16000 || channels != 1) return - try { listener.onAmplitude(com.brycewg.asrkb.asr.calculateNormalizedAmplitude(pcm)) } catch (_: Throwable) {} + try { + listener.onAmplitude(com.brycewg.asrkb.asr.calculateNormalizedAmplitude(pcm)) + } catch (_: Throwable) {} if (!wsReady.get()) { synchronized(prebufferLock) { prebuffer.addLast(pcm.copyOf()) } } else { var flushed: Array? = null synchronized(prebufferLock) { - if (prebuffer.isNotEmpty()) { flushed = prebuffer.toTypedArray(); prebuffer.clear() } + if (prebuffer.isNotEmpty()) { + flushed = prebuffer.toTypedArray() + prebuffer.clear() + } } val socket = ws ?: return flushed?.forEach { b -> kotlin.runCatching { socket.send(ByteString.of(*b)) } } @@ -406,7 +418,9 @@ class SonioxStreamAsrEngine( val stable = stripEndMarker(finalTextBuffer.toString()) val preview = if (nonFinal.isNotEmpty()) { stripEndMarker(mergeWithOverlapDedup(stable, nonFinal.toString())) - } else stable + } else { + stable + } // 仅在会话运行中才发出中间预览;用户 stop() 后可能仍有零星 non-final 到达,需忽略以避免重复追加 if (preview.isNotEmpty() && running.get()) listener.onPartial(preview) @@ -538,5 +552,4 @@ class SonioxStreamAsrEngine( if (s.isEmpty()) return s return s.replace("", "").replace("", "").trimEnd() } - } diff --git a/app/src/main/java/com/brycewg/asrkb/asr/TelespeechFileAsrEngine.kt b/app/src/main/java/com/brycewg/asrkb/asr/TelespeechFileAsrEngine.kt index eddd8048..4caaca68 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/TelespeechFileAsrEngine.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/TelespeechFileAsrEngine.kt @@ -24,237 +24,237 @@ import kotlinx.coroutines.sync.withLock * - 仅支持本地离线识别,不支持流式。 */ class TelespeechFileAsrEngine( - context: Context, - scope: CoroutineScope, - prefs: Prefs, - listener: StreamingAsrEngine.Listener, - onRequestDuration: ((Long) -> Unit)? = null + context: Context, + scope: CoroutineScope, + prefs: Prefs, + listener: StreamingAsrEngine.Listener, + onRequestDuration: ((Long) -> Unit)? = null, ) : BaseFileAsrEngine(context, scope, prefs, listener, onRequestDuration), PcmBatchRecognizer { - // TeleSpeech 本地:同 SenseVoice,默认限制为 5 分钟以控制内存与处理时长 - override val maxRecordDurationMillis: Int = 5 * 60 * 1000 + // TeleSpeech 本地:同 SenseVoice,默认限制为 5 分钟以控制内存与处理时长 + override val maxRecordDurationMillis: Int = 5 * 60 * 1000 - private fun showToast(resId: Int) { - try { - Handler(Looper.getMainLooper()).post { + private fun showToast(resId: Int) { try { - Toast.makeText(context, context.getString(resId), Toast.LENGTH_SHORT).show() + Handler(Looper.getMainLooper()).post { + try { + Toast.makeText(context, context.getString(resId), Toast.LENGTH_SHORT).show() + } catch (t: Throwable) { + Log.e("TelespeechFileAsrEngine", "Failed to show toast", t) + } + } } catch (t: Throwable) { - Log.e("TelespeechFileAsrEngine", "Failed to show toast", t) + Log.e("TelespeechFileAsrEngine", "Failed to post toast", t) } - } - } catch (t: Throwable) { - Log.e("TelespeechFileAsrEngine", "Failed to post toast", t) } - } - - private fun notifyLoadStart() { - val ui = (listener as? SenseVoiceFileAsrEngine.LocalModelLoadUi) - if (ui != null) { - try { - ui.onLocalModelLoadStart() - } catch (t: Throwable) { - Log.e("TelespeechFileAsrEngine", "Failed to notify load start", t) - } - } else { - // 复用通用“本地模型加载中”文案 - showToast(R.string.sv_loading_model) - } - } - - private fun notifyLoadDone() { - val ui = (listener as? SenseVoiceFileAsrEngine.LocalModelLoadUi) - if (ui != null) { - try { - ui.onLocalModelLoadDone() - } catch (t: Throwable) { - Log.e("TelespeechFileAsrEngine", "Failed to notify load done", t) - } + + private fun notifyLoadStart() { + val ui = (listener as? SenseVoiceFileAsrEngine.LocalModelLoadUi) + if (ui != null) { + try { + ui.onLocalModelLoadStart() + } catch (t: Throwable) { + Log.e("TelespeechFileAsrEngine", "Failed to notify load start", t) + } + } else { + // 复用通用“本地模型加载中”文案 + showToast(R.string.sv_loading_model) + } } - } - override fun ensureReady(): Boolean { - if (!super.ensureReady()) return false - val manager = TelespeechOnnxManager.getInstance() - if (!manager.isOnnxAvailable()) { - try { - listener.onError(context.getString(R.string.error_local_asr_not_ready)) - } catch (t: Throwable) { - Log.e("TelespeechFileAsrEngine", "Failed to send error callback", t) - } - return false + private fun notifyLoadDone() { + val ui = (listener as? SenseVoiceFileAsrEngine.LocalModelLoadUi) + if (ui != null) { + try { + ui.onLocalModelLoadDone() + } catch (t: Throwable) { + Log.e("TelespeechFileAsrEngine", "Failed to notify load done", t) + } + } } - return true - } - override suspend fun recognize(pcm: ByteArray) { - val t0 = System.currentTimeMillis() - try { - // 调用前若缺少通用标点模型,则给出一次性提示(不影响识别流程) - try { - SherpaPunctuationManager.maybeWarnModelMissing(context) - } catch (t: Throwable) { - Log.w("TelespeechFileAsrEngine", "Failed to warn punctuation model missing", t) - } - - val manager = TelespeechOnnxManager.getInstance() - if (!manager.isOnnxAvailable()) { - listener.onError(context.getString(R.string.error_local_asr_not_ready)) - return - } - - val base = try { - context.getExternalFilesDir(null) - } catch (t: Throwable) { - Log.w("TelespeechFileAsrEngine", "Failed to get external files dir", t) - null - } ?: context.filesDir - - val probeRoot = java.io.File(base, "telespeech") - val variant = try { - prefs.tsModelVariant - } catch (t: Throwable) { - Log.w("TelespeechFileAsrEngine", "Failed to get TeleSpeech variant", t) - "int8" - } - val variantDir = when (variant) { - "full" -> java.io.File(probeRoot, "full") - else -> java.io.File(probeRoot, "int8") - } - val auto = findTsModelDir(variantDir) ?: findTsModelDir(probeRoot) - if (auto == null) { - listener.onError(context.getString(R.string.error_telespeech_model_missing)) - return - } - val dir = auto.absolutePath - - val tokensPath = java.io.File(dir, "tokens.txt").absolutePath - val int8File = java.io.File(dir, "model.int8.onnx") - val f32File = java.io.File(dir, "model.onnx") - val modelFile = when { - int8File.exists() -> int8File - f32File.exists() -> f32File - else -> null - } - val modelPath = modelFile?.absolutePath - val minBytes = 8L * 1024L * 1024L - if (modelPath == null || !java.io.File(tokensPath).exists() || (modelFile?.length() ?: 0L) < minBytes) { - listener.onError(context.getString(R.string.error_telespeech_model_missing)) - return - } - - val samples = pcmToFloatArray(pcm) - if (samples.isEmpty()) { - listener.onError(context.getString(R.string.error_audio_empty)) - return - } - - val keepMinutes = try { - prefs.tsKeepAliveMinutes - } catch (t: Throwable) { - Log.w("TelespeechFileAsrEngine", "Failed to get keep alive minutes", t) - -1 - } - val keepMs = if (keepMinutes <= 0) 0L else keepMinutes.toLong() * 60_000L - val alwaysKeep = keepMinutes < 0 - - val text = manager.decodeOffline( - assetManager = null, - tokens = tokensPath, - model = modelPath, - provider = "cpu", - numThreads = try { - prefs.tsNumThreads - } catch (t: Throwable) { - Log.w("TelespeechFileAsrEngine", "Failed to get num threads", t) - 2 - }, - samples = samples, - sampleRate = sampleRate, - keepAliveMs = keepMs, - alwaysKeep = alwaysKeep, - onLoadStart = { notifyLoadStart() }, - onLoadDone = { notifyLoadDone() } - ) - - if (text.isNullOrBlank()) { - listener.onError(context.getString(R.string.error_asr_empty_result)) - } else { - val raw = text.trim() - val useItn = try { - prefs.tsUseItn - } catch (t: Throwable) { - Log.w("TelespeechFileAsrEngine", "Failed to get tsUseItn", t) - false + override fun ensureReady(): Boolean { + if (!super.ensureReady()) return false + val manager = TelespeechOnnxManager.getInstance() + if (!manager.isOnnxAvailable()) { + try { + listener.onError(context.getString(R.string.error_local_asr_not_ready)) + } catch (t: Throwable) { + Log.e("TelespeechFileAsrEngine", "Failed to send error callback", t) + } + return false } - val normalized = if (useItn) ChineseItn.normalize(raw) else raw - val finalText = try { - SherpaPunctuationManager.getInstance().addOfflinePunctuation(context, normalized) + return true + } + + override suspend fun recognize(pcm: ByteArray) { + val t0 = System.currentTimeMillis() + try { + // 调用前若缺少通用标点模型,则给出一次性提示(不影响识别流程) + try { + SherpaPunctuationManager.maybeWarnModelMissing(context) + } catch (t: Throwable) { + Log.w("TelespeechFileAsrEngine", "Failed to warn punctuation model missing", t) + } + + val manager = TelespeechOnnxManager.getInstance() + if (!manager.isOnnxAvailable()) { + listener.onError(context.getString(R.string.error_local_asr_not_ready)) + return + } + + val base = try { + context.getExternalFilesDir(null) + } catch (t: Throwable) { + Log.w("TelespeechFileAsrEngine", "Failed to get external files dir", t) + null + } ?: context.filesDir + + val probeRoot = java.io.File(base, "telespeech") + val variant = try { + prefs.tsModelVariant + } catch (t: Throwable) { + Log.w("TelespeechFileAsrEngine", "Failed to get TeleSpeech variant", t) + "int8" + } + val variantDir = when (variant) { + "full" -> java.io.File(probeRoot, "full") + else -> java.io.File(probeRoot, "int8") + } + val auto = findTsModelDir(variantDir) ?: findTsModelDir(probeRoot) + if (auto == null) { + listener.onError(context.getString(R.string.error_telespeech_model_missing)) + return + } + val dir = auto.absolutePath + + val tokensPath = java.io.File(dir, "tokens.txt").absolutePath + val int8File = java.io.File(dir, "model.int8.onnx") + val f32File = java.io.File(dir, "model.onnx") + val modelFile = when { + int8File.exists() -> int8File + f32File.exists() -> f32File + else -> null + } + val modelPath = modelFile?.absolutePath + val minBytes = 8L * 1024L * 1024L + if (modelPath == null || !java.io.File(tokensPath).exists() || (modelFile?.length() ?: 0L) < minBytes) { + listener.onError(context.getString(R.string.error_telespeech_model_missing)) + return + } + + val samples = pcmToFloatArray(pcm) + if (samples.isEmpty()) { + listener.onError(context.getString(R.string.error_audio_empty)) + return + } + + val keepMinutes = try { + prefs.tsKeepAliveMinutes + } catch (t: Throwable) { + Log.w("TelespeechFileAsrEngine", "Failed to get keep alive minutes", t) + -1 + } + val keepMs = if (keepMinutes <= 0) 0L else keepMinutes.toLong() * 60_000L + val alwaysKeep = keepMinutes < 0 + + val text = manager.decodeOffline( + assetManager = null, + tokens = tokensPath, + model = modelPath, + provider = "cpu", + numThreads = try { + prefs.tsNumThreads + } catch (t: Throwable) { + Log.w("TelespeechFileAsrEngine", "Failed to get num threads", t) + 2 + }, + samples = samples, + sampleRate = sampleRate, + keepAliveMs = keepMs, + alwaysKeep = alwaysKeep, + onLoadStart = { notifyLoadStart() }, + onLoadDone = { notifyLoadDone() }, + ) + + if (text.isNullOrBlank()) { + listener.onError(context.getString(R.string.error_asr_empty_result)) + } else { + val raw = text.trim() + val useItn = try { + prefs.tsUseItn + } catch (t: Throwable) { + Log.w("TelespeechFileAsrEngine", "Failed to get tsUseItn", t) + false + } + val normalized = if (useItn) ChineseItn.normalize(raw) else raw + val finalText = try { + SherpaPunctuationManager.getInstance().addOfflinePunctuation(context, normalized) + } catch (t: Throwable) { + Log.e("TelespeechFileAsrEngine", "Failed to apply offline punctuation", t) + normalized + } + listener.onFinal(finalText) + } } catch (t: Throwable) { - Log.e("TelespeechFileAsrEngine", "Failed to apply offline punctuation", t) - normalized + Log.e("TelespeechFileAsrEngine", "Recognition failed", t) + listener.onError(context.getString(R.string.error_recognize_failed_with_reason, t.message ?: "")) + } finally { + val dt = System.currentTimeMillis() - t0 + try { + onRequestDuration?.invoke(dt) + } catch (t: Throwable) { + Log.e("TelespeechFileAsrEngine", "Failed to invoke duration callback", t) + } } - listener.onFinal(finalText) - } - } catch (t: Throwable) { - Log.e("TelespeechFileAsrEngine", "Recognition failed", t) - listener.onError(context.getString(R.string.error_recognize_failed_with_reason, t.message ?: "")) - } finally { - val dt = System.currentTimeMillis() - t0 - try { - onRequestDuration?.invoke(dt) - } catch (t: Throwable) { - Log.e("TelespeechFileAsrEngine", "Failed to invoke duration callback", t) - } } - } - - override suspend fun recognizeFromPcm(pcm: ByteArray) { - recognize(pcm) - } - - private fun pcmToFloatArray(pcm: ByteArray): FloatArray { - if (pcm.isEmpty()) return FloatArray(0) - val n = pcm.size / 2 - val out = FloatArray(n) - val bb = java.nio.ByteBuffer.wrap(pcm).order(java.nio.ByteOrder.LITTLE_ENDIAN) - var i = 0 - while (i < n) { - val s = bb.short.toInt() - var f = s / 32768.0f - if (f > 1f) f = 1f else if (f < -1f) f = -1f - out[i] = f - i++ + + override suspend fun recognizeFromPcm(pcm: ByteArray) { + recognize(pcm) + } + + private fun pcmToFloatArray(pcm: ByteArray): FloatArray { + if (pcm.isEmpty()) return FloatArray(0) + val n = pcm.size / 2 + val out = FloatArray(n) + val bb = java.nio.ByteBuffer.wrap(pcm).order(java.nio.ByteOrder.LITTLE_ENDIAN) + var i = 0 + while (i < n) { + val s = bb.short.toInt() + var f = s / 32768.0f + if (f > 1f) f = 1f else if (f < -1f) f = -1f + out[i] = f + i++ + } + return out } - return out - } } // 公开卸载入口:供设置页在清除模型后释放本地识别器内存 fun unloadTelespeechRecognizer() { - LocalModelLoadCoordinator.cancel() - TelespeechOnnxManager.getInstance().unload() + LocalModelLoadCoordinator.cancel() + TelespeechOnnxManager.getInstance().unload() } // 判断是否已有缓存的本地识别器(已加载或正在加载中) fun isTelespeechPrepared(): Boolean { - val manager = TelespeechOnnxManager.getInstance() - return manager.isPrepared() || manager.isPreparing() + val manager = TelespeechOnnxManager.getInstance() + return manager.isPrepared() || manager.isPreparing() } // TeleSpeech 模型目录探测:与 SenseVoice 一致,查找含 tokens.txt 的目录 fun findTsModelDir(root: java.io.File?): java.io.File? { - if (root == null || !root.exists()) return null - val direct = java.io.File(root, "tokens.txt") - if (direct.exists()) return root - val subs = root.listFiles() ?: return null - for (f in subs) { - if (f.isDirectory) { - val t = java.io.File(f, "tokens.txt") - if (t.exists()) return f + if (root == null || !root.exists()) return null + val direct = java.io.File(root, "tokens.txt") + if (direct.exists()) return root + val subs = root.listFiles() ?: return null + for (f in subs) { + if (f.isDirectory) { + val t = java.io.File(f, "tokens.txt") + if (t.exists()) return f + } } - } - return null + return null } /** @@ -262,423 +262,430 @@ fun findTsModelDir(root: java.io.File?): java.io.File? { */ class TelespeechOnnxManager private constructor() { - companion object { - private const val TAG = "TelespeechOnnxManager" + companion object { + private const val TAG = "TelespeechOnnxManager" + + @Volatile + private var instance: TelespeechOnnxManager? = null + + fun getInstance(): TelespeechOnnxManager { + return instance ?: synchronized(this) { + instance ?: TelespeechOnnxManager().also { instance = it } + } + } + } + + private val scope = CoroutineScope(SupervisorJob()) + private val mutex = Mutex() + + @Volatile + private var cachedConfig: RecognizerConfig? = null + + @Volatile + private var cachedRecognizer: ReflectiveRecognizer? = null + + @Volatile + private var preparing: Boolean = false + + @Volatile + private var clsOfflineRecognizer: Class<*>? = null + + @Volatile + private var clsOfflineRecognizerConfig: Class<*>? = null + + @Volatile + private var clsOfflineModelConfig: Class<*>? = null + + @Volatile + private var clsFeatureConfig: Class<*>? = null + + @Volatile + private var unloadJob: Job? = null + + @Volatile + private var lastKeepAliveMs: Long = 0L @Volatile - private var instance: TelespeechOnnxManager? = null + private var lastAlwaysKeep: Boolean = false - fun getInstance(): TelespeechOnnxManager { - return instance ?: synchronized(this) { - instance ?: TelespeechOnnxManager().also { instance = it } - } + fun isOnnxAvailable(): Boolean { + return try { + Class.forName("com.k2fsa.sherpa.onnx.OfflineRecognizer") + true + } catch (t: Throwable) { + Log.d(TAG, "sherpa-onnx not available", t) + false + } } - } - - private val scope = CoroutineScope(SupervisorJob()) - private val mutex = Mutex() - - @Volatile - private var cachedConfig: RecognizerConfig? = null - @Volatile - private var cachedRecognizer: ReflectiveRecognizer? = null - - @Volatile - private var preparing: Boolean = false - @Volatile - private var clsOfflineRecognizer: Class<*>? = null - @Volatile - private var clsOfflineRecognizerConfig: Class<*>? = null - @Volatile - private var clsOfflineModelConfig: Class<*>? = null - @Volatile - private var clsFeatureConfig: Class<*>? = null - @Volatile - private var unloadJob: Job? = null - - @Volatile - private var lastKeepAliveMs: Long = 0L - @Volatile - private var lastAlwaysKeep: Boolean = false - - fun isOnnxAvailable(): Boolean { - return try { - Class.forName("com.k2fsa.sherpa.onnx.OfflineRecognizer") - true - } catch (t: Throwable) { - Log.d(TAG, "sherpa-onnx not available", t) - false + + fun unload() { + val snapshot = cachedRecognizer ?: return + scope.launch { + val shouldRelease = mutex.withLock { + if (cachedRecognizer !== snapshot) return@withLock false + cachedRecognizer = null + cachedConfig = null + unloadJob?.cancel() + unloadJob = null + true + } + if (shouldRelease) { + try { + snapshot.release() + } catch (t: Throwable) { + Log.e(TAG, "Failed to release recognizer on unload", t) + } + Log.d(TAG, "Recognizer unloaded") + } + } } - } - - fun unload() { - val snapshot = cachedRecognizer ?: return - scope.launch { - val shouldRelease = mutex.withLock { - if (cachedRecognizer !== snapshot) return@withLock false - cachedRecognizer = null - cachedConfig = null + + fun isPrepared(): Boolean = cachedRecognizer != null + + fun isPreparing(): Boolean = preparing + + private fun scheduleAutoUnload(keepAliveMs: Long, alwaysKeep: Boolean) { unloadJob?.cancel() - unloadJob = null - true - } - if (shouldRelease) { - try { - snapshot.release() - } catch (t: Throwable) { - Log.e(TAG, "Failed to release recognizer on unload", t) + if (alwaysKeep) { + Log.d(TAG, "Recognizer will be kept alive indefinitely") + return + } + if (keepAliveMs <= 0L) { + Log.d(TAG, "Auto-unloading immediately (keepAliveMs=$keepAliveMs)") + unload() + return + } + Log.d(TAG, "Scheduling auto-unload in ${keepAliveMs}ms") + unloadJob = scope.launch { + delay(keepAliveMs) + Log.d(TAG, "Auto-unloading recognizer after timeout") + unload() } - Log.d(TAG, "Recognizer unloaded") - } } - } - fun isPrepared(): Boolean = cachedRecognizer != null - - fun isPreparing(): Boolean = preparing + private data class RecognizerConfig( + val tokens: String, + val model: String, + val provider: String, + val numThreads: Int, + val sampleRate: Int, + val featureDim: Int, + ) + + private fun initClasses() { + if (clsOfflineRecognizer == null) { + clsOfflineRecognizer = Class.forName("com.k2fsa.sherpa.onnx.OfflineRecognizer") + clsOfflineRecognizerConfig = Class.forName("com.k2fsa.sherpa.onnx.OfflineRecognizerConfig") + clsOfflineModelConfig = Class.forName("com.k2fsa.sherpa.onnx.OfflineModelConfig") + clsFeatureConfig = Class.forName("com.k2fsa.sherpa.onnx.FeatureConfig") + Log.d(TAG, "Initialized sherpa-onnx reflection classes for TeleSpeech") + } + } - private fun scheduleAutoUnload(keepAliveMs: Long, alwaysKeep: Boolean) { - unloadJob?.cancel() - if (alwaysKeep) { - Log.d(TAG, "Recognizer will be kept alive indefinitely") - return + private fun trySetField(target: Any, name: String, value: Any?): Boolean { + return try { + val f = target.javaClass.getDeclaredField(name) + f.isAccessible = true + f.set(target, value) + true + } catch (t: Throwable) { + try { + val methodName = "set" + name.replaceFirstChar { + if (it.isLowerCase()) it.titlecase() else it.toString() + } + val m = if (value == null) { + target.javaClass.getMethod(methodName, Any::class.java) + } else { + target.javaClass.getMethod(methodName, value.javaClass) + } + m.invoke(target, value) + true + } catch (t2: Throwable) { + Log.w(TAG, "Failed to set field '$name'", t2) + false + } + } } - if (keepAliveMs <= 0L) { - Log.d(TAG, "Auto-unloading immediately (keepAliveMs=$keepAliveMs)") - unload() - return + + private fun buildFeatureConfig(sampleRate: Int, featureDim: Int): Any { + val feat = clsFeatureConfig!!.getDeclaredConstructor().newInstance() + trySetField(feat, "sampleRate", sampleRate) + trySetField(feat, "featureDim", featureDim) + return feat } - Log.d(TAG, "Scheduling auto-unload in ${keepAliveMs}ms") - unloadJob = scope.launch { - delay(keepAliveMs) - Log.d(TAG, "Auto-unloading recognizer after timeout") - unload() + + private fun buildModelConfig(tokens: String, model: String, numThreads: Int, provider: String): Any { + val modelConfig = clsOfflineModelConfig!!.getDeclaredConstructor().newInstance() + trySetField(modelConfig, "tokens", tokens) + trySetField(modelConfig, "numThreads", numThreads) + trySetField(modelConfig, "provider", provider) + trySetField(modelConfig, "debug", false) + trySetField(modelConfig, "teleSpeech", model) + trySetField(modelConfig, "modelType", "telespeech_ctc") + return modelConfig } - } - - private data class RecognizerConfig( - val tokens: String, - val model: String, - val provider: String, - val numThreads: Int, - val sampleRate: Int, - val featureDim: Int - ) - - private fun initClasses() { - if (clsOfflineRecognizer == null) { - clsOfflineRecognizer = Class.forName("com.k2fsa.sherpa.onnx.OfflineRecognizer") - clsOfflineRecognizerConfig = Class.forName("com.k2fsa.sherpa.onnx.OfflineRecognizerConfig") - clsOfflineModelConfig = Class.forName("com.k2fsa.sherpa.onnx.OfflineModelConfig") - clsFeatureConfig = Class.forName("com.k2fsa.sherpa.onnx.FeatureConfig") - Log.d(TAG, "Initialized sherpa-onnx reflection classes for TeleSpeech") + + private fun buildRecognizerConfig(config: RecognizerConfig): Any { + val modelConfig = buildModelConfig(config.tokens, config.model, config.numThreads, config.provider) + val featConfig = buildFeatureConfig(config.sampleRate, config.featureDim) + val recConfig = clsOfflineRecognizerConfig!!.getDeclaredConstructor().newInstance() + if (!trySetField(recConfig, "modelConfig", modelConfig)) { + trySetField(recConfig, "model_config", modelConfig) + } + if (!trySetField(recConfig, "featConfig", featConfig)) { + trySetField(recConfig, "feat_config", featConfig) + } + trySetField(recConfig, "decodingMethod", "greedy_search") + trySetField(recConfig, "maxActivePaths", 4) + return recConfig } - } - - private fun trySetField(target: Any, name: String, value: Any?): Boolean { - return try { - val f = target.javaClass.getDeclaredField(name) - f.isAccessible = true - f.set(target, value) - true - } catch (t: Throwable) { - try { - val methodName = "set" + name.replaceFirstChar { - if (it.isLowerCase()) it.titlecase() else it.toString() + + private fun createRecognizer(assetManager: android.content.res.AssetManager?, recConfig: Any): Any { + val ctor = if (assetManager == null) { + try { + clsOfflineRecognizer!!.getDeclaredConstructor(clsOfflineRecognizerConfig) + } catch (t: Throwable) { + Log.d(TAG, "No single-param constructor, using AssetManager variant", t) + clsOfflineRecognizer!!.getDeclaredConstructor( + android.content.res.AssetManager::class.java, + clsOfflineRecognizerConfig, + ) + } + } else { + try { + clsOfflineRecognizer!!.getDeclaredConstructor( + android.content.res.AssetManager::class.java, + clsOfflineRecognizerConfig, + ) + } catch (t: Throwable) { + Log.d(TAG, "No AssetManager constructor, using single-param variant", t) + clsOfflineRecognizer!!.getDeclaredConstructor(clsOfflineRecognizerConfig) + } } - val m = if (value == null) { - target.javaClass.getMethod(methodName, Any::class.java) + return if (ctor.parameterCount == 2) { + ctor.newInstance(assetManager, recConfig) } else { - target.javaClass.getMethod(methodName, value.javaClass) + ctor.newInstance(recConfig) } - m.invoke(target, value) - true - } catch (t2: Throwable) { - Log.w(TAG, "Failed to set field '$name'", t2) - false - } - } - } - - private fun buildFeatureConfig(sampleRate: Int, featureDim: Int): Any { - val feat = clsFeatureConfig!!.getDeclaredConstructor().newInstance() - trySetField(feat, "sampleRate", sampleRate) - trySetField(feat, "featureDim", featureDim) - return feat - } - - private fun buildModelConfig(tokens: String, model: String, numThreads: Int, provider: String): Any { - val modelConfig = clsOfflineModelConfig!!.getDeclaredConstructor().newInstance() - trySetField(modelConfig, "tokens", tokens) - trySetField(modelConfig, "numThreads", numThreads) - trySetField(modelConfig, "provider", provider) - trySetField(modelConfig, "debug", false) - trySetField(modelConfig, "teleSpeech", model) - trySetField(modelConfig, "modelType", "telespeech_ctc") - return modelConfig - } - - private fun buildRecognizerConfig(config: RecognizerConfig): Any { - val modelConfig = buildModelConfig(config.tokens, config.model, config.numThreads, config.provider) - val featConfig = buildFeatureConfig(config.sampleRate, config.featureDim) - val recConfig = clsOfflineRecognizerConfig!!.getDeclaredConstructor().newInstance() - if (!trySetField(recConfig, "modelConfig", modelConfig)) { - trySetField(recConfig, "model_config", modelConfig) - } - if (!trySetField(recConfig, "featConfig", featConfig)) { - trySetField(recConfig, "feat_config", featConfig) - } - trySetField(recConfig, "decodingMethod", "greedy_search") - trySetField(recConfig, "maxActivePaths", 4) - return recConfig - } - - private fun createRecognizer(assetManager: android.content.res.AssetManager?, recConfig: Any): Any { - val ctor = if (assetManager == null) { - try { - clsOfflineRecognizer!!.getDeclaredConstructor(clsOfflineRecognizerConfig) - } catch (t: Throwable) { - Log.d(TAG, "No single-param constructor, using AssetManager variant", t) - clsOfflineRecognizer!!.getDeclaredConstructor( - android.content.res.AssetManager::class.java, - clsOfflineRecognizerConfig - ) - } - } else { - try { - clsOfflineRecognizer!!.getDeclaredConstructor( - android.content.res.AssetManager::class.java, - clsOfflineRecognizerConfig - ) - } catch (t: Throwable) { - Log.d(TAG, "No AssetManager constructor, using single-param variant", t) - clsOfflineRecognizer!!.getDeclaredConstructor(clsOfflineRecognizerConfig) - } - } - return if (ctor.parameterCount == 2) { - ctor.newInstance(assetManager, recConfig) - } else { - ctor.newInstance(recConfig) } - } - private fun releaseRecognizerSafely(recognizer: ReflectiveRecognizer?, reason: String) { - if (recognizer == null) return - try { - recognizer.release() - } catch (t: Throwable) { - Log.e(TAG, "Failed to release recognizer ($reason)", t) + private fun releaseRecognizerSafely(recognizer: ReflectiveRecognizer?, reason: String) { + if (recognizer == null) return + try { + recognizer.release() + } catch (t: Throwable) { + Log.e(TAG, "Failed to release recognizer ($reason)", t) + } } - } - private fun invokeCallbackSafely(name: String, callback: (() -> Unit)?) { - if (callback == null) return - try { - callback() - } catch (t: Throwable) { - Log.e(TAG, "$name callback failed", t) + private fun invokeCallbackSafely(name: String, callback: (() -> Unit)?) { + if (callback == null) return + try { + callback() + } catch (t: Throwable) { + Log.e(TAG, "$name callback failed", t) + } } - } - - private suspend fun ensurePreparedLocked( - assetManager: android.content.res.AssetManager?, - cfg: RecognizerConfig, - onLoadStart: (() -> Unit)?, - onLoadDone: (() -> Unit)? - ): ReflectiveRecognizer? { - initClasses() - val cached = cachedRecognizer - if (cached != null && cachedConfig == cfg) return cached - - preparing = true - unloadJob?.cancel() - unloadJob = null - - var newRecognizer: ReflectiveRecognizer? = null - try { - currentCoroutineContext().ensureActive() - invokeCallbackSafely("onLoadStart", onLoadStart) - currentCoroutineContext().ensureActive() - - val recConfig = buildRecognizerConfig(cfg) - currentCoroutineContext().ensureActive() - val raw = createRecognizer(assetManager, recConfig) - newRecognizer = ReflectiveRecognizer(raw, clsOfflineRecognizer!!) - currentCoroutineContext().ensureActive() - - val oldRecognizer = cachedRecognizer - cachedRecognizer = newRecognizer - cachedConfig = cfg - invokeCallbackSafely("onLoadDone", onLoadDone) - if (oldRecognizer != null && oldRecognizer !== newRecognizer) { - releaseRecognizerSafely(oldRecognizer, "old") - } - return newRecognizer - } catch (t: CancellationException) { - releaseRecognizerSafely(newRecognizer, "canceled") - throw t - } catch (t: Throwable) { - releaseRecognizerSafely(newRecognizer, "failed") - throw t - } finally { - preparing = false + + private suspend fun ensurePreparedLocked( + assetManager: android.content.res.AssetManager?, + cfg: RecognizerConfig, + onLoadStart: (() -> Unit)?, + onLoadDone: (() -> Unit)?, + ): ReflectiveRecognizer? { + initClasses() + val cached = cachedRecognizer + if (cached != null && cachedConfig == cfg) return cached + + preparing = true + unloadJob?.cancel() + unloadJob = null + + var newRecognizer: ReflectiveRecognizer? = null + try { + currentCoroutineContext().ensureActive() + invokeCallbackSafely("onLoadStart", onLoadStart) + currentCoroutineContext().ensureActive() + + val recConfig = buildRecognizerConfig(cfg) + currentCoroutineContext().ensureActive() + val raw = createRecognizer(assetManager, recConfig) + newRecognizer = ReflectiveRecognizer(raw, clsOfflineRecognizer!!) + currentCoroutineContext().ensureActive() + + val oldRecognizer = cachedRecognizer + cachedRecognizer = newRecognizer + cachedConfig = cfg + invokeCallbackSafely("onLoadDone", onLoadDone) + if (oldRecognizer != null && oldRecognizer !== newRecognizer) { + releaseRecognizerSafely(oldRecognizer, "old") + } + return newRecognizer + } catch (t: CancellationException) { + releaseRecognizerSafely(newRecognizer, "canceled") + throw t + } catch (t: Throwable) { + releaseRecognizerSafely(newRecognizer, "failed") + throw t + } finally { + preparing = false + } } - } - - suspend fun decodeOffline( - assetManager: android.content.res.AssetManager?, - tokens: String, - model: String, - provider: String, - numThreads: Int, - samples: FloatArray, - sampleRate: Int, - keepAliveMs: Long, - alwaysKeep: Boolean, - onLoadStart: (() -> Unit)? = null, - onLoadDone: (() -> Unit)? = null - ): String? = mutex.withLock { - try { - val cfg = RecognizerConfig( - tokens = tokens, - model = model, - provider = provider, - numThreads = numThreads, - sampleRate = sampleRate, - featureDim = 40 - ) - val recognizer = ensurePreparedLocked(assetManager, cfg, onLoadStart, onLoadDone) - ?: return@withLock null - lastKeepAliveMs = keepAliveMs - lastAlwaysKeep = alwaysKeep - val stream = recognizer.createStream() - try { - stream.acceptWaveform(samples, sampleRate) - val text = recognizer.decode(stream) - scheduleAutoUnload(keepAliveMs, alwaysKeep) - return@withLock text - } finally { - stream.release() - } - } catch (t: CancellationException) { - throw t - } catch (t: Throwable) { - Log.e(TAG, "Failed to decode offline TeleSpeech: ${t.message}", t) - return@withLock null + + suspend fun decodeOffline( + assetManager: android.content.res.AssetManager?, + tokens: String, + model: String, + provider: String, + numThreads: Int, + samples: FloatArray, + sampleRate: Int, + keepAliveMs: Long, + alwaysKeep: Boolean, + onLoadStart: (() -> Unit)? = null, + onLoadDone: (() -> Unit)? = null, + ): String? = mutex.withLock { + try { + val cfg = RecognizerConfig( + tokens = tokens, + model = model, + provider = provider, + numThreads = numThreads, + sampleRate = sampleRate, + featureDim = 40, + ) + val recognizer = ensurePreparedLocked(assetManager, cfg, onLoadStart, onLoadDone) + ?: return@withLock null + lastKeepAliveMs = keepAliveMs + lastAlwaysKeep = alwaysKeep + val stream = recognizer.createStream() + try { + stream.acceptWaveform(samples, sampleRate) + val text = recognizer.decode(stream) + scheduleAutoUnload(keepAliveMs, alwaysKeep) + return@withLock text + } finally { + stream.release() + } + } catch (t: CancellationException) { + throw t + } catch (t: Throwable) { + Log.e(TAG, "Failed to decode offline TeleSpeech: ${t.message}", t) + return@withLock null + } } - } - - suspend fun prepare( - assetManager: android.content.res.AssetManager?, - tokens: String, - model: String, - provider: String, - numThreads: Int, - keepAliveMs: Long, - alwaysKeep: Boolean, - onLoadStart: (() -> Unit)? = null, - onLoadDone: (() -> Unit)? = null - ): Boolean = mutex.withLock { - try { - val cfg = RecognizerConfig( - tokens = tokens, - model = model, - provider = provider, - numThreads = numThreads, - sampleRate = 16000, - featureDim = 40 - ) - val ok = ensurePreparedLocked(assetManager, cfg, onLoadStart, onLoadDone) != null - if (!ok) return@withLock false - lastKeepAliveMs = keepAliveMs - lastAlwaysKeep = alwaysKeep - true - } catch (t: CancellationException) { - throw t - } catch (t: Throwable) { - Log.e(TAG, "Failed to prepare TeleSpeech recognizer: ${t.message}", t) - false + + suspend fun prepare( + assetManager: android.content.res.AssetManager?, + tokens: String, + model: String, + provider: String, + numThreads: Int, + keepAliveMs: Long, + alwaysKeep: Boolean, + onLoadStart: (() -> Unit)? = null, + onLoadDone: (() -> Unit)? = null, + ): Boolean = mutex.withLock { + try { + val cfg = RecognizerConfig( + tokens = tokens, + model = model, + provider = provider, + numThreads = numThreads, + sampleRate = 16000, + featureDim = 40, + ) + val ok = ensurePreparedLocked(assetManager, cfg, onLoadStart, onLoadDone) != null + if (!ok) return@withLock false + lastKeepAliveMs = keepAliveMs + lastAlwaysKeep = alwaysKeep + true + } catch (t: CancellationException) { + throw t + } catch (t: Throwable) { + Log.e(TAG, "Failed to prepare TeleSpeech recognizer: ${t.message}", t) + false + } } - } } /** * TeleSpeech 预加载:根据当前配置尝试构建本地识别器,便于降低首次点击等待 */ fun preloadTelespeechIfConfigured( - context: Context, - prefs: Prefs, - onLoadStart: (() -> Unit)? = null, - onLoadDone: (() -> Unit)? = null, - suppressToastOnStart: Boolean = false, - forImmediateUse: Boolean = false + context: Context, + prefs: Prefs, + onLoadStart: (() -> Unit)? = null, + onLoadDone: (() -> Unit)? = null, + suppressToastOnStart: Boolean = false, + forImmediateUse: Boolean = false, ) { - try { - val manager = TelespeechOnnxManager.getInstance() - if (!manager.isOnnxAvailable()) return - - val base = context.getExternalFilesDir(null) ?: context.filesDir - val probeRoot = java.io.File(base, "telespeech") - val variantDir = when (prefs.tsModelVariant) { - "full" -> java.io.File(probeRoot, "full") - else -> java.io.File(probeRoot, "int8") - } - val modelDir = findTsModelDir(variantDir) ?: findTsModelDir(probeRoot) ?: return - val tokensPath = java.io.File(modelDir, "tokens.txt").absolutePath - val int8File = java.io.File(modelDir, "model.int8.onnx") - val f32File = java.io.File(modelDir, "model.onnx") - val modelFile = when { - int8File.exists() -> int8File - f32File.exists() -> f32File - else -> return - } - val modelPath = modelFile.absolutePath - val minBytes = 8L * 1024L * 1024L - if (!java.io.File(tokensPath).exists() || modelFile.length() < minBytes) return - val keepMinutes = prefs.tsKeepAliveMinutes - val keepMs = if (keepMinutes <= 0) 0L else keepMinutes.toLong() * 60_000L - val alwaysKeep = keepMinutes < 0 - - val numThreads = prefs.tsNumThreads - val key = "telespeech|tokens=$tokensPath|model=$modelPath|provider=cpu|threads=$numThreads" - - val mainHandler = Handler(Looper.getMainLooper()) - LocalModelLoadCoordinator.request(key) { - val t0 = android.os.SystemClock.uptimeMillis() - val ok = manager.prepare( - assetManager = null, - tokens = tokensPath, - model = modelPath, - provider = "cpu", - numThreads = numThreads, - keepAliveMs = keepMs, - alwaysKeep = alwaysKeep, - onLoadStart = { - if (!suppressToastOnStart) { - mainHandler.post { - Toast.makeText( - context, - context.getString(R.string.sv_loading_model), - Toast.LENGTH_SHORT - ).show() + try { + val manager = TelespeechOnnxManager.getInstance() + if (!manager.isOnnxAvailable()) return + + val base = context.getExternalFilesDir(null) ?: context.filesDir + val probeRoot = java.io.File(base, "telespeech") + val variantDir = when (prefs.tsModelVariant) { + "full" -> java.io.File(probeRoot, "full") + else -> java.io.File(probeRoot, "int8") + } + val modelDir = findTsModelDir(variantDir) ?: findTsModelDir(probeRoot) ?: return + val tokensPath = java.io.File(modelDir, "tokens.txt").absolutePath + val int8File = java.io.File(modelDir, "model.int8.onnx") + val f32File = java.io.File(modelDir, "model.onnx") + val modelFile = when { + int8File.exists() -> int8File + f32File.exists() -> f32File + else -> return + } + val modelPath = modelFile.absolutePath + val minBytes = 8L * 1024L * 1024L + if (!java.io.File(tokensPath).exists() || modelFile.length() < minBytes) return + val keepMinutes = prefs.tsKeepAliveMinutes + val keepMs = if (keepMinutes <= 0) 0L else keepMinutes.toLong() * 60_000L + val alwaysKeep = keepMinutes < 0 + + val numThreads = prefs.tsNumThreads + val key = "telespeech|tokens=$tokensPath|model=$modelPath|provider=cpu|threads=$numThreads" + + val mainHandler = Handler(Looper.getMainLooper()) + LocalModelLoadCoordinator.request(key) { + val t0 = android.os.SystemClock.uptimeMillis() + val ok = manager.prepare( + assetManager = null, + tokens = tokensPath, + model = modelPath, + provider = "cpu", + numThreads = numThreads, + keepAliveMs = keepMs, + alwaysKeep = alwaysKeep, + onLoadStart = { + if (!suppressToastOnStart) { + mainHandler.post { + Toast.makeText( + context, + context.getString(R.string.sv_loading_model), + Toast.LENGTH_SHORT, + ).show() + } + } + onLoadStart?.invoke() + }, + onLoadDone = onLoadDone, + ) + if (ok && !forImmediateUse) { + val dt = (android.os.SystemClock.uptimeMillis() - t0).coerceAtLeast(0) + mainHandler.post { + Toast.makeText( + context, + context.getString(R.string.sv_model_ready_with_ms, dt), + Toast.LENGTH_SHORT, + ).show() + } } - } - onLoadStart?.invoke() - }, - onLoadDone = onLoadDone - ) - if (ok && !forImmediateUse) { - val dt = (android.os.SystemClock.uptimeMillis() - t0).coerceAtLeast(0) - mainHandler.post { - Toast.makeText( - context, - context.getString(R.string.sv_model_ready_with_ms, dt), - Toast.LENGTH_SHORT - ).show() } - } + } catch (t: Throwable) { + Log.e("TelespeechFileAsrEngine", "Failed to preload TeleSpeech model", t) } - } catch (t: Throwable) { - Log.e("TelespeechFileAsrEngine", "Failed to preload TeleSpeech model", t) - } } diff --git a/app/src/main/java/com/brycewg/asrkb/asr/TelespeechPseudoStreamAsrEngine.kt b/app/src/main/java/com/brycewg/asrkb/asr/TelespeechPseudoStreamAsrEngine.kt index 3253b5e3..60917f23 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/TelespeechPseudoStreamAsrEngine.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/TelespeechPseudoStreamAsrEngine.kt @@ -14,7 +14,7 @@ class TelespeechPseudoStreamAsrEngine( scope: CoroutineScope, prefs: Prefs, listener: StreamingAsrEngine.Listener, - onRequestDuration: ((Long) -> Unit)? = null + onRequestDuration: ((Long) -> Unit)? = null, ) : LocalModelPseudoStreamAsrEngine(context, scope, prefs, listener, onRequestDuration) { companion object { @@ -28,7 +28,7 @@ class TelespeechPseudoStreamAsrEngine( listener = listener, sampleRate = sampleRate, onRequestDuration = onRequestDuration, - tag = TAG + tag = TAG, ) override fun ensureReady(): Boolean { diff --git a/app/src/main/java/com/brycewg/asrkb/asr/TelespeechPseudoStreamDelegate.kt b/app/src/main/java/com/brycewg/asrkb/asr/TelespeechPseudoStreamDelegate.kt index db0c9fc9..7782fd7b 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/TelespeechPseudoStreamDelegate.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/TelespeechPseudoStreamDelegate.kt @@ -22,7 +22,7 @@ internal class TelespeechPseudoStreamDelegate( private val listener: StreamingAsrEngine.Listener, private val sampleRate: Int, private val onRequestDuration: ((Long) -> Unit)?, - private val tag: String + private val tag: String, ) { private val previewMutex = Mutex() @@ -157,8 +157,8 @@ internal class TelespeechPseudoStreamDelegate( listener.onError( context.getString( R.string.error_recognize_failed_with_reason, - t.message ?: "" - ) + t.message ?: "", + ), ) } catch (e: Throwable) { Log.e(tag, "Failed to notify final recognition error", e) @@ -191,7 +191,7 @@ internal class TelespeechPseudoStreamDelegate( Toast.makeText( context, context.getString(R.string.sv_loading_model), - Toast.LENGTH_SHORT + Toast.LENGTH_SHORT, ).show() } catch (t: Throwable) { Log.e(tag, "Failed to show toast", t) @@ -216,7 +216,7 @@ internal class TelespeechPseudoStreamDelegate( private suspend fun decodeOnce( pcm: ByteArray, - reportErrorToUser: Boolean + reportErrorToUser: Boolean, ): String? { val manager = TelespeechOnnxManager.getInstance() if (!manager.isOnnxAvailable()) { @@ -325,7 +325,7 @@ internal class TelespeechPseudoStreamDelegate( keepAliveMs = keepMs, alwaysKeep = alwaysKeep, onLoadStart = { notifyLoadStart() }, - onLoadDone = { notifyLoadDone() } + onLoadDone = { notifyLoadDone() }, ) if (text.isNullOrBlank()) { diff --git a/app/src/main/java/com/brycewg/asrkb/asr/TelespeechPushPcmPseudoStreamAsrEngine.kt b/app/src/main/java/com/brycewg/asrkb/asr/TelespeechPushPcmPseudoStreamAsrEngine.kt index 35f44a8a..10683e1a 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/TelespeechPushPcmPseudoStreamAsrEngine.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/TelespeechPushPcmPseudoStreamAsrEngine.kt @@ -12,59 +12,59 @@ import java.util.concurrent.atomic.AtomicLong * - finishPcm/stop 时对整段音频做一次离线识别(onFinal)。 */ class TelespeechPushPcmPseudoStreamAsrEngine( - context: Context, - scope: CoroutineScope, - prefs: Prefs, - listener: StreamingAsrEngine.Listener, - onRequestDuration: ((Long) -> Unit)? = null + context: Context, + scope: CoroutineScope, + prefs: Prefs, + listener: StreamingAsrEngine.Listener, + onRequestDuration: ((Long) -> Unit)? = null, ) : PushPcmPseudoStreamAsrEngine(context, scope, prefs, listener, onRequestDuration) { - companion object { - private const val TAG = "TsPushPcmPseudo" - } + companion object { + private const val TAG = "TsPushPcmPseudo" + } - private val delegate = TelespeechPseudoStreamDelegate( - context = context, - scope = scope, - prefs = prefs, - listener = listener, - sampleRate = sampleRate, - onRequestDuration = onRequestDuration, - tag = TAG - ) + private val delegate = TelespeechPseudoStreamDelegate( + context = context, + scope = scope, + prefs = prefs, + listener = listener, + sampleRate = sampleRate, + onRequestDuration = onRequestDuration, + tag = TAG, + ) - private val sessionIdGenerator = AtomicLong(0L) + private val sessionIdGenerator = AtomicLong(0L) - @Volatile - private var activeSessionId: Long = 0L + @Volatile + private var activeSessionId: Long = 0L - @Volatile - private var finishingSessionId: Long = 0L + @Volatile + private var finishingSessionId: Long = 0L - override fun start() { - val wasRunning = isRunning - super.start() - if (wasRunning || !isRunning) return - val sessionId = sessionIdGenerator.incrementAndGet() - activeSessionId = sessionId - finishingSessionId = sessionId - delegate.onSessionStart(sessionId) - } + override fun start() { + val wasRunning = isRunning + super.start() + if (wasRunning || !isRunning) return + val sessionId = sessionIdGenerator.incrementAndGet() + activeSessionId = sessionId + finishingSessionId = sessionId + delegate.onSessionStart(sessionId) + } - override fun stop() { - finishingSessionId = activeSessionId - super.stop() - } + override fun stop() { + finishingSessionId = activeSessionId + super.stop() + } - override fun ensureReady(): Boolean { - return delegate.ensureReady() - } + override fun ensureReady(): Boolean { + return delegate.ensureReady() + } - override fun onSegmentBoundary(pcmSegment: ByteArray) { - delegate.onSegmentBoundary(activeSessionId, pcmSegment) - } + override fun onSegmentBoundary(pcmSegment: ByteArray) { + delegate.onSegmentBoundary(activeSessionId, pcmSegment) + } - override suspend fun onSessionFinished(fullPcm: ByteArray) { - delegate.onSessionFinished(finishingSessionId, fullPcm) - } + override suspend fun onSessionFinished(fullPcm: ByteArray) { + delegate.onSessionFinished(finishingSessionId, fullPcm) + } } diff --git a/app/src/main/java/com/brycewg/asrkb/asr/VadDetector.kt b/app/src/main/java/com/brycewg/asrkb/asr/VadDetector.kt index d07928b7..17966544 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/VadDetector.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/VadDetector.kt @@ -3,7 +3,6 @@ package com.brycewg.asrkb.asr import android.content.Context import android.util.Log import com.brycewg.asrkb.store.Prefs -import com.brycewg.asrkb.ui.floating.FloatingAsrService import com.k2fsa.sherpa.onnx.TenVadModelConfig import com.k2fsa.sherpa.onnx.Vad import com.k2fsa.sherpa.onnx.VadModelConfig @@ -31,7 +30,7 @@ class VadDetector( private val context: Context, private val sampleRate: Int, private val windowMs: Int, - sensitivityLevel: Int + sensitivityLevel: Int, ) { /** * 单帧分析结果: @@ -40,7 +39,7 @@ class VadDetector( */ data class FrameResult( val isSpeech: Boolean, - val silenceStop: Boolean + val silenceStop: Boolean, ) companion object { @@ -53,26 +52,37 @@ class VadDetector( // 调整为更宽松的分段:低档位给更长的静音要求,减少“提前中断”。 // 规则:sensitivityLevel 越高,minSilenceDuration 越小(更敏感)。 private val MIN_SILENCE_DURATION_MAP: FloatArray = floatArrayOf( - 0.55f, // 1 非常不敏感:至少 0.60s 静音才算非语音 - 0.50f, // 2 - 0.42f, // 3 - 0.35f, // 4 - 0.30f, // 5 - 0.25f, // 6 - 0.20f, // 7(默认附近) - 0.16f, // 8 - 0.12f, // 9 - 0.08f // 10 更敏感 + // 1 非常不敏感:至少 0.60s 静音才算非语音 + 0.55f, + // 2 + 0.50f, + // 3 + 0.42f, + // 4 + 0.35f, + // 5 + 0.30f, + // 6 + 0.25f, + // 7(默认附近) + 0.20f, + // 8 + 0.16f, + // 9 + 0.12f, + // 10 更敏感 + 0.08f, ) private const val MAX_POOL_SIZE = 2 private data class VadPoolKey( val sampleRate: Int, - val sensitivityLevel: Int + val sensitivityLevel: Int, ) private val poolLock = Any() + @Volatile private var poolKey: VadPoolKey? = null private val vadPool: ArrayDeque = ArrayDeque() @@ -91,28 +101,28 @@ class VadDetector( threshold = threshold, minSilenceDuration = minSilenceDuration, minSpeechDuration = 0.25f, - windowSize = 256 + windowSize = 256, ) return VadModelConfig( tenVadModelConfig = tenConfig, sampleRate = sampleRate, numThreads = 1, provider = "cpu", - debug = false + debug = false, ) } private fun createVad(context: Context, sampleRate: Int, sensitivityLevel: Int): Vad { return Vad( assetManager = context.assets, - config = buildVadModelConfig(sampleRate, sensitivityLevel) + config = buildVadModelConfig(sampleRate, sensitivityLevel), ) } private fun acquireFromPool( context: Context, sampleRate: Int, - sensitivityLevel: Int + sensitivityLevel: Int, ): Vad { val lvl = sensitivityLevel.coerceIn(1, LEVELS) val key = VadPoolKey(sampleRate = sampleRate, sensitivityLevel = lvl) @@ -151,7 +161,7 @@ class VadDetector( private fun recycleToPool( key: VadPoolKey, - vad: Vad + vad: Vad, ) { val shouldPool = synchronized(poolLock) { (poolKey == key) && (vadPool.size < MAX_POOL_SIZE) @@ -249,6 +259,7 @@ class VadDetector( private val threshold: Float private val speechHangoverMs: Int private var speechHangoverRemainingMs: Int = 0 + // 录音开始阶段的初期防抖(仅在首次检测到语音之前生效) private val initialDebounceMs: Int private var initialDebounceRemainingMs: Int = 0 @@ -289,7 +300,7 @@ class VadDetector( initVad() Log.i( TAG, - "VadDetector initialized: windowMs=$windowMs, sensitivity=$lvl, minSilenceDuration=$minSilenceDuration, threshold=$threshold, hangoverMs=$speechHangoverMs, initialDebounceMs=$initialDebounceMs" + "VadDetector initialized: windowMs=$windowMs, sensitivity=$lvl, minSilenceDuration=$minSilenceDuration, threshold=$threshold, hangoverMs=$speechHangoverMs, initialDebounceMs=$initialDebounceMs", ) } catch (t: Throwable) { Log.e(TAG, "Failed to initialize VAD, will fallback to no detection", t) @@ -438,7 +449,7 @@ class VadDetector( // Little Endian: 低字节在前 val lo = pcm[i].toInt() and 0xFF val hi = pcm[i + 1].toInt() and 0xFF - val pcmValue = (hi shl 8) or lo // 0..65535 + val pcmValue = (hi shl 8) or lo // 0..65535 // 转为有符号 -32768..32767 val signed = if (pcmValue < 0x8000) pcmValue else pcmValue - 0x10000 @@ -446,8 +457,9 @@ class VadDetector( // 归一化到 -1.0 ~ 1.0 // 使用 32768.0f 避免 -32768 除法溢出,并限制范围 var normalized = signed / 32768.0f - if (normalized > 1.0f) normalized = 1.0f - else if (normalized < -1.0f) normalized = -1.0f + if (normalized > 1.0f) { + normalized = 1.0f + } else if (normalized < -1.0f) normalized = -1.0f samples[sampleIdx] = normalized diff --git a/app/src/main/java/com/brycewg/asrkb/asr/VolcFileAsrEngine.kt b/app/src/main/java/com/brycewg/asrkb/asr/VolcFileAsrEngine.kt index c2daec5f..dffc86f7 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/VolcFileAsrEngine.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/VolcFileAsrEngine.kt @@ -2,7 +2,6 @@ package com.brycewg.asrkb.asr import android.content.Context import android.util.Base64 -import android.util.Log import com.brycewg.asrkb.R import com.brycewg.asrkb.store.Prefs import kotlinx.coroutines.CoroutineScope @@ -24,13 +23,14 @@ class VolcFileAsrEngine( prefs: Prefs, listener: StreamingAsrEngine.Listener, onRequestDuration: ((Long) -> Unit)? = null, - httpClient: OkHttpClient? = null + httpClient: OkHttpClient? = null, ) : BaseFileAsrEngine(context, scope, prefs, listener, onRequestDuration), PcmBatchRecognizer { companion object { // 文件识别模型 1.0 / 2.0 private const val FILE_RESOURCE_V1 = "volc.bigasr.auc" private const val FILE_RESOURCE_V2 = "volc.seedasr.auc" + // 文件识别极速版(仅 1.0) private const val FILE_RESOURCE_TURBO = "volc.bigasr.auc_turbo" private const val TAG = "VolcFileAsrEngine" @@ -73,7 +73,7 @@ class VolcFileAsrEngine( val msg = resp.header("X-Api-Message") ?: resp.message val detail = formatHttpDetail(msg) listener.onError( - context.getString(R.string.error_request_failed_http, resp.code, detail) + context.getString(R.string.error_request_failed_http, resp.code, detail), ) return } @@ -82,11 +82,17 @@ class VolcFileAsrEngine( val obj = JSONObject(bodyStr) if (obj.has("result")) { obj.getJSONObject("result").optString("text", "") - } else "" - } catch (_: Throwable) { "" } + } else { + "" + } + } catch (_: Throwable) { + "" + } if (text.isNotBlank()) { val dt = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - t0) - try { onRequestDuration?.invoke(dt) } catch (_: Throwable) {} + try { + onRequestDuration?.invoke(dt) + } catch (_: Throwable) {} listener.onFinal(text) } else { listener.onError(context.getString(R.string.error_asr_empty_result)) @@ -94,13 +100,15 @@ class VolcFileAsrEngine( } } catch (t: Throwable) { listener.onError( - context.getString(R.string.error_recognize_failed_with_reason, t.message ?: "") + context.getString(R.string.error_recognize_failed_with_reason, t.message ?: ""), ) } } // 供“推送 PCM 适配器”调用,直接复用现有实现 - override suspend fun recognizeFromPcm(pcm: ByteArray) { recognize(pcm) } + override suspend fun recognizeFromPcm(pcm: ByteArray) { + recognize(pcm) + } /** * 构建火山引擎 API 请求体 diff --git a/app/src/main/java/com/brycewg/asrkb/asr/VolcStandardFileAsrEngine.kt b/app/src/main/java/com/brycewg/asrkb/asr/VolcStandardFileAsrEngine.kt index 1db038d4..12a26e11 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/VolcStandardFileAsrEngine.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/VolcStandardFileAsrEngine.kt @@ -25,7 +25,7 @@ class VolcStandardFileAsrEngine( prefs: Prefs, listener: StreamingAsrEngine.Listener, onRequestDuration: ((Long) -> Unit)? = null, - httpClient: OkHttpClient? = null + httpClient: OkHttpClient? = null, ) : BaseFileAsrEngine(context, scope, prefs, listener, onRequestDuration), PcmBatchRecognizer { companion object { @@ -77,7 +77,7 @@ class VolcStandardFileAsrEngine( val msg = resp.header("X-Api-Message") ?: resp.message val detail = formatHttpDetail(msg, "status=${status ?: "unknown"}") listener.onError( - context.getString(R.string.error_request_failed_http, resp.code, detail) + context.getString(R.string.error_request_failed_http, resp.code, detail), ) return } @@ -85,7 +85,7 @@ class VolcStandardFileAsrEngine( val msg = resp.header("X-Api-Message") ?: resp.message val detail = formatHttpDetail(msg, "status=${status ?: "unknown"}") listener.onError( - context.getString(R.string.error_request_failed_http, resp.code, detail) + context.getString(R.string.error_request_failed_http, resp.code, detail), ) return } @@ -114,7 +114,7 @@ class VolcStandardFileAsrEngine( } catch (t: Throwable) { Log.e(TAG, "Failed to recognize with Volc standard file API", t) listener.onError( - context.getString(R.string.error_recognize_failed_with_reason, t.message ?: "") + context.getString(R.string.error_recognize_failed_with_reason, t.message ?: ""), ) } } @@ -164,7 +164,7 @@ class VolcStandardFileAsrEngine( val msg = resp.header("X-Api-Message") ?: resp.message val detail = formatHttpDetail(msg, "status=${status ?: "unknown"}") return QueryResult.Failed( - context.getString(R.string.error_request_failed_http, resp.code, detail) + context.getString(R.string.error_request_failed_http, resp.code, detail), ) } val message = resp.header("X-Api-Message") ?: resp.message @@ -185,7 +185,7 @@ class VolcStandardFileAsrEngine( else -> { val detail = formatHttpDetail(message, "status=${status ?: "unknown"}") return QueryResult.Failed( - context.getString(R.string.error_request_failed_http, resp.code, detail) + context.getString(R.string.error_request_failed_http, resp.code, detail), ) } } diff --git a/app/src/main/java/com/brycewg/asrkb/asr/VolcStreamAsrEngine.kt b/app/src/main/java/com/brycewg/asrkb/asr/VolcStreamAsrEngine.kt index b0427131..03aec35f 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/VolcStreamAsrEngine.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/VolcStreamAsrEngine.kt @@ -25,8 +25,8 @@ import org.json.JSONObject import java.io.ByteArrayOutputStream import java.nio.ByteBuffer import java.nio.ByteOrder -import java.util.UUID import java.util.ArrayDeque +import java.util.UUID import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicBoolean import java.util.zip.GZIPInputStream @@ -42,12 +42,13 @@ class VolcStreamAsrEngine( private val scope: CoroutineScope, private val prefs: Prefs, private val listener: StreamingAsrEngine.Listener, - private val externalPcmMode: Boolean = false + private val externalPcmMode: Boolean = false, ) : StreamingAsrEngine, ExternalPcmConsumer { companion object { private const val TAG = "VolcStreamAsrEngine" private const val WS_ENDPOINT_BIDI_ASYNC = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel_async" + // 流式识别模型 1.0 / 2.0 private const val STREAM_RESOURCE_V1 = "volc.bigasr.sauc.duration" private const val STREAM_RESOURCE_V2 = "volc.seedasr.sauc.duration" @@ -106,7 +107,7 @@ class VolcStreamAsrEngine( if (!externalPcmMode) { val hasPermission = ContextCompat.checkSelfPermission( context, - Manifest.permission.RECORD_AUDIO + Manifest.permission.RECORD_AUDIO, ) == PackageManager.PERMISSION_GRANTED if (!hasPermission) { listener.onError(context.getString(R.string.error_record_permission_denied)) @@ -147,7 +148,9 @@ class VolcStreamAsrEngine( } } flushed?.forEach { b -> - try { sendAudioFrame(b, last = false) } catch (_: Throwable) { } + try { + sendAudioFrame(b, last = false) + } catch (_: Throwable) { } } // 发送最后一包标记(空载荷亦可作为结束信号) if (!audioLastSent.get()) { @@ -165,7 +168,9 @@ class VolcStreamAsrEngine( delay(50) left -= 50 } - try { ws?.close(1000, "stop") } catch (_: Throwable) { } + try { + ws?.close(1000, "stop") + } catch (_: Throwable) { } ws = null wsReady.set(false) } @@ -178,70 +183,79 @@ class VolcStreamAsrEngine( .url(WS_ENDPOINT_BIDI_ASYNC) .headers( Headers.headersOf( - "X-Api-App-Key", prefs.appKey, - "X-Api-Access-Key", prefs.accessKey, + "X-Api-App-Key", + prefs.appKey, + "X-Api-Access-Key", + prefs.accessKey, // 使用小时版资源(可根据需要切换并发版) - "X-Api-Resource-Id", streamResource, - "X-Api-Connect-Id", connectId - ) + "X-Api-Resource-Id", + streamResource, + "X-Api-Connect-Id", + connectId, + ), ) .build() - ws = http.newWebSocket(req, object : WebSocketListener() { - override fun onOpen(webSocket: WebSocket, response: Response) { - Log.d(TAG, "WebSocket opened, sending full client request") - // 发送 full client request - try { - val full = buildFullClientRequestJson() - val payload = gzip(full.toByteArray(Charsets.UTF_8)) - val frame = buildClientFrame( - messageType = MSG_TYPE_FULL_CLIENT_REQ, - flags = 0, - serialization = SERIALIZE_JSON, - compression = COMPRESS_GZIP, - payload = payload - ) - webSocket.send(ByteString.of(*frame)) - wsReady.set(true) - Log.d(TAG, "WebSocket ready, audio streaming can begin") - // 若用户在握手期间已 stop(),此处立即冲刷预缓冲并发送最后标记,避免尾段丢失 - if (!running.get() && awaitingFinal.get()) { - flushPrebufferAndSendLast() + ws = http.newWebSocket( + req, + object : WebSocketListener() { + override fun onOpen(webSocket: WebSocket, response: Response) { + Log.d(TAG, "WebSocket opened, sending full client request") + // 发送 full client request + try { + val full = buildFullClientRequestJson() + val payload = gzip(full.toByteArray(Charsets.UTF_8)) + val frame = buildClientFrame( + messageType = MSG_TYPE_FULL_CLIENT_REQ, + flags = 0, + serialization = SERIALIZE_JSON, + compression = COMPRESS_GZIP, + payload = payload, + ) + webSocket.send(ByteString.of(*frame)) + wsReady.set(true) + Log.d(TAG, "WebSocket ready, audio streaming can begin") + // 若用户在握手期间已 stop(),此处立即冲刷预缓冲并发送最后标记,避免尾段丢失 + if (!running.get() && awaitingFinal.get()) { + flushPrebufferAndSendLast() + } + } catch (t: Throwable) { + Log.e(TAG, "Failed to send full client request", t) + listener.onError(context.getString(R.string.error_recognize_failed_with_reason, t.message ?: "")) + stop() + return } - } catch (t: Throwable) { - Log.e(TAG, "Failed to send full client request", t) - listener.onError(context.getString(R.string.error_recognize_failed_with_reason, t.message ?: "")) - stop() - return } - } - override fun onMessage(webSocket: WebSocket, bytes: ByteString) { - try { - handleServerMessage(bytes) - } catch (t: Throwable) { - Log.e(TAG, "Failed to handle server message", t) - listener.onError(context.getString(R.string.error_recognize_failed_with_reason, t.message ?: "")) + override fun onMessage(webSocket: WebSocket, bytes: ByteString) { + try { + handleServerMessage(bytes) + } catch (t: Throwable) { + Log.e(TAG, "Failed to handle server message", t) + listener.onError(context.getString(R.string.error_recognize_failed_with_reason, t.message ?: "")) + } } - } - override fun onClosing(webSocket: WebSocket, code: Int, reason: String) { - Log.d(TAG, "WebSocket closing with code $code: $reason") - try { webSocket.close(code, reason) } catch (t: Throwable) { - Log.e(TAG, "Failed to close WebSocket", t) + override fun onClosing(webSocket: WebSocket, code: Int, reason: String) { + Log.d(TAG, "WebSocket closing with code $code: $reason") + try { + webSocket.close(code, reason) + } catch (t: Throwable) { + Log.e(TAG, "Failed to close WebSocket", t) + } + wsReady.set(false) } - wsReady.set(false) - } - override fun onFailure(webSocket: WebSocket, t: Throwable, response: Response?) { - Log.e(TAG, "WebSocket failed: ${t.message}", t) - wsReady.set(false) - if (running.get()) { - listener.onError(context.getString(R.string.error_recognize_failed_with_reason, t.message ?: "")) + override fun onFailure(webSocket: WebSocket, t: Throwable, response: Response?) { + Log.e(TAG, "WebSocket failed: ${t.message}", t) + wsReady.set(false) + if (running.get()) { + listener.onError(context.getString(R.string.error_recognize_failed_with_reason, t.message ?: "")) + } + running.set(false) } - running.set(false) - } - }) + }, + ) } // ========== ExternalPcmConsumer 实现(外部推流) ========== @@ -251,13 +265,18 @@ class VolcStreamAsrEngine( Log.w(TAG, "ignore frame: sr=$sampleRate ch=$channels") return } - try { listener.onAmplitude(com.brycewg.asrkb.asr.calculateNormalizedAmplitude(pcm)) } catch (_: Throwable) { } + try { + listener.onAmplitude(com.brycewg.asrkb.asr.calculateNormalizedAmplitude(pcm)) + } catch (_: Throwable) { } if (!wsReady.get()) { synchronized(prebufferLock) { prebuffer.addLast(pcm.copyOf()) } } else { var flushed: Array? = null synchronized(prebufferLock) { - if (prebuffer.isNotEmpty()) { flushed = prebuffer.toTypedArray(); prebuffer.clear() } + if (prebuffer.isNotEmpty()) { + flushed = prebuffer.toTypedArray() + prebuffer.clear() + } } flushed?.forEach { b -> kotlin.runCatching { sendAudioFrame(b, last = false) } } kotlin.runCatching { sendAudioFrame(pcm, last = false) } @@ -281,7 +300,7 @@ class VolcStreamAsrEngine( sampleRate = sampleRate, channelConfig = channelConfig, audioFormat = audioFormat, - chunkMillis = chunkMillis + chunkMillis = chunkMillis, ) if (!audioManager.hasPermission()) { @@ -292,9 +311,11 @@ class VolcStreamAsrEngine( } // 长按说话模式下由用户松手决定停止,绕过 VAD 自动判停 - val vadDetector = if (isVadAutoStopEnabled(context, prefs)) + val vadDetector = if (isVadAutoStopEnabled(context, prefs)) { VadDetector(context, sampleRate, prefs.autoStopSilenceWindowMs, prefs.autoStopSilenceSensitivity) - else null + } else { + null + } val maxFrames = (2000 / chunkMillis).coerceAtLeast(1) // 预缓冲上限≈2s @@ -314,7 +335,9 @@ class VolcStreamAsrEngine( // VAD 自动判停 if (vadDetector?.shouldStop(audioChunk, audioChunk.size) == true) { Log.d(TAG, "Silence detected, stopping recording") - try { listener.onStopped() } catch (t: Throwable) { + try { + listener.onStopped() + } catch (t: Throwable) { Log.e(TAG, "Failed to notify stopped", t) } stop() @@ -407,7 +430,7 @@ class VolcStreamAsrEngine( flags = flags, serialization = SERIALIZE_NONE, compression = COMPRESS_GZIP, - payload = payload + payload = payload, ) webSocket.send(ByteString.of(*frame)) } @@ -442,7 +465,6 @@ class VolcStreamAsrEngine( put("force_to_speech_time", 1000) // 说明:配置 end_window_size 后 vad_segment_duration 不生效,这里不再冗余设置 } - } return JSONObject().apply { put("user", user) @@ -484,13 +506,17 @@ class VolcStreamAsrEngine( Log.d(TAG, "Received final result, length: ${text.length}") // 重要:即使服务端最终文本为空也要回调 onFinal(""), // 以便上层(悬浮球/IME)清理 isProcessing 状态并给出友好提示。 - try { listener.onFinal(text) } catch (t: Throwable) { + try { + listener.onFinal(text) + } catch (t: Throwable) { Log.e(TAG, "Failed to notify final result", t) } running.set(false) awaitingFinal.set(false) scope.launch(Dispatchers.IO) { - try { ws?.close(1000, "final") } catch (t: Throwable) { + try { + ws?.close(1000, "final") + } catch (t: Throwable) { Log.e(TAG, "Failed to close WebSocket after final", t) } ws = null @@ -509,7 +535,9 @@ class VolcStreamAsrEngine( val size = readUInt32BE(arr, offset + 4) val start = offset + 8 val end = (start + size).coerceAtMost(arr.size) - val msg = try { String(arr.copyOfRange(start, end), Charsets.UTF_8) } catch (t: Throwable) { + val msg = try { + String(arr.copyOfRange(start, end), Charsets.UTF_8) + } catch (t: Throwable) { Log.e(TAG, "Failed to decode error message", t) "" } @@ -529,8 +557,12 @@ class VolcStreamAsrEngine( if (o.has("result")) { val r = o.getJSONObject("result") r.optString("text", "") - } else "" - } catch (_: Throwable) { "" } + } else { + "" + } + } catch (_: Throwable) { + "" + } } @Suppress("SameParameterValue") @@ -539,7 +571,7 @@ class VolcStreamAsrEngine( flags: Int, serialization: Int, compression: Int, - payload: ByteArray + payload: ByteArray, ): ByteArray { val header = ByteArray(4) header[0] = (((PROTOCOL_VERSION and 0x0F) shl 4) or (HEADER_SIZE_UNITS and 0x0F)).toByte() diff --git a/app/src/main/java/com/brycewg/asrkb/asr/ZhipuFileAsrEngine.kt b/app/src/main/java/com/brycewg/asrkb/asr/ZhipuFileAsrEngine.kt index 578dd6ba..86a35600 100644 --- a/app/src/main/java/com/brycewg/asrkb/asr/ZhipuFileAsrEngine.kt +++ b/app/src/main/java/com/brycewg/asrkb/asr/ZhipuFileAsrEngine.kt @@ -25,7 +25,7 @@ class ZhipuFileAsrEngine( prefs: Prefs, listener: StreamingAsrEngine.Listener, onRequestDuration: ((Long) -> Unit)? = null, - httpClient: OkHttpClient? = null + httpClient: OkHttpClient? = null, ) : BaseFileAsrEngine(context, scope, prefs, listener, onRequestDuration), PcmBatchRecognizer { companion object { @@ -70,7 +70,7 @@ class ZhipuFileAsrEngine( .addFormDataPart( "file", "audio.wav", - tmp.asRequestBody("audio/wav".toMediaType()) + tmp.asRequestBody("audio/wav".toMediaType()), ) // 添加可选的 prompt 参数(用于长文本场景的前文上下文) @@ -94,14 +94,16 @@ class ZhipuFileAsrEngine( val extra = extractErrorHint(bodyStr) val detail = formatHttpDetail(r.message, extra) listener.onError( - context.getString(R.string.error_request_failed_http, r.code, detail) + context.getString(R.string.error_request_failed_http, r.code, detail), ) return } val text = parseTextFromResponse(bodyStr) if (text.isNotBlank()) { val dt = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - t0) - try { onRequestDuration?.invoke(dt) } catch (_: Throwable) {} + try { + onRequestDuration?.invoke(dt) + } catch (_: Throwable) {} listener.onFinal(text) } else { listener.onError(context.getString(R.string.error_asr_empty_result)) @@ -118,12 +120,14 @@ class ZhipuFileAsrEngine( } } catch (t: Throwable) { listener.onError( - context.getString(R.string.error_recognize_failed_with_reason, t.message ?: "") + context.getString(R.string.error_recognize_failed_with_reason, t.message ?: ""), ) } } - override suspend fun recognizeFromPcm(pcm: ByteArray) { recognize(pcm) } + override suspend fun recognizeFromPcm(pcm: ByteArray) { + recognize(pcm) + } /** * 从响应体中提取错误提示信息 diff --git a/app/src/main/java/com/brycewg/asrkb/clipboard/ClipboardFileManager.kt b/app/src/main/java/com/brycewg/asrkb/clipboard/ClipboardFileManager.kt index 91be20f5..cda0f515 100644 --- a/app/src/main/java/com/brycewg/asrkb/clipboard/ClipboardFileManager.kt +++ b/app/src/main/java/com/brycewg/asrkb/clipboard/ClipboardFileManager.kt @@ -15,8 +15,8 @@ class ClipboardFileManager(private val context: Context) { companion object { private const val TAG = "ClipboardFileManager" private const val BIBI_FOLDER = "BiBi" - private const val MAX_CACHE_SIZE_MB = 500 // 最大缓存 500MB - private const val MAX_FILE_AGE_DAYS = 30 // 文件最长保留 30 天 + private const val MAX_CACHE_SIZE_MB = 500 // 最大缓存 500MB + private const val MAX_FILE_AGE_DAYS = 30 // 文件最长保留 30 天 } /** @@ -75,7 +75,7 @@ class ClipboardFileManager(private val context: Context) { fileName: String, inputStream: InputStream, totalBytes: Long = -1, - progressCallback: ((Long, Long) -> Unit)? = null + progressCallback: ((Long, Long) -> Unit)? = null, ): String? { return try { val file = getFile(fileName) diff --git a/app/src/main/java/com/brycewg/asrkb/clipboard/ClipboardHistoryStore.kt b/app/src/main/java/com/brycewg/asrkb/clipboard/ClipboardHistoryStore.kt index 39afdbc4..884e3e42 100644 --- a/app/src/main/java/com/brycewg/asrkb/clipboard/ClipboardHistoryStore.kt +++ b/app/src/main/java/com/brycewg/asrkb/clipboard/ClipboardHistoryStore.kt @@ -13,7 +13,9 @@ import java.util.UUID */ @Serializable enum class EntryType { - TEXT, IMAGE, FILE + TEXT, + IMAGE, + FILE, } /** @@ -21,10 +23,10 @@ enum class EntryType { */ @Serializable enum class DownloadStatus { - NONE, // 未下载 + NONE, // 未下载 DOWNLOADING, // 下载中 - COMPLETED, // 已完成 - FAILED // 失败 + COMPLETED, // 已完成 + FAILED, // 失败 } /** @@ -40,17 +42,23 @@ class ClipboardHistoryStore(private val context: Context, private val prefs: Pre @Serializable data class Entry( val id: String, - val text: String = "", // 文本内容(保持向后兼容) + // 文本内容(保持向后兼容) + val text: String = "", val ts: Long, val pinned: Boolean, // 新增字段:支持文件类型 val type: EntryType = EntryType.TEXT, - val fileName: String? = null, // 文件名 - val fileSize: Long? = null, // 文件大小(字节) - val mimeType: String? = null, // MIME 类型 - val localFilePath: String? = null, // 本地文件路径 + // 文件名 + val fileName: String? = null, + // 文件大小(字节) + val fileSize: Long? = null, + // MIME 类型 + val mimeType: String? = null, + // 本地文件路径 + val localFilePath: String? = null, val downloadStatus: DownloadStatus = DownloadStatus.NONE, - val serverFileName: String? = null // 服务器上的文件名(用于下载) + // 服务器上的文件名(用于下载) + val serverFileName: String? = null, ) { /** * 用于列表 / 信息栏展示的文本。 @@ -72,7 +80,12 @@ class ClipboardHistoryStore(private val context: Context, private val prefs: Pre } private val sp by lazy { context.getSharedPreferences("asr_prefs", Context.MODE_PRIVATE) } - private val json by lazy { Json { ignoreUnknownKeys = true; encodeDefaults = true } } + private val json by lazy { + Json { + ignoreUnknownKeys = true + encodeDefaults = true + } + } companion object { private const val TAG = "ClipboardHistoryStore" @@ -206,7 +219,9 @@ class ClipboardHistoryStore(private val context: Context, private val prefs: Pre history.removeAt(idx) sp.edit().putString(KEY_CLIP_HISTORY_JSON, json.encodeToString(history)).apply() true - } else false + } else { + false + } } catch (t: Throwable) { Log.e(TAG, "deleteHistoryById failed", t) false @@ -243,7 +258,7 @@ class ClipboardHistoryStore(private val context: Context, private val prefs: Pre fileSize: Long? = null, mimeType: String? = null, localFilePath: String? = null, - downloadStatus: DownloadStatus = DownloadStatus.NONE + downloadStatus: DownloadStatus = DownloadStatus.NONE, ): Boolean { try { // 先清理旧的文件条目,保证「最多一个文件记录」 @@ -261,7 +276,7 @@ class ClipboardHistoryStore(private val context: Context, private val prefs: Pre mimeType = mimeType, localFilePath = localFilePath, downloadStatus = downloadStatus, - serverFileName = serverFileName + serverFileName = serverFileName, ) his.add(0, entry) @@ -280,7 +295,7 @@ class ClipboardHistoryStore(private val context: Context, private val prefs: Pre fun updateFileEntry( id: String, localFilePath: String?, - downloadStatus: DownloadStatus + downloadStatus: DownloadStatus, ): Boolean { try { val history = getHistory().toMutableList() @@ -289,7 +304,7 @@ class ClipboardHistoryStore(private val context: Context, private val prefs: Pre val old = history[idx] history[idx] = old.copy( localFilePath = localFilePath ?: old.localFilePath, - downloadStatus = downloadStatus + downloadStatus = downloadStatus, ) sp.edit().putString(KEY_CLIP_HISTORY_JSON, json.encodeToString(history)).apply() return true @@ -302,7 +317,7 @@ class ClipboardHistoryStore(private val context: Context, private val prefs: Pre val old = pinned[idxP] pinned[idxP] = old.copy( localFilePath = localFilePath ?: old.localFilePath, - downloadStatus = downloadStatus + downloadStatus = downloadStatus, ) sp.edit().putString(KEY_CLIP_PINNED_JSON, json.encodeToString(pinned)).apply() return true diff --git a/app/src/main/java/com/brycewg/asrkb/clipboard/SyncClipboardManager.kt b/app/src/main/java/com/brycewg/asrkb/clipboard/SyncClipboardManager.kt index f81f0e93..cc0df1bf 100644 --- a/app/src/main/java/com/brycewg/asrkb/clipboard/SyncClipboardManager.kt +++ b/app/src/main/java/com/brycewg/asrkb/clipboard/SyncClipboardManager.kt @@ -28,9 +28,9 @@ import java.util.concurrent.TimeUnit */ @Serializable private data class UploadClipboardPayload( - val hasData: Boolean = false, - val text: String, - val type: String = "Text" + val hasData: Boolean = false, + val text: String, + val type: String = "Text", ) /** @@ -38,13 +38,13 @@ private data class UploadClipboardPayload( */ @Serializable private data class PullClipboardPayload( - val text: String? = null, - val type: String? = null, - val hasData: Boolean? = null, - val dataName: String? = null, - @SerialName("Clipboard") val legacyClipboard: String? = null, - @SerialName("Type") val legacyType: String? = null, - @SerialName("File") val legacyFile: String? = null + val text: String? = null, + val type: String? = null, + val hasData: Boolean? = null, + val dataName: String? = null, + @SerialName("Clipboard") val legacyClipboard: String? = null, + @SerialName("Type") val legacyType: String? = null, + @SerialName("File") val legacyFile: String? = null, ) /** @@ -55,680 +55,684 @@ private data class PullClipboardPayload( * 注意:服务端认证使用标准 HTTP Basic(`Authorization: Basic `)。 */ class SyncClipboardManager( - private val context: Context, - private val prefs: Prefs, - private val scope: CoroutineScope, - private val listener: Listener? = null, - private val clipboardStore: ClipboardHistoryStore? = null + private val context: Context, + private val prefs: Prefs, + private val scope: CoroutineScope, + private val listener: Listener? = null, + private val clipboardStore: ClipboardHistoryStore? = null, ) { - interface Listener { - fun onPulledNewContent(text: String) - fun onUploadSuccess() - fun onUploadFailed(reason: String? = null) - fun onFilePulled(type: EntryType, fileName: String, serverFileName: String) - } - - private val clipboard by lazy { context.getSystemService(Context.CLIPBOARD_SERVICE) as ClipboardManager } - private val client by lazy { - OkHttpClient.Builder() - .connectTimeout(8, TimeUnit.SECONDS) - .readTimeout(8, TimeUnit.SECONDS) - .writeTimeout(8, TimeUnit.SECONDS) - .build() - } - private val json by lazy { Json { ignoreUnknownKeys = true } } - private val fileManager by lazy { ClipboardFileManager(context) } - - companion object { - private const val TAG = "SyncClipboardManager" - } - - private var pullJob: Job? = null - private var listenerRegistered = false - @Volatile private var suppressNextChange = false - // 记录最近一次从服务端拉取的文本哈希,用于减少本地剪贴板读取次数 - @Volatile private var lastPulledServerHash: String? = null - - private val clipListener = ClipboardManager.OnPrimaryClipChangedListener { - if (suppressNextChange) { - // 忽略由我们主动写入导致的回调 - suppressNextChange = false - return@OnPrimaryClipChangedListener + interface Listener { + fun onPulledNewContent(text: String) + fun onUploadSuccess() + fun onUploadFailed(reason: String? = null) + fun onFilePulled(type: EntryType, fileName: String, serverFileName: String) } - if (!prefs.syncClipboardEnabled) return@OnPrimaryClipChangedListener - scope.launch(Dispatchers.IO) { - try { - uploadCurrentClipboardText() - } catch (e: Throwable) { - Log.e(TAG, "Failed to upload clipboard text on change", e) - } + + private val clipboard by lazy { context.getSystemService(Context.CLIPBOARD_SERVICE) as ClipboardManager } + private val client by lazy { + OkHttpClient.Builder() + .connectTimeout(8, TimeUnit.SECONDS) + .readTimeout(8, TimeUnit.SECONDS) + .writeTimeout(8, TimeUnit.SECONDS) + .build() + } + private val json by lazy { Json { ignoreUnknownKeys = true } } + private val fileManager by lazy { ClipboardFileManager(context) } + + companion object { + private const val TAG = "SyncClipboardManager" } - } - - fun start() { - if (!prefs.syncClipboardEnabled) return - ensureListener() - ensurePullLoop() - } - - fun stop() { - try { - if (listenerRegistered) clipboard.removePrimaryClipChangedListener(clipListener) - } catch (e: Throwable) { - Log.e(TAG, "Failed to remove clipboard listener", e) + + private var pullJob: Job? = null + private var listenerRegistered = false + + @Volatile private var suppressNextChange = false + + // 记录最近一次从服务端拉取的文本哈希,用于减少本地剪贴板读取次数 + @Volatile private var lastPulledServerHash: String? = null + + private val clipListener = ClipboardManager.OnPrimaryClipChangedListener { + if (suppressNextChange) { + // 忽略由我们主动写入导致的回调 + suppressNextChange = false + return@OnPrimaryClipChangedListener + } + if (!prefs.syncClipboardEnabled) return@OnPrimaryClipChangedListener + scope.launch(Dispatchers.IO) { + try { + uploadCurrentClipboardText() + } catch (e: Throwable) { + Log.e(TAG, "Failed to upload clipboard text on change", e) + } + } } - listenerRegistered = false - pullJob?.cancel() - pullJob = null - suppressNextChange = false - lastPulledServerHash = null - } - - private fun ensureListener() { - if (!listenerRegistered) { - try { - clipboard.addPrimaryClipChangedListener(clipListener) - listenerRegistered = true - } catch (e: Throwable) { - Log.e(TAG, "Failed to add clipboard listener", e) - } + + fun start() { + if (!prefs.syncClipboardEnabled) return + ensureListener() + ensurePullLoop() } - } - - private fun ensurePullLoop() { - pullJob?.cancel() - if (!prefs.syncClipboardAutoPullEnabled) return - val intervalSec = prefs.syncClipboardPullIntervalSec.coerceIn(1, 600) - pullJob = scope.launch(Dispatchers.IO) { - while (isActive && prefs.syncClipboardEnabled && prefs.syncClipboardAutoPullEnabled) { + + fun stop() { try { - pullNow(updateClipboard = true) + if (listenerRegistered) clipboard.removePrimaryClipChangedListener(clipListener) } catch (e: Throwable) { - Log.e(TAG, "Failed to pull clipboard in loop", e) + Log.e(TAG, "Failed to remove clipboard listener", e) } - delay(intervalSec * 1000L) - } + listenerRegistered = false + pullJob?.cancel() + pullJob = null + suppressNextChange = false + lastPulledServerHash = null } - } - - private fun buildUrl(): String? { - val raw = prefs.syncClipboardServerBase.trim() - if (raw.isBlank()) return null - val base = raw.trimEnd('/') - val lower = base.lowercase() - return if (lower.endsWith(".json")) base else "$base/SyncClipboard.json" - } - - private fun sha256Hex(s: String): String { - val md = MessageDigest.getInstance("SHA-256") - val bytes = md.digest(s.toByteArray(Charsets.UTF_8)) - val sb = StringBuilder(bytes.size * 2) - for (b in bytes) sb.append(String.format("%02x", b)) - return sb.toString() - } - - private fun authHeaderB64(): String? { - val u = prefs.syncClipboardUsername - val p = prefs.syncClipboardPassword - if (u.isBlank() || p.isBlank()) return null - val token = "$u:$p".toByteArray(Charsets.UTF_8) - val b64 = Base64.encodeToString(token, Base64.NO_WRAP) - return "Basic $b64" - } - - private fun readClipboardText(): String? { - val clip = try { - clipboard.primaryClip - } catch (e: Throwable) { - Log.e(TAG, "Failed to read clipboard", e) - null - } ?: return null - if (clip.itemCount <= 0) return null - val item = clip.getItemAt(0) - val text = try { - item.coerceToText(context)?.toString() - } catch (e: Throwable) { - Log.e(TAG, "Failed to coerce clipboard item to text", e) - null + + private fun ensureListener() { + if (!listenerRegistered) { + try { + clipboard.addPrimaryClipChangedListener(clipListener) + listenerRegistered = true + } catch (e: Throwable) { + Log.e(TAG, "Failed to add clipboard listener", e) + } + } } - return text?.takeIf { it.isNotEmpty() } - } - - private fun writeClipboardText(text: String) { - val clip = ClipData.newPlainText("SyncClipboard", text) - suppressNextChange = true - try { - clipboard.setPrimaryClip(clip) - } catch (e: Throwable) { - Log.e(TAG, "Failed to write clipboard text", e) - } finally { - suppressNextChange = false + + private fun ensurePullLoop() { + pullJob?.cancel() + if (!prefs.syncClipboardAutoPullEnabled) return + val intervalSec = prefs.syncClipboardPullIntervalSec.coerceIn(1, 600) + pullJob = scope.launch(Dispatchers.IO) { + while (isActive && prefs.syncClipboardEnabled && prefs.syncClipboardAutoPullEnabled) { + try { + pullNow(updateClipboard = true) + } catch (e: Throwable) { + Log.e(TAG, "Failed to pull clipboard in loop", e) + } + delay(intervalSec * 1000L) + } + } } - } - - private fun uploadCurrentClipboardText() { - val url = buildUrl() ?: return - val authB64 = authHeaderB64() ?: return - val text = readClipboardText() ?: return - if (text.isEmpty()) return - // 若与最近一次成功上传(或最近一次拉取写入)相同,则跳过上传,避免重复 - try { - val newHash = sha256Hex(text) - val last = try { prefs.syncClipboardLastUploadedHash } catch (e: Throwable) { - Log.e(TAG, "Failed to read last uploaded hash", e) - "" - } - if (newHash == last) return - } catch (e: Throwable) { - Log.e(TAG, "Failed to compute hash for clipboard text", e) - // 继续尝试上传 + + private fun buildUrl(): String? { + val raw = prefs.syncClipboardServerBase.trim() + if (raw.isBlank()) return null + val base = raw.trimEnd('/') + val lower = base.lowercase() + return if (lower.endsWith(".json")) base else "$base/SyncClipboard.json" } - // 仅使用标准 Basic Base64 认证 - uploadText(url, authB64, text) - } - - private fun uploadText(url: String, auth: String, text: String): Boolean { - return try { - val payload = UploadClipboardPayload(text = text) - val bodyJson = json.encodeToString(payload) - val req = Request.Builder() - .url(url) - .header("Authorization", auth) - .put(bodyJson.toRequestBody("application/json; charset=utf-8".toMediaType())) - .build() - client.newCall(req).execute().use { resp -> - if (resp.isSuccessful) { - // 记录最近一次成功上传内容的哈希,便于后续对比 - try { - prefs.syncClipboardLastUploadedHash = sha256Hex(text) - } catch (e: Throwable) { - Log.e(TAG, "Failed to save uploaded hash", e) - } - try { - listener?.onUploadSuccess() - } catch (e: Throwable) { - Log.e(TAG, "Failed to notify upload success listener", e) - } - true - } else { - Log.w(TAG, "Upload failed with status: ${resp.code}") - try { - listener?.onUploadFailed("HTTP ${resp.code}") - } catch (e: Throwable) { - Log.e(TAG, "Failed to notify upload failed listener", e) - } - false - } - } - } catch (e: Throwable) { - Log.e(TAG, "Failed to upload clipboard text", e) - try { - listener?.onUploadFailed(e.message) - } catch (t: Throwable) { - Log.e(TAG, "Failed to notify upload failed listener (exception)", t) - } - false + + private fun sha256Hex(s: String): String { + val md = MessageDigest.getInstance("SHA-256") + val bytes = md.digest(s.toByteArray(Charsets.UTF_8)) + val sb = StringBuilder(bytes.size * 2) + for (b in bytes) sb.append(String.format("%02x", b)) + return sb.toString() + } + + private fun authHeaderB64(): String? { + val u = prefs.syncClipboardUsername + val p = prefs.syncClipboardPassword + if (u.isBlank() || p.isBlank()) return null + val token = "$u:$p".toByteArray(Charsets.UTF_8) + val b64 = Base64.encodeToString(token, Base64.NO_WRAP) + return "Basic $b64" } - } - - /** - * 一次性上传当前系统粘贴板文本(不进行"与上次一致"跳过判断)。 - * 返回是否成功。 - */ - fun uploadOnce(): Boolean { - val url = buildUrl() ?: return false - val authB64 = authHeaderB64() ?: return false - val text = readClipboardText() ?: return false - if (text.isEmpty()) return false - return try { - val ok = uploadText(url, authB64, text) - ok - } catch (e: Throwable) { - Log.e(TAG, "uploadOnce failed", e) - false + + private fun readClipboardText(): String? { + val clip = try { + clipboard.primaryClip + } catch (e: Throwable) { + Log.e(TAG, "Failed to read clipboard", e) + null + } ?: return null + if (clip.itemCount <= 0) return null + val item = clip.getItemAt(0) + val text = try { + item.coerceToText(context)?.toString() + } catch (e: Throwable) { + Log.e(TAG, "Failed to coerce clipboard item to text", e) + null + } + return text?.takeIf { it.isNotEmpty() } } - } - - /** - * 执行带认证的请求(HTTP Basic)。 - */ - private fun executeRequestWithAuth( - requestBuilder: (auth: String) -> Request, - responseHandler: (okhttp3.Response) -> T? - ): T? { - val authB64 = authHeaderB64() ?: return null - return try { - val req = requestBuilder(authB64) - client.newCall(req).execute().use { resp -> - if (resp.isSuccessful) { - return responseHandler(resp) + + private fun writeClipboardText(text: String) { + val clip = ClipData.newPlainText("SyncClipboard", text) + suppressNextChange = true + try { + clipboard.setPrimaryClip(clip) + } catch (e: Throwable) { + Log.e(TAG, "Failed to write clipboard text", e) + } finally { + suppressNextChange = false } - Log.w(TAG, "Auth failed with status: ${resp.code}") - null - } - } catch (e: Throwable) { - Log.e(TAG, "Auth request failed", e) - null } - } - - fun pullNow(updateClipboard: Boolean): Pair { - val url = buildUrl() ?: return false to null - - val result = try { - executeRequestWithAuth( - requestBuilder = { auth -> - Request.Builder() - .url(url) - .header("Authorization", auth) - .get() - .build() - }, - responseHandler = { resp -> - val body = resp.body?.string()?.takeIf { it.isNotEmpty() } - if (body == null) { - Log.w(TAG, "Pull response body is empty") - return@executeRequestWithAuth null - } - - val payload = try { - json.decodeFromString(body) - } catch (e: Throwable) { - Log.e(TAG, "Failed to parse clipboard payload", e) - return@executeRequestWithAuth null - } - - val payloadType = resolvePayloadType(payload) - when (payloadType.lowercase()) { - "text" -> { - val textDataName = if (payload.hasData == true) { - payload.dataName?.takeIf { it.isNotEmpty() } - ?: payload.legacyFile?.takeIf { it.isNotEmpty() } - } else { - null - } - val text = if (!textDataName.isNullOrBlank()) { - downloadTextData(textDataName) - } else { - resolvePayloadText(payload) - } - val nonBlankText = text?.takeIf { it.isNotBlank() } - if (nonBlankText == null) { - Log.w(TAG, "Clipboard text is blank") - return@executeRequestWithAuth null - } - return@executeRequestWithAuth handleTextPayload(nonBlankText, updateClipboard) + + private fun uploadCurrentClipboardText() { + val url = buildUrl() ?: return + val authB64 = authHeaderB64() ?: return + val text = readClipboardText() ?: return + if (text.isEmpty()) return + // 若与最近一次成功上传(或最近一次拉取写入)相同,则跳过上传,避免重复 + try { + val newHash = sha256Hex(text) + val last = try { + prefs.syncClipboardLastUploadedHash + } catch (e: Throwable) { + Log.e(TAG, "Failed to read last uploaded hash", e) + "" } - "image", "file" -> { - val fileName = resolvePayloadFileName(payload) - val nonBlankFileName = fileName?.takeIf { it.isNotBlank() } - if (nonBlankFileName == null) { - Log.w(TAG, "File name is blank for type: $payloadType") - return@executeRequestWithAuth null - } - val normalizedType = if (payloadType.equals("image", ignoreCase = true)) "Image" else "File" - return@executeRequestWithAuth handleFilePayload(normalizedType, nonBlankFileName) + if (newHash == last) return + } catch (e: Throwable) { + Log.e(TAG, "Failed to compute hash for clipboard text", e) + // 继续尝试上传 + } + // 仅使用标准 Basic Base64 认证 + uploadText(url, authB64, text) + } + + private fun uploadText(url: String, auth: String, text: String): Boolean { + return try { + val payload = UploadClipboardPayload(text = text) + val bodyJson = json.encodeToString(payload) + val req = Request.Builder() + .url(url) + .header("Authorization", auth) + .put(bodyJson.toRequestBody("application/json; charset=utf-8".toMediaType())) + .build() + client.newCall(req).execute().use { resp -> + if (resp.isSuccessful) { + // 记录最近一次成功上传内容的哈希,便于后续对比 + try { + prefs.syncClipboardLastUploadedHash = sha256Hex(text) + } catch (e: Throwable) { + Log.e(TAG, "Failed to save uploaded hash", e) + } + try { + listener?.onUploadSuccess() + } catch (e: Throwable) { + Log.e(TAG, "Failed to notify upload success listener", e) + } + true + } else { + Log.w(TAG, "Upload failed with status: ${resp.code}") + try { + listener?.onUploadFailed("HTTP ${resp.code}") + } catch (e: Throwable) { + Log.e(TAG, "Failed to notify upload failed listener", e) + } + false + } } - else -> { - Log.w(TAG, "Unsupported payload type: $payloadType") - return@executeRequestWithAuth null + } catch (e: Throwable) { + Log.e(TAG, "Failed to upload clipboard text", e) + try { + listener?.onUploadFailed(e.message) + } catch (t: Throwable) { + Log.e(TAG, "Failed to notify upload failed listener (exception)", t) } - } + false } - ) - } catch (e: Throwable) { - Log.e(TAG, "pullNow failed", e) - null } - return if (result != null) { - true to result - } else { - false to null + /** + * 一次性上传当前系统粘贴板文本(不进行"与上次一致"跳过判断)。 + * 返回是否成功。 + */ + fun uploadOnce(): Boolean { + val url = buildUrl() ?: return false + val authB64 = authHeaderB64() ?: return false + val text = readClipboardText() ?: return false + if (text.isEmpty()) return false + return try { + val ok = uploadText(url, authB64, text) + ok + } catch (e: Throwable) { + Log.e(TAG, "uploadOnce failed", e) + false + } } - } - - /** - * 统一解析 payload 类型,优先新字段,兼容旧字段。 - */ - private fun resolvePayloadType(payload: PullClipboardPayload): String { - val explicitType = payload.type?.trim().takeUnless { it.isNullOrBlank() } - ?: payload.legacyType?.trim().takeUnless { it.isNullOrBlank() } - if (explicitType != null) return explicitType - if (payload.hasData == true && (!payload.dataName.isNullOrBlank() || !payload.legacyFile.isNullOrBlank())) { - return "File" + + /** + * 执行带认证的请求(HTTP Basic)。 + */ + private fun executeRequestWithAuth( + requestBuilder: (auth: String) -> Request, + responseHandler: (okhttp3.Response) -> T?, + ): T? { + val authB64 = authHeaderB64() ?: return null + return try { + val req = requestBuilder(authB64) + client.newCall(req).execute().use { resp -> + if (resp.isSuccessful) { + return responseHandler(resp) + } + Log.w(TAG, "Auth failed with status: ${resp.code}") + null + } + } catch (e: Throwable) { + Log.e(TAG, "Auth request failed", e) + null + } } - return "Text" - } - - /** - * 统一解析 payload 文本内容,优先新字段,兼容旧字段。 - */ - private fun resolvePayloadText(payload: PullClipboardPayload): String? { - return payload.text?.takeIf { it.isNotEmpty() } - ?: payload.legacyClipboard?.takeIf { it.isNotEmpty() } - } - - /** - * 统一解析 payload 文件名,优先新字段,兼容旧字段。 - */ - private fun resolvePayloadFileName(payload: PullClipboardPayload): String? { - return payload.dataName?.takeIf { it.isNotEmpty() } - ?: payload.legacyFile?.takeIf { it.isNotEmpty() } - ?: payload.text?.takeIf { payload.hasData == true && it.isNotEmpty() } - ?: payload.legacyClipboard?.takeIf { payload.hasData == true && it.isNotEmpty() } - } - - /** - * 拉取文本内容:当 `Text + hasData=true` 时,正文存放在 `/file/{dataName}`。 - */ - private fun downloadTextData(dataName: String): String? { - val fileUrl = buildFileUrl(dataName) ?: run { - Log.w(TAG, "Failed to build text data url for: $dataName") - return null + + fun pullNow(updateClipboard: Boolean): Pair { + val url = buildUrl() ?: return false to null + + val result = try { + executeRequestWithAuth( + requestBuilder = { auth -> + Request.Builder() + .url(url) + .header("Authorization", auth) + .get() + .build() + }, + responseHandler = { resp -> + val body = resp.body?.string()?.takeIf { it.isNotEmpty() } + if (body == null) { + Log.w(TAG, "Pull response body is empty") + return@executeRequestWithAuth null + } + + val payload = try { + json.decodeFromString(body) + } catch (e: Throwable) { + Log.e(TAG, "Failed to parse clipboard payload", e) + return@executeRequestWithAuth null + } + + val payloadType = resolvePayloadType(payload) + when (payloadType.lowercase()) { + "text" -> { + val textDataName = if (payload.hasData == true) { + payload.dataName?.takeIf { it.isNotEmpty() } + ?: payload.legacyFile?.takeIf { it.isNotEmpty() } + } else { + null + } + val text = if (!textDataName.isNullOrBlank()) { + downloadTextData(textDataName) + } else { + resolvePayloadText(payload) + } + val nonBlankText = text?.takeIf { it.isNotBlank() } + if (nonBlankText == null) { + Log.w(TAG, "Clipboard text is blank") + return@executeRequestWithAuth null + } + return@executeRequestWithAuth handleTextPayload(nonBlankText, updateClipboard) + } + "image", "file" -> { + val fileName = resolvePayloadFileName(payload) + val nonBlankFileName = fileName?.takeIf { it.isNotBlank() } + if (nonBlankFileName == null) { + Log.w(TAG, "File name is blank for type: $payloadType") + return@executeRequestWithAuth null + } + val normalizedType = if (payloadType.equals("image", ignoreCase = true)) "Image" else "File" + return@executeRequestWithAuth handleFilePayload(normalizedType, nonBlankFileName) + } + else -> { + Log.w(TAG, "Unsupported payload type: $payloadType") + return@executeRequestWithAuth null + } + } + }, + ) + } catch (e: Throwable) { + Log.e(TAG, "pullNow failed", e) + null + } + + return if (result != null) { + true to result + } else { + false to null + } } - return executeRequestWithAuth( - requestBuilder = { auth -> - Request.Builder() - .url(fileUrl) - .header("Authorization", auth) - .get() - .build() - }, - responseHandler = { resp -> - val text = resp.body?.string() - if (text.isNullOrEmpty()) { - Log.w(TAG, "Downloaded text data is empty: $dataName") - return@executeRequestWithAuth null + + /** + * 统一解析 payload 类型,优先新字段,兼容旧字段。 + */ + private fun resolvePayloadType(payload: PullClipboardPayload): String { + val explicitType = payload.type?.trim().takeUnless { it.isNullOrBlank() } + ?: payload.legacyType?.trim().takeUnless { it.isNullOrBlank() } + if (explicitType != null) return explicitType + if (payload.hasData == true && (!payload.dataName.isNullOrBlank() || !payload.legacyFile.isNullOrBlank())) { + return "File" } - text - } - ) - } - - /** - * 处理文本类型的 payload - */ - private fun handleTextPayload(text: String, updateClipboard: Boolean): String { - // 远端内容变为文本时,清除历史中的文件条目与最近文件名记录 - try { - clipboardStore?.clearFileEntries() - prefs.syncClipboardLastFileName = "" - } catch (t: Throwable) { - Log.e(TAG, "Failed to clear file entries on text payload", t) + return "Text" } - // 计算服务端文本哈希并与上次拉取缓存对比,未变化则避免读取系统剪贴板 - val newServerHash = try { - sha256Hex(text) - } catch (e: Throwable) { - Log.e(TAG, "Failed to compute hash for pulled text", e) - null + /** + * 统一解析 payload 文本内容,优先新字段,兼容旧字段。 + */ + private fun resolvePayloadText(payload: PullClipboardPayload): String? { + return payload.text?.takeIf { it.isNotEmpty() } + ?: payload.legacyClipboard?.takeIf { it.isNotEmpty() } } - val prevServerHash = lastPulledServerHash - lastPulledServerHash = newServerHash - if (updateClipboard) { - if (newServerHash != null && newServerHash == prevServerHash) { - // 服务端内容未变化:跳过本地剪贴板读取以降低读取频率 - return text - } - val cur = readClipboardText() - if (text.isNotEmpty() && text != cur) { - writeClipboardText(text) - // 将此次拉取的内容也记录到"最近一次上传哈希",避免后续补上传(减少不必要的上传) + /** + * 统一解析 payload 文件名,优先新字段,兼容旧字段。 + */ + private fun resolvePayloadFileName(payload: PullClipboardPayload): String? { + return payload.dataName?.takeIf { it.isNotEmpty() } + ?: payload.legacyFile?.takeIf { it.isNotEmpty() } + ?: payload.text?.takeIf { payload.hasData == true && it.isNotEmpty() } + ?: payload.legacyClipboard?.takeIf { payload.hasData == true && it.isNotEmpty() } + } + + /** + * 拉取文本内容:当 `Text + hasData=true` 时,正文存放在 `/file/{dataName}`。 + */ + private fun downloadTextData(dataName: String): String? { + val fileUrl = buildFileUrl(dataName) ?: run { + Log.w(TAG, "Failed to build text data url for: $dataName") + return null + } + return executeRequestWithAuth( + requestBuilder = { auth -> + Request.Builder() + .url(fileUrl) + .header("Authorization", auth) + .get() + .build() + }, + responseHandler = { resp -> + val text = resp.body?.string() + if (text.isNullOrEmpty()) { + Log.w(TAG, "Downloaded text data is empty: $dataName") + return@executeRequestWithAuth null + } + text + }, + ) + } + + /** + * 处理文本类型的 payload + */ + private fun handleTextPayload(text: String, updateClipboard: Boolean): String { + // 远端内容变为文本时,清除历史中的文件条目与最近文件名记录 try { - prefs.syncClipboardLastUploadedHash = sha256Hex(text) + clipboardStore?.clearFileEntries() + prefs.syncClipboardLastFileName = "" + } catch (t: Throwable) { + Log.e(TAG, "Failed to clear file entries on text payload", t) + } + + // 计算服务端文本哈希并与上次拉取缓存对比,未变化则避免读取系统剪贴板 + val newServerHash = try { + sha256Hex(text) } catch (e: Throwable) { - Log.e(TAG, "Failed to save pulled hash", e) + Log.e(TAG, "Failed to compute hash for pulled text", e) + null } + val prevServerHash = lastPulledServerHash + lastPulledServerHash = newServerHash + + if (updateClipboard) { + if (newServerHash != null && newServerHash == prevServerHash) { + // 服务端内容未变化:跳过本地剪贴板读取以降低读取频率 + return text + } + val cur = readClipboardText() + if (text.isNotEmpty() && text != cur) { + writeClipboardText(text) + // 将此次拉取的内容也记录到"最近一次上传哈希",避免后续补上传(减少不必要的上传) + try { + prefs.syncClipboardLastUploadedHash = sha256Hex(text) + } catch (e: Throwable) { + Log.e(TAG, "Failed to save pulled hash", e) + } + try { + listener?.onPulledNewContent(text) + } catch (e: Throwable) { + Log.e(TAG, "Failed to notify pulled content listener", e) + } + } + } + return text + } + + /** + * 处理文件类型的 payload + * 仅添加到历史记录,不自动下载 + */ + private fun handleFilePayload(type: String, fileName: String): String { try { - listener?.onPulledNewContent(text) + // 若文件名与最近一次处理的文件相同,则视为内容未更新,避免重复触发预览 + val prevName = try { + prefs.syncClipboardLastFileName + } catch (e: Throwable) { + Log.e(TAG, "Failed to read last file name", e) + "" + } + if (fileName.isNotEmpty() && fileName == prevName) { + Log.d(TAG, "File payload unchanged, skip preview: $fileName") + return fileName + } + + val entryType = when (type.lowercase()) { + "image" -> EntryType.IMAGE + "file" -> EntryType.FILE + else -> EntryType.FILE + } + + // 检查文件是否已下载 + val localFile = fileManager.getFile(fileName) + val downloadStatus = if (localFile.exists()) { + DownloadStatus.COMPLETED + } else { + DownloadStatus.NONE + } + + val localPath = if (localFile.exists()) localFile.absolutePath else null + + // 添加到历史记录(仅保留最新一条文件记录) + clipboardStore?.addFileEntry( + type = entryType, + fileName = fileName, + serverFileName = fileName, + fileSize = if (localFile.exists()) localFile.length() else null, + localFilePath = localPath, + downloadStatus = downloadStatus, + ) + + // 通知监听器有新文件 + try { + listener?.onFilePulled(entryType, fileName, fileName) + } catch (e: Throwable) { + Log.e(TAG, "Failed to notify file pulled listener", e) + } + + // 记录最近一次成功处理的文件名 + try { + prefs.syncClipboardLastFileName = fileName + } catch (e: Throwable) { + Log.e(TAG, "Failed to save last file name", e) + } + + Log.d(TAG, "File payload handled: $fileName (type: $type, status: $downloadStatus)") + return fileName } catch (e: Throwable) { - Log.e(TAG, "Failed to notify pulled content listener", e) + Log.e(TAG, "Failed to handle file payload: $fileName", e) + return fileName } - } - } - return text - } - - /** - * 处理文件类型的 payload - * 仅添加到历史记录,不自动下载 - */ - private fun handleFilePayload(type: String, fileName: String): String { - try { - // 若文件名与最近一次处理的文件相同,则视为内容未更新,避免重复触发预览 - val prevName = try { - prefs.syncClipboardLastFileName - } catch (e: Throwable) { - Log.e(TAG, "Failed to read last file name", e) - "" - } - if (fileName.isNotEmpty() && fileName == prevName) { - Log.d(TAG, "File payload unchanged, skip preview: $fileName") - return fileName - } - - val entryType = when (type.lowercase()) { - "image" -> EntryType.IMAGE - "file" -> EntryType.FILE - else -> EntryType.FILE - } - - // 检查文件是否已下载 - val localFile = fileManager.getFile(fileName) - val downloadStatus = if (localFile.exists()) { - DownloadStatus.COMPLETED - } else { - DownloadStatus.NONE - } - - val localPath = if (localFile.exists()) localFile.absolutePath else null - - // 添加到历史记录(仅保留最新一条文件记录) - clipboardStore?.addFileEntry( - type = entryType, - fileName = fileName, - serverFileName = fileName, - fileSize = if (localFile.exists()) localFile.length() else null, - localFilePath = localPath, - downloadStatus = downloadStatus - ) - - // 通知监听器有新文件 - try { - listener?.onFilePulled(entryType, fileName, fileName) - } catch (e: Throwable) { - Log.e(TAG, "Failed to notify file pulled listener", e) - } - - // 记录最近一次成功处理的文件名 - try { - prefs.syncClipboardLastFileName = fileName - } catch (e: Throwable) { - Log.e(TAG, "Failed to save last file name", e) - } - - Log.d(TAG, "File payload handled: $fileName (type: $type, status: $downloadStatus)") - return fileName - } catch (e: Throwable) { - Log.e(TAG, "Failed to handle file payload: $fileName", e) - return fileName - } - } - - /** - * 下载文件 - * @param entryId 条目 ID - * @param progressCallback 进度回调 - * @return 是否下载成功 - */ - fun downloadFile( - entryId: String, - progressCallback: ((Long, Long) -> Unit)? = null - ): Boolean { - val entry = clipboardStore?.getEntryById(entryId) ?: return false - val serverFileName = entry.serverFileName ?: entry.fileName ?: return false - - // 检查是否已下载 - if (fileManager.fileExists(serverFileName, entry.fileSize)) { - Log.d(TAG, "File already downloaded: $serverFileName") - clipboardStore?.updateFileEntry( - entryId, - fileManager.getFile(serverFileName).absolutePath, - DownloadStatus.COMPLETED - ) - return true } - // 更新状态为下载中 - clipboardStore?.updateFileEntry(entryId, null, DownloadStatus.DOWNLOADING) + /** + * 下载文件 + * @param entryId 条目 ID + * @param progressCallback 进度回调 + * @return 是否下载成功 + */ + fun downloadFile( + entryId: String, + progressCallback: ((Long, Long) -> Unit)? = null, + ): Boolean { + val entry = clipboardStore?.getEntryById(entryId) ?: return false + val serverFileName = entry.serverFileName ?: entry.fileName ?: return false + + // 检查是否已下载 + if (fileManager.fileExists(serverFileName, entry.fileSize)) { + Log.d(TAG, "File already downloaded: $serverFileName") + clipboardStore?.updateFileEntry( + entryId, + fileManager.getFile(serverFileName).absolutePath, + DownloadStatus.COMPLETED, + ) + return true + } - val (ok, localPath) = downloadFileDirectInternal( - serverFileName = serverFileName, - expectedSize = entry.fileSize, - progressCallback = progressCallback - ) + // 更新状态为下载中 + clipboardStore?.updateFileEntry(entryId, null, DownloadStatus.DOWNLOADING) - if (ok && localPath != null) { - clipboardStore?.updateFileEntry(entryId, localPath, DownloadStatus.COMPLETED) - return true - } + val (ok, localPath) = downloadFileDirectInternal( + serverFileName = serverFileName, + expectedSize = entry.fileSize, + progressCallback = progressCallback, + ) - clipboardStore?.updateFileEntry(entryId, null, DownloadStatus.FAILED) - return false - } - - /** - * 直接按文件名下载文件(不依赖剪贴板历史条目) - * @param fileName 服务器上的文件名 - * @param progressCallback 进度回调 - * @return Pair<是否成功, 本地路径(成功时非 null)> - */ - fun downloadFileDirect( - fileName: String, - progressCallback: ((Long, Long) -> Unit)? = null - ): Pair { - if (fileName.isBlank()) return false to null - - // 已存在则直接返回 - if (fileManager.fileExists(fileName)) { - val local = fileManager.getFile(fileName) - Log.d(TAG, "File already downloaded (direct): $fileName -> ${local.absolutePath}") - return true to local.absolutePath - } + if (ok && localPath != null) { + clipboardStore?.updateFileEntry(entryId, localPath, DownloadStatus.COMPLETED) + return true + } - return downloadFileDirectInternal( - serverFileName = fileName, - expectedSize = null, - progressCallback = progressCallback - ) - } - - /** - * 文件下载核心实现,供历史条目下载和直接按文件名下载复用 - */ - private fun downloadFileDirectInternal( - serverFileName: String, - expectedSize: Long?, - progressCallback: ((Long, Long) -> Unit)? - ): Pair { - val fileUrl = buildFileUrl(serverFileName) ?: run { - Log.w(TAG, "Failed to build file url for: $serverFileName") - return false to null + clipboardStore?.updateFileEntry(entryId, null, DownloadStatus.FAILED) + return false } - val authB64 = authHeaderB64() ?: run { - Log.w(TAG, "Missing auth header for file download") - return false to null - } + /** + * 直接按文件名下载文件(不依赖剪贴板历史条目) + * @param fileName 服务器上的文件名 + * @param progressCallback 进度回调 + * @return Pair<是否成功, 本地路径(成功时非 null)> + */ + fun downloadFileDirect( + fileName: String, + progressCallback: ((Long, Long) -> Unit)? = null, + ): Pair { + if (fileName.isBlank()) return false to null + + // 已存在则直接返回 + if (fileManager.fileExists(fileName)) { + val local = fileManager.getFile(fileName) + Log.d(TAG, "File already downloaded (direct): $fileName -> ${local.absolutePath}") + return true to local.absolutePath + } - // 若已存在且大小匹配,直接返回 - if (fileManager.fileExists(serverFileName, expectedSize)) { - val local = fileManager.getFile(serverFileName) - Log.d(TAG, "File already exists with expected size: $serverFileName -> ${local.absolutePath}") - return true to local.absolutePath + return downloadFileDirectInternal( + serverFileName = fileName, + expectedSize = null, + progressCallback = progressCallback, + ) } - return try { - val req = Request.Builder() - .url(fileUrl) - .header("Authorization", authB64) - .get() - .build() - - client.newCall(req).execute().use { resp -> - if (!resp.isSuccessful) { - Log.w(TAG, "Download failed: ${resp.code}") - return false to null + /** + * 文件下载核心实现,供历史条目下载和直接按文件名下载复用 + */ + private fun downloadFileDirectInternal( + serverFileName: String, + expectedSize: Long?, + progressCallback: ((Long, Long) -> Unit)?, + ): Pair { + val fileUrl = buildFileUrl(serverFileName) ?: run { + Log.w(TAG, "Failed to build file url for: $serverFileName") + return false to null } - val body = resp.body ?: run { - Log.w(TAG, "Download body is null for: $serverFileName") - return false to null + val authB64 = authHeaderB64() ?: run { + Log.w(TAG, "Missing auth header for file download") + return false to null } - val totalBytes = body.contentLength() - val localPath = fileManager.saveFile( - serverFileName, - body.byteStream(), - totalBytes, - progressCallback - ) + // 若已存在且大小匹配,直接返回 + if (fileManager.fileExists(serverFileName, expectedSize)) { + val local = fileManager.getFile(serverFileName) + Log.d(TAG, "File already exists with expected size: $serverFileName -> ${local.absolutePath}") + return true to local.absolutePath + } - if (localPath != null) { - Log.d(TAG, "File downloaded successfully: $serverFileName -> $localPath") - true to localPath - } else { - Log.w(TAG, "Failed to save downloaded file: $serverFileName") - false to null + return try { + val req = Request.Builder() + .url(fileUrl) + .header("Authorization", authB64) + .get() + .build() + + client.newCall(req).execute().use { resp -> + if (!resp.isSuccessful) { + Log.w(TAG, "Download failed: ${resp.code}") + return false to null + } + + val body = resp.body ?: run { + Log.w(TAG, "Download body is null for: $serverFileName") + return false to null + } + + val totalBytes = body.contentLength() + val localPath = fileManager.saveFile( + serverFileName, + body.byteStream(), + totalBytes, + progressCallback, + ) + + if (localPath != null) { + Log.d(TAG, "File downloaded successfully: $serverFileName -> $localPath") + true to localPath + } else { + Log.w(TAG, "Failed to save downloaded file: $serverFileName") + false to null + } + } + } catch (e: Exception) { + Log.e(TAG, "Download error: $serverFileName", e) + false to null } - } - } catch (e: Exception) { - Log.e(TAG, "Download error: $serverFileName", e) - false to null - } - } - - /** - * 构建文件下载 URL - */ - private fun buildFileUrl(fileName: String): String? { - val raw = prefs.syncClipboardServerBase.trim() - if (raw.isBlank()) return null - val base = raw.trimEnd('/') - // 文件在服务器的 /file/ 目录下 - val encodedFileName = Uri.encode(fileName) - return "$base/file/$encodedFileName" - } - - /** - * 在启动时调用:若系统剪贴板文本与上次成功上传不一致,则主动上传一次。 - */ - fun proactiveUploadIfChanged() { - val url = buildUrl() ?: return - val authB64 = authHeaderB64() ?: return - val text = readClipboardText() ?: return - if (text.isEmpty()) return - val newHash = try { - sha256Hex(text) - } catch (e: Throwable) { - Log.e(TAG, "Failed to compute hash for proactive upload", e) - return } - val last = try { - prefs.syncClipboardLastUploadedHash - } catch (e: Throwable) { - Log.e(TAG, "Failed to read last uploaded hash", e) - "" + + /** + * 构建文件下载 URL + */ + private fun buildFileUrl(fileName: String): String? { + val raw = prefs.syncClipboardServerBase.trim() + if (raw.isBlank()) return null + val base = raw.trimEnd('/') + // 文件在服务器的 /file/ 目录下 + val encodedFileName = Uri.encode(fileName) + return "$base/file/$encodedFileName" } - if (newHash != last) { - try { - uploadText(url, authB64, text) - } catch (e: Throwable) { - Log.e(TAG, "proactiveUploadIfChanged failed", e) - } + + /** + * 在启动时调用:若系统剪贴板文本与上次成功上传不一致,则主动上传一次。 + */ + fun proactiveUploadIfChanged() { + val url = buildUrl() ?: return + val authB64 = authHeaderB64() ?: return + val text = readClipboardText() ?: return + if (text.isEmpty()) return + val newHash = try { + sha256Hex(text) + } catch (e: Throwable) { + Log.e(TAG, "Failed to compute hash for proactive upload", e) + return + } + val last = try { + prefs.syncClipboardLastUploadedHash + } catch (e: Throwable) { + Log.e(TAG, "Failed to read last uploaded hash", e) + "" + } + if (newHash != last) { + try { + uploadText(url, authB64, text) + } catch (e: Throwable) { + Log.e(TAG, "proactiveUploadIfChanged failed", e) + } + } } - } } diff --git a/app/src/main/java/com/brycewg/asrkb/ime/AiEditUseCase.kt b/app/src/main/java/com/brycewg/asrkb/ime/AiEditUseCase.kt index 56176ddc..d4bf4c0b 100644 --- a/app/src/main/java/com/brycewg/asrkb/ime/AiEditUseCase.kt +++ b/app/src/main/java/com/brycewg/asrkb/ime/AiEditUseCase.kt @@ -21,7 +21,7 @@ internal class AiEditUseCase( private val updateSessionContext: ((KeyboardSessionContext) -> KeyboardSessionContext) -> Unit, private val transitionToState: (KeyboardState) -> Unit, private val transitionToIdle: (keepMessage: Boolean) -> Unit, - private val transitionToIdleWithTiming: (showBackupUsedHint: Boolean) -> Unit + private val transitionToIdleWithTiming: (showBackupUsedHint: Boolean) -> Unit, ) { fun handleClick(ic: InputConnection?) { if (ic == null) { @@ -156,4 +156,3 @@ internal class AiEditUseCase( private const val TAG = "AiEditUseCase" } } - diff --git a/app/src/main/java/com/brycewg/asrkb/ime/AsrCommitRecorder.kt b/app/src/main/java/com/brycewg/asrkb/ime/AsrCommitRecorder.kt index d0e14fa9..0d80f5fb 100644 --- a/app/src/main/java/com/brycewg/asrkb/ime/AsrCommitRecorder.kt +++ b/app/src/main/java/com/brycewg/asrkb/ime/AsrCommitRecorder.kt @@ -11,13 +11,13 @@ internal class AsrCommitRecorder( private val context: Context, private val prefs: Prefs, private val asrManager: AsrSessionManager, - private val logTag: String + private val logTag: String, ) { fun record( text: String, aiProcessed: Boolean, aiPostMs: Long = 0L, - aiPostStatus: AsrHistoryStore.AiPostStatus = AsrHistoryStore.AiPostStatus.NONE + aiPostStatus: AsrHistoryStore.AiPostStatus = AsrHistoryStore.AiPostStatus.NONE, ) { try { val chars = TextSanitizer.countEffectiveChars(text) @@ -42,7 +42,7 @@ internal class AsrCommitRecorder( procMs = procMs, source = "ime", aiProcessed = aiProcessed, - charCount = chars + charCount = chars, ) if (!prefs.disableUsageStats) { @@ -64,8 +64,8 @@ internal class AsrCommitRecorder( aiProcessed = aiProcessed, aiPostMs = aiPostMs, aiPostStatus = aiPostStatus, - charCount = chars - ) + charCount = chars, + ), ) } catch (e: Exception) { Log.e(logTag, "Failed to add ASR history", e) @@ -79,4 +79,3 @@ internal class AsrCommitRecorder( } } } - diff --git a/app/src/main/java/com/brycewg/asrkb/ime/AsrKeyboardService.kt b/app/src/main/java/com/brycewg/asrkb/ime/AsrKeyboardService.kt index b6a295f0..a09944f1 100644 --- a/app/src/main/java/com/brycewg/asrkb/ime/AsrKeyboardService.kt +++ b/app/src/main/java/com/brycewg/asrkb/ime/AsrKeyboardService.kt @@ -1,36 +1,36 @@ package com.brycewg.asrkb.ime import android.Manifest +import android.annotation.SuppressLint import android.content.BroadcastReceiver -import android.content.IntentFilter import android.content.Intent -import android.annotation.SuppressLint +import android.content.IntentFilter import android.content.pm.PackageManager import android.inputmethodservice.InputMethodService -import android.view.LayoutInflater import android.view.ContextThemeWrapper +import android.view.LayoutInflater import android.view.View -import android.view.inputmethod.InputMethodManager import android.view.inputmethod.EditorInfo +import android.view.inputmethod.InputMethodManager +import androidx.appcompat.widget.PopupMenu import androidx.core.content.ContextCompat +import com.brycewg.asrkb.LocaleHelper import com.brycewg.asrkb.R +import com.brycewg.asrkb.UiColors import com.brycewg.asrkb.asr.AsrVendor import com.brycewg.asrkb.asr.BluetoothRouteManager import com.brycewg.asrkb.asr.LlmPostProcessor import com.brycewg.asrkb.asr.partitionAsrVendorsByConfigured import com.brycewg.asrkb.store.Prefs -import com.brycewg.asrkb.util.HapticFeedbackHelper -import com.brycewg.asrkb.ui.SettingsActivity +import com.brycewg.asrkb.store.debug.DebugLogManager import com.brycewg.asrkb.ui.AsrVendorUi +import com.brycewg.asrkb.ui.SettingsActivity +import com.brycewg.asrkb.util.HapticFeedbackHelper import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.SupervisorJob import kotlinx.coroutines.cancel import kotlinx.coroutines.launch -import com.brycewg.asrkb.LocaleHelper -import com.brycewg.asrkb.UiColors -import com.brycewg.asrkb.store.debug.DebugLogManager -import androidx.appcompat.widget.PopupMenu /** * ASR 键盘服务 @@ -93,9 +93,11 @@ class AsrKeyboardService : InputMethodService(), KeyboardActionHandler.UiListene // ========== 剪贴板和其他辅助功能 ========== private var prefsReceiver: BroadcastReceiver? = null + // 本地模型首次出现预热仅触发一次 private var localPreloadTriggered: Boolean = false private var suppressReturnPrevImeOnHideOnce: Boolean = false + // 记录最近一次在 IME 内弹出菜单的时间,用于限制“防误收起”逻辑的作用窗口 private var lastPopupMenuShownAt: Long = 0L @@ -115,7 +117,7 @@ class AsrKeyboardService : InputMethodService(), KeyboardActionHandler.UiListene prefs, asrManager, inputHelper, - LlmPostProcessor() + LlmPostProcessor(), ) backspaceGestureHandler = BackspaceGestureHandler(inputHelper) @@ -153,12 +155,16 @@ class AsrKeyboardService : InputMethodService(), KeyboardActionHandler.UiListene prefsReceiver = r try { androidx.core.content.ContextCompat.registerReceiver( - /* context = */ this, - /* receiver = */ r, - /* filter = */ IntentFilter().apply { + /* context = */ + this, + /* receiver = */ + r, + /* filter = */ + IntentFilter().apply { addAction(ACTION_REFRESH_IME_UI) }, - /* flags = */ androidx.core.content.ContextCompat.RECEIVER_NOT_EXPORTED + /* flags = */ + androidx.core.content.ContextCompat.RECEIVER_NOT_EXPORTED, ) } catch (e: Throwable) { android.util.Log.e("AsrKeyboardService", "Failed to register prefsReceiver", e) @@ -184,18 +190,18 @@ class AsrKeyboardService : InputMethodService(), KeyboardActionHandler.UiListene return createKeyboardView() } - private fun createKeyboardView(): View { - val themedContext = ContextThemeWrapper(this, R.style.Theme_ASRKeyboard_Ime) - val dynamicContext = com.google.android.material.color.DynamicColors.wrapContextIfAvailable(themedContext) - val view = LayoutInflater.from(dynamicContext).inflate(R.layout.keyboard_view, null, false) - return setupKeyboardView(view) - } + private fun createKeyboardView(): View { + val themedContext = ContextThemeWrapper(this, R.style.Theme_ASRKeyboard_Ime) + val dynamicContext = com.google.android.material.color.DynamicColors.wrapContextIfAvailable(themedContext) + val view = LayoutInflater.from(dynamicContext).inflate(R.layout.keyboard_view, null, false) + return setupKeyboardView(view) + } - private fun setupKeyboardView(view: View): View { - rootView = view + private fun setupKeyboardView(view: View): View { + rootView = view - // 根据主题动态调整键盘背景色,使其略浅于当前容器色但仍明显深于普通按键与麦克风按钮 - themeStyler.applyKeyboardBackgroundColor(view) + // 根据主题动态调整键盘背景色,使其略浅于当前容器色但仍明显深于普通按键与麦克风按钮 + themeStyler.applyKeyboardBackgroundColor(view) // 应用 Window Insets 以适配 Android 15 边缘到边缘显示 layoutController?.installKeyboardInsetsListener(view) @@ -241,14 +247,13 @@ class AsrKeyboardService : InputMethodService(), KeyboardActionHandler.UiListene "imeOptions" to (info?.imeOptions ?: 0), "icNull" to (currentInputConnection == null), "isMultiLine" to ((info?.inputType ?: 0) and android.text.InputType.TYPE_TEXT_FLAG_MULTI_LINE != 0), - "actionId" to ((info?.imeOptions ?: 0) and android.view.inputmethod.EditorInfo.IME_MASK_ACTION) - ) + "actionId" to ((info?.imeOptions ?: 0) and android.view.inputmethod.EditorInfo.IME_MASK_ACTION), + ), ) // 键盘面板首次出现时,按需异步预加载本地模型(SenseVoice/FunASR Nano/Paraformer) tryPreloadLocalModel() - // 刷新 UI viewRefs?.btnImeSwitcher?.visibility = View.VISIBLE mainKeyboardBinder?.applyPunctuationLabels() @@ -260,7 +265,6 @@ class AsrKeyboardService : InputMethodService(), KeyboardActionHandler.UiListene onStateChanged(actionHandler.getCurrentState()) } - // 同步系统栏颜色 rootView?.post { syncSystemBarsToKeyboardBackground(rootView) } @@ -276,7 +280,11 @@ class AsrKeyboardService : InputMethodService(), KeyboardActionHandler.UiListene clipboardCoordinator?.startClipboardPreviewListener() // 预热耳机路由(键盘显示) - try { BluetoothRouteManager.setImeActive(this, true) } catch (t: Throwable) { android.util.Log.w("AsrKeyboardService", "BluetoothRouteManager setImeActive(true)", t) } + try { + BluetoothRouteManager.setImeActive(this, true) + } catch (t: Throwable) { + android.util.Log.w("AsrKeyboardService", "BluetoothRouteManager setImeActive(true)", t) + } // 自动启动录音(如果开启了设置) if (prefs.autoStartRecordingOnShow) { @@ -294,7 +302,6 @@ class AsrKeyboardService : InputMethodService(), KeyboardActionHandler.UiListene }, 100) } } - } override fun onUpdateSelection( @@ -303,7 +310,7 @@ class AsrKeyboardService : InputMethodService(), KeyboardActionHandler.UiListene newSelStart: Int, newSelEnd: Int, candidatesStart: Int, - candidatesEnd: Int + candidatesEnd: Int, ) { super.onUpdateSelection(oldSelStart, oldSelEnd, newSelStart, newSelEnd, candidatesStart, candidatesEnd) aiEditPanelController?.onSelectionChanged(newSelStart, newSelEnd) @@ -321,7 +328,11 @@ class AsrKeyboardService : InputMethodService(), KeyboardActionHandler.UiListene resetPanelsToMainKeyboard() // 键盘收起,解除预热(若未在录音) - try { BluetoothRouteManager.setImeActive(this, false) } catch (t: Throwable) { android.util.Log.w("AsrKeyboardService", "BluetoothRouteManager setImeActive(false)", t) } + try { + BluetoothRouteManager.setImeActive(this, false) + } catch (t: Throwable) { + android.util.Log.w("AsrKeyboardService", "BluetoothRouteManager setImeActive(false)", t) + } // 如开启:键盘收起后自动切回上一个输入法 if (prefs.returnPrevImeOnHide) { @@ -678,7 +689,7 @@ class AsrKeyboardService : InputMethodService(), KeyboardActionHandler.UiListene private fun hasRecordAudioPermission(): Boolean { return ContextCompat.checkSelfPermission( this, - Manifest.permission.RECORD_AUDIO + Manifest.permission.RECORD_AUDIO, ) == PackageManager.PERMISSION_GRANTED } @@ -856,7 +867,10 @@ class AsrKeyboardService : InputMethodService(), KeyboardActionHandler.UiListene else -> false } if (!enabled) return - if (com.brycewg.asrkb.asr.isLocalAsrPrepared(p)) { localPreloadTriggered = true; return } + if (com.brycewg.asrkb.asr.isLocalAsrPrepared(p)) { + localPreloadTriggered = true + return + } // 信息栏显示"加载中…",完成后回退状态 rootView?.post { @@ -883,7 +897,7 @@ class AsrKeyboardService : InputMethodService(), KeyboardActionHandler.UiListene }, 1200) } }, - suppressToastOnStart = true + suppressToastOnStart = true, ) } } diff --git a/app/src/main/java/com/brycewg/asrkb/ime/AsrSessionManager.kt b/app/src/main/java/com/brycewg/asrkb/ime/AsrSessionManager.kt index 92ef668e..0bfa367e 100644 --- a/app/src/main/java/com/brycewg/asrkb/ime/AsrSessionManager.kt +++ b/app/src/main/java/com/brycewg/asrkb/ime/AsrSessionManager.kt @@ -1,15 +1,47 @@ package com.brycewg.asrkb.ime import android.content.Context -import android.util.Log -import android.os.SystemClock +import android.media.AudioAttributes +import android.media.AudioFocusRequest +import android.media.AudioManager import android.os.Handler import android.os.Looper -import android.media.AudioManager -import android.media.AudioFocusRequest -import android.media.AudioAttributes -import com.brycewg.asrkb.asr.* +import android.os.SystemClock +import android.util.Log +import com.brycewg.asrkb.asr.AsrErrorMessageMapper +import com.brycewg.asrkb.asr.AsrVendor +import com.brycewg.asrkb.asr.BaseFileAsrEngine import com.brycewg.asrkb.asr.BluetoothRouteManager +import com.brycewg.asrkb.asr.DashscopeFileAsrEngine +import com.brycewg.asrkb.asr.DashscopeStreamAsrEngine +import com.brycewg.asrkb.asr.ElevenLabsFileAsrEngine +import com.brycewg.asrkb.asr.ElevenLabsStreamAsrEngine +import com.brycewg.asrkb.asr.FunAsrNanoFileAsrEngine +import com.brycewg.asrkb.asr.GeminiFileAsrEngine +import com.brycewg.asrkb.asr.LOCAL_MODEL_READY_WAIT_MAX_MS +import com.brycewg.asrkb.asr.OpenAiFileAsrEngine +import com.brycewg.asrkb.asr.ParaformerStreamAsrEngine +import com.brycewg.asrkb.asr.ParallelAsrEngine +import com.brycewg.asrkb.asr.SenseVoiceFileAsrEngine +import com.brycewg.asrkb.asr.SenseVoicePseudoStreamAsrEngine +import com.brycewg.asrkb.asr.SiliconFlowFileAsrEngine +import com.brycewg.asrkb.asr.SonioxFileAsrEngine +import com.brycewg.asrkb.asr.SonioxStreamAsrEngine +import com.brycewg.asrkb.asr.StreamingAsrEngine +import com.brycewg.asrkb.asr.TelespeechFileAsrEngine +import com.brycewg.asrkb.asr.TelespeechPseudoStreamAsrEngine +import com.brycewg.asrkb.asr.VolcFileAsrEngine +import com.brycewg.asrkb.asr.VolcStandardFileAsrEngine +import com.brycewg.asrkb.asr.VolcStreamAsrEngine +import com.brycewg.asrkb.asr.ZhipuFileAsrEngine +import com.brycewg.asrkb.asr.awaitLocalAsrReady +import com.brycewg.asrkb.asr.isFunAsrNanoPrepared +import com.brycewg.asrkb.asr.isLocalAsrReady +import com.brycewg.asrkb.asr.isLocalAsrVendor +import com.brycewg.asrkb.asr.isSenseVoicePrepared +import com.brycewg.asrkb.asr.isTelespeechPrepared +import com.brycewg.asrkb.asr.preloadFunAsrNanoIfConfigured +import com.brycewg.asrkb.asr.preloadTelespeechIfConfigured import com.brycewg.asrkb.store.Prefs import com.brycewg.asrkb.store.debug.DebugLogManager import kotlinx.coroutines.CoroutineScope @@ -31,7 +63,7 @@ import java.util.concurrent.atomic.AtomicLong class AsrSessionManager( private val context: Context, private val scope: CoroutineScope, - private val prefs: Prefs + private val prefs: Prefs, ) : StreamingAsrEngine.Listener, SenseVoiceFileAsrEngine.LocalModelLoadUi { companion object { @@ -95,7 +127,11 @@ class AsrSessionManager( private var lastRequestDurationMs: Long? = null // 统计/历史:本次会话主/备供应商快照(避免设置变更导致 vendorId 串台) - private var sessionPrimaryVendor: AsrVendor = try { prefs.asrVendor } catch (_: Throwable) { AsrVendor.Volc } + private var sessionPrimaryVendor: AsrVendor = try { + prefs.asrVendor + } catch (_: Throwable) { + AsrVendor.Volc + } private var lastFinalVendorForStats: AsrVendor? = null // 本地模型:Processing 阶段等待“模型就绪”的耗时(用于将处理耗时统计从模型就绪开始) @@ -107,6 +143,7 @@ class AsrSessionManager( // 会话录音时长统计(毫秒) private var sessionStartUptimeMs: Long = 0L private var lastAudioMsForStats: Long = 0L + // 统计/历史:端到端耗时起点(从开始录音到最终提交完成) private var sessionStartTotalUptimeMs: Long = 0L @@ -160,7 +197,7 @@ class AsrSessionManager( listener = this, primaryVendor = primaryVendor, backupVendor = backupVendor, - onPrimaryRequestDuration = ::onRequestDuration + onPrimaryRequestDuration = ::onRequestDuration, ) } return when (prefs.asrVendor) { @@ -174,11 +211,15 @@ class AsrSessionManager( VolcFileAsrEngine(context, scope, prefs, this, ::onRequestDuration) } } - } else null + } else { + null + } AsrVendor.SiliconFlow -> if (prefs.hasSfKeys()) { SiliconFlowFileAsrEngine(context, scope, prefs, this, ::onRequestDuration) - } else null + } else { + null + } AsrVendor.ElevenLabs -> if (prefs.hasElevenKeys()) { if (prefs.elevenStreamingEnabled) { @@ -186,11 +227,15 @@ class AsrSessionManager( } else { ElevenLabsFileAsrEngine(context, scope, prefs, this, ::onRequestDuration) } - } else null + } else { + null + } AsrVendor.OpenAI -> if (prefs.hasOpenAiKeys()) { OpenAiFileAsrEngine(context, scope, prefs, this, ::onRequestDuration) - } else null + } else { + null + } AsrVendor.DashScope -> if (prefs.hasDashKeys()) { if (prefs.isDashStreamingModelSelected()) { @@ -198,11 +243,15 @@ class AsrSessionManager( } else { DashscopeFileAsrEngine(context, scope, prefs, this, ::onRequestDuration) } - } else null + } else { + null + } AsrVendor.Gemini -> if (prefs.hasGeminiKeys()) { GeminiFileAsrEngine(context, scope, prefs, this, ::onRequestDuration) - } else null + } else { + null + } AsrVendor.Soniox -> if (prefs.hasSonioxKeys()) { if (prefs.sonioxStreamingEnabled) { @@ -210,11 +259,15 @@ class AsrSessionManager( } else { SonioxFileAsrEngine(context, scope, prefs, this, ::onRequestDuration) } - } else null + } else { + null + } AsrVendor.Zhipu -> if (prefs.hasZhipuKeys()) { ZhipuFileAsrEngine(context, scope, prefs, this, ::onRequestDuration) - } else null + } else { + null + } AsrVendor.SenseVoice -> { if (prefs.svPseudoStreamEnabled) { @@ -245,7 +298,11 @@ class AsrSessionManager( } private fun shouldUseBackupAsr(primaryVendor: AsrVendor, backupVendor: AsrVendor): Boolean { - val enabled = try { prefs.backupAsrEnabled } catch (_: Throwable) { false } + val enabled = try { + prefs.backupAsrEnabled + } catch (_: Throwable) { + false + } if (!enabled) return false if (backupVendor == primaryVendor) return false return try { @@ -273,7 +330,9 @@ class AsrSessionManager( val matched = when (current) { is ParallelAsrEngine -> if (current.primaryVendor == primaryVendor && current.backupVendor == backupVendor) { current - } else null + } else { + null + } else -> null } val engine = matched ?: buildEngine() @@ -373,7 +432,9 @@ class AsrSessionManager( Log.w(TAG, "Cancel local model wait job failed on startRecording", t) } localModelReadyWaitJob = null - try { sessionStartUptimeMs = SystemClock.uptimeMillis() } catch (t: Throwable) { + try { + sessionStartUptimeMs = SystemClock.uptimeMillis() + } catch (t: Throwable) { Log.w(TAG, "Failed to get uptime for session start", t) sessionStartUptimeMs = 0L } @@ -391,8 +452,8 @@ class AsrSessionManager( "vendor" to prefs.asrVendor.name, "engine" to (eng?.javaClass?.simpleName ?: "null"), "state" to state::class.java.simpleName, - "duckMedia" to prefs.duckMediaOnRecordEnabled - ) + "duckMedia" to prefs.duckMediaOnRecordEnabled, + ), ) } catch (_: Throwable) { } // 开始录音前根据设置决定是否请求短时独占音频焦点(音频避让) @@ -428,7 +489,7 @@ class AsrSessionManager( onLoadStart = { onLocalModelLoadStart() }, onLoadDone = { onLocalModelLoadDone() }, suppressToastOnStart = true, - forImmediateUse = true + forImmediateUse = true, ) AsrVendor.Telespeech -> com.brycewg.asrkb.asr.preloadTelespeechIfConfigured( context, @@ -436,7 +497,7 @@ class AsrSessionManager( onLoadStart = { onLocalModelLoadStart() }, onLoadDone = { onLocalModelLoadDone() }, suppressToastOnStart = true, - forImmediateUse = true + forImmediateUse = true, ) else -> com.brycewg.asrkb.asr.preloadSenseVoiceIfConfigured( context, @@ -444,7 +505,7 @@ class AsrSessionManager( onLoadStart = { onLocalModelLoadStart() }, onLoadDone = { onLocalModelLoadDone() }, suppressToastOnStart = true, - forImmediateUse = true + forImmediateUse = true, ) } } catch (t: Throwable) { @@ -462,12 +523,16 @@ class AsrSessionManager( event = "start_state", data = mapOf( "engine" to (asrEngine?.javaClass?.simpleName ?: "null"), - "running" to (asrEngine?.isRunning == true) - ) + "running" to (asrEngine?.isRunning == true), + ), ) } catch (_: Throwable) { } // 录音期间保持耳机路由 - try { BluetoothRouteManager.onRecordingStarted(context) } catch (t: Throwable) { Log.w(TAG, "BluetoothRouteManager onRecordingStarted", t) } + try { + BluetoothRouteManager.onRecordingStarted(context) + } catch (t: Throwable) { + Log.w(TAG, "BluetoothRouteManager onRecordingStarted", t) + } } /** @@ -483,8 +548,8 @@ class AsrSessionManager( event = "stop", data = mapOf( "state" to currentState::class.java.simpleName, - "engineRunning" to (asrEngine?.isRunning == true) - ) + "engineRunning" to (asrEngine?.isRunning == true), + ), ) } catch (_: Throwable) { } // 归还音频焦点 @@ -494,7 +559,11 @@ class AsrSessionManager( Log.w(TAG, "abandonAudioFocusIfNeeded failed on stopRecording", t) } // 若无键盘可见,录音结束后可撤销预热 - try { BluetoothRouteManager.onRecordingStopped(context) } catch (t: Throwable) { Log.w(TAG, "BluetoothRouteManager onRecordingStopped", t) } + try { + BluetoothRouteManager.onRecordingStopped(context) + } catch (t: Throwable) { + Log.w(TAG, "BluetoothRouteManager onRecordingStopped", t) + } } /** @@ -628,8 +697,8 @@ class AsrSessionManager( event = "final", data = mapOf( "len" to text.length, - "state" to currentState::class.java.simpleName - ) + "state" to currentState::class.java.simpleName, + ), ) } catch (_: Throwable) { } listener?.onAsrFinal(text, currentState) @@ -660,8 +729,8 @@ class AsrSessionManager( event = "error", data = mapOf( "state" to currentState::class.java.simpleName, - "msgType" to if (friendlyMessage != null) "friendly" else "raw" - ) + "msgType" to if (friendlyMessage != null) "friendly" else "raw", + ), ) } catch (_: Throwable) { } listener?.onAsrError(friendlyMessage ?: message) @@ -696,8 +765,8 @@ class AsrSessionManager( event = "stopped", data = mapOf( "audioMs" to ms, - "state" to currentState::class.java.simpleName - ) + "state" to currentState::class.java.simpleName, + ), ) } catch (_: Throwable) { } listener?.onAsrStopped() @@ -713,7 +782,9 @@ class AsrSessionManager( Log.d(TAG, "onLocalModelLoadStart") try { Handler(Looper.getMainLooper()).post { - try { listener?.onLocalModelLoadStart() } catch (t: Throwable) { + try { + listener?.onLocalModelLoadStart() + } catch (t: Throwable) { Log.e(TAG, "Failed to deliver onLocalModelLoadStart to UI", t) } } @@ -726,7 +797,9 @@ class AsrSessionManager( Log.d(TAG, "onLocalModelLoadDone") try { Handler(Looper.getMainLooper()).post { - try { listener?.onLocalModelLoadDone() } catch (t: Throwable) { + try { + listener?.onLocalModelLoadDone() + } catch (t: Throwable) { Log.e(TAG, "Failed to deliver onLocalModelLoadDone to UI", t) } } @@ -738,14 +811,18 @@ class AsrSessionManager( // ========== 私有方法 ========== private fun markLocalModelProcessingStartIfNeeded() { - val vendor = try { prefs.asrVendor } catch (t: Throwable) { + val vendor = try { + prefs.asrVendor + } catch (t: Throwable) { Log.w(TAG, "Failed to read asrVendor for local model timing", t) return } if (!isLocalAsrVendor(vendor)) return if (localModelWaitStartUptimeMs != 0L) return - val startMs = try { SystemClock.uptimeMillis() } catch (t: Throwable) { + val startMs = try { + SystemClock.uptimeMillis() + } catch (t: Throwable) { Log.w(TAG, "Failed to read uptime for local model timing", t) 0L } @@ -765,7 +842,11 @@ class AsrSessionManager( val ok = awaitLocalAsrReady(prefs, maxWaitMs = LOCAL_MODEL_READY_WAIT_MAX_MS) if (!ok) return@launch if (sessionSeq != seq) return@launch - val readyAt = try { SystemClock.uptimeMillis() } catch (_: Throwable) { 0L } + val readyAt = try { + SystemClock.uptimeMillis() + } catch (_: Throwable) { + 0L + } if (readyAt > 0L && startMs > 0L && readyAt >= startMs) { localModelReadyWaitMs.compareAndSet(0L, (readyAt - startMs).coerceAtLeast(0L)) } diff --git a/app/src/main/java/com/brycewg/asrkb/ime/BackspaceGestureHandler.kt b/app/src/main/java/com/brycewg/asrkb/ime/BackspaceGestureHandler.kt index 5c4d802a..89cb7090 100644 --- a/app/src/main/java/com/brycewg/asrkb/ime/BackspaceGestureHandler.kt +++ b/app/src/main/java/com/brycewg/asrkb/ime/BackspaceGestureHandler.kt @@ -17,7 +17,7 @@ import kotlin.math.abs * - 清空后向下滑动:恢复清空前的文本 */ class BackspaceGestureHandler( - private val inputHelper: InputConnectionHelper + private val inputHelper: InputConnectionHelper, ) { // 回调接口 interface Listener { @@ -34,9 +34,11 @@ class BackspaceGestureHandler( private var startY: Float = 0f private var isPressed: Boolean = false private var clearedInGesture: Boolean = false + // 单次手势内去抖标记:避免在 MOVE 过程中重复触发撤销/恢复 private var undoTriggeredInGesture: Boolean = false private var restoredAfterClearInGesture: Boolean = false + // 清空发生时的参考坐标:用于在同一手势中判定“清空后向下滑动恢复” private var clearRefX: Float = 0f private var clearRefY: Float = 0f @@ -80,8 +82,8 @@ class BackspaceGestureHandler( // 向上/向左滑动:清空所有文本(要求方向占优,减少误触) if (!clearedInGesture && ( (dy <= -slop && absDy >= absDx) || - (dx <= -slop && absDx >= absDy) - ) + (dx <= -slop && absDx >= absDy) + ) ) { // 记录清空时的参考坐标(用于后续“清空后下滑恢复”的判定) clearRefX = event.x @@ -115,7 +117,7 @@ class BackspaceGestureHandler( val dx = event.x - startX val dy = event.y - startY val isTap = abs(dx) < slop && abs(dy) < slop && - !clearedInGesture && !longPressStarted + !clearedInGesture && !longPressStarted onActionUp(view, isTap, event.actionMasked == MotionEvent.ACTION_UP) return true @@ -171,7 +173,11 @@ class BackspaceGestureHandler( listener?.onVibrateRequest() // 离开按压态 - try { view.isPressed = false } catch (e: Throwable) { android.util.Log.w("BackspaceGestureHandler", "Failed to set pressed=false (clear)", e) } + try { + view.isPressed = false + } catch (e: Throwable) { + android.util.Log.w("BackspaceGestureHandler", "Failed to set pressed=false (clear)", e) + } } private fun onSwipeToUndo(view: View) { @@ -183,7 +189,11 @@ class BackspaceGestureHandler( listener?.onUndo() // 离开按压态 - try { view.isPressed = false } catch (e: Throwable) { android.util.Log.w("BackspaceGestureHandler", "Failed to set pressed=false (undo)", e) } + try { + view.isPressed = false + } catch (e: Throwable) { + android.util.Log.w("BackspaceGestureHandler", "Failed to set pressed=false (undo)", e) + } } private fun onRestoreAfterClear(view: View, ic: InputConnection) { @@ -243,7 +253,10 @@ class BackspaceGestureHandler( repeatRunnable = null // 释放按压态 - try { view.isPressed = false } catch (e: Throwable) { android.util.Log.w("BackspaceGestureHandler", "Failed to set pressed=false (up)", e) } + try { + view.isPressed = false + } catch (e: Throwable) { + android.util.Log.w("BackspaceGestureHandler", "Failed to set pressed=false (up)", e) + } } - } diff --git a/app/src/main/java/com/brycewg/asrkb/ime/ClipboardPanelAdapter.kt b/app/src/main/java/com/brycewg/asrkb/ime/ClipboardPanelAdapter.kt index c4a9ca6f..17aff392 100644 --- a/app/src/main/java/com/brycewg/asrkb/ime/ClipboardPanelAdapter.kt +++ b/app/src/main/java/com/brycewg/asrkb/ime/ClipboardPanelAdapter.kt @@ -11,7 +11,7 @@ import com.brycewg.asrkb.R import com.brycewg.asrkb.clipboard.ClipboardHistoryStore class ClipboardPanelAdapter( - private val onItemClick: (ClipboardHistoryStore.Entry) -> Unit + private val onItemClick: (ClipboardHistoryStore.Entry) -> Unit, ) : ListAdapter(DIFF) { companion object { @@ -40,7 +40,7 @@ class ClipboardPanelAdapter( fun bind( e: ClipboardHistoryStore.Entry, - onClick: (ClipboardHistoryStore.Entry) -> Unit + onClick: (ClipboardHistoryStore.Entry) -> Unit, ) { // 文本与文件统一使用 Entry 自带的展示文案,文件为「EXT-名称」形式 tv.text = e.getDisplayLabel() @@ -49,4 +49,3 @@ class ClipboardPanelAdapter( } } } - diff --git a/app/src/main/java/com/brycewg/asrkb/ime/ClipboardPanelController.kt b/app/src/main/java/com/brycewg/asrkb/ime/ClipboardPanelController.kt index 7c004c58..52ab1ad9 100644 --- a/app/src/main/java/com/brycewg/asrkb/ime/ClipboardPanelController.kt +++ b/app/src/main/java/com/brycewg/asrkb/ime/ClipboardPanelController.kt @@ -127,7 +127,7 @@ internal class ClipboardPanelController( override fun onMove( recyclerView: RecyclerView, viewHolder: RecyclerView.ViewHolder, - target: RecyclerView.ViewHolder + target: RecyclerView.ViewHolder, ): Boolean = false override fun getSwipeDirs(recyclerView: RecyclerView, viewHolder: RecyclerView.ViewHolder): Int { @@ -190,4 +190,3 @@ internal class ClipboardPanelController( showPopupMenuKeepingIme(popup) } } - diff --git a/app/src/main/java/com/brycewg/asrkb/ime/DictationUseCase.kt b/app/src/main/java/com/brycewg/asrkb/ime/DictationUseCase.kt index a32e987e..aefb8536 100644 --- a/app/src/main/java/com/brycewg/asrkb/ime/DictationUseCase.kt +++ b/app/src/main/java/com/brycewg/asrkb/ime/DictationUseCase.kt @@ -23,13 +23,13 @@ internal class DictationUseCase( private val transitionToState: (KeyboardState) -> Unit, private val transitionToIdle: (keepMessage: Boolean) -> Unit, private val transitionToIdleWithTiming: (showBackupUsedHint: Boolean) -> Unit, - private val scheduleProcessingTimeout: (audioMsOverride: Long?) -> Unit + private val scheduleProcessingTimeout: (audioMsOverride: Long?) -> Unit, ) { suspend fun handleFinal( ic: InputConnection, text: String, state: KeyboardState.Listening, - seq: Long + seq: Long, ) { if (prefs.postProcessEnabled && prefs.hasLlmKeys()) { handleWithPostprocess(ic, text, state, seq) @@ -42,7 +42,7 @@ internal class DictationUseCase( ic: InputConnection, text: String, state: KeyboardState.Listening, - seq: Long + seq: Long, ) { if (isCancelled(seq)) return @@ -56,7 +56,7 @@ internal class DictationUseCase( onFinalReady = { processingTimeoutController.cancel() }, onPostprocFailed = { uiListenerProvider()?.onStatusMessage(context.getString(R.string.status_llm_failed_used_raw)) - } + }, ) ?: return if (isCancelled(seq)) return @@ -86,7 +86,7 @@ internal class DictationUseCase( PostprocCommit(finalOut, rawText) } else { null - } + }, ) } @@ -94,7 +94,7 @@ internal class DictationUseCase( text = finalOut, aiProcessed = aiUsed, aiPostMs = aiPostMs, - aiPostStatus = aiPostStatus + aiPostStatus = aiPostStatus, ) uiListenerProvider()?.onVibrate() @@ -123,7 +123,7 @@ internal class DictationUseCase( ic: InputConnection, text: String, state: KeyboardState.Listening, - seq: Long + seq: Long, ) { val finalToCommit = com.brycewg.asrkb.util.AsrFinalFilters.applySimple(context, prefs, text) @@ -164,7 +164,7 @@ internal class DictationUseCase( updateSessionContext { prev -> prev.copy( lastAsrCommitText = finalToCommit, - lastPostprocCommit = null + lastPostprocCommit = null, ) } diff --git a/app/src/main/java/com/brycewg/asrkb/ime/ExtensionButtonAction.kt b/app/src/main/java/com/brycewg/asrkb/ime/ExtensionButtonAction.kt index d0b24e4e..79b71775 100644 --- a/app/src/main/java/com/brycewg/asrkb/ime/ExtensionButtonAction.kt +++ b/app/src/main/java/com/brycewg/asrkb/ime/ExtensionButtonAction.kt @@ -8,7 +8,7 @@ import com.brycewg.asrkb.R enum class ExtensionButtonAction( val id: String, val titleResId: Int, - val iconResId: Int + val iconResId: Int, ) { /** * 禁用(不显示按钮或显示为灰色) @@ -16,7 +16,7 @@ enum class ExtensionButtonAction( NONE( id = "none", titleResId = R.string.ext_btn_none, - iconResId = R.drawable.dots_three_outline + iconResId = R.drawable.dots_three_outline, ), /** @@ -25,7 +25,7 @@ enum class ExtensionButtonAction( SELECT( id = "select", titleResId = R.string.ext_btn_select, - iconResId = R.drawable.selection_toggle + iconResId = R.drawable.selection_toggle, ), /** @@ -34,7 +34,7 @@ enum class ExtensionButtonAction( SELECT_ALL( id = "select_all", titleResId = R.string.ext_btn_select_all, - iconResId = R.drawable.selection_all_toggle + iconResId = R.drawable.selection_all_toggle, ), /** @@ -43,7 +43,7 @@ enum class ExtensionButtonAction( COPY( id = "copy", titleResId = R.string.ext_btn_copy, - iconResId = R.drawable.copy_toggle + iconResId = R.drawable.copy_toggle, ), /** @@ -52,7 +52,7 @@ enum class ExtensionButtonAction( PASTE( id = "paste", titleResId = R.string.ext_btn_paste, - iconResId = R.drawable.selection_background_toggle + iconResId = R.drawable.selection_background_toggle, ), /** @@ -61,7 +61,7 @@ enum class ExtensionButtonAction( CLIPBOARD( id = "clipboard", titleResId = R.string.ext_btn_clipboard, - iconResId = R.drawable.clipboard_toggle + iconResId = R.drawable.clipboard_toggle, ), /** @@ -70,7 +70,7 @@ enum class ExtensionButtonAction( HIDE_KEYBOARD( id = "hide_keyboard", titleResId = R.string.ext_btn_hide, - iconResId = R.drawable.caret_circle_down_toggle + iconResId = R.drawable.caret_circle_down_toggle, ), /** @@ -79,17 +79,16 @@ enum class ExtensionButtonAction( SILENCE_AUTOSTOP_TOGGLE( id = "silence_autostop_toggle", titleResId = R.string.ext_btn_silence_autostop, - iconResId = R.drawable.hand_palm + iconResId = R.drawable.hand_palm, ), - /** * 光标左移一位(长按连发) */ CURSOR_LEFT( id = "cursor_left", titleResId = R.string.ext_btn_cursor_left, - iconResId = R.drawable.arrow_left_toggle + iconResId = R.drawable.arrow_left_toggle, ), /** @@ -98,7 +97,7 @@ enum class ExtensionButtonAction( CURSOR_RIGHT( id = "cursor_right", titleResId = R.string.ext_btn_cursor_right, - iconResId = R.drawable.arrow_right_toggle + iconResId = R.drawable.arrow_right_toggle, ), /** @@ -107,7 +106,7 @@ enum class ExtensionButtonAction( MOVE_START( id = "move_start", titleResId = R.string.ext_btn_move_start, - iconResId = R.drawable.arrow_line_left_toggle + iconResId = R.drawable.arrow_line_left_toggle, ), /** @@ -116,7 +115,7 @@ enum class ExtensionButtonAction( MOVE_END( id = "move_end", titleResId = R.string.ext_btn_move_end, - iconResId = R.drawable.arrow_line_right_toggle + iconResId = R.drawable.arrow_line_right_toggle, ), /** @@ -125,7 +124,7 @@ enum class ExtensionButtonAction( NUMPAD( id = "numpad", titleResId = R.string.ext_btn_numpad, - iconResId = R.drawable.numpad_toggle + iconResId = R.drawable.numpad_toggle, ), /** @@ -134,8 +133,9 @@ enum class ExtensionButtonAction( UNDO( id = "undo", titleResId = R.string.ext_btn_undo, - iconResId = R.drawable.arrow_u_up_left_toggle - ); + iconResId = R.drawable.arrow_u_up_left_toggle, + ), + ; companion object { /** @@ -154,4 +154,3 @@ enum class ExtensionButtonAction( } } } - diff --git a/app/src/main/java/com/brycewg/asrkb/ime/ExtensionButtonActionDispatcher.kt b/app/src/main/java/com/brycewg/asrkb/ime/ExtensionButtonActionDispatcher.kt index 417d9a18..23956711 100644 --- a/app/src/main/java/com/brycewg/asrkb/ime/ExtensionButtonActionDispatcher.kt +++ b/app/src/main/java/com/brycewg/asrkb/ime/ExtensionButtonActionDispatcher.kt @@ -12,11 +12,11 @@ internal class ExtensionButtonActionDispatcher( private val inputHelper: InputConnectionHelper, private val uiListenerProvider: () -> KeyboardActionHandler.UiListener?, private val handleUndo: (InputConnection) -> Boolean, - private val logTag: String + private val logTag: String, ) { fun dispatch( action: ExtensionButtonAction, - ic: InputConnection? + ic: InputConnection?, ): KeyboardActionHandler.ExtensionButtonActionResult { return when (action) { ExtensionButtonAction.NONE -> KeyboardActionHandler.ExtensionButtonActionResult.SUCCESS @@ -122,4 +122,3 @@ internal class ExtensionButtonActionDispatcher( } } } - diff --git a/app/src/main/java/com/brycewg/asrkb/ime/ImeClipboardCoordinator.kt b/app/src/main/java/com/brycewg/asrkb/ime/ImeClipboardCoordinator.kt index 6e0c4929..901d0c74 100644 --- a/app/src/main/java/com/brycewg/asrkb/ime/ImeClipboardCoordinator.kt +++ b/app/src/main/java/com/brycewg/asrkb/ime/ImeClipboardCoordinator.kt @@ -30,6 +30,7 @@ internal class ImeClipboardCoordinator( ) { private var clipboardManager: ClipboardManager? = null private var clipboardChangeListener: ClipboardManager.OnPrimaryClipChangedListener? = null + @Volatile private var lastShownClipboardHash: String? = null private var syncClipboardManager: SyncClipboardManager? = null @@ -79,7 +80,7 @@ internal class ImeClipboardCoordinator( } } }, - clipStoreProvider() + clipStoreProvider(), ) } syncClipboardManager?.start() @@ -109,7 +110,7 @@ internal class ImeClipboardCoordinator( Toast.makeText( context, context.getString(R.string.clip_file_download_success), - Toast.LENGTH_SHORT + Toast.LENGTH_SHORT, ).show() // 刷新列表显示下载完成状态 if (isClipboardPanelVisible()) { @@ -119,7 +120,7 @@ internal class ImeClipboardCoordinator( Toast.makeText( context, context.getString(R.string.clip_file_download_failed), - Toast.LENGTH_SHORT + Toast.LENGTH_SHORT, ).show() // 刷新列表显示失败状态 if (isClipboardPanelVisible()) { @@ -133,7 +134,7 @@ internal class ImeClipboardCoordinator( Toast.makeText( context, context.getString(R.string.clip_file_download_error, e.message ?: ""), - Toast.LENGTH_SHORT + Toast.LENGTH_SHORT, ).show() } } @@ -153,7 +154,7 @@ internal class ImeClipboardCoordinator( Toast.makeText( context, context.getString(R.string.clip_file_not_found), - Toast.LENGTH_SHORT + Toast.LENGTH_SHORT, ).show() return } @@ -161,7 +162,7 @@ internal class ImeClipboardCoordinator( val uri = FileProvider.getUriForFile( context, "${context.packageName}.fileprovider", - file + file, ) val intent = Intent(Intent.ACTION_VIEW).apply { @@ -183,7 +184,7 @@ internal class ImeClipboardCoordinator( context.startActivity( Intent.createChooser(shareIntent, context.getString(R.string.clip_file_open_chooser_title)).apply { addFlags(Intent.FLAG_ACTIVITY_NEW_TASK) - } + }, ) } } catch (e: Exception) { @@ -191,7 +192,7 @@ internal class ImeClipboardCoordinator( Toast.makeText( context, context.getString(R.string.clip_file_open_failed, e.message ?: ""), - Toast.LENGTH_SHORT + Toast.LENGTH_SHORT, ).show() } } diff --git a/app/src/main/java/com/brycewg/asrkb/ime/ImeExtensionButtonsController.kt b/app/src/main/java/com/brycewg/asrkb/ime/ImeExtensionButtonsController.kt index d0b8325c..63c2da19 100644 --- a/app/src/main/java/com/brycewg/asrkb/ime/ImeExtensionButtonsController.kt +++ b/app/src/main/java/com/brycewg/asrkb/ime/ImeExtensionButtonsController.kt @@ -104,7 +104,8 @@ internal class ImeExtensionButtonsController( } KeyboardActionHandler.ExtensionButtonActionResult.NEED_CURSOR_LEFT, - KeyboardActionHandler.ExtensionButtonActionResult.NEED_CURSOR_RIGHT -> { + KeyboardActionHandler.ExtensionButtonActionResult.NEED_CURSOR_RIGHT, + -> { // 光标移动已在长按处理中完成 } @@ -176,4 +177,3 @@ internal class ImeExtensionButtonsController( updateBtn(views.btnExt4, prefs.extBtn4) } } - diff --git a/app/src/main/java/com/brycewg/asrkb/ime/ImeLayoutController.kt b/app/src/main/java/com/brycewg/asrkb/ime/ImeLayoutController.kt index ef2a6bed..4ca086b3 100644 --- a/app/src/main/java/com/brycewg/asrkb/ime/ImeLayoutController.kt +++ b/app/src/main/java/com/brycewg/asrkb/ime/ImeLayoutController.kt @@ -205,7 +205,7 @@ internal class ImeLayoutController( R.id.btnAiPanelMoveEnd, // 剪贴板面板按钮 R.id.clip_btnBack, - R.id.clip_btnDelete + R.id.clip_btnDelete, ) ids40.forEach { scaleSquareButton(it) } scaleGestureButton(refs?.btnGestureCancel) @@ -417,8 +417,8 @@ internal class ImeLayoutController( "needsColdStartFix" to needsColdStartFix, "needsOverlapFix" to needsOverlapFix, "correctionThresholdPx" to correctionThresholdPx, - "afterTop" to top - ) + "afterTop" to top, + ), ) } } diff --git a/app/src/main/java/com/brycewg/asrkb/ime/ImeMainKeyboardBinder.kt b/app/src/main/java/com/brycewg/asrkb/ime/ImeMainKeyboardBinder.kt index 5f1e9f77..fb74d39e 100644 --- a/app/src/main/java/com/brycewg/asrkb/ime/ImeMainKeyboardBinder.kt +++ b/app/src/main/java/com/brycewg/asrkb/ime/ImeMainKeyboardBinder.kt @@ -159,8 +159,8 @@ internal class ImeMainKeyboardBinder( views.btnPunct2?.setOnTouchListener( createSwipeUpToAltListener( primary = { prefs.punct1 }, - secondary = { prefs.punct2 } - ) + secondary = { prefs.punct2 }, + ), ) // 右侧合并标点键(3/4) @@ -171,8 +171,8 @@ internal class ImeMainKeyboardBinder( views.btnPunct3?.setOnTouchListener( createSwipeUpToAltListener( primary = { prefs.punct3 }, - secondary = { prefs.punct4 } - ) + secondary = { prefs.punct4 }, + ), ) // 第四个按键:供应商切换按钮(样式与 Prompt 选择类似) @@ -218,7 +218,7 @@ internal class ImeMainKeyboardBinder( */ private fun createSwipeUpToAltListener( primary: () -> String, - secondary: () -> String + secondary: () -> String, ): View.OnTouchListener { val touchSlop = ViewConfiguration.get(context).scaledTouchSlop val thresholdPx = (24f * context.resources.displayMetrics.density).toInt().coerceAtLeast(touchSlop) @@ -249,4 +249,3 @@ internal class ImeMainKeyboardBinder( } } } - diff --git a/app/src/main/java/com/brycewg/asrkb/ime/ImeThemeStyler.kt b/app/src/main/java/com/brycewg/asrkb/ime/ImeThemeStyler.kt index 2217d299..dea81f65 100644 --- a/app/src/main/java/com/brycewg/asrkb/ime/ImeThemeStyler.kt +++ b/app/src/main/java/com/brycewg/asrkb/ime/ImeThemeStyler.kt @@ -4,8 +4,8 @@ import android.content.Context import android.view.View import android.view.Window import androidx.core.graphics.ColorUtils -import androidx.core.view.WindowInsetsControllerCompat import androidx.core.view.ViewCompat +import androidx.core.view.WindowInsetsControllerCompat import com.brycewg.asrkb.UiColorTokens import com.brycewg.asrkb.UiColors @@ -34,7 +34,7 @@ internal class ImeThemeStyler { fun installKeyboardInsetsListener( rootView: View, - onSystemBarsBottomInsetChanged: (bottom: Int) -> Unit + onSystemBarsBottomInsetChanged: (bottom: Int) -> Unit, ) { ViewCompat.setOnApplyWindowInsetsListener(rootView) { _, windowInsets -> val bottom = ImeInsetsResolver.resolveBottomInset(windowInsets, rootView.resources) diff --git a/app/src/main/java/com/brycewg/asrkb/ime/InputConnectionHelper.kt b/app/src/main/java/com/brycewg/asrkb/ime/InputConnectionHelper.kt index b93ae6ad..7f87dc87 100644 --- a/app/src/main/java/com/brycewg/asrkb/ime/InputConnectionHelper.kt +++ b/app/src/main/java/com/brycewg/asrkb/ime/InputConnectionHelper.kt @@ -12,7 +12,7 @@ import android.view.inputmethod.InputConnection * 为所有操作添加详细日志,以便在特定应用中功能静默失败时能够快速定位问题。 */ class InputConnectionHelper( - private val tag: String = "InputConnectionHelper" + private val tag: String = "InputConnectionHelper", ) { /** * 提交文本到输入框 @@ -248,7 +248,7 @@ class InputConnectionHelper( val imeOptions = editorInfo?.imeOptions ?: 0 val action = imeOptions and EditorInfo.IME_MASK_ACTION val isMultiLine = (editorInfo?.inputType ?: 0) and - android.text.InputType.TYPE_TEXT_FLAG_MULTI_LINE != 0 + android.text.InputType.TYPE_TEXT_FLAG_MULTI_LINE != 0 val flagNoEnterAction = (imeOptions and EditorInfo.IME_FLAG_NO_ENTER_ACTION) != 0 // 根据 action 类型和输入框特性决定行为 @@ -257,11 +257,11 @@ class InputConnectionHelper( isMultiLine && flagNoEnterAction -> false // 特定的 action 类型需要执行 performEditorAction action == EditorInfo.IME_ACTION_SEND || - action == EditorInfo.IME_ACTION_GO || - action == EditorInfo.IME_ACTION_SEARCH || - action == EditorInfo.IME_ACTION_DONE || - action == EditorInfo.IME_ACTION_NEXT || - action == EditorInfo.IME_ACTION_PREVIOUS -> true + action == EditorInfo.IME_ACTION_GO || + action == EditorInfo.IME_ACTION_SEARCH || + action == EditorInfo.IME_ACTION_DONE || + action == EditorInfo.IME_ACTION_NEXT || + action == EditorInfo.IME_ACTION_PREVIOUS -> true // 其他情况发送普通回车 else -> false } @@ -379,20 +379,18 @@ class InputConnectionHelper( ic.beginBatchEdit() var replaced = false - // 尝试在光标前查找并替换 if (!before.isNullOrEmpty() && before.endsWith(oldText)) { + // 尝试在光标前查找并替换 ic.deleteSurroundingText(oldText.length, 0) ic.commitText(newText, 1) replaced = true - } - // 尝试在光标后查找并替换 - else if (!after.isNullOrEmpty() && after.startsWith(oldText)) { + } else if (!after.isNullOrEmpty() && after.startsWith(oldText)) { + // 尝试在光标后查找并替换 ic.deleteSurroundingText(0, oldText.length) ic.commitText(newText, 1) replaced = true - } - // 尝试在整个上下文中查找 - else if (before != null && after != null) { + } else if (before != null && after != null) { + // 尝试在整个上下文中查找并替换 val combined = before + after val pos = combined.lastIndexOf(oldText) if (pos >= 0) { diff --git a/app/src/main/java/com/brycewg/asrkb/ime/KeyboardActionHandler.kt b/app/src/main/java/com/brycewg/asrkb/ime/KeyboardActionHandler.kt index c474b586..a1f07dc3 100644 --- a/app/src/main/java/com/brycewg/asrkb/ime/KeyboardActionHandler.kt +++ b/app/src/main/java/com/brycewg/asrkb/ime/KeyboardActionHandler.kt @@ -8,10 +8,10 @@ import com.brycewg.asrkb.R import com.brycewg.asrkb.asr.LlmPostProcessor import com.brycewg.asrkb.asr.VadAutoStopGuard import com.brycewg.asrkb.store.Prefs +import com.brycewg.asrkb.store.debug.DebugLogManager import kotlinx.coroutines.CoroutineScope -import kotlinx.coroutines.launch import kotlinx.coroutines.delay -import com.brycewg.asrkb.store.debug.DebugLogManager +import kotlinx.coroutines.launch /** * 键盘动作处理器:作为控制器/ViewModel 管理键盘的核心状态和业务逻辑 @@ -30,7 +30,7 @@ class KeyboardActionHandler( private val prefs: Prefs, private val asrManager: AsrSessionManager, private val inputHelper: InputConnectionHelper, - private val llmPostProcessor: LlmPostProcessor + private val llmPostProcessor: LlmPostProcessor, ) : AsrSessionManager.Listener { companion object { @@ -59,6 +59,7 @@ class KeyboardActionHandler( // 强制停止标记:用于忽略上一会话迟到的 onFinal/onStopped private var dropPendingFinal: Boolean = false + // 操作序列号:用于取消在途处理(强制停止/新会话开始都会递增) private var opSeq: Long = 0L @@ -72,14 +73,14 @@ class KeyboardActionHandler( opSeqProvider = { opSeq }, audioMsProvider = { asrManager.peekLastAudioMsForStats() }, usingBackupEngineProvider = { asrManager.getEngine() is com.brycewg.asrkb.asr.ParallelAsrEngine }, - onTimeout = { transitionToIdle() } + onTimeout = { transitionToIdle() }, ) private val commitRecorder = AsrCommitRecorder( context = context, prefs = prefs, asrManager = asrManager, - logTag = TAG + logTag = TAG, ) private val postprocessPipeline = PostprocessPipeline( @@ -88,7 +89,7 @@ class KeyboardActionHandler( prefs = prefs, inputHelper = inputHelper, llmPostProcessor = llmPostProcessor, - logTag = TAG + logTag = TAG, ) private val dictationUseCase = DictationUseCase( @@ -107,7 +108,7 @@ class KeyboardActionHandler( transitionToState = { transitionToState(it) }, transitionToIdle = { keepMessage -> transitionToIdle(keepMessage = keepMessage) }, transitionToIdleWithTiming = { showBackupUsedHint -> transitionToIdleWithTiming(showBackupUsedHint) }, - scheduleProcessingTimeout = { audioMsOverride -> scheduleProcessingTimeout(audioMsOverride) } + scheduleProcessingTimeout = { audioMsOverride -> scheduleProcessingTimeout(audioMsOverride) }, ) private val aiEditUseCase = AiEditUseCase( @@ -124,7 +125,7 @@ class KeyboardActionHandler( updateSessionContext = { transform -> sessionContext = transform(sessionContext) }, transitionToState = { transitionToState(it) }, transitionToIdle = { keepMessage -> transitionToIdle(keepMessage = keepMessage) }, - transitionToIdleWithTiming = { showBackupUsedHint -> transitionToIdleWithTiming(showBackupUsedHint) } + transitionToIdleWithTiming = { showBackupUsedHint -> transitionToIdleWithTiming(showBackupUsedHint) }, ) private val promptApplyUseCase = PromptApplyUseCase( @@ -137,7 +138,7 @@ class KeyboardActionHandler( saveUndoSnapshot = { ic -> saveUndoSnapshot(ic) }, getLastAsrCommitText = { sessionContext.lastAsrCommitText }, updateSessionContext = { transform -> sessionContext = transform(sessionContext) }, - logTag = TAG + logTag = TAG, ) private val extensionButtonDispatcher = ExtensionButtonActionDispatcher( @@ -146,7 +147,7 @@ class KeyboardActionHandler( inputHelper = inputHelper, uiListenerProvider = { uiListener }, handleUndo = { ic -> handleUndo(ic) }, - logTag = TAG + logTag = TAG, ) private val retryUseCase = RetryUseCase( @@ -156,12 +157,14 @@ class KeyboardActionHandler( transitionToState = { transitionToState(it) }, transitionToIdle = { transitionToIdle() }, scheduleProcessingTimeout = { scheduleProcessingTimeout() }, - logTag = TAG + logTag = TAG, ) + // 长按期间的"按住状态"和自动重启计数(用于应对录音被系统提前中断的设备差异) private var micHoldActive: Boolean = false private var micHoldRestartCount: Int = 0 private var autoStopSuppression: AutoCloseable? = null + // 自动启动录音标志:标识当前录音是否由键盘面板自动启动 private var isAutoStartedRecording: Boolean = false @@ -226,25 +229,29 @@ class KeyboardActionHandler( "state" to currentState::class.java.simpleName, "opSeq" to opSeq, "dropPendingFinal" to dropPendingFinal, - "isAutoStarted" to isAutoStartedRecording - ) + "isAutoStarted" to isAutoStartedRecording, + ), ) } catch (_: Throwable) { } when (currentState) { is KeyboardState.Idle -> { // 开始录音 startNormalListening() - try { DebugLogManager.log("ime", "mic_tap_action", mapOf("action" to "start_listening", "opSeq" to opSeq)) } catch (_: Throwable) { } + try { + DebugLogManager.log("ime", "mic_tap_action", mapOf("action" to "start_listening", "opSeq" to opSeq)) + } catch (_: Throwable) { } } is KeyboardState.Listening -> { // 停止录音:如果是自动启动的录音,或者正常的点按模式,都执行停止 // 统一进入 Processing,显示"识别中"直到最终结果(即使未开启后处理) - isAutoStartedRecording = false // 清除自动启动标志 + isAutoStartedRecording = false // 清除自动启动标志 asrManager.stopRecording() transitionToState(KeyboardState.Processing) scheduleProcessingTimeout() uiListener?.onStatusMessage(context.getString(R.string.status_recognizing)) - try { DebugLogManager.log("ime", "mic_tap_action", mapOf("action" to "stop_and_process", "opSeq" to opSeq)) } catch (_: Throwable) { } + try { + DebugLogManager.log("ime", "mic_tap_action", mapOf("action" to "stop_and_process", "opSeq" to opSeq)) + } catch (_: Throwable) { } } is KeyboardState.Processing -> { // 强制停止:立即回到 Idle,并忽略本会话迟到的 onFinal/onStopped @@ -252,12 +259,16 @@ class KeyboardActionHandler( dropPendingFinal = true transitionToIdle(keepMessage = true) uiListener?.onStatusMessage(context.getString(R.string.status_cancelled)) - try { DebugLogManager.log("ime", "mic_tap_action", mapOf("action" to "force_stop", "opSeq" to opSeq)) } catch (_: Throwable) { } + try { + DebugLogManager.log("ime", "mic_tap_action", mapOf("action" to "force_stop", "opSeq" to opSeq)) + } catch (_: Throwable) { } } else -> { // 其他状态忽略 Log.w(TAG, "handleMicTapToggle: ignored in state $currentState") - try { DebugLogManager.log("ime", "mic_tap_action", mapOf("action" to "ignored", "state" to currentState::class.java.simpleName)) } catch (_: Throwable) { } + try { + DebugLogManager.log("ime", "mic_tap_action", mapOf("action" to "ignored", "state" to currentState::class.java.simpleName)) + } catch (_: Throwable) { } } } } @@ -277,15 +288,15 @@ class KeyboardActionHandler( "state" to currentState::class.java.simpleName, "opSeq" to opSeq, "dropPendingFinal" to dropPendingFinal, - "isAutoStarted" to isAutoStartedRecording - ) + "isAutoStarted" to isAutoStartedRecording, + ), ) } catch (_: Throwable) { } when (currentState) { is KeyboardState.Idle -> startNormalListening() is KeyboardState.Listening -> { // 如果正在录音(可能是自动启动的),长按应该停止并重新开始 - isAutoStartedRecording = false // 清除自动启动标志 + isAutoStartedRecording = false // 清除自动启动标志 } is KeyboardState.Processing -> { // 强制停止:根据模式决定后续动作 @@ -295,17 +306,23 @@ class KeyboardActionHandler( if (!prefs.micTapToggleEnabled) { // 长按模式:直接开始新一轮录音 startNormalListening() - try { DebugLogManager.log("ime", "mic_down_action", mapOf("action" to "force_stop_and_restart", "opSeq" to opSeq)) } catch (_: Throwable) { } + try { + DebugLogManager.log("ime", "mic_down_action", mapOf("action" to "force_stop_and_restart", "opSeq" to opSeq)) + } catch (_: Throwable) { } } else { // 点按切换模式:仅取消并回到空闲 transitionToIdle(keepMessage = true) uiListener?.onStatusMessage(context.getString(R.string.status_cancelled)) - try { DebugLogManager.log("ime", "mic_down_action", mapOf("action" to "force_stop_to_idle", "opSeq" to opSeq)) } catch (_: Throwable) { } + try { + DebugLogManager.log("ime", "mic_down_action", mapOf("action" to "force_stop_to_idle", "opSeq" to opSeq)) + } catch (_: Throwable) { } } } else -> { Log.w(TAG, "handleMicPressDown: ignored in state $currentState") - try { DebugLogManager.log("ime", "mic_down_action", mapOf("action" to "ignored", "state" to currentState::class.java.simpleName)) } catch (_: Throwable) { } + try { + DebugLogManager.log("ime", "mic_down_action", mapOf("action" to "ignored", "state" to currentState::class.java.simpleName)) + } catch (_: Throwable) { } } } } @@ -364,7 +381,7 @@ class KeyboardActionHandler( autoEnterOnce = autoEnterAfterFinal micHoldActive = false releaseAutoStopSuppression() - isAutoStartedRecording = false // 清除自动启动标志 + isAutoStartedRecording = false // 清除自动启动标志 try { DebugLogManager.log( category = "ime", @@ -372,8 +389,8 @@ class KeyboardActionHandler( data = mapOf( "autoEnter" to autoEnterAfterFinal, "state" to currentState::class.java.simpleName, - "opSeq" to opSeq - ) + "opSeq" to opSeq, + ), ) } catch (_: Throwable) { } if (asrManager.isRunning()) { @@ -382,16 +399,22 @@ class KeyboardActionHandler( transitionToState(KeyboardState.Processing) scheduleProcessingTimeout() uiListener?.onStatusMessage(context.getString(R.string.status_recognizing)) - try { DebugLogManager.log("ime", "mic_up_action", mapOf("action" to "stop_and_process", "autoEnter" to autoEnterAfterFinal, "opSeq" to opSeq)) } catch (_: Throwable) { } + try { + DebugLogManager.log("ime", "mic_up_action", mapOf("action" to "stop_and_process", "autoEnter" to autoEnterAfterFinal, "opSeq" to opSeq)) + } catch (_: Throwable) { } } else { // 异常:UI 处于 Listening,但引擎未在运行(例如启动失败/被系统打断)。 // 为避免卡住“正在聆听”,直接归位到 Idle 并提示“已取消”。 if (currentState is KeyboardState.Listening || currentState is KeyboardState.AiEditListening) { // 确保释放音频焦点/路由(即使引擎未在运行) - try { asrManager.stopRecording() } catch (_: Throwable) { } + try { + asrManager.stopRecording() + } catch (_: Throwable) { } transitionToIdle(keepMessage = true) uiListener?.onStatusMessage(context.getString(R.string.status_cancelled)) - try { DebugLogManager.log("ime", "mic_up_action", mapOf("action" to "not_running_cancel", "opSeq" to opSeq)) } catch (_: Throwable) { } + try { + DebugLogManager.log("ime", "mic_up_action", mapOf("action" to "not_running_cancel", "opSeq" to opSeq)) + } catch (_: Throwable) { } } } } @@ -459,7 +482,7 @@ class KeyboardActionHandler( */ fun handleExtensionButtonClick( action: com.brycewg.asrkb.ime.ExtensionButtonAction, - ic: InputConnection? + ic: InputConnection?, ): ExtensionButtonActionResult { return extensionButtonDispatcher.dispatch(action, ic) } @@ -468,15 +491,15 @@ class KeyboardActionHandler( * 扩展按钮动作结果 */ enum class ExtensionButtonActionResult { - SUCCESS, // 成功完成 - FAILED, // 失败 - NEED_TOGGLE_SELECTION, // 需要 IME 切换选择模式 - NEED_CURSOR_LEFT, // 需要 IME 处理左移(支持长按) - NEED_CURSOR_RIGHT, // 需要 IME 处理右移(支持长按) - NEED_SHOW_NUMPAD, // 需要 IME 显示数字键盘 - NEED_SHOW_CLIPBOARD, // 需要 IME 显示剪贴板面板 - NEED_HIDE_KEYBOARD, // 需要 IME 收起键盘 - NEED_TOGGLE_CONTINUOUS_TALK // 需要 IME 切换畅说模式 + SUCCESS, // 成功完成 + FAILED, // 失败 + NEED_TOGGLE_SELECTION, // 需要 IME 切换选择模式 + NEED_CURSOR_LEFT, // 需要 IME 处理左移(支持长按) + NEED_CURSOR_RIGHT, // 需要 IME 处理右移(支持长按) + NEED_SHOW_NUMPAD, // 需要 IME 显示数字键盘 + NEED_SHOW_CLIPBOARD, // 需要 IME 显示剪贴板面板 + NEED_HIDE_KEYBOARD, // 需要 IME 收起键盘 + NEED_TOGGLE_CONTINUOUS_TALK, // 需要 IME 切换畅说模式 } /** @@ -536,7 +559,7 @@ class KeyboardActionHandler( fullText = label, displaySnippet = label, type = ClipboardPreviewType.FILE, - fileEntryId = entry.id + fileEntryId = entry.id, ) sessionContext = sessionContext.copy(clipboardPreview = preview) uiListener?.onShowClipboardPreview(preview) @@ -678,51 +701,77 @@ class KeyboardActionHandler( // 若仍在长按且为非点按模式,并且上一轮录音时长极短,则判定为系统提前中断,自动重启一次录音 // 这样用户的“持续按住说话”不会因为系统打断而直接被判定为取消 // 仅用于早停判定:读取但不清空,避免影响后续统计与历史写入 - val earlyMs = try { asrManager.peekLastAudioMsForStats() } catch (t: Throwable) { 0L } + val earlyMs = try { + asrManager.peekLastAudioMsForStats() + } catch (t: Throwable) { + 0L + } if (!prefs.micTapToggleEnabled && micHoldActive && earlyMs in 1..250) { if (micHoldRestartCount < 1) { micHoldRestartCount += 1 - try { DebugLogManager.log("ime", "auto_restart_after_early_stop", mapOf("audioMs" to earlyMs, "count" to micHoldRestartCount, "opSeq" to opSeq)) } catch (_: Throwable) { } + try { + DebugLogManager.log("ime", "auto_restart_after_early_stop", mapOf("audioMs" to earlyMs, "count" to micHoldRestartCount, "opSeq" to opSeq)) + } catch (_: Throwable) { } startNormalListening() return@launch } else { - try { DebugLogManager.log("ime", "auto_restart_skip", mapOf("audioMs" to earlyMs, "count" to micHoldRestartCount, "opSeq" to opSeq)) } catch (_: Throwable) { } + try { + DebugLogManager.log("ime", "auto_restart_skip", mapOf("audioMs" to earlyMs, "count" to micHoldRestartCount, "opSeq" to opSeq)) + } catch (_: Throwable) { } } } // 若强制停止,忽略迟到的 onStopped if (dropPendingFinal) return@launch // 若此时已经开始了新的录音(引擎运行中),则将本次 onStopped 视为上一会话的迟到事件并忽略。 if (asrManager.isRunning()) { - try { asrManager.popLastAudioMsForStats() } catch (_: Throwable) { } - try { DebugLogManager.log("ime", "asr_stopped_ignored", mapOf("reason" to "new_session_running", "opSeq" to opSeq)) } catch (_: Throwable) { } + try { + asrManager.popLastAudioMsForStats() + } catch (_: Throwable) { } + try { + DebugLogManager.log("ime", "asr_stopped_ignored", mapOf("reason" to "new_session_running", "opSeq" to opSeq)) + } catch (_: Throwable) { } return@launch } // 误触极短录音:直接取消,避免进入“识别中…”阻塞后续长按 val audioMs = earlyMs // 若前面 earlyMs==0(例如未知或异常),再尝试一次以兼容既有逻辑 - val audioMsVal = if (audioMs != 0L) audioMs else try { asrManager.peekLastAudioMsForStats() } catch (t: Throwable) { - Log.w(TAG, "popLastAudioMsForStats failed", t) - 0L + val audioMsVal = if (audioMs != 0L) { + audioMs + } else { + try { + asrManager.peekLastAudioMsForStats() + } catch (t: Throwable) { + Log.w(TAG, "popLastAudioMsForStats failed", t) + 0L + } } if (audioMsVal in 1..250) { // 将后续迟到回调丢弃并归位 dropPendingFinal = true transitionToIdle() uiListener?.onStatusMessage(context.getString(R.string.status_cancelled)) - try { DebugLogManager.log("ime", "asr_stopped", mapOf("audioMs" to audioMsVal, "action" to "cancel_short", "opSeq" to opSeq)) } catch (_: Throwable) { } + try { + DebugLogManager.log("ime", "asr_stopped", mapOf("audioMs" to audioMsVal, "action" to "cancel_short", "opSeq" to opSeq)) + } catch (_: Throwable) { } return@launch } // 正常流程:进入 Processing,等待最终结果或兜底 transitionToState(KeyboardState.Processing) scheduleProcessingTimeout(audioMsVal) uiListener?.onStatusMessage(context.getString(R.string.status_recognizing)) - try { DebugLogManager.log("ime", "asr_stopped", mapOf("audioMs" to audioMs, "action" to "enter_processing", "opSeq" to opSeq)) } catch (_: Throwable) { } + try { + DebugLogManager.log("ime", "asr_stopped", mapOf("audioMs" to audioMs, "action" to "enter_processing", "opSeq" to opSeq)) + } catch (_: Throwable) { } } } override fun onLocalModelLoadStart() { // 记录开始时间,用于计算加载耗时 - try { modelLoadStartUptimeMs = android.os.SystemClock.uptimeMillis() } catch (_: Throwable) { modelLoadStartUptimeMs = 0L } + try { + modelLoadStartUptimeMs = android.os.SystemClock.uptimeMillis() + } catch (_: Throwable) { + modelLoadStartUptimeMs = 0L + } val resId = if (currentState is KeyboardState.Listening || currentState is KeyboardState.AiEditListening) { R.string.sv_loading_model_while_listening } else { @@ -735,7 +784,9 @@ class KeyboardActionHandler( val dt = try { val now = android.os.SystemClock.uptimeMillis() if (modelLoadStartUptimeMs > 0L && now >= modelLoadStartUptimeMs) now - modelLoadStartUptimeMs else -1L - } catch (_: Throwable) { -1L } + } catch (_: Throwable) { + -1L + } if (dt > 0) { uiListener?.onStatusMessage(context.getString(R.string.sv_model_ready_with_ms, dt)) } else { @@ -760,19 +811,22 @@ class KeyboardActionHandler( data = mapOf( "from" to prev::class.java.simpleName, "to" to newState::class.java.simpleName, - "opSeq" to opSeq - ) + "opSeq" to opSeq, + ), ) } catch (_: Throwable) { } // 仅在携带文本上下文的状态下同步到 AsrSessionManager, // 避免切到 Processing 后丢失 partialText 影响最终合并 when (newState) { is KeyboardState.Listening, - is KeyboardState.AiEditListening -> asrManager.setCurrentState(newState) + is KeyboardState.AiEditListening, + -> asrManager.setCurrentState(newState) else -> { /* keep previous contextual state in AsrSessionManager */ } } if (newState !is KeyboardState.Idle) { - try { uiListener?.onHideRetryChip() } catch (_: Throwable) {} + try { + uiListener?.onHideRetryChip() + } catch (_: Throwable) {} } uiListener?.onStateChanged(newState) } @@ -780,10 +834,12 @@ class KeyboardActionHandler( private fun transitionToIdle(keepMessage: Boolean = false) { // 新的显式归位:递增操作序列,取消在途处理 opSeq++ - try { DebugLogManager.log("ime", "opseq_inc", mapOf("at" to "to_idle", "opSeq" to opSeq)) } catch (_: Throwable) { } + try { + DebugLogManager.log("ime", "opseq_inc", mapOf("at" to "to_idle", "opSeq" to opSeq)) + } catch (_: Throwable) { } processingTimeoutController.cancel() autoEnterOnce = false - isAutoStartedRecording = false // 清除自动启动标志 + isAutoStartedRecording = false // 清除自动启动标志 transitionToState(KeyboardState.Idle) if (!keepMessage) { uiListener?.onStatusMessage(context.getString(R.string.status_idle)) @@ -793,11 +849,15 @@ class KeyboardActionHandler( private fun startNormalListening() { // 开启新一轮录音:递增操作序列,取消在途处理 opSeq++ - try { DebugLogManager.log("ime", "opseq_inc", mapOf("at" to "start_listening", "opSeq" to opSeq)) } catch (_: Throwable) { } + try { + DebugLogManager.log("ime", "opseq_inc", mapOf("at" to "start_listening", "opSeq" to opSeq)) + } catch (_: Throwable) { } processingTimeoutController.cancel() dropPendingFinal = false autoEnterOnce = false - try { uiListener?.onHideRetryChip() } catch (_: Throwable) {} + try { + uiListener?.onHideRetryChip() + } catch (_: Throwable) {} val state = KeyboardState.Listening() transitionToState(state) asrManager.startRecording(state) @@ -864,7 +924,6 @@ class KeyboardActionHandler( } } - /** * 获取当前输入连接(需要从外部注入) * 这是一个临时方案,实际应该通过参数传递 diff --git a/app/src/main/java/com/brycewg/asrkb/ime/KeyboardState.kt b/app/src/main/java/com/brycewg/asrkb/ime/KeyboardState.kt index 484f9b44..cea14a87 100644 --- a/app/src/main/java/com/brycewg/asrkb/ime/KeyboardState.kt +++ b/app/src/main/java/com/brycewg/asrkb/ime/KeyboardState.kt @@ -18,7 +18,7 @@ sealed class KeyboardState { data class Listening( val partialText: String? = null, val committedStableLen: Int = 0, - val lockedBySwipe: Boolean = false + val lockedBySwipe: Boolean = false, ) : KeyboardState() /** @@ -41,7 +41,7 @@ sealed class KeyboardState { data class AiEditListening( val targetIsSelection: Boolean, val targetText: String, - val instruction: String? = null + val instruction: String? = null, ) : KeyboardState() /** @@ -53,7 +53,7 @@ sealed class KeyboardState { data class AiEditProcessing( val targetIsSelection: Boolean, val targetText: String, - val instruction: String + val instruction: String, ) : KeyboardState() } @@ -62,7 +62,7 @@ sealed class KeyboardState { */ data class UndoSnapshot( val beforeCursor: CharSequence, - val afterCursor: CharSequence + val afterCursor: CharSequence, ) /** @@ -70,7 +70,7 @@ data class UndoSnapshot( */ data class PostprocCommit( val processed: String, - val raw: String + val raw: String, ) /** @@ -78,7 +78,7 @@ data class PostprocCommit( */ enum class ClipboardPreviewType { TEXT, - FILE + FILE, } /** @@ -88,7 +88,7 @@ data class ClipboardPreview( val fullText: String, val displaySnippet: String, val type: ClipboardPreviewType = ClipboardPreviewType.TEXT, - val fileEntryId: String? = null + val fileEntryId: String? = null, ) /** @@ -99,5 +99,5 @@ data class KeyboardSessionContext( val lastPostprocCommit: PostprocCommit? = null, val undoSnapshot: UndoSnapshot? = null, val lastRequestDurationMs: Long? = null, - val clipboardPreview: ClipboardPreview? = null + val clipboardPreview: ClipboardPreview? = null, ) diff --git a/app/src/main/java/com/brycewg/asrkb/ime/MicGestureController.kt b/app/src/main/java/com/brycewg/asrkb/ime/MicGestureController.kt index 4f15b060..4d1337ef 100644 --- a/app/src/main/java/com/brycewg/asrkb/ime/MicGestureController.kt +++ b/app/src/main/java/com/brycewg/asrkb/ime/MicGestureController.kt @@ -16,7 +16,10 @@ internal class MicGestureController( private val onLockedBySwipeChanged: () -> Unit, ) { private enum class GestureState { - None, PendingCancel, PendingSend, PendingLock + None, + PendingCancel, + PendingSend, + PendingLock, } private var state: GestureState = GestureState.None @@ -31,8 +34,8 @@ internal class MicGestureController( event = "mic_click_locked", data = mapOf( "state" to actionHandler.getCurrentState()::class.java.simpleName, - "running" to (actionHandler.getCurrentState() is KeyboardState.Listening) - ) + "running" to (actionHandler.getCurrentState() is KeyboardState.Listening), + ), ) actionHandler.handleLockedMicTap() return@setOnClickListener @@ -44,8 +47,8 @@ internal class MicGestureController( "tapToggle" to true, "state" to actionHandler.getCurrentState()::class.java.simpleName, "running" to (actionHandler.getCurrentState() is KeyboardState.Listening), - "aiPanel" to isAiEditPanelVisible() - ) + "aiPanel" to isAiEditPanelVisible(), + ), ) if (isAiEditPanelVisible()) { actionHandler.handleAiEditClick(inputConnectionProvider()) @@ -90,8 +93,8 @@ internal class MicGestureController( data = mapOf( "tapToggle" to false, "aiPanel" to true, - "state" to actionHandler.getCurrentState()::class.java.simpleName - ) + "state" to actionHandler.getCurrentState()::class.java.simpleName, + ), ) v.performClick() return true @@ -101,8 +104,8 @@ internal class MicGestureController( data = mapOf( "tapToggle" to false, "state" to actionHandler.getCurrentState()::class.java.simpleName, - "running" to (actionHandler.getCurrentState() is KeyboardState.Listening) - ) + "running" to (actionHandler.getCurrentState() is KeyboardState.Listening), + ), ) actionHandler.handleAiEditClick(inputConnectionProvider()) return true @@ -113,8 +116,8 @@ internal class MicGestureController( data = mapOf( "tapToggle" to false, "state" to actionHandler.getCurrentState()::class.java.simpleName, - "running" to (actionHandler.getCurrentState() is KeyboardState.Listening) - ) + "running" to (actionHandler.getCurrentState() is KeyboardState.Listening), + ), ) if (actionHandler.getCurrentState() is KeyboardState.AiEditListening) { actionHandler.handleAiEditClick(inputConnectionProvider()) @@ -127,8 +130,8 @@ internal class MicGestureController( event = "ai_mic_cancel", data = mapOf( "tapToggle" to false, - "state" to actionHandler.getCurrentState()::class.java.simpleName - ) + "state" to actionHandler.getCurrentState()::class.java.simpleName, + ), ) if (actionHandler.getCurrentState() is KeyboardState.AiEditListening) { actionHandler.handleAiEditClick(inputConnectionProvider()) @@ -149,8 +152,8 @@ internal class MicGestureController( event = "mic_down_blocked", data = mapOf( "tapToggle" to false, - "state" to actionHandler.getCurrentState()::class.java.simpleName - ) + "state" to actionHandler.getCurrentState()::class.java.simpleName, + ), ) v.performClick() return true @@ -161,8 +164,8 @@ internal class MicGestureController( data = mapOf( "tapToggle" to false, "state" to actionHandler.getCurrentState()::class.java.simpleName, - "running" to (actionHandler.getCurrentState() is KeyboardState.Listening) - ) + "running" to (actionHandler.getCurrentState() is KeyboardState.Listening), + ), ) actionHandler.handleMicPressDown() return true @@ -209,8 +212,8 @@ internal class MicGestureController( data = mapOf( "tapToggle" to false, "state" to actionHandler.getCurrentState()::class.java.simpleName, - "running" to (actionHandler.getCurrentState() is KeyboardState.Listening) - ) + "running" to (actionHandler.getCurrentState() is KeyboardState.Listening), + ), ) actionHandler.handleMicPressUp(false) v.performClick() @@ -224,8 +227,8 @@ internal class MicGestureController( data = mapOf( "tapToggle" to false, "state" to actionHandler.getCurrentState()::class.java.simpleName, - "running" to (actionHandler.getCurrentState() is KeyboardState.Listening) - ) + "running" to (actionHandler.getCurrentState() is KeyboardState.Listening), + ), ) state = GestureState.None updatePressedState(GestureState.None) @@ -262,4 +265,3 @@ internal class MicGestureController( } } } - diff --git a/app/src/main/java/com/brycewg/asrkb/ime/PostprocessPipeline.kt b/app/src/main/java/com/brycewg/asrkb/ime/PostprocessPipeline.kt index 65d25dcc..0a9f8f06 100644 --- a/app/src/main/java/com/brycewg/asrkb/ime/PostprocessPipeline.kt +++ b/app/src/main/java/com/brycewg/asrkb/ime/PostprocessPipeline.kt @@ -18,7 +18,7 @@ internal class PostprocessPipeline( private val prefs: Prefs, private val inputHelper: InputConnectionHelper, private val llmPostProcessor: LlmPostProcessor, - private val logTag: String + private val logTag: String, ) { data class Result( val finalText: String, @@ -26,7 +26,7 @@ internal class PostprocessPipeline( val postprocFailed: Boolean, val aiUsed: Boolean, val aiPostMs: Long, - val aiPostStatus: AsrHistoryStore.AiPostStatus + val aiPostStatus: AsrHistoryStore.AiPostStatus, ) suspend fun process( @@ -34,7 +34,7 @@ internal class PostprocessPipeline( text: String, isCancelled: () -> Boolean, onFinalReady: () -> Unit, - onPostprocFailed: () -> Unit + onPostprocFailed: () -> Unit, ): Result? { val rawText = try { if (prefs.trimFinalTrailingPunct) TextSanitizer.trimTrailingPunctAndEmoji(text) else text @@ -55,7 +55,7 @@ internal class PostprocessPipeline( onEmit = emit@{ typed -> if (isCancelled() || committed) return@emit inputHelper.setComposingText(ic, typed) - } + }, ) } else { null @@ -82,7 +82,7 @@ internal class PostprocessPipeline( prefs, text, llmPostProcessor, - onStreamingUpdate = onStreamingUpdate + onStreamingUpdate = onStreamingUpdate, ) } catch (t: Throwable) { Log.e(logTag, "applyWithAi failed", t) @@ -161,7 +161,7 @@ internal class PostprocessPipeline( postprocFailed = postprocFailed, aiUsed = aiUsed, aiPostMs = aiPostMs, - aiPostStatus = aiPostStatus + aiPostStatus = aiPostStatus, ) } } diff --git a/app/src/main/java/com/brycewg/asrkb/ime/ProcessingTimeoutController.kt b/app/src/main/java/com/brycewg/asrkb/ime/ProcessingTimeoutController.kt index 531ef909..382afdea 100644 --- a/app/src/main/java/com/brycewg/asrkb/ime/ProcessingTimeoutController.kt +++ b/app/src/main/java/com/brycewg/asrkb/ime/ProcessingTimeoutController.kt @@ -20,7 +20,7 @@ internal class ProcessingTimeoutController( private val opSeqProvider: () -> Long, private val audioMsProvider: () -> Long, private val usingBackupEngineProvider: () -> Boolean, - private val onTimeout: () -> Unit + private val onTimeout: () -> Unit, ) { private var job: Job? = null @@ -51,7 +51,7 @@ internal class ProcessingTimeoutController( // 读取配置失败等异常场景:回退为原有策略(不阻塞、继续计时) Log.w( logTag, - "awaitLocalAsrReady returned false, fallback to immediate timeout countdown" + "awaitLocalAsrReady returned false, fallback to immediate timeout countdown", ) } // 若等待期间状态已变化,则不再继续计时 @@ -62,14 +62,14 @@ internal class ProcessingTimeoutController( if (currentStateProvider() is KeyboardState.Processing) { debugLog( "processing_timeout_fired", - mapOf("opSeq" to opSeqProvider(), "audioMs" to audioMs, "timeoutMs" to timeoutMs) + mapOf("opSeq" to opSeqProvider(), "audioMs" to audioMs, "timeoutMs" to timeoutMs), ) onTimeout() } } debugLog( "processing_timeout_scheduled", - mapOf("opSeq" to opSeqProvider(), "audioMs" to audioMs, "timeoutMs" to timeoutMs) + mapOf("opSeq" to opSeqProvider(), "audioMs" to audioMs, "timeoutMs" to timeoutMs), ) } @@ -101,4 +101,3 @@ internal class ProcessingTimeoutController( } } } - diff --git a/app/src/main/java/com/brycewg/asrkb/ime/PromptApplyUseCase.kt b/app/src/main/java/com/brycewg/asrkb/ime/PromptApplyUseCase.kt index 92f07a4d..e075346a 100644 --- a/app/src/main/java/com/brycewg/asrkb/ime/PromptApplyUseCase.kt +++ b/app/src/main/java/com/brycewg/asrkb/ime/PromptApplyUseCase.kt @@ -19,7 +19,7 @@ internal class PromptApplyUseCase( private val saveUndoSnapshot: (InputConnection) -> Unit, private val getLastAsrCommitText: () -> String?, private val updateSessionContext: ((KeyboardSessionContext) -> KeyboardSessionContext) -> Unit, - private val logTag: String + private val logTag: String, ) { fun apply(ic: InputConnection?, promptOverride: String?) { if (ic == null) return @@ -40,7 +40,7 @@ internal class PromptApplyUseCase( target.text, llmPostProcessor, promptOverride = promptOverride, - forceAi = true + forceAi = true, ) } catch (t: Throwable) { Log.e(logTag, "apply prompt failed", t) @@ -78,7 +78,7 @@ internal class PromptApplyUseCase( PostprocCommit(processed = out, raw = target.text) } else { null - } + }, ) } @@ -92,13 +92,13 @@ internal class PromptApplyUseCase( private data class Target( val mode: TargetMode, - val text: String + val text: String, ) private enum class TargetMode { SELECTION, LAST_ASR, - ENTIRE + ENTIRE, } private fun resolveTargetText(ic: InputConnection): Target? { @@ -124,4 +124,3 @@ internal class PromptApplyUseCase( return Target(TargetMode.ENTIRE, all) } } - diff --git a/app/src/main/java/com/brycewg/asrkb/ime/RetryUseCase.kt b/app/src/main/java/com/brycewg/asrkb/ime/RetryUseCase.kt index 3ff8bc21..5e61305b 100644 --- a/app/src/main/java/com/brycewg/asrkb/ime/RetryUseCase.kt +++ b/app/src/main/java/com/brycewg/asrkb/ime/RetryUseCase.kt @@ -12,7 +12,7 @@ internal class RetryUseCase( private val transitionToState: (KeyboardState) -> Unit, private val transitionToIdle: () -> Unit, private val scheduleProcessingTimeout: () -> Unit, - private val logTag: String + private val logTag: String, ) { fun shouldOfferRetry(message: String): Boolean { val engine = try { @@ -38,7 +38,7 @@ internal class RetryUseCase( "host", "unreachable", "rate", - "too many requests" + "too many requests", ) val looksNetwork = networkKeywords.any { kw -> kw in message || kw in msgLower } if (!looksNetwork) return false @@ -69,4 +69,3 @@ internal class RetryUseCase( } } } - diff --git a/app/src/main/java/com/brycewg/asrkb/ime/UndoManager.kt b/app/src/main/java/com/brycewg/asrkb/ime/UndoManager.kt index 704ca28a..1ca1543e 100644 --- a/app/src/main/java/com/brycewg/asrkb/ime/UndoManager.kt +++ b/app/src/main/java/com/brycewg/asrkb/ime/UndoManager.kt @@ -6,7 +6,7 @@ import android.view.inputmethod.InputConnection internal class UndoManager( private val inputHelper: InputConnectionHelper, private val logTag: String, - private val maxSnapshots: Int = 3 + private val maxSnapshots: Int = 3, ) { private val snapshots = ArrayDeque(maxSnapshots) @@ -62,4 +62,3 @@ internal class UndoManager( } } } - diff --git a/app/src/main/java/com/brycewg/asrkb/store/AnalyticsStore.kt b/app/src/main/java/com/brycewg/asrkb/store/AnalyticsStore.kt index 71e81dcf..b69ca86d 100644 --- a/app/src/main/java/com/brycewg/asrkb/store/AnalyticsStore.kt +++ b/app/src/main/java/com/brycewg/asrkb/store/AnalyticsStore.kt @@ -3,11 +3,11 @@ package com.brycewg.asrkb.store import android.content.Context import android.content.SharedPreferences import android.util.Log -import java.util.UUID import kotlinx.serialization.Serializable import kotlinx.serialization.decodeFromString import kotlinx.serialization.encodeToString import kotlinx.serialization.json.Json +import java.util.UUID /** * 统计事件本地缓存。 @@ -16,111 +16,112 @@ import kotlinx.serialization.json.Json * - 仅保存必要字段,不包含识别文本内容 */ class AnalyticsStore(context: Context) { - companion object { - private const val TAG = "AnalyticsStore" - private const val SP_NAME = "asr_prefs" - private const val KEY_ASR_EVENTS_JSON = "analytics_asr_events" - private const val KEY_APP_STARTS_JSON = "analytics_app_starts" - } + companion object { + private const val TAG = "AnalyticsStore" + private const val SP_NAME = "asr_prefs" + private const val KEY_ASR_EVENTS_JSON = "analytics_asr_events" + private const val KEY_APP_STARTS_JSON = "analytics_app_starts" + } - private val sp: SharedPreferences = context.getSharedPreferences(SP_NAME, Context.MODE_PRIVATE) - private val json = Json { - ignoreUnknownKeys = true - isLenient = true - encodeDefaults = true - } + private val sp: SharedPreferences = context.getSharedPreferences(SP_NAME, Context.MODE_PRIVATE) + private val json = Json { + ignoreUnknownKeys = true + isLenient = true + encodeDefaults = true + } - @Serializable - data class AsrEvent( - val id: String = UUID.randomUUID().toString(), - val timestamp: Long, - val vendorId: String, - val audioMs: Long, - val procMs: Long, - val source: String, // "ime" | "floating" - val aiProcessed: Boolean, - val charCount: Int - ) + @Serializable + data class AsrEvent( + val id: String = UUID.randomUUID().toString(), + val timestamp: Long, + val vendorId: String, + val audioMs: Long, + val procMs: Long, + // "ime" | "floating" + val source: String, + val aiProcessed: Boolean, + val charCount: Int, + ) - @Serializable - data class AppStartEvent( - val id: String = UUID.randomUUID().toString(), - val timestamp: Long - ) + @Serializable + data class AppStartEvent( + val id: String = UUID.randomUUID().toString(), + val timestamp: Long, + ) - private fun readAsrEventsInternal(): MutableList { - val raw = sp.getString(KEY_ASR_EVENTS_JSON, "").orEmpty() - if (raw.isBlank()) return mutableListOf() - return try { - json.decodeFromString>(raw).toMutableList() - } catch (e: Exception) { - Log.e(TAG, "Failed to parse ASR events JSON", e) - mutableListOf() + private fun readAsrEventsInternal(): MutableList { + val raw = sp.getString(KEY_ASR_EVENTS_JSON, "").orEmpty() + if (raw.isBlank()) return mutableListOf() + return try { + json.decodeFromString>(raw).toMutableList() + } catch (e: Exception) { + Log.e(TAG, "Failed to parse ASR events JSON", e) + mutableListOf() + } } - } - private fun writeAsrEventsInternal(list: List) { - try { - val text = json.encodeToString(list) - sp.edit().putString(KEY_ASR_EVENTS_JSON, text).apply() - } catch (e: Exception) { - Log.e(TAG, "Failed to write ASR events JSON", e) + private fun writeAsrEventsInternal(list: List) { + try { + val text = json.encodeToString(list) + sp.edit().putString(KEY_ASR_EVENTS_JSON, text).apply() + } catch (e: Exception) { + Log.e(TAG, "Failed to write ASR events JSON", e) + } } - } - private fun readAppStartsInternal(): MutableList { - val raw = sp.getString(KEY_APP_STARTS_JSON, "").orEmpty() - if (raw.isBlank()) return mutableListOf() - return try { - json.decodeFromString>(raw).toMutableList() - } catch (e: Exception) { - Log.e(TAG, "Failed to parse app starts JSON", e) - mutableListOf() + private fun readAppStartsInternal(): MutableList { + val raw = sp.getString(KEY_APP_STARTS_JSON, "").orEmpty() + if (raw.isBlank()) return mutableListOf() + return try { + json.decodeFromString>(raw).toMutableList() + } catch (e: Exception) { + Log.e(TAG, "Failed to parse app starts JSON", e) + mutableListOf() + } } - } - private fun writeAppStartsInternal(list: List) { - try { - val text = json.encodeToString(list) - sp.edit().putString(KEY_APP_STARTS_JSON, text).apply() - } catch (e: Exception) { - Log.e(TAG, "Failed to write app starts JSON", e) + private fun writeAppStartsInternal(list: List) { + try { + val text = json.encodeToString(list) + sp.edit().putString(KEY_APP_STARTS_JSON, text).apply() + } catch (e: Exception) { + Log.e(TAG, "Failed to write app starts JSON", e) + } } - } - fun addAsrEvent(event: AsrEvent) { - val list = readAsrEventsInternal() - list.add(event) - writeAsrEventsInternal(list) - } + fun addAsrEvent(event: AsrEvent) { + val list = readAsrEventsInternal() + list.add(event) + writeAsrEventsInternal(list) + } - fun addAppStart(event: AppStartEvent) { - val list = readAppStartsInternal() - list.add(event) - writeAppStartsInternal(list) - } + fun addAppStart(event: AppStartEvent) { + val list = readAppStartsInternal() + list.add(event) + writeAppStartsInternal(list) + } - fun listAsrEvents(): List = - readAsrEventsInternal().sortedBy { it.timestamp } + fun listAsrEvents(): List = + readAsrEventsInternal().sortedBy { it.timestamp } - fun listAppStarts(): List = - readAppStartsInternal().sortedBy { it.timestamp } + fun listAppStarts(): List = + readAppStartsInternal().sortedBy { it.timestamp } - fun deleteAsrEventsByIds(ids: Set): Int { - if (ids.isEmpty()) return 0 - val list = readAsrEventsInternal() - val before = list.size - val remained = list.filterNot { ids.contains(it.id) } - writeAsrEventsInternal(remained) - return (before - remained.size).coerceAtLeast(0) - } + fun deleteAsrEventsByIds(ids: Set): Int { + if (ids.isEmpty()) return 0 + val list = readAsrEventsInternal() + val before = list.size + val remained = list.filterNot { ids.contains(it.id) } + writeAsrEventsInternal(remained) + return (before - remained.size).coerceAtLeast(0) + } - fun deleteAppStartsByIds(ids: Set): Int { - if (ids.isEmpty()) return 0 - val list = readAppStartsInternal() - val before = list.size - val remained = list.filterNot { ids.contains(it.id) } - writeAppStartsInternal(remained) - return (before - remained.size).coerceAtLeast(0) - } + fun deleteAppStartsByIds(ids: Set): Int { + if (ids.isEmpty()) return 0 + val list = readAppStartsInternal() + val before = list.size + val remained = list.filterNot { ids.contains(it.id) } + writeAppStartsInternal(remained) + return (before - remained.size).coerceAtLeast(0) + } } diff --git a/app/src/main/java/com/brycewg/asrkb/store/AsrHistoryStore.kt b/app/src/main/java/com/brycewg/asrkb/store/AsrHistoryStore.kt index 77b52d66..ef004f6d 100644 --- a/app/src/main/java/com/brycewg/asrkb/store/AsrHistoryStore.kt +++ b/app/src/main/java/com/brycewg/asrkb/store/AsrHistoryStore.kt @@ -4,8 +4,8 @@ import android.content.Context import android.content.SharedPreferences import android.util.Log import kotlinx.serialization.Serializable -import kotlinx.serialization.encodeToString import kotlinx.serialization.decodeFromString +import kotlinx.serialization.encodeToString import kotlinx.serialization.json.Json import java.util.UUID @@ -15,91 +15,92 @@ import java.util.UUID * - 提供新增、查询、删除(单个/批量) */ class AsrHistoryStore(context: Context) { - companion object { - private const val TAG = "AsrHistoryStore" - private const val SP_NAME = "asr_prefs" - private const val KEY_ASR_HISTORY_JSON = "asr_history" - // 防止无限增长,保留最近 N 条 - private const val MAX_RECORDS = 2000 - } + companion object { + private const val TAG = "AsrHistoryStore" + private const val SP_NAME = "asr_prefs" + private const val KEY_ASR_HISTORY_JSON = "asr_history" - private val sp: SharedPreferences = context.getSharedPreferences(SP_NAME, Context.MODE_PRIVATE) - private val json = Json { - ignoreUnknownKeys = true - isLenient = true - encodeDefaults = true - } + // 防止无限增长,保留最近 N 条 + private const val MAX_RECORDS = 2000 + } + + private val sp: SharedPreferences = context.getSharedPreferences(SP_NAME, Context.MODE_PRIVATE) + private val json = Json { + ignoreUnknownKeys = true + isLenient = true + encodeDefaults = true + } - @Serializable - enum class AiPostStatus { - NONE, - SUCCESS, - FAILED - } + @Serializable + enum class AiPostStatus { + NONE, + SUCCESS, + FAILED, + } - @Serializable - data class AsrHistoryRecord( - val id: String = UUID.randomUUID().toString(), - val timestamp: Long, - val text: String, - val vendorId: String, - val audioMs: Long, - // 端到端总耗时(毫秒):从开始录音到最终提交完成(含识别/后处理/打字机动画等待等)。 - // 旧记录无该字段时视为 0。 - val totalElapsedMs: Long = 0, - // 供应商处理耗时(非流式文件识别时有效,毫秒)。OSS 旧记录无该字段时视为 0。 - val procMs: Long = 0, - val source: String, // "ime" | "floating" | "external" - val aiProcessed: Boolean, - // AI 后处理耗时(毫秒)。未尝试或旧记录无该字段时视为 0。 - val aiPostMs: Long = 0, - // AI 后处理状态。旧记录无该字段时视为 NONE。 - val aiPostStatus: AiPostStatus = AiPostStatus.NONE, - val charCount: Int - ) + @Serializable + data class AsrHistoryRecord( + val id: String = UUID.randomUUID().toString(), + val timestamp: Long, + val text: String, + val vendorId: String, + val audioMs: Long, + // 端到端总耗时(毫秒):从开始录音到最终提交完成(含识别/后处理/打字机动画等待等)。 + // 旧记录无该字段时视为 0。 + val totalElapsedMs: Long = 0, + // 供应商处理耗时(非流式文件识别时有效,毫秒)。OSS 旧记录无该字段时视为 0。 + val procMs: Long = 0, + // "ime" | "floating" | "external" + val source: String, + val aiProcessed: Boolean, + // AI 后处理耗时(毫秒)。未尝试或旧记录无该字段时视为 0。 + val aiPostMs: Long = 0, + // AI 后处理状态。旧记录无该字段时视为 NONE。 + val aiPostStatus: AiPostStatus = AiPostStatus.NONE, + val charCount: Int, + ) - private fun readAllInternal(): MutableList { - val raw = sp.getString(KEY_ASR_HISTORY_JSON, "").orEmpty() - if (raw.isBlank()) return mutableListOf() - return try { - json.decodeFromString>(raw).toMutableList() - } catch (e: Exception) { - Log.e(TAG, "Failed to parse history JSON", e) - mutableListOf() + private fun readAllInternal(): MutableList { + val raw = sp.getString(KEY_ASR_HISTORY_JSON, "").orEmpty() + if (raw.isBlank()) return mutableListOf() + return try { + json.decodeFromString>(raw).toMutableList() + } catch (e: Exception) { + Log.e(TAG, "Failed to parse history JSON", e) + mutableListOf() + } } - } - private fun writeAllInternal(list: List) { - try { - val text = json.encodeToString(list) - sp.edit().putString(KEY_ASR_HISTORY_JSON, text).apply() - } catch (e: Exception) { - Log.e(TAG, "Failed to write history JSON", e) + private fun writeAllInternal(list: List) { + try { + val text = json.encodeToString(list) + sp.edit().putString(KEY_ASR_HISTORY_JSON, text).apply() + } catch (e: Exception) { + Log.e(TAG, "Failed to write history JSON", e) + } } - } - fun add(record: AsrHistoryRecord) { - val list = readAllInternal() - list.add(record) - // 按时间倒序裁剪 - val ordered = list.sortedByDescending { it.timestamp } - val pruned = if (ordered.size > MAX_RECORDS) ordered.take(MAX_RECORDS) else ordered - writeAllInternal(pruned) - } + fun add(record: AsrHistoryRecord) { + val list = readAllInternal() + list.add(record) + // 按时间倒序裁剪 + val ordered = list.sortedByDescending { it.timestamp } + val pruned = if (ordered.size > MAX_RECORDS) ordered.take(MAX_RECORDS) else ordered + writeAllInternal(pruned) + } - fun listAll(): List = readAllInternal().sortedByDescending { it.timestamp } + fun listAll(): List = readAllInternal().sortedByDescending { it.timestamp } - fun deleteByIds(ids: Set): Int { - if (ids.isEmpty()) return 0 - val list = readAllInternal() - val before = list.size - val remained = list.filterNot { ids.contains(it.id) } - writeAllInternal(remained) - return (before - remained.size).coerceAtLeast(0) - } + fun deleteByIds(ids: Set): Int { + if (ids.isEmpty()) return 0 + val list = readAllInternal() + val before = list.size + val remained = list.filterNot { ids.contains(it.id) } + writeAllInternal(remained) + return (before - remained.size).coerceAtLeast(0) + } - fun clearAll() { - writeAllInternal(emptyList()) - } + fun clearAll() { + writeAllInternal(emptyList()) + } } - diff --git a/app/src/main/java/com/brycewg/asrkb/store/DashScopePrefsCompat.kt b/app/src/main/java/com/brycewg/asrkb/store/DashScopePrefsCompat.kt index faa7a524..deda8f2f 100644 --- a/app/src/main/java/com/brycewg/asrkb/store/DashScopePrefsCompat.kt +++ b/app/src/main/java/com/brycewg/asrkb/store/DashScopePrefsCompat.kt @@ -21,4 +21,3 @@ internal object DashScopePrefsCompat { return if (funAsr) Prefs.DASH_MODEL_FUN_ASR_REALTIME else Prefs.DASH_MODEL_QWEN3_REALTIME } } - diff --git a/app/src/main/java/com/brycewg/asrkb/store/Prefs.kt b/app/src/main/java/com/brycewg/asrkb/store/Prefs.kt index 5d8ba545..2d7670e3 100644 --- a/app/src/main/java/com/brycewg/asrkb/store/Prefs.kt +++ b/app/src/main/java/com/brycewg/asrkb/store/Prefs.kt @@ -99,7 +99,6 @@ class Prefs(context: Context) { // 移除:键盘内“切换输入法”按钮显示开关(按钮始终显示) - // 输入/点击触觉反馈强度 var hapticFeedbackLevel: Int get() { @@ -117,7 +116,7 @@ class Prefs(context: Context) { set(value) = sp.edit { putInt( KEY_HAPTIC_FEEDBACK_LEVEL, - value.coerceIn(HAPTIC_FEEDBACK_LEVEL_OFF, HAPTIC_FEEDBACK_LEVEL_HEAVY) + value.coerceIn(HAPTIC_FEEDBACK_LEVEL_OFF, HAPTIC_FEEDBACK_LEVEL_HEAVY), ) } @@ -227,7 +226,6 @@ class Prefs(context: Context) { get() = sp.getBoolean(KEY_HIDE_RECENT_TASK_CARD, false) set(value) = sp.edit { putBoolean(KEY_HIDE_RECENT_TASK_CARD, value) } - // 应用内语言(空字符串表示跟随系统;如:"zh-Hans"、"en") var appLanguageTag: String get() = sp.getString(KEY_APP_LANGUAGE_TAG, "") ?: "" @@ -285,7 +283,7 @@ class Prefs(context: Context) { set(value) = sp.edit { putFloat( KEY_FLOATING_DOCK_FRACTION, - if (value < 0f) -1f else value.coerceIn(0f, 1f) + if (value < 0f) -1f else value.coerceIn(0f, 1f), ) } @@ -384,7 +382,7 @@ class Prefs(context: Context) { val models: List = emptyList(), val enableReasoning: Boolean = false, val reasoningParamsOnJson: String = "", - val reasoningParamsOffJson: String = "" + val reasoningParamsOffJson: String = "", ) fun getLlmProviders(): List = PrefsLlmProviderStore.getLlmProviders(this, json) @@ -424,7 +422,7 @@ class Prefs(context: Context) { localizedDefaults = defaults, knownDefaultVariants = knownDefaultVariants, legacyPrompt = legacyPrompt, - initializedFromDefaults = initializedFromDefaults + initializedFromDefaults = initializedFromDefaults, ) } val parsed = try { @@ -438,7 +436,7 @@ class Prefs(context: Context) { prefs = this, current = parsed, localizedDefaults = defaults, - knownDefaultVariants = knownDefaultVariants + knownDefaultVariants = knownDefaultVariants, ) val migrated = PromptPresetMigrations.migrateLegacyPromptIfNeeded( prefs = this, @@ -446,13 +444,13 @@ class Prefs(context: Context) { localizedDefaults = defaults, knownDefaultVariants = knownDefaultVariants, legacyPrompt = legacyPrompt, - initializedFromDefaults = initializedFromDefaults + initializedFromDefaults = initializedFromDefaults, ) val synced = PromptPresetMigrations.syncDefaultsForLanguageIfNeeded( prefs = this, current = migrated, localizedDefaults = defaults, - knownDefaultVariants = knownDefaultVariants + knownDefaultVariants = knownDefaultVariants, ) if (synced != parsed) { setPromptPresets(synced) @@ -631,13 +629,12 @@ class Prefs(context: Context) { val enableReasoning: Boolean, val useCustomReasoningParams: Boolean, val reasoningParamsOnJson: String, - val reasoningParamsOffJson: String + val reasoningParamsOffJson: String, ) // 阿里云百炼(DashScope)凭证 var dashApiKey: String by stringPref(KEY_DASH_API_KEY, "") - // DashScope:自定义识别上下文(提示词) var dashPrompt: String by stringPref(KEY_DASH_PROMPT, "") @@ -988,6 +985,7 @@ class Prefs(context: Context) { var zipformerCleanupDone: Boolean get() = sp.getBoolean(KEY_ZIPFORMER_CLEANUP_DONE, false) set(value) = sp.edit { putBoolean(KEY_ZIPFORMER_CLEANUP_DONE, value) } + // --- 供应商配置通用化 --- internal val vendorFields: Map> = PrefsAsrVendorFields.vendorFields @@ -1020,7 +1018,7 @@ class Prefs(context: Context) { } fun hasVolcKeys(): Boolean = hasVendorKeys(AsrVendor.Volc) - fun hasSfKeys(): Boolean = sfFreeAsrEnabled || sfApiKey.isNotBlank() // 免费服务启用或有 API Key + fun hasSfKeys(): Boolean = sfFreeAsrEnabled || sfApiKey.isNotBlank() // 免费服务启用或有 API Key fun hasDashKeys(): Boolean = hasVendorKeys(AsrVendor.DashScope) fun hasElevenKeys(): Boolean = hasVendorKeys(AsrVendor.ElevenLabs) fun hasOpenAiKeys(): Boolean = hasVendorKeys(AsrVendor.OpenAI) @@ -1269,11 +1267,14 @@ class Prefs(context: Context) { const val SF_CHAT_COMPLETIONS_ENDPOINT = "https://api.siliconflow.cn/v1/chat/completions" const val DEFAULT_SF_MODEL = "FunAudioLLM/SenseVoiceSmall" const val DEFAULT_SF_OMNI_MODEL = "Qwen/Qwen3-Omni-30B-A3B-Instruct" + // SiliconFlow 免费服务模型配置 - const val DEFAULT_SF_FREE_ASR_MODEL = "FunAudioLLM/SenseVoiceSmall" // 免费 ASR 默认模型 - const val DEFAULT_SF_FREE_LLM_MODEL = "Qwen/Qwen3-8B" // 免费 LLM 默认模型 + const val DEFAULT_SF_FREE_ASR_MODEL = "FunAudioLLM/SenseVoiceSmall" // 免费 ASR 默认模型 + const val DEFAULT_SF_FREE_LLM_MODEL = "Qwen/Qwen3-8B" // 免费 LLM 默认模型 + // 免费 ASR 可选模型列表 val SF_FREE_ASR_MODELS: List = PrefsOptionLists.SF_FREE_ASR_MODELS + // 免费 LLM 可选模型列表 val SF_FREE_LLM_MODELS: List = PrefsOptionLists.SF_FREE_LLM_MODELS @@ -1285,9 +1286,11 @@ class Prefs(context: Context) { const val DEFAULT_DASH_MODEL = "qwen3-asr-flash" const val DASH_MODEL_QWEN3_REALTIME = "qwen3-asr-flash-realtime-2026-02-10" const val DASH_MODEL_FUN_ASR_REALTIME = "fun-asr-realtime" + // Gemini 默认 const val DEFAULT_GEM_ENDPOINT = "https://generativelanguage.googleapis.com/v1beta" const val DEFAULT_GEM_MODEL = "gemini-2.5-flash" + // Zhipu GLM ASR 默认 const val DEFAULT_ZHIPU_TEMPERATURE = 0.95f @@ -1310,16 +1313,15 @@ class Prefs(context: Context) { // 悬浮球默认大小(dp) const val DEFAULT_FLOATING_BALL_SIZE_DP = 44 + // 悬浮写入兼容:默认目标包名(每行一个;支持前缀匹配) const val DEFAULT_FLOATING_WRITE_COMPAT_PACKAGES = "org.telegram.messenger\nnu.gpu.nagram" - // Soniox 默认端点 const val SONIOX_API_BASE_URL = "https://api.soniox.com" const val SONIOX_FILES_ENDPOINT = "$SONIOX_API_BASE_URL/v1/files" const val SONIOX_TRANSCRIPTIONS_ENDPOINT = "$SONIOX_API_BASE_URL/v1/transcriptions" const val SONIOX_WS_URL = "wss://stt-rt.soniox.com/transcribe-websocket" - } // 导出全部设置为 JSON 字符串(包含密钥,仅用于本地备份/迁移) @@ -1333,5 +1335,4 @@ class Prefs(context: Context) { Log.i(TAG, "Successfully imported settings from JSON") return true } - } diff --git a/app/src/main/java/com/brycewg/asrkb/store/PrefsAsrVendorFields.kt b/app/src/main/java/com/brycewg/asrkb/store/PrefsAsrVendorFields.kt index 74199b4d..c63d1912 100644 --- a/app/src/main/java/com/brycewg/asrkb/store/PrefsAsrVendorFields.kt +++ b/app/src/main/java/com/brycewg/asrkb/store/PrefsAsrVendorFields.kt @@ -13,16 +13,17 @@ internal object PrefsAsrVendorFields { internal val vendorFields: Map> = mapOf( AsrVendor.Volc to listOf( VendorField(KEY_APP_KEY, required = true), - VendorField(KEY_ACCESS_KEY, required = true) + VendorField(KEY_ACCESS_KEY, required = true), ), // SiliconFlow:免费服务启用时无需 API Key AsrVendor.SiliconFlow to listOf( - VendorField(KEY_SF_API_KEY, required = false), // 免费服务时无需 API Key - VendorField(KEY_SF_MODEL, default = Prefs.DEFAULT_SF_MODEL) + // 免费服务时无需 API Key + VendorField(KEY_SF_API_KEY, required = false), + VendorField(KEY_SF_MODEL, default = Prefs.DEFAULT_SF_MODEL), ), AsrVendor.ElevenLabs to listOf( VendorField(KEY_ELEVEN_API_KEY, required = true), - VendorField(KEY_ELEVEN_LANGUAGE_CODE) + VendorField(KEY_ELEVEN_LANGUAGE_CODE), ), AsrVendor.OpenAI to listOf( VendorField(KEY_OA_ASR_ENDPOINT, required = true, default = Prefs.DEFAULT_OA_ASR_ENDPOINT), @@ -31,24 +32,24 @@ internal object PrefsAsrVendorFields { // 可选 Prompt 字段(字符串);开关为布尔,单独在导入/导出处理 VendorField(KEY_OA_ASR_PROMPT, required = false, default = ""), // 可选语言字段(字符串) - VendorField(KEY_OA_ASR_LANGUAGE, required = false, default = "") + VendorField(KEY_OA_ASR_LANGUAGE, required = false, default = ""), ), AsrVendor.DashScope to listOf( VendorField(KEY_DASH_API_KEY, required = true), VendorField(KEY_DASH_PROMPT, default = ""), - VendorField(KEY_DASH_LANGUAGE, default = "") + VendorField(KEY_DASH_LANGUAGE, default = ""), ), AsrVendor.Gemini to listOf( VendorField(KEY_GEM_ENDPOINT, required = true, default = Prefs.DEFAULT_GEM_ENDPOINT), VendorField(KEY_GEM_API_KEY, required = true), VendorField(KEY_GEM_MODEL, required = true, default = Prefs.DEFAULT_GEM_MODEL), - VendorField(KEY_GEM_PROMPT, default = "") + VendorField(KEY_GEM_PROMPT, default = ""), ), AsrVendor.Soniox to listOf( - VendorField(KEY_SONIOX_API_KEY, required = true) + VendorField(KEY_SONIOX_API_KEY, required = true), ), AsrVendor.Zhipu to listOf( - VendorField(KEY_ZHIPU_API_KEY, required = true) + VendorField(KEY_ZHIPU_API_KEY, required = true), ), // 本地 SenseVoice(sherpa-onnx)无需鉴权 AsrVendor.SenseVoice to emptyList(), @@ -57,12 +58,12 @@ internal object PrefsAsrVendorFields { // 本地 TeleSpeech(sherpa-onnx)无需鉴权 AsrVendor.Telespeech to emptyList(), // 本地 Paraformer(sherpa-onnx)无需鉴权 - AsrVendor.Paraformer to emptyList() + AsrVendor.Paraformer to emptyList(), ) } internal data class VendorField( val key: String, val required: Boolean = false, - val default: String = "" + val default: String = "", ) diff --git a/app/src/main/java/com/brycewg/asrkb/store/PrefsInitTasks.kt b/app/src/main/java/com/brycewg/asrkb/store/PrefsInitTasks.kt index b1292b99..660131ec 100644 --- a/app/src/main/java/com/brycewg/asrkb/store/PrefsInitTasks.kt +++ b/app/src/main/java/com/brycewg/asrkb/store/PrefsInitTasks.kt @@ -23,6 +23,7 @@ internal object PrefsInitTasks { private const val TAG = "Prefs" @Volatile private var toggleListenerRegistered: Boolean = false + @Volatile private var fnLegacyCleanupStarted: Boolean = false private val globalToggleListener = SharedPreferences.OnSharedPreferenceChangeListener { prefs, key -> @@ -173,7 +174,7 @@ internal object PrefsInitTasks { val legacyTargets = listOf( File(File(base, "sensevoice"), "nano-int8"), - File(File(base, "sensevoice"), "nano-full") + File(File(base, "sensevoice"), "nano-full"), ) CoroutineScope(SupervisorJob() + Dispatchers.IO).launch { diff --git a/app/src/main/java/com/brycewg/asrkb/store/PrefsLlmProviderStore.kt b/app/src/main/java/com/brycewg/asrkb/store/PrefsLlmProviderStore.kt index 70775747..984802f5 100644 --- a/app/src/main/java/com/brycewg/asrkb/store/PrefsLlmProviderStore.kt +++ b/app/src/main/java/com/brycewg/asrkb/store/PrefsLlmProviderStore.kt @@ -18,7 +18,7 @@ internal object PrefsLlmProviderStore { endpoint = prefs.llmEndpoint.ifBlank { Prefs.DEFAULT_LLM_ENDPOINT }, apiKey = prefs.llmApiKey, model = prefs.llmModel.ifBlank { Prefs.DEFAULT_LLM_MODEL }, - temperature = prefs.llmTemperature + temperature = prefs.llmTemperature, ) setLlmProviders(prefs, json, listOf(migrated)) } @@ -47,4 +47,3 @@ internal object PrefsLlmProviderStore { return list.firstOrNull { it.id == id } ?: list.firstOrNull() } } - diff --git a/app/src/main/java/com/brycewg/asrkb/store/PrefsLlmVendorStore.kt b/app/src/main/java/com/brycewg/asrkb/store/PrefsLlmVendorStore.kt index 520047e0..24383324 100644 --- a/app/src/main/java/com/brycewg/asrkb/store/PrefsLlmVendorStore.kt +++ b/app/src/main/java/com/brycewg/asrkb/store/PrefsLlmVendorStore.kt @@ -62,7 +62,7 @@ internal object PrefsLlmVendorStore { sp: SharedPreferences, json: Json, vendor: LlmVendor, - sfFreeLlmUsePaidKey: Boolean + sfFreeLlmUsePaidKey: Boolean, ): List { val key = "llm_vendor_${vendor.id}_models_json" if (vendor == LlmVendor.SF_FREE && !sfFreeLlmUsePaidKey) { @@ -142,7 +142,7 @@ internal object PrefsLlmVendorStore { enableReasoning = getLlmVendorReasoningEnabled(sp, vendor), useCustomReasoningParams = !isBuiltinLlmPresetModel(vendor, model, prefs.sfFreeLlmUsePaidKey), reasoningParamsOnJson = getLlmVendorReasoningParamsOnJson(sp, vendor), - reasoningParamsOffJson = getLlmVendorReasoningParamsOffJson(sp, vendor) + reasoningParamsOffJson = getLlmVendorReasoningParamsOffJson(sp, vendor), ) } } else { @@ -150,14 +150,15 @@ internal object PrefsLlmVendorStore { // 实际 API Key 在 LlmPostProcessor 中注入 Prefs.EffectiveLlmConfig( endpoint = vendor.endpoint, - apiKey = "", // 免费服务在调用层注入内置 Key + // 免费服务在调用层注入内置 Key + apiKey = "", model = model, temperature = Prefs.DEFAULT_LLM_TEMPERATURE, vendor = vendor, enableReasoning = getLlmVendorReasoningEnabled(sp, vendor), useCustomReasoningParams = !isBuiltinLlmPresetModel(vendor, model, prefs.sfFreeLlmUsePaidKey), reasoningParamsOnJson = getLlmVendorReasoningParamsOnJson(sp, vendor), - reasoningParamsOffJson = getLlmVendorReasoningParamsOffJson(sp, vendor) + reasoningParamsOffJson = getLlmVendorReasoningParamsOffJson(sp, vendor), ) } } @@ -174,9 +175,11 @@ internal object PrefsLlmVendorStore { enableReasoning = provider.enableReasoning, useCustomReasoningParams = true, reasoningParamsOnJson = provider.reasoningParamsOnJson, - reasoningParamsOffJson = provider.reasoningParamsOffJson + reasoningParamsOffJson = provider.reasoningParamsOffJson, ) - } else null + } else { + null + } } else -> { // 内置供应商:使用预设端点 + 用户 API Key + 用户选择的模型 @@ -194,7 +197,7 @@ internal object PrefsLlmVendorStore { enableReasoning = getLlmVendorReasoningEnabled(sp, vendor), useCustomReasoningParams = !isBuiltinLlmPresetModel(vendor, model, prefs.sfFreeLlmUsePaidKey), reasoningParamsOnJson = getLlmVendorReasoningParamsOnJson(sp, vendor), - reasoningParamsOffJson = getLlmVendorReasoningParamsOffJson(sp, vendor) + reasoningParamsOffJson = getLlmVendorReasoningParamsOffJson(sp, vendor), ) } } @@ -207,7 +210,8 @@ internal object PrefsLlmVendorStore { ReasoningMode.THINKING_TYPE -> """{"thinking":{"type":"enabled"}}""" ReasoningMode.REASONING_EFFORT -> """{"reasoning_effort":"medium"}""" ReasoningMode.MODEL_SELECTION, - ReasoningMode.NONE -> Prefs.DEFAULT_CUSTOM_REASONING_PARAMS_ON_JSON + ReasoningMode.NONE, + -> Prefs.DEFAULT_CUSTOM_REASONING_PARAMS_ON_JSON } } @@ -219,7 +223,8 @@ internal object PrefsLlmVendorStore { ReasoningMode.THINKING_TYPE -> """{"thinking":{"type":"disabled"}}""" ReasoningMode.REASONING_EFFORT -> """{"reasoning_effort":"none"}""" ReasoningMode.MODEL_SELECTION, - ReasoningMode.NONE -> Prefs.DEFAULT_CUSTOM_REASONING_PARAMS_OFF_JSON + ReasoningMode.NONE, + -> Prefs.DEFAULT_CUSTOM_REASONING_PARAMS_OFF_JSON } } } diff --git a/app/src/main/java/com/brycewg/asrkb/store/PrefsOptionLists.kt b/app/src/main/java/com/brycewg/asrkb/store/PrefsOptionLists.kt index 1e5d447a..5044a2ed 100644 --- a/app/src/main/java/com/brycewg/asrkb/store/PrefsOptionLists.kt +++ b/app/src/main/java/com/brycewg/asrkb/store/PrefsOptionLists.kt @@ -9,12 +9,11 @@ package com.brycewg.asrkb.store internal object PrefsOptionLists { val SF_FREE_ASR_MODELS = listOf( "FunAudioLLM/SenseVoiceSmall", - "TeleAI/TeleSpeechASR" + "TeleAI/TeleSpeechASR", ) val SF_FREE_LLM_MODELS = listOf( "Qwen/Qwen3-8B", - "THUDM/GLM-4-9B-0414" + "THUDM/GLM-4-9B-0414", ) } - diff --git a/app/src/main/java/com/brycewg/asrkb/store/PromptPreset.kt b/app/src/main/java/com/brycewg/asrkb/store/PromptPreset.kt index 3e1ae512..07eb3673 100644 --- a/app/src/main/java/com/brycewg/asrkb/store/PromptPreset.kt +++ b/app/src/main/java/com/brycewg/asrkb/store/PromptPreset.kt @@ -6,5 +6,5 @@ import kotlinx.serialization.Serializable data class PromptPreset( val id: String, val title: String, - val content: String + val content: String, ) diff --git a/app/src/main/java/com/brycewg/asrkb/store/PromptPresetDefaults.kt b/app/src/main/java/com/brycewg/asrkb/store/PromptPresetDefaults.kt index 33fa3055..8e62d3ed 100644 --- a/app/src/main/java/com/brycewg/asrkb/store/PromptPresetDefaults.kt +++ b/app/src/main/java/com/brycewg/asrkb/store/PromptPresetDefaults.kt @@ -16,28 +16,28 @@ internal fun buildDefaultPromptPresets(context: Context): List { PromptPreset( id = DEFAULT_PRESET_GENERAL_ID, title = context.getString(R.string.llm_prompt_preset_default_general_title), - content = context.getString(R.string.llm_prompt_preset_default_general_content) + content = context.getString(R.string.llm_prompt_preset_default_general_content), ), PromptPreset( id = DEFAULT_PRESET_POLISH_ID, title = context.getString(R.string.llm_prompt_preset_default_polish_title), - content = context.getString(R.string.llm_prompt_preset_default_polish_content) + content = context.getString(R.string.llm_prompt_preset_default_polish_content), ), PromptPreset( id = DEFAULT_PRESET_TRANSLATE_EN_ID, title = context.getString(R.string.llm_prompt_preset_default_translate_en_title), - content = context.getString(R.string.llm_prompt_preset_default_translate_en_content) + content = context.getString(R.string.llm_prompt_preset_default_translate_en_content), ), PromptPreset( id = DEFAULT_PRESET_KEY_POINTS_ID, title = context.getString(R.string.llm_prompt_preset_default_key_points_title), - content = context.getString(R.string.llm_prompt_preset_default_key_points_content) + content = context.getString(R.string.llm_prompt_preset_default_key_points_content), ), PromptPreset( id = DEFAULT_PRESET_TODO_ID, title = context.getString(R.string.llm_prompt_preset_default_todo_title), - content = context.getString(R.string.llm_prompt_preset_default_todo_content) - ) + content = context.getString(R.string.llm_prompt_preset_default_todo_content), + ), ) } diff --git a/app/src/main/java/com/brycewg/asrkb/store/PromptPresetMigrations.kt b/app/src/main/java/com/brycewg/asrkb/store/PromptPresetMigrations.kt index ae28e457..c3c1fe72 100644 --- a/app/src/main/java/com/brycewg/asrkb/store/PromptPresetMigrations.kt +++ b/app/src/main/java/com/brycewg/asrkb/store/PromptPresetMigrations.kt @@ -20,7 +20,7 @@ internal object PromptPresetMigrations { localizedDefaults: List, knownDefaultVariants: List>, legacyPrompt: String, - initializedFromDefaults: Boolean + initializedFromDefaults: Boolean, ): List { if (legacyPrompt.isBlank()) return current val knownBuiltinContents = knownDefaultVariants @@ -33,7 +33,7 @@ internal object PromptPresetMigrations { val migratedPreset = PromptPreset( id = java.util.UUID.randomUUID().toString(), title = prefs.getLocalizedString(R.string.llm_prompt_preset_mine_title), - content = legacyPrompt + content = legacyPrompt, ) val updated = current + migratedPreset val shouldActivate = initializedFromDefaults || @@ -49,7 +49,7 @@ internal object PromptPresetMigrations { private fun matchesDefaultPromptPresets( presets: List, - defaults: List + defaults: List, ): Boolean { if (presets.size != defaults.size) return false return presets.map { it.title to it.content } == defaults.map { it.title to it.content } @@ -58,7 +58,7 @@ internal object PromptPresetMigrations { private fun matchesAnyDefaultPromptPresets( presets: List, localizedDefaults: List, - knownDefaultVariants: List> + knownDefaultVariants: List>, ): Boolean { if (matchesDefaultPromptPresets(presets, localizedDefaults)) return true return knownDefaultVariants.any { defaults -> matchesDefaultPromptPresets(presets, defaults) } @@ -68,7 +68,7 @@ internal object PromptPresetMigrations { prefs: Prefs, current: List, localizedDefaults: List, - knownDefaultVariants: List> + knownDefaultVariants: List>, ): List { if (current.isEmpty() || localizedDefaults.isEmpty()) return current val builtinIds = localizedDefaults.map { it.id }.toSet() @@ -114,7 +114,7 @@ internal object PromptPresetMigrations { prefs: Prefs, current: List, localizedDefaults: List, - knownDefaultVariants: List> + knownDefaultVariants: List>, ): List { if (current.isEmpty() || localizedDefaults.isEmpty()) return current val defaultById = localizedDefaults.associateBy { it.id } @@ -135,7 +135,7 @@ internal object PromptPresetMigrations { changed = true preset.copy( title = localized.title, - content = localized.content + content = localized.content, ) } if (!changed) return current diff --git a/app/src/main/java/com/brycewg/asrkb/store/SonioxLanguagesStore.kt b/app/src/main/java/com/brycewg/asrkb/store/SonioxLanguagesStore.kt index 716ab9c2..c4fe1891 100644 --- a/app/src/main/java/com/brycewg/asrkb/store/SonioxLanguagesStore.kt +++ b/app/src/main/java/com/brycewg/asrkb/store/SonioxLanguagesStore.kt @@ -36,4 +36,3 @@ internal object SonioxLanguagesStore { } } } - diff --git a/app/src/main/java/com/brycewg/asrkb/store/SpeechPreset.kt b/app/src/main/java/com/brycewg/asrkb/store/SpeechPreset.kt index 5f393a64..10e5256e 100644 --- a/app/src/main/java/com/brycewg/asrkb/store/SpeechPreset.kt +++ b/app/src/main/java/com/brycewg/asrkb/store/SpeechPreset.kt @@ -6,5 +6,5 @@ import kotlinx.serialization.Serializable data class SpeechPreset( val id: String, val name: String, - val content: String + val content: String, ) diff --git a/app/src/main/java/com/brycewg/asrkb/store/SpeechPresetStore.kt b/app/src/main/java/com/brycewg/asrkb/store/SpeechPresetStore.kt index 1026feb8..49c7905b 100644 --- a/app/src/main/java/com/brycewg/asrkb/store/SpeechPresetStore.kt +++ b/app/src/main/java/com/brycewg/asrkb/store/SpeechPresetStore.kt @@ -48,4 +48,3 @@ internal object SpeechPresetStore { return match?.content } } - diff --git a/app/src/main/java/com/brycewg/asrkb/store/UsageStatsStore.kt b/app/src/main/java/com/brycewg/asrkb/store/UsageStatsStore.kt index e9cb3d16..69c0009b 100644 --- a/app/src/main/java/com/brycewg/asrkb/store/UsageStatsStore.kt +++ b/app/src/main/java/com/brycewg/asrkb/store/UsageStatsStore.kt @@ -81,7 +81,7 @@ internal object UsageStatsStore { vendor: AsrVendor, audioMs: Long, chars: Int, - procMs: Long = 0L + procMs: Long = 0L, ) { if (chars <= 0 && audioMs <= 0) return val today = LocalDate.now().format(DateTimeFormatter.BASIC_ISO_DATE) @@ -151,7 +151,7 @@ data class VendorAgg( var chars: Long = 0, var audioMs: Long = 0, // 非流式请求的供应商处理耗时聚合(毫秒) - var procMs: Long = 0 + var procMs: Long = 0, ) @Serializable @@ -159,7 +159,7 @@ data class DayAgg( var sessions: Long = 0, var chars: Long = 0, var audioMs: Long = 0, - var procMs: Long = 0 + var procMs: Long = 0, ) @Serializable @@ -170,6 +170,5 @@ data class UsageStats( var totalProcMs: Long = 0, var perVendor: MutableMap = mutableMapOf(), var daily: MutableMap = mutableMapOf(), - var firstUseDate: String = "" + var firstUseDate: String = "", ) - diff --git a/app/src/main/java/com/brycewg/asrkb/store/debug/DebugLogManager.kt b/app/src/main/java/com/brycewg/asrkb/store/debug/DebugLogManager.kt index 743a6778..7c8da801 100644 --- a/app/src/main/java/com/brycewg/asrkb/store/debug/DebugLogManager.kt +++ b/app/src/main/java/com/brycewg/asrkb/store/debug/DebugLogManager.kt @@ -12,7 +12,11 @@ import kotlinx.coroutines.Job import kotlinx.coroutines.SupervisorJob import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.launch -import java.io.* +import java.io.BufferedOutputStream +import java.io.File +import java.io.FileInputStream +import java.io.FileOutputStream +import java.io.RandomAccessFile import java.text.SimpleDateFormat import java.util.Date import java.util.Locale @@ -26,352 +30,354 @@ import java.util.Locale * 敏感信息约束:禁止记录识别文本/输入内容/剪贴板内容/密钥等,仅记录状态摘要。 */ object DebugLogManager { - private const val TAG = "DebugLogManager" - private const val DIR_NAME = "debug" - private const val FILE_NAME = "recording.log" - private const val MAX_BYTES: Long = 3L * 1024L * 1024L // 3 MB + private const val TAG = "DebugLogManager" + private const val DIR_NAME = "debug" + private const val FILE_NAME = "recording.log" + private const val MAX_BYTES: Long = 3L * 1024L * 1024L // 3 MB - @Volatile - private var recording: Boolean = false + @Volatile + private var recording: Boolean = false - private var scope: CoroutineScope? = null - private var writerJob: Job? = null - private var channel: Channel? = null - private var output: BufferedOutputStream? = null - private var logFile: File? = null + private var scope: CoroutineScope? = null + private var writerJob: Job? = null + private var channel: Channel? = null + private var output: BufferedOutputStream? = null + private var logFile: File? = null - private data class PersistentLogEntry( - val context: Context, - val line: String - ) - private val persistentScope = CoroutineScope(SupervisorJob() + Dispatchers.IO) - private val persistentChannel = Channel(capacity = 256) - @Volatile - private var persistentWriterStarted: Boolean = false + private data class PersistentLogEntry( + val context: Context, + val line: String, + ) + private val persistentScope = CoroutineScope(SupervisorJob() + Dispatchers.IO) + private val persistentChannel = Channel(capacity = 256) - fun isRecording(): Boolean = recording + @Volatile + private var persistentWriterStarted: Boolean = false - @Synchronized - fun start(context: Context) { - try { - if (recording) return - val dir = File(context.noBackupFilesDir, DIR_NAME) - if (!dir.exists()) dir.mkdirs() + fun isRecording(): Boolean = recording - val file = File(dir, FILE_NAME) - // 清空旧记录 - if (file.exists()) { + @Synchronized + fun start(context: Context) { try { - FileOutputStream(file, false).use { /* truncate */ } + if (recording) return + val dir = File(context.noBackupFilesDir, DIR_NAME) + if (!dir.exists()) dir.mkdirs() + + val file = File(dir, FILE_NAME) + // 清空旧记录 + if (file.exists()) { + try { + FileOutputStream(file, false).use { /* truncate */ } + } catch (e: Throwable) { + Log.e(TAG, "Failed to truncate old recording", e) + } + } + + val fos = FileOutputStream(file, true) + output = BufferedOutputStream(fos) + logFile = file + channel = Channel(capacity = 256) + scope = CoroutineScope(SupervisorJob() + Dispatchers.IO) + writerJob = scope?.launch { + try { + for (line in channel!!) { + writeLine(line) + } + } catch (t: Throwable) { + Log.e(TAG, "Writer loop error", t) + } finally { + try { + output?.flush() + } catch (e: Throwable) { + Log.e(TAG, "Error flushing output", e) + } + try { + output?.close() + } catch (e: Throwable) { + Log.e(TAG, "Error closing output", e) + } + } + } + recording = true + // 记录环境摘要 + log( + category = "debug", + event = "recording_started", + data = mapOf( + "sdk" to Build.VERSION.SDK_INT, + "brand" to Build.BRAND, + "model" to Build.MODEL, + "fingerprint" to safeFingerprint(), + ), + ) } catch (e: Throwable) { - Log.e(TAG, "Failed to truncate old recording", e) + Log.e(TAG, "Failed to start recording", e) + stop() // 尝试清理 } - } + } - val fos = FileOutputStream(file, true) - output = BufferedOutputStream(fos) - logFile = file - channel = Channel(capacity = 256) - scope = CoroutineScope(SupervisorJob() + Dispatchers.IO) - writerJob = scope?.launch { + @Synchronized + fun stop() { try { - for (line in channel!!) { - writeLine(line) - } - } catch (t: Throwable) { - Log.e(TAG, "Writer loop error", t) + if (!recording) return + log(category = "debug", event = "recording_stopping") + recording = false + try { + channel?.close() + } catch (e: Throwable) { + Log.e(TAG, "Error closing channel", e) + } + try { + writerJob?.cancel() + } catch (e: Throwable) { + Log.e(TAG, "Error canceling writer job", e) + } + } catch (e: Throwable) { + Log.e(TAG, "Failed to stop recording", e) } finally { - try { - output?.flush() - } catch (e: Throwable) { - Log.e(TAG, "Error flushing output", e) - } - try { - output?.close() - } catch (e: Throwable) { - Log.e(TAG, "Error closing output", e) - } + channel = null + writerJob = null + scope = null + try { + output?.close() + } catch (e: Throwable) { + Log.e(TAG, "Error closing output in finally", e) + } + output = null } - } - recording = true - // 记录环境摘要 - log( - category = "debug", - event = "recording_started", - data = mapOf( - "sdk" to Build.VERSION.SDK_INT, - "brand" to Build.BRAND, - "model" to Build.MODEL, - "fingerprint" to safeFingerprint() - ) - ) - } catch (e: Throwable) { - Log.e(TAG, "Failed to start recording", e) - stop() // 尝试清理 } - } - @Synchronized - fun stop() { - try { - if (!recording) return - log(category = "debug", event = "recording_stopping") - recording = false - try { - channel?.close() - } catch (e: Throwable) { - Log.e(TAG, "Error closing channel", e) - } - try { - writerJob?.cancel() - } catch (e: Throwable) { - Log.e(TAG, "Error canceling writer job", e) - } - } catch (e: Throwable) { - Log.e(TAG, "Failed to stop recording", e) - } finally { - channel = null - writerJob = null - scope = null - try { - output?.close() - } catch (e: Throwable) { - Log.e(TAG, "Error closing output in finally", e) - } - output = null + fun log(category: String, event: String, data: Map = emptyMap()) { + if (!recording) return + val ch = channel ?: return + try { + val line = buildJsonLine(category, event, data) + ch.trySend(line).isSuccess + } catch (e: Throwable) { + Log.e(TAG, "Failed to enqueue log", e) + } } - } - fun log(category: String, event: String, data: Map = emptyMap()) { - if (!recording) return - val ch = channel ?: return - try { - val line = buildJsonLine(category, event, data) - ch.trySend(line).isSuccess - } catch (e: Throwable) { - Log.e(TAG, "Failed to enqueue log", e) + /** + * 无需手动开始录制也可写入导出日志;若正在录制则复用录制通道。 + */ + fun logPersistent(context: Context, category: String, event: String, data: Map = emptyMap()) { + if (recording) { + log(category, event, data) + return + } + try { + val line = buildJsonLine(category, event, data) + ensurePersistentWriter() + val queued = persistentChannel.trySend(PersistentLogEntry(context.applicationContext, line)).isSuccess + if (!queued) { + // 队列瞬时拥塞时,降级到 IO 协程直写,避免主线程阻塞 + persistentScope.launch { + appendPersistentLine(context.applicationContext, line) + } + } + } catch (e: Throwable) { + Log.e(TAG, "Failed to append persistent log", e) + } } - } - /** - * 无需手动开始录制也可写入导出日志;若正在录制则复用录制通道。 - */ - fun logPersistent(context: Context, category: String, event: String, data: Map = emptyMap()) { - if (recording) { - log(category, event, data) - return - } - try { - val line = buildJsonLine(category, event, data) - ensurePersistentWriter() - val queued = persistentChannel.trySend(PersistentLogEntry(context.applicationContext, line)).isSuccess - if (!queued) { - // 队列瞬时拥塞时,降级到 IO 协程直写,避免主线程阻塞 + @Synchronized + private fun ensurePersistentWriter() { + if (persistentWriterStarted) return + persistentWriterStarted = true persistentScope.launch { - appendPersistentLine(context.applicationContext, line) + try { + for (entry in persistentChannel) { + appendPersistentLine(entry.context, entry.line) + } + } catch (e: Throwable) { + Log.e(TAG, "Persistent writer loop error", e) + persistentWriterStarted = false + } } - } - } catch (e: Throwable) { - Log.e(TAG, "Failed to append persistent log", e) } - } - @Synchronized - private fun ensurePersistentWriter() { - if (persistentWriterStarted) return - persistentWriterStarted = true - persistentScope.launch { - try { - for (entry in persistentChannel) { - appendPersistentLine(entry.context, entry.line) + @Synchronized + private fun appendPersistentLine(context: Context, line: String) { + try { + val dir = File(context.noBackupFilesDir, DIR_NAME) + if (!dir.exists()) dir.mkdirs() + val file = File(dir, FILE_NAME) + if (!file.exists()) { + file.createNewFile() + } + if (file.length() > MAX_BYTES) { + truncateKeepTail(file, keepBytes = 2L * 1024L * 1024L) + } + FileOutputStream(file, true).use { out -> + out.write(line.toByteArray(Charsets.UTF_8)) + out.write('\n'.code) + out.flush() + } + logFile = file + } catch (e: Throwable) { + Log.e(TAG, "Failed writing persistent log line", e) } - } catch (e: Throwable) { - Log.e(TAG, "Persistent writer loop error", e) - persistentWriterStarted = false - } } - } - @Synchronized - private fun appendPersistentLine(context: Context, line: String) { - try { - val dir = File(context.noBackupFilesDir, DIR_NAME) - if (!dir.exists()) dir.mkdirs() - val file = File(dir, FILE_NAME) - if (!file.exists()) { - file.createNewFile() - } - if (file.length() > MAX_BYTES) { - truncateKeepTail(file, keepBytes = 2L * 1024L * 1024L) - } - FileOutputStream(file, true).use { out -> - out.write(line.toByteArray(Charsets.UTF_8)) - out.write('\n'.code) - out.flush() - } - logFile = file - } catch (e: Throwable) { - Log.e(TAG, "Failed writing persistent log line", e) - } - } + /** + * 复制日志到 cache 并构造分享 Intent。若正在录制,返回 RecordingActive 错误。 + */ + fun buildShareIntent(context: Context): ShareIntentResult { + try { + if (recording) return ShareIntentResult.Error(ShareError.RecordingActive) + val src = logFile ?: File(File(context.noBackupFilesDir, DIR_NAME), FILE_NAME) + if (!src.exists() || src.length() <= 0) return ShareIntentResult.Error(ShareError.NoLog) - /** - * 复制日志到 cache 并构造分享 Intent。若正在录制,返回 RecordingActive 错误。 - */ - fun buildShareIntent(context: Context): ShareIntentResult { - try { - if (recording) return ShareIntentResult.Error(ShareError.RecordingActive) - val src = logFile ?: File(File(context.noBackupFilesDir, DIR_NAME), FILE_NAME) - if (!src.exists() || src.length() <= 0) return ShareIntentResult.Error(ShareError.NoLog) + val stamp = SimpleDateFormat("yyyyMMdd_HHmmss", Locale.getDefault()).format(Date()) + val name = "asrkb_debug_log_$stamp.txt" + val dst = File(context.cacheDir, name) + try { + FileInputStream(src).use { ins -> + FileOutputStream(dst).use { outs -> + ins.copyTo(outs) + } + } + } catch (e: Throwable) { + Log.e(TAG, "Failed to copy log to cache", e) + return ShareIntentResult.Error(ShareError.Failed) + } - val stamp = SimpleDateFormat("yyyyMMdd_HHmmss", Locale.getDefault()).format(Date()) - val name = "asrkb_debug_log_${stamp}.txt" - val dst = File(context.cacheDir, name) - try { - FileInputStream(src).use { ins -> - FileOutputStream(dst).use { outs -> - ins.copyTo(outs) - } - } - } catch (e: Throwable) { - Log.e(TAG, "Failed to copy log to cache", e) - return ShareIntentResult.Error(ShareError.Failed) - } + val uri: Uri = try { + FileProvider.getUriForFile(context, context.packageName + ".fileprovider", dst) + } catch (e: Throwable) { + Log.e(TAG, "Failed to get Uri from FileProvider", e) + return ShareIntentResult.Error(ShareError.Failed) + } - val uri: Uri = try { - FileProvider.getUriForFile(context, context.packageName + ".fileprovider", dst) - } catch (e: Throwable) { - Log.e(TAG, "Failed to get Uri from FileProvider", e) - return ShareIntentResult.Error(ShareError.Failed) - } - - val intent = Intent(Intent.ACTION_SEND).apply { - type = "text/plain" - putExtra(Intent.EXTRA_STREAM, uri) - putExtra(Intent.EXTRA_SUBJECT, "ASRKB Debug Log") - addFlags(Intent.FLAG_GRANT_READ_URI_PERMISSION) - } - return ShareIntentResult.Success(intent, name) - } catch (t: Throwable) { - Log.e(TAG, "Error building share intent", t) - return ShareIntentResult.Error(ShareError.Failed) + val intent = Intent(Intent.ACTION_SEND).apply { + type = "text/plain" + putExtra(Intent.EXTRA_STREAM, uri) + putExtra(Intent.EXTRA_SUBJECT, "ASRKB Debug Log") + addFlags(Intent.FLAG_GRANT_READ_URI_PERMISSION) + } + return ShareIntentResult.Success(intent, name) + } catch (t: Throwable) { + Log.e(TAG, "Error building share intent", t) + return ShareIntentResult.Error(ShareError.Failed) + } } - } - private fun writeLine(line: String) { - try { - val out = output ?: return - // 尺寸控制:超过上限则截断为末尾 2MB - val f = logFile - if (f != null && f.length() > MAX_BYTES) { - truncateKeepTail(f, keepBytes = 2L * 1024L * 1024L) - } - out.write(line.toByteArray(Charsets.UTF_8)) - out.write('\n'.code) - out.flush() - } catch (e: Throwable) { - Log.e(TAG, "Error writing log line", e) + private fun writeLine(line: String) { + try { + val out = output ?: return + // 尺寸控制:超过上限则截断为末尾 2MB + val f = logFile + if (f != null && f.length() > MAX_BYTES) { + truncateKeepTail(f, keepBytes = 2L * 1024L * 1024L) + } + out.write(line.toByteArray(Charsets.UTF_8)) + out.write('\n'.code) + out.flush() + } catch (e: Throwable) { + Log.e(TAG, "Error writing log line", e) + } } - } - private fun truncateKeepTail(file: File, keepBytes: Long) { - try { - val size = file.length() - if (size <= keepBytes) return - val tmp = File(file.parentFile, file.name + ".tmp") - RandomAccessFile(file, "r").use { raf -> - raf.seek(size - keepBytes) - FileOutputStream(tmp).use { outs -> - val buf = ByteArray(32 * 1024) - while (true) { - val n = raf.read(buf) - if (n <= 0) break - outs.write(buf, 0, n) - } + private fun truncateKeepTail(file: File, keepBytes: Long) { + try { + val size = file.length() + if (size <= keepBytes) return + val tmp = File(file.parentFile, file.name + ".tmp") + RandomAccessFile(file, "r").use { raf -> + raf.seek(size - keepBytes) + FileOutputStream(tmp).use { outs -> + val buf = ByteArray(32 * 1024) + while (true) { + val n = raf.read(buf) + if (n <= 0) break + outs.write(buf, 0, n) + } + } + } + if (!file.delete()) { + Log.w(TAG, "Failed to delete original during truncate") + } + if (!tmp.renameTo(file)) { + Log.w(TAG, "Failed to rename tmp during truncate") + } + } catch (e: Throwable) { + Log.e(TAG, "Error truncating log file", e) } - } - if (!file.delete()) { - Log.w(TAG, "Failed to delete original during truncate") - } - if (!tmp.renameTo(file)) { - Log.w(TAG, "Failed to rename tmp during truncate") - } - } catch (e: Throwable) { - Log.e(TAG, "Error truncating log file", e) } - } - private fun buildJsonLine(category: String, event: String, data: Map): String { - val sb = StringBuilder(128) - sb.append('{') - appendField(sb, "ts", isoNow()); sb.append(',') - appendField(sb, "cat", category); sb.append(',') - appendField(sb, "evt", event) - for ((k, v) in data) { - sb.append(',') - appendFieldName(sb, k) - sb.append(':') - appendValue(sb, v) + private fun buildJsonLine(category: String, event: String, data: Map): String { + val sb = StringBuilder(128) + sb.append('{') + appendField(sb, "ts", isoNow()) + sb.append(',') + appendField(sb, "cat", category) + sb.append(',') + appendField(sb, "evt", event) + for ((k, v) in data) { + sb.append(',') + appendFieldName(sb, k) + sb.append(':') + appendValue(sb, v) + } + sb.append('}') + return sb.toString() } - sb.append('}') - return sb.toString() - } - private fun appendField(sb: StringBuilder, name: String, value: String) { - appendFieldName(sb, name) - sb.append(':') - appendString(sb, value) - } + private fun appendField(sb: StringBuilder, name: String, value: String) { + appendFieldName(sb, name) + sb.append(':') + appendString(sb, value) + } private fun appendFieldName(sb: StringBuilder, name: String) { - appendString(sb, name) - } + appendString(sb, name) + } - private fun appendValue(sb: StringBuilder, v: Any?) { - when (v) { - null -> sb.append("null") - is Number, is Boolean -> sb.append(v.toString()) - else -> appendString(sb, v.toString()) + private fun appendValue(sb: StringBuilder, v: Any?) { + when (v) { + null -> sb.append("null") + is Number, is Boolean -> sb.append(v.toString()) + else -> appendString(sb, v.toString()) + } } - } - private fun appendString(sb: StringBuilder, s: String) { - sb.append('"') - for (ch in s) { - when (ch) { - '\\' -> sb.append("\\\\") - '"' -> sb.append("\\\"") - '\n' -> sb.append("\\n") - '\r' -> sb.append("\\r") - '\t' -> sb.append("\\t") - else -> sb.append(ch) - } + private fun appendString(sb: StringBuilder, s: String) { + sb.append('"') + for (ch in s) { + when (ch) { + '\\' -> sb.append("\\\\") + '"' -> sb.append("\\\"") + '\n' -> sb.append("\\n") + '\r' -> sb.append("\\r") + '\t' -> sb.append("\\t") + else -> sb.append(ch) + } + } + sb.append('"') } - sb.append('"') - } - private fun isoNow(): String { - return try { - val sdf = SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSX", Locale.US) - sdf.format(Date()) - } catch (_: Throwable) { - System.currentTimeMillis().toString() + private fun isoNow(): String { + return try { + val sdf = SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSX", Locale.US) + sdf.format(Date()) + } catch (_: Throwable) { + System.currentTimeMillis().toString() + } } - } - private fun safeFingerprint(): String { - return try { - Build.FINGERPRINT.take(24) - } catch (_: Throwable) { - "" + private fun safeFingerprint(): String { + return try { + Build.FINGERPRINT.take(24) + } catch (_: Throwable) { + "" + } } - } - sealed class ShareIntentResult { - data class Success(val intent: Intent, val displayName: String) : ShareIntentResult() - data class Error(val error: ShareError) : ShareIntentResult() - } + sealed class ShareIntentResult { + data class Success(val intent: Intent, val displayName: String) : ShareIntentResult() + data class Error(val error: ShareError) : ShareIntentResult() + } - enum class ShareError { RecordingActive, NoLog, Failed } + enum class ShareError { RecordingActive, NoLog, Failed } } - diff --git a/app/src/main/java/com/brycewg/asrkb/ui/AsrAccessibilityService.kt b/app/src/main/java/com/brycewg/asrkb/ui/AsrAccessibilityService.kt index 72d87cb4..5bd4a102 100644 --- a/app/src/main/java/com/brycewg/asrkb/ui/AsrAccessibilityService.kt +++ b/app/src/main/java/com/brycewg/asrkb/ui/AsrAccessibilityService.kt @@ -5,6 +5,7 @@ import android.content.ClipData import android.content.ClipboardManager import android.content.Context import android.content.Intent +import android.graphics.Rect import android.os.Bundle import android.os.Handler import android.os.Looper @@ -13,12 +14,11 @@ import android.view.accessibility.AccessibilityEvent import android.view.accessibility.AccessibilityNodeInfo import android.view.accessibility.AccessibilityWindowInfo import android.widget.Toast -import android.graphics.Rect -import com.brycewg.asrkb.store.Prefs import com.brycewg.asrkb.LocaleHelper +import com.brycewg.asrkb.store.Prefs +import com.brycewg.asrkb.store.debug.DebugLogManager import com.brycewg.asrkb.ui.floating.FloatingAsrService import com.brycewg.asrkb.ui.floating.FloatingImeHints -import com.brycewg.asrkb.store.debug.DebugLogManager /** * 无障碍服务,用于悬浮球语音识别后将文本插入到当前焦点的输入框中 @@ -37,7 +37,7 @@ class AsrAccessibilityService : AccessibilityService() { */ data class FocusContext( val prefix: String, - val suffix: String + val suffix: String, ) companion object { @@ -68,7 +68,7 @@ class AsrAccessibilityService : AccessibilityService() { FocusContext( prefix = full.substring(0, s), - suffix = full.substring(e, full.length) + suffix = full.substring(e, full.length), ) } } @@ -262,11 +262,11 @@ class AsrAccessibilityService : AccessibilityService() { */ private fun isRelevantEventType(eventType: Int): Boolean { return eventType == AccessibilityEvent.TYPE_WINDOW_STATE_CHANGED || - eventType == AccessibilityEvent.TYPE_WINDOW_CONTENT_CHANGED || - eventType == AccessibilityEvent.TYPE_VIEW_FOCUSED || - eventType == AccessibilityEvent.TYPE_VIEW_TEXT_SELECTION_CHANGED || - eventType == AccessibilityEvent.TYPE_VIEW_TEXT_CHANGED || - eventType == AccessibilityEvent.TYPE_WINDOWS_CHANGED + eventType == AccessibilityEvent.TYPE_WINDOW_CONTENT_CHANGED || + eventType == AccessibilityEvent.TYPE_VIEW_FOCUSED || + eventType == AccessibilityEvent.TYPE_VIEW_TEXT_SELECTION_CHANGED || + eventType == AccessibilityEvent.TYPE_VIEW_TEXT_CHANGED || + eventType == AccessibilityEvent.TYPE_WINDOWS_CHANGED } private fun tryDispatchImeVisibilityHint() { @@ -291,7 +291,11 @@ class AsrAccessibilityService : AccessibilityService() { val now = System.currentTimeMillis() if (now - lastA11yAggEmitAt >= 1000L) { lastA11yAggEmitAt = now - val pkg = try { getActiveWindowPackage() } catch (_: Throwable) { null } ?: "" + val pkg = try { + getActiveWindowPackage() + } catch (_: Throwable) { + null + } ?: "" val d = mapOf( "pkgTop" to pkg, "winStateChanged" to aggWinStateChanged, @@ -299,7 +303,7 @@ class AsrAccessibilityService : AccessibilityService() { "viewFocused" to aggViewFocused, "textSelChanged" to aggTextSelChanged, "textChanged" to aggTextChanged, - "windowsChanged" to aggWindowsChanged + "windowsChanged" to aggWindowsChanged, ) DebugLogManager.log("a11y", "events", d) aggWinStateChanged = 0 @@ -330,7 +334,6 @@ class AsrAccessibilityService : AccessibilityService() { return mWindow || hold } - /** * 更新输入法可见性状态,并在状态变化时通知相关服务。 */ @@ -343,8 +346,8 @@ class AsrAccessibilityService : AccessibilityService() { event = if (active) "scene_active" else "scene_inactive", data = mapOf( "by" to "a11y", - "pkg" to (getActiveWindowPackage() ?: "") - ) + "pkg" to (getActiveWindowPackage() ?: ""), + ), ) // 附带一次决策解释 try { @@ -372,7 +375,7 @@ class AsrAccessibilityService : AccessibilityService() { "holdByFocus" to holdByFocus, "strategyUsed" to strategyUsed, "activePkg" to (activePkg ?: ""), - "resultActive" to resultActive + "resultActive" to resultActive, ) } @@ -424,16 +427,37 @@ class AsrAccessibilityService : AccessibilityService() { if (target != null) { Log.d(TAG, "Found editable/focusable node; trying ACTION_SET_TEXT") try { - val nodeClass = try { target.className?.toString() } catch (_: Throwable) { null } ?: "" - val editable = try { target.isEditable } catch (_: Throwable) { false } + val nodeClass = try { + target.className?.toString() + } catch (_: Throwable) { + null + } ?: "" + val editable = try { + target.isEditable + } catch (_: Throwable) { + false + } val hasSetText = nodeHasAction(target, AccessibilityNodeInfo.ACTION_SET_TEXT) val hasPaste = nodeHasAction(target, AccessibilityNodeInfo.ACTION_PASTE) val hasLongClick = nodeHasAction(target, AccessibilityNodeInfo.ACTION_LONG_CLICK) - val textLen = try { target.text?.length ?: 0 } catch (_: Throwable) { 0 } - val selStart = try { target.textSelectionStart } catch (_: Throwable) { -1 } - val selEnd = try { target.textSelectionEnd } catch (_: Throwable) { -1 } + val textLen = try { + target.text?.length ?: 0 + } catch (_: Throwable) { + 0 + } + val selStart = try { + target.textSelectionStart + } catch (_: Throwable) { + -1 + } + val selEnd = try { + target.textSelectionEnd + } catch (_: Throwable) { + -1 + } DebugLogManager.log( - "insert", "cap", + "insert", + "cap", mapOf( "nodeClass" to nodeClass, "editable" to editable, @@ -442,8 +466,8 @@ class AsrAccessibilityService : AccessibilityService() { "hasLongClick" to hasLongClick, "textLen" to textLen, "selStart" to selStart, - "selEnd" to selEnd - ) + "selEnd" to selEnd, + ), ) } catch (t: Throwable) { Log.w(TAG, "Failed to log insert capabilities", t) @@ -498,8 +522,8 @@ class AsrAccessibilityService : AccessibilityService() { "fallback_clipboard", mapOf( "reason" to (e::class.java.simpleName), - "msg" to (e.message?.take(80) ?: "") - ) + "msg" to (e.message?.take(80) ?: ""), + ), ) copyToClipboard(this, text) Toast.makeText(this, getString(com.brycewg.asrkb.R.string.floating_asr_copied), Toast.LENGTH_SHORT).show() @@ -609,7 +633,8 @@ class AsrAccessibilityService : AccessibilityService() { private fun findFocusedEditableNode(root: AccessibilityNodeInfo): AccessibilityNodeInfo? { root.findFocus(AccessibilityNodeInfo.FOCUS_INPUT)?.let { f -> if (isEditableLike(f)) return f - @Suppress("DEPRECATION") f.recycle() + @Suppress("DEPRECATION") + f.recycle() } return findEditableNodeRecursive(root) } @@ -620,10 +645,12 @@ class AsrAccessibilityService : AccessibilityService() { val child = node.getChild(i) ?: continue val result = findEditableNodeRecursive(child) if (result != null) { - @Suppress("DEPRECATION") child.recycle() + @Suppress("DEPRECATION") + child.recycle() return result } - @Suppress("DEPRECATION") child.recycle() + @Suppress("DEPRECATION") + child.recycle() } return null } @@ -709,7 +736,8 @@ class AsrAccessibilityService : AccessibilityService() { val root = w.root ?: continue val node = findFocusedEditableNode(root) if (node != null) { - @Suppress("DEPRECATION") node.recycle() + @Suppress("DEPRECATION") + node.recycle() return true } } catch (t: Throwable) { @@ -723,7 +751,8 @@ class AsrAccessibilityService : AccessibilityService() { val node = findFocusedEditableNode(root) val ok = node != null if (node != null) { - @Suppress("DEPRECATION") node.recycle() + @Suppress("DEPRECATION") + node.recycle() } ok } catch (e: Throwable) { @@ -743,7 +772,9 @@ class AsrAccessibilityService : AccessibilityService() { val root = w.root if (root != null) return true val r = Rect() - try { w.getBoundsInScreen(r) } catch (_: Throwable) {} + try { + w.getBoundsInScreen(r) + } catch (_: Throwable) {} if (r.width() > 0 && r.height() > 0) return true if (w.isActive || w.isFocused) return true } diff --git a/app/src/main/java/com/brycewg/asrkb/ui/AsrVendorTag.kt b/app/src/main/java/com/brycewg/asrkb/ui/AsrVendorTag.kt index 98d3b1c1..459dffb1 100644 --- a/app/src/main/java/com/brycewg/asrkb/ui/AsrVendorTag.kt +++ b/app/src/main/java/com/brycewg/asrkb/ui/AsrVendorTag.kt @@ -5,48 +5,48 @@ import androidx.annotation.StringRes import com.brycewg.asrkb.R enum class AsrVendorTag( - @StringRes val labelResId: Int, - @ColorRes val bgColorResId: Int, - @ColorRes val textColorResId: Int + @StringRes val labelResId: Int, + @ColorRes val bgColorResId: Int, + @ColorRes val textColorResId: Int, ) { - Online( - labelResId = R.string.asr_vendor_tag_online, - bgColorResId = R.color.asr_tag_bg_online, - textColorResId = R.color.asr_tag_fg_online - ), - Local( - labelResId = R.string.asr_vendor_tag_local, - bgColorResId = R.color.asr_tag_bg_local, - textColorResId = R.color.asr_tag_fg_local - ), - Streaming( - labelResId = R.string.asr_vendor_tag_streaming, - bgColorResId = R.color.asr_tag_bg_streaming, - textColorResId = R.color.asr_tag_fg_streaming - ), - NonStreaming( - labelResId = R.string.asr_vendor_tag_non_streaming, - bgColorResId = R.color.asr_tag_bg_non_streaming, - textColorResId = R.color.asr_tag_fg_non_streaming - ), - PseudoStreaming( - labelResId = R.string.asr_vendor_tag_pseudo_streaming, - bgColorResId = R.color.asr_tag_bg_pseudo_streaming, - textColorResId = R.color.asr_tag_fg_pseudo_streaming - ), - Custom( - labelResId = R.string.asr_vendor_tag_custom, - bgColorResId = R.color.asr_tag_bg_custom, - textColorResId = R.color.asr_tag_fg_custom - ), - ChineseDialect( - labelResId = R.string.asr_vendor_tag_chinese_dialect, - bgColorResId = R.color.asr_tag_bg_cn_dialect, - textColorResId = R.color.asr_tag_fg_cn_dialect - ), - Accurate( - labelResId = R.string.asr_vendor_tag_accurate, - bgColorResId = R.color.asr_tag_bg_accurate, - textColorResId = R.color.asr_tag_fg_accurate - ), + Online( + labelResId = R.string.asr_vendor_tag_online, + bgColorResId = R.color.asr_tag_bg_online, + textColorResId = R.color.asr_tag_fg_online, + ), + Local( + labelResId = R.string.asr_vendor_tag_local, + bgColorResId = R.color.asr_tag_bg_local, + textColorResId = R.color.asr_tag_fg_local, + ), + Streaming( + labelResId = R.string.asr_vendor_tag_streaming, + bgColorResId = R.color.asr_tag_bg_streaming, + textColorResId = R.color.asr_tag_fg_streaming, + ), + NonStreaming( + labelResId = R.string.asr_vendor_tag_non_streaming, + bgColorResId = R.color.asr_tag_bg_non_streaming, + textColorResId = R.color.asr_tag_fg_non_streaming, + ), + PseudoStreaming( + labelResId = R.string.asr_vendor_tag_pseudo_streaming, + bgColorResId = R.color.asr_tag_bg_pseudo_streaming, + textColorResId = R.color.asr_tag_fg_pseudo_streaming, + ), + Custom( + labelResId = R.string.asr_vendor_tag_custom, + bgColorResId = R.color.asr_tag_bg_custom, + textColorResId = R.color.asr_tag_fg_custom, + ), + ChineseDialect( + labelResId = R.string.asr_vendor_tag_chinese_dialect, + bgColorResId = R.color.asr_tag_bg_cn_dialect, + textColorResId = R.color.asr_tag_fg_cn_dialect, + ), + Accurate( + labelResId = R.string.asr_vendor_tag_accurate, + bgColorResId = R.color.asr_tag_bg_accurate, + textColorResId = R.color.asr_tag_fg_accurate, + ), } diff --git a/app/src/main/java/com/brycewg/asrkb/ui/AsrVendorUi.kt b/app/src/main/java/com/brycewg/asrkb/ui/AsrVendorUi.kt index 566cab95..a519e0ed 100644 --- a/app/src/main/java/com/brycewg/asrkb/ui/AsrVendorUi.kt +++ b/app/src/main/java/com/brycewg/asrkb/ui/AsrVendorUi.kt @@ -8,113 +8,113 @@ import com.brycewg.asrkb.asr.AsrVendor * 统一提供 ASR 供应商的顺序与显示名,避免各处硬编码 listOf。 */ object AsrVendorUi { - /** 固定的供应商顺序(设置页/菜单统一使用) */ - fun ordered(): List = listOf( - AsrVendor.SiliconFlow, - AsrVendor.Volc, - AsrVendor.ElevenLabs, - AsrVendor.OpenAI, - AsrVendor.DashScope, - AsrVendor.Gemini, - AsrVendor.Soniox, - AsrVendor.Zhipu, - AsrVendor.SenseVoice, - AsrVendor.FunAsrNano, - AsrVendor.Telespeech, - AsrVendor.Paraformer - ) + /** 固定的供应商顺序(设置页/菜单统一使用) */ + fun ordered(): List = listOf( + AsrVendor.SiliconFlow, + AsrVendor.Volc, + AsrVendor.ElevenLabs, + AsrVendor.OpenAI, + AsrVendor.DashScope, + AsrVendor.Gemini, + AsrVendor.Soniox, + AsrVendor.Zhipu, + AsrVendor.SenseVoice, + AsrVendor.FunAsrNano, + AsrVendor.Telespeech, + AsrVendor.Paraformer, + ) - /** 指定 vendor 的多语言显示名 */ - fun name(context: Context, v: AsrVendor): String = when (v) { - AsrVendor.Volc -> context.getString(R.string.vendor_volc) - AsrVendor.SiliconFlow -> context.getString(R.string.vendor_sf) - AsrVendor.ElevenLabs -> context.getString(R.string.vendor_eleven) - AsrVendor.OpenAI -> context.getString(R.string.vendor_openai) - AsrVendor.DashScope -> context.getString(R.string.vendor_dashscope) - AsrVendor.Gemini -> context.getString(R.string.vendor_gemini) - AsrVendor.Soniox -> context.getString(R.string.vendor_soniox) - AsrVendor.Zhipu -> context.getString(R.string.vendor_zhipu) - AsrVendor.SenseVoice -> context.getString(R.string.vendor_sensevoice) - AsrVendor.FunAsrNano -> context.getString(R.string.vendor_funasr_nano) - AsrVendor.Telespeech -> context.getString(R.string.vendor_telespeech) - AsrVendor.Paraformer -> context.getString(R.string.vendor_paraformer) - } + /** 指定 vendor 的多语言显示名 */ + fun name(context: Context, v: AsrVendor): String = when (v) { + AsrVendor.Volc -> context.getString(R.string.vendor_volc) + AsrVendor.SiliconFlow -> context.getString(R.string.vendor_sf) + AsrVendor.ElevenLabs -> context.getString(R.string.vendor_eleven) + AsrVendor.OpenAI -> context.getString(R.string.vendor_openai) + AsrVendor.DashScope -> context.getString(R.string.vendor_dashscope) + AsrVendor.Gemini -> context.getString(R.string.vendor_gemini) + AsrVendor.Soniox -> context.getString(R.string.vendor_soniox) + AsrVendor.Zhipu -> context.getString(R.string.vendor_zhipu) + AsrVendor.SenseVoice -> context.getString(R.string.vendor_sensevoice) + AsrVendor.FunAsrNano -> context.getString(R.string.vendor_funasr_nano) + AsrVendor.Telespeech -> context.getString(R.string.vendor_telespeech) + AsrVendor.Paraformer -> context.getString(R.string.vendor_paraformer) + } - /** 指定 vendor 的标签(用于选择器展示;可后续按需调整) */ - fun tags(v: AsrVendor): List = when (v) { - AsrVendor.SiliconFlow -> listOf( - AsrVendorTag.Online, - AsrVendorTag.NonStreaming, - AsrVendorTag.Custom - ) - AsrVendor.Volc -> listOf( - AsrVendorTag.Online, - AsrVendorTag.Streaming, - AsrVendorTag.NonStreaming, - AsrVendorTag.ChineseDialect, - AsrVendorTag.Accurate - ) - AsrVendor.ElevenLabs -> listOf( - AsrVendorTag.Online, - AsrVendorTag.Streaming, - AsrVendorTag.NonStreaming - ) - AsrVendor.OpenAI -> listOf( - AsrVendorTag.Online, - AsrVendorTag.NonStreaming, - AsrVendorTag.Custom - ) - AsrVendor.DashScope -> listOf( - AsrVendorTag.Online, - AsrVendorTag.Streaming, - AsrVendorTag.NonStreaming, - AsrVendorTag.ChineseDialect, - AsrVendorTag.Accurate - ) - AsrVendor.Gemini -> listOf( - AsrVendorTag.Online, - AsrVendorTag.NonStreaming, - AsrVendorTag.Accurate, - AsrVendorTag.Custom - ) - AsrVendor.Soniox -> listOf( - AsrVendorTag.Online, - AsrVendorTag.Streaming, - AsrVendorTag.NonStreaming, - AsrVendorTag.Accurate - ) - AsrVendor.Zhipu -> listOf( - AsrVendorTag.Online, - AsrVendorTag.NonStreaming, - AsrVendorTag.ChineseDialect - ) - AsrVendor.SenseVoice -> listOf( - AsrVendorTag.Local, - AsrVendorTag.NonStreaming, - AsrVendorTag.PseudoStreaming - ) - AsrVendor.FunAsrNano -> listOf( - AsrVendorTag.Local, - AsrVendorTag.NonStreaming, - AsrVendorTag.ChineseDialect, - AsrVendorTag.Accurate - ) - AsrVendor.Telespeech -> listOf( - AsrVendorTag.Local, - AsrVendorTag.NonStreaming, - AsrVendorTag.PseudoStreaming, - AsrVendorTag.ChineseDialect - ) - AsrVendor.Paraformer -> listOf( - AsrVendorTag.Local, - AsrVendorTag.Streaming, - AsrVendorTag.Accurate - ) - } + /** 指定 vendor 的标签(用于选择器展示;可后续按需调整) */ + fun tags(v: AsrVendor): List = when (v) { + AsrVendor.SiliconFlow -> listOf( + AsrVendorTag.Online, + AsrVendorTag.NonStreaming, + AsrVendorTag.Custom, + ) + AsrVendor.Volc -> listOf( + AsrVendorTag.Online, + AsrVendorTag.Streaming, + AsrVendorTag.NonStreaming, + AsrVendorTag.ChineseDialect, + AsrVendorTag.Accurate, + ) + AsrVendor.ElevenLabs -> listOf( + AsrVendorTag.Online, + AsrVendorTag.Streaming, + AsrVendorTag.NonStreaming, + ) + AsrVendor.OpenAI -> listOf( + AsrVendorTag.Online, + AsrVendorTag.NonStreaming, + AsrVendorTag.Custom, + ) + AsrVendor.DashScope -> listOf( + AsrVendorTag.Online, + AsrVendorTag.Streaming, + AsrVendorTag.NonStreaming, + AsrVendorTag.ChineseDialect, + AsrVendorTag.Accurate, + ) + AsrVendor.Gemini -> listOf( + AsrVendorTag.Online, + AsrVendorTag.NonStreaming, + AsrVendorTag.Accurate, + AsrVendorTag.Custom, + ) + AsrVendor.Soniox -> listOf( + AsrVendorTag.Online, + AsrVendorTag.Streaming, + AsrVendorTag.NonStreaming, + AsrVendorTag.Accurate, + ) + AsrVendor.Zhipu -> listOf( + AsrVendorTag.Online, + AsrVendorTag.NonStreaming, + AsrVendorTag.ChineseDialect, + ) + AsrVendor.SenseVoice -> listOf( + AsrVendorTag.Local, + AsrVendorTag.NonStreaming, + AsrVendorTag.PseudoStreaming, + ) + AsrVendor.FunAsrNano -> listOf( + AsrVendorTag.Local, + AsrVendorTag.NonStreaming, + AsrVendorTag.ChineseDialect, + AsrVendorTag.Accurate, + ) + AsrVendor.Telespeech -> listOf( + AsrVendorTag.Local, + AsrVendorTag.NonStreaming, + AsrVendorTag.PseudoStreaming, + AsrVendorTag.ChineseDialect, + ) + AsrVendor.Paraformer -> listOf( + AsrVendorTag.Local, + AsrVendorTag.Streaming, + AsrVendorTag.Accurate, + ) + } - /** 顺序化的 (Vendor, 显示名) 列表 */ - fun pairs(context: Context): List> = ordered().map { it to name(context, it) } + /** 顺序化的 (Vendor, 显示名) 列表 */ + fun pairs(context: Context): List> = ordered().map { it to name(context, it) } - /** 顺序化的显示名列表 */ - fun names(context: Context): List = ordered().map { name(context, it) } + /** 顺序化的显示名列表 */ + fun names(context: Context): List = ordered().map { name(context, it) } } diff --git a/app/src/main/java/com/brycewg/asrkb/ui/DownloadSourceConfig.kt b/app/src/main/java/com/brycewg/asrkb/ui/DownloadSourceConfig.kt index 435b90fd..8a51b36b 100644 --- a/app/src/main/java/com/brycewg/asrkb/ui/DownloadSourceConfig.kt +++ b/app/src/main/java/com/brycewg/asrkb/ui/DownloadSourceConfig.kt @@ -4,42 +4,42 @@ import android.content.Context import com.brycewg.asrkb.R object DownloadSourceConfig { - private data class Mirror(val labelRes: Int, val prefix: String) + private data class Mirror(val labelRes: Int, val prefix: String) - private val mirrors = listOf( - Mirror(R.string.download_source_mirror_1, "https://ghproxy.net/"), - Mirror(R.string.download_source_mirror_2, "https://hub.gitmirror.com/"), - Mirror(R.string.download_source_mirror_3, "https://fastgit.cc/") - ) - - fun buildOptions( - context: Context, - officialUrl: String, - officialLabelRes: Int = R.string.download_source_github_official - ): List { - val options = ArrayList(mirrors.size + 1) - options.add( - DownloadSourceDialog.Option( - context.getString(officialLabelRes), - officialUrl - ) + private val mirrors = listOf( + Mirror(R.string.download_source_mirror_1, "https://ghproxy.net/"), + Mirror(R.string.download_source_mirror_2, "https://hub.gitmirror.com/"), + Mirror(R.string.download_source_mirror_3, "https://fastgit.cc/"), ) - mirrors.forEach { mirror -> - options.add( - DownloadSourceDialog.Option( - context.getString(mirror.labelRes), - applyMirrorPrefix(officialUrl, mirror.prefix) + + fun buildOptions( + context: Context, + officialUrl: String, + officialLabelRes: Int = R.string.download_source_github_official, + ): List { + val options = ArrayList(mirrors.size + 1) + options.add( + DownloadSourceDialog.Option( + context.getString(officialLabelRes), + officialUrl, + ), ) - ) + mirrors.forEach { mirror -> + options.add( + DownloadSourceDialog.Option( + context.getString(mirror.labelRes), + applyMirrorPrefix(officialUrl, mirror.prefix), + ), + ) + } + return options } - return options - } - private fun applyMirrorPrefix(originalUrl: String, mirrorPrefix: String): String { - return if (originalUrl.startsWith("https://github.com/")) { - mirrorPrefix + originalUrl - } else { - originalUrl + private fun applyMirrorPrefix(originalUrl: String, mirrorPrefix: String): String { + return if (originalUrl.startsWith("https://github.com/")) { + mirrorPrefix + originalUrl + } else { + originalUrl + } } - } } diff --git a/app/src/main/java/com/brycewg/asrkb/ui/DownloadSourceDialog.kt b/app/src/main/java/com/brycewg/asrkb/ui/DownloadSourceDialog.kt index 7e3c3611..b2210434 100644 --- a/app/src/main/java/com/brycewg/asrkb/ui/DownloadSourceDialog.kt +++ b/app/src/main/java/com/brycewg/asrkb/ui/DownloadSourceDialog.kt @@ -25,176 +25,176 @@ import java.net.Socket import java.net.SocketTimeoutException object DownloadSourceDialog { - private const val TAG = "DownloadSourceDialog" - private const val LATENCY_TIMEOUT_MS = 3000 - - data class Option(val label: String, val url: String) - - fun show( - context: Context, - titleRes: Int, - options: List