Говоря о TD3: от принципа алгоритма к реализации кода

обучение с подкреплением

Эта статья была впервые опубликована на:Уокер ИИ

Хорошо известно, что в алгоритмах обучения с подкреплением, основанных на ценностном обучении, таких как DQN, ошибки аппроксимации функций являются причиной завышенных значений Q и неоптимальных политик. Мы показываем, что эта проблема все еще существует в рамках AC, и предлагаем новые механизмы для минимизации ее воздействия на акторов (политическая функция) и критиков (функция оценки). Наш алгоритм основан на двойном Q-обучении и ограничивает переоценку Q, выбирая меньшую из двух функций оценки. (из реферата статьи TD3)

1. Что такое ТД3

TD3 — это полное название алгоритма градиентной политики с двойной задержкой и глубокой детерминацией. Алгоритм градиента политики глубокого детерминизма в полном имени TD3 — это полное имя DDPG. Так каково происхождение DDPG и TD3? По сути, проще говоря, TD3 — это оптимизированная версия DDPG.

1.1 Почему был предложен TD3

В обучении с подкреплением обучение дискретным действиям основано на DQN, и DQN передаетсяargMaxQtableargMaxQ_{table}Способ выбора действий часто слишком велик для оценки функции ценности, что приводит к ошибкам. В рамках AC непрерывного управления действиями, если каждый шаг оценивается таким образом, ошибка будет накапливаться шаг за шагом, что приводит к невозможности найти оптимальную стратегию, и в конечном итоге алгоритм не может быть сходим.

1.2 Что делает TD3 на основе DDPG

  • Используйте две сети Critic. Используйте две сети для оценки функции ценности действия (идея Double DQN аналогична). выбрать во время обученияmin(Qθ1(s,a),Qθ2(s,a))min(Q^{\theta1}(s,a),Q^{\theta2}(s,a))в качестве оценки.

  • Используйте метод мягкого обновления. Вместо прямого копирования используйтеθ=тθ'+(1т)θ\тета = \тау\тета^′ + (1 - \тау)\тетаспособ обновления сетевых параметров.

  • Используйте политический шум. Исследовательский шум используется при исследовании с помощью Epsilon-Greedy. (По-прежнему используется шум политики, который используется для сглаживания ожиданий политики при обновлении параметров)

  • Используйте отложенное обучение. Сеть Critic обновляется чаще, чем сеть Actor.

  • Используйте градиентный перехват. Перехватите градиент обновления параметра Актера до определенного диапазона.

2. Идея алгоритма TD3

