[Stove AI] Глубокое обучение 008-Keras для решения проблем мультиклассификации

искусственный интеллект глубокое обучение Python Keras

[Stove AI] Глубокое обучение 008-Keras для решения проблем мультиклассификации

(Библиотеки Python и номера версий, используемые в этой статье: Python 3.6, Numpy 1.14, scikit-learn 0.19, matplotlib 2.2, Keras 2.1.6, Tensorflow 1.9.0)

статья передо мной[Stove AI] Глубокое обучение 005 — простые строки кода Keras для решения проблем с двумя классификациямиВ разделе мы представляем использование Keras для решения задач бинарной классификации. Как решить столько проблем с классификацией? Каковы различия?


1. Подготовьте набор данных

Для демонстрации на этот раз я выбрал запись в блогеСерия keras︱Обучение множественной классификации изображений и точная настройка с использованием узких мест (3)Для упомянутого набора данных исходный набор данных помещает все категории фотографий поездов в папку поезда, все тестовые фотографии — в тестовую папку и начинается с разных номеров для представления разных категорий, например, начиная с 3. Фотография — это класс автобуса и так далее. Сначала поместите эти разные категории фотографий в разные папки, последняя папка поезда имеет 5 подпапок с 80 изображениями в каждой подпапке, а конечная тестовая папка имеет 5 подпапок, каждая подпапка. В папке 20 изображений. Всего изображений всего 500.

В коде нам нужно использовать ImageDataGenerator для увеличения данных и flow_from_directory для создания потока данных из папок.

Код в основном такой же, как и в статье с двумя категориями, поэтому я не буду размещать его здесь, вы можете перейти намой гитхабСмотрите полный код напрямую.

Единственное отличие состоит в том, что для исходной задачи с двумя категориями нужно установить class_mode='categorical' вместо class_mode='binary'.


2. Построение модели и обучение

В основном то же самое, что и бинарная классификация, следующая часть построения модели:

# 4,建立Keras模型:模型的建立主要包括模型的搭建,模型的配置
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D
from keras.layers import Activation, Dropout, Flatten, Dense
from keras import optimizers
def build_model(input_shape):
    # 模型的搭建:此处构建三个CNN层+2个全连接层的结构
    model = Sequential()
    model.add(Conv2D(32, (3, 3), input_shape=input_shape))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))

    model.add(Conv2D(32, (3, 3)))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))

    model.add(Conv2D(64, (3, 3)))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))

    model.add(Flatten())
    model.add(Dense(64))
    model.add(Activation('relu'))
    model.add(Dropout(0.5)) # Dropout防止过拟合
    model.add(Dense(class_num)) # 此处多分类问题,用Dense(class_num)
    model.add(Activation('softmax')) #多分类问题用softmax作为activation function
    
    # 模型的配置
    model.compile(loss='categorical_crossentropy', # 定义模型的loss func,optimizer,
                  optimizer=optimizers.RMSprop(), # 使用默认的lr=0.001
                  metrics=['accuracy'])# 主要优化accuracy

    return model # 返回构建好的模型

Изменение заключается в следующем: последний слой Dense должен использовать Dense(class_num) вместо Dense(1), а затем использовать стандартную функцию активации нескольких классов: softmax. Что касается конфигурации модели, функцию потерь также необходимо изменить на «categorical_crossentropy».

После обучения модели окончательный результат выглядит следующим образом:

Из результатов: явления переобучения нет, но тест не стабилен и изменение относительно велико. Тест акк после плато около 0,85.

########################резюме########################## ######

1. Проблема с несколькими классами и проблема с двумя классами в основном одинаковы, различия заключаются в следующем: 1. При настройке каталога потока_потока установите class_mode='categorical'. 2. Последний слой модели использует Dense (class_num) и softmax, специальную функцию активации нескольких классов. 3. Функция потерь модели должна использовать categorical_crossentropy.

#################################################################


Примечание. Эта часть кода была загружена в (мой гитхаб), добро пожаловать на скачивание.