Шаблон проекта глубокого обучения

машинное обучение искусственный интеллект TensorFlow глубокое обучение Архитектура
Шаблон проекта глубокого обучения

Шаблон проекта глубокого обучения(Шаблон проекта DL), который упрощает процесс загрузки данных, построения сетей, моделей обучения и прогнозирования выборок.

Исходный код: https://github.com/SpikeKing/DL-Project-Template

Как пользоваться

Скачать проект

git clone https://github.com/SpikeKing/DL-Project-Template

Создавать и активировать виртуальные среды

virtualenv venv
source venv/bin/activate

Установите зависимости Python

pip install -r requirements.txt

Процесс разработки

  1. Определите свой собственный класс загрузки данных и наследуйте DataLoaderBase;
  2. Определите свой собственный класс сетевой структуры и наследуйте ModelBase;
  3. Определите свой собственный обучающий класс модели и наследуйте TrainerBase;
  4. Определите свой собственный класс прогнозирования выборки и наследуйте InferBase;
  5. Определите свой собственный файл конфигурации и запишите соответствующие параметры эксперимента;

Выполните модель поезда и предскажите выборочные операции.

Пример проекта

ИдентифицироватьMNISTРукописные числа в библиотеке, работыsimple_mnist

тренироваться:

python main_train.py -c configs/simple_mnist_config.json

предсказывать:

python main_test.py -c configs/simple_mnist_config.json -m simple_mnist.weights.10-0.24.hdf5

сетевая структура

网络结构

TensorBoard

TensorBoard

Инженерная архитектура

Рамки

架构

структура папок

├── bases
│   ├── data_loader_base.py             - 数据加载基类
│   ├── infer_base.py                   - 预测样本(推断)基类
│   ├── model_base.py                   - 网络结构(模型)基类
│   ├── trainer_base.py                 - 训练模型基类
├── configs                             - 配置文件夹
│   └── simple_mnist_config.json
├── data_loaders                        - 数据加载文件夹
│   ├── __init__.py
│   ├── simple_mnist_dl.py
├── experiments                         - 实验数据文件夹
│   └── simple_mnist                    - 实验名称
│       ├── checkpoints                 - 存储的模型和参数
│       │   └── simple_mnist.weights.10-0.24.hdf5
│       ├── images                      - 图片
│       │   └── model.png
│       └── logs                        - 日志,如TensorBoard
│           └── events.out.tfevents.1524034653.wang
├── infers                              - 推断文件夹
│   ├── __init__.py
│   ├── simple_mnist_infer.py
├── main_test.py                        - 预测样本入口
├── main_train.py                       - 训练模型入口
├── models                              - 网络结构文件夹
│   ├── __init__.py
│   ├── simple_mnist_model.py
├── requirements.txt                    - 依赖库
├── trainers                            - 训练模型文件夹
│   ├── __init__.py
│   ├── simple_mnist_trainer.py
└── utils                               - 工具文件夹
    ├── __init__.py
    ├── config_utils.py                 - 配置工具类
    ├── np_utils.py                     - NumPy工具类
    ├── utils.py                        - 其他工具类

основные компоненты

DataLoader

Шаги:

  1. Создайте свой собственный класс данных загрузки и наследуйте базовый класс DataLoaderBase;
  2. перезаписыватьget_train_data()иget_test_data(), возвращает данные обучения и тестирования;

Model

Шаги:

  1. Создайте свой собственный класс сетевой структуры и наследуйте базовый класс ModelBase;
  2. перезаписыватьbuild_model(), для создания сетевой структуры;
  3. В конструкторе вызовитеbuild_model();

Уведомление:plot_model()Поддержка рисования сетевых структур;

Trainer

Шаги:

  1. Создайте свой собственный учебный класс и наследуйте базовый класс TrainerBase;
  2. Параметры: модель сетевой структуры, обучающие данные;
  3. перезаписыватьtrain()подгонять данные, обучать структуру сети;

Примечание. Поддерживает вызовы обратных вызовов во время обучения, дополнительно добавляя хранилище моделей, TensorBoard, метрики FPR и т. д.

Infer

Шаги:

  1. Создайте свой собственный класс прогнозирования и наследуйте базовый класс InferBase;
  2. перезаписыватьload_model(), обеспечивающая функцию загрузки модели;
  3. перезаписыватьpredict(), обеспечивая функцию прогнозирования выборки;

Config

Определите параметры, необходимые в процессе обучения модели, в формате JSON, поддержку: такие параметры, как скорость обучения, эпоха, пакет и т. д.

Main

тренироваться:

  1. Создадим конфигурационный файл config;
  2. Создайте класс загрузки данных dl;
  3. Создайте модель класса сетевой структуры;
  4. Создайте тренер класса обучения, параметры - данные обучения и тестирования, модель;
  5. Выполнить train() тренера учебного класса;

предсказывать:

  1. Создадим конфигурационный файл config;
  2. Обработайте предсказанный образец теста;
  3. Создать вывод класса предсказания;
  4. Выполнить прогноз() класса прогнозирования infer;

благодарный

Ссылаться наTensorflow-Project-Templateпроект

By C. L. Wang @ МейтуОблачное бизнес-подразделение