Source code for agent.agents.ppo_agent

#!/usr/local/bin/python
# coding: utf-8
from itertools import count
from typing import Any, Tuple

import numpy

from agent.architectures import DDPGActorArchitecture, DDPGCriticArchitecture
from agent.interfaces.partials.agents.torch_agents.actor_critic_agent import ActorCriticAgent
from agent.interfaces.specifications import ValuedTransition
from agent.interfaces.specifications.generalised_delayed_construction_specification import GDCS
from agent.training.procedures import batched_training
from agent.training.train_agent import parallelised_training, train_agent
from agent.utilities import to_tensor
from warg.named_ordered_dictionary import NOD

__author__ = 'cnheider'

import torch
from torch import nn
from tqdm import tqdm
import torch.nn.functional as F

from agent import utilities as U


[docs]class PPOAgent(ActorCriticAgent): ''' PPO, Proximal Policy Optimization method See method __defaults__ for default parameters ''' # region Private def __defaults__(self) -> None: self._steps = 10 self._discount_factor = 0.99 self._gae_tau = 0.95 # self._reached_horizon_penalty = -10. self._actor_lr = 3e-4 self._critic_lr = 3e-3 self._entropy_reg_coef = 3e-3 self._value_reg_coef = 5e-1 self._batch_size = 64 self._mini_batch_size = 10 self._initial_observation_period = 0 self._target_update_tau = 1.0 self._update_target_interval = 1000 self._max_grad_norm = None self._solved_threshold = -200 self._test_interval = 1000 self._early_stop = False self._rollouts = 10000 self._ppo_epochs = 4 self._current_kl_beta = 1.00 self._state_type = torch.float self._value_type = torch.float self._action_type = torch.long # params for epsilon greedy self._exploration_epsilon_start = 0.99 self._exploration_epsilon_end = 0.05 self._exploration_epsilon_decay = 10000 self._use_cuda = False self._surrogate_clipping_value = 0.2 self._optimiser_spec = GDCS(torch.optim.Adam, {}) self._actor_arch_spec = GDCS(DDPGActorArchitecture, kwargs=NOD({'input_shape': None, # Obtain from environment 'hidden_layers': None, 'output_activation':None, 'output_shape': None, # Obtain from environment })) self._critic_arch_spec = GDCS(DDPGCriticArchitecture, kwargs=NOD({'input_shape': None, # Obtain from environment 'hidden_layers': None, 'output_activation':None, 'output_shape': None, # Obtain from environment })) self._optimiser = None self._update_early_stopping = None # self._update_early_stopping = self.kl_target_stop
[docs] def kl_target_stop(self, old_log_probs, new_log_probs, kl_target=0.03, beta_max=20, beta_min=1 / 20): ''' TRPO negloss = -tf.reduce_mean(self.advantages_ph * tf.exp(self.logp - self.prev_logp)) negloss += tf.reduce_mean(self.beta_ph * self.kl_divergence) negloss += tf.reduce_mean(self.ksi_ph * tf.square(tf.maximum(0.0, self.kl_divergence - 2 * self.kl_target))) self.ksi = 10 Adaptive kl_target = 0.01 Adaptive kl_target = 0.03 :param kl_target: :param beta_max: :param beta_min: :param old_log_probs: :param new_log_probs: :return: ''' kl_now = torch.distributions.kl_divergence(old_log_probs, new_log_probs) if kl_now > 4 * kl_target: return True if kl_now < kl_target / 1.5: self._current_kl_beta /= 2 elif kl_now > kl_target * 1.5: self._current_kl_beta *= 2 self._current_kl_beta = numpy.clip(self._current_kl_beta, beta_min, beta_max) return False
# endregion # region Protected def _optimise(self, cost, **kwargs): self._optimiser.zero_grad() cost.backward() if self._max_grad_norm is not None: nn.utils.clip_grad_norm(self._actor.parameters(), self._max_grad_norm) nn.utils.clip_grad_norm(self._critic.parameters(), self._max_grad_norm) self._optimiser.step() def _sample_model(self, state, *args, **kwargs): ''' continuous randomly sample from normal distribution, whose mean and variance come from policy network. [batch, action_size] :param state: :type state: :param continuous: :type continuous: :param kwargs: :type kwargs: :return: :rtype: ''' model_input = U.to_tensor(state, device=self._device, dtype=self._state_type) mean, std = self._actor(model_input) value_estimate = self._critic(model_input) distribution = torch.distributions.Normal(mean, std) with torch.no_grad(): action = distribution.sample() action_log_prob = distribution.log_prob(action) return action.detach().to('cpu').numpy(), action_log_prob, value_estimate, distribution # endregion # region Public
[docs] def take_n_steps(self, initial_state, environment, n=100, render=False, render_frequency=100): state = initial_state accumulated_signal = 0 transitions = [] terminated = False T = tqdm(range(1, n + 1), f'Step #{self._step_i} - {0}/{n}', leave=False, disable=not render) for t in T: # T.set_description(f'Step #{self._step_i} - {t}/{n}') self._step_i += 1 dist, value_estimates, *_ = self.sample_action(state) action = dist._sample() action_prob = dist.log_prob(action) next_state, signal, terminated, _ = environment.react(action) if render and self._rollout_i % render_frequency == 0: environment.render() successor_state = None if not terminated: # If environment terminated then there is no successor state successor_state = next_state transitions.append( ValuedTransition(state, action, action_prob, value_estimates, signal, successor_state, not terminated, ) ) state = next_state accumulated_signal += signal if terminated: state = environment.reset() self._rollout_i += 1 return transitions, accumulated_signal, terminated, state
[docs] def react(self): pass
[docs] def back_trace_advantages(self, transitions): n_step_summary = ValuedTransition(*zip(*transitions)) advantages = U.advantage_estimate(n_step_summary.signal, n_step_summary.non_terminal, n_step_summary.value_estimate, discount_factor=self._discount_factor, tau=self._gae_tau, device=self._device ) value_estimates = U.to_tensor(n_step_summary.value_estimate, device=self._device) discounted_returns = value_estimates + advantages i = 0 advantage_memories = [] for step in zip(*n_step_summary): step = ValuedTransition(*step) advantage_memories.append( ValuedTransition( step.state, step.action, discounted_returns[i], step.successor_state, step.terminal, step.action_prob, advantages[i] ) ) i += 1 return advantage_memories
[docs] def evaluate3(self, batch, discrete=False, **kwargs): # region Tensorise states = U.to_tensor(batch.state, device=self._device).view(-1, self._input_shape[0]) value_estimates = U.to_tensor(batch.value_estimate, device=self._device) advantages = U.to_tensor(batch.advantage, device=self._device) discounted_returns = U.to_tensor(batch.discounted_return, device=self._device) action_probs_old = U.to_tensor(batch.action_prob, device=self._device).view(-1, self._output_shape[0]) # endregion advantage = (advantages - advantages.mean()) / (advantages.std() + self._divide_by_zero_safety) *_, action_probs_new, distribution = self._sample_model(states) if discrete: actions = U.to_tensor(batch.action, device=self._device).view(-1, self._output_shape[0]) action_probs_old = action_probs_old.gather(1, actions) action_probs_new = action_probs_new.gather(1, actions) ratio = (action_probs_new - action_probs_old).exp() # Generated action probs from (new policy) and (old policy). # Values of [0..1] means that actions less likely with the new policy, # while values [>1] mean action a more likely now surrogate = ratio * advantage clamped_ratio = torch.clamp(ratio, min=1. - self._surrogate_clipping_value, max=1. + self._surrogate_clipping_value) surrogate_clipped = clamped_ratio * advantage # (L^CLIP) policy_loss = -torch.min(surrogate, surrogate_clipped).mean() entropy_loss = distribution.entropy().mean() # value_error = (value_estimates - discounted_returns).pow(2).mean() value_error = F.mse_loss(value_estimates, discounted_returns) collective_cost = policy_loss + value_error * self._value_reg_coef - entropy_loss * self._entropy_reg_coef return collective_cost, policy_loss, value_error
[docs] def evaluate2(self, *, states, actions, log_probs, returns, advantage, **kwargs): action_out, action_log_prob, value_estimate, distribution = self._sample_model(states) old_log_probs = log_probs new_log_probs = distribution.log_prob(actions) ratio = (new_log_probs - old_log_probs).exp() surrogate = ratio * advantage surrogate_clipped = (torch.clamp(ratio, 1.0 - self._surrogate_clipping_value, 1.0 + self._surrogate_clipping_value) * advantage) actor_loss = - torch.min(surrogate, surrogate_clipped).mean() # critic_loss = (value_estimate-returns).pow(2).mean() critic_loss = F.mse_loss(value_estimate, returns) entropy = distribution.entropy().mean() loss = self._value_reg_coef * critic_loss + actor_loss - entropy + self._entropy_reg_coef return loss, new_log_probs, old_log_probs
[docs] def update_targets(self, *args, **kwargs) -> None: self.update_target(target_model=self._target_actor, source_model=self._actor, target_update_tau=self._target_update_tau) self.update_target(target_model=self._target_critic, source_model=self._critic, target_update_tau=self._target_update_tau)
[docs] def evaluate(self, batch, _last_value_estimate, discrete=False, **kwargs): returns_ = U.compute_gae(_last_value_estimate, batch.signal, batch.non_terminal, batch.value_estimate, discount_factor=self._discount_factor, tau=self._gae_tau) returns = torch.cat(returns_).detach() log_probs = torch.cat(batch.action_prob).detach() values = torch.cat(batch.value_estimate).detach() states = torch.cat(batch.state).view(-1, self._input_shape[0]) actions = to_tensor(batch.action).view(-1,self._output_shape[0]) advantage = returns - values self.inner_ppo_update(states, actions, log_probs, returns, advantage) if self._step_i % self._update_target_interval == 0: self.update_targets() return returns,log_probs,values,states,actions,advantage
[docs] def update_models(self, *, stat_writer = None, **kwargs) -> None: pass ''' batch = U.AdvantageMemory(*zip(*self._memory_buffer.sample())) collective_cost, actor_loss, critic_loss = self.evaluate(batch) self._optimise_wrt(collective_cost) '''
[docs] def inner_ppo_update(self, states, actions, log_probs, returns, advantages, ): mini_batch_gen = self.ppo_mini_batch_iter(self._mini_batch_size, states, actions, log_probs, returns, advantages) for _ in range(self._ppo_epochs): try: batch = mini_batch_gen.__next__() except StopIteration: return loss, new_log_probs, old_log_probs = self.evaluate2(**batch.as_dict()) self._actor_optimiser.zero_grad() self._critic_optimiser.zero_grad() loss.backward() self._actor_optimiser.step() self._critic_optimiser.step() if self._update_early_stopping: if self._update_early_stopping(old_log_probs, new_log_probs): break
[docs] @staticmethod def ppo_mini_batch_iter(mini_batch_size: int, states: Any, actions: Any, log_probs: Any, returns: Any, advantage: Any) -> iter: batch_size = actions.size(0) for _ in range(batch_size // mini_batch_size): rand_ids = numpy.random.randint(0, batch_size, mini_batch_size) yield NOD(states=states[rand_ids, :], actions=actions[rand_ids, :], log_probs=log_probs[rand_ids, :], returns=returns[rand_ids, :], advantage=advantage[rand_ids, :])
# endregion # region Test
[docs]def ppo_test(rollouts=None, skip=True): import agent.configs.agent_test_configs.ppo_test_config as C if rollouts: C.ROLLOUTS = rollouts train_agent(PPOAgent, C, training_procedure=parallelised_training(training_procedure=batched_training, auto_reset_on_terminal_state=True), parse_args=False, skip_confirmation=skip)
if __name__ == '__main__': ppo_test() # endregion