Чтение этого руководства требует определенной основы работы с mxnet и gluon. В этой статье используется обучающий код процессора. Адрес личного блога:zmkwjx.github.ioАдрес этой статьи на гитхабе:GitHub.com/Такой высокомерный и кричащий/G Luo…Адрес официального сайта GluonTS:gluon-ts.mxnet.io
1. Окружающая среда и установка
1.1 Среда разработки этой статьи: Ubuntu16.04TS, Python3.71.2 Быстрая установка
pip install matplotlib numpy pandas pathlib
pip install mxnet mxnet-mkl gluon gluonts
2. Программа обучения
#Third-party imports
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path
from gluonts.model import deepar
from gluonts.dataset import common
from gluonts.dataset.util import to_pandas
from gluonts.model.predictor import Predictor
2.1 Загрузка обучающих данных Twitter_volume_AMZN.csv
url = "./data/Twitter_volume_AMZN.csv"
df = pd.read_csv(url, header=0, index_col=0)
data = common.ListDataset([{"start": df.index[0],
"target": df.value[:"2015-04-23 00:00:00"]}], freq="H")
- pd.read_csvПрочитайте файл csv и преобразуйте его во фрейм данных.
- common.ListDatasetЗагрузить тренировочные данные
2.2 ИнтерпретацияListDataset
class gluonts.dataset.common.ListDataset(data_iter: Iterable[Dict[str, Any]], freq: str, one_dim_target: bool = True)
- data_iter:Итерируемый объект дает все элементы в наборе данных. Каждый элемент должен быть словарем, отображающим строки в значения. Например: {"начало": "2014-09-07", "цель": [0.1, 0.2]}
- частота:Частота наблюдений во временном ряду.
- one_dim_target:Следует ли принимать только одномерные целевые временные ряды.
2.3 Обучение существующей моделиGluonTS поставляется со многими готовыми моделями. Все, что нужно сделать пользователю, это настроить некоторые гиперпараметры. Существующие модели сосредоточены на вероятностных прогнозах (но не ограничиваются ими). Вероятностные прогнозы — это предсказания, сделанные в форме вероятностных распределений, а не простых одноточечных оценок.
estimator = deepar.DeepAREstimator(freq="H", prediction_length=24)
predictor = estimator.train(training_data=data)
- Создайте сеть DeepAR и обучите ее
- предсказание_длина:продолжительность прогноза
- training_data:тренировочные данные
2.4 Предварительный просмотр результатов обучения
for test_entry, forecast in zip(train_data, predictor.predict(train_data)):
to_pandas(test_entry)[-60:].plot(linewidth=2)
forecast.plot(color='g', prediction_intervals=[50.0, 90.0])
plt.grid(which='both')
plt.show()
- результат прогноза
2.5 Вывод результатов обучения
prediction = next(predictor.predict(train_data))
print(prediction.mean)
prediction.plot(output_file='graph.png')
- OUT
2.5 Сохраните обучающую модель
predictor.serialize(Path("此处填入Model文件夹的绝对路径"))
2.6 Использование обученной модели
predictor = Predictor.deserialize(Path("此处填入Model文件夹的绝对路径"))
- пример
import pandas as pd
from pathlib import Path
from gluonts.dataset import common
from gluonts.dataset.util import to_pandas
from gluonts.model.predictor import Predictor
url = "./data/Twitter_volume_AMZN.csv"
df = pd.read_csv(url, header=0, index_col=0)
train_data = common.ListDataset([{"start": df.index[0],
"target": df.value[:"2015-04-23 00:00:00"]}],freq="H")
predictor = Predictor.deserialize(Path("此处填入Model文件夹的绝对路径"))
prediction = next(predictor.predict(train_data))
print(prediction.mean)
prediction.plot(output_file='graph.png')