A composite probability distribution. |
|
A differentiable categorical distribution. |
|
A differentiable normal distribution. |
|
A categorical distribution over a discretized interval. |
|
A differentiable squashed normal distribution. |
Probability Distributions¶
This is a collection of differentiable probability distributions used throughout the package.
Object Reference¶
- class coax.proba_dists.ProbaDist(space)[source]¶
A composite probability distribution. This consists of a nested structure, whose leaves are either
coax.proba_dists.CategoricalDist
orcoax.proba_dists.NormalDist
instances.- Parameters:
space (gymnasium.Space) – The gymnasium-style space that specifies the domain of the distribution. This may be any space included in the
gymnasium.spaces
module.
- postprocess_variate(rng, X, index=0, batch_mode=False)[source]¶
The post-processor specific to variates drawn from this ditribution.
This method provides the interface between differentiable, batched variates, i.e. outputs of
sample()
andmode()
and the provided gymnasium space.- Parameters:
rng (PRNGKey) – A key for seeding the pseudo-random number generator.
X (raw variates) – A batch of raw clean variates, i.e. same format as the outputs of
sample()
andmode()
.index (int, optional) – The index to pick out from the batch. Note that this only applies if
batch_mode=False
.batch_mode (bool, optional) – Whether to return a batch or a single instance.
- Returns:
x or X (clean variate) – A single clean variate or a batch thereof (if
batch_mode=True
). A variate is called clean if it is an instance of the gymnasium-stylespace
, i.e. it satisfiesx in self.space
.
- preprocess_variate(rng, X)[source]¶
The pre-processor to ensure that an instance of the
space
is processed into the same structure as variates drawn from this ditribution, i.e. outputs ofsample()
andmode()
.
- property affine_transform¶
Transform the distribution \(\mathcal{D}\to\mathcal{D}'\) in such a way that its associated variables \(X\sim\mathcal{D}\) and \(X'\sim\mathcal{D}'\) are related via an affine transformation:
\[X' = X\times\text{scale} + \text{shift}\]- Parameters:
dist_params (pytree with ndarray leaves) – The distribution parameters of the original distribution \(\mathcal{D}\).
scale (float or ndarray) – The multiplicative factor of the affine transformation.
shift (float or ndarray) – The additive shift of the affine transformation.
value_transform (ValueTransform, optional) –
The transform to apply to the values before the affine transform, i.e.
\[X' = f\bigl(f^{-1}(X)\times\text{scale} + \text{shift}\bigr)\]
- Returns:
dist_params (pytree with ndarray leaves) – The distribution parameters of the transformed distribution \(\mathcal{D}'\).
- property cross_entropy¶
JIT-compiled function that computes the cross-entropy of a distribution \(q\) relative to another categorical distribution \(p\):
\[\text{CE}[p,q]\ =\ -\mathbb{E}_p \log q\]- Parameters:
dist_params_p (pytree with ndarray leaves) – The distribution parameters of the base distribution \(p\).
dist_params_q (pytree with ndarray leaves) – The distribution parameters of the auxiliary distribution \(q\).
- property default_priors¶
The default distribution parameters.
- property dist_params_structure¶
The tree structure of the distribution parameters.
- property entropy¶
JIT-compiled function that computes the entropy of the distribution.
\[H\ =\ -\mathbb{E}_p \log p\]- Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
- Returns:
H (ndarray of floats) – A batch of entropy values.
- property hyperparams¶
The distribution hyperparameters.
- property kl_divergence¶
JIT-compiled function that computes the Kullback-Leibler divergence of a categorical distribution \(q\) relative to another distribution \(p\):
\[\text{KL}[p,q]\ = -\mathbb{E}_p \left(\log q -\log p\right)\]- Parameters:
dist_params_p (pytree with ndarray leaves) – The distribution parameters of the base distribution \(p\).
dist_params_q (pytree with ndarray leaves) – The distribution parameters of the auxiliary distribution \(q\).
- property log_proba¶
JIT-compiled function that evaluates log-probabilities.
- Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
X (ndarray) – A batch of variates, e.g. a batch of actions \(a\) collected from experience.
- Returns:
logP (ndarray of floats) – A batch of log-probabilities associated with the provided variates.
- property mean¶
JIT-compiled functions that generates differentiable means of the distribution.
- Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
- Returns:
X (ndarray) – A batch of differentiable variates.
- property mode¶
JIT-compiled functions that generates differentiable modes of the distribution.
- Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
- Returns:
X (ndarray) – A batch of differentiable variates.
- property sample¶
JIT-compiled function that generates differentiable variates.
- Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
rng (PRNGKey) – A key for seeding the pseudo-random number generator.
- Returns:
X (ndarray) – A batch of differentiable variates.
- property space¶
The gymnasium-style space that specifies the domain of the distribution.
- class coax.proba_dists.CategoricalDist(space, gumbel_softmax_tau=0.2)[source]¶
A differentiable categorical distribution.
The input
dist_params
to each of the functions is expected to be of the form:dist_params = {'logits': array([...])}
which represent the (conditional) distribution parameters. The
logits
, denoted \(z\in\mathbb{R}^n\), are related to the categorical distribution parameters \(p\in\Delta^n\) via a softmax:\[p_k\ =\ \text{softmax}_k(z)\ =\ \frac{\text{e}^{z_k}}{\sum_j\text{e}^{z_j}}\]- Parameters:
space (gymnasium.spaces.Discrete) – The gymnasium-style space that specifies the domain of the distribution.
gumbel_softmax_tau (positive float, optional) – The parameter \(\tau\) specifies the sharpness of the Gumbel-softmax sampling (see
sample()
method below). A good value for \(\tau\) balances the trade-off between getting proper deterministic variates (i.e. one-hot vectors) versus getting smooth differentiable variates.
- postprocess_variate(rng, X, index=0, batch_mode=False)[source]¶
The post-processor specific to variates drawn from this ditribution.
This method provides the interface between differentiable, batched variates, i.e. outputs of
sample()
andmode()
and the provided gymnasium space.- Parameters:
rng (PRNGKey) – A key for seeding the pseudo-random number generator.
X (raw variates) – A batch of raw clean variates, i.e. same format as the outputs of
sample()
andmode()
.index (int, optional) – The index to pick out from the batch. Note that this only applies if
batch_mode=False
.batch_mode (bool, optional) – Whether to return a batch or a single instance.
- Returns:
x or X (clean variate) – A single clean variate or a batch thereof (if
batch_mode=True
). A variate is called clean if it is an instance of the gymnasium-stylespace
, i.e. it satisfiesx in self.space
.
- preprocess_variate(rng, X)[source]¶
The pre-processor to ensure that an instance of the
space
is processed into the same structure as variates drawn from this ditribution, i.e. outputs ofsample()
andmode()
.
- property affine_transform¶
Transform the distribution \(\mathcal{D}\to\mathcal{D}'\) in such a way that its associated variables \(X\sim\mathcal{D}\) and \(X'\sim\mathcal{D}'\) are related via an affine transformation:
\[X' = X\times\text{scale} + \text{shift}\]- Parameters:
dist_params (pytree with ndarray leaves) – The distribution parameters of the original distribution \(\mathcal{D}\).
scale (float or ndarray) – The multiplicative factor of the affine transformation.
shift (float or ndarray) – The additive shift of the affine transformation.
value_transform (ValueTransform, optional) –
The transform to apply to the values before the affine transform, i.e.
\[X' = f\bigl(f^{-1}(X)\times\text{scale} + \text{shift}\bigr)\]
- Returns:
dist_params (pytree with ndarray leaves) – The distribution parameters of the transformed distribution \(\mathcal{D}'\).
- property cross_entropy¶
JIT-compiled function that computes the cross-entropy of a categorical distribution \(q\) relative to another categorical distribution \(p\):
\[\text{CE}[p,q]\ =\ -\sum_k p_k \log q_k\]- Parameters:
dist_params_p (pytree with ndarray leaves) – The distribution parameters of the base distribution \(p\).
dist_params_q (pytree with ndarray leaves) – The distribution parameters of the auxiliary distribution \(q\).
- property default_priors¶
The default distribution parameters.
- property dist_params_structure¶
The tree structure of the distribution parameters.
- property entropy¶
JIT-compiled function that computes the entropy of the distribution.
\[H\ =\ -\sum_k p_k \log p_k\]- Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
- Returns:
H (ndarray of floats) – A batch of entropy values.
- property hyperparams¶
The distribution hyperparameters.
- property kl_divergence¶
JIT-compiled function that computes the Kullback-Leibler divergence of a categorical distribution \(q\) relative to another categorical distribution \(p\):
\[\text{KL}[p,q]\ =\ -\sum_k p_k \left(\log q_k -\log p_k\right)\]- Parameters:
dist_params_p (pytree with ndarray leaves) – The distribution parameters of the base distribution \(p\).
dist_params_q (pytree with ndarray leaves) – The distribution parameters of the auxiliary distribution \(q\).
- property log_proba¶
JIT-compiled function that evaluates log-probabilities.
- Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
X (ndarray) – A batch of variates, e.g. a batch of actions \(a\) collected from experience.
- Returns:
logP (ndarray of floats) – A batch of log-probabilities associated with the provided variates.
- property mean¶
JIT-compiled functions that generates differentiable means of the distribution. Strictly speaking, the mean of a categorical variable is not well defined. We opt for returning the raw probabilities: \(\text{mean}_k=p_k\).
- Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
- Returns:
X (ndarray) – A batch of would-be variates \(x\sim\text{Cat}(p)\). In contrast to the output of other methods, these aren’t true variates because they are not almost-one-hot encoded.
- property mode¶
JIT-compiled functions that generates differentiable modes of the distribution, for which we use a similar trick as in Gumbel-softmax sampling:
\[\text{mode}_k\ =\ \text{softmax}_k\left( \frac{\log p_k}{\tau} \right)\]- Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
- Returns:
X (ndarray) – A batch of variates \(x\sim\text{Cat}(p)\). In order to ensure differentiability of the variates this is not an integer, but instead an almost-one-hot encoded version thereof.
For example, instead of sampling \(x=2\) from a 4-class categorical distribution, Gumbel-softmax will return a vector like \(x=(0.05, 0.02, 0.86, 0.07)\). The latter representation can be viewed as an almost-one-hot encoded version of the former.
- property sample¶
JIT-compiled function that generates differentiable variates using Gumbel-softmax sampling. \(x\sim\text{Cat}(p)\) is implemented as
\[\begin{split}u_k\ &\sim\ \text{Unif}(0, 1) \\ g_k\ &=\ -\log(-\log(u_k)) \\ x_k\ &=\ \text{softmax}_k\left( \frac{g_k + \log p_k}{\tau} \right)\end{split}\]- Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
rng (PRNGKey) – A key for seeding the pseudo-random number generator.
- Returns:
X (ndarray) – A batch of variates \(x\sim\text{Cat}(p)\). In order to ensure differentiability of the variates this is not an integer, but instead an almost-one-hot encoded version thereof.
For example, instead of sampling \(x=2\) from a 4-class categorical distribution, Gumbel-softmax will return a vector like \(x=[0.05, 0.02, 0.86, 0.07]\). The latter representation can be viewed as an almost-one-hot encoded version of the former.
- property space¶
The gymnasium-style space that specifies the domain of the distribution.
- class coax.proba_dists.NormalDist(space, clip_box=(-256.0, 256.0), clip_reals=(-30.0, 30.0), clip_logvar=(-20.0, 20.0))[source]¶
A differentiable normal distribution.
The input
dist_params
to each of the functions is expected to be of the form:dist_params = {'mu': array([...]), 'logvar': array([...])}
which represent the (conditional) distribution parameters. Here,
mu
is the mean \(\mu\) andlogvar
is the log-variance \(\log(\sigma^2)\).- Parameters:
space (gymnasium.spaces.Box) – The gymnasium-style space that specifies the domain of the distribution.
clip_box (pair of floats, optional) – The range of values to allow for clean (compact) variates. This is mainly to ensure reasonable values when one or more dimensions of the Box space have very large ranges, while in reality only a small part of that range is occupied.
clip_reals (pair of floats, optional) – The range of values to allow for raw (decompactified) variates, the reals, used internally. This range is set for numeric stability. Namely, the
postprocess_variate
method compactifies the reals to a closed interval (Box) by applying a logistic sigmoid. Setting a finite range forclip_reals
ensures that the sigmoid doesn’t fully saturate.clip_logvar (pair of floats, optional) – The range of values to allow for the log-variance of the distribution.
- postprocess_variate(rng, X, index=0, batch_mode=False)[source]¶
The post-processor specific to variates drawn from this ditribution.
This method provides the interface between differentiable, batched variates, i.e. outputs of
sample()
andmode()
and the provided gymnasium space.- Parameters:
rng (PRNGKey) – A key for seeding the pseudo-random number generator.
X (raw variates) – A batch of raw clean variates, i.e. same format as the outputs of
sample()
andmode()
.index (int, optional) – The index to pick out from the batch. Note that this only applies if
batch_mode=False
.batch_mode (bool, optional) – Whether to return a batch or a single instance.
- Returns:
x or X (clean variate) – A single clean variate or a batch thereof (if
batch_mode=True
). A variate is called clean if it is an instance of the gymnasium-stylespace
, i.e. it satisfiesx in self.space
.
- preprocess_variate(rng, X)[source]¶
The pre-processor to ensure that an instance of the
space
is processed into the same structure as variates drawn from this ditribution, i.e. outputs ofsample()
andmode()
.
- property affine_transform¶
Transform the distribution \(\mathcal{D}\to\mathcal{D}'\) in such a way that its associated variables \(X\sim\mathcal{D}\) and \(X'\sim\mathcal{D}'\) are related via an affine transformation:
\[X' = X\times\text{scale} + \text{shift}\]- Parameters:
dist_params (pytree with ndarray leaves) – The distribution parameters of the original distribution \(\mathcal{D}\).
scale (float or ndarray) – The multiplicative factor of the affine transformation.
shift (float or ndarray) – The additive shift of the affine transformation.
value_transform (ValueTransform, optional) –
The transform to apply to the values before the affine transform, i.e.
\[X' = f\bigl(f^{-1}(X)\times\text{scale} + \text{shift}\bigr)\]
- Returns:
dist_params (pytree with ndarray leaves) – The distribution parameters of the transformed distribution \(\mathcal{D}'\).
- property cross_entropy¶
JIT-compiled function that computes the cross-entropy of a distribution \(q\) relative to another categorical distribution \(p\):
\[\text{CE}[p,q]\ =\ -\mathbb{E}_p \log q \ =\ \frac12\left( \log(2\pi\sigma_q^2) + \frac{(\mu_p-\mu_q)^2+\sigma_p^2}{\sigma_q^2} \right)\]- Parameters:
dist_params_p (pytree with ndarray leaves) – The distribution parameters of the base distribution \(p\).
dist_params_q (pytree with ndarray leaves) – The distribution parameters of the auxiliary distribution \(q\).
- property default_priors¶
The default distribution parameters.
- property dist_params_structure¶
The tree structure of the distribution parameters.
- property entropy¶
JIT-compiled function that computes the entropy of the distribution.
\[H\ =\ -\mathbb{E}_p \log p \ =\ \frac12\left( \log(2\pi\sigma^2) + 1\right)\]- Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
- Returns:
H (ndarray of floats) – A batch of entropy values.
- property hyperparams¶
The distribution hyperparameters.
- property kl_divergence¶
JIT-compiled function that computes the Kullback-Leibler divergence of a categorical distribution \(q\) relative to another distribution \(p\):
\[\text{KL}[p,q]\ = -\mathbb{E}_p \left(\log q -\log p\right) \ =\ \frac12\left( \log(\sigma_q^2) - \log(\sigma_p^2) + \frac{(\mu_p-\mu_q)^2+\sigma_p^2}{\sigma_q^2} - 1 \right)\]- Parameters:
dist_params_p (pytree with ndarray leaves) – The distribution parameters of the base distribution \(p\).
dist_params_q (pytree with ndarray leaves) – The distribution parameters of the auxiliary distribution \(q\).
- property log_proba¶
JIT-compiled function that evaluates log-probabilities.
- Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
X (ndarray) – A batch of variates, e.g. a batch of actions \(a\) collected from experience.
- Returns:
logP (ndarray of floats) – A batch of log-probabilities associated with the provided variates.
- property mean¶
JIT-compiled functions that generates differentiable means of the distribution, in this case simply \(\mu\).
- Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
- Returns:
X (ndarray) – A batch of differentiable variates.
- property mode¶
JIT-compiled functions that generates differentiable modes of the distribution, which for a normal distribution is the same as the
mean
.- Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
- Returns:
X (ndarray) – A batch of differentiable variates.
- property sample¶
JIT-compiled function that generates differentiable variates using the reparametrization trick, i.e. \(x\sim\mathcal{N}(\mu,\sigma^2)\) is implemented as
\[\begin{split}\varepsilon\ &\sim\ \mathcal{N}(0,1) \\ x\ &=\ \mu + \sigma\,\varepsilon\end{split}\]- Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
rng (PRNGKey) – A key for seeding the pseudo-random number generator.
- Returns:
X (ndarray) – A batch of differentiable variates.
- property space¶
The gymnasium-style space that specifies the domain of the distribution.
- class coax.proba_dists.DiscretizedIntervalDist(space, num_bins=20, gumbel_softmax_tau=0.2)[source]¶
A categorical distribution over a discretized interval.
The input
dist_params
to each of the functions is expected to be of the form:dist_params = {'logits': array([...])}
which represent the (conditional) distribution parameters. The
logits
, denoted \(z\in\mathbb{R}^n\), are related to the categorical distribution parameters \(p\in\Delta^n\) via a softmax:\[p_k\ =\ \text{softmax}_k(z)\ =\ \frac{\text{e}^{z_k}}{\sum_j\text{e}^{z_j}}\]- Parameters:
space (gymnasium.spaces.Box) – The gymnasium-style space that specifies the domain of the distribution. The shape of the Box must have
prod(shape) == 1
, i.e. a single interval.num_bins (int, optional) – The number of equal-sized bins used in the discretization.
gumbel_softmax_tau (positive float, optional) – The parameter \(\tau\) specifies the sharpness of the Gumbel-softmax sampling (see
sample()
method below). A good value for \(\tau\) balances the trade-off between getting proper deterministic variates (i.e. one-hot vectors) versus getting smooth differentiable variates.
- postprocess_variate(rng, X, index=0, batch_mode=False)[source]¶
The post-processor specific to variates drawn from this ditribution.
This method provides the interface between differentiable, batched variates, i.e. outputs of
sample()
andmode()
and the provided gymnasium space.- Parameters:
rng (PRNGKey) – A key for seeding the pseudo-random number generator.
X (raw variates) – A batch of raw clean variates, i.e. same format as the outputs of
sample()
andmode()
.index (int, optional) – The index to pick out from the batch. Note that this only applies if
batch_mode=False
.batch_mode (bool, optional) – Whether to return a batch or a single instance.
- Returns:
x or X (clean variate) – A single clean variate or a batch thereof (if
batch_mode=True
). A variate is called clean if it is an instance of the gymnasium-stylespace
, i.e. it satisfiesx in self.space
.
- preprocess_variate(rng, X)[source]¶
The pre-processor to ensure that an instance of the
space
is processed into the same structure as variates drawn from this ditribution, i.e. outputs ofsample()
andmode()
.
- property affine_transform¶
Transform the distribution \(\mathcal{D}\to\mathcal{D}'\) in such a way that its associated variables \(X\sim\mathcal{D}\) and \(X'\sim\mathcal{D}'\) are related via an affine transformation:
\[X' = X\times\text{scale} + \text{shift}\]- Parameters:
dist_params (pytree with ndarray leaves) – The distribution parameters of the original distribution \(\mathcal{D}\).
scale (float or ndarray) – The multiplicative factor of the affine transformation.
shift (float or ndarray) – The additive shift of the affine transformation.
value_transform (ValueTransform, optional) –
The transform to apply to the values before the affine transform, i.e.
\[X' = f\bigl(f^{-1}(X)\times\text{scale} + \text{shift}\bigr)\]
- Returns:
dist_params (pytree with ndarray leaves) – The distribution parameters of the transformed distribution \(\mathcal{D}'\).
- property cross_entropy¶
JIT-compiled function that computes the cross-entropy of a categorical distribution \(q\) relative to another categorical distribution \(p\):
\[\text{CE}[p,q]\ =\ -\sum_k p_k \log q_k\]- Parameters:
dist_params_p (pytree with ndarray leaves) – The distribution parameters of the base distribution \(p\).
dist_params_q (pytree with ndarray leaves) – The distribution parameters of the auxiliary distribution \(q\).
- property default_priors¶
The default distribution parameters.
- property dist_params_structure¶
The tree structure of the distribution parameters.
- property entropy¶
JIT-compiled function that computes the entropy of the distribution.
\[H\ =\ -\sum_k p_k \log p_k\]- Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
- Returns:
H (ndarray of floats) – A batch of entropy values.
- property hyperparams¶
The distribution hyperparameters.
- property kl_divergence¶
JIT-compiled function that computes the Kullback-Leibler divergence of a categorical distribution \(q\) relative to another categorical distribution \(p\):
\[\text{KL}[p,q]\ =\ -\sum_k p_k \left(\log q_k -\log p_k\right)\]- Parameters:
dist_params_p (pytree with ndarray leaves) – The distribution parameters of the base distribution \(p\).
dist_params_q (pytree with ndarray leaves) – The distribution parameters of the auxiliary distribution \(q\).
- property log_proba¶
JIT-compiled function that evaluates log-probabilities.
- Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
X (ndarray) – A batch of variates, e.g. a batch of actions \(a\) collected from experience.
- Returns:
logP (ndarray of floats) – A batch of log-probabilities associated with the provided variates.
- property mean¶
JIT-compiled functions that generates differentiable means of the distribution. Strictly speaking, the mean of a categorical variable is not well defined. We opt for returning the raw probabilities: \(\text{mean}_k=p_k\).
- Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
- Returns:
X (ndarray) – A batch of would-be variates \(x\sim\text{Cat}(p)\). In contrast to the output of other methods, these aren’t true variates because they are not almost-one-hot encoded.
- property mode¶
JIT-compiled functions that generates differentiable modes of the distribution, for which we use a similar trick as in Gumbel-softmax sampling:
\[\text{mode}_k\ =\ \text{softmax}_k\left( \frac{\log p_k}{\tau} \right)\]- Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
- Returns:
X (ndarray) – A batch of variates \(x\sim\text{Cat}(p)\). In order to ensure differentiability of the variates this is not an integer, but instead an almost-one-hot encoded version thereof.
For example, instead of sampling \(x=2\) from a 4-class categorical distribution, Gumbel-softmax will return a vector like \(x=(0.05, 0.02, 0.86, 0.07)\). The latter representation can be viewed as an almost-one-hot encoded version of the former.
- property sample¶
JIT-compiled function that generates differentiable variates using Gumbel-softmax sampling. \(x\sim\text{Cat}(p)\) is implemented as
\[\begin{split}u_k\ &\sim\ \text{Unif}(0, 1) \\ g_k\ &=\ -\log(-\log(u_k)) \\ x_k\ &=\ \text{softmax}_k\left( \frac{g_k + \log p_k}{\tau} \right)\end{split}\]- Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
rng (PRNGKey) – A key for seeding the pseudo-random number generator.
- Returns:
X (ndarray) – A batch of variates \(x\sim\text{Cat}(p)\). In order to ensure differentiability of the variates this is not an integer, but instead an almost-one-hot encoded version thereof.
For example, instead of sampling \(x=2\) from a 4-class categorical distribution, Gumbel-softmax will return a vector like \(x=[0.05, 0.02, 0.86, 0.07]\). The latter representation can be viewed as an almost-one-hot encoded version of the former.
- property space¶
The gymnasium-style space that specifies the domain of the distribution.
- class coax.proba_dists.EmpiricalQuantileDist(num_quantiles)[source]¶
- postprocess_variate(rng, X, index=0, batch_mode=False)¶
The post-processor specific to variates drawn from this ditribution.
This method provides the interface between differentiable, batched variates, i.e. outputs of
sample()
andmode()
and the provided gymnasium space.- Parameters:
rng (PRNGKey) – A key for seeding the pseudo-random number generator.
X (raw variates) – A batch of raw clean variates, i.e. same format as the outputs of
sample()
andmode()
.index (int, optional) – The index to pick out from the batch. Note that this only applies if
batch_mode=False
.batch_mode (bool, optional) – Whether to return a batch or a single instance.
- Returns:
x or X (clean variate) – A single clean variate or a batch thereof (if
batch_mode=True
). A variate is called clean if it is an instance of the gymnasium-stylespace
, i.e. it satisfiesx in self.space
.
- preprocess_variate(rng, X)¶
The pre-processor to ensure that an instance of the
space
is processed into the same structure as variates drawn from this ditribution, i.e. outputs ofsample()
andmode()
.
- property affine_transform¶
Transform the distribution \(\mathcal{D}\to\mathcal{D}'\) in such a way that its associated variables \(X\sim\mathcal{D}\) and \(X'\sim\mathcal{D}'\) are related via an affine transformation:
\[X' = X\times\text{scale} + \text{shift}\]- Parameters:
dist_params (pytree with ndarray leaves) – The distribution parameters of the original distribution \(\mathcal{D}\).
scale (float or ndarray) – The multiplicative factor of the affine transformation.
shift (float or ndarray) – The additive shift of the affine transformation.
value_transform (ValueTransform, optional) –
The transform to apply to the values before the affine transform, i.e.
\[X' = f\bigl(f^{-1}(X)\times\text{scale} + \text{shift}\bigr)\]
- Returns:
dist_params (pytree with ndarray leaves) – The distribution parameters of the transformed distribution \(\mathcal{D}'\).
- property cross_entropy¶
JIT-compiled function that computes the cross-entropy of a distribution \(q\) relative to another categorical distribution \(p\):
\[\text{CE}[p,q]\ =\ -\mathbb{E}_p \log q\]- Parameters:
dist_params_p (pytree with ndarray leaves) – The distribution parameters of the base distribution \(p\).
dist_params_q (pytree with ndarray leaves) – The distribution parameters of the auxiliary distribution \(q\).
- property default_priors¶
The default distribution parameters.
- property dist_params_structure¶
The tree structure of the distribution parameters.
- property entropy¶
JIT-compiled function that computes the entropy of the distribution.
\[H\ =\ -\mathbb{E}_p \log p\]- Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
- Returns:
H (ndarray of floats) – A batch of entropy values.
- property hyperparams¶
The distribution hyperparameters.
- property kl_divergence¶
JIT-compiled function that computes the Kullback-Leibler divergence of a categorical distribution \(q\) relative to another distribution \(p\):
\[\text{KL}[p,q]\ = -\mathbb{E}_p \left(\log q -\log p\right)\]- Parameters:
dist_params_p (pytree with ndarray leaves) – The distribution parameters of the base distribution \(p\).
dist_params_q (pytree with ndarray leaves) – The distribution parameters of the auxiliary distribution \(q\).
- property log_proba¶
JIT-compiled function that evaluates log-probabilities.
- Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
X (ndarray) – A batch of variates, e.g. a batch of actions \(a\) collected from experience.
- Returns:
logP (ndarray of floats) – A batch of log-probabilities associated with the provided variates.
- property mean¶
JIT-compiled functions that generates differentiable means of the distribution.
- Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
- Returns:
X (ndarray) – A batch of differentiable variates.
- property mode¶
JIT-compiled functions that generates differentiable modes of the distribution.
- Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
- Returns:
X (ndarray) – A batch of differentiable variates.
- property sample¶
JIT-compiled function that generates differentiable variates.
- Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
rng (PRNGKey) – A key for seeding the pseudo-random number generator.
- Returns:
X (ndarray) – A batch of differentiable variates.
- property space¶
The gymnasium-style space that specifies the domain of the distribution.
- class coax.proba_dists.SquashedNormalDist(space, clip_logvar=None)[source]¶
A differentiable squashed normal distribution.
The input
dist_params
to each of the functions is expected to be of the form:dist_params = {'mu': array([...]), 'logvar': array([...])}
which represent the (conditional) distribution parameters. Here,
mu
is the mean \(\mu\) andlogvar
is the log-variance \(\log(\sigma^2)\).- Parameters:
space (gymnasium.spaces.Box) – The gymnasium-style space that specifies the domain of the distribution.
clip_logvar (pair of floats, optional) – The range of values to allow for the log-variance of the distribution.
- postprocess_variate(rng, X, index=0, batch_mode=False)[source]¶
The post-processor specific to variates drawn from this ditribution.
This method provides the interface between differentiable, batched variates, i.e. outputs of
sample()
andmode()
and the provided gymnasium space.- Parameters:
rng (PRNGKey) – A key for seeding the pseudo-random number generator.
X (raw variates) – A batch of raw clean variates, i.e. same format as the outputs of
sample()
andmode()
.index (int, optional) – The index to pick out from the batch. Note that this only applies if
batch_mode=False
.batch_mode (bool, optional) – Whether to return a batch or a single instance.
- Returns:
x or X (clean variate) – A single clean variate or a batch thereof (if
batch_mode=True
). A variate is called clean if it is an instance of the gymnasium-stylespace
, i.e. it satisfiesx in self.space
.
- preprocess_variate(rng, X)[source]¶
The pre-processor to ensure that an instance of the
space
is processed into the same structure as variates drawn from this ditribution, i.e. outputs ofsample()
andmode()
.
- property affine_transform¶
Transform the distribution \(\mathcal{D}\to\mathcal{D}'\) in such a way that its associated variables \(X\sim\mathcal{D}\) and \(X'\sim\mathcal{D}'\) are related via an affine transformation:
\[X' = X\times\text{scale} + \text{shift}\]- Parameters:
dist_params (pytree with ndarray leaves) – The distribution parameters of the original distribution \(\mathcal{D}\).
scale (float or ndarray) – The multiplicative factor of the affine transformation.
shift (float or ndarray) – The additive shift of the affine transformation.
value_transform (ValueTransform, optional) –
The transform to apply to the values before the affine transform, i.e.
\[X' = f\bigl(f^{-1}(X)\times\text{scale} + \text{shift}\bigr)\]
- Returns:
dist_params (pytree with ndarray leaves) – The distribution parameters of the transformed distribution \(\mathcal{D}'\).
- property cross_entropy¶
JIT-compiled function that computes the cross-entropy of a distribution \(q\) relative to another categorical distribution \(p\):
\[\text{CE}[p,q]\ =\ -\mathbb{E}_p \log q \ =\ \frac12\left( \log(2\pi\sigma_q^2) + \frac{(\mu_p-\mu_q)^2+\sigma_p^2}{\sigma_q^2} \right)\]- Parameters:
dist_params_p (pytree with ndarray leaves) – The distribution parameters of the base distribution \(p\).
dist_params_q (pytree with ndarray leaves) – The distribution parameters of the auxiliary distribution \(q\).
- property default_priors¶
The default distribution parameters.
- property dist_params_structure¶
The tree structure of the distribution parameters.
- property entropy¶
JIT-compiled function that computes the entropy of the distribution.
\[H\ =\ -\mathbb{E}_p \log p \ =\ \frac12\left( \log(2\pi\sigma^2) + 1\right)\]- Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
- Returns:
H (ndarray of floats) – A batch of entropy values.
- property hyperparams¶
The distribution hyperparameters.
- property kl_divergence¶
JIT-compiled function that computes the Kullback-Leibler divergence of a categorical distribution \(q\) relative to another distribution \(p\):
\[\text{KL}[p,q]\ = -\mathbb{E}_p \left(\log q -\log p\right) \ =\ \frac12\left( \log(\sigma_q^2) - \log(\sigma_p^2) + \frac{(\mu_p-\mu_q)^2+\sigma_p^2}{\sigma_q^2} - 1 \right)\]- Parameters:
dist_params_p (pytree with ndarray leaves) – The distribution parameters of the base distribution \(p\).
dist_params_q (pytree with ndarray leaves) – The distribution parameters of the auxiliary distribution \(q\).
- property log_proba¶
JIT-compiled function that evaluates log-probabilities.
- Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
X (ndarray) – A batch of variates, e.g. a batch of actions \(a\) collected from experience.
- Returns:
logP (ndarray of floats) – A batch of log-probabilities associated with the provided variates.
- property mean¶
JIT-compiled functions that generates differentiable means of the distribution, in this case simply \(\tanh(\mu)\).
- Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
- Returns:
X (ndarray) – A batch of differentiable variates.
- property mode¶
JIT-compiled functions that generates differentiable modes of the distribution, which for a normal distribution is the same as the
mean
.- Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
- Returns:
X (ndarray) – A batch of differentiable variates.
- property sample¶
JIT-compiled function that generates differentiable variates using the reparametrization trick, i.e. \(x\sim\tanh(\mathcal{N}(\mu,\sigma^2))\) is implemented as
\[\begin{split}\varepsilon\ &\sim\ \mathcal{N}(0,1) \\ x\ &=\ \tanh(\mu + \sigma\,\varepsilon)\end{split}\]- Parameters:
dist_params (pytree with ndarray leaves) – A batch of distribution parameters.
rng (PRNGKey) – A key for seeding the pseudo-random number generator.
- Returns:
X (ndarray) – A batch of differentiable variates.
- property space¶
The gymnasium-style space that specifies the domain of the distribution.