Машинное обучение распознаванию цифр на картинках

машинное обучение искусственный интеллект

Сегодня мы используем логистическую регрессию для машинного обучения изображений, используя два набора данных, цифры и MNIST, которые представляют небольшие наборы данных и большие наборы данных соответственно.

1. Набор данных цифр

1.1 Импорт данных

Набор данных digits — это набор данных в scikit-learn, который не нужно загружать из Интернета.

from sklearn.datasets import load_digitsdigits = load_digits()скопировать код

Теперь давайте посмотрим на статистику набора данных цифр.

#一共有1797个数据和1797标签print('照片数据形状(维度): ', digits.data.shape)print('标签数据形状(维度): ', digits.target.shape)скопировать код

бегать

    照片数据形状(维度):  (1797, 64)    标签数据形状(维度):  (1797,)скопировать код

1.2 Печать фотографий и их этикеток

Поскольку размер данных равен 1797, всего имеется 64 измерения. Тогда каждый фрагмент данных представляет собой список. Но мы знаем, что изображение представляет собой двумерную структуру, и мы знаем, что изображение набора данных digits представляет собой квадрат, поэтому нам нужно восстановить (изменить форму) исходные данные изображения в массив (8, 8) .

Чтобы дать вам более интуитивное представление о наборе данных, мы печатаем здесь первые 5 фотографий набора данных цифр.

#先查看图片是什么样子print(digits.data[0])#重构图片数据为(8,8)的数组import numpy as npprint(np.reshape(digits.data[0], (8,8)))скопировать код

бегать

    [ 0.  0.  5. 13.  9.  1.  0.  0.  0.  0. 13. 15. 10. 15.  5.  0.  0.  3.     15.  2.  0. 11.  8.  0.  0.  4. 12.  0.  0.  8.  8.  0.  0.  5.  8.  0.      0.  9.  8.  0.  0.  4. 11.  0.  1. 12.  7.  0.  0.  2. 14.  5. 10. 12.      0.  0.  0.  0.  6. 13. 10.  0.  0.  0.]    [[ 0.  0.  5. 13.  9.  1.  0.  0.]     [ 0.  0. 13. 15. 10. 15.  5.  0.]     [ 0.  3. 15.  2.  0. 11.  8.  0.]     [ 0.  4. 12.  0.  0.  8.  8.  0.]     [ 0.  5.  8.  0.  0.  9.  8.  0.]     [ 0.  4. 11.  0.  1. 12.  7.  0.]     [ 0.  2. 14.  5. 10. 12.  0.  0.]     [ 0.  0.  6. 13. 10.  0.  0.  0.]]скопировать код

Отображение изображений matplotlib в блокнотах

%matplotlib inlineimport numpy as npimport matplotlib.pyplot as plt#选取数据集前5个数据data = digits.data[0:5]label = digits.target[0:5]#画图尺寸宽20,高4plt.figure(figsize = (20, 4))for idx, (imagedata, label) in enumerate(zip(data, label)):    #画布被切分为一行5个子图。 idx+1表示第idx+1个图    plt.subplot(1, 5, idx+1)    image = np.reshape(imagedata, (8, 8))    #为了方便观看,我们将其灰度显示    plt.imshow(image, cmap = plt.cm.gray)    plt.title('The number of Image is  {}'.format(label))скопировать код
png

1.3 Разделите данные на обучающий набор и тестовый набор

Чтобы уменьшить возможность подгонки модели к данным, способность модели к обобщению повышена. Чтобы убедиться, что наша обученная модель может делать прогнозы на основе новых данных, нам нужно разделить набор данных цифр на обучающий набор и тестовый набор.

from sklearn.model_selection import train_test_split#测试集占总数据中的30%, 设置随机状态,方便后续复现本次的随机切分X_train, X_test, y_train, y_test = train_test_split(digits.data, digits.target, test_size = 0.3, random_state=100)скопировать код

1.4 Обучение, предсказание, точность

В этой статье мы используем логистическую регрессию. Поскольку набор цифровых данных невелик, мы просто используем решатель по умолчанию.

from sklearn.linear_model import LogisticRegressionlogisticRegre = LogisticRegression()#训练logisticRegre.fit(X_train, y_train)скопировать код

