Tensorflow MobileNet портирован на Android

TensorFlow Android

1 pb-файл преобразования модели CKPT

Использовать предыдущий блог«Использование официальной модели предварительного обучения MobileNet V1»Загрузите официальную предварительно обученную модель MobileNet V1 в"Мобайлнет_v1_1.0_192". Несмотря на то, что упакованный загруженный файл содержит преобразованныйpbдокумент, но официально предоставленныйpbВыход модели1001Вероятность, соответствующая категории, нам нужны 3 категории с наибольшей вероятностью. Функции можно использовать в исходной сетиtf.nn.top_kПолучите 3 класса с наибольшей вероятностью, поставьте функциюtf.nn.top_kкак вычислительный узел в сети. Код преобразования модели показан ниже.

import tensorflow as tf
from mobilenet_v1 import mobilenet_v1,mobilenet_v1_arg_scope 
import numpy as np
slim = tf.contrib.slim
CKPT = 'mobilenet_v1_1.0_192.ckpt'  

def build_model(inputs):   
    with slim.arg_scope(mobilenet_v1_arg_scope(is_training=False)):
        logits, end_points = mobilenet_v1(inputs, is_training=False, depth_multiplier=1.0, num_classes=1001)
    scores = end_points['Predictions']
    print(scores)
    #取概率最大的5个类别及其对应概率
    output = tf.nn.top_k(scores, k=3, sorted=True)
    #indices为类别索引,values为概率值
    return output.indices,output.values

def load_model(sess):
    loader = tf.train.Saver()
    loader.restore(sess,CKPT)
   
inputs=tf.placeholder(dtype=tf.float32,shape=(1,192,192,3),name='input')
classes_tf,scores_tf = build_model(inputs)  
classes = tf.identity(classes_tf, name='classes')
scores = tf.identity(scores_tf, name='scores') 
with tf.Session() as sess:
    load_model(sess)
    graph = tf.get_default_graph()
    output_graph_def = tf.graph_util.convert_variables_to_constants(
        sess, graph.as_graph_def(), [classes.op.name,scores.op.name]) 
    tf.train.write_graph(output_graph_def, 'model', 'mobilenet_v1_1.0_192.pb', as_text=False)

В приведенном выше коде единственная вероятность всех категорий проходит через расчетный узелtf.nn.top_kПосле этого он делится на два выхода: 3 категории с наибольшей вероятностьюclasses, вероятность 3 категорий с наибольшей вероятностьюscores. После выполнения приведенного выше кода в каталоге“model”получить файлmobilenet_v1_1.0_192.pb.

2 Портирование на Android

2.1 Использование Tensorflow Mobile в AndroidStudio

Во-первых,AndroidStudioверсия должна быть3.0и выше. СоздайтеAndroid Projectпосле, вModule:appизbuild.gradleДобавьте следующее в зависимости в файле:

 compile 'org.tensorflow:tensorflow-android:+'

2.2 Мобильный интерфейс Tensorflow

Вызовите класс-оболочку, используя модель в библиотеке Tensorflow Mobile.org.tensorflow.contrib.android.TensorFlowInferenceInterfaceДля завершения вызова модели в основном используются следующие функции.

 public TensorFlowInferenceInterface(AssetManager assetManager, String model){...}
 public void feed(String inputName, float[] src, long... dims) {...}
 public void run(String[] outputNames) {...}
 public void fetch(String outputName, int[] dst) {...}

Среди них параметры в конструктореmodelПредставляет каталог“assets”в названии модели.feedПараметры в функцииinputNameУказывает имя входного узла, то есть имя входного узла, указанное при преобразовании соответствующей модели“input”,параметрsrcУказывает массив входных данных, параметр переменной длиныdimsПредставляет размер ввода, например входящий1,192,192,3представляет входные данныеShape=[1,192,192,3]. функцияrunпараметрыoutputNamesУказывает выполнение от входного узла доoutputNamesВсе пути к узлам в . функцияfetchСредний параметрoutputNameУказывает имя выходного узла и копирует данные указанного выходного узла вdstсередина.

2.3 Растровый объект для float[]

Обратите внимание, что в разделе 2.1 функцияfeedОбъект данных, переданный во входной узел,float[]. Поэтому необходимоBitmapПреобразовать вfloat[]объект, пример кода показан ниже.

    //读取Bitmap像素值,并放入到浮点数数组中。归一化到[-1,1]
    private float[] getFloatImage(Bitmap bitmap){
        Bitmap bm = getResizedBitmap(bitmap,inputWH,inputWH);
        bm.getPixels(inputIntData, 0, bm.getWidth(), 0, 0, bm.getWidth(), bm.getHeight());
        for (int i = 0; i < inputIntData.length; ++i) {
            final int val = inputIntData[i];
            inputFloatData[i * 3 + 0] =(float) (((val >> 16) & 0xFF)/255.0-0.5)*2;
            inputFloatData[i * 3 + 1] = (float)(((val >> 8) & 0xFF)/255.0-0.5)*2;
            inputFloatData[i * 3 + 2] = (float)(( val & 0xFF)/255.0-0.5)*2 ;
        }
        return inputFloatData;
    }

