Сегодня пользователь сетиНаучите, как собрать tensorflow Lite2.0 на AndroidКомментарий к этой статье
Спросите, как ввести изображение и вывести массив?
Я думаю, что это также является проблемой для многих новичков, многие студенты начального уровня не завершили создание модели, обучение, преобразование в TensorFlow Lite и фактическое использование в Android.
Так что я дал ему демо, которое я написал раньше, думая об этом или находя время, чтобы поместить это.demoНапишите статью, надеясь помочь большему количеству студентов начального уровня.
Хотя статей о рукописном вводе на основе TensorFlow много, все же нужно быть многословным, ведь это хороший пример искусственного интеллекта начального уровня.
Я не обращаю внимания на детали алгоритма распознавания рукописного текста, я обращаю внимание на весь процесс от модели до приложения, если вы хотите понять алгоритм, пожалуйста, изучите его самостоятельно.
Заинтересованные студенты могут следить за моей серией блоговСерия «Искусственный интеллект» (обновление...), Я тоже осваиваю эти знания, учусь и общаюсь вместе.
1 Основы почерка
1.1 Изучение набора данных MINIST
Используемый набор данных MNIST получен из Национального института стандартов и технологий (NIST). Обучающая выборка состояла из рукописных чисел 250 разных людей, 50% старшеклассников и 50% сотрудников Бюро переписи населения.Тестовая выборка также была написана от руки в тех же пропорциях цифровых данных.
Как выглядит каждое изображение в наборе данных?
Именно так:Получено следующим кодом:
# Plot ad hoc mnist instances
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt
# load (downloaded if needed) the MNIST dataset
(X_train, y_train), (X_test, y_test) = mnist.load_data()
# plot 4 images as gray scale
plt.subplot(221)
plt.imshow(X_train[0], cmap=plt.get_cmap("gray"))
plt.subplot(222)
plt.imshow(X_train[1], cmap=plt.get_cmap("gray"))
plt.subplot(223)
plt.imshow(X_train[2], cmap=plt.get_cmap("gray"))
plt.subplot(224)
plt.imshow(X_train[3], cmap=plt.get_cmap("gray"))
# show the plot
plt.show()
Но что такое хранилище на самом деле?Вы можете обнаружить, что это слово 0, а хранилище - это значение RGB изображения, где значение равно нулю, оно черное, а ненулевые места - разные уровни серого. Это изображение RGB-матрицы в градациях серого.
1.2 Основное введение в CNN
Алгоритм распознавания рукописного ввода, используемый на этот раз, — это CNN (Convolutional Neural Network), который широко используется в компьютерном зрении.Наиболее классическая схема распознавания рукописного ввода CNN описывает весь процесс распознавания рукописного ввода, конкретные детали обсуждаться не будут, есть возможность написать статью о деталях этого алгоритма, но структура модели нейронной сети этой статьи выглядит следующим образом. :
1.3 Распознавание рукописного ввода на основе TensorFlow
Используется интерфейс Keras в TensorFlow, который больше подходит для новичков. Заставляет вас чувствовать, что построение модели нейронной сети похоже на строительные блоки.
Код выглядит следующим образом, обратите внимание на комментарии.
import numpy
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.python.keras.utils import np_utils
import tensorflow as tf
import pathlib
# fix random seed for reproducibility
seed = 7
numpy.random.seed(seed)
# load data
(X_train, y_train), (X_test, y_test) = mnist.load_data()
# reshape to be [samples][channels][width][height]
X_train = X_train.reshape(X_train.shape[0], 28, 28, 1).astype('float32')
X_test = X_test.reshape(X_test.shape[0], 28, 28, 1).astype('float32')
# normalize inputs from 0-255 to 0-1
X_train = X_train / 255
X_test = X_test / 255
print(X_train.shape)
# one hot encode outputs
y_train = np_utils.to_categorical(y_train)
y_test = np_utils.to_categorical(y_test)
print(X_train[0])
num_classes = y_test.shape[1]
def baseline_model():
# create model
model = Sequential()
model.add(Conv2D(32, kernel_size=(5, 5),
input_shape=(28, 28, 1),//采用单通道的图片
activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.2))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dense(num_classes, activation='softmax'))
# Compile model
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer='adam',
metrics=['accuracy'])
return model
model = baseline_model()
# Fit the model
model.fit(X_train, y_train, validation_data=(X_test, y_test), epochs=10, batch_size=200, verbose=2)
# Final evaluation of the model
scores = model.evaluate(X_test, y_test, verbose=0)
print("CNN Error: %.2f%%" % (100 - scores[1] * 100))
# 上面升级网络训练的过程
# 下面需要将其转换tensorflow Lite模型,便于在Android中使用。
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
tflite_model_file = pathlib.Path('saved_model/model.tflite')
tflite_model_file.write_bytes(tflite_model)
2 Реализовать распознавание рукописного ввода на Android
Если вы не знаете, как настроить среду Android, см.Научите, как собрать tensorflow Lite2.0 на Android
2.1 Загрузите модель
Поместите обученный файл TensorFlow Lite в папку ресурсов Android.
public class TF {
private static Context mContext;
Interpreter mInterpreter;
private static TF instance;
public static TF newInstance(Context context) {
mContext = context;
if (instance == null) {
instance = new TF();
}
return instance;
}
Interpreter get() {
try {
if (Objects.isNull(mInterpreter))
mInterpreter = new Interpreter(loadModelFile(mContext));
} catch (IOException e) {
e.printStackTrace();
}
return mInterpreter;
}
// 获取文件
private MappedByteBuffer loadModelFile(Context context) throws IOException {
AssetFileDescriptor fileDescriptor = context.getAssets().openFd("model.tflite");
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}
}
2.2 Настройка вида чертежа
public class HandWriteView extends View {
Path mPath = new Path();
Paint mPaint;
Bitmap mBitmap;
Canvas mCanvas;
public HandWriteView(Context context) {
super(context);
init();
}
public HandWriteView(Context context, AttributeSet attrs) {
super(context, attrs);
init();
}
void init() {
mPaint = new Paint();
mPaint.setColor(Color.WHITE);
mPaint.setStyle(Paint.Style.STROKE);
mPaint.setStrokeJoin(Paint.Join.ROUND);
mPaint.setStrokeCap(Paint.Cap.ROUND);
mPaint.setStrokeWidth(30);
}
@Override
protected void onDraw(Canvas canvas) {
super.onDraw(canvas);
mBitmap = Bitmap.createBitmap(getWidth(), getHeight(), Bitmap.Config.ARGB_8888);
mCanvas = new Canvas(mBitmap);
mCanvas.drawColor(Color.BLACK);
canvas.drawPath(mPath, mPaint);
mCanvas.drawPath(mPath, mPaint);
}
@Override
public boolean onTouchEvent(MotionEvent event) {
switch (event.getAction()) {
case MotionEvent.ACTION_DOWN:
mPath.moveTo(event.getX(), event.getY());
break;
case MotionEvent.ACTION_MOVE:
mPath.lineTo(event.getX(), event.getY());
break;
case MotionEvent.ACTION_UP:
case MotionEvent.ACTION_CANCEL:
break;
}
postInvalidate();
return true;
}
Bitmap getBitmap() {
mPath.reset();
return mBitmap;
}
}
2.3 Преобразование растрового изображения в формат, требуемый сетью
Поскольку данные в наборе данных имеют размер 28 * 28 * 3, 28 — это ширина и высота изображения, 3 — это три канала R, G и B, поэтому перед вводом в сеть нам нужно преобразовать растровое изображение в формат, требуемый сетью.
private ByteBuffer convertBitmapToByteBuffer(Bitmap bitmap) {
int inputShape[] = TF.newInstance(getApplicationContext()).get().getInputTensor(0).shape();
int inputImageWidth = inputShape[1];
int inputImageHeight = inputShape[2];
Bitmap bs = Bitmap.createScaledBitmap(bitmap, inputImageWidth, inputImageHeight, true);
mImageView.setImageBitmap(bs);
ByteBuffer byteBuffer = ByteBuffer.allocateDirect(4 * inputImageHeight * inputImageWidth);
byteBuffer.order(ByteOrder.nativeOrder());
int[] pixels = new int[inputImageWidth * inputImageHeight];
bs.getPixels(pixels, 0, bs.getWidth(), 0, 0, bs.getWidth(), bs.getHeight());
for (int pixelValue : pixels) {
int r = (pixelValue >> 16 & 0xFF);
int g = (pixelValue >> 8 & 0xFF);
int b = (pixelValue & 0xFF);
// Convert RGB to grayscale and normalize pixel value to [0..1]
float normalizedPixelValue = (r + g + b) / 3.0f / 255.0f;
byteBuffer.putFloat(normalizedPixelValue);
}
return byteBuffer;
}
2.4 Вывод результатов распознавания
Результат распознавания оценивается по вероятности 0-9, а наибольшей вероятностью является результат распознавания.
float[][] input = new float[1][10];
TF.newInstance(getApplicationContext()).get().run(convertBitmapToByteBuffer(mHandWriteView.getBitmap()), input);
int result = -1;
float value = 0f;
for (int j = 0; j < 10; j++) {
if (input[0][j] > value) {
value = input[0][j];
result = j;
}
Log.i("TAG", "result: " + j + " " + input[0][j]);
}
if (input[0][result] < 0.2f) {
mTextView.setText("结果为:未识别");
} else {
mTextView.setText("结果为:" + result);
}
Результат распознавания:
При необходимости нажмитеdemoскачать.
3 Резюме
Основной процесс разработки приложения искусственного интеллекта так много.Ключ заключается в алгоритме.Чтобы получить более точную модель, в дополнение к использованию лучшей модели, данные необходимо вращать, улучшать или отбеливать, чтобы улучшить данные, разнообразие.
Приветствую всех для общения! ! ! !