Сделайте прогнозы для новых данных.Обратите внимание, что если вы делаете прогнозы только для одних данных (одномерный массив), вы должны преобразовать одномерный массив в матричную форму.

data.reshape(n_rows, n_columns)

Преобразование данных в матрицу размерности (n_rows, n_columns). Обратите внимание, что если мы не знаем размеров размерности преобразуемой матрицы, мы можем установить это значение равным -1.

#测试集中的第一个数据。#我们知道它是一行,但是如果不知道列是多少,那么设置为-1#实际上,我们知道列是64 #所以下面的写法等同于X_test[0].reshape(1, 64)one_new_image = X_test[0].reshape(1, -1)#预测logisticRegre.predict(one_new_image)скопировать код

бегать

array([9])скопировать код

Делайте прогнозы на основе нескольких данных

predictions = logisticRegre.predict(X_test[0:10])#真实的数字print(y_test[0:10])#预测的数字print(predictions)#准确率score = logisticRegre.score(X_test, y_test)print(score)скопировать код

результат операции

    [9 9 0 2 4 5 7 4 7 2]    [9 3 0 2 4 5 7 4 3 2]    0.9592592592592593скопировать код

вау, это довольно точно.

1.5 Матрица путаницы

Как правило, матрица путаницы часто используется для оценки точности прогноза.Здесь мы используем seaborn и matplotlib для рисования матрицы путаницы.

% matplotlib inlineimport matplotlib.pyplot as pltimport seaborn as snsfrom sklearn.metrics import confusion_matrixpredictions = logisticRegre.predict(X_test)cm = confusion_matrix(y_test, predictions)plt.figure(figsize = (9, 9))sns.heatmap(cm, annot=True, fmt='.3f', linewidth=0.5, square=True, cmap='Blues_r')plt.ylabel('Actual Label')plt.xlabel('Predicted Label')plt.title('Accurate Score: {}'.format(score), size=15)скопировать код
png

2. Набор данных МНИСТ

Набор цифровых данных чрезвычайно мал, и обучение и прогнозирование можно выполнить за считанные секунды. Однако, если набор данных большой, наши требования к скорости обучения становятся актуальными, и необходима настройка параметров модели. Итак, давайте попробуем это с большим набором данных MNIST. Я скачал mnist из Интернета и организовал его в файл csv. Первый столбец — это метка, а следующий столбец — значение пикселя изображения. Всего 785 столбцов. Изображения набора данных MNIST состоят из 28 * 28.

import pandas as pdimport numpy as nptrain = pd.read_csv('mnist_train.csv', header = None)test = pd.read_csv('mnist_test.csv', header = None)y_train = train.loc[:, 0] #pd.series#注意:train.loc[:, 1:]返回的是pd.DataFrame类。#这里我们要将其转化为np.array方便操作X_train = np.array(train.loc[:, 1:]) y_test = test.loc[:, 0]X_test = np.array(test.loc[:, 1:])скопировать код
#我们看看这些MNIST维度print('X_train 维度: {}'.format(X_train.shape))print('y_train 维度: {}'.format(y_train.shape))print('X_test 维度: {}'.format(X_test.shape))print('y_test 维度: {}'.format(y_test.shape))скопировать код

результат операции

    X_train 维度: (60000, 784)    y_train 维度: (60000,)    X_test 维度: (10000, 784)    y_test 维度: (10000,)скопировать код

2.1 Печать изображений и этикеток MNIST

%matplotlib inlineimport numpy as npimport matplotlib.pyplot as plt#只看5张图片数据data = X_train[0:5]label = y_train[0:5]plt.figure(figsize = (20, 4))for idx, (imagedata, label) in enumerate(zip(data, label)):    plt.subplot(1, 5, idx+1)    #MNIST数据集的图片为28*28像素    image = np.reshape(imagedata, (28,28))    plt.imshow(image, cmap=plt.cm.gray)    plt.title('The number of Image is {}'.format(label))скопировать код
png

2.2 Обучение, предсказание, точность

Предыдущий набор данных цифр был всего 1797, а размер каждого изображения был (8, 8). Но набор данных MNIST до 70000, а размер каждого изображения составляет (28, 28). Поэтому, если параметры подобраны не разумно, скорость обучения будет очень медленной.

