предисловие
Последняя статья кратко представила Hello World от TensorFlow. Прогнозирование линейной модели было выполнено с помощью TensorFlow.
Сегодня мы собираемся выполнить одну из самых простых задач классификации изображений.
Введение в набор данных
Мы используем набор данных fashion-mnist. Его адрес на гитхабе:GitHub.com/smashed руды и…
Набор данных содержит много изображений:
Каждые обучающие данные соответствуют классу. Диапазон значений его метки составляет от 0 до 9, всего 10 категорий. Фактическое значение каждой категории следующее:
Label | Description |
---|---|
0 | T-shirt/top |
1 | Trouser |
2 | Pullover |
3 | Dress |
4 | Coat |
5 | Sandal |
6 | Shirt |
7 | Sneaker |
8 | Bag |
9 | Ankle boot |
Предварительное изучение наборов данных
Давайте начнем кодировать, используя TensorFlow для этой задачи классификации изображений.
Сначала загрузите набор данных:
import tensorflow as tf
minst = tf.keras.datasets.fashion_mnist.load_data()
Получите обучающий набор, обучающую метку, тестовый набор и тестовую метку из набора данных:
(training_images, training_labels), (test_images, test_labels) = minst
Давайте посмотрим, какая первая картинка в тренировочном наборе.
import matplotlib.pyplot as plt
training_images[0]
Мы нашли многомерный массив.
array([[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
0, 0, 13, 73, 0, 0, 1, 4, 0, 0, 0, 0, 1,
1, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3,
0, 36, 136, 127, 62, 54, 0, 0, 0, 1, 3, 4, 0,
0, 3],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6,
0, 102, 204, 176, 134, 144, 123, 23, 0, 0, 0, 0, 12,
10, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 155, 236, 207, 178, 107, 156, 161, 109, 64, 23, 77, 130,
72, 15],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
69, 207, 223, 218, 216, 216, 163, 127, 121, 122, 146, 141, 88,
172, 66],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0,
200, 232, 232, 233, 229, 223, 223, 215, 213, 164, 127, 123, 196,
229, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
183, 225, 216, 223, 228, 235, 227, 224, 222, 224, 221, 223, 245,
173, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
193, 228, 218, 213, 198, 180, 212, 210, 211, 213, 223, 220, 243,
202, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 3, 0, 12,
219, 220, 212, 218, 192, 169, 227, 208, 218, 224, 212, 226, 197,
209, 52],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 99,
244, 222, 220, 218, 203, 198, 221, 215, 213, 222, 220, 245, 119,
167, 56],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 55,
236, 228, 230, 228, 240, 232, 213, 218, 223, 234, 217, 217, 209,
92, 0],
[ 0, 0, 1, 4, 6, 7, 2, 0, 0, 0, 0, 0, 237,
226, 217, 223, 222, 219, 222, 221, 216, 223, 229, 215, 218, 255,
77, 0],
[ 0, 3, 0, 0, 0, 0, 0, 0, 0, 62, 145, 204, 228,
207, 213, 221, 218, 208, 211, 218, 224, 223, 219, 215, 224, 244,
159, 0],
[ 0, 0, 0, 0, 18, 44, 82, 107, 189, 228, 220, 222, 217,
226, 200, 205, 211, 230, 224, 234, 176, 188, 250, 248, 233, 238,
215, 0],
[ 0, 57, 187, 208, 224, 221, 224, 208, 204, 214, 208, 209, 200,
159, 245, 193, 206, 223, 255, 255, 221, 234, 221, 211, 220, 232,
246, 0],
[ 3, 202, 228, 224, 221, 211, 211, 214, 205, 205, 205, 220, 240,
80, 150, 255, 229, 221, 188, 154, 191, 210, 204, 209, 222, 228,
225, 0],
[ 98, 233, 198, 210, 222, 229, 229, 234, 249, 220, 194, 215, 217,
241, 65, 73, 106, 117, 168, 219, 221, 215, 217, 223, 223, 224,
229, 29],
[ 75, 204, 212, 204, 193, 205, 211, 225, 216, 185, 197, 206, 198,
213, 240, 195, 227, 245, 239, 223, 218, 212, 209, 222, 220, 221,
230, 67],
[ 48, 203, 183, 194, 213, 197, 185, 190, 194, 192, 202, 214, 219,
221, 220, 236, 225, 216, 199, 206, 186, 181, 177, 172, 181, 205,
206, 115],
[ 0, 122, 219, 193, 179, 171, 183, 196, 204, 210, 213, 207, 211,
210, 200, 196, 194, 191, 195, 191, 198, 192, 176, 156, 167, 177,
210, 92],
[ 0, 0, 74, 189, 212, 191, 175, 172, 175, 181, 185, 188, 189,
188, 193, 198, 204, 209, 210, 210, 211, 188, 188, 194, 192, 216,
170, 0],
[ 2, 0, 0, 0, 66, 200, 222, 237, 239, 242, 246, 243, 244,
221, 220, 193, 191, 179, 182, 182, 181, 176, 166, 168, 99, 58,
0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 40, 61, 44, 72, 41, 35,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0]], dtype=uint8)
Далее посмотрите на конкретные размеры:
training_images[0].shape
Результатом является двузначный массив 28*28.
(28, 28)
Изображение показано ниже как изображение:
print(training_labels[0])
plt.imshow(training_images[0])
Результаты следующие: где 9 соответствует ботильонам.
9
<matplotlib.image.AxesImage at 0x19eb8ff2310>
обработка данных
Прежде чем классифицировать изображение, давайте нормализуем его:
training_images = training_images / 255
test_images = test_images/255
Затем мы видим, изменилось ли содержимое первого изображения:
training_images[0]
Мы видим, что все числа, содержащиеся на первой картинке, стали числами от 0 до 1. Это также отвечает нашим потребностям в нормализации:
array([[0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. ],
[0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. ],
[0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. ],
[0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0.00392157, 0. , 0. ,
0.05098039, 0.28627451, 0. , 0. , 0.00392157,
0.01568627, 0. , 0. , 0. , 0. ,
0.00392157, 0.00392157, 0. ],
[0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0.01176471, 0. , 0.14117647,
0.53333333, 0.49803922, 0.24313725, 0.21176471, 0. ,
0. , 0. , 0.00392157, 0.01176471, 0.01568627,
0. , 0. , 0.01176471],
[0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0.02352941, 0. , 0.4 ,
0.8 , 0.69019608, 0.5254902 , 0.56470588, 0.48235294,
0.09019608, 0. , 0. , 0. , 0. ,
0.04705882, 0.03921569, 0. ],
[0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0.60784314,
0.9254902 , 0.81176471, 0.69803922, 0.41960784, 0.61176471,
0.63137255, 0.42745098, 0.25098039, 0.09019608, 0.30196078,
0.50980392, 0.28235294, 0.05882353],
[0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0.00392157, 0. , 0.27058824, 0.81176471,
0.8745098 , 0.85490196, 0.84705882, 0.84705882, 0.63921569,
0.49803922, 0.4745098 , 0.47843137, 0.57254902, 0.55294118,
0.34509804, 0.6745098 , 0.25882353],
[0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0.00392157,
0.00392157, 0.00392157, 0. , 0.78431373, 0.90980392,
0.90980392, 0.91372549, 0.89803922, 0.8745098 , 0.8745098 ,
0.84313725, 0.83529412, 0.64313725, 0.49803922, 0.48235294,
0.76862745, 0.89803922, 0. ],
[0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0.71764706, 0.88235294,
0.84705882, 0.8745098 , 0.89411765, 0.92156863, 0.89019608,
0.87843137, 0.87058824, 0.87843137, 0.86666667, 0.8745098 ,
0.96078431, 0.67843137, 0. ],
[0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0.75686275, 0.89411765,
0.85490196, 0.83529412, 0.77647059, 0.70588235, 0.83137255,
0.82352941, 0.82745098, 0.83529412, 0.8745098 , 0.8627451 ,
0.95294118, 0.79215686, 0. ],
[0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0.00392157,
0.01176471, 0. , 0.04705882, 0.85882353, 0.8627451 ,
0.83137255, 0.85490196, 0.75294118, 0.6627451 , 0.89019608,
0.81568627, 0.85490196, 0.87843137, 0.83137255, 0.88627451,
0.77254902, 0.81960784, 0.20392157],
[0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0.02352941, 0. , 0.38823529, 0.95686275, 0.87058824,
0.8627451 , 0.85490196, 0.79607843, 0.77647059, 0.86666667,
0.84313725, 0.83529412, 0.87058824, 0.8627451 , 0.96078431,
0.46666667, 0.65490196, 0.21960784],
[0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0.01568627,
0. , 0. , 0.21568627, 0.9254902 , 0.89411765,
0.90196078, 0.89411765, 0.94117647, 0.90980392, 0.83529412,
0.85490196, 0.8745098 , 0.91764706, 0.85098039, 0.85098039,
0.81960784, 0.36078431, 0. ],
[0. , 0. , 0.00392157, 0.01568627, 0.02352941,
0.02745098, 0.00784314, 0. , 0. , 0. ,
0. , 0. , 0.92941176, 0.88627451, 0.85098039,
0.8745098 , 0.87058824, 0.85882353, 0.87058824, 0.86666667,
0.84705882, 0.8745098 , 0.89803922, 0.84313725, 0.85490196,
1. , 0.30196078, 0. ],
[0. , 0.01176471, 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0.24313725,
0.56862745, 0.8 , 0.89411765, 0.81176471, 0.83529412,
0.86666667, 0.85490196, 0.81568627, 0.82745098, 0.85490196,
0.87843137, 0.8745098 , 0.85882353, 0.84313725, 0.87843137,
0.95686275, 0.62352941, 0. ],
[0. , 0. , 0. , 0. , 0.07058824,
0.17254902, 0.32156863, 0.41960784, 0.74117647, 0.89411765,
0.8627451 , 0.87058824, 0.85098039, 0.88627451, 0.78431373,
0.80392157, 0.82745098, 0.90196078, 0.87843137, 0.91764706,
0.69019608, 0.7372549 , 0.98039216, 0.97254902, 0.91372549,
0.93333333, 0.84313725, 0. ],
[0. , 0.22352941, 0.73333333, 0.81568627, 0.87843137,
0.86666667, 0.87843137, 0.81568627, 0.8 , 0.83921569,
0.81568627, 0.81960784, 0.78431373, 0.62352941, 0.96078431,
0.75686275, 0.80784314, 0.8745098 , 1. , 1. ,
0.86666667, 0.91764706, 0.86666667, 0.82745098, 0.8627451 ,
0.90980392, 0.96470588, 0. ],
[0.01176471, 0.79215686, 0.89411765, 0.87843137, 0.86666667,
0.82745098, 0.82745098, 0.83921569, 0.80392157, 0.80392157,
0.80392157, 0.8627451 , 0.94117647, 0.31372549, 0.58823529,
1. , 0.89803922, 0.86666667, 0.7372549 , 0.60392157,
0.74901961, 0.82352941, 0.8 , 0.81960784, 0.87058824,
0.89411765, 0.88235294, 0. ],
[0.38431373, 0.91372549, 0.77647059, 0.82352941, 0.87058824,
0.89803922, 0.89803922, 0.91764706, 0.97647059, 0.8627451 ,
0.76078431, 0.84313725, 0.85098039, 0.94509804, 0.25490196,
0.28627451, 0.41568627, 0.45882353, 0.65882353, 0.85882353,
0.86666667, 0.84313725, 0.85098039, 0.8745098 , 0.8745098 ,
0.87843137, 0.89803922, 0.11372549],
[0.29411765, 0.8 , 0.83137255, 0.8 , 0.75686275,
0.80392157, 0.82745098, 0.88235294, 0.84705882, 0.7254902 ,
0.77254902, 0.80784314, 0.77647059, 0.83529412, 0.94117647,
0.76470588, 0.89019608, 0.96078431, 0.9372549 , 0.8745098 ,
0.85490196, 0.83137255, 0.81960784, 0.87058824, 0.8627451 ,
0.86666667, 0.90196078, 0.2627451 ],
[0.18823529, 0.79607843, 0.71764706, 0.76078431, 0.83529412,
0.77254902, 0.7254902 , 0.74509804, 0.76078431, 0.75294118,
0.79215686, 0.83921569, 0.85882353, 0.86666667, 0.8627451 ,
0.9254902 , 0.88235294, 0.84705882, 0.78039216, 0.80784314,
0.72941176, 0.70980392, 0.69411765, 0.6745098 , 0.70980392,
0.80392157, 0.80784314, 0.45098039],
[0. , 0.47843137, 0.85882353, 0.75686275, 0.70196078,
0.67058824, 0.71764706, 0.76862745, 0.8 , 0.82352941,
0.83529412, 0.81176471, 0.82745098, 0.82352941, 0.78431373,
0.76862745, 0.76078431, 0.74901961, 0.76470588, 0.74901961,
0.77647059, 0.75294118, 0.69019608, 0.61176471, 0.65490196,
0.69411765, 0.82352941, 0.36078431],
[0. , 0. , 0.29019608, 0.74117647, 0.83137255,
0.74901961, 0.68627451, 0.6745098 , 0.68627451, 0.70980392,
0.7254902 , 0.7372549 , 0.74117647, 0.7372549 , 0.75686275,
0.77647059, 0.8 , 0.81960784, 0.82352941, 0.82352941,
0.82745098, 0.7372549 , 0.7372549 , 0.76078431, 0.75294118,
0.84705882, 0.66666667, 0. ],
[0.00784314, 0. , 0. , 0. , 0.25882353,
0.78431373, 0.87058824, 0.92941176, 0.9372549 , 0.94901961,
0.96470588, 0.95294118, 0.95686275, 0.86666667, 0.8627451 ,
0.75686275, 0.74901961, 0.70196078, 0.71372549, 0.71372549,
0.70980392, 0.69019608, 0.65098039, 0.65882353, 0.38823529,
0.22745098, 0. , 0. ],
[0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0.15686275, 0.23921569, 0.17254902,
0.28235294, 0.16078431, 0.1372549 , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. ],
[0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. ],
[0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. ]])
Моделирование
model = tf.keras.models.Sequential(
[
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation = tf.nn.relu),
tf.keras.layers.Dense(10, activation = tf.nn.softmax)
]
)
Здесь мы строим модель нейронной сети, В модели 3 слоя, Слой 1: Используйте Flatten, чтобы поместить 28Изображение числа 28 разбито на одномерный массив, равный 768.1 данные. Слой 2: полностью связанный слой со 128 нейронами. Функция активации использует Relu. Третий слой: полносвязный слой, который также является результирующим слоем, имеет 10 нейронов. Это связано с тем, что нам нужно выполнить задачу с несколькими классификациями, а softmax специально используется для задач с несколькими классификациями. Итак, функция активации использует softmax.
Затем скомпилируйте модель и обучите ее:
model.compile(
optimizer = tf.optimizers.Adam(),
loss="sparse_categorical_crossentropy",
metrics = ['accuracy']
)
model.fit(training_images, training_labels, epochs = 5)
epochs = 5 — количество эпох для обучения.
Результат выглядит следующим образом:
Epoch 1/5
1875/1875 [==============================] - 2s 692us/step - loss: 0.5020 - accuracy: 0.8235
Epoch 2/5
1875/1875 [==============================] - 1s 680us/step - loss: 0.3727 - accuracy: 0.8664
Epoch 3/5
1875/1875 [==============================] - 1s 735us/step - loss: 0.3389 - accuracy: 0.8765
Epoch 4/5
1875/1875 [==============================] - 1s 672us/step - loss: 0.3146 - accuracy: 0.8853
Epoch 5/5
1875/1875 [==============================] - 1s 686us/step - loss: 0.2960 - accuracy: 0.8910
После обучения проверим точность модели:
model.evaluate(test_images, test_labels)
Результат около 0,8668
313/313 [==============================] - 0s 544us/step - loss: 0.3702 - accuracy: 0.8668
Если вы хотите получить более точные данные, вы можете соответствующим образом увеличить эпохи. Делайте больше раундов тренировок.
Давайте посмотрим на первое изображение тестового набора и убедимся, что алгоритм точен:
classifications = model.predict(test_images)
print(classifications[0])
Результатом является массив длиной 10, где каждое число соответствует вероятности класса 0-9. Самым большим здесь является число с индексом 9. Таким образом, прогнозируемый результат должен быть 9.
[3.6468668e-06 5.4612696e-07 4.5575359e-08 2.6194398e-08 2.4610272e-06
2.0921556e-02 4.4843373e-06 5.3708768e-01 9.6340482e-05 4.4188327e-01]
Давайте посмотрим на его реальную этикетку:
test_labels[0]
Результат выглядит следующим образом:
9
как мы и предсказывали.
постскриптум
На этом простая задача мультиклассификации изображений завершена. Но фактический результат является переоснащением. Точность предсказания на тестовом наборе составляет 86%. Тренировочный набор составляет 89%.