Обучите «феникса» на Java

машинное обучение Java
Обучите «феникса» на Java

Автор: Кингью и Ланкинг

FlappyBird — мобильная игра, запущенная в 2013 году, которая быстро стала популярной в Интернете благодаря простому игровому процессу, но чрезвычайно сложным настройкам. Благодаря разработке передовых алгоритмов, таких как глубокое обучение (DL) и обучение с подкреплением (RL), мы можем легко обучить агента управлять Flappy Bird с помощью Java.

История начинается с«О чем говорят большие парни на GitHub после приветствия? 》, то сегодня мы рассмотрим, как обучить феникса с помощью Java. Для игрового проекта мы использовали базовую библиотеку классов Java, написанную толькоFlappyBirdигра. Для обучения мы используемDeepJavaLibraryФреймворк глубокого обучения на основе Java для создания обучающих сетей с подкреплением и их обучения. После почти 3 миллионов шагов (четыре часа) обучения птица набрала более 8000 баллов и гибко перемещается между водопроводными трубами.

В этой статье мы начнем с принципа шаг за шагом, чтобы реализовать алгоритм обучения с подкреплением и использовать его для обучения игре. Если в какой-то момент непонятно, как действовать дальше, вы можете обратиться к исходному коду проекта.

адрес проекта:GitHub.com/kingfish пересечение/R…

Архитектура обучения с подкреплением (RL)

В этом разделе мы представим основные алгоритмы и нейронные сети, которые помогут вам лучше понять, как тренироваться. Этот проект иDeepLearningFlappyBirdАналогичный подход был использован для обучения. Общая архитектура алгоритма Q-Learning + Convolutional Neural Network (CNN), в которой хранится состояние каждого кадра игры, то есть действия, предпринимаемые птицей, и эффекты после совершения действий. используются в качестве обучающих данных сверточной нейронной сети.

Краткий обзор обучения CNN

Входные данные CNN - это 4 последовательных кадра изображений. Мы складываем эти изображения как текущее «наблюдение» за птицей, и изображения будут преобразованы в изображения в градациях серого, чтобы уменьшить требуемые ресурсы обучения. Матричная форма хранения изображений(batch size, 4 (frames), 80 (width), 80 (height))Элементами массива являются значения пикселей текущего кадра, и эти данные будут вводиться в CNN и выводиться(batch size, 2)Матрица , второе измерение матрицы — доход, соответствующий птице (машет крыльями и не предпринимает действий).

тренировочные данные

После того, как птица сделает ход, получимpreObservation and currentObservationТо есть два последовательных изображения по 4 кадра представляют состояние птицы до и после действия. Тогда мы будемpreObservation, currentObservation, action, reward, terminalСоставленная пятерка сохраняется в replayBuffer как шаг. Это обучающий набор данных ограниченного размера, который динамически обновляет содержимое с учетом последних операций.

public void step(NDList action, boolean training) {
    if (action.singletonOrThrow().getInt(1) == 1) {
        bird.birdFlap();
    }
    stepFrame();
    NDList preObservation = currentObservation;
    currentObservation = createObservation(currentImg);
    FlappyBirdStep step = new FlappyBirdStep(manager.newSubManager(),
            preObservation, currentObservation, action, currentReward, currentTerminal);
    if (training) {
        replayBuffer.addStep(step);
    }
    if (gameState == GAME_OVER) {
        restartGame();
    }
}

три цикла обучения

Обучение разделено на 3 разные эпохи, чтобы лучше генерировать обучающие данные:

  • Наблюдайте за циклом: случайным образом генерируйте обучающие данные
  • Цикл исследования: обновляйте обучающие данные, комбинируя случайные действия и действия логического вывода.
  • Цикл обучения: действие рассуждения приводит к генерации новых данных

Благодаря этому тренировочному режиму мы можем лучше достичь желаемого эффекта.

В цикле исследования мы будем выбирать случайные действия в соответствии с весами или использовать действия, предполагаемые моделью, как действия птицы. В начале обучения вес случайных действий может быть очень большим, потому что решения модели очень неточны (или даже менее случайны). Позже в процессе обучения, по мере того как модель усваивает больше действий, мы продолжаем увеличивать вес действия модели для вывода и в конечном итоге делаем его доминирующим действием. Параметр, который модулирует случайное действие, называется эпсилон и меняется в ходе обучения.

public NDList chooseAction(RlEnv env, boolean training) {
    if (training && RandomUtils.random() < exploreRate.getNewValue(counter++)) {
        return env.getActionSpace().randomAction();
    } else return baseAgent.chooseAction(env, training);
}

тренировочная логика

Во-первых, мы будем случайным образом выбирать пакет данных из replayBuffer в качестве обучающего набора. Затем передайте preObservation в нейронную сеть, чтобы получить вознаграждение (Q) за все действия в виде прогнозируемого значения:

NDList QReward = trainer.forward(preInput);
NDList Q = new NDList(QReward.singletonOrThrow()
        .mul(actionInput.singletonOrThrow())
        .sum(new int[]{1}));

PostObservation также вводится в нейронную сеть, а вознаграждение (targetQ) всех действий рассчитывается как истинное значение в соответствии с марковским процессом принятия решений и функцией ценности Беллмана:

