Deep Deterministic Policy Gradients (DDPG)ΒΆ
The Deep Deterministic Policy Gradients (DDPG) algorithm is a little different from other policy objectives. It learns a policy directly from a (type-I) q-function. The
Here \(a_\theta(s)\) is the mode of the underlying conditional probability distribution
\(\pi_\theta(.|s)\). See e.g. the mode
method of coax.proba_dists.NormalDist
.
In other words, we evaluate the policy according to the current estimate of its best-case
performance. This is implemented by the coax.policy_objectives.DeterministicPG
updater
class.
Since the policy objective uses a q-function \(q_\varphi(s,a)\), we also need to learn that. At the moment of writing, there are two ways to learn \(q_\varphi(s,a)\) in coax.
Option 1: SARSA.
The first option is to use SARSA updates, whose \(n\)-step bootstrapped target is constructed as:
where \(A_{t+n}\) is sampled from experience and
This is implemented by the coax.td_learning.Sarsa
updater class.
Option 2: Q-Learning.
The second option is to use q-learning updates, whose \(n\)-step bootstrapped target is instead constructed as:
Here, \(a_{\theta_\text{targ}}\!(s)\) is the mode introduced above, evaluated on the
target-model weights \(\theta_\text{targ}\). The reason why we call this q-learning is that we
construct the TD-target as though the next action \(A_{t+n}\) would have been the greedy action.
This is implemented by the coax.td_learning.QLearningMode
updater class.
For more details, have a look at the spinningup page on DDPG here, which includes links to the original papers.
import gymnasium
import coax
import optax
import haiku as hk
import jax.numpy as jnp
# pick environment
env = gymnasium.make(...)
env = coax.wrappers.TrainMonitor(env)
def func_pi(S, is_training):
# custom haiku function (for continuous actions in this example)
mu = hk.Sequential([...])(S) # mu.shape: (batch_size, *action_space.shape)
return {'mu': mu, 'logvar': jnp.full_like(mu, -10)} # deterministic policy
def func_q(S, A, is_training):
# custom haiku function
value = hk.Sequential([...])
return value(S) # output shape: (batch_size,)
# define function approximator
pi = coax.Policy(func_pi, env)
q = coax.Q(func_q, env, action_preprocessor=pi.proba_dist.preprocess_variate)
# target networks
pi_targ = pi.copy()
q_targ = q.copy()
# specify how to update policy and value function
determ_pg = coax.policy_objectives.DeterministicPG(pi, q, optimizer=optax.adam(0.001))
qlearning = coax.td_learning.QLearning(q, pi_targ, q_targ, optimizer=optax.adam(0.002))
# specify how to trace the transitions
tracer = coax.reward_tracing.NStep(n=1, gamma=0.9)
buffer = coax.experience_replay.SimpleReplayBuffer(capacity=1000000)
# action noise
noise = coax.utils.OrnsteinUhlenbeckNoise(mu=0., sigma=0.2, theta=0.15)
for ep in range(100):
s, info = env.reset()
noise.reset()
noise.sigma *= 0.99 # slowly decrease noise scale
for t in range(env.spec.max_episode_steps):
a = noise(pi(s))
s_next, r, done, truncated, info = env.step(a)
# add transition to buffer
tracer.add(s, a, r, done)
while tracer:
buffer.add(tracer.pop())
# update
transition_batch = buffer.sample(batch_size=32)
metrics_q = qlearning.update(transition_batch)
metrics_pi = determ_pg.update(transition_batch)
env.record_metrics(metrics_q)
env.record_metrics(metrics_pi)
# periodically sync target models
if ep % 10 == 0:
pi_targ.soft_update(pi, tau=1.0)
q_targ.soft_update(q, tau=1.0)
if done or truncated:
break
s = s_next