Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,13 @@ import com.google.mlkit.vision.segmentation.Segmenter
import com.google.mlkit.vision.segmentation.selfie.SelfieSegmenterOptions
import java.io.File
import java.io.FileOutputStream
import java.io.InputStream
import java.net.HttpURLConnection
import java.net.URI
import java.net.URL
import kotlin.math.pow
import android.webkit.MimeTypeMap
import kotlinx.coroutines.*

class BackgroundRemoverModule internal constructor(context: ReactApplicationContext) :
BackgroundRemoverSpec(context) {
Expand All @@ -33,34 +38,42 @@ class BackgroundRemoverModule internal constructor(context: ReactApplicationCont
@ReactMethod
override fun removeBackground(imageURI: String, promise: Promise) {
val segmenter = this.segmenter ?: createSegmenter()
val image = getImageBitmap(imageURI)

val inputImage = InputImage.fromBitmap(image, 0)

segmenter.process(inputImage).addOnFailureListener { e ->
promise.reject(e)
}.addOnSuccessListener { result ->
val maskBuffer = result.buffer
val mask = Bitmap.createBitmap(result.width, result.height, Bitmap.Config.ARGB_8888)

for (y in 0 until result.height) {
for (x in 0 until result.width) {
val alpha = maskBuffer.getFloat().pow(4)
mask.setPixel(x, y, Color.argb((alpha * 255).toInt(), 0, 0, 0))
try {
val image = getImageBitmap(imageURI).copy(Bitmap.Config.ARGB_8888, true)
val inputImage = InputImage.fromBitmap(image, 0)

segmenter.process(inputImage).addOnFailureListener { e ->
promise.reject(e)
}.addOnSuccessListener { result ->
val maskBuffer = result.buffer
val mask = Bitmap.createBitmap(result.width, result.height, Bitmap.Config.ARGB_8888)

for (y in 0 until result.height) {
for (x in 0 until result.width) {
val alpha = maskBuffer.getFloat().pow(4)
val color = if (alpha > 0.1) Color.argb((alpha * 255).toInt(), 0, 0, 0) else Color.TRANSPARENT
mask.setPixel(x, y, color)
}
}
}

val paint = Paint(Paint.ANTI_ALIAS_FLAG)
paint.setXfermode(PorterDuffXfermode(PorterDuff.Mode.DST_IN))
val canvas = Canvas(image)
canvas.drawBitmap(mask, 0f, 0f, paint)

val fileName = URI(imageURI).path.split("/").last()
val savedImageURI = saveImage(image, fileName)
promise.resolve(savedImageURI)
val mutableImage = Bitmap.createBitmap(image.width, image.height, Bitmap.Config.ARGB_8888)
val canvas = Canvas(mutableImage)
canvas.drawColor(Color.TRANSPARENT, PorterDuff.Mode.CLEAR)
canvas.drawBitmap(image, 0f, 0f, null)
val paint = Paint(Paint.ANTI_ALIAS_FLAG)
paint.xfermode = PorterDuffXfermode(PorterDuff.Mode.DST_IN)
canvas.drawBitmap(mask, 0f, 0f, paint)

val fileName = URI(imageURI).path.split("/").last()
val savedImageURI = saveImage(mutableImage, fileName)
promise.resolve(savedImageURI)
}
} catch (e: Exception) {
promise.reject(e)
}
}


private fun createSegmenter(): Segmenter {
val options =
SelfieSegmenterOptions.Builder()
Expand All @@ -76,7 +89,16 @@ class BackgroundRemoverModule internal constructor(context: ReactApplicationCont
private fun getImageBitmap(imageURI: String): Bitmap {
val uri = Uri.parse(imageURI)

val bitmap = if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.P) {
return if (uri.scheme == "http" || uri.scheme == "https") {
val localFile = downloadImage(uri.toString())
decodeBitmapFromUri(Uri.fromFile(localFile))
} else {
decodeBitmapFromUri(uri)
}
}

private fun decodeBitmapFromUri(uri: Uri): Bitmap {
return if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.P) {
ImageDecoder.decodeBitmap(
ImageDecoder.createSource(
reactApplicationContext.contentResolver,
Expand All @@ -86,16 +108,62 @@ class BackgroundRemoverModule internal constructor(context: ReactApplicationCont
} else {
MediaStore.Images.Media.getBitmap(reactApplicationContext.contentResolver, uri)
}
}

private fun downloadImage(imageUrl: String): File {
return runBlocking {
withContext(Dispatchers.IO) {
val url = URL(imageUrl)
val connection = url.openConnection() as HttpURLConnection
connection.connect()

return bitmap
if (connection.responseCode != HttpURLConnection.HTTP_OK) {
throw Exception("Failed to download image: ${connection.responseMessage}")
}

var extension = MimeTypeMap.getFileExtensionFromUrl(imageUrl)
if (extension.isNullOrEmpty()) {
val contentType = connection.contentType
extension = MimeTypeMap.getSingleton().getExtensionFromMimeType(contentType) ?: "png"
}

val inputStream: InputStream = connection.inputStream
val tempFile = File.createTempFile("downloaded_image", ".$extension", reactApplicationContext.cacheDir)
tempFile.outputStream().use { outputStream ->
inputStream.copyTo(outputStream)
}
inputStream.close()
connection.disconnect()

tempFile
}
}
}

private fun saveImage(bitmap: Bitmap, fileName: String): String {
val file = File(reactApplicationContext.filesDir, fileName)
val updatedFileName = if (fileName.endsWith(".jpg", ignoreCase = true)) {
fileName.replace(".jpg", ".png", true)
} else {
fileName
}
val file = File(reactApplicationContext.filesDir, updatedFileName)

val safeBitmap = if (!bitmap.hasAlpha()) {
val newBitmap = Bitmap.createBitmap(bitmap.width, bitmap.height, Bitmap.Config.ARGB_8888)
val canvas = Canvas(newBitmap)
canvas.drawColor(Color.TRANSPARENT, PorterDuff.Mode.CLEAR)
canvas.drawBitmap(bitmap, 0f, 0f, null)
newBitmap
} else {
bitmap
}

val fileOutputStream = FileOutputStream(file)
bitmap.compress(Bitmap.CompressFormat.JPEG, 100, fileOutputStream)
safeBitmap.compress(Bitmap.CompressFormat.PNG, 100, fileOutputStream)
fileOutputStream.flush()
fileOutputStream.close()
return file.toURI().toString()

return "file://${file.absolutePath}"
}

companion object {
Expand Down
Loading