// 将 postInput 输入到神经网络中得到 targetQReward 是 (batchsize,2) 的矩阵。根据 Q-learning 的算法,每一次的 targetQ 需要根据当前环境是否结束算出不同的值,因此需要将每一个 step 的 targetQ 单独算出后再将 targetQ 堆积成 NDList。
NDList targetQReward = trainer.forward(postInput);
NDArray[] targetQValue = new NDArray[batchSteps.length]; 
for (int i = 0; i < batchSteps.length; i++) {
    if (batchSteps[i].isTerminal()) {
        targetQValue[i] = batchSteps[i].getReward();
    } else {
        targetQValue[i] = targetQReward.singletonOrThrow().get(i)
                .max()
                .mul(rewardDiscount)
                .add(rewardInput.singletonOrThrow().get(i));
    }
}
NDList targetQBatch = new NDList();
Arrays.stream(targetQValue).forEach(value -> targetQBatch.addAll(new NDList(value)));
NDList targetQ = new NDList(NDArrays.stack(targetQBatch, 0));

В конце обучения вычисляются значения потерь для Q и targetQ и обновляются веса в CNN.

Модель сверточной нейронной сети (CNN)

Мы используем архитектуру нейронной сети с 3 сверточными слоями, 4 функциями активации relu и 2 полносвязными слоями.

layer input shape output shape
conv2d (batchSize, 4, 80, 80) (batchSize,4,20,20)
conv2d (batchSize, 4, 20 ,20) (batchSize, 32, 9, 9)
conv2d (batchSize, 32, 9, 9) (batchSize, 64, 7, 7)
linear (batchSize, 3136) (batchSize, 512)
linear (batchSize, 512) (batchSize, 2)

тренировочный процесс

Библиотека RL от DJL предоставляет очень удобный интерфейс для реализации обучения с подкреплением: (RlEnv, RlAgent, ReplayBuffer).

  • Реализуйте интерфейс RlAgent для создания обучаемого агента.
  • Данные, необходимые для обучения, можно сгенерировать, внедрив интерфейс RlEnv в существующую игровую среду.
  • Создайте ReplayBuffer для хранения и динамического обновления данных обучения.

После реализации этих интерфейсов просто вызовитеstepметод:

RlEnv.step(action, training);

Этот метод вводит действия, выбранные RlAgent, в игровую среду для обратной связи. Что мы можем предоставить в RlEnvrunEnviromentВ методе вызывается пошаговый метод, а дальше нужно только повторить выполнениеrunEnvironmentметод непрерывной генерации данных для обучения.

public Step[] runEnvironment(RlAgent agent, boolean training) {
    // run the game
    NDList action = agent.chooseAction(this, training);
    step(action, training);
    if (training) {
        batchSteps = this.getBatch();
    }
    return batchSteps;
}

Мы устанавливаем количество шагов, которые ReplayBuffer может хранить, равным 50 000. В течение периода наблюдения мы сначала будем хранить 1000 шагов, сгенерированных случайными действиями, в replayBuffer, чтобы агент мог быстрее учиться на случайных действиях.

Во время циклов исследования и обучения нейронная сеть случайным образом генерирует обучающие наборы из replayBuffer и передает их в модель для обучения. Мы итерируем нейронную сеть, используя оптимизатор Adam и функцию потерь MSE.

Предварительная обработка входных данных нейронной сети

Сначала измените размер изображения на80x80и преобразованы в оттенки серого, что помогает увеличить скорость обучения без потери информации.

public static NDArray imgPreprocess(BufferedImage observation) {
    return NDImageUtils.toTensor(
            NDImageUtils.resize(
                    ImageFactory.getInstance().fromImage(observation)
                    .toNDArray(NDManager.newBaseManager(),
                     Image.Flag.GRAYSCALE) ,80,80));
}

Затем мы берем четыре последовательных кадра изображений в качестве входных данных.Чтобы получить непрерывные изображения четырех последовательных кадров, мы поддерживаем глобальную очередь изображений для сохранения изображений в игровом потоке, заменяем самый старый кадр после каждого действия, а затем помещаем The изображения в очереди складываются в один массив NDArray.

public NDList createObservation(BufferedImage currentImg) {
    NDArray observation = GameUtil.imgPreprocess(currentImg);
    if (imgQueue.isEmpty()) {
        for (int i = 0; i < 4; i++) {
            imgQueue.offer(observation);
        }
        return new NDList(NDArrays.stack(new NDList(observation, observation, observation, observation), 1));
    } else {
        imgQueue.remove();
        imgQueue.offer(observation);
        NDArray[] buf = new NDArray[4];
        int i = 0;
        for (NDArray nd : imgQueue) {
            buf[i++] = nd;
        }
        return new NDList(NDArrays.stack(new NDList(buf[0], buf[1], buf[2], buf[3]), 1));
    }
}

Как только вышеперечисленные части будут завершены, мы можем начать обучение. Оптимизация обучения Для повышения эффективности обучения мы отключили графический интерфейс, чтобы ускорить создание образцов. И используйте многопоточность Java для запуска цикла обучения и цикла генерации образцов в отдельных потоках.

List<Callable<Object>> callables = new ArrayList<>(numOfThreads);
callables.add(new GeneratorCallable(game, agent, training));
if(training) {
    callables.add(new TrainerCallable(model, agent));
}

Суммировать

Модель обучалась на графическом процессоре NVIDIA T4 около 4 часов с обновлением 3 миллионов шагов. После тренировки птица может гибко управлять своими движениями и гибко челночить между трубами. Обученная модель также загружается в репозиторий для тестирования. В этом проекте DJL предоставляет мощный обучающий API и поддержку библиотеки моделей, что делает его удобным в процессе разработки Java.

Полный код этого проекта:GitHub.com/kingfish пересечение/R…

Wechat найдите общедоступную учетную запись «HelloGitHub», нажмите «Присоединиться», чтобы присоединиться к нам.