CV深度学习模型Android端落地第三步:使用Tensorflow Lite 将自己的训练得到的模型移植到Android上
这个系列的博客主要介绍如何在Android设备上移植你训练的cv神经网络模型。
主要过程如下: 1、使用Android Camera2 APIs获得摄像头实时预览的画面。
2、如果是对人脸图像进行处理,使用Android Camera2自带的Face类来对人脸检测,并完成在预览画面上画框将人脸框出、添加文字显示神经网络处理结果的功能。
3、使用Tensorflow Lite 将自己的训练得到的模型移植到Android上。
以上三个步骤会分为三个博客,同时也会提供示例代码。步骤二可以根据你的实际需求跳过或修改。这是这个系列的第三篇博客。
准备
- 新建一个C++ suport的AS项目
- 一个你训练好的以pb结尾的Tensorflow模型,如果你的模型是caffemodel可以使用代码将其转换成pb模型。
- 一个能调用你的模型完成你想要功能的Python脚本,以确保你的模型可以使用,以手写字体模型Minist为例:
1 |
|
- 清楚你的输入lable(inputName)和输出lable(outputName),如果不清楚,可以使用如下代码输出pb模型的层级结构
1 |
|
- 清楚你的输入向量和输出向量的关系
Input | Output | IN_COL | IN_ROW | OUT_COL | OUT_ROW | Code | |
---|---|---|---|---|---|---|---|
输入:单通道28*28 | 输出:1*1 | 1 | 28*28 | 1 | 1 | inferenceInterface.feed(inputName, inputdata, IN_COL, IN_ROW); |
|
输入:三通道224*224 | 输出:1*7 | 3 | 224*224 | 1 | 7 | inferenceInterface.feed(inputName, inputdata, (1,3,224,224)); |
Android Studio配置
(1)新建一个Android项目。
(2)把训练好的pb文件(mnist.pb)放入Android项目中app/src/main/assets下,若不存在assets目录,右键main->new->Directory,输入assets。
(3)将下载的libtensorflow_inference.so和libandroid_tensorflow_inference_java.jar如下结构放在libs文件夹下。
(4)app.gradle配置
在defaultConfig中添加 1
2
3
4multiDexEnabled true
ndk {
abiFilters "armeabi-v7a"
}1
2
3
4
5sourceSets {
main {
jniLibs.srcDirs = ['libs']
}
}
在dependencies中增加TensoFlow编译的jar文件libandroid_tensorflow_inference_java.jar: 1
compile files('libs/libandroid_tensorflow_inference_java.jar')
- 测试图片配置 将一张28*28的手写字体图片放到
app/src/main/res/drawabble
中,在AS中复制,会自动帮你建立xml配置关系的。
代码
在需要调用TensoFlow的地方,加载so库 System.loadLibrary("tensorflow_inference");
并 import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
就可以使用了
注意,旧版的TensoFlow,是如下方式进行,该方法可参考博客基于TensorFlow的MNIST手写数字识别与Android移植:
1 |
|
但在最新的libandroid_tensorflow_inference_java.jar中,已经没有这些方法了,换为 1
2
3TensorFlowInferenceInterface.feed()
TensorFlowInferenceInterface.run()
TensorFlowInferenceInterface.fetch()
下面是以MNIST手写数字识别为例,其实现方法如下: 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92package com.example.jinquan.pan.mnist_ensorflow_androiddemo;
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.graphics.Color;
import android.graphics.Matrix;
import android.util.Log;
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
public class PredictionTF {
private static final String TAG = "PredictionTF";
//设置模型输入/输出节点的数据维度
private static final int IN_COL = 1;
private static final int IN_ROW = 28*28;
private static final int OUT_COL = 1;
private static final int OUT_ROW = 1;
//模型中输入变量的名称
private static final String inputName = "input/x_input";
//模型中输出变量的名称
private static final String outputName = "output";
TensorFlowInferenceInterface inferenceInterface;
static {
//加载libtensorflow_inference.so库文件
System.loadLibrary("tensorflow_inference");
Log.e(TAG,"libtensorflow_inference.so库加载成功");
}
PredictionTF(AssetManager assetManager, String modePath) {
//初始化TensorFlowInferenceInterface对象
inferenceInterface = new TensorFlowInferenceInterface(assetManager,modePath);
Log.e(TAG,"TensoFlow模型文件加载成功");
}
/**
* 利用训练好的TensoFlow模型预测结果
* @param bitmap 输入被测试的bitmap图
* @return 返回预测结果,int数组
*/
public int[] getPredict(Bitmap bitmap) {
float[] inputdata = bitmapToFloatArray(bitmap,28, 28);//需要将图片缩放带28*28
//将数据feed给tensorflow的输入节点
inferenceInterface.feed(inputName, inputdata, IN_COL, IN_ROW);
//运行tensorflow
String[] outputNames = new String[] {outputName};
inferenceInterface.run(outputNames);
///获取输出节点的输出信息
int[] outputs = new int[OUT_COL*OUT_ROW]; //用于存储模型的输出数据
inferenceInterface.fetch(outputName, outputs);
return outputs;
}
/**
* 将bitmap转为(按行优先)一个float数组,并且每个像素点都归一化到0~1之间。
* @param bitmap 输入被测试的bitmap图片
* @param rx 将图片缩放到指定的大小(列)->28
* @param ry 将图片缩放到指定的大小(行)->28
* @return 返回归一化后的一维float数组 ->28*28
*/
public static float[] bitmapToFloatArray(Bitmap bitmap, int rx, int ry){
int height = bitmap.getHeight();
int width = bitmap.getWidth();
// 计算缩放比例
float scaleWidth = ((float) rx) / width;
float scaleHeight = ((float) ry) / height;
Matrix matrix = new Matrix();
matrix.postScale(scaleWidth, scaleHeight);
bitmap = Bitmap.createBitmap(bitmap, 0, 0, width, height, matrix, true);
Log.i(TAG,"bitmap width:"+bitmap.getWidth()+",height:"+bitmap.getHeight());
Log.i(TAG,"bitmap.getConfig():"+bitmap.getConfig());
height = bitmap.getHeight();
width = bitmap.getWidth();
float[] result = new float[height*width];
int k = 0;
//行优先
for(int j = 0;j < height;j++){
for (int i = 0;i < width;i++){
int argb = bitmap.getPixel(i,j);
int r = Color.red(argb);
int g = Color.green(argb);
int b = Color.blue(argb);
int a = Color.alpha(argb);
//由于是灰度图,所以r,g,b分量是相等的。
assert(r==g && g==b);
result[k++] = r / 255.0f;
}
}
return result;
}
}
- 简单说明一下:项目新建了一个
PredictionTF
类,该类会先加载libtensorflow_inference.so
库文件;PredictionTF(AssetManager assetManager, String modePath)
构造方法需要传入AssetManager
对象和pb文件的路径;
- 从资源文件中获取BitMap图片,并传入
getPredict(Bitmap bitmap)
方法,该方法首先将BitMap图像缩放到28*28的大小,由于原图是灰度图,我们需要获取灰度图的像素值,并将28*28的像素转存为行向量的一个float数组,并且每个像素点都归一化到0~1之间,这个就是bitmapToFloatArray(Bitmap bitmap, int rx, int ry)
方法的作用;
- 然后将数据 喂给(feed) tensorflow的输入节点,并 运行(run) tensorflow,最后 获取(fetch) 输出节点的输出信息。
MainActivity很简单,一个单击事件获取预测结果: 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.support.v7.app.AppCompatActivity;
import android.os.Bundle;
import android.util.Log;
import android.view.View;
import android.widget.ImageView;
import android.widget.TextView;
public class MainActivity extends AppCompatActivity {
// Used to load the 'native-lib' library on application startup.
static {
System.loadLibrary("native-lib");//可以去掉
}
private static final String TAG = "MainActivity";
private static final String MODEL_FILE = "file:///android_asset/mnist.pb"; //模型存放路径
TextView txt;
TextView tv;
ImageView imageView;
Bitmap bitmap;
PredictionTF preTF;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
// Example of a call to a native method
tv = (TextView) findViewById(R.id.sample_text);
txt=(TextView)findViewById(R.id.txt_id);
imageView =(ImageView)findViewById(R.id.imageView1);
bitmap = BitmapFactory.decodeStream(getClass().getResourceAsStream("/res/drawable/test.bmp"));
imageView.setImageBitmap(bitmap);
preTF = new PredictionTF(getAssets(),MODEL_FILE);//输入模型存放路径,并加载TensoFlow模型
}
public void click01(View v){
String res="预测结果为:";
int[] result= preTF.getPredict(bitmap);
for (int i=0;i<result.length;i++){
Log.i(TAG, res+result[i] );
res=res+String.valueOf(result[i])+" ";
}
txt.setText(res);
}
/**
* A native method that is implemented by the 'native-lib' native library,
* which is packaged with this application.
*/
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33<?xml version="1.0" encoding="utf-8"?>
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:orientation="vertical"
android:paddingBottom="16dp"
android:paddingLeft="16dp"
android:paddingRight="16dp"
android:paddingTop="16dp">
<TextView
android:id="@+id/sample_text"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="https://blog.csdn.net/guyuealian"
android:layout_gravity="center"/>
<Button
android:onClick="click01"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:text="click" />
<TextView
android:id="@+id/txt_id"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:gravity="center"
android:text="结果为:"/>
<ImageView
android:id="@+id/imageView1"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_gravity="center"/>
</LinearLayout>
注意事项
不同的神经网络对于输入图片的要求也不一样,有些需要转化成灰度图,有些需要归一化,有些要对尺寸进行裁剪,有些要减去均值。最好是先写一个Python脚本,将你的模型跑起来,完成需要的预处理,检测神经网络的输出是否正确。通过这个脚本了解输入输出数据的维度,再根据这个Python脚本去改写出对应的java代码。
特别要注意的是:
1 |
|
写好inferenceInterface.feed,才能成功起调模型,得到你想要的结果。
本博客主要依据这位博主的博客进行二次编写
演示Demo
本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!