Соси кошек вместе с кодом! Эта статья участвует【Эссе "Мяу Звезды"】.
Идентификация кошек на основе PaddlePaddle - чей ты котенок?
Одноклассник Сяомин: Котёнок из семьи Эргоузи снова пришёл его украсть Сяохуамао голодает, что мне делать? ? ?
что делать? ? ? Идентификация мелких домашних животных может быть проведена только исследовательской группой университета Что мне делать, как воспитаннику детского сада? ? ?
Одноклассник Сяомин в отчаянии почесал голову и издал болезненный звук «ааааа». . . . . .
Сяо Мин обнаружил новостной репортаж «Технология распознавания лиц обезьяны» уже здесь! 》new.QQ.com/О, красота/2021022…
Не волнуйтесь, детский сад Classmate Xiaoming, я научу вас использовать Paddlehub летающей весла, чтобы реализовать распознавание котенков, а котята, которые вы не знаете, не разрешается войти.
1. Сбор данных
Все видео с кошками собраны из общедоступных видео, а фото лица котенка получено через скриншот видео, поэтому нет необходимости делать отдельное фото.
!unzip -q data/data71411/cat.zip
1.1 python вызывает openCV, чтобы делать снимок из видео каждую 1 секунду и сохраняет его после нумерации.
import cv2
import os
for i in range(1,5):
# 创建图片目录
print(i)
mp4_file=str(i)+'.mp4'
dir_path=os.path.join('dataset',str(i))
if not os.path.exists(dir_path):
os.makedirs(dir_path)
# 每秒存一次图片
vidcap = cv2.VideoCapture(mp4_file)
success,image = vidcap.read()
fps = int(vidcap.get(cv2.CAP_PROP_FPS))
count = 0
while success:
if count % fps == 0:
cv2.imwrite("{}/{}.jpg".format(dir_path, int(count / fps)), image)
print('Process %dth seconds: ' % int(count / fps), success)
success,image = vidcap.read()
count += 1
1.2 Генерация изображений для обработки
Удалите ненормальные изображения, такие как кредиты
Вручную...................................
import matplotlib.pyplot as plt
%matplotlib inline
import cv2 as cv
import numpy as np
# jupyter notebook显示
def visualize_images():
img = cv.imread('dataset/1/1.jpg')
plt.imshow(img)
plt.show()
visualize_images()
1.3 Представление набора данных
4 разных котенка
1.4 генерация списка
Для пользовательского набора данных сначала создайте список изображений, разделите пользовательские изображения на тестовый набор и обучающий набор и пометьте их. Следующая программа может быть запущена одна, если передан путь к папке большой категории, программа будет перебирать каждую подкатегорию в ней, чтобы сгенерировать список в фиксированном формате.Например, мы помещаем корневой каталог категория лица Передайте ее в ./dataset. Наконец, в указанном каталоге будут созданы три файла: readme.json, train.list и test.list.
import os
import json
# 设置要生成文件的路径
data_root_path = '/home/aistudio/dataset'
# 所有类别的信息
class_detail = []
# 获取所有类别保存的文件夹名称,这里是['1', '2', '3','4']
class_dirs = os.listdir(data_root_path)
# 类别标签
class_label = 0
# 获取总类别的名称
father_paths = data_root_path.split('/') #['', 'home', 'aistudio', 'dataset']
while True:
if father_paths[father_paths.__len__() - 1] == '':
del father_paths[father_paths.__len__() - 1]
else:
break
father_path = father_paths[father_paths.__len__() - 1]
# 把生产的数据列表都放在自己的总类别文件夹中
data_list_path = '/home/aistudio/%s/' % father_path
# 如果不存在这个文件夹,就创建
isexist = os.path.exists(data_list_path)
if not isexist:
os.makedirs(data_list_path)
# 清空原来的数据
with open(data_list_path + "test.txt", 'w') as f:
pass
with open(data_list_path + "trainer.txt", 'w') as f:
pass
# 总的图像数量
all_class_images = 0
# 读取每个类别
for class_dir in class_dirs:
# 每个类别的信息
class_detail_list = {}
test_sum = 0
trainer_sum = 0
# 统计每个类别有多少张图片
class_sum = 0
# 获取类别路径
path = data_root_path + "/" + class_dir
# 获取所有图片
img_paths = os.listdir(path)
for img_path in img_paths: # 遍历文件夹下的每个图片
name_path = path + '/' + img_path # 每张图片的路径
if class_sum % 10 == 0: # 每10张图片取一个做测试数据
test_sum += 1 #test_sum测试数据的数目
with open(data_list_path + "test.txt", 'a') as f:
f.write(name_path + "\t%d" % class_label + "\n") #class_label 标签:0,1,2
else:
trainer_sum += 1 #trainer_sum测试数据的数目
with open(data_list_path + "trainer.txt", 'a') as f:
f.write(name_path + "\t%d" % class_label + "\n")#class_label 标签:0,1,2
class_sum += 1 #每类图片的数目
all_class_images += 1 #所有类图片的数目
# 说明的json文件的class_detail数据
class_detail_list['class_name'] = class_dir #类别名称,如jiangwen
class_detail_list['class_label'] = class_label #类别标签,0,1,2
class_detail_list['class_test_images'] = test_sum #该类数据的测试集数目
class_detail_list['class_trainer_images'] = trainer_sum #该类数据的训练集数目
class_detail.append(class_detail_list)
class_label += 1 #class_label 标签:0,1,2
# 获取类别数量
all_class_sum = class_dirs.__len__()
# 说明的json文件信息
readjson = {}
readjson['all_class_name'] = father_path #文件父目录
readjson['all_class_sum'] = all_class_sum #
readjson['all_class_images'] = all_class_images
readjson['class_detail'] = class_detail
jsons = json.dumps(readjson, sort_keys=True, indent=4, separators=(',', ': '))
with open(data_list_path + "readme.json",'w') as f:
f.write(jsons)
print ('生成数据列表完成!')
生成数据列表完成!
1.5 Создание набора данных
import paddle
import paddle.vision.transforms as T
import numpy as np
from PIL import Image
class MiaoMiaoDataset(paddle.io.Dataset):
"""
2类Bee数据集类的定义
"""
def __init__(self,mode='train'):
"""
初始化函数
"""
self.data = []
with open('dataset/{}.txt'.format(mode)) as f:
for line in f.readlines():
info = line.strip().split('\t')
if len(info) > 0:
self.data.append([info[0].strip(), info[1].strip()])
if mode == 'train':
self.transforms = T.Compose([
T.Resize((224,224)),
T.RandomHorizontalFlip(0.5), # 随机水平翻转
T.ToTensor(), # 数据的格式转换和标准化 HWC => CHW
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 图像归一化
])
else:
self.transforms = T.Compose([
T.Resize((224,224)), # 图像大小修改
# T.RandomCrop(IMAGE_SIZE), # 随机裁剪
T.ToTensor(), # 数据的格式转换和标准化 HWC => CHW
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 图像归一化
])
def get_origin_data(self):
return self.data
def __getitem__(self, index):
"""
根据索引获取单个样本
"""
image_file, label = self.data[index]
image = Image.open(image_file)
if image.mode != 'RGB':
image = image.convert('RGB')
image = self.transforms(image)
return image, np.array(label, dtype='int64')
def __len__(self):
"""
获取样本总数
"""
return len(self.data)
train_dataset=MiaoMiaoDataset(mode='trainer')
test_dataset=MiaoMiaoDataset(mode='test')
print('train_data len: {}, test_data len:{}'.format(train_dataset.__len__(), test_dataset.__len__()))
train_data len: 45, test_data len:7
2. Определение модели и обучение
В настоящее время данные разделены на обучающие и тестовые наборы данных, а также на количество классификаций.
Далее мы определим модель и снова рассмотрим сеть resnet50.
import paddle
from paddle import Model
# 定义网络
network=paddle.vision.models.resnet50(num_classes=4, pretrained=True)
model = paddle.Model(network)
model.summary((-1, 3, 224 , 224))
100%|██████████| 151272/151272 [00:02<00:00, 72148.01it/s]
-------------------------------------------------------------------------------
Layer (type) Input Shape Output Shape Param #
===============================================================================
Conv2D-1 [[1, 3, 224, 224]] [1, 64, 112, 112] 9,408
BatchNorm2D-1 [[1, 64, 112, 112]] [1, 64, 112, 112] 256
ReLU-1 [[1, 64, 112, 112]] [1, 64, 112, 112] 0
MaxPool2D-1 [[1, 64, 112, 112]] [1, 64, 56, 56] 0
Conv2D-3 [[1, 64, 56, 56]] [1, 64, 56, 56] 4,096
BatchNorm2D-3 [[1, 64, 56, 56]] [1, 64, 56, 56] 256
ReLU-2 [[1, 256, 56, 56]] [1, 256, 56, 56] 0
Conv2D-4 [[1, 64, 56, 56]] [1, 64, 56, 56] 36,864
BatchNorm2D-4 [[1, 64, 56, 56]] [1, 64, 56, 56] 256
Conv2D-5 [[1, 64, 56, 56]] [1, 256, 56, 56] 16,384
BatchNorm2D-5 [[1, 256, 56, 56]] [1, 256, 56, 56] 1,024
Conv2D-2 [[1, 64, 56, 56]] [1, 256, 56, 56] 16,384
BatchNorm2D-2 [[1, 256, 56, 56]] [1, 256, 56, 56] 1,024
BottleneckBlock-1 [[1, 64, 56, 56]] [1, 256, 56, 56] 0
Conv2D-6 [[1, 256, 56, 56]] [1, 64, 56, 56] 16,384
BatchNorm2D-6 [[1, 64, 56, 56]] [1, 64, 56, 56] 256
ReLU-3 [[1, 256, 56, 56]] [1, 256, 56, 56] 0
Conv2D-7 [[1, 64, 56, 56]] [1, 64, 56, 56] 36,864
BatchNorm2D-7 [[1, 64, 56, 56]] [1, 64, 56, 56] 256
Conv2D-8 [[1, 64, 56, 56]] [1, 256, 56, 56] 16,384
BatchNorm2D-8 [[1, 256, 56, 56]] [1, 256, 56, 56] 1,024
BottleneckBlock-2 [[1, 256, 56, 56]] [1, 256, 56, 56] 0
Conv2D-9 [[1, 256, 56, 56]] [1, 64, 56, 56] 16,384
BatchNorm2D-9 [[1, 64, 56, 56]] [1, 64, 56, 56] 256
ReLU-4 [[1, 256, 56, 56]] [1, 256, 56, 56] 0
Conv2D-10 [[1, 64, 56, 56]] [1, 64, 56, 56] 36,864
BatchNorm2D-10 [[1, 64, 56, 56]] [1, 64, 56, 56] 256
Conv2D-11 [[1, 64, 56, 56]] [1, 256, 56, 56] 16,384
BatchNorm2D-11 [[1, 256, 56, 56]] [1, 256, 56, 56] 1,024
BottleneckBlock-3 [[1, 256, 56, 56]] [1, 256, 56, 56] 0
Conv2D-13 [[1, 256, 56, 56]] [1, 128, 56, 56] 32,768
BatchNorm2D-13 [[1, 128, 56, 56]] [1, 128, 56, 56] 512
ReLU-5 [[1, 512, 28, 28]] [1, 512, 28, 28] 0
Conv2D-14 [[1, 128, 56, 56]] [1, 128, 28, 28] 147,456
BatchNorm2D-14 [[1, 128, 28, 28]] [1, 128, 28, 28] 512
Conv2D-15 [[1, 128, 28, 28]] [1, 512, 28, 28] 65,536
BatchNorm2D-15 [[1, 512, 28, 28]] [1, 512, 28, 28] 2,048
Conv2D-12 [[1, 256, 56, 56]] [1, 512, 28, 28] 131,072
BatchNorm2D-12 [[1, 512, 28, 28]] [1, 512, 28, 28] 2,048
BottleneckBlock-4 [[1, 256, 56, 56]] [1, 512, 28, 28] 0
Conv2D-16 [[1, 512, 28, 28]] [1, 128, 28, 28] 65,536
BatchNorm2D-16 [[1, 128, 28, 28]] [1, 128, 28, 28] 512
ReLU-6 [[1, 512, 28, 28]] [1, 512, 28, 28] 0
Conv2D-17 [[1, 128, 28, 28]] [1, 128, 28, 28] 147,456
BatchNorm2D-17 [[1, 128, 28, 28]] [1, 128, 28, 28] 512
Conv2D-18 [[1, 128, 28, 28]] [1, 512, 28, 28] 65,536
BatchNorm2D-18 [[1, 512, 28, 28]] [1, 512, 28, 28] 2,048
BottleneckBlock-5 [[1, 512, 28, 28]] [1, 512, 28, 28] 0
Conv2D-19 [[1, 512, 28, 28]] [1, 128, 28, 28] 65,536
BatchNorm2D-19 [[1, 128, 28, 28]] [1, 128, 28, 28] 512
ReLU-7 [[1, 512, 28, 28]] [1, 512, 28, 28] 0
Conv2D-20 [[1, 128, 28, 28]] [1, 128, 28, 28] 147,456
BatchNorm2D-20 [[1, 128, 28, 28]] [1, 128, 28, 28] 512
Conv2D-21 [[1, 128, 28, 28]] [1, 512, 28, 28] 65,536
BatchNorm2D-21 [[1, 512, 28, 28]] [1, 512, 28, 28] 2,048
BottleneckBlock-6 [[1, 512, 28, 28]] [1, 512, 28, 28] 0
Conv2D-22 [[1, 512, 28, 28]] [1, 128, 28, 28] 65,536
BatchNorm2D-22 [[1, 128, 28, 28]] [1, 128, 28, 28] 512
ReLU-8 [[1, 512, 28, 28]] [1, 512, 28, 28] 0
Conv2D-23 [[1, 128, 28, 28]] [1, 128, 28, 28] 147,456
BatchNorm2D-23 [[1, 128, 28, 28]] [1, 128, 28, 28] 512
Conv2D-24 [[1, 128, 28, 28]] [1, 512, 28, 28] 65,536
BatchNorm2D-24 [[1, 512, 28, 28]] [1, 512, 28, 28] 2,048
BottleneckBlock-7 [[1, 512, 28, 28]] [1, 512, 28, 28] 0
Conv2D-26 [[1, 512, 28, 28]] [1, 256, 28, 28] 131,072
BatchNorm2D-26 [[1, 256, 28, 28]] [1, 256, 28, 28] 1,024
ReLU-9 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-27 [[1, 256, 28, 28]] [1, 256, 14, 14] 589,824
BatchNorm2D-27 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024
Conv2D-28 [[1, 256, 14, 14]] [1, 1024, 14, 14] 262,144
BatchNorm2D-28 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 4,096
Conv2D-25 [[1, 512, 28, 28]] [1, 1024, 14, 14] 524,288
BatchNorm2D-25 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 4,096
BottleneckBlock-8 [[1, 512, 28, 28]] [1, 1024, 14, 14] 0
Conv2D-29 [[1, 1024, 14, 14]] [1, 256, 14, 14] 262,144
BatchNorm2D-29 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024
ReLU-10 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-30 [[1, 256, 14, 14]] [1, 256, 14, 14] 589,824
BatchNorm2D-30 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024
Conv2D-31 [[1, 256, 14, 14]] [1, 1024, 14, 14] 262,144
BatchNorm2D-31 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 4,096
BottleneckBlock-9 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-32 [[1, 1024, 14, 14]] [1, 256, 14, 14] 262,144
BatchNorm2D-32 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024
ReLU-11 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-33 [[1, 256, 14, 14]] [1, 256, 14, 14] 589,824
BatchNorm2D-33 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024
Conv2D-34 [[1, 256, 14, 14]] [1, 1024, 14, 14] 262,144
BatchNorm2D-34 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 4,096
BottleneckBlock-10 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-35 [[1, 1024, 14, 14]] [1, 256, 14, 14] 262,144
BatchNorm2D-35 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024
ReLU-12 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-36 [[1, 256, 14, 14]] [1, 256, 14, 14] 589,824
BatchNorm2D-36 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024
Conv2D-37 [[1, 256, 14, 14]] [1, 1024, 14, 14] 262,144
BatchNorm2D-37 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 4,096
BottleneckBlock-11 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-38 [[1, 1024, 14, 14]] [1, 256, 14, 14] 262,144
BatchNorm2D-38 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024
ReLU-13 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-39 [[1, 256, 14, 14]] [1, 256, 14, 14] 589,824
BatchNorm2D-39 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024
Conv2D-40 [[1, 256, 14, 14]] [1, 1024, 14, 14] 262,144
BatchNorm2D-40 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 4,096
BottleneckBlock-12 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-41 [[1, 1024, 14, 14]] [1, 256, 14, 14] 262,144
BatchNorm2D-41 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024
ReLU-14 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-42 [[1, 256, 14, 14]] [1, 256, 14, 14] 589,824
BatchNorm2D-42 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024
Conv2D-43 [[1, 256, 14, 14]] [1, 1024, 14, 14] 262,144
BatchNorm2D-43 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 4,096
BottleneckBlock-13 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-45 [[1, 1024, 14, 14]] [1, 512, 14, 14] 524,288
BatchNorm2D-45 [[1, 512, 14, 14]] [1, 512, 14, 14] 2,048
ReLU-15 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 0
Conv2D-46 [[1, 512, 14, 14]] [1, 512, 7, 7] 2,359,296
BatchNorm2D-46 [[1, 512, 7, 7]] [1, 512, 7, 7] 2,048
Conv2D-47 [[1, 512, 7, 7]] [1, 2048, 7, 7] 1,048,576
BatchNorm2D-47 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 8,192
Conv2D-44 [[1, 1024, 14, 14]] [1, 2048, 7, 7] 2,097,152
BatchNorm2D-44 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 8,192
BottleneckBlock-14 [[1, 1024, 14, 14]] [1, 2048, 7, 7] 0
Conv2D-48 [[1, 2048, 7, 7]] [1, 512, 7, 7] 1,048,576
BatchNorm2D-48 [[1, 512, 7, 7]] [1, 512, 7, 7] 2,048
ReLU-16 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 0
Conv2D-49 [[1, 512, 7, 7]] [1, 512, 7, 7] 2,359,296
BatchNorm2D-49 [[1, 512, 7, 7]] [1, 512, 7, 7] 2,048
Conv2D-50 [[1, 512, 7, 7]] [1, 2048, 7, 7] 1,048,576
BatchNorm2D-50 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 8,192
BottleneckBlock-15 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 0
Conv2D-51 [[1, 2048, 7, 7]] [1, 512, 7, 7] 1,048,576
BatchNorm2D-51 [[1, 512, 7, 7]] [1, 512, 7, 7] 2,048
ReLU-17 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 0
Conv2D-52 [[1, 512, 7, 7]] [1, 512, 7, 7] 2,359,296
BatchNorm2D-52 [[1, 512, 7, 7]] [1, 512, 7, 7] 2,048
Conv2D-53 [[1, 512, 7, 7]] [1, 2048, 7, 7] 1,048,576
BatchNorm2D-53 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 8,192
BottleneckBlock-16 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 0
AdaptiveAvgPool2D-1 [[1, 2048, 7, 7]] [1, 2048, 1, 1] 0
Linear-1 [[1, 2048]] [1, 4] 8,196
===============================================================================
Total params: 23,569,348
Trainable params: 23,463,108
Non-trainable params: 106,240
-------------------------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 261.48
Params size (MB): 89.91
Estimated Total Size (MB): 351.96
-------------------------------------------------------------------------------
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1263: UserWarning: Skip loading for fc.weight. fc.weight receives a shape [2048, 1000], but the expected shape is [2048, 4].
warnings.warn(("Skip loading for {}. ".format(key) + str(err)))
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1263: UserWarning: Skip loading for fc.bias. fc.bias receives a shape [1000], but the expected shape is [4].
warnings.warn(("Skip loading for {}. ".format(key) + str(err)))
{'total_params': 23569348, 'trainable_params': 23463108}
# 模型训练配置
model.prepare(optimizer=paddle.optimizer.Adam(learning_rate=0.000005,parameters=model.parameters()),# 优化器
loss=paddle.nn.CrossEntropyLoss(), # 损失函数
metrics=paddle.metric.Accuracy()) # 评估指标
# 训练可视化VisualDL工具的回调函数
visualdl = paddle.callbacks.VisualDL(log_dir='visualdl_log')
# 启动模型全流程训练
model.fit(train_dataset, # 训练数据集
# test_dataset, # 评估数据集
epochs=20, # 总的训练轮次
batch_size=256, # 批次计算的样本量大小
shuffle=True, # 是否打乱样本集
verbose=1, # 日志展示格式
save_dir='./chk_points/', # 分阶段的训练模型存储路径
callbacks=[visualdl]) # 回调函数使用
model.save('model_save')
3. Оценка и тестирование модели
# plot the evaluate
model.evaluate(test_dataset,verbose=1)
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
step 7/7 [==============================] - loss: 0.0000e+00 - acc: 0.7143 - 30ms/step
Eval samples: 7
{'loss': [0.0], 'acc': 0.7142857142857143}
предсказывать
Делайте прогнозы на основе данных test_dataset
print('测试数据集样本量:{}'.format(len(test_dataset)))
测试数据集样本量:7
# 执行预测
result = model.predict(test_dataset)
Predict begin...
step 7/7 [==============================] - 32ms/step
Predict samples: 7
# 打印前10条看看结果
for idx in range(7):
predict_label = str(np.argmax(result[0][idx]))
real_label = str(test_dataset.__getitem__(idx)[1])
print('样本ID:{}, 真实标签:{}, 预测值:{}'.format(idx, real_label, predict_label))
样本ID:0, 真实标签:0, 预测值:0
样本ID:1, 真实标签:0, 预测值:0
样本ID:2, 真实标签:2, 预测值:2
样本ID:3, 真实标签:3, 预测值:3
样本ID:4, 真实标签:3, 预测值:3
样本ID:5, 真实标签:4, 预测值:0
样本ID:6, 真实标签:4, 预测值:1
# 定义画图方法
from PIL import Image
import matplotlib.font_manager as font_manager
import matplotlib.pyplot as plt
%matplotlib inline
fontpath = 'MINGHEI_R.TTF'
font = font_manager.FontProperties(fname=fontpath, size=10)
def show_img(img, predict):
plt.figure()
plt.title(predict, FontProperties=font)
plt.imshow(img, cmap=plt.cm.binary)
plt.show()
# 抽样展示
origin_data=test_dataset.get_origin_data()
for i in range(7):
img_path=origin_data[i][0]
real_label=str(origin_data[i][1])
predict_label= str(np.argmax(result[0][i]))
img=Image.open(img_path)
title='样本ID:{}, 真实标签:{}, 预测值:{}'.format(idx, real_label, predict_label)
show_img(img, title)