diff --git a/packages/react-native-background-remover/android/src/main/java/com/backgroundremover/BackgroundRemoverModule.kt b/packages/react-native-background-remover/android/src/main/java/com/backgroundremover/BackgroundRemoverModule.kt index fb03050..7383bb6 100644 --- a/packages/react-native-background-remover/android/src/main/java/com/backgroundremover/BackgroundRemoverModule.kt +++ b/packages/react-native-background-remover/android/src/main/java/com/backgroundremover/BackgroundRemoverModule.kt @@ -19,7 +19,10 @@ import com.google.mlkit.vision.segmentation.Segmenter import com.google.mlkit.vision.segmentation.selfie.SelfieSegmenterOptions import java.io.File import java.io.FileOutputStream +import java.io.InputStream +import java.net.HttpURLConnection import java.net.URI +import java.net.URL import kotlin.math.pow class BackgroundRemoverModule internal constructor(context: ReactApplicationContext) : @@ -33,31 +36,47 @@ class BackgroundRemoverModule internal constructor(context: ReactApplicationCont @ReactMethod override fun removeBackground(imageURI: String, promise: Promise) { val segmenter = this.segmenter ?: createSegmenter() - val image = getImageBitmap(imageURI) - - val inputImage = InputImage.fromBitmap(image, 0) - - segmenter.process(inputImage).addOnFailureListener { e -> - promise.reject(e) - }.addOnSuccessListener { result -> - val maskBuffer = result.buffer - val mask = Bitmap.createBitmap(result.width, result.height, Bitmap.Config.ARGB_8888) - - for (y in 0 until result.height) { - for (x in 0 until result.width) { - val alpha = maskBuffer.getFloat().pow(4) - mask.setPixel(x, y, Color.argb((alpha * 255).toInt(), 0, 0, 0)) + try { + val image = getImageBitmap(imageURI) + val inputImage = InputImage.fromBitmap(image, 0) + + segmenter.process(inputImage).addOnFailureListener { e -> + promise.reject(e) + }.addOnSuccessListener { result -> + val maskBuffer = result.buffer + maskBuffer.rewind() // Reset buffer position + + // Create a new bitmap with transparent background + val resultBitmap = Bitmap.createBitmap(result.width, result.height, Bitmap.Config.ARGB_8888) + + // Get pixel arrays for processing + val imagePixels = IntArray(result.width * result.height) + image.getPixels(imagePixels, 0, result.width, 0, 0, result.width, result.height) + + val resultPixels = IntArray(result.width * result.height) + + // Process each pixel + for (i in 0 until result.width * result.height) { + val confidence = maskBuffer.getFloat() + + if (confidence > 0.5f) { // Threshold for foreground detection + // Keep the original pixel with full opacity + resultPixels[i] = imagePixels[i] + } else { + // Make pixel completely transparent + resultPixels[i] = Color.TRANSPARENT + } } - } - - val paint = Paint(Paint.ANTI_ALIAS_FLAG) - paint.setXfermode(PorterDuffXfermode(PorterDuff.Mode.DST_IN)) - val canvas = Canvas(image) - canvas.drawBitmap(mask, 0f, 0f, paint) + + // Set the processed pixels to result bitmap + resultBitmap.setPixels(resultPixels, 0, result.width, 0, 0, result.width, result.height) - val fileName = URI(imageURI).path.split("/").last() - val savedImageURI = saveImage(image, fileName) - promise.resolve(savedImageURI) + val fileName = URI(imageURI).path.split("/").last() + val savedImageURI = saveImage(resultBitmap, fileName) + promise.resolve(savedImageURI) + } + } catch (e: Exception) { + promise.reject(e) } } @@ -76,7 +95,16 @@ class BackgroundRemoverModule internal constructor(context: ReactApplicationCont private fun getImageBitmap(imageURI: String): Bitmap { val uri = Uri.parse(imageURI) - val bitmap = if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.P) { + return if (uri.scheme == "http" || uri.scheme == "https") { + val localFile = downloadImage(uri.toString()) + decodeBitmapFromUri(Uri.fromFile(localFile)) + } else { + decodeBitmapFromUri(uri) + } + } + + private fun decodeBitmapFromUri(uri: Uri): Bitmap { + return if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.P) { ImageDecoder.decodeBitmap( ImageDecoder.createSource( reactApplicationContext.contentResolver, @@ -86,16 +114,43 @@ class BackgroundRemoverModule internal constructor(context: ReactApplicationCont } else { MediaStore.Images.Media.getBitmap(reactApplicationContext.contentResolver, uri) } + } + + private fun downloadImage(imageUrl: String): File { + val url = URL(imageUrl) + val connection = url.openConnection() as HttpURLConnection + connection.connect() + + if (connection.responseCode != HttpURLConnection.HTTP_OK) { + throw Exception("Failed to download image: ${connection.responseMessage}") + } - return bitmap + val inputStream: InputStream = connection.inputStream + val tempFile = File.createTempFile("downloaded_image", ".jpg", reactApplicationContext.cacheDir) + tempFile.outputStream().use { outputStream -> + inputStream.copyTo(outputStream) + } + inputStream.close() + connection.disconnect() + + return tempFile } private fun saveImage(bitmap: Bitmap, fileName: String): String { - val file = File(reactApplicationContext.filesDir, fileName) + val updatedFileName = if (fileName.endsWith(".jpg", ignoreCase = true)) { + fileName.replace(".jpg", ".png", true) + } else if (!fileName.endsWith(".png", ignoreCase = true)) { + "$fileName.png" + } else { + fileName + } + val file = File(reactApplicationContext.filesDir, updatedFileName) val fileOutputStream = FileOutputStream(file) - bitmap.compress(Bitmap.CompressFormat.JPEG, 100, fileOutputStream) + bitmap.compress(Bitmap.CompressFormat.PNG, 100, fileOutputStream) + fileOutputStream.flush() fileOutputStream.close() - return file.toURI().toString() + + return "file://${file.absolutePath}" } companion object {