from sklearn.linear_model import LogisticRegressionimport timedef model(solver='liblinear'):    """    改变LogisticRegression模型的solver参数,计算运行准确率及时间    """    start = time.time()    logisticRegr = LogisticRegression(solver=solver)    logisticRegr.fit(X_train, y_train)    score = logisticRegr.score(X_test, y_test)    end = time.time()    print('准确率:{0}, 耗时: {1}'.format(score, int(end-start)))    return logisticRegrmodel(solver='liblinear')model(solver='lbfgs')скопировать код

результат операции

    准确率:0.9176, 耗时3840    准确率:0.9173, 耗时65скопировать код

После тестирования выяснилось, что в моем macbook air2015 по умолчанию

Solver='liblinear' время обучения 3840 секунд.

Solver='lbfgs' время обучения 65 секунд.

Решатель изменился с liblinear на lbfgs, пожертвовав только точностью 0,0003, но скорость можно увеличить почти в 60 раз. При обучении с помощью машинного обучения параметры алгоритма другие, и скорость обучения сильно различается.Взгляните на следующий рисунок.

2.3 Печать ошибочно предсказанных изображений

Матрица путаницы, используемая набором данных цифр, оценивает точность, но она недостаточно интуитивна. Здесь мы печатаем изображение ошибки предсказания

logistricRegr = model(solver='lbfgs')predictions = logistricRegr.predict(X_test)#预测分类错误图片的索引misclassifiedIndexes = []for idx,(label,predict) in enumerate(zip(y_test, predictions)):    if label != predict:        misclassifiedIndexes.append(idx)print(misclassifiedIndexes)скопировать код
准确率:0.9173, 耗时76[8, 33, 38, 63, 66, 73, 119, 124, 149, 151, 153, 193, 211, 217, 218, 233, 241, 245, 247, 259, 282, 290, 307, 313, 318, 320,  ........    857, 877, 881, 898, 924, 938, 939, 947, 16789808, 9811, 9832, 9835, 9839, 9840, 9855, 9858, 9867, 9874, 9883, 9888, 9892, 9893, 9901, 9905, 9916, 9925, 9926, 9941, 9943, 9944, 9959, 9970, 9975, 9980, 9982, 9986]скопировать код

напечатать неправильное изображение

%matplotlib inlineimport matplotlib.pyplot as pltimport numpy as npplt.figure(figsize = (20, 4))#打印前5个分类错误的图片for plotidx, badidx in enumerate(misclassifiedIndexes[0:5]):    plt.subplot(1, 5, plotidx+1)    img = np.reshape(X_test[badidx], (28, 28))    plt.imshow(img)    predict_label = predictions[badidx]    true_label = y_test[badidx]    plt.title('Predicted: {0}, Actual: {1}'.format(predict_label, true_label))скопировать код
png

Получение данных и кода

Прошлые статьи

Учебные материалы по Python объемом 100 ГБ: от начального уровня до мастера!   

Как студенты колледжа используют свои знания, чтобы зарабатывать 3000 в месяц

Сотни наборов текстовых данных G ждут вас | Бесплатная заявка

Почему вы планируете на 2019, а не на 2018 год?   

15 лучших библиотек Python для науки о данных в 2017 году 

Знакомство с алгоритмом К-средних 

Машинное обучение | Восемь шагов для решения 90% проблем NLP 

Используйте sklearn для обработки естественного языка-1

Сложная вложенная структура данных словаря, обрабатывающая библиотеку-glom

Чтение файлов pdf и docx, профессиональный тест эффективен

Как извлечь информацию об объектах из текста? 

Что может сделать nltk для китайцев 

Каждое слово, оставленное в Интернете, раскрывает вашу личность

Элегантные и лаконичные списки

Получить список советов и рекомендаций

Как сортировать данные различными способами?

[Объяснение видео] Scrapy рекурсивно собирает информацию о пользователях Jianshu

Артефакт сбора информации о торговце Meituan 

Используйте библиотеку chardect для решения проблемы искаженных веб-страниц.

gevent: асинхронная теория и практика 

Легкая и эффективная библиотека асинхронного доступа

Подробное объяснение конфигурации селенового диска

Как использовать артефакт сканера PyQuery

Простое изучение базы данных SQLite3

Функция вызова Python по строке

Библиотека символьных вычислений в круге Python - Sympy

Как использовать библиотеку даты и времени в Python