diff --git a/Yolo/android/app/src/main/AndroidManifest.xml b/Yolo/android/app/src/main/AndroidManifest.xml
index e05f5de4..5e4a6f0c 100644
--- a/Yolo/android/app/src/main/AndroidManifest.xml
+++ b/Yolo/android/app/src/main/AndroidManifest.xml
@@ -5,6 +5,7 @@
+
@@ -21,9 +22,8 @@
tools:targetApi="34">
@@ -31,6 +31,13 @@
+
+
+
diff --git a/Yolo/android/app/src/main/java/com/example/executorchyolodemo/BirdDetectionActivity.java b/Yolo/android/app/src/main/java/com/example/executorchyolodemo/BirdDetectionActivity.java
index da78d11f..0ad59d5f 100644
--- a/Yolo/android/app/src/main/java/com/example/executorchyolodemo/BirdDetectionActivity.java
+++ b/Yolo/android/app/src/main/java/com/example/executorchyolodemo/BirdDetectionActivity.java
@@ -81,7 +81,8 @@ protected void onCreate(Bundle savedInstanceState) {
// Initialize bird detection pipeline
try {
- birdPipeline = new BirdDetectionPipeline(this);
+ String modelDir = getIntent().getStringExtra("model_dir");
+ birdPipeline = new BirdDetectionPipeline(this, modelDir);
Log.d(TAG, "Bird detection pipeline initialized successfully");
} catch (Exception e) {
Log.e(TAG, "Failed to initialize bird pipeline", e);
diff --git a/Yolo/android/app/src/main/java/com/example/executorchyolodemo/BirdDetectionPipeline.java b/Yolo/android/app/src/main/java/com/example/executorchyolodemo/BirdDetectionPipeline.java
index ac729d68..dcc66cff 100644
--- a/Yolo/android/app/src/main/java/com/example/executorchyolodemo/BirdDetectionPipeline.java
+++ b/Yolo/android/app/src/main/java/com/example/executorchyolodemo/BirdDetectionPipeline.java
@@ -115,9 +115,21 @@ private String generateLocationKey(RectF box) {
}
public BirdDetectionPipeline(Context context) throws IOException {
+ this(context, null);
+ }
+
+ public BirdDetectionPipeline(Context context, String modelDir) throws IOException {
try {
- String yoloPath = Utils.assetFilePath(context, "yolo_detector.pte");
- String classifierPath = Utils.assetFilePath(context, "bird_classifier.pte");
+ String yoloPath;
+ String classifierPath;
+
+ if (modelDir != null) {
+ yoloPath = modelDir + "/yolo_detector.pte";
+ classifierPath = modelDir + "/bird_classifier.pte";
+ } else {
+ yoloPath = Utils.assetFilePath(context, "yolo_detector.pte");
+ classifierPath = Utils.assetFilePath(context, "bird_classifier.pte");
+ }
yoloModule = Module.load(yoloPath);
classifierModule = Module.load(classifierPath);
diff --git a/Yolo/android/app/src/main/java/com/example/executorchyolodemo/ModelDownloadActivity.kt b/Yolo/android/app/src/main/java/com/example/executorchyolodemo/ModelDownloadActivity.kt
new file mode 100644
index 00000000..45c9960c
--- /dev/null
+++ b/Yolo/android/app/src/main/java/com/example/executorchyolodemo/ModelDownloadActivity.kt
@@ -0,0 +1,48 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+package com.example.executorchyolodemo
+
+import android.content.Intent
+import android.os.Bundle
+import androidx.activity.ComponentActivity
+import androidx.activity.compose.setContent
+import androidx.activity.viewModels
+import androidx.compose.material3.MaterialTheme
+
+class ModelDownloadActivity : ComponentActivity() {
+
+ private val downloadViewModel: ModelDownloadViewModel by viewModels()
+
+ override fun onCreate(savedInstanceState: Bundle?) {
+ super.onCreate(savedInstanceState)
+ downloadViewModel.initialize(filesDir.absolutePath)
+
+ // Skip download screen if all models already downloaded
+ if (downloadViewModel.allModelsDownloaded()) {
+ launchBirdDetection()
+ return
+ }
+
+ setContent {
+ MaterialTheme {
+ ModelDownloadScreen(
+ downloadViewModel = downloadViewModel,
+ onDownloadComplete = { launchBirdDetection() }
+ )
+ }
+ }
+ }
+
+ private fun launchBirdDetection() {
+ val intent = Intent(this, BirdDetectionActivity::class.java)
+ intent.putExtra("model_dir", downloadViewModel.getModelDir())
+ startActivity(intent)
+ finish()
+ }
+}
diff --git a/Yolo/android/app/src/main/java/com/example/executorchyolodemo/ModelDownloadScreen.kt b/Yolo/android/app/src/main/java/com/example/executorchyolodemo/ModelDownloadScreen.kt
new file mode 100644
index 00000000..f7c4f817
--- /dev/null
+++ b/Yolo/android/app/src/main/java/com/example/executorchyolodemo/ModelDownloadScreen.kt
@@ -0,0 +1,186 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+package com.example.executorchyolodemo
+
+import androidx.compose.foundation.layout.Arrangement
+import androidx.compose.foundation.layout.Column
+import androidx.compose.foundation.layout.Row
+import androidx.compose.foundation.layout.Spacer
+import androidx.compose.foundation.layout.fillMaxSize
+import androidx.compose.foundation.layout.fillMaxWidth
+import androidx.compose.foundation.layout.height
+import androidx.compose.foundation.layout.padding
+import androidx.compose.foundation.layout.size
+import androidx.compose.foundation.layout.width
+import androidx.compose.material3.Button
+import androidx.compose.material3.CircularProgressIndicator
+import androidx.compose.material3.LinearProgressIndicator
+import androidx.compose.material3.MaterialTheme
+import androidx.compose.material3.Surface
+import androidx.compose.material3.Text
+import androidx.compose.runtime.Composable
+import androidx.compose.ui.Alignment
+import androidx.compose.ui.Modifier
+import androidx.compose.ui.unit.dp
+
+@Composable
+fun ModelDownloadScreen(
+ downloadViewModel: ModelDownloadViewModel,
+ onDownloadComplete: () -> Unit
+) {
+ val status = downloadViewModel.downloadStatus
+ val progress = downloadViewModel.downloadProgress
+ val currentFileName = downloadViewModel.currentFileName
+ val error = downloadViewModel.errorMessage
+ val isDownloading = status == DownloadStatus.DOWNLOADING
+
+ Column(
+ modifier = Modifier
+ .fillMaxSize()
+ .padding(16.dp),
+ horizontalAlignment = Alignment.CenterHorizontally,
+ verticalArrangement = Arrangement.Center
+ ) {
+ Text(
+ text = "Bird Detection Models",
+ style = MaterialTheme.typography.headlineMedium
+ )
+
+ Spacer(modifier = Modifier.height(8.dp))
+
+ Text(
+ text = "Download the required models to get started",
+ style = MaterialTheme.typography.bodyMedium,
+ color = MaterialTheme.colorScheme.onSurfaceVariant
+ )
+
+ Spacer(modifier = Modifier.height(24.dp))
+
+ // Files list
+ Surface(
+ modifier = Modifier.fillMaxWidth(),
+ color = MaterialTheme.colorScheme.surfaceVariant,
+ shape = MaterialTheme.shapes.medium
+ ) {
+ Column(modifier = Modifier.padding(16.dp)) {
+ Text(
+ text = "Model Files",
+ style = MaterialTheme.typography.titleSmall
+ )
+ Spacer(modifier = Modifier.height(8.dp))
+
+ ModelDownloadViewModel.MODEL_FILES.forEach { fileInfo ->
+ FileStatusRow(fileInfo, downloadViewModel)
+ Spacer(modifier = Modifier.height(4.dp))
+ }
+
+ // Download progress
+ if (isDownloading) {
+ Spacer(modifier = Modifier.height(12.dp))
+ LinearProgressIndicator(
+ progress = progress,
+ modifier = Modifier.fillMaxWidth(),
+ )
+ Spacer(modifier = Modifier.height(8.dp))
+ Row(verticalAlignment = Alignment.CenterVertically) {
+ CircularProgressIndicator(modifier = Modifier.size(16.dp), strokeWidth = 2.dp)
+ Spacer(modifier = Modifier.width(8.dp))
+ Text(
+ text = "Downloading $currentFileName...",
+ style = MaterialTheme.typography.bodySmall,
+ color = MaterialTheme.colorScheme.onSurfaceVariant
+ )
+ }
+ }
+
+ if (status == DownloadStatus.COMPLETED) {
+ Spacer(modifier = Modifier.height(12.dp))
+ Text(
+ text = "All models downloaded!",
+ style = MaterialTheme.typography.bodyMedium,
+ color = MaterialTheme.colorScheme.primary
+ )
+ }
+
+ if (error != null) {
+ Spacer(modifier = Modifier.height(12.dp))
+ Text(
+ text = error,
+ style = MaterialTheme.typography.bodySmall,
+ color = MaterialTheme.colorScheme.error
+ )
+ }
+ }
+ }
+
+ Spacer(modifier = Modifier.height(24.dp))
+
+ when (status) {
+ DownloadStatus.NOT_STARTED, DownloadStatus.FAILED -> {
+ Button(
+ onClick = { downloadViewModel.downloadModels() },
+ modifier = Modifier.fillMaxWidth()
+ ) {
+ Text(if (status == DownloadStatus.FAILED) "Retry Download" else "Download Models")
+ }
+ }
+
+ DownloadStatus.DOWNLOADING -> {
+ Button(
+ onClick = {},
+ enabled = false,
+ modifier = Modifier.fillMaxWidth()
+ ) {
+ Text("Downloading...")
+ }
+ }
+
+ DownloadStatus.COMPLETED -> {
+ Button(
+ onClick = onDownloadComplete,
+ modifier = Modifier.fillMaxWidth()
+ ) {
+ Text("Continue")
+ }
+ }
+ }
+ }
+}
+
+@Composable
+private fun FileStatusRow(
+ fileInfo: ModelFileInfo,
+ downloadViewModel: ModelDownloadViewModel
+) {
+ val status = downloadViewModel.downloadStatus
+ val currentFileName = downloadViewModel.currentFileName
+ val fileExists = downloadViewModel.isFileDownloaded(fileInfo.filename)
+
+ Row(
+ modifier = Modifier.fillMaxWidth(),
+ horizontalArrangement = Arrangement.SpaceBetween
+ ) {
+ Column(modifier = Modifier.weight(1f)) {
+ Text(text = fileInfo.description, style = MaterialTheme.typography.bodyMedium)
+ Text(
+ text = fileInfo.filename,
+ style = MaterialTheme.typography.bodySmall,
+ color = MaterialTheme.colorScheme.onSurfaceVariant
+ )
+ }
+ Text(
+ text = when {
+ fileExists || status == DownloadStatus.COMPLETED -> "✓"
+ status == DownloadStatus.DOWNLOADING && currentFileName == fileInfo.filename -> "⬇"
+ else -> "○"
+ },
+ style = MaterialTheme.typography.bodyMedium
+ )
+ }
+}
diff --git a/Yolo/android/app/src/main/java/com/example/executorchyolodemo/ModelDownloadViewModel.kt b/Yolo/android/app/src/main/java/com/example/executorchyolodemo/ModelDownloadViewModel.kt
new file mode 100644
index 00000000..e6288cc3
--- /dev/null
+++ b/Yolo/android/app/src/main/java/com/example/executorchyolodemo/ModelDownloadViewModel.kt
@@ -0,0 +1,189 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+package com.example.executorchyolodemo
+
+import android.util.Log
+import androidx.compose.runtime.getValue
+import androidx.compose.runtime.mutableStateOf
+import androidx.compose.runtime.setValue
+import androidx.lifecycle.ViewModel
+import androidx.lifecycle.viewModelScope
+import kotlinx.coroutines.Dispatchers
+import kotlinx.coroutines.launch
+import kotlinx.coroutines.withContext
+import java.io.File
+import java.io.FileOutputStream
+import java.net.HttpURLConnection
+import java.net.URL
+
+data class ModelFileInfo(
+ val url: String,
+ val filename: String,
+ val description: String
+)
+
+enum class DownloadStatus {
+ NOT_STARTED,
+ DOWNLOADING,
+ COMPLETED,
+ FAILED
+}
+
+class ModelDownloadViewModel : ViewModel() {
+
+ companion object {
+ private const val TAG = "ModelDownloadViewModel"
+ const val MODELS_SUBDIRECTORY = "yolo"
+
+ private const val YOLO_DETECTOR_URL =
+ "https://huggingface.co/larryliu0820/yolo26s-ExecuTorch-XNNPACK/resolve/main/yolo26s_dynamic_xnnpack.pte"
+ private const val BIRD_CLASSIFIER_URL =
+ "https://huggingface.co/psiddh/bird-classifier-executorch/resolve/main/bird_classifier.pte"
+
+ val MODEL_FILES = listOf(
+ ModelFileInfo(
+ url = YOLO_DETECTOR_URL,
+ filename = "yolo_detector.pte",
+ description = "YOLO Bird Detector"
+ ),
+ ModelFileInfo(
+ url = BIRD_CLASSIFIER_URL,
+ filename = "bird_classifier.pte",
+ description = "Bird Species Classifier"
+ )
+ )
+ }
+
+ var downloadStatus by mutableStateOf(DownloadStatus.NOT_STARTED)
+ private set
+
+ var downloadProgress by mutableStateOf(0f)
+ private set
+
+ var currentFileIndex by mutableStateOf(0)
+ private set
+
+ var totalFileCount by mutableStateOf(0)
+ private set
+
+ var currentFileName by mutableStateOf("")
+ private set
+
+ var errorMessage by mutableStateOf(null)
+ private set
+
+ private lateinit var modelsDir: String
+
+ fun initialize(filesDir: String) {
+ modelsDir = filesDir + "/" + MODELS_SUBDIRECTORY
+ }
+
+ fun getModelDir(): String = modelsDir
+
+ fun allModelsDownloaded(): Boolean {
+ return MODEL_FILES.all { File("$modelsDir/${it.filename}").exists() }
+ }
+
+ fun isFileDownloaded(filename: String): Boolean {
+ return File("$modelsDir/$filename").exists()
+ }
+
+ fun downloadModels() {
+ if (downloadStatus == DownloadStatus.DOWNLOADING) return
+
+ val filesToDownload = MODEL_FILES.filter {
+ !File("$modelsDir/${it.filename}").exists()
+ }
+
+ if (filesToDownload.isEmpty()) {
+ downloadStatus = DownloadStatus.COMPLETED
+ return
+ }
+
+ downloadStatus = DownloadStatus.DOWNLOADING
+ downloadProgress = 0f
+ currentFileIndex = 0
+ totalFileCount = filesToDownload.size
+ errorMessage = null
+
+ viewModelScope.launch {
+ try {
+ val dir = File(modelsDir)
+ if (!dir.exists()) {
+ dir.mkdirs()
+ }
+
+ for ((index, fileInfo) in filesToDownload.withIndex()) {
+ currentFileIndex = index
+ currentFileName = fileInfo.filename
+ val targetFile = File("$modelsDir/${fileInfo.filename}")
+
+ val success = downloadFile(fileInfo, targetFile)
+ if (!success) {
+ downloadStatus = DownloadStatus.FAILED
+ return@launch
+ }
+ downloadProgress = (index + 1).toFloat() / filesToDownload.size
+ }
+
+ downloadStatus = DownloadStatus.COMPLETED
+ } catch (e: Exception) {
+ Log.e(TAG, "Download failed", e)
+ downloadStatus = DownloadStatus.FAILED
+ errorMessage = "Download failed: ${e.message}"
+ }
+ }
+ }
+
+ fun resetStatus() {
+ downloadStatus = DownloadStatus.NOT_STARTED
+ errorMessage = null
+ }
+
+ private suspend fun downloadFile(
+ fileInfo: ModelFileInfo,
+ targetFile: File
+ ): Boolean = withContext(Dispatchers.IO) {
+ try {
+ Log.i(TAG, "Downloading ${fileInfo.filename} from ${fileInfo.url}")
+ val url = URL(fileInfo.url)
+ val connection = url.openConnection() as HttpURLConnection
+ connection.requestMethod = "GET"
+ connection.instanceFollowRedirects = true
+ connection.connectTimeout = 30000
+ connection.readTimeout = 30000
+ connection.connect()
+
+ if (connection.responseCode != HttpURLConnection.HTTP_OK) {
+ throw Exception("Server returned HTTP ${connection.responseCode}")
+ }
+
+ val tempFile = File(targetFile.absolutePath + ".tmp")
+ connection.inputStream.use { input ->
+ FileOutputStream(tempFile).use { output ->
+ val buffer = ByteArray(8192)
+ var bytesRead: Int
+ while (input.read(buffer).also { bytesRead = it } != -1) {
+ output.write(buffer, 0, bytesRead)
+ }
+ }
+ }
+ tempFile.renameTo(targetFile)
+
+ Log.i(TAG, "Downloaded ${fileInfo.filename} successfully")
+ true
+ } catch (e: Exception) {
+ Log.e(TAG, "Failed to download ${fileInfo.filename}", e)
+ withContext(Dispatchers.Main) {
+ errorMessage = "Failed to download ${fileInfo.filename}: ${e.message}"
+ }
+ false
+ }
+ }
+}