Автор: Кингью и Ланкинг
FlappyBird — мобильная игра, запущенная в 2013 году, которая быстро стала популярной в Интернете благодаря простому игровому процессу, но чрезвычайно сложным настройкам. Благодаря разработке передовых алгоритмов, таких как глубокое обучение (DL) и обучение с подкреплением (RL), мы можем легко обучить агента управлять Flappy Bird с помощью Java.
История начинается с«О чем говорят большие парни на GitHub после приветствия? 》, то сегодня мы рассмотрим, как обучить феникса с помощью Java. Для игрового проекта мы использовали базовую библиотеку классов Java, написанную толькоFlappyBirdигра. Для обучения мы используемDeepJavaLibraryФреймворк глубокого обучения на основе Java для создания обучающих сетей с подкреплением и их обучения. После почти 3 миллионов шагов (четыре часа) обучения птица набрала более 8000 баллов и гибко перемещается между водопроводными трубами.
В этой статье мы начнем с принципа шаг за шагом, чтобы реализовать алгоритм обучения с подкреплением и использовать его для обучения игре. Если в какой-то момент непонятно, как действовать дальше, вы можете обратиться к исходному коду проекта.
адрес проекта:GitHub.com/kingfish пересечение/R…
Архитектура обучения с подкреплением (RL)
В этом разделе мы представим основные алгоритмы и нейронные сети, которые помогут вам лучше понять, как тренироваться. Этот проект иDeepLearningFlappyBirdАналогичный подход был использован для обучения. Общая архитектура алгоритма Q-Learning + Convolutional Neural Network (CNN), в которой хранится состояние каждого кадра игры, то есть действия, предпринимаемые птицей, и эффекты после совершения действий. используются в качестве обучающих данных сверточной нейронной сети.
Краткий обзор обучения CNN
Входные данные CNN - это 4 последовательных кадра изображений. Мы складываем эти изображения как текущее «наблюдение» за птицей, и изображения будут преобразованы в изображения в градациях серого, чтобы уменьшить требуемые ресурсы обучения. Матричная форма хранения изображений(batch size, 4 (frames), 80 (width), 80 (height))
Элементами массива являются значения пикселей текущего кадра, и эти данные будут вводиться в CNN и выводиться(batch size, 2)
Матрица , второе измерение матрицы — доход, соответствующий птице (машет крыльями и не предпринимает действий).
тренировочные данные
После того, как птица сделает ход, получимpreObservation and currentObservation
То есть два последовательных изображения по 4 кадра представляют состояние птицы до и после действия. Тогда мы будемpreObservation, currentObservation, action, reward, terminal
Составленная пятерка сохраняется в replayBuffer как шаг. Это обучающий набор данных ограниченного размера, который динамически обновляет содержимое с учетом последних операций.
public void step(NDList action, boolean training) {
if (action.singletonOrThrow().getInt(1) == 1) {
bird.birdFlap();
}
stepFrame();
NDList preObservation = currentObservation;
currentObservation = createObservation(currentImg);
FlappyBirdStep step = new FlappyBirdStep(manager.newSubManager(),
preObservation, currentObservation, action, currentReward, currentTerminal);
if (training) {
replayBuffer.addStep(step);
}
if (gameState == GAME_OVER) {
restartGame();
}
}
три цикла обучения
Обучение разделено на 3 разные эпохи, чтобы лучше генерировать обучающие данные:
- Наблюдайте за циклом: случайным образом генерируйте обучающие данные
- Цикл исследования: обновляйте обучающие данные, комбинируя случайные действия и действия логического вывода.
- Цикл обучения: действие рассуждения приводит к генерации новых данных
Благодаря этому тренировочному режиму мы можем лучше достичь желаемого эффекта.
В цикле исследования мы будем выбирать случайные действия в соответствии с весами или использовать действия, предполагаемые моделью, как действия птицы. В начале обучения вес случайных действий может быть очень большим, потому что решения модели очень неточны (или даже менее случайны). Позже в процессе обучения, по мере того как модель усваивает больше действий, мы продолжаем увеличивать вес действия модели для вывода и в конечном итоге делаем его доминирующим действием. Параметр, который модулирует случайное действие, называется эпсилон и меняется в ходе обучения.
public NDList chooseAction(RlEnv env, boolean training) {
if (training && RandomUtils.random() < exploreRate.getNewValue(counter++)) {
return env.getActionSpace().randomAction();
} else return baseAgent.chooseAction(env, training);
}
тренировочная логика
Во-первых, мы будем случайным образом выбирать пакет данных из replayBuffer в качестве обучающего набора. Затем передайте preObservation в нейронную сеть, чтобы получить вознаграждение (Q) за все действия в виде прогнозируемого значения:
NDList QReward = trainer.forward(preInput);
NDList Q = new NDList(QReward.singletonOrThrow()
.mul(actionInput.singletonOrThrow())
.sum(new int[]{1}));
PostObservation также вводится в нейронную сеть, а вознаграждение (targetQ) всех действий рассчитывается как истинное значение в соответствии с марковским процессом принятия решений и функцией ценности Беллмана:
// 将 postInput 输入到神经网络中得到 targetQReward 是 (batchsize,2) 的矩阵。根据 Q-learning 的算法,每一次的 targetQ 需要根据当前环境是否结束算出不同的值,因此需要将每一个 step 的 targetQ 单独算出后再将 targetQ 堆积成 NDList。
NDList targetQReward = trainer.forward(postInput);
NDArray[] targetQValue = new NDArray[batchSteps.length];
for (int i = 0; i < batchSteps.length; i++) {
if (batchSteps[i].isTerminal()) {
targetQValue[i] = batchSteps[i].getReward();
} else {
targetQValue[i] = targetQReward.singletonOrThrow().get(i)
.max()
.mul(rewardDiscount)
.add(rewardInput.singletonOrThrow().get(i));
}
}
NDList targetQBatch = new NDList();
Arrays.stream(targetQValue).forEach(value -> targetQBatch.addAll(new NDList(value)));
NDList targetQ = new NDList(NDArrays.stack(targetQBatch, 0));
В конце обучения вычисляются значения потерь для Q и targetQ и обновляются веса в CNN.
Модель сверточной нейронной сети (CNN)
Мы используем архитектуру нейронной сети с 3 сверточными слоями, 4 функциями активации relu и 2 полносвязными слоями.
layer | input shape | output shape |
---|---|---|
conv2d | (batchSize, 4, 80, 80) | (batchSize,4,20,20) |
conv2d | (batchSize, 4, 20 ,20) | (batchSize, 32, 9, 9) |
conv2d | (batchSize, 32, 9, 9) | (batchSize, 64, 7, 7) |
linear | (batchSize, 3136) | (batchSize, 512) |
linear | (batchSize, 512) | (batchSize, 2) |
тренировочный процесс
Библиотека RL от DJL предоставляет очень удобный интерфейс для реализации обучения с подкреплением: (RlEnv, RlAgent, ReplayBuffer).
- Реализуйте интерфейс RlAgent для создания обучаемого агента.
- Данные, необходимые для обучения, можно сгенерировать, внедрив интерфейс RlEnv в существующую игровую среду.
- Создайте ReplayBuffer для хранения и динамического обновления данных обучения.
После реализации этих интерфейсов просто вызовитеstepметод:
RlEnv.step(action, training);
Этот метод вводит действия, выбранные RlAgent, в игровую среду для обратной связи. Что мы можем предоставить в RlEnvrunEnviroment
В методе вызывается пошаговый метод, а дальше нужно только повторить выполнениеrunEnvironment
метод непрерывной генерации данных для обучения.
public Step[] runEnvironment(RlAgent agent, boolean training) {
// run the game
NDList action = agent.chooseAction(this, training);
step(action, training);
if (training) {
batchSteps = this.getBatch();
}
return batchSteps;
}
Мы устанавливаем количество шагов, которые ReplayBuffer может хранить, равным 50 000. В течение периода наблюдения мы сначала будем хранить 1000 шагов, сгенерированных случайными действиями, в replayBuffer, чтобы агент мог быстрее учиться на случайных действиях.
Во время циклов исследования и обучения нейронная сеть случайным образом генерирует обучающие наборы из replayBuffer и передает их в модель для обучения. Мы итерируем нейронную сеть, используя оптимизатор Adam и функцию потерь MSE.
Предварительная обработка входных данных нейронной сети
Сначала измените размер изображения на80x80
и преобразованы в оттенки серого, что помогает увеличить скорость обучения без потери информации.
public static NDArray imgPreprocess(BufferedImage observation) {
return NDImageUtils.toTensor(
NDImageUtils.resize(
ImageFactory.getInstance().fromImage(observation)
.toNDArray(NDManager.newBaseManager(),
Image.Flag.GRAYSCALE) ,80,80));
}
Затем мы берем четыре последовательных кадра изображений в качестве входных данных.Чтобы получить непрерывные изображения четырех последовательных кадров, мы поддерживаем глобальную очередь изображений для сохранения изображений в игровом потоке, заменяем самый старый кадр после каждого действия, а затем помещаем The изображения в очереди складываются в один массив NDArray.
public NDList createObservation(BufferedImage currentImg) {
NDArray observation = GameUtil.imgPreprocess(currentImg);
if (imgQueue.isEmpty()) {
for (int i = 0; i < 4; i++) {
imgQueue.offer(observation);
}
return new NDList(NDArrays.stack(new NDList(observation, observation, observation, observation), 1));
} else {
imgQueue.remove();
imgQueue.offer(observation);
NDArray[] buf = new NDArray[4];
int i = 0;
for (NDArray nd : imgQueue) {
buf[i++] = nd;
}
return new NDList(NDArrays.stack(new NDList(buf[0], buf[1], buf[2], buf[3]), 1));
}
}
Как только вышеперечисленные части будут завершены, мы можем начать обучение. Оптимизация обучения Для повышения эффективности обучения мы отключили графический интерфейс, чтобы ускорить создание образцов. И используйте многопоточность Java для запуска цикла обучения и цикла генерации образцов в отдельных потоках.
List<Callable<Object>> callables = new ArrayList<>(numOfThreads);
callables.add(new GeneratorCallable(game, agent, training));
if(training) {
callables.add(new TrainerCallable(model, agent));
}
Суммировать
Модель обучалась на графическом процессоре NVIDIA T4 около 4 часов с обновлением 3 миллионов шагов. После тренировки птица может гибко управлять своими движениями и гибко челночить между трубами. Обученная модель также загружается в репозиторий для тестирования. В этом проекте DJL предоставляет мощный обучающий API и поддержку библиотеки моделей, что делает его удобным в процессе разработки Java.
Полный код этого проекта:GitHub.com/kingfish пересечение/R…
Wechat найдите общедоступную учетную запись «HelloGitHub», нажмите «Присоединиться», чтобы присоединиться к нам.