В этом руководстве мы будем использовать Flask для развертывания модели PyTorch и объясним REST API для вывода модели. В частности, мы развернем предварительно обученную модель DenseNet 121 для обнаружения изображений.
Примечание:
Доступны наGitHubПолучите полный код, использованный в этой статье, на
Это первое руководство из серии руководств по развертыванию моделей PyTorch в рабочей среде. Использование Flask таким образом — это, безусловно, самый простой способ начать обслуживание моделей PyTorch, но он не подходит для случаев использования с высокими требованиями к производительности. следовательно:
- Если вы уже знакомы с TorchScript, вы можете сразу перейти к нашемуLoading a TorchScript Model in C++руководство.
- Если вам нужно сначала освежить в памяти TorchScript, ознакомьтесь с нашимIntro a TorchScriptруководство.
1. Определите API
Мы начнем с определения конечной точки API, типов запросов и ответов. Конечная точка нашего API будет находиться по адресу/ predict
, который принимает изображения сfile
HTTP POST-запрос с параметрами. Ответ будет ответом JSON, содержащим прогноз:
{"class_id": "n02124075", "class_name": "Egyptian_cat"}
2. Зависимость (пакет)
Выполните следующую команду, чтобы загрузить необходимые нам зависимости:
$ pip install Flask==1.0.3 torchvision-0.3.0
3. Простой веб-сервер
Ниже приведен простой веб-сервер, взятый из документации Flask.
from flask import Flask
app = Flask(__name__)
@app.route('/')
def hello():
return 'Hello World!'
Сохраните приведенный выше фрагмент кода в файле с именемapp.py
, теперь вы можете запустить сервер разработки Flask, набрав:
$ FLASK_ENV=development FLASK_APP=app.py flask run
при посещении в веб-браузереhttp://localhost:5000/
, вы получите текстHello World
Привет!
Мы внесем некоторые изменения в приведенный выше фрагмент кода, чтобы он соответствовал нашему определению API. Во-первых, мы переименуемpredict
метод. Мы обновляем путь к конечной точке на/predict
. Поскольку файл изображения будет отправлен через HTTP-запрос POST, мы обновим его, чтобы он также принимал только запросы POST:
@app.route('/predict', methods=['POST'])
def predict():
return 'Hello World!'
Мы также изменим тип ответа, чтобы он возвращал ответ JSON, содержащий идентификатор и имя класса ImageNet. после обновленияapp.py
Теперь файл:
from flask import Flask, jsonify
app = Flask(__name__)
@app.route('/predict', methods=['POST'])
def predict():
return jsonify({'class_id': 'IMAGE_NET_XXX', 'class_name': 'Cat'})
4. Рассуждение
В следующей части мы сосредоточимся на написании кода логического вывода. Это будет состоять из двух частей: первая часть — подготовить изображение, чтобы его можно было передать в DenseNet; во второй части мы напишем код, чтобы получить фактические прогнозы от модели.
4.1 Подготовка изображений
Модель DenseNet требует, чтобы изображения были 3-канальными изображениями RGB размером 224 x 224. Мы также нормализуем тензор изображения с желаемыми значениями среднего и стандартного отклонения. вы можете нажатьздесьчтобы узнать больше об этом.
мы будем использоватьtorchvision
библиотекаtransforms
для создания конвейера преобразования, который преобразует изображение по мере необходимости. ты сможешьздесьПодробнее о конверсиях.
import io
import torchvision.transforms as transforms
from PIL import Image
def transform_image(image_bytes):
my_transforms = transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
image = Image.open(io.BytesIO(image_bytes))
return my_transforms(image).unsqueeze(0)
Приведенный выше метод берет данные изображения в байтах, применяет ряд преобразований и возвращает тензор. Чтобы проверить описанный выше подход, прочитайте файл изображения в байтовом режиме (сначала замените ../_static/img/sample_file.jpeg фактическим путем к файлу на вашем компьютере) и посмотрите, получится ли у вас тензор:
with open("../_static/img/sample_file.jpeg", 'rb') as f:
image_bytes = f.read()
tensor = transform_image(image_bytes=image_bytes)
print(tensor)
- Выходной результат:
tensor([[[[ 0.4508, 0.4166, 0.3994, ..., -1.3473, -1.3302, -1.3473],
[ 0.5364, 0.4851, 0.4508, ..., -1.2959, -1.3130, -1.3302],
[ 0.7077, 0.6392, 0.6049, ..., -1.2959, -1.3302, -1.3644],
...,
[ 1.3755, 1.3927, 1.4098, ..., 1.1700, 1.3584, 1.6667],
[ 1.8893, 1.7694, 1.4440, ..., 1.2899, 1.4783, 1.5468],
[ 1.6324, 1.8379, 1.8379, ..., 1.4783, 1.7352, 1.4612]],
[[ 0.5728, 0.5378, 0.5203, ..., -1.3704, -1.3529, -1.3529],
[ 0.6604, 0.6078, 0.5728, ..., -1.3004, -1.3179, -1.3354],
[ 0.8529, 0.7654, 0.7304, ..., -1.3004, -1.3354, -1.3704],
...,
[ 1.4657, 1.4657, 1.4832, ..., 1.3256, 1.5357, 1.8508],
[ 2.0084, 1.8683, 1.5182, ..., 1.4657, 1.6583, 1.7283],
[ 1.7458, 1.9384, 1.9209, ..., 1.6583, 1.9209, 1.6408]],
[[ 0.7228, 0.6879, 0.6531, ..., -1.6476, -1.6302, -1.6476],
[ 0.8099, 0.7576, 0.7228, ..., -1.6476, -1.6476, -1.6650],
[ 1.0017, 0.9145, 0.8797, ..., -1.6476, -1.6650, -1.6999],
...,
[ 1.6291, 1.6291, 1.6465, ..., 1.6291, 1.8208, 2.1346],
[ 2.1868, 2.0300, 1.6814, ..., 1.7685, 1.9428, 2.0125],
[ 1.9254, 2.0997, 2.0823, ..., 1.9428, 2.2043, 1.9080]]]])
4.2 Прогноз
Предварительно обученная модель DenseNet 121 теперь будет использоваться для прогнозирования класса изображения. мы будем использоватьtorchvision
Библиотека внутри библиотеки, которая загружает модель и делает вывод. В этом примере мы будем использовать предварительно обученную модель, но вы можете использовать тот же подход со своей собственной моделью. на эторуководствоУзнайте больше о загрузке моделей в файлы .
from torchvision import models
# 确保使用`pretrained`作为`True`来使用预训练的权重:
model = models.densenet121(pretrained=True)
# 由于我们仅将模型用于推理,因此请切换到“eval”模式:
model.eval()
def get_prediction(image_bytes):
tensor = transform_image(image_bytes=image_bytes)
outputs = model.forward(tensor)
_, y_hat = outputs.max(1)
return y_hat
Тензорy_hat
Индекс идентификатора класса, который будет содержать предсказание. Однако нам нужно легко читаемое имя класса. Для этого нам нужен идентификатор класса, чтобы назвать карту. будетДокументСкачать какimagenet_class_index.json
и запомните, где он был сохранен (или, если вы точно следовали шагам, описанным в этом руководстве, сохраните его вtutorials/_static
середина). Этот файл содержит сопоставление идентификаторов классов ImageNet с именами классов ImageNet. Мы загрузим этот файл JSON и получим имя класса, который предсказывает индекс.
import json
imagenet_class_index = json.load(open('../_static/imagenet_class_index.json'))
def get_prediction(image_bytes):
tensor = transform_image(image_bytes=image_bytes)
outputs = model.forward(tensor)
_, y_hat = outputs.max(1)
predicted_idx = str(y_hat.item())
return imagenet_class_index[predicted_idx]
используя словарьimagenet_class_index
Прежде сначала мы преобразуем значения тензора в строковые значения, потому что словарьimagenet_class_index
Ключи - это строки. Мы проверим вышеуказанный метод:
with open("../_static/img/sample_file.jpeg", 'rb') as f:
image_bytes = f.read()
print(get_prediction(image_bytes=image_bytes))
- Выходной результат:
['n02124075', 'Egyptian_cat']
Вы получите такой ответ:
['n02124075', 'Egyptian_cat']
Первый элемент массива — это идентификатор класса ImageNet, а второй — удобочитаемое имя.
Примечание. Вы заметили, что переменные модели не
get_prediction
часть метода? Или почему модель является глобальной переменной? С точки зрения памяти и вычислений загрузка модели может быть
дорогая операция. Если вы загрузите модель вget_prediction
метод, модель излишне загружается каждый раз, когда вызывается метод. Поскольку мы создаем веб-сервис
server, так что могут быть тысячи запросов в секунду, поэтому мы не должны тратить время на повторную загрузку модели для каждого вывода. Поэтому мы загружаем модель в память только один раз. В живых
В производственных системах вычисления должны использоваться эффективно, чтобы иметь возможность обрабатывать запросы в масштабе, поэтому модели обычно следует загружать до обработки запросов.
5. Интегрируйте модель в наш сервер API.
В последней части мы добавляем модель на сервер Flask API. Поскольку наш сервер API должен получить файл изображения, мы обновимpredict
способ чтения файла из запроса:
from flask import request
@app.route('/predict', methods=['POST'])
def predict():
if request.method == 'POST':
# 从请求中获得文件
file = request.files['file']
# 转化为字节
img_bytes = file.read()
class_id, class_name = get_prediction(image_bytes=img_bytes)
return jsonify({'class_id': class_id, 'class_name': class_name})
app.py
Теперь файл завершен. Вот полная версия; замените путь на путь, по которому вы сохранили файл, и он должен работать следующим образом:
import io
import json
from torchvision import models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request
app = Flask(__name__)
imagenet_class_index = json.load(open('<PATH/TO/.json/FILE>/imagenet_class_index.json'))
model = models.densenet121(pretrained=True)
model.eval()
def transform_image(image_bytes):
my_transforms = transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
image = Image.open(io.BytesIO(image_bytes))
return my_transforms(image).unsqueeze(0)
def get_prediction(image_bytes):
tensor = transform_image(image_bytes=image_bytes)
outputs = model.forward(tensor)
_, y_hat = outputs.max(1)
predicted_idx = str(y_hat.item())
return imagenet_class_index[predicted_idx]
@app.route('/predict', methods=['POST'])
def predict():
if request.method == 'POST':
file = request.files['file']
img_bytes = file.read()
class_id, class_name = get_prediction(image_bytes=img_bytes)
return jsonify({'class_id': class_id, 'class_name': class_name})
if __name__ == '__main__':
app.run()
Протестируем наш веб-сервер, запустим:
$ FLASK_ENV=development FLASK_APP=app.py flask run
мы можем использоватьrequestsбиблиотека для отправки POST-запроса в наше приложение:
import requests
resp = requests.post("http://localhost:5000/predict",
files={"file": open('<PATH/TO/.jpg/FILE>/cat.jpg','rb')})
Распечататьresp.json()
Отображаются следующие результаты:
{"class_id": "n02124075", "class_name": "Egyptian_cat"}
6. Следующие шаги
Серверы, которые мы пишем, очень просты и могут не делать всего, что нужно производственному приложению. Итак, вот несколько вещей, которые вы можете сделать, чтобы улучшить его:
- конечная точка
/predict
Предполагается, что в запросе всегда будет файл изображения. Это может не работать для всех запросов. Наши пользователи могут отправлять изображения с дополнительными параметрами или вообще без изображений. - Пользователи также могут отправлять файлы без изображений. Поскольку мы не обрабатываем ошибки, это сломает наш сервер. Добавление явных путей обработки ошибок для создания исключений позволит нам
Лучшая обработка неправильного ввода - Даже если модель может распознавать большое количество категорий изображений, она может не распознавать все изображения. Улучшите реализацию, чтобы обрабатывать случаи, когда модель не может ничего распознать на изображении.
- Мы запускаем сервер Flask в режиме разработки, который не подходит для развертывания в производстве. вы можете просмотретьруководство
для развертывания сервера Flask в производстве. - Вы также можете добавить пользовательский интерфейс, создав страницу с формой, которая принимает изображение и отображает прогнозы. Посмотреть похожиепроектдемо и егоисходный код.
- В этом руководстве мы только показали, как создать службу, которая может возвращать прогнозы для одного изображения за раз. Мы можем изменить сервис, чтобы он мог возвращать прогнозы для нескольких изображений одновременно. также,service-streamer
Библиотека автоматически ставит в очередь запросы к службе и сэмплирует их в минимальные пакеты, доступные модели. вы можете просмотретьэтот учебник. - Наконец, мы рекомендуем вам ознакомиться с другими руководствами по развертыванию моделей PyTorch в верхней части страницы.
Сводная станция блога о технологиях искусственного интеллекта Panchuang: http://docs.panchuang.net/PyTorch, официальная китайская учебная станция: http://pytorch.panchuang.net/OpenCV, официальный китайский документ: http://woshicver.com/