Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,22 @@ import java.io.File
* It does not perform any actual file operations.
*/
class TestFileProvider : LocalFileProvider {
override fun saveBitmapToFile(
override suspend fun saveBitmapToFile(
bitmap: Bitmap,
file: File,
): File {
) {
TODO("Not yet implemented")
}

override fun getFileFromCache(fileName: String): File {
override suspend fun getFileFromCache(fileName: String): File {
TODO("Not yet implemented")
}

override fun createCacheFile(fileName: String): File {
override suspend fun createCacheFile(fileName: String): File {
TODO("Not yet implemented")
}

override fun saveToSharedStorage(
override suspend fun saveToSharedStorage(
file: File,
fileName: String,
mimeType: String,
Expand All @@ -53,11 +53,11 @@ class TestFileProvider : LocalFileProvider {
TODO("Not yet implemented")
}

override fun copyToInternalStorage(uri: Uri): File {
override suspend fun copyToInternalStorage(uri: Uri): File {
return File("")
}

override fun saveUriToSharedStorage(
override suspend fun saveUriToSharedStorage(
inputUri: Uri,
fileName: String,
mimeType: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,109 +22,114 @@ import android.net.Uri
import android.os.Build
import android.os.Environment
import android.provider.MediaStore
import androidx.annotation.WorkerThread
import androidx.core.content.FileProvider
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import java.io.File
import java.io.FileInputStream
import java.io.FileOutputStream
import java.io.IOException
import java.nio.file.Files
import java.util.UUID
import javax.inject.Inject
import javax.inject.Named
import javax.inject.Singleton

interface LocalFileProvider {
fun saveBitmapToFile(bitmap: Bitmap, file: File): File
fun getFileFromCache(fileName: String): File
fun createCacheFile(fileName: String): File
fun saveToSharedStorage(file: File, fileName: String, mimeType: String): Uri
@WorkerThread
suspend fun saveBitmapToFile(bitmap: Bitmap, file: File)
@WorkerThread
suspend fun getFileFromCache(fileName: String): File
@WorkerThread
suspend fun createCacheFile(fileName: String): File
@WorkerThread
suspend fun saveToSharedStorage(file: File, fileName: String, mimeType: String): Uri
fun sharingUriForFile(file: File): Uri
fun copyToInternalStorage(uri: Uri): File
fun saveUriToSharedStorage(
inputUri: Uri,
fileName: String,
mimeType: String,
): Uri
@WorkerThread
suspend fun copyToInternalStorage(uri: Uri): File
@WorkerThread
suspend fun saveUriToSharedStorage(inputUri: Uri, fileName: String, mimeType: String): Uri
}

@Singleton
open class LocalFileProviderImpl @Inject constructor(val application: Context) : LocalFileProvider {
open class LocalFileProviderImpl @Inject constructor(
val application: Context,
@Named("IO")
val ioDispatcher: CoroutineDispatcher,
) : LocalFileProvider {

override fun saveBitmapToFile(bitmap: Bitmap, file: File): File {
override suspend fun saveBitmapToFile(bitmap: Bitmap, file: File) = withContext(ioDispatcher) {
var outputStream: FileOutputStream? = null
try {
outputStream = FileOutputStream(file)
bitmap.compress(Bitmap.CompressFormat.JPEG, 100, outputStream)
outputStream.flush()
return file
} catch (e: IOException) {
throw e
} finally {
outputStream?.close()
}
}
override fun getFileFromCache(fileName: String): File {
return File(application.cacheDir, fileName)

override suspend fun getFileFromCache(fileName: String): File = withContext(ioDispatcher) {
File(application.cacheDir, fileName)
}

@Throws(IOException::class)
override fun createCacheFile(fileName: String): File {
override suspend fun createCacheFile(fileName: String): File = withContext(ioDispatcher) {
val cacheDir = application.cacheDir
val imageFile = File(cacheDir, fileName)
if (!imageFile.createNewFile()) {
throw IOException("Unable to create file: ${imageFile.absolutePath}")
}
return imageFile
return@withContext imageFile
}

override fun saveToSharedStorage(
override suspend fun saveToSharedStorage(
file: File,
fileName: String,
mimeType: String,
): Uri {
): Uri = withContext(ioDispatcher) {
val (uri, contentValues) = createSharedStorageEntry(fileName, mimeType)
saveFileToUri(file, uri)
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) {
contentValues.put(MediaStore.Images.ImageColumns.IS_PENDING, 0)
}
application.contentResolver.update(uri, contentValues, null, null)
return uri
return@withContext uri
}

override fun saveUriToSharedStorage(
override suspend fun saveUriToSharedStorage(
inputUri: Uri,
fileName: String,
mimeType: String,
): Uri {
): Uri = withContext(ioDispatcher) {
val (newUri, contentValues) = createSharedStorageEntry(fileName, mimeType)
application.contentResolver.openOutputStream(newUri)?.use { outputStream ->
application.contentResolver.openInputStream(inputUri)?.use { inputStream ->
val buffer = ByteArray(4 * 1024) // 4 KB buffer size - adjust as needed
var bytesRead: Int
while (inputStream.read(buffer).also { bytesRead = it } != -1) {
outputStream.write(buffer, 0, bytesRead)
}
inputStream.copyTo(outputStream)
}
} ?: throw IOException("Failed to open output stream.")
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) {
contentValues.put(MediaStore.Images.ImageColumns.IS_PENDING, 0)
}
application.contentResolver.update(newUri, contentValues, null, null)
return newUri
return@withContext newUri
}

@Throws(IOException::class)
@WorkerThread
private fun saveFileToUri(file: File, uri: Uri) {
application.contentResolver.openOutputStream(uri)?.use { outputStream ->
FileInputStream(file).use { inputStream ->
val buffer = ByteArray(4 * 1024) // 4 KB buffer size - adjust as needed
var bytesRead: Int
while (inputStream.read(buffer).also { bytesRead = it } != -1) {
outputStream.write(buffer, 0, bytesRead)
}
file.inputStream().use { inputStream ->
inputStream.copyTo(outputStream)
}
} ?: throw IOException("Failed to open output stream for uri: $uri")
}

@Throws(IOException::class)
@WorkerThread
private fun createSharedStorageEntry(fileName: String, mimeType: String): Pair<Uri, ContentValues> {
val resolver = application.contentResolver
val contentValues = ContentValues().apply {
Expand Down Expand Up @@ -160,14 +165,14 @@ open class LocalFileProviderImpl @Inject constructor(val application: Context) :
}

@Throws(IOException::class)
override fun copyToInternalStorage(uri: Uri): File {
override suspend fun copyToInternalStorage(uri: Uri): File = withContext(ioDispatcher) {
val uuid = UUID.randomUUID()
val file = File(application.cacheDir, "temp_file_$uuid")
application.contentResolver.openInputStream(uri)?.use { inputStream ->
file.outputStream().use { outputStream ->
inputStream.copyTo(outputStream)
}
}
return file
return@withContext file
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ internal object DataModule {

@Provides
@Singleton
fun provideLocalFileProvider(@ApplicationContext appContext: Context): LocalFileProvider =
LocalFileProviderImpl(appContext)
fun provideLocalFileProvider(@ApplicationContext appContext: Context, @Named("IO") ioDispatcher: CoroutineDispatcher): LocalFileProvider =
LocalFileProviderImpl(appContext, ioDispatcher)

@Provides
@Singleton
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ internal class ImageGenerationRepositoryImpl @Inject constructor(

override suspend fun saveImage(imageBitmap: Bitmap): Uri {
val cacheFile = localFileProvider.createCacheFile("shared_image_${UUID.randomUUID()}.jpg")
val file = localFileProvider.saveBitmapToFile(imageBitmap, cacheFile)
return localFileProvider.sharingUriForFile(file)
localFileProvider.saveBitmapToFile(imageBitmap, cacheFile)
return localFileProvider.sharingUriForFile(cacheFile)
}

override suspend fun saveImageToExternalStorage(imageBitmap: Bitmap): Uri {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,24 +35,19 @@ import com.android.developers.androidify.data.TextGenerationRepository
import com.android.developers.androidify.util.LocalFileProvider
import dagger.hilt.android.lifecycle.HiltViewModel
import dagger.hilt.android.qualifiers.ApplicationContext
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.cancel
import kotlinx.coroutines.Job
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.update
import kotlinx.coroutines.launch
import javax.inject.Inject
import javax.inject.Named

@HiltViewModel
class CreationViewModel @Inject constructor(
val internetConnectivityManager: InternetConnectivityManager,
val imageGenerationRepository: ImageGenerationRepository,
val textGenerationRepository: TextGenerationRepository,
val fileProvider: LocalFileProvider,
@Named("IO")
val ioDispatcher: CoroutineDispatcher = Dispatchers.IO,
@ApplicationContext
val context: Context,
) : ViewModel() {
Expand All @@ -74,6 +69,9 @@ class CreationViewModel @Inject constructor(
val snackbarHostState: StateFlow<SnackbarHostState>
get() = _snackbarHostState

private var promptGenerationJob: Job? = null
private var imageGenerationJob: Job? = null

fun onImageSelected(uri: Uri?) {
_uiState.update {
it.copy(
Expand All @@ -96,7 +94,8 @@ class CreationViewModel @Inject constructor(
}

fun onPromptGenerationClicked() {
viewModelScope.launch(ioDispatcher) {
promptGenerationJob?.cancel()
promptGenerationJob = viewModelScope.launch {
Log.d("CreationViewModel", "Generating prompt...")
_uiState.update {
it.copy(promptGenerationInProgress = true)
Expand All @@ -122,7 +121,8 @@ class CreationViewModel @Inject constructor(
}

fun startClicked() {
viewModelScope.launch(ioDispatcher) {
imageGenerationJob?.cancel()
imageGenerationJob = viewModelScope.launch {
if (internetConnectivityManager.isInternetAvailable()) {
try {
_uiState.update {
Expand Down Expand Up @@ -196,7 +196,8 @@ class CreationViewModel @Inject constructor(
}

fun cancelInProgressTask() {
ioDispatcher.cancel()
promptGenerationJob?.cancel()
imageGenerationJob?.cancel()
_uiState.update {
it.copy(screenState = ScreenState.EDIT)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ class CreationViewModelTest {
imageGenerationRepository,
TestTextGenerationRepository(),
TestFileProvider(),
UnconfinedTestDispatcher(),
context = RuntimeEnvironment.getApplication(),
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,16 @@ import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import com.android.developers.androidify.data.ImageGenerationRepository
import dagger.hilt.android.lifecycle.HiltViewModel
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.update
import kotlinx.coroutines.launch
import javax.inject.Inject
import javax.inject.Named

@HiltViewModel
class ResultsViewModel @Inject constructor(
val imageGenerationRepository: ImageGenerationRepository,
@Named("IO")
val ioDispatcher: CoroutineDispatcher = Dispatchers.IO,
) : ViewModel() {

private val _state = MutableStateFlow(ResultState())
Expand All @@ -58,7 +53,7 @@ class ResultsViewModel @Inject constructor(
}

fun shareClicked() {
viewModelScope.launch(ioDispatcher) {
viewModelScope.launch {
val resultUrl = state.value.resultImageBitmap
if (resultUrl != null) {
val imageFileUri = imageGenerationRepository.saveImage(resultUrl)
Expand All @@ -70,7 +65,7 @@ class ResultsViewModel @Inject constructor(
}
}
fun downloadClicked() {
viewModelScope.launch(ioDispatcher) {
viewModelScope.launch {
val resultBitmap = state.value.resultImageBitmap
val originalImage = state.value.originalImageUrl
if (originalImage != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ class ResultsViewModelTest {
fun setup() {
viewModel = ResultsViewModel(
FakeImageGenerationRepository(),
UnconfinedTestDispatcher(),
)
}

Expand Down