Policies¶
A parametrized policy \(\pi_\theta(a|s)\). |
|
Create an \(\epsilon\)-greedy policy, given a q-function. |
|
Derive a Boltzmann policy from a q-function. |
|
A simple random policy. |
There are generally two distinct ways of constructing a policy \(\pi(a|s)\). One method uses a function approximator to parametrize a state-action value function \(q_\theta(s,a)\) and then derives a policy from this q-function. The other method uses a function approximator to parametrize the policy directly, i.e. \(\pi(a|s)=\pi_\theta(a|s)\). The methods are called value-based methods and policy gradient methods, respectively.
A policy in coax is a function that maps state observations to actions. The example below shows how to use a policy in a simple episode roll-out.
env = gymnasium.make(...)
s = env.reset()
for t in range(max_episode_steps):
a = pi(s)
s_next, r, done, info = env.step(a)
if done:
break
s = s_next
Some algorithms require us to collect the log-propensities along with the sampled actions. For this
reason, policies have the optional return_logp
flag:
a, logp = pi(s, return_logp=True)
The log-propensity represents \(\log\pi(a|s)\), which is a non-positive real-valued number. A
stochastic policy returns logp<0
, whereas a deterministic policy returns logp=0
.
As an aside, we note that coax policies have two more methods:
a = pi.mode(s) # same as pi(s), except 'sampling' greedily
dist_params = pi.dist_params(s) # distribution parameters, conditioned on s
print(dist_params) # in this example: categorical dist with n=3
# {'logits': array([-0.5711, 1.0513 , 0.0012])}
Random policy¶
Before we discuss value-based policies and parametrized policies, let’s discuss the simplest
possible policy first, namely coax.RandomPolicy
. This policy doesn’t require any function
approximator. It simply calls env.action_space.sample()
. This policy may be useful for
creating simple benchmarks.
pi = coax.RandomPolicy(env)
Value-based policies¶
Value-based policies are defined indirectly, via a q-function. Examples of
value-based policies are coax.EpsilonGreedy
(see example below) and
coax.BoltzmannPolicy
.
pi = coax.EpsilonGreedy(q, epsilon=0.1)
pi = coax.BoltzmannPolicy(q, temperature=0.02)
Note that the hyperparameters epsilon
and temperature
may be updated at any time,
e.g.
pi.epsilon *= 0.99 # at the start of each epsiode
Parametrized policies¶
Now that we’ve discussed value-based policies, let’s start our discussion of parametrized (learnable) policies. We provide three examples:
Discrete actions (categorical dist)
Continuous actions (normal dist)
Discrete actions
A common action space is Discrete
. As an example, we’ll take the
CartPole environment. To get started, let’s generate some example data so that we know the
correct input/output format for our forward-pass function.
env = gymnasium.make('CartPole-v0')
data = coax.Policy.example_data(env)
print(data)
# ExampleData(
# inputs=Inputs(
# args=ArgsType2(
# S=array(shape=(1, 4), dtype=float32)
# is_training=True)
# static_argnums=(1,))
# output={
# 'logits': array(shape=(1, 2), dtype=float32)})
Now, our task is to write a Haiku-style forward-pass function that generates this output given the input. To be clear, our task is not to recreate the exact values; the example data is only there to give us an idea of the structure (shapes, dtypes, etc.).
def func(S, is_training):
logits = hk.Sequential((
hk.Flatten(),
hk.Linear(8), jax.nn.relu,
hk.Linear(8), jax.nn.relu,
hk.Linear(8), jax.nn.relu,
hk.Linear(env.action_space.n, w_init=jnp.zeros)
))
return {'logits': logits(S)}
pi = coax.Policy(func, env)
# example usage
s = env.observation_space.sample()
a = pi(s)
print(a) # 0 or 1
If something goes wrong and you’d like to debug the forward-pass function, here’s an example of what
coax.Policy.__init__
runs under the hood:
rngs = hk.PRNGSequence(42)
transformed = hk.transform_with_state(func)
params, function_state = transformed.init(next(rngs), *data.inputs.args)
output, function_state = transformed.apply(params, function_state, next(rngs), *data.inputs.args)
Continuous actions
Besides discrete actions, we might wish to build an agent compatible with continuous actions. Here’s an example of how to create a valid policy function approximator for the Pendulum environment:
import coax
import jax
import haiku as hk
from math import prod
def func(S, is_training):
shared = hk.Sequential((
hk.Flatten(),
hk.Linear(8), jax.nn.relu,
hk.Linear(8), jax.nn.relu,
))
mu = hk.Sequential((
shared,
hk.Linear(8), jax.nn.relu,
hk.Linear(prod(env.action_space.shape), w_init=jnp.zeros),
hk.Reshape(env.action_space.shape),
))
logvar = hk.Sequential((
shared,
hk.Linear(8), jax.nn.relu,
hk.Linear(prod(env.action_space.shape), w_init=jnp.zeros),
hk.Reshape(env.action_space.shape),
))
return {'mu': mu(S), 'logvar': logvar(S)}
pi = coax.Policy(func, env)
# example usage
s = env.observation_space.sample()
a = pi(s)
print(a)
# array([0.39267802], dtype=float32)
Note that if you’re ever unsure what the correct input / output format is, you can always generate
some example data using the coax.Policy.example_data()
helper (see example above).
Composite actions
The coax package supports all action spaces that are supported by the gymnasium.spaces API.
To illustrate the flexibility of the coax framework, here’s an example of a composite action space:
from collections import namedtuple
from gymnasium.spaces import Dict, Tuple, Box, Discrete, MultiDiscrete
DummyEnv = namedtuple('DummyEnv', ('observation_space', 'action_space'))
env = DummyEnv(
Box(low=0, high=1, shape=(7,)),
Dict({
'foo': MultiDiscrete([4, 5]),
'bar': Tuple((Box(low=0, high=1, shape=(2, 3)),)),
}))
data = coax.Policy.example_data(observation_space, action_space)
print(data.output)
# {'foo': ({'logits': DeviceArray([[-1.29, 0.34, 1.57, 1.88]], dtype=float32)},
# {'logits': DeviceArray([[-0.11, -0.35, -0.57, 2.51, 1.78]], dtype=float32)}),
# 'bar': ({'logvar': DeviceArray([[[-0.11, 1.23, 0.12],
# [-0.35, 0.46, 0.73]]], dtype=float32),
# 'mu': DeviceArray([[[-0.35, -0.37, -0.67],
# [-0.44, -0.71, 0.45]]], dtype=float32)},)}
Thus, if we ensure that our forward-pass function outputs this format, we can sample actions in precisely the same way as we’ve done before. For example, here’s a compatible forward-pass function:
def func(S, is_training):
return {
'foo': ({'logits': hk.Linear(4)(S)},
{'logits': hk.Linear(5)(S)}),
'bar': ({'mu': hk.Linear(6)(S).reshape(-1, 2, 3),
'logvar': hk.Linear(6)(S).reshape(-1, 2, 3)},),
}
pi = coax.Policy(func, env)
# example usage:
s = observation_space.sample()
a, logp = pi(s, return_logp=True)
assert a in action_space
print(logp) # -8.647176
print(a)
# {'foo': array([2, 4]),
# 'bar': (array([[0.18, 0.57, 0.38],
# [0.81, 0.21, 0.67]], dtype=float32),)}
Object Reference¶
- class coax.Policy(func, env, observation_preprocessor=None, proba_dist=None, random_seed=None)[source]¶
A parametrized policy \(\pi_\theta(a|s)\).
- Parameters:
func (function) – A Haiku-style function that specifies the forward pass.
env (gymnasium.Env) – The gymnasium-style environment. This is used to validate the input/output structure of
func
.observation_preprocessor (function, optional) – Turns a single observation into a batch of observations in a form that is convenient for feeding into
func
. If left unspecified, this defaults todefault_preprocessor(env.observation_space)
.proba_dist (ProbaDist, optional) –
A probability distribution that is used to interpret the output of
func <coax.Policy.func>
. Check out thecoax.proba_dists
module for available options.If left unspecified, this defaults to:
proba_dist = coax.proba_dists.ProbaDist(action_space)
random_seed (int, optional) – Seed for pseudo-random number generators.
- __call__(s, return_logp=False)[source]¶
Sample an action \(a\sim\pi_\theta(.|s)\).
- Parameters:
s (state observation) – A single state observation \(s\).
return_logp (bool, optional) – Whether to return the log-propensity \(\log\pi(a|s)\).
- Returns:
a (action) – A single action \(a\).
logp (float, optional) – The log-propensity \(\log\pi_\theta(a|s)\). This is only returned if we set
return_logp=True
.
- copy(deep=False)¶
Create a copy of the current instance.
- Parameters:
deep (bool, optional) – Whether the copy should be a deep copy.
- Returns:
copy – A deep copy of the current instance.
- dist_params(s)[source]¶
Get the conditional distribution parameters of \(\pi_\theta(.|s)\).
- Parameters:
s (state observation) – A single state observation \(s\).
- Returns:
dist_params (Params) – The distribution parameters of \(\pi_\theta(.|s)\).
- classmethod example_data(env, observation_preprocessor=None, proba_dist=None, batch_size=1, random_seed=None)[source]¶
A small utility function that generates example input and output data. These may be useful for writing and debugging your own custom function approximators.
- mean(s)[source]¶
Get the mean of the distribution \(\pi_\theta(.|s)\).
Note that if the actions are discrete, this returns the
mode
instead.- Parameters:
s (state observation) – A single state observation \(s\).
- Returns:
a (action) – A single action \(a\).
- mode(s)[source]¶
Sample a greedy action \(a=\arg\max_a\pi_\theta(a|s)\).
- Parameters:
s (state observation) – A single state observation \(s\).
- Returns:
a (action) – A single action \(a\).
- soft_update(other, tau)¶
Synchronize the current instance with
other
through exponential smoothing:\[\theta\ \leftarrow\ \theta + \tau\, (\theta_\text{new} - \theta)\]- Parameters:
other – A seperate copy of the current object. This object will hold the new parameters \(\theta_\text{new}\).
tau (float between 0 and 1, optional) – If we set \(\tau=1\) we do a hard update. If we pick a smaller value, we do a smooth update.
- property function¶
The function approximator itself, defined as a JIT-compiled pure function. This function may be called directly as:
output, function_state = obj.function(obj.params, obj.function_state, obj.rng, *inputs)
- property function_state¶
The state of the function approximator, see
haiku.transform_with_state()
.
- property mean_func¶
The function that is used for getting the mean of the distribution, defined as a JIT-compiled pure function. This function may be called directly as:
output = obj.mean_func(obj.params, obj.function_state, obj.rng, *inputs)
- property mode_func¶
The function that is used for getting the mode of the distribution, defined as a JIT-compiled pure function. This function may be called directly as:
output = obj.mode_func(obj.params, obj.function_state, obj.rng, *inputs)
- property params¶
The parameters (weights) of the function approximator.
- property sample_func¶
The function that is used for sampling random from the underlying
proba_dist
, defined as a JIT-compiled pure function. This function may be called directly as:output = obj.sample_func(obj.params, obj.function_state, obj.rng, *inputs)
- class coax.EpsilonGreedy(q, epsilon=0.1)[source]¶
Create an \(\epsilon\)-greedy policy, given a q-function.
This policy samples actions \(a\sim\pi_q(.|s)\) according to the following rule:
\[\begin{split}u &\sim \text{Uniform([0, 1])} \\ a_\text{rand} &\sim \text{Uniform}(\text{actions}) \\ a\ &=\ \left\{\begin{matrix} a_\text{rand} & \text{ if } u < \epsilon \\ \arg\max_{a'} q(s,a') & \text{ otherwise } \end{matrix}\right.\end{split}\]- Parameters:
q (Q) – A state-action value function.
epsilon (float between 0 and 1, optional) – The probability of sampling an action uniformly at random (as opposed to sampling greedily).
- __call__(s, return_logp=False)¶
Sample an action \(a\sim\pi_q(.|s)\).
- Parameters:
s (state observation) – A single state observation \(s\).
return_logp (bool, optional) – Whether to return the log-propensity \(\log\pi_q(a|s)\).
- Returns:
a (action) – A single action \(a\).
logp (float, optional) – The log-propensity \(\log\pi_q(a|s)\). This is only returned if we set
return_logp=True
.
- dist_params(s)¶
Get the conditional distribution parameters of \(\pi_q(.|s)\).
- Parameters:
s (state observation) – A single state observation \(s\).
- Returns:
dist_params (Params) – The distribution parameters of \(\pi_q(.|s)\).
- mean(s)¶
Get the mean of the distribution \(\pi_q(.|s)\).
Note that if the actions are discrete, this returns the
mode
instead.- Parameters:
s (state observation) – A single state observation \(s\).
- Returns:
a (action) – A single action \(a\).
- mode(s)¶
Sample a greedy action \(a=\arg\max_a\pi_q(a|s)\).
- Parameters:
s (state observation) – A single state observation \(s\).
- Returns:
a (action) – A single action \(a\).
- property function¶
The function approximator itself, defined as a JIT-compiled pure function. This function may be called directly as:
output, function_state = obj.function(obj.params, obj.function_state, obj.rng, *inputs)
- property function_state¶
The state of the function approximator, see
haiku.transform_with_state()
.
- property mean_func¶
The function that is used for getting the mean of the distribution, defined as a JIT-compiled pure function. This function may be called directly as:
output = obj.mean_func(obj.params, obj.function_state, obj.rng, *inputs)
- property mode_func¶
The function that is used for getting the mode of the distribution, defined as a JIT-compiled pure function. This function may be called directly as:
output = obj.mode_func(obj.params, obj.function_state, obj.rng, *inputs)
- property params¶
The parameters (weights) of the function approximator.
- property sample_func¶
The function that is used for sampling random from the underlying
proba_dist
, defined as a JIT-compiled pure function. This function may be called directly as:output = obj.sample_func(obj.params, obj.function_state, obj.rng, *inputs)
- class coax.BoltzmannPolicy(q, temperature=0.02)[source]¶
Derive a Boltzmann policy from a q-function.
This policy samples actions \(a\sim\pi_q(.|s)\) according to the following rule:
\[\begin{split}p &= \text{softmax}(q(s,.) / \tau) \\ a &\sim \text{Cat}(p)\end{split}\]Note that this policy is only well-defined for discrete action spaces. Also, it’s worth noting that if the q-function has a non-trivial value transform \(f(.)\) (e.g.
coax.value_transforms.LogTransform
), we feed in the transformed estimate as our logits, i.e.\[p = \text{softmax}(f(q(s,.)) / \tau)\]- Parameters:
q (Q) – A state-action value function.
temperature (positive float, optional) – The Boltzmann temperature \(\tau>0\) sets the sharpness of the categorical distribution. Picking a small value for \(\tau\) results in greedy sampling while large values results in uniform sampling.
- __call__(s, return_logp=False)¶
Sample an action \(a\sim\pi_q(.|s)\).
- Parameters:
s (state observation) – A single state observation \(s\).
return_logp (bool, optional) – Whether to return the log-propensity \(\log\pi_q(a|s)\).
- Returns:
a (action) – A single action \(a\).
logp (float, optional) – The log-propensity \(\log\pi_q(a|s)\). This is only returned if we set
return_logp=True
.
- dist_params(s)¶
Get the conditional distribution parameters of \(\pi_q(.|s)\).
- Parameters:
s (state observation) – A single state observation \(s\).
- Returns:
dist_params (Params) – The distribution parameters of \(\pi_q(.|s)\).
- mean(s)¶
Get the mean of the distribution \(\pi_q(.|s)\).
Note that if the actions are discrete, this returns the
mode
instead.- Parameters:
s (state observation) – A single state observation \(s\).
- Returns:
a (action) – A single action \(a\).
- mode(s)¶
Sample a greedy action \(a=\arg\max_a\pi_q(a|s)\).
- Parameters:
s (state observation) – A single state observation \(s\).
- Returns:
a (action) – A single action \(a\).
- property function¶
The function approximator itself, defined as a JIT-compiled pure function. This function may be called directly as:
output, function_state = obj.function(obj.params, obj.function_state, obj.rng, *inputs)
- property function_state¶
The state of the function approximator, see
haiku.transform_with_state()
.
- property mean_func¶
The function that is used for getting the mean of the distribution, defined as a JIT-compiled pure function. This function may be called directly as:
output = obj.mean_func(obj.params, obj.function_state, obj.rng, *inputs)
- property mode_func¶
The function that is used for getting the mode of the distribution, defined as a JIT-compiled pure function. This function may be called directly as:
output = obj.mode_func(obj.params, obj.function_state, obj.rng, *inputs)
- property params¶
The parameters (weights) of the function approximator.
- property sample_func¶
The function that is used for sampling random from the underlying
proba_dist
, defined as a JIT-compiled pure function. This function may be called directly as:output = obj.sample_func(obj.params, obj.function_state, obj.rng, *inputs)
- class coax.RandomPolicy(env, random_seed=None)[source]¶
A simple random policy.
- Parameters:
env (gymnasium.Env) – The gymnasium-style environment. This is only used to get the
env.action_space
.random_seed (int, optional) – Sets the random state to get reproducible results.
- __call__(s, return_logp=False)[source]¶
Sample an action \(a\sim\pi_\theta(.|s)\).
- Parameters:
s (state observation) – A single state observation \(s\).
return_logp (bool, optional) – Whether to return the log-propensity \(\log\pi(a|s)\).
- Returns:
a (action) – A single action \(a\).
logp (float, optional) – The log-propensity \(\log\pi_\theta(a|s)\). This is only returned if we set
return_logp=True
.