diff --git a/packages/react-native-background-remover/android/src/main/java/com/backgroundremover/BackgroundRemoverModule.kt b/packages/react-native-background-remover/android/src/main/java/com/backgroundremover/BackgroundRemoverModule.kt index fb03050..17aefb6 100644 --- a/packages/react-native-background-remover/android/src/main/java/com/backgroundremover/BackgroundRemoverModule.kt +++ b/packages/react-native-background-remover/android/src/main/java/com/backgroundremover/BackgroundRemoverModule.kt @@ -19,8 +19,13 @@ import com.google.mlkit.vision.segmentation.Segmenter import com.google.mlkit.vision.segmentation.selfie.SelfieSegmenterOptions import java.io.File import java.io.FileOutputStream +import java.io.InputStream +import java.net.HttpURLConnection import java.net.URI +import java.net.URL import kotlin.math.pow +import android.webkit.MimeTypeMap +import kotlinx.coroutines.* class BackgroundRemoverModule internal constructor(context: ReactApplicationContext) : BackgroundRemoverSpec(context) { @@ -33,34 +38,42 @@ class BackgroundRemoverModule internal constructor(context: ReactApplicationCont @ReactMethod override fun removeBackground(imageURI: String, promise: Promise) { val segmenter = this.segmenter ?: createSegmenter() - val image = getImageBitmap(imageURI) - - val inputImage = InputImage.fromBitmap(image, 0) - - segmenter.process(inputImage).addOnFailureListener { e -> - promise.reject(e) - }.addOnSuccessListener { result -> - val maskBuffer = result.buffer - val mask = Bitmap.createBitmap(result.width, result.height, Bitmap.Config.ARGB_8888) - - for (y in 0 until result.height) { - for (x in 0 until result.width) { - val alpha = maskBuffer.getFloat().pow(4) - mask.setPixel(x, y, Color.argb((alpha * 255).toInt(), 0, 0, 0)) + try { + val image = getImageBitmap(imageURI).copy(Bitmap.Config.ARGB_8888, true) + val inputImage = InputImage.fromBitmap(image, 0) + + segmenter.process(inputImage).addOnFailureListener { e -> + promise.reject(e) + }.addOnSuccessListener { result -> + val maskBuffer = result.buffer + val mask = Bitmap.createBitmap(result.width, result.height, Bitmap.Config.ARGB_8888) + + for (y in 0 until result.height) { + for (x in 0 until result.width) { + val alpha = maskBuffer.getFloat().pow(4) + val color = if (alpha > 0.1) Color.argb((alpha * 255).toInt(), 0, 0, 0) else Color.TRANSPARENT + mask.setPixel(x, y, color) + } } - } - val paint = Paint(Paint.ANTI_ALIAS_FLAG) - paint.setXfermode(PorterDuffXfermode(PorterDuff.Mode.DST_IN)) - val canvas = Canvas(image) - canvas.drawBitmap(mask, 0f, 0f, paint) - - val fileName = URI(imageURI).path.split("/").last() - val savedImageURI = saveImage(image, fileName) - promise.resolve(savedImageURI) + val mutableImage = Bitmap.createBitmap(image.width, image.height, Bitmap.Config.ARGB_8888) + val canvas = Canvas(mutableImage) + canvas.drawColor(Color.TRANSPARENT, PorterDuff.Mode.CLEAR) + canvas.drawBitmap(image, 0f, 0f, null) + val paint = Paint(Paint.ANTI_ALIAS_FLAG) + paint.xfermode = PorterDuffXfermode(PorterDuff.Mode.DST_IN) + canvas.drawBitmap(mask, 0f, 0f, paint) + + val fileName = URI(imageURI).path.split("/").last() + val savedImageURI = saveImage(mutableImage, fileName) + promise.resolve(savedImageURI) + } + } catch (e: Exception) { + promise.reject(e) } } + private fun createSegmenter(): Segmenter { val options = SelfieSegmenterOptions.Builder() @@ -76,7 +89,16 @@ class BackgroundRemoverModule internal constructor(context: ReactApplicationCont private fun getImageBitmap(imageURI: String): Bitmap { val uri = Uri.parse(imageURI) - val bitmap = if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.P) { + return if (uri.scheme == "http" || uri.scheme == "https") { + val localFile = downloadImage(uri.toString()) + decodeBitmapFromUri(Uri.fromFile(localFile)) + } else { + decodeBitmapFromUri(uri) + } + } + + private fun decodeBitmapFromUri(uri: Uri): Bitmap { + return if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.P) { ImageDecoder.decodeBitmap( ImageDecoder.createSource( reactApplicationContext.contentResolver, @@ -86,16 +108,62 @@ class BackgroundRemoverModule internal constructor(context: ReactApplicationCont } else { MediaStore.Images.Media.getBitmap(reactApplicationContext.contentResolver, uri) } + } + + private fun downloadImage(imageUrl: String): File { + return runBlocking { + withContext(Dispatchers.IO) { + val url = URL(imageUrl) + val connection = url.openConnection() as HttpURLConnection + connection.connect() - return bitmap + if (connection.responseCode != HttpURLConnection.HTTP_OK) { + throw Exception("Failed to download image: ${connection.responseMessage}") + } + + var extension = MimeTypeMap.getFileExtensionFromUrl(imageUrl) + if (extension.isNullOrEmpty()) { + val contentType = connection.contentType + extension = MimeTypeMap.getSingleton().getExtensionFromMimeType(contentType) ?: "png" + } + + val inputStream: InputStream = connection.inputStream + val tempFile = File.createTempFile("downloaded_image", ".$extension", reactApplicationContext.cacheDir) + tempFile.outputStream().use { outputStream -> + inputStream.copyTo(outputStream) + } + inputStream.close() + connection.disconnect() + + tempFile + } + } } private fun saveImage(bitmap: Bitmap, fileName: String): String { - val file = File(reactApplicationContext.filesDir, fileName) + val updatedFileName = if (fileName.endsWith(".jpg", ignoreCase = true)) { + fileName.replace(".jpg", ".png", true) + } else { + fileName + } + val file = File(reactApplicationContext.filesDir, updatedFileName) + + val safeBitmap = if (!bitmap.hasAlpha()) { + val newBitmap = Bitmap.createBitmap(bitmap.width, bitmap.height, Bitmap.Config.ARGB_8888) + val canvas = Canvas(newBitmap) + canvas.drawColor(Color.TRANSPARENT, PorterDuff.Mode.CLEAR) + canvas.drawBitmap(bitmap, 0f, 0f, null) + newBitmap + } else { + bitmap + } + val fileOutputStream = FileOutputStream(file) - bitmap.compress(Bitmap.CompressFormat.JPEG, 100, fileOutputStream) + safeBitmap.compress(Bitmap.CompressFormat.PNG, 100, fileOutputStream) + fileOutputStream.flush() fileOutputStream.close() - return file.toURI().toString() + + return "file://${file.absolutePath}" } companion object {