Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ import com.google.mlkit.vision.segmentation.Segmenter
import com.google.mlkit.vision.segmentation.selfie.SelfieSegmenterOptions
import java.io.File
import java.io.FileOutputStream
import java.io.InputStream
import java.net.HttpURLConnection
import java.net.URI
import java.net.URL
import kotlin.math.pow

class BackgroundRemoverModule internal constructor(context: ReactApplicationContext) :
Expand All @@ -33,31 +36,47 @@ class BackgroundRemoverModule internal constructor(context: ReactApplicationCont
@ReactMethod
override fun removeBackground(imageURI: String, promise: Promise) {
val segmenter = this.segmenter ?: createSegmenter()
val image = getImageBitmap(imageURI)

val inputImage = InputImage.fromBitmap(image, 0)

segmenter.process(inputImage).addOnFailureListener { e ->
promise.reject(e)
}.addOnSuccessListener { result ->
val maskBuffer = result.buffer
val mask = Bitmap.createBitmap(result.width, result.height, Bitmap.Config.ARGB_8888)

for (y in 0 until result.height) {
for (x in 0 until result.width) {
val alpha = maskBuffer.getFloat().pow(4)
mask.setPixel(x, y, Color.argb((alpha * 255).toInt(), 0, 0, 0))
try {
val image = getImageBitmap(imageURI)
val inputImage = InputImage.fromBitmap(image, 0)

segmenter.process(inputImage).addOnFailureListener { e ->
promise.reject(e)
}.addOnSuccessListener { result ->
val maskBuffer = result.buffer
maskBuffer.rewind() // Reset buffer position

// Create a new bitmap with transparent background
val resultBitmap = Bitmap.createBitmap(result.width, result.height, Bitmap.Config.ARGB_8888)

// Get pixel arrays for processing
val imagePixels = IntArray(result.width * result.height)
image.getPixels(imagePixels, 0, result.width, 0, 0, result.width, result.height)

val resultPixels = IntArray(result.width * result.height)

// Process each pixel
for (i in 0 until result.width * result.height) {
val confidence = maskBuffer.getFloat()

if (confidence > 0.5f) { // Threshold for foreground detection
// Keep the original pixel with full opacity
resultPixels[i] = imagePixels[i]
} else {
// Make pixel completely transparent
resultPixels[i] = Color.TRANSPARENT
}
}
}

val paint = Paint(Paint.ANTI_ALIAS_FLAG)
paint.setXfermode(PorterDuffXfermode(PorterDuff.Mode.DST_IN))
val canvas = Canvas(image)
canvas.drawBitmap(mask, 0f, 0f, paint)

// Set the processed pixels to result bitmap
resultBitmap.setPixels(resultPixels, 0, result.width, 0, 0, result.width, result.height)

val fileName = URI(imageURI).path.split("/").last()
val savedImageURI = saveImage(image, fileName)
promise.resolve(savedImageURI)
val fileName = URI(imageURI).path.split("/").last()
val savedImageURI = saveImage(resultBitmap, fileName)
promise.resolve(savedImageURI)
}
} catch (e: Exception) {
promise.reject(e)
}
}

Expand All @@ -76,7 +95,16 @@ class BackgroundRemoverModule internal constructor(context: ReactApplicationCont
private fun getImageBitmap(imageURI: String): Bitmap {
val uri = Uri.parse(imageURI)

val bitmap = if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.P) {
return if (uri.scheme == "http" || uri.scheme == "https") {
val localFile = downloadImage(uri.toString())
decodeBitmapFromUri(Uri.fromFile(localFile))
} else {
decodeBitmapFromUri(uri)
}
}

private fun decodeBitmapFromUri(uri: Uri): Bitmap {
return if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.P) {
ImageDecoder.decodeBitmap(
ImageDecoder.createSource(
reactApplicationContext.contentResolver,
Expand All @@ -86,16 +114,43 @@ class BackgroundRemoverModule internal constructor(context: ReactApplicationCont
} else {
MediaStore.Images.Media.getBitmap(reactApplicationContext.contentResolver, uri)
}
}

private fun downloadImage(imageUrl: String): File {
val url = URL(imageUrl)
val connection = url.openConnection() as HttpURLConnection
connection.connect()

if (connection.responseCode != HttpURLConnection.HTTP_OK) {
throw Exception("Failed to download image: ${connection.responseMessage}")
}

return bitmap
val inputStream: InputStream = connection.inputStream
val tempFile = File.createTempFile("downloaded_image", ".jpg", reactApplicationContext.cacheDir)
tempFile.outputStream().use { outputStream ->
inputStream.copyTo(outputStream)
}
inputStream.close()
connection.disconnect()

return tempFile
}

private fun saveImage(bitmap: Bitmap, fileName: String): String {
val file = File(reactApplicationContext.filesDir, fileName)
val updatedFileName = if (fileName.endsWith(".jpg", ignoreCase = true)) {
fileName.replace(".jpg", ".png", true)
} else if (!fileName.endsWith(".png", ignoreCase = true)) {
"$fileName.png"
} else {
fileName
}
val file = File(reactApplicationContext.filesDir, updatedFileName)
val fileOutputStream = FileOutputStream(file)
bitmap.compress(Bitmap.CompressFormat.JPEG, 100, fileOutputStream)
bitmap.compress(Bitmap.CompressFormat.PNG, 100, fileOutputStream)
fileOutputStream.flush()
fileOutputStream.close()
return file.toURI().toString()

return "file://${file.absolutePath}"
}

companion object {
Expand Down