diff --git a/app/src/main/kotlin/com/google/ai/sample/GenerativeAiViewModelFactory.kt b/app/src/main/kotlin/com/google/ai/sample/GenerativeAiViewModelFactory.kt index a55b193..06f94f5 100644 --- a/app/src/main/kotlin/com/google/ai/sample/GenerativeAiViewModelFactory.kt +++ b/app/src/main/kotlin/com/google/ai/sample/GenerativeAiViewModelFactory.kt @@ -29,7 +29,10 @@ enum class ModelOption( val size: String? = null, val supportsScreenshot: Boolean = true, val isOfflineModel: Boolean = false, - val offlineModelFilename: String? = null + val offlineModelFilename: String? = null, + val offlineRequiredFilenames: List = emptyList(), + val additionalDownloadUrls: List = emptyList(), + val requiresVisionBackend: Boolean = false ) { PUTER_GLM5("GLM-5V Turbo (Puter)", "openrouter:z-ai/glm-5v-turbo", ApiProvider.PUTER, supportsScreenshot = true), MISTRAL_LARGE_3("Mistral Large 3", "mistral-large-latest", ApiProvider.MISTRAL), @@ -53,15 +56,37 @@ enum class ModelOption( "https://huggingface.co/na5h13/gemma-3n-E4B-it-litert-lm/resolve/main/gemma-3n-E4B-it-int4.litertlm?download=true", "4.92 GB", isOfflineModel = true, - offlineModelFilename = "gemma-3n-e4b-it-int4.litertlm" + offlineModelFilename = "gemma-3n-e4b-it-int4.litertlm", + offlineRequiredFilenames = listOf("gemma-3n-e4b-it-int4.litertlm") ), GEMMA_4_E4B_IT( "Gemma 4 E4B it (offline)", "gemma-4-e4b-it", ApiProvider.GOOGLE, "https://huggingface.co/litert-community/gemma-4-E4B-it-litert-lm/resolve/main/gemma-4-E4B-it.litertlm?download=true", + "3.40 GB", isOfflineModel = true, - offlineModelFilename = "gemma-4-E4B-it.litertlm" + offlineModelFilename = "gemma-4-E4B-it.litertlm", + offlineRequiredFilenames = listOf("gemma-4-E4B-it.litertlm") + ), + QWEN3_5_4B_OFFLINE( + "Qwen3.5 4B (offline)", + "qwen3.5-4b-offline", + ApiProvider.GOOGLE, + "https://huggingface.co/Yoursmiling/Qwen3.5-4B-LiteRT/resolve/main/model_quantized.litertlm?download=true", + "4.27 GB", + isOfflineModel = true, + offlineModelFilename = "model_quantized.litertlm", + offlineRequiredFilenames = listOf( + "model_quantized.litertlm", + "sentencepiece.model" + ), + additionalDownloadUrls = listOf( + "https://huggingface.co/Yoursmiling/Qwen3.5-4B-LiteRT/resolve/main/sentencepiece.model?download=true", + "https://huggingface.co/Yoursmiling/Qwen3.5-4B-LiteRT/resolve/main/tokenizer.json?download=true", + "https://huggingface.co/Yoursmiling/Qwen3.5-4B-LiteRT/resolve/main/tokenizer_config.json?download=true" + ), + requiresVisionBackend = false ), HUMAN_EXPERT("Human Expert", "human-expert", ApiProvider.HUMAN_EXPERT); diff --git a/app/src/main/kotlin/com/google/ai/sample/feature/multimodal/ModelDownloadManager.kt b/app/src/main/kotlin/com/google/ai/sample/feature/multimodal/ModelDownloadManager.kt index 6bc58cd..17cdaa7 100644 --- a/app/src/main/kotlin/com/google/ai/sample/feature/multimodal/ModelDownloadManager.kt +++ b/app/src/main/kotlin/com/google/ai/sample/feature/multimodal/ModelDownloadManager.kt @@ -58,9 +58,16 @@ object ModelDownloadManager { private var downloadJob: Job? = null private var isPaused = false + private data class DownloadTarget( + val finalFile: File, + val tempFile: File, + val url: String, + val label: String + ) + fun isModelDownloaded(context: Context, model: ModelOption = GenerativeAiViewModelFactory.getCurrentModel()): Boolean { - val file = getModelFile(context, model) - return file != null && file.exists() && file.length() > 0 + val required = getRequiredFiles(context, model) + return required.isNotEmpty() && required.all { it.exists() && it.length() > 0 } } fun getModelFile(context: Context, model: ModelOption = GenerativeAiViewModelFactory.getCurrentModel()): File? { @@ -74,13 +81,26 @@ object ModelDownloadManager { } } - private fun getTempFile(context: Context, model: ModelOption): File? { - val modelFilename = model.offlineModelFilename ?: return null - val externalFilesDir = context.getExternalFilesDir(null) - return if (externalFilesDir != null) { - File(externalFilesDir, modelFilename + TEMP_SUFFIX) + private fun getRequiredFiles(context: Context, model: ModelOption): List { + val externalFilesDir = context.getExternalFilesDir(null) ?: return emptyList() + val requiredNames = if (model.offlineRequiredFilenames.isNotEmpty()) { + model.offlineRequiredFilenames } else { - null + listOfNotNull(model.offlineModelFilename) + } + return requiredNames.map { File(externalFilesDir, it) } + } + + fun getMissingRequiredFiles(context: Context, model: ModelOption): List { + val externalFilesDir = context.getExternalFilesDir(null) ?: return model.offlineRequiredFilenames + val requiredNames = if (model.offlineRequiredFilenames.isNotEmpty()) { + model.offlineRequiredFilenames + } else { + listOfNotNull(model.offlineModelFilename) + } + return requiredNames.filter { name -> + val f = File(externalFilesDir, name) + !f.exists() || f.length() <= 0 } } @@ -147,7 +167,7 @@ object ModelDownloadManager { isPaused = false downloadJob = CoroutineScope(Dispatchers.IO).launch { - downloadWithResume(context, model, url) + downloadModelPackage(context, model, url) } } @@ -164,7 +184,7 @@ object ModelDownloadManager { isPaused = false downloadJob = CoroutineScope(Dispatchers.IO).launch { - downloadWithResume(context, model, url) + downloadModelPackage(context, model, url) } } @@ -174,11 +194,16 @@ object ModelDownloadManager { downloadJob?.cancel() downloadJob = null - // Delete temp file - val tempFile = getTempFile(context, model) - if (tempFile != null && tempFile.exists()) { - tempFile.delete() - Log.d(TAG, "Temp file deleted.") + // Delete temp files for full package + val externalFilesDir = context.getExternalFilesDir(null) + if (externalFilesDir != null) { + val targets = buildDownloadTargets(context, model, model.downloadUrl ?: "") + targets.forEach { target -> + if (target.tempFile.exists()) { + target.tempFile.delete() + } + } + Log.d(TAG, "Temporary package files deleted.") } _downloadState.value = DownloadState.Idle @@ -188,21 +213,79 @@ object ModelDownloadManager { } } - private suspend fun downloadWithResume(context: Context, model: ModelOption, url: String) { - val tempFile = getTempFile(context, model) ?: run { + private suspend fun downloadModelPackage(context: Context, model: ModelOption, primaryUrl: String) { + val targets = buildDownloadTargets(context, model, primaryUrl) + if (targets.isEmpty()) { _downloadState.value = DownloadState.Error("Storage not available.") return } - val finalFile = getModelFile(context, model) ?: run { - _downloadState.value = DownloadState.Error("Storage not available.") - return + + for ((index, target) in targets.withIndex()) { + if (!coroutineContext.isActive) return + Log.i(TAG, "Downloading package file ${index + 1}/${targets.size}: ${target.label}") + val error = downloadSingleFileWithResume(context, target, index, targets.size) + if (error != null) { + _downloadState.value = DownloadState.Error(error) + cancelDownloadNotification(context) + return + } + } + + _downloadState.value = DownloadState.Completed + showDownloadCompleteNotification(context) + withContext(Dispatchers.Main) { + Toast.makeText(context, "Model download complete!", Toast.LENGTH_SHORT).show() + } + } + + private fun buildDownloadTargets(context: Context, model: ModelOption, primaryUrl: String): List { + val externalFilesDir = context.getExternalFilesDir(null) ?: return emptyList() + val primaryFilename = model.offlineModelFilename ?: return emptyList() + val urls = listOf(primaryUrl) + model.additionalDownloadUrls + val filenames = urls.mapIndexedNotNull { idx, url -> + if (idx == 0) primaryFilename else filenameFromUrl(url) + } + if (urls.size != filenames.size) { + Log.e(TAG, "Could not resolve filename for at least one download URL.") + return emptyList() + } + return urls.zip(filenames).map { (url, filename) -> + val finalFile = File(externalFilesDir, filename) + DownloadTarget( + finalFile = finalFile, + tempFile = File(externalFilesDir, "$filename$TEMP_SUFFIX"), + url = url, + label = filename + ) + } + } + + private fun filenameFromUrl(url: String): String? { + val clean = url.substringBefore('?') + val slash = clean.lastIndexOf('/') + return if (slash >= 0 && slash + 1 < clean.length) clean.substring(slash + 1) else null + } + + private suspend fun downloadSingleFileWithResume( + context: Context, + target: DownloadTarget, + fileIndex: Int, + fileCount: Int + ): String? { + val tempFile = target.tempFile + val finalFile = target.finalFile + val url = target.url + + if (finalFile.exists() && finalFile.length() > 0L) { + Log.d(TAG, "Skipping already downloaded file: ${target.label}") + return null } var retryCount = 0 var bytesDownloaded = if (tempFile.exists()) tempFile.length() else 0L while (retryCount <= MAX_RETRIES) { - if (!coroutineContext.isActive) return // Coroutine was cancelled + if (!coroutineContext.isActive) return null // Coroutine was cancelled var connection: HttpURLConnection? = null try { @@ -240,9 +323,7 @@ object ModelDownloadManager { } } else -> { - _downloadState.value = DownloadState.Error("Server error: $responseCode") - cancelDownloadNotification(context) - return + return "Server error for ${target.label}: $responseCode" } } @@ -264,7 +345,7 @@ object ModelDownloadManager { if (!coroutineContext.isActive) { Log.d(TAG, "Download cancelled during read.") cancelDownloadNotification(context) - return + return null } if (isPaused) { @@ -275,7 +356,7 @@ object ModelDownloadManager { ) // Keep notification showing paused state showDownloadNotification(context, bytesDownloaded.toFloat() / totalBytes, bytesDownloaded, totalBytes) - return + return null } output.write(buffer, 0, bytesRead) @@ -286,13 +367,14 @@ object ModelDownloadManager { if (now - lastProgressUpdate >= PROGRESS_UPDATE_INTERVAL_MS) { lastProgressUpdate = now val progress = if (totalBytes > 0) bytesDownloaded.toFloat() / totalBytes else 0f + val aggregateProgress = (fileIndex + progress) / fileCount.toFloat() _downloadState.value = DownloadState.Downloading( - progress = progress, + progress = aggregateProgress, bytesDownloaded = bytesDownloaded, totalBytes = totalBytes ) // Point 18: Update notification with progress - showDownloadNotification(context, progress, bytesDownloaded, totalBytes) + showDownloadNotification(context, aggregateProgress, bytesDownloaded, totalBytes) } } } @@ -303,30 +385,20 @@ object ModelDownloadManager { finalFile.delete() if (tempFile.renameTo(finalFile)) { Log.i(TAG, "Download complete! File: ${finalFile.absolutePath} (${finalFile.length()} bytes)") - _downloadState.value = DownloadState.Completed - showDownloadCompleteNotification(context) - withContext(Dispatchers.Main) { - Toast.makeText(context, "Model download complete!", Toast.LENGTH_SHORT).show() - } } else { - _downloadState.value = DownloadState.Error("Failed to save model file.") - cancelDownloadNotification(context) + return "Failed to save ${target.label}." } } - return // Success, exit retry loop + return null // Success, exit retry loop } catch (e: IOException) { Log.e(TAG, "Download error (attempt ${retryCount + 1}): ${e.message}") retryCount++ if (retryCount > MAX_RETRIES) { - _downloadState.value = DownloadState.Error("Download failed after $MAX_RETRIES retries: ${e.message}") - cancelDownloadNotification(context) - withContext(Dispatchers.Main) { - Toast.makeText(context, "Download failed: ${e.message}", Toast.LENGTH_LONG).show() - } + return "Download failed for ${target.label} after $MAX_RETRIES retries: ${e.message}" } else { _downloadState.value = DownloadState.Downloading( - progress = if (bytesDownloaded > 0) 0f else 0f, + progress = fileIndex.toFloat() / fileCount.toFloat(), bytesDownloaded = bytesDownloaded, totalBytes = -1 ) @@ -337,6 +409,8 @@ object ModelDownloadManager { connection?.disconnect() } } + + return "Download failed for ${target.label}." } /** diff --git a/app/src/main/kotlin/com/google/ai/sample/feature/multimodal/PhotoReasoningViewModel.kt b/app/src/main/kotlin/com/google/ai/sample/feature/multimodal/PhotoReasoningViewModel.kt index a059839..9348426 100644 --- a/app/src/main/kotlin/com/google/ai/sample/feature/multimodal/PhotoReasoningViewModel.kt +++ b/app/src/main/kotlin/com/google/ai/sample/feature/multimodal/PhotoReasoningViewModel.kt @@ -327,25 +327,34 @@ class PhotoReasoningViewModel( private fun initializeOfflineModel(context: Context): String? { try { val currentModel = com.google.ai.sample.GenerativeAiViewModelFactory.getCurrentModel() + val missingFiles = ModelDownloadManager.getMissingRequiredFiles(context, currentModel) + if (missingFiles.isNotEmpty()) { + return "Offline model files missing: ${missingFiles.joinToString(", ")}. Please redownload the model package." + } val modelFile = ModelDownloadManager.getModelFile(context, currentModel) if (modelFile != null && modelFile.exists()) { // Load backend preference GenerativeAiViewModelFactory.loadBackendPreference(context) val backend = GenerativeAiViewModelFactory.getBackend() + val isLiteRtModel = currentModel.offlineModelFilename?.endsWith(".litertlm", ignoreCase = true) == true - if (currentModel == ModelOption.GEMMA_4_E4B_IT) { + if (isLiteRtModel) { if (!isLiteRtAbiSupported()) { - return "Gemma 4 offline is only supported on arm64-v8a or x86_64 devices." + return "Offline LiteRT models are only supported on arm64-v8a or x86_64 devices." } Log.i( TAG, - "Initializing Gemma 4 LiteRT engine. preferredBackend=$backend, " + + "Initializing LiteRT engine for ${currentModel.displayName}. preferredBackend=$backend, " + "abis=${Build.SUPPORTED_ABIS?.joinToString() ?: "unknown"}, " + "modelPath=${modelFile.absolutePath}, modelSizeBytes=${modelFile.length()}" ) if (liteRtEngine == null) { val preferredBackend = if (backend == InferenceBackend.GPU) Backend.GPU() else Backend.CPU() - val preferredVisionBackend = if (currentModel.supportsScreenshot) Backend.GPU() else null + val preferredVisionBackend = if (currentModel.requiresVisionBackend) { + if (backend == InferenceBackend.GPU) Backend.GPU() else Backend.CPU() + } else { + null + } val audioBackend = null val cacheDir = if (modelFile.absolutePath.startsWith("/data/local/tmp")) { @@ -401,6 +410,11 @@ class PhotoReasoningViewModel( ) { return "LiteRT native runtime is not available on this device/ABI. Use an arm64-v8a or x86_64 build." } + if (msg.contains("litert_compiled_model", ignoreCase = true) || + msg.contains("litert_tensor_buffer", ignoreCase = true) + ) { + return "Offline model could not be initialized: LiteRT cannot compile this model package on this device. This usually means the file set is incompatible with the selected artifact/backend." + } return if (msg.contains("memory", ignoreCase = true) || msg.contains("RAM", ignoreCase = true) || msg.contains("OOM", ignoreCase = true) || msg.contains("alloc", ignoreCase = true) || msg.contains("out of", ignoreCase = true)) { "Not enough RAM to load the model on GPU. Try switching to CPU." } else { @@ -835,7 +849,8 @@ class PhotoReasoningViewModel( // Initialize model if needed var initError: String? = null val selectedOfflineModel = GenerativeAiViewModelFactory.getCurrentModel() - if (selectedOfflineModel == ModelOption.GEMMA_4_E4B_IT) { + val useLiteRt = selectedOfflineModel.offlineModelFilename?.endsWith(".litertlm", ignoreCase = true) == true + if (useLiteRt) { if (liteRtEngine == null) { withContext(Dispatchers.Main) { replaceAiMessageText("Initializing offline model...", isPending = true) @@ -860,7 +875,7 @@ class PhotoReasoningViewModel( _isInitializingOfflineModelFlow.value = false } - if (selectedOfflineModel == ModelOption.GEMMA_4_E4B_IT && liteRtEngine == null) { + if (useLiteRt && liteRtEngine == null) { val errorMsg = initError ?: "Offline model could not be initialized." withContext(Dispatchers.Main) { _uiState.value = PhotoReasoningUiState.Error(errorMsg) @@ -875,7 +890,7 @@ class PhotoReasoningViewModel( refreshStopButtonState() } return@launch - } else if (selectedOfflineModel != ModelOption.GEMMA_4_E4B_IT && llmInference == null) { + } else if (!useLiteRt && llmInference == null) { val errorMsg = initError ?: "Offline model could not be initialized." withContext(Dispatchers.Main) { _uiState.value = PhotoReasoningUiState.Error(errorMsg) @@ -896,7 +911,7 @@ class PhotoReasoningViewModel( Log.d(TAG, "Sending streaming prompt to offline model (length: ${fullPrompt.length})") - val finalResponse = if (selectedOfflineModel == ModelOption.GEMMA_4_E4B_IT) { + val finalResponse = if (useLiteRt) { val engine = liteRtEngine if (engine == null) { withContext(Dispatchers.Main) {