Value Functions¶
A state value function \(v_\theta(s)\). |
|
A state-action value function \(q_\theta(s,a)\). |
|
A state-value function \(v(s)\), represented by a stochastic function \(\mathbb{P}_\theta(G_t|S_t=s)\). |
|
A q-function \(q(s,a)\), represented by a stochastic function \(\mathbb{P}_\theta(G_t|S_t=s,A_t=a)\). |
|
A state-action value function \(q(s,a)=r(s,a)+\gamma\mathop{\mathbb{E}}_{s'\sim p(.|s,a)}v(s')\). |
There are two kinds of value functions, state value functions \(v(s)\) and state-action value functions (or q-functions) \(q(s,a)\). The state value function evaluates the expected (discounted) return, defined as:
The operator \(\mathbb{E}_t\) takes the expectation value over all transitions (indexed by
\(t\)). The \(v(s)\) function is implemented by the coax.V
class. The state-action
value is defined in a similar way:
This is implemented by the coax.Q
class.
v(s)¶
In this example we see how to construct a valid state value function \(v(s)\). We’ll start by creating some example data, which allows us inspect the correct input/output format.
import coax
import gymnasium
env = gymnasium.make('CartPole-v0')
data = coax.V.example_data(env)
print(data)
# ExampleData(
# inputs=Inputs(
# args=ArgsType2(
# S=array(shape=(1, 4), dtype=float32)
# is_training=True)
# static_argnums=(1,))
# output=array(shape=(1,), dtype=float32))
From this we may define our Haiku-style forward-pass function:
import jax
import jax.numpy as jnp
import haiku as hk
def func(S, is_training):
seq = hk.Sequential((
hk.Linear(8), jax.nn.relu,
hk.Linear(8), jax.nn.relu,
hk.Linear(8), jax.nn.relu,
hk.Linear(1, w_init=jnp.zeros), jnp.ravel
))
return seq(S)
v = coax.V(func, env)
# example usage
s = env.observation_space.sample()
print(v(s)) # 0.0
q(s, a)¶
In this example we see how to construct a valid state-action value function \(q(s,a)\). Let’s create some example data again.
import coax
import gymnasium
env = gymnasium.make('CartPole-v0')
data = coax.Q.example_data(env)
print(data.type1)
# ExampleData(
# inputs=Inputs(
# args=ArgsType1(
# S=array(shape=(1, 4), dtype=float32)
# A=array(shape=(1, 2), dtype=float32)
# is_training=True)
# static_argnums=(2,))
# output=array(shape=(1,), dtype=float32))
print(data.type2)
# ExampleData(
# inputs=Inputs(
# args=ArgsType2(
# S=array(shape=(1, 4), dtype=float32)
# is_training=True)
# static_argnums=(1,))
# output=array(shape=(1, 2), dtype=float32))
Note that there are two types of modeling a q-function:
where \(n\) is the number of discrete actions. Note that type-2 q-functions are only well-defined for discrete action spaces, whereas type-1 q-functions may be defined for any action space.
Let’s first define our type-1 forward-pass function:
import jax
import jax.numpy as jnp
import haiku as hk
def func_type1(S, A, is_training):
""" (s,a) -> q(s,a) """
seq = hk.Sequential((
hk.Linear(8), jax.nn.relu,
hk.Linear(8), jax.nn.relu,
hk.Linear(8), jax.nn.relu,
hk.Linear(1, w_init=jnp.zeros), jnp.ravel
))
X = jnp.concatenate((S, A), axis=-1)
return seq(X)
q = coax.Q(func_type1, env)
# example usage
s = env.observation_space.sample()
a = env.action_space.sample()
print(q(s, a)) # 0.0
print(q(s)) # array([0., 0.])
Alternatively, a type-2 forward-pass function might be:
def func_type2(S, is_training):
""" s -> q(s,.) """
seq = hk.Sequential((
hk.Linear(8), jax.nn.relu,
hk.Linear(8), jax.nn.relu,
hk.Linear(8), jax.nn.relu,
hk.Linear(env.action_space.n, w_init=jnp.zeros)
))
return seq(S)
q = coax.Q(func_type2, env)
# example usage
s = env.observation_space.sample()
a = env.action_space.sample()
print(q(s, a)) # 0.0
print(q(s)) # array([0., 0.])
If something goes wrong and you’d like to debug the forward-pass function, here’s an example of what
coax.Q.__init__
runs under the hood:
rngs = hk.PRNGSequence(42)
transformed = hk.transform_with_state(func_type2)
params, function_state = transformed.init(next(rngs), *data.type2.inputs.args)
output, function_state = transformed.apply(params, function_state, next(rngs), *data.type2.inputs.args)
Object Reference¶
- class coax.V(func, env, observation_preprocessor=None, value_transform=None, random_seed=None)[source]¶
A state value function \(v_\theta(s)\).
- Parameters:
func (function) – A Haiku-style function that specifies the forward pass. The function signature must be the same as the example below.
env (gymnasium.Env) – The gymnasium-style environment. This is used to validate the input/output structure of
func
.observation_preprocessor (function, optional) – Turns a single observation into a batch of observations in a form that is convenient for feeding into
func
. If left unspecified, this defaults todefault_preprocessor(env.observation_space)
.value_transform (ValueTransform or pair of funcs, optional) –
If provided, the target for the underlying function approximator is transformed such that:
\[\tilde{v}_\theta(S_t)\ \approx\ f(G_t)\]This means that calling the function involves undoing this transformation:
\[v(s)\ =\ f^{-1}(\tilde{v}_\theta(s))\]Here, \(f\) and \(f^{-1}\) are given by
value_transform.transform_func
andvalue_transform.inverse_func
, respectively. Note that a ValueTransform is just a glorified pair of functions, i.e. passingvalue_transform=(func, inverse_func)
works just as well.random_seed (int, optional) – Seed for pseudo-random number generators.
- __call__(s)[source]¶
Evaluate the value function on a state observation \(s\).
- Parameters:
s (state observation) – A single state observation \(s\).
- Returns:
v (ndarray, shape: ()) – The estimated expected value associated with the input state observation
s
.
- copy(deep=False)¶
Create a copy of the current instance.
- Parameters:
deep (bool, optional) – Whether the copy should be a deep copy.
- Returns:
copy – A deep copy of the current instance.
- classmethod example_data(env, observation_preprocessor=None, batch_size=1, random_seed=None)[source]¶
A small utility function that generates example input and output data. These may be useful for writing and debugging your own custom function approximators.
- soft_update(other, tau)¶
Synchronize the current instance with
other
through exponential smoothing:\[\theta\ \leftarrow\ \theta + \tau\, (\theta_\text{new} - \theta)\]- Parameters:
other – A seperate copy of the current object. This object will hold the new parameters \(\theta_\text{new}\).
tau (float between 0 and 1, optional) – If we set \(\tau=1\) we do a hard update. If we pick a smaller value, we do a smooth update.
- property function¶
The function approximator itself, defined as a JIT-compiled pure function. This function may be called directly as:
output, function_state = obj.function(obj.params, obj.function_state, obj.rng, *inputs)
- property function_state¶
The state of the function approximator, see
haiku.transform_with_state()
.
- property params¶
The parameters (weights) of the function approximator.
- class coax.Q(func, env, observation_preprocessor=None, action_preprocessor=None, value_transform=None, random_seed=None)[source]¶
A state-action value function \(q_\theta(s,a)\).
- Parameters:
func (function) – A Haiku-style function that specifies the forward pass. The function signature must be the same as the example below.
env (gymnasium.Env) – The gymnasium-style environment. This is used to validate the input/output structure of
func
.func
.observation_preprocessor (function, optional) – Turns a single observation into a batch of observations in a form that is convenient for feeding into
func
. If left unspecified, this defaults todefault_preprocessor(env.observation_space)
.action_preprocessor (function, optional) – Turns a single action into a batch of actions in a form that is convenient for feeding into
func
. If left unspecified, this defaultsdefault_preprocessor(env.action_space)
.value_transform (ValueTransform or pair of funcs, optional) –
If provided, the target for the underlying function approximator is transformed such that:
\[\tilde{q}_\theta(S_t, A_t)\ \approx\ f(G_t)\]This means that calling the function involves undoing this transformation:
\[q(s, a)\ =\ f^{-1}(\tilde{q}_\theta(s, a))\]Here, \(f\) and \(f^{-1}\) are given by
value_transform.transform_func
andvalue_transform.inverse_func
, respectively. Note that a ValueTransform is just a glorified pair of functions, i.e. passingvalue_transform=(func, inverse_func)
works just as well.random_seed (int, optional) – Seed for pseudo-random number generators.
- __call__(s, a=None)[source]¶
Evaluate the state-action function on a state observation \(s\) or on a state-action pair \((s, a)\).
- Parameters:
s (state observation) – A single state observation \(s\).
a (action) – A single action \(a\).
- Returns:
q_sa or q_s (ndarray) – Depending on whether
a
is provided, this either returns a scalar representing \(q(s,a)\in\mathbb{R}\) or a vector representing \(q(s,.)\in\mathbb{R}^n\), where \(n\) is the number of discrete actions. Naturally, this only applies for discrete action spaces.
- copy(deep=False)¶
Create a copy of the current instance.
- Parameters:
deep (bool, optional) – Whether the copy should be a deep copy.
- Returns:
copy – A deep copy of the current instance.
- classmethod example_data(env, observation_preprocessor=None, action_preprocessor=None, batch_size=1, random_seed=None)[source]¶
A small utility function that generates example input and output data. These may be useful for writing and debugging your own custom function approximators.
- soft_update(other, tau)¶
Synchronize the current instance with
other
through exponential smoothing:\[\theta\ \leftarrow\ \theta + \tau\, (\theta_\text{new} - \theta)\]- Parameters:
other – A seperate copy of the current object. This object will hold the new parameters \(\theta_\text{new}\).
tau (float between 0 and 1, optional) – If we set \(\tau=1\) we do a hard update. If we pick a smaller value, we do a smooth update.
- property function¶
The function approximator itself, defined as a JIT-compiled pure function. This function may be called directly as:
output, function_state = obj.function(obj.params, obj.function_state, obj.rng, *inputs)
- property function_state¶
The state of the function approximator, see
haiku.transform_with_state()
.
- property function_type1¶
Same as
function
, except that it ensures a type-1 function signature, regardless of the underlyingmodeltype
.
- property function_type2¶
Same as
function
, except that it ensures a type-2 function signature, regardless of the underlyingmodeltype
.
- property modeltype¶
Specifier for how the q-function is modeled, i.e.
\[\begin{split}(s,a) &\mapsto q(s,a)\in\mathbb{R} &\qquad (\text{modeltype} &= 1) \\ s &\mapsto q(s,.)\in\mathbb{R}^n &\qquad (\text{modeltype} &= 2)\end{split}\]Note that modeltype=2 is only well-defined if the action space is
Discrete
. Namely, \(n\) is the number of discrete actions.
- property params¶
The parameters (weights) of the function approximator.
- class coax.StochasticV(func, env, value_range, num_bins=51, observation_preprocessor=None, value_transform=None, random_seed=None)[source]¶
A state-value function \(v(s)\), represented by a stochastic function \(\mathbb{P}_\theta(G_t|S_t=s)\).
- Parameters:
func (function) – A Haiku-style function that specifies the forward pass.
env (gymnasium.Env) – The gymnasium-style environment. This is used to validate the input/output structure of
func
.value_range (tuple of floats) – A pair of floats
(min_value, max_value)
.num_bins (int, optional) – The space of rewards is discretized in
num_bins
equal sized bins. We use the default setting of 51 as suggested in the Distributional RL paper.observation_preprocessor (function, optional) – Turns a single observation into a batch of observations in a form that is convenient for feeding into
func
. If left unspecified, this defaults todefault_preprocessor(env.observation_space)
.value_transform (ValueTransform or pair of funcs, optional) –
If provided, the target for the underlying function approximator is transformed:
\[\tilde{G}_t\ =\ f(G_t)\]This means that calling the function involves undoing this transformation using its inverse \(f^{-1}\). The functions \(f\) and \(f^{-1}\) are given by
value_transform.transform_func
andvalue_transform.inverse_func
, respectively. Note that a ValueTransform is just a glorified pair of functions, i.e. passingvalue_transform=(func, inverse_func)
works just as well.random_seed (int, optional) – Seed for pseudo-random number generators.
- __call__(s, return_logp=False)[source]¶
Sample a value.
- Parameters:
s (state observation) – A single state observation \(s\).
return_logp (bool, optional) – Whether to return the log-propensity associated with the sampled output value.
- Returns:
value (float or list thereof) – A single value associated with the state observation \(s\).
logp (non-positive float or list thereof, optional) – The log-propensity associated with the sampled output value. This is only returned if we set
return_logp=True
. Depending on whethera
is provided, this is either a single float or a list of \(n\) floats, one for each discrete action.
- copy(deep=False)¶
Create a copy of the current instance.
- Parameters:
deep (bool, optional) – Whether the copy should be a deep copy.
- Returns:
copy – A deep copy of the current instance.
- dist_params(s)[source]¶
Get the parameters of the underlying (conditional) probability distribution.
- Parameters:
s (state observation) – A single state observation \(s\).
- Returns:
dist_params (dict or list of dicts) – Depending on whether
a
is provided, this either returns a single dist-params dict or a list of \(n\) such dicts, one for each discrete action.
- classmethod example_data(env, value_range, num_bins=51, observation_preprocessor=None, value_transform=None, batch_size=1, random_seed=None)[source]¶
A small utility function that generates example input and output data. These may be useful for writing and debugging your own custom function approximators.
- mean(s)[source]¶
Get the mean value.
- Parameters:
s (state observation) – A single state observation \(s\).
- Returns:
value (float) – A single value associated with the state observation \(s\).
- mode(s)[source]¶
Get the most probable value.
- Parameters:
s (state observation) – A single state observation \(s\).
- Returns:
value (float) – A single value associated with the state observation \(s\).
- soft_update(other, tau)¶
Synchronize the current instance with
other
through exponential smoothing:\[\theta\ \leftarrow\ \theta + \tau\, (\theta_\text{new} - \theta)\]- Parameters:
other – A seperate copy of the current object. This object will hold the new parameters \(\theta_\text{new}\).
tau (float between 0 and 1, optional) – If we set \(\tau=1\) we do a hard update. If we pick a smaller value, we do a smooth update.
- property function¶
The function approximator itself, defined as a JIT-compiled pure function. This function may be called directly as:
output, function_state = obj.function(obj.params, obj.function_state, obj.rng, *inputs)
- property function_state¶
The state of the function approximator, see
haiku.transform_with_state()
.
- property mean_func¶
The function that is used for getting the mean of the distribution, defined as a JIT-compiled pure function. This function may be called directly as:
output = obj.mean_func(obj.params, obj.function_state, obj.rng, *inputs)
- property mode_func¶
The function that is used for getting the mode of the distribution, defined as a JIT-compiled pure function. This function may be called directly as:
output = obj.mode_func(obj.params, obj.function_state, obj.rng, *inputs)
- property params¶
The parameters (weights) of the function approximator.
- property sample_func¶
The function that is used for sampling random from the underlying
proba_dist
, defined as a JIT-compiled pure function. This function may be called directly as:output = obj.sample_func(obj.params, obj.function_state, obj.rng, *inputs)
- class coax.StochasticQ(func, env, value_range=None, num_bins=51, observation_preprocessor=None, action_preprocessor=None, value_transform=None, random_seed=None)[source]¶
A q-function \(q(s,a)\), represented by a stochastic function \(\mathbb{P}_\theta(G_t|S_t=s,A_t=a)\).
- Parameters:
func (function) – A Haiku-style function that specifies the forward pass.
env (gymnasium.Env) – The gymnasium-style environment. This is used to validate the input/output structure of
func
.value_range (tuple of floats, optional) – A pair of floats
(min_value, max_value)
. If novalue_range
is given,num_bins
is the number of bins of the quantile function as in IQN or QR-DQN.num_bins (int, optional) –
If
value_range
is given: The space of rewards is discretized innum_bins
equal sized bins. We use the default setting of 51 as suggested in the Distributional RL paper.Else: The number of fractions of the quantile function of the rewards is defined by
num_bins
as in IQN or QR-DQN.observation_preprocessor (function, optional) – Turns a single observation into a batch of observations in a form that is convenient for feeding into
func
. If left unspecified, this defaults todefault_preprocessor(env.observation_space)
.action_preprocessor (function, optional) – Turns a single action into a batch of actions in a form that is convenient for feeding into
func
. If left unspecified, this defaultsdefault_preprocessor(env.action_space)
.value_transform (ValueTransform or pair of funcs, optional) –
If provided, the target for the underlying function approximator is transformed:
\[\tilde{G}_t\ =\ f(G_t)\]This means that calling the function involves undoing this transformation using its inverse \(f^{-1}\). The functions \(f\) and \(f^{-1}\) are given by
value_transform.transform_func
andvalue_transform.inverse_func
, respectively. Note that a ValueTransform is just a glorified pair of functions, i.e. passingvalue_transform=(func, inverse_func)
works just as well.random_seed (int, optional) – Seed for pseudo-random number generators.
- __call__(s, a=None, return_logp=False)[source]¶
Sample a value.
- Parameters:
s (state observation) – A single state observation \(s\).
a (action, optional) – A single action \(a\). This is required if the actions space is non-discrete.
return_logp (bool, optional) – Whether to return the log-propensity associated with the sampled output value.
- Returns:
value (float or list thereof) – Depending on whether
a
is provided, this either returns a single value or a list of \(n\) values, one for each discrete action.logp (non-positive float or list thereof, optional) – The log-propensity associated with the sampled output value. This is only returned if we set
return_logp=True
. Depending on whethera
is provided, this is either a single float or a list of \(n\) floats, one for each discrete action.
- copy(deep=False)¶
Create a copy of the current instance.
- Parameters:
deep (bool, optional) – Whether the copy should be a deep copy.
- Returns:
copy – A deep copy of the current instance.
- dist_params(s, a=None)[source]¶
Get the parameters of the underlying (conditional) probability distribution.
- Parameters:
s (state observation) – A single state observation \(s\).
a (action, optional) – A single action \(a\). This is required if the actions space is non-discrete.
- Returns:
dist_params (dict or list of dicts) – Depending on whether
a
is provided, this either returns a single dist-params dict or a list of \(n\) such dicts, one for each discrete action.
- classmethod example_data(env, value_range, num_bins=51, observation_preprocessor=None, action_preprocessor=None, value_transform=None, batch_size=1, random_seed=None)[source]¶
A small utility function that generates example input and output data. These may be useful for writing and debugging your own custom function approximators.
- mean(s, a=None)[source]¶
Get the mean value.
- Parameters:
s (state observation) – A single state observation \(s\).
a (action, optional) – A single action \(a\). This is required if the actions space is non-discrete.
- Returns:
value (float or list thereof) – Depending on whether
a
is provided, this either returns a single value or a list of \(n\) values, one for each discrete action.
- mode(s, a=None)[source]¶
Get the most probable value.
- Parameters:
s (state observation) – A single state observation \(s\).
a (action, optional) – A single action \(a\). This is required if the actions space is non-discrete.
- Returns:
value (float or list thereof) – Depending on whether
a
is provided, this either returns a single value or a list of \(n\) values, one for each discrete action.
- soft_update(other, tau)¶
Synchronize the current instance with
other
through exponential smoothing:\[\theta\ \leftarrow\ \theta + \tau\, (\theta_\text{new} - \theta)\]- Parameters:
other – A seperate copy of the current object. This object will hold the new parameters \(\theta_\text{new}\).
tau (float between 0 and 1, optional) – If we set \(\tau=1\) we do a hard update. If we pick a smaller value, we do a smooth update.
- property function¶
The function approximator itself, defined as a JIT-compiled pure function. This function may be called directly as:
output, function_state = obj.function(obj.params, obj.function_state, obj.rng, *inputs)
- property function_state¶
The state of the function approximator, see
haiku.transform_with_state()
.
- property function_type1¶
Same as
function
, except that it ensures a type-1 function signature, regardless of the underlyingmodeltype
.
- property function_type2¶
Same as
function
, except that it ensures a type-2 function signature, regardless of the underlyingmodeltype
.
- property mean_func_type1¶
The function that is used for computing the mean, defined as a JIT-compiled pure function. This function may be called directly as:
output = obj.mean_func_type1(obj.params, obj.function_state, obj.rng, S, A)
- property mean_func_type2¶
The function that is used for computing the mean, defined as a JIT-compiled pure function. This function may be called directly as:
output = obj.mean_func_type2(obj.params, obj.function_state, obj.rng, S)
- property mode_func_type1¶
The function that is used for computing the mode, defined as a JIT-compiled pure function. This function may be called directly as:
output = obj.mode_func_type1(obj.params, obj.function_state, obj.rng, S, A)
- property mode_func_type2¶
The function that is used for computing the mode, defined as a JIT-compiled pure function. This function may be called directly as:
output = obj.mode_func_type2(obj.params, obj.function_state, obj.rng, S)
- property modeltype¶
Specifier for how the dynamics model is implemented, i.e.
\[\begin{split}(s,a) &\mapsto p(s'|s,a) &\qquad (\text{modeltype} &= 1) \\ s &\mapsto p(s'|s,.) &\qquad (\text{modeltype} &= 2)\end{split}\]Note that modeltype=2 is only well-defined if the action space is
Discrete
. Namely, \(n\) is the number of discrete actions.
- property params¶
The parameters (weights) of the function approximator.
- property sample_func_type1¶
The function that is used for generating random samples, defined as a JIT-compiled pure function. This function may be called directly as:
output = obj.sample_func_type1(obj.params, obj.function_state, obj.rng, S)
- property sample_func_type2¶
The function that is used for generating random samples, defined as a JIT-compiled pure function. This function may be called directly as:
output = obj.sample_func_type2(obj.params, obj.function_state, obj.rng, S, A)
- class coax.SuccessorStateQ(v, p, r, gamma=0.9)[source]¶
A state-action value function \(q(s,a)=r(s,a)+\gamma\mathop{\mathbb{E}}_{s'\sim p(.|s,a)}v(s')\).
caution A word of caution: If you use custom observation/action pre-/post-processors, please make sure that all three function approximators
v
,p
andr
use the same ones.- Parameters:
v (V or StochasticV) – A state value function \(v(s)\).
p (TransitionModel or StochasticTransitionModel) – A transition model.
r (RewardFunction or StochasticRewardFunction) – A reward function.
gamma (float between 0 and 1, optional) – The discount factor for future rewards \(\gamma\in[0,1]\).
- __call__(s, a=None)[source]¶
Evaluate the state-action function on a state observation \(s\) or on a state-action pair \((s, a)\).
- Parameters:
s (state observation) – A single state observation \(s\).
a (action) – A single action \(a\).
- Returns:
q_sa or q_s (ndarray) – Depending on whether
a
is provided, this either returns a scalar representing \(q(s,a)\in\mathbb{R}\) or a vector representing \(q(s,.)\in\mathbb{R}^n\), where \(n\) is the number of discrete actions. Naturally, this only applies for discrete action spaces.