Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ enum class ModelOption(
val size: String? = null,
val supportsScreenshot: Boolean = true,
val isOfflineModel: Boolean = false,
val offlineModelFilename: String? = null
val offlineModelFilename: String? = null,
val offlineRequiredFilenames: List<String> = emptyList(),
val additionalDownloadUrls: List<String> = emptyList(),
val requiresVisionBackend: Boolean = false
) {
PUTER_GLM5("GLM-5V Turbo (Puter)", "openrouter:z-ai/glm-5v-turbo", ApiProvider.PUTER, supportsScreenshot = true),
MISTRAL_LARGE_3("Mistral Large 3", "mistral-large-latest", ApiProvider.MISTRAL),
Expand All @@ -53,15 +56,37 @@ enum class ModelOption(
"https://huggingface.co/na5h13/gemma-3n-E4B-it-litert-lm/resolve/main/gemma-3n-E4B-it-int4.litertlm?download=true",
"4.92 GB",
isOfflineModel = true,
offlineModelFilename = "gemma-3n-e4b-it-int4.litertlm"
offlineModelFilename = "gemma-3n-e4b-it-int4.litertlm",
offlineRequiredFilenames = listOf("gemma-3n-e4b-it-int4.litertlm")
),
GEMMA_4_E4B_IT(
"Gemma 4 E4B it (offline)",
"gemma-4-e4b-it",
ApiProvider.GOOGLE,
"https://huggingface.co/litert-community/gemma-4-E4B-it-litert-lm/resolve/main/gemma-4-E4B-it.litertlm?download=true",
"3.40 GB",
isOfflineModel = true,
offlineModelFilename = "gemma-4-E4B-it.litertlm"
offlineModelFilename = "gemma-4-E4B-it.litertlm",
offlineRequiredFilenames = listOf("gemma-4-E4B-it.litertlm")
),
QWEN3_5_4B_OFFLINE(
"Qwen3.5 4B (offline)",
"qwen3.5-4b-offline",
ApiProvider.GOOGLE,
"https://huggingface.co/Yoursmiling/Qwen3.5-4B-LiteRT/resolve/main/model_quantized.litertlm?download=true",
"4.27 GB",
isOfflineModel = true,
offlineModelFilename = "model_quantized.litertlm",
offlineRequiredFilenames = listOf(
"model_quantized.litertlm",
"sentencepiece.model"
),
additionalDownloadUrls = listOf(
"https://huggingface.co/Yoursmiling/Qwen3.5-4B-LiteRT/resolve/main/sentencepiece.model?download=true",
"https://huggingface.co/Yoursmiling/Qwen3.5-4B-LiteRT/resolve/main/tokenizer.json?download=true",
"https://huggingface.co/Yoursmiling/Qwen3.5-4B-LiteRT/resolve/main/tokenizer_config.json?download=true"
),
requiresVisionBackend = false
),
HUMAN_EXPERT("Human Expert", "human-expert", ApiProvider.HUMAN_EXPERT);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,16 @@ object ModelDownloadManager {
private var downloadJob: Job? = null
private var isPaused = false

private data class DownloadTarget(
val finalFile: File,
val tempFile: File,
val url: String,
val label: String
)

fun isModelDownloaded(context: Context, model: ModelOption = GenerativeAiViewModelFactory.getCurrentModel()): Boolean {
val file = getModelFile(context, model)
return file != null && file.exists() && file.length() > 0
val required = getRequiredFiles(context, model)
return required.isNotEmpty() && required.all { it.exists() && it.length() > 0 }
}

fun getModelFile(context: Context, model: ModelOption = GenerativeAiViewModelFactory.getCurrentModel()): File? {
Expand All @@ -74,13 +81,26 @@ object ModelDownloadManager {
}
}

private fun getTempFile(context: Context, model: ModelOption): File? {
val modelFilename = model.offlineModelFilename ?: return null
val externalFilesDir = context.getExternalFilesDir(null)
return if (externalFilesDir != null) {
File(externalFilesDir, modelFilename + TEMP_SUFFIX)
private fun getRequiredFiles(context: Context, model: ModelOption): List<File> {
val externalFilesDir = context.getExternalFilesDir(null) ?: return emptyList()
val requiredNames = if (model.offlineRequiredFilenames.isNotEmpty()) {
model.offlineRequiredFilenames
} else {
null
listOfNotNull(model.offlineModelFilename)
}
return requiredNames.map { File(externalFilesDir, it) }
}

fun getMissingRequiredFiles(context: Context, model: ModelOption): List<String> {
val externalFilesDir = context.getExternalFilesDir(null) ?: return model.offlineRequiredFilenames
val requiredNames = if (model.offlineRequiredFilenames.isNotEmpty()) {
model.offlineRequiredFilenames
} else {
listOfNotNull(model.offlineModelFilename)
}
return requiredNames.filter { name ->
val f = File(externalFilesDir, name)
!f.exists() || f.length() <= 0
}
}

Expand Down Expand Up @@ -147,7 +167,7 @@ object ModelDownloadManager {

isPaused = false
downloadJob = CoroutineScope(Dispatchers.IO).launch {
downloadWithResume(context, model, url)
downloadModelPackage(context, model, url)
}
}

Expand All @@ -164,7 +184,7 @@ object ModelDownloadManager {

isPaused = false
downloadJob = CoroutineScope(Dispatchers.IO).launch {
downloadWithResume(context, model, url)
downloadModelPackage(context, model, url)
}
}

Expand All @@ -174,11 +194,16 @@ object ModelDownloadManager {
downloadJob?.cancel()
downloadJob = null

// Delete temp file
val tempFile = getTempFile(context, model)
if (tempFile != null && tempFile.exists()) {
tempFile.delete()
Log.d(TAG, "Temp file deleted.")
// Delete temp files for full package
val externalFilesDir = context.getExternalFilesDir(null)
if (externalFilesDir != null) {
val targets = buildDownloadTargets(context, model, model.downloadUrl ?: "")
targets.forEach { target ->
if (target.tempFile.exists()) {
target.tempFile.delete()
}
}
Log.d(TAG, "Temporary package files deleted.")
}

_downloadState.value = DownloadState.Idle
Expand All @@ -188,21 +213,79 @@ object ModelDownloadManager {
}
}

private suspend fun downloadWithResume(context: Context, model: ModelOption, url: String) {
val tempFile = getTempFile(context, model) ?: run {
private suspend fun downloadModelPackage(context: Context, model: ModelOption, primaryUrl: String) {
val targets = buildDownloadTargets(context, model, primaryUrl)
if (targets.isEmpty()) {
_downloadState.value = DownloadState.Error("Storage not available.")
return
}
val finalFile = getModelFile(context, model) ?: run {
_downloadState.value = DownloadState.Error("Storage not available.")
return

for ((index, target) in targets.withIndex()) {
if (!coroutineContext.isActive) return
Log.i(TAG, "Downloading package file ${index + 1}/${targets.size}: ${target.label}")
val error = downloadSingleFileWithResume(context, target, index, targets.size)
if (error != null) {
_downloadState.value = DownloadState.Error(error)
cancelDownloadNotification(context)
return
}
}

_downloadState.value = DownloadState.Completed
showDownloadCompleteNotification(context)
withContext(Dispatchers.Main) {
Toast.makeText(context, "Model download complete!", Toast.LENGTH_SHORT).show()
}
}

private fun buildDownloadTargets(context: Context, model: ModelOption, primaryUrl: String): List<DownloadTarget> {
val externalFilesDir = context.getExternalFilesDir(null) ?: return emptyList()
val primaryFilename = model.offlineModelFilename ?: return emptyList()
val urls = listOf(primaryUrl) + model.additionalDownloadUrls
val filenames = urls.mapIndexedNotNull { idx, url ->
if (idx == 0) primaryFilename else filenameFromUrl(url)
}
if (urls.size != filenames.size) {
Log.e(TAG, "Could not resolve filename for at least one download URL.")
return emptyList()
}
return urls.zip(filenames).map { (url, filename) ->
val finalFile = File(externalFilesDir, filename)
DownloadTarget(
finalFile = finalFile,
tempFile = File(externalFilesDir, "$filename$TEMP_SUFFIX"),
url = url,
label = filename
)
}
}

private fun filenameFromUrl(url: String): String? {
val clean = url.substringBefore('?')
val slash = clean.lastIndexOf('/')
return if (slash >= 0 && slash + 1 < clean.length) clean.substring(slash + 1) else null
}

private suspend fun downloadSingleFileWithResume(
context: Context,
target: DownloadTarget,
fileIndex: Int,
fileCount: Int
): String? {
val tempFile = target.tempFile
val finalFile = target.finalFile
val url = target.url

if (finalFile.exists() && finalFile.length() > 0L) {
Log.d(TAG, "Skipping already downloaded file: ${target.label}")
return null
}

var retryCount = 0
var bytesDownloaded = if (tempFile.exists()) tempFile.length() else 0L

while (retryCount <= MAX_RETRIES) {
if (!coroutineContext.isActive) return // Coroutine was cancelled
if (!coroutineContext.isActive) return null // Coroutine was cancelled

var connection: HttpURLConnection? = null
try {
Expand Down Expand Up @@ -240,9 +323,7 @@ object ModelDownloadManager {
}
}
else -> {
_downloadState.value = DownloadState.Error("Server error: $responseCode")
cancelDownloadNotification(context)
return
return "Server error for ${target.label}: $responseCode"
}
}

Expand All @@ -264,7 +345,7 @@ object ModelDownloadManager {
if (!coroutineContext.isActive) {
Log.d(TAG, "Download cancelled during read.")
cancelDownloadNotification(context)
return
return null
}

if (isPaused) {
Expand All @@ -275,7 +356,7 @@ object ModelDownloadManager {
)
// Keep notification showing paused state
showDownloadNotification(context, bytesDownloaded.toFloat() / totalBytes, bytesDownloaded, totalBytes)
return
return null
}

output.write(buffer, 0, bytesRead)
Expand All @@ -286,13 +367,14 @@ object ModelDownloadManager {
if (now - lastProgressUpdate >= PROGRESS_UPDATE_INTERVAL_MS) {
lastProgressUpdate = now
val progress = if (totalBytes > 0) bytesDownloaded.toFloat() / totalBytes else 0f
val aggregateProgress = (fileIndex + progress) / fileCount.toFloat()
_downloadState.value = DownloadState.Downloading(
progress = progress,
progress = aggregateProgress,
bytesDownloaded = bytesDownloaded,
totalBytes = totalBytes
)
// Point 18: Update notification with progress
showDownloadNotification(context, progress, bytesDownloaded, totalBytes)
showDownloadNotification(context, aggregateProgress, bytesDownloaded, totalBytes)
}
}
}
Expand All @@ -303,30 +385,20 @@ object ModelDownloadManager {
finalFile.delete()
if (tempFile.renameTo(finalFile)) {
Log.i(TAG, "Download complete! File: ${finalFile.absolutePath} (${finalFile.length()} bytes)")
_downloadState.value = DownloadState.Completed
showDownloadCompleteNotification(context)
withContext(Dispatchers.Main) {
Toast.makeText(context, "Model download complete!", Toast.LENGTH_SHORT).show()
}
} else {
_downloadState.value = DownloadState.Error("Failed to save model file.")
cancelDownloadNotification(context)
return "Failed to save ${target.label}."
}
}
return // Success, exit retry loop
return null // Success, exit retry loop

} catch (e: IOException) {
Log.e(TAG, "Download error (attempt ${retryCount + 1}): ${e.message}")
retryCount++
if (retryCount > MAX_RETRIES) {
_downloadState.value = DownloadState.Error("Download failed after $MAX_RETRIES retries: ${e.message}")
cancelDownloadNotification(context)
withContext(Dispatchers.Main) {
Toast.makeText(context, "Download failed: ${e.message}", Toast.LENGTH_LONG).show()
}
return "Download failed for ${target.label} after $MAX_RETRIES retries: ${e.message}"
} else {
_downloadState.value = DownloadState.Downloading(
progress = if (bytesDownloaded > 0) 0f else 0f,
progress = fileIndex.toFloat() / fileCount.toFloat(),
bytesDownloaded = bytesDownloaded,
totalBytes = -1
)
Expand All @@ -337,6 +409,8 @@ object ModelDownloadManager {
connection?.disconnect()
}
}

return "Download failed for ${target.label}."
}

/**
Expand Down
Loading
Loading