1 pb-файл преобразования модели CKPT
Использовать предыдущий блог«Использование официальной модели предварительного обучения MobileNet V1»Загрузите официальную предварительно обученную модель MobileNet V1 в"Мобайлнет_v1_1.0_192". Несмотря на то, что упакованный загруженный файл содержит преобразованныйpb
документ, но официально предоставленныйpb
Выход модели1001
Вероятность, соответствующая категории, нам нужны 3 категории с наибольшей вероятностью. Функции можно использовать в исходной сетиtf.nn.top_k
Получите 3 класса с наибольшей вероятностью, поставьте функциюtf.nn.top_k
как вычислительный узел в сети. Код преобразования модели показан ниже.
import tensorflow as tf
from mobilenet_v1 import mobilenet_v1,mobilenet_v1_arg_scope
import numpy as np
slim = tf.contrib.slim
CKPT = 'mobilenet_v1_1.0_192.ckpt'
def build_model(inputs):
with slim.arg_scope(mobilenet_v1_arg_scope(is_training=False)):
logits, end_points = mobilenet_v1(inputs, is_training=False, depth_multiplier=1.0, num_classes=1001)
scores = end_points['Predictions']
print(scores)
#取概率最大的5个类别及其对应概率
output = tf.nn.top_k(scores, k=3, sorted=True)
#indices为类别索引,values为概率值
return output.indices,output.values
def load_model(sess):
loader = tf.train.Saver()
loader.restore(sess,CKPT)
inputs=tf.placeholder(dtype=tf.float32,shape=(1,192,192,3),name='input')
classes_tf,scores_tf = build_model(inputs)
classes = tf.identity(classes_tf, name='classes')
scores = tf.identity(scores_tf, name='scores')
with tf.Session() as sess:
load_model(sess)
graph = tf.get_default_graph()
output_graph_def = tf.graph_util.convert_variables_to_constants(
sess, graph.as_graph_def(), [classes.op.name,scores.op.name])
tf.train.write_graph(output_graph_def, 'model', 'mobilenet_v1_1.0_192.pb', as_text=False)
В приведенном выше коде единственная вероятность всех категорий проходит через расчетный узелtf.nn.top_k
После этого он делится на два выхода: 3 категории с наибольшей вероятностьюclasses
, вероятность 3 категорий с наибольшей вероятностьюscores
. После выполнения приведенного выше кода в каталоге“model”
получить файлmobilenet_v1_1.0_192.pb
.
2 Портирование на Android
2.1 Использование Tensorflow Mobile в AndroidStudio
Во-первых,AndroidStudio
версия должна быть3.0
и выше. СоздайтеAndroid Project
после, вModule:app
изbuild.gradle
Добавьте следующее в зависимости в файле:
compile 'org.tensorflow:tensorflow-android:+'
2.2 Мобильный интерфейс Tensorflow
Вызовите класс-оболочку, используя модель в библиотеке Tensorflow Mobile.org.tensorflow.contrib.android.TensorFlowInferenceInterface
Для завершения вызова модели в основном используются следующие функции.
public TensorFlowInferenceInterface(AssetManager assetManager, String model){...}
public void feed(String inputName, float[] src, long... dims) {...}
public void run(String[] outputNames) {...}
public void fetch(String outputName, int[] dst) {...}
Среди них параметры в конструктореmodel
Представляет каталог“assets”
в названии модели.feed
Параметры в функцииinputName
Указывает имя входного узла, то есть имя входного узла, указанное при преобразовании соответствующей модели“input”
,параметрsrc
Указывает массив входных данных, параметр переменной длиныdims
Представляет размер ввода, например входящий1,192,192,3
представляет входные данныеShape=[1,192,192,3]
. функцияrun
параметрыoutputNames
Указывает выполнение от входного узла доoutputNames
Все пути к узлам в . функцияfetch
Средний параметрoutputName
Указывает имя выходного узла и копирует данные указанного выходного узла вdst
середина.
2.3 Растровый объект для float[]
Обратите внимание, что в разделе 2.1 функцияfeed
Объект данных, переданный во входной узел,float[]
. Поэтому необходимоBitmap
Преобразовать вfloat[]
объект, пример кода показан ниже.
//读取Bitmap像素值,并放入到浮点数数组中。归一化到[-1,1]
private float[] getFloatImage(Bitmap bitmap){
Bitmap bm = getResizedBitmap(bitmap,inputWH,inputWH);
bm.getPixels(inputIntData, 0, bm.getWidth(), 0, 0, bm.getWidth(), bm.getHeight());
for (int i = 0; i < inputIntData.length; ++i) {
final int val = inputIntData[i];
inputFloatData[i * 3 + 0] =(float) (((val >> 16) & 0xFF)/255.0-0.5)*2;
inputFloatData[i * 3 + 1] = (float)(((val >> 8) & 0xFF)/255.0-0.5)*2;
inputFloatData[i * 3 + 2] = (float)(( val & 0xFF)/255.0-0.5)*2 ;
}
return inputFloatData;
}
так какMobileNet V1
Входные данные предварительно обученной модели нормализуются к[-1,1]
, поэтому в функцииgetFloatImage
нормализовать данные, чтобы[-1,1]
.
2.4 Инкапсуляция вызовов модели
Чтобы облегчить вызов, инкапсулируйте функции вызова, связанные с моделью, в классы.TFModelUtils
в, черезTFModelUtils
изrun
Функция завершает вызов модели, и пример кода показан ниже.
package com.huachao.mn_v1_192;
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.graphics.Matrix;
import android.util.Log;
import org.tensorflow.contrib.android.TensorFlowInferenceInterface;
import java.io.BufferedReader;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.HashMap;
import java.util.Map;
public class TFModelUtils {
private TensorFlowInferenceInterface inferenceInterface;
private int[] inputIntData ;
private float[] inputFloatData ;
private int inputWH;
private String inputName;
private String[] outputNames;
private Map<Integer,String> dict;
public TFModelUtils(AssetManager assetMngr,int inputWH,String inputName,String[]outputNames,String modelName){
this.inputWH=inputWH;
this.inputName=inputName;
this.outputNames=outputNames;
this.inputIntData=new int[inputWH*inputWH];
this.inputFloatData = new float[inputWH*inputWH*3];
//从assets目录加载模型
inferenceInterface= new TensorFlowInferenceInterface(assetMngr, modelName);
this.loadLabel(assetMngr);
}
public Map<String,Object> run(Bitmap bitmap){
float[] inputData = getFloatImage(bitmap);
//将输入数据复制到TensorFlow中,指定输入Shape=[1,INPUT_WH,INPUT_WH,3]
inferenceInterface.feed(inputName, inputData, 1, inputWH, inputWH, 3);
// 执行模型
inferenceInterface.run( outputNames );
//将输出Tensor对象复制到指定数组中
int[] classes=new int[3];
float[] scores=new float[3];
inferenceInterface.fetch(outputNames[0], classes);
inferenceInterface.fetch(outputNames[1], scores);
Map<String,Object> results=new HashMap<>();
results.put("scores",scores);
String[] classesLabel = new String[3];
for(int i =0;i<3;i++){
int idx=classes[i];
classesLabel[i]=dict.get(idx);
// System.out.printf("classes:"+dict.get(idx)+",scores:"+scores[i]+"\n");
}
results.put("classes",classesLabel);
return results;
}
//读取Bitmap像素值,并放入到浮点数数组中。归一化到[-1,1]
private float[] getFloatImage(Bitmap bitmap){
Bitmap bm = getResizedBitmap(bitmap,inputWH,inputWH);
bm.getPixels(inputIntData, 0, bm.getWidth(), 0, 0, bm.getWidth(), bm.getHeight());
for (int i = 0; i < inputIntData.length; ++i) {
final int val = inputIntData[i];
inputFloatData[i * 3 + 0] =(float) (((val >> 16) & 0xFF)/255.0-0.5)*2;
inputFloatData[i * 3 + 1] = (float)(((val >> 8) & 0xFF)/255.0-0.5)*2;
inputFloatData[i * 3 + 2] = (float)(( val & 0xFF)/255.0-0.5)*2 ;
}
return inputFloatData;
}
//对图像做Resize
public Bitmap getResizedBitmap(Bitmap bm, int newWidth, int newHeight) {
int width = bm.getWidth();
int height = bm.getHeight();
float scaleWidth = ((float) newWidth) / width;
float scaleHeight = ((float) newHeight) / height;
Matrix matrix = new Matrix();
matrix.postScale(scaleWidth, scaleHeight);
Bitmap resizedBitmap = Bitmap.createBitmap( bm, 0, 0, width, height, matrix, false);
bm.recycle();
return resizedBitmap;
}
private void loadLabel( AssetManager assetManager ) {
dict=new HashMap<>();
try {
InputStream stream = assetManager.open("label.txt");
InputStreamReader isr=new InputStreamReader(stream);
BufferedReader br=new BufferedReader(isr);
String line;
while((line=br.readLine())!=null){
line=line.trim();
String[] arr = line.split(",");
if(arr.length!=2)
continue;
int key=Integer.parseInt(arr[0]);
String value = arr[1];
dict.put(key,value);
}
}catch (Exception e){
e.printStackTrace();
Log.e("ERROR",e.getMessage());
}
}
}