После долгих поисков в Интернете, я хочу использовать vgg для классификации изображений, и я обнаружил, что есть один принцип на востоке и один принцип на западе. В этой статье обобщается официальная реализация VGG с использованием фреймворка pytorch для классификации изображений.
Метод, представленный в этой статье, требует подготовки двух дополнительных файлов: Одна из них — модель vgg (посмотрите на доменное имя, чтобы узнать, что оно официально предоставлено pytorch), адрес загрузки:
{
'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
}
Одним из них является информация об этикетке image-net, в этой статье используется версия json.нажмите на меня, чтобы скачать.Будьте осторожны с расположением и названием файлов
Обратите внимание на то, где хранятся фотографии, которые вы хотите проверить!
не может数据集/直接就是你的图片.jpg/png
, должно быть:数据集/不管怎么样,这里一定要有一个文件夹/你的图片.jpg/png
Версия pytorch >=0.4 Никаких других зависимостей, кроме pytorch.
пример
预测的种类是: ['n02793495', 'barn']
预测的种类是: ['n02793495', 'barn']
预测的种类是: ['n02793495', 'barn']
预测的种类是: ['n03776460', 'mobile_home']
预测的种类是: ['n03956157', 'planetarium']
预测的种类是: ['n02793495', 'barn']
预测的种类是: ['n02793495', 'barn']
Чтение кода, связанного с json
// readJson.py
import json
def GetInfo():
with open("./imageNet.json",'r') as load_f:
load_dict = json.load(load_f)
return load_dict
Настоящая тема здесь
import torch
import readJson
from torchvision import transforms
import torchvision.models as models
from torchvision import datasets
tran=transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# 准备vgg19 /vgg 16 就准备好相应的文件
print("加载vgg模型中...........")
modelVGG = models.vgg19()
path ="C://Users/liu/.torch/models/vgg19-dcbb9e9d.pth"
pre = torch.load(path)
modelVGG.load_state_dict(pre)
print("加载vgg模型成功")
# 准备cuda
CUDA = torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
modelVGG = modelVGG.to(device)
# 准备数据集
print("准备数据集中............")
dataset = datasets.ImageFolder('D:/shujuji/你的文件夹',tran)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True,)
jsonInfo = readJson.GetInfo()
print("准备成功")
for i,imgs in enumerate(data_loader):
# 哪怕一批是8张图,还是会被包在一个list中。形如:list[ [图1,图2,.....图8] ]。是一个数组外又包了一个数组。
real = imgs[0]
if (CUDA):
real = real.cuda()
out = modelVGG(real)
# 求指定维度的最大值,返回最大值以及索引
predict_value, predict_idx = torch.max(out, 1)
print(predict_value,predict_idx)
# 解析每一次的结果
pos =0
for p in predict_idx:
print(p.item(),predict_value[pos].item())
print("预测的种类是:",jsonInfo[ str(p.item())])
pos= pos +1