本文基于TensorFlow并通过雅虎的 open_nsfw 简单地实现Android上的图片鉴黄效果。
TensorFlow是一个软件库或框架,由Google团队设计,以最简单的方式实现机器学习和深度学习概念。他结合了优化技术的计算机代数,便于计算许多数学表达式。
TensorFlow有详细记录,包含大量机器学习库,并提供了一些重要的功能和方法。
TensorFlow 也是一个Google产品,它包括各种机器学习和深度学习算法。TensorFlow可以训练和运行深度神经网络,用于手写数字分类、图像识别和各种序列模型的创建。
现在越来越多的移动设备集成了定制硬件来更有效地处理机器学习带来的工作负载。TensorFlow Lite支持Android神经网络API(Android Neural Networks API)利用这些新的加速器硬件。当加速器硬件不可用的时候,TensorFlow Lite会执行优化CPU,这可以确保模型仍然可以很快地运行在一个大的设备上。
TensorFlow Lite特点如下:
- 轻量级:允许在具有很小的二进制大小和快速初始化启动的机器学习模型设备上进行推理;
- 跨平台:能够运行在许多不同的平台上,首先支持
Android和IOS平台;- 快速:针对移动设备进行了优化,包括显著提高模型加载时间和支持硬件加速。
先看一下实现效果吧:
图片鉴黄主要思路是通过nsfw.tflite模型文件生成Interpreter,然后通过Interpreter获取python中定义的入口ByteBuffer的张量(Tensor),然后把要鉴别的文件做归一化处理,输入到ByteBuffer中,通过运行Interpreter获取结果即可,其流程如下图所示:
在build.gradle中添加TensorFlow依赖:
implementation 'org.tensorflow:tensorflow-lite:0.0.0-nightly'
implementation 'org.tensorflow:tensorflow-lite-gpu:0.0.0-nightly'将resources目录下的nsfw.tflite文件拷贝到手机SD存储卡上。
public void init() {
File file = new File(Environment.getExternalStorageDirectory() + "/sty/tensorflow/",
"nsfw.tflite");
Interpreter.Options options = new Interpreter.Options();
options.setNumThreads(4);
//加载模型
tflite = new Interpreter(file, options);
//获取到Python中定义的变量input,input为入口的意思
//张量
Tensor tensor = tflite.getInputTensor(tflite.getInputIndex("input"));
//申请并清空内存
imgData = ByteBuffer.allocateDirect(224 * 224 * 3 * 4);
imgData.order(ByteOrder.LITTLE_ENDIAN);
isInitialized = true;
}public void run(Bitmap bitmap, Context context) {
imgData.rewind(); //清空
Bitmap scaleBitmap = Bitmap.createScaledBitmap(bitmap, 224, 224, true);
intValues = new int[224 * 224];
//bitmap --> int的数组
scaleBitmap.getPixels(intValues, 0, 224, 0, 0, 224, 224);
//intValues --> 赋值给imgData
for (int color : intValues) {
int r = Color.red(color);
int g = Color.green(color);
int b = Color.blue(color);
imgData.putFloat(b);
imgData.putFloat(g);
imgData.putFloat(r);
}
//最终获取的结果
float[][] outArray = new float[1][2];
//把程序传给GPU,然后GPU判断和执行
tflite.run(imgData, outArray);
//保留4位小数
DecimalFormat df = new DecimalFormat("#0.0000");
//outArray:入参出参对象
//正常图像:outArray[0][0]
//敏感图片:outArray[0][1]
ToastUtil.show(context, "\n黄色图片:" + df.format(outArray[0][1])
+ "\n正常图片:" + df.format(outArray[0][0]));
}
