[Фактическое использование] pytorch использует VGG для классификации изображений

PyTorch

После долгих поисков в Интернете, я хочу использовать vgg для классификации изображений, и я обнаружил, что есть один принцип на востоке и один принцип на западе. В этой статье обобщается официальная реализация VGG с использованием фреймворка pytorch для классификации изображений.

Метод, представленный в этой статье, требует подготовки двух дополнительных файлов: Одна из них — модель vgg (посмотрите на доменное имя, чтобы узнать, что оно официально предоставлено pytorch), адрес загрузки:

{
    'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
    'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
    'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
    'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
}

Одним из них является информация об этикетке image-net, в этой статье используется версия json.нажмите на меня, чтобы скачать.Будьте осторожны с расположением и названием файлов

Обратите внимание на то, где хранятся фотографии, которые вы хотите проверить!

не может数据集/直接就是你的图片.jpg/png, должно быть:数据集/不管怎么样,这里一定要有一个文件夹/你的图片.jpg/png

Версия pytorch >=0.4 Никаких других зависимостей, кроме pytorch.

пример

预测的种类是: ['n02793495', 'barn']

预测的种类是: ['n02793495', 'barn']

预测的种类是: ['n02793495', 'barn']

预测的种类是: ['n03776460', 'mobile_home']

预测的种类是: ['n03956157', 'planetarium']

预测的种类是: ['n02793495', 'barn']

预测的种类是: ['n02793495', 'barn']

Чтение кода, связанного с json

// readJson.py
import json

def GetInfo():
    with open("./imageNet.json",'r') as load_f:
        load_dict = json.load(load_f)
        return load_dict

Настоящая тема здесь

import torch
import readJson
from torchvision import transforms
import torchvision.models as models
from torchvision import datasets

tran=transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
])

# 准备vgg19 /vgg 16 就准备好相应的文件
print("加载vgg模型中...........")
modelVGG = models.vgg19()
path ="C://Users/liu/.torch/models/vgg19-dcbb9e9d.pth"
pre = torch.load(path)
modelVGG.load_state_dict(pre)
print("加载vgg模型成功")

# 准备cuda
CUDA = torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
modelVGG = modelVGG.to(device)

# 准备数据集
print("准备数据集中............")
dataset = datasets.ImageFolder('D:/shujuji/你的文件夹',tran)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True,)
jsonInfo =  readJson.GetInfo()
print("准备成功")

for i,imgs in enumerate(data_loader):
	# 哪怕一批是8张图,还是会被包在一个list中。形如:list[  [图1,图2,.....图8]  ]。是一个数组外又包了一个数组。
    real = imgs[0]

    if (CUDA):
        real = real.cuda()

    out = modelVGG(real)
    # 求指定维度的最大值,返回最大值以及索引
    predict_value, predict_idx = torch.max(out, 1)
    print(predict_value,predict_idx)
    # 解析每一次的结果
    pos =0
    for p in predict_idx:
        print(p.item(),predict_value[pos].item())
        print("预测的种类是:",jsonInfo[ str(p.item())])
        pos= pos +1