Общая идея алгоритма TD3, сначала инициализировать 3 сети соответственноQθ1,Qθ2,число ПиϕQ_{\theta1},Q_{\theta2},\pi_\phi, параметрθ1,θ2,ϕ\theta_1,\theta_2,\phi, после инициализации 3 целевых сетей скопируйте 3 сетевых параметра, соответствующих начальной инициализации, в целевую сеть соответственно.θ1'θ1,θ2'θ2,ϕ'ϕ\тета{_1^′}\стрелка влево\тета_1,\тета{_2^′}\стрелка влево\тета_2,\фи_′\стрелка влево\фи. Инициализировать буфер воспроизведенияβ\beta. Затем путем повторения цикла оптимальная стратегия находится снова и снова. На каждой итерации добавляется шум при выборе значения действия, так чтоa число Пиϕ(s)+ϵa~\pi_\phi(s) + \epsilon,ϵN(0,о)\epsilon \sim N(0,\sigma),Потом(s,a,r,s')(с, а, г, с ^ ')положить вβ\beta,когдаβ\betaпри достижении определенного значения. затем случайным образом изβ\betaВ примере мини-пакетные данные генерируются черезa~число Пиϕ'(s')+ϵ\ тильда {а} \ сим \ пи _ {\ фи ^ '} (с ^ ') + \ эпсилон,ϵclip(N(0,о~),c,c)\epsilon \sim clip(N(0,\tilde\sigma),-c,c),Рассчитатьs'с^'Значение соответствующего действия в состоянииa~\tilde a,пройти черезs',a~с ^ ', \ тильда а,РассчитатьtargetQ1,targetQ2targetQ1,targetQ2,Получатьmin(targetQ1,targetQ)min(targetQ1,targetQ),заs'с^'изtargetQtargetQценность.

Рассчитано по уравнению БеллманаssизtargetQtargetQзначение, через две текущие сети в соответствии сs,as,aРассчитать текущийQQзначение, после объединения двух текущих сетевыхQQценность иtargetQtargetQЗначение рассчитывается по MSE Loss, и параметр обновляется. После обновления сети Critic сеть Actor принимает отложенное обновление (как правило, Critic обновляется дважды, а Actor — один раз). Обновите сеть Актера путем градиентного восхождения. Целевая сеть обновляется посредством программного обновления.

  • Зачем добавлять шум при расчете значения Action при обновлении сети Critic, чтобы сгладить добавленный ранее шум.

  • Уравнение Беллмана: для непрерывного процесса MRP (Markov Reward Process) (непрерывный процесс вознаграждения за состояние) состояниеssпереход в следующее состояниеs'с^'Вероятность фиксирована, не зависит от предыдущих раундов состояний. в,vvПредставляет функцию, оценивающую текущее состояние.γ\gammaОбычно приближается к 1, но меньше 1.

3. Реализация кода

Код в основном воспроизводится в соответствии с кодом DDPG и документом TD3 и реализован с использованием Pytorch 1.7.

3.1 Создайте сетевую структуру

Сетевая структура Q1 в основном используется для обновления сети Актера.

class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, max_action):
        super(Actor, self).__init__()
        self.f1 = nn.Linear(state_dim, 256)
        self.f2 = nn.Linear(256, 128)
        self.f3 = nn.Linear(128, action_dim)
        self.max_action = max_action
    def forward(self,x):
        x = self.f1(x)
        x = F.relu(x)
        x = self.f2(x)
        x = F.relu(x)
        x = self.f3(x)
        return torch.tanh(x) * self.max_action
class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Critic,self).__init__()
        self.f11 = nn.Linear(state_dim+action_dim, 256)
        self.f12 = nn.Linear(256, 128)
        self.f13 = nn.Linear(128, 1)

        self.f21 = nn.Linear(state_dim + action_dim, 256)
        self.f22 = nn.Linear(256, 128)
        self.f23 = nn.Linear(128, 1)

    def forward(self, state, action):
        sa = torch.cat([state, action], 1)

        x = self.f11(sa)
        x = F.relu(x)
        x = self.f12(x)
        x = F.relu(x)
        Q1 = self.f13(x)

        x = self.f21(sa)
        x = F.relu(x)
        x = self.f22(x)
        x = F.relu(x)
        Q2 = self.f23(x)

        return Q1, Q2

3.2 Определение сети

 self.actor = Actor(self.state_dim, self.action_dim, self.max_action)
        self.target_actor = copy.deepcopy(self.actor)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4)

        #定义critic网络
        self.critic = Critic(self.state_dim, self.action_dim)
        self.target_critic = copy.deepcopy(self.critic)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4)

3.3 Обновление сети

Обновите принятие сетимягкое обновление,отложенное обновлениеи Т. Д.

 def learn(self):
        self.total_it += 1
        data = self.buffer.smaple(size=128)
        state, action, done, state_next, reward = data
        with torch.no_grad:
            noise = (torch.rand_like(action) * self.policy_noise).clamp(-self.noise_clip, self.noise_clip)
            next_action = (self.target_actor(state_next) + noise).clamp(-self.max_action, self.max_action)
            target_Q1,target_Q2 = self.target_critic(state_next, next_action)
            target_Q = torch.min(target_Q1, target_Q2)
            target_Q = reward + done * self.discount * target_Q
        current_Q1, current_Q2 = self.critic(state, action)
        critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
        critic_loss.backward()
        self.critic_optimizer.step()

        if self.total_it % self.policy_freq == 0:

            q1,q2 = self.critic(state, self.actor(state))
            actor_loss = -torch.min(q1, q2).mean()

            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()
            for param, target_param in zip(self.critic.parameters(), self.target_critic.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

            for param, target_param in zip(self.actor.parameters(), self.target_actor.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

4. Резюме

TD3 является модернизированной версией DDPG.В решении многих задач эффект намного лучше, чем у DDPG.Существенно улучшена как скорость обучения, так и результаты.

5. Информация

  1. разбирательства.malaysian.press/v80/absmodels…

PS: Для получения дополнительной технической галантереи, пожалуйста, обратите внимание на [Публичный аккаунт | xingzhe_ai] и обсудите с ходоками!