так какMobileNet V1Входные данные предварительно обученной модели нормализуются к[-1,1], поэтому в функцииgetFloatImageнормализовать данные, чтобы[-1,1].

2.4 Инкапсуляция вызовов модели

Чтобы облегчить вызов, инкапсулируйте функции вызова, связанные с моделью, в классы.TFModelUtilsв, черезTFModelUtilsизrunФункция завершает вызов модели, и пример кода показан ниже.

package com.huachao.mn_v1_192; 
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.graphics.Matrix;
import android.util.Log; 
import org.tensorflow.contrib.android.TensorFlowInferenceInterface; 
import java.io.BufferedReader;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.HashMap;
import java.util.Map; 
public class TFModelUtils {
    private TensorFlowInferenceInterface inferenceInterface;
    private int[] inputIntData ;
    private float[] inputFloatData ;
    private int inputWH;
    private String inputName;
    private String[] outputNames; 
    private Map<Integer,String> dict;
    public TFModelUtils(AssetManager assetMngr,int inputWH,String inputName,String[]outputNames,String modelName){
        this.inputWH=inputWH;
        this.inputName=inputName;
        this.outputNames=outputNames;
        this.inputIntData=new int[inputWH*inputWH];
        this.inputFloatData = new float[inputWH*inputWH*3];
        //从assets目录加载模型
        inferenceInterface= new TensorFlowInferenceInterface(assetMngr, modelName);
        this.loadLabel(assetMngr);
    }
    public Map<String,Object> run(Bitmap bitmap){ 
        float[] inputData = getFloatImage(bitmap);
        //将输入数据复制到TensorFlow中,指定输入Shape=[1,INPUT_WH,INPUT_WH,3]
        inferenceInterface.feed(inputName, inputData, 1, inputWH, inputWH, 3); 
        // 执行模型
        inferenceInterface.run( outputNames ); 
        //将输出Tensor对象复制到指定数组中
        int[] classes=new int[3];
        float[] scores=new float[3];
        inferenceInterface.fetch(outputNames[0], classes);
        inferenceInterface.fetch(outputNames[1], scores);
        Map<String,Object> results=new HashMap<>();
        results.put("scores",scores);
        String[] classesLabel = new String[3];
        for(int i =0;i<3;i++){
            int idx=classes[i];
            classesLabel[i]=dict.get(idx);
//            System.out.printf("classes:"+dict.get(idx)+",scores:"+scores[i]+"\n");
        }
        results.put("classes",classesLabel);
        return results;


    }
    //读取Bitmap像素值,并放入到浮点数数组中。归一化到[-1,1]
    private float[] getFloatImage(Bitmap bitmap){
        Bitmap bm = getResizedBitmap(bitmap,inputWH,inputWH);
        bm.getPixels(inputIntData, 0, bm.getWidth(), 0, 0, bm.getWidth(), bm.getHeight());
        for (int i = 0; i < inputIntData.length; ++i) {
            final int val = inputIntData[i];
            inputFloatData[i * 3 + 0] =(float) (((val >> 16) & 0xFF)/255.0-0.5)*2;
            inputFloatData[i * 3 + 1] = (float)(((val >> 8) & 0xFF)/255.0-0.5)*2;
            inputFloatData[i * 3 + 2] = (float)(( val & 0xFF)/255.0-0.5)*2 ;
        }
        return inputFloatData;
    }
    //对图像做Resize
    public Bitmap getResizedBitmap(Bitmap bm, int newWidth, int newHeight) {
        int width = bm.getWidth();
        int height = bm.getHeight();
        float scaleWidth = ((float) newWidth) / width;
        float scaleHeight = ((float) newHeight) / height;
        Matrix matrix = new Matrix();
        matrix.postScale(scaleWidth, scaleHeight);
        Bitmap resizedBitmap = Bitmap.createBitmap( bm, 0, 0, width, height, matrix, false);
        bm.recycle();
        return resizedBitmap;
    }

    private void loadLabel( AssetManager assetManager ) {
        dict=new HashMap<>();
        try {
            InputStream stream = assetManager.open("label.txt");
            InputStreamReader isr=new InputStreamReader(stream);
            BufferedReader br=new BufferedReader(isr);
            String line;
            while((line=br.readLine())!=null){
                line=line.trim();
                String[] arr = line.split(",");
                if(arr.length!=2)
                    continue;
                int key=Integer.parseInt(arr[0]);
                String value = arr[1];
                dict.put(key,value);
            }
        }catch (Exception e){
            e.printStackTrace();
            Log.e("ERROR",e.getMessage());
        }
    }
}

3 Тестирование модели

тестовая модель