Обучение с подкреплением 8 - Реализация кода DQN Tensorflow 2.0

обучение с подкреплением

в предыдущей статьеОбучение с подкреплением — Введение в DQNВ этой статье мы подробно представили источник DQN и два метода обработки, предложенные алгоритмом DQN для проблемы, связанной с трудностью сходимости обучения с подкреплением: воспроизведение опыта и фиксированное целевое значение. В этой статье мы будем использовать код для реализации алгоритма DQN.

1. Знакомство с окружающей средой

1. Введение в тренажерный зал

Этот алгоритм и алгоритмы, которые будут представлены в следующих статьях, будут использоватьсяOpenAIOpenAIРелизGymGymсреда моделирования,GymGymЭто платформа моделирования для исследования и разработки алгоритмов, связанных с обучением с подкреплением, и предоставляет интерфейсы для многих задач и сред (или игр), и пользователям не нужно слишком много знать о внутренней реализации игры.Его можно использовать для тестирования и моделирование простым вызовом, и совместим с общими библиотеками числовых операций, такими какTensorFlowTensorFlow.

import gym
env = gym.make('CartPole-v1')
env.reset()
for _ in range(1000):
    env.render()
    env.step(env.action_space.sample()) # take a random action
env.close()

Результаты приведены ниже:

aMXZ7Q.gif

Как видно из приведенного выше кода,gymОсновной интерфейсEnv. В качестве единого интерфейса среды,EnvСодержит следующие основные методы:

  • reset(self): Сбросить состояние окружения, вернуться к наблюдению. Если раунд заканчивается, эта функция вызывается для сброса информации об окружении.
  • step(self, action): выполнить действиеactionпродвинуться на один временной шаг, вернутьсяobservation, reward, done, info.
    • observationпредставляет наблюдения за окружающей средой, т.state
    • rewardУказывает полученную награду
    • doneУказывает, закончился ли текущий возврат
    • infoВозвращает некоторую диагностическую информацию, обычно не используемую часто
  • render(self, mode=‘human’, close=False): перерисовывает один кадр окружения.
  • close(self): Закройте среду и очистите память.

Приведенный выше код импортируется первымgymбиблиотека, созданная в строке 2CartPole-v01среды и сбросьте состояние среды в строке 3. в цикле for1000Контроль каждого временного шага (*timestep), пятая строка обновляет экран среды для каждого временного шага, шестая строка выполняет случайное действие (0 или 1) для текущего состояния среды, и, наконец, седьмая строка закрывает среду моделирования. после окончания цикла. .

2. Введение в среду CartPole-v1

CartPole - это базовая среда, предоставляемая тренажерным залом, то есть игра с автомобильным столбом. В игре есть автомобиль, на котором стоит столб. Начальное состояние после каждого сброса будет другим. Автомобиль должен двигаться влево и вправо, чтобы держать шест в вертикальном положении.Чтобы игра продолжалась, должны быть выполнены следующие два условия:

  • Угол, на который наклонена штангаθ\thetaне может быть больше 15°
  • Движущееся положение x тележки должно находиться в определенном диапазоне (2,4 единицы длины от середины до обеих сторон).

заCartPole-v1Среда, действия которой представляют собой два дискретных действия, движение влево (0) и движение вправо (1) Среда включает четыре переменные: положение автомобиля, скорость автомобиля, угол между полюсом и скорость изменения угла. Как показано в следующем коде:

import gym
env = gym.make('CartPole-v0')
print(env.action_space)  # Discrete(2)
observation = env.reset()
print(observation)  # [-0.0390601  -0.04725411  0.0466889   0.02129675]

Ниже сCartPole-v1Возьмите среду в качестве примера, чтобы представить реализацию DQN.

Во-вторых, реализация кода

1. Внедрение пула воспроизведения опыта

class ReplayBuffer:
    def __init__(self, capacity=10000):
        self.capacity = capacity
        self.buffer = []
        self.position = 0

    def push(self, state, action, reward, next_state, done):
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.position] = (state, action, reward, next_state, done)
        self.position = int((self.position + 1) % self.capacity)

    def sample(self, batch_size = args.batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = map(np.stack, zip(*batch))
        return state, action, reward, next_state, done

Сначала определите пул воспроизведения опыта емкостью 10 000, функцияpushЭто добавление информации о том, что агент взаимодействует со средой, в пул опыта.Реализация циклической очереди, используемая здесь, обратите внимание наpositionОперации с указателями. Когда алгоритм необходимо обновить данными, используйтеsampleСлучайным образом выберите один из очереди опытаbatch_size, используйте функцию zip, чтобы упаковать каждый фрагмент данных вместе:

zip: a=[1,2], b=[2,3], zip(a,b) => [(1, 2), (2, 3)]

Затем используйте функцию стека, чтобы преобразовать каждый столбец данных в список и вернуть

2. Структура сети

Код этой серии обучения с подкреплением используетсяtensorlayer, верноtensorflowСделал некоторую инкапсуляцию, чтобы упростить использование, суть в том, чтобы такжеСпециально для обучения с подкреплениемЕсть несколько встроенных интерфейсов, следующиеОфициальный сайтвводить:

TensorLayer — это библиотека глубокого обучения и обучения с подкреплением, основанная на Google TensorFlow, предназначенная для исследователей и инженеров. Он предоставляет высокоуровневый (Higher-Level) API глубокого обучения, который может не только ускорить эксперименты исследователей, но и уменьшить повторяющуюся работу инженеров в реальной разработке. TensorLayer очень легко модифицировать и расширять, что делает его подходящим как для исследований в области машинного обучения, так и для приложений.

Определите сетевую модель:

def create_model(input_state_shape):
    input_layer = tl.layers.Input(input_state_shape)
    layer_1 = tl.layers.Dense(n_units=32, act=tf.nn.relu)(input_layer)
    layer_2 = tl.layers.Dense(n_units=16, act=tf.nn.relu)(layer_1)
    output_layer = tl.layers.Dense(n_units=self.action_dim)(layer_2)
    return tl.models.Model(inputs=input_layer, outputs=output_layer)

self.model = create_model([None, self.state_dim])
self.target_model = create_model([None, self.state_dim])
self.model.train()
self.target_model.eval()

можно увидетьtensorlayerиспользовать сtensorflowПримерно так же, пока естьtensorflowОсновы можно понять с первого взгляда.В приведенном выше коде мы определяем функцию для создания модели сети. затем создайте текущую сетьmodelи целевая сетьtarget_model, мы знаем, что целевая сеть в DQN действует как «цель» для оценки текущего целевого значения, поэтому мы устанавливаем ее в режим оценки и вызываемeval()функция. иmodelСеть — это сеть, которую мы хотим обучить, вызвав функциюtrain()Установите режим тренировки.

3. Процесс управления алгоритмом

for episode in range(train_episodes):
    total_reward, done = 0, False
    while not done:
        action = self.choose_action(state)
        next_state, reward, done, _ = self.env.step(action)
        self.buffer.push(state, action, reward, next_state, done)
        total_reward += reward
        state = next_state
        # self.render()
    if len(self.buffer.buffer) > args.batch_size:
        self.replay()
        self.target_update()

Процесс взаимодействия с окружением был описан выше, здесь мы сосредоточимся на операторе if в строке 10. Когда длина пула опыта больше единицыbatch_size, начать звонитьreplay()функция обновления сетиmodelпараметры сети, затем вызовитеtarget_update()функция поставитьmodelСкопируйте сетевые параметры вtarget_modelИнтернет.

4. Обновление сетевых параметров

def replay(self):
    for _ in range(10):
        states, actions, rewards, next_states, done = self.buffer.sample()
        # compute the target value for the sample tuple
        # target [batch_size, action_dim]
        # target represents the current fitting level
        target = self.target_model(states).numpy()
        next_q_value = tf.reduce_max(self.target_model(next_states), axis=1)
        target_q = rewards + (1 - done) * args.gamma * next_q_value
        target[range(args.batch_size), actions] = target_q

        with tf.GradientTape() as tape:
            q_pred = self.model(states)
            loss = tf.losses.mean_squared_error(target, q_pred)
        grads = tape.gradient(loss, self.model.trainable_weights)
        self.model_optim.apply_gradients(zip(grads, self.model.trainable_weights))

Эта часть должна быть основным кодом DQN.replay()В функции мы циклически обновляем текущую сеть десять раз, цель состоит в том, чтобы изменить частоту обновления двух сетей, что способствует конвергенции сети.

Конкретная часть обновления: мы знаем, что DQN заменяет таблицу Q в Q-Learning нейронной сетью. Между ними много общего. Мы можем сравнить метод обновления Q-Learning. Для табличной формы Q мы получаем значение действия Q определенного состояния непосредственно через индекс, затем в нейронной сети состояние нужно ввести в нейронную сеть и получить путем прямого вычисления.

Δw=альфа(r+γ  maxa'  Q^(s',a',w)Q^(s,a,w))wQ^(s,a,w)\Delta w = \alpha (r + \gamma\;max_{a'}\; \hat{Q}(s', a', w) - \hat{Q}{(s, a, w)})\cdot \nabla_w\hat{Q}{(s, a, w)}

Третья строка сначала получаетbatch_sizeданные, этот процесс называетсяsample. В строке 7 мы сначала получаем текущее значение действия, а цель представляет собой значение действия, рассчитанное в соответствии с текущими параметрами сети. Тогда строка 8 сначала получает все действия следующего состояния при текущих параметрах сети, а затем используетreduce_max()Функция находит максимальное значение действия. Затем в строках 9 и 10 для вычисления используется наибольшее значение действия следующего состояния.target_q, это,r+γ  maxa'  Q^(s',a',w)r + \gamma\;max_{a'}\; \hat{Q}(s', a', w)раздел, затем обновитьtarget. Обратите внимание, что мы использовали приведенное выше при расчете целиtarget_modelсеть, целевая сеть используется только при оценке состояния сети.

Затем мы используемq_pred = self.model(states)Сеть получает текущее состояние сети, то есть по формулеQ^(s,a,w)\hat{Q}{(s, a, w)}, используйте функцию MSE для вычисления функции потерь и, наконец, обновитеmodelИнтернет.

Пожалуйста, обратитесь к полному кодуОбучение с подкреплением — адрес кода DQNПожалуйста, также дайтеstar, Спасибо

3. Резюме DQN

Хотя два решения, предложенные DQN, хороши, все еще есть проблемы, требующие решения, такие как:

  • Правильно ли рассчитано целевое значение Q (Q Target )? пройти всеmax  Qmax\;QЕсть ли проблема с расчетом?
  • Значение Q представляет собой значение действия, поэтому будет ли простая оценка значения действия неточной?

Улучшение, соответствующее первой проблеме, называется Double DQN, а улучшение второй проблемы — Dueling DQN. Все они являются улучшенными версиями DQN, о которых мы расскажем в следующей статье.