cherry.distributions¶
cherry.distributions.Categorical
¶
Description¶
Similar to torch.nn.Categorical
, but reshapes tensors of N
samples into (N, 1)-shaped tensors.
Arguments¶
Identical to torch.distribution.Categorical
.
Example¶
dist = Categorical(logits=torch.randn(bsz, action_size))
actions = dist.sample() # shape: bsz x 1
log_probs = dist.log_prob(actions) # shape: bsz x 1
deterministic_action = action.mode()
cherry.distributions.Normal
¶
Description¶
Similar to PyTorch's Independent(Normal(loc, std))
: when computing log-densities or the entropy,
we sum over the last dimension.
This is typically used to compute log-probabilities of N-dimensional actions sampled from a multivariate Gaussian with diagional covariance.
Arguments¶
Identical to torch.distribution.Normal
.
Example¶
normal = Normal(torch.zeros(bsz, action_size), torch.ones(bsz, action_size))
actions = normal.sample()
log_probs = normal.log_prob(actions) # shape: bsz x 1
entropies = normal.entropy() # shape: bsz x 1
deterministic_action = action.mode()
cherry.distributions.TanhNormal
¶
Description¶
Implements a Normal distribution followed by a Tanh, often used with the Soft Actor-Critic.
This implementation also exposes sample_and_log_prob
and rsample_and_log_prob
,
which returns both samples and log-densities.
The log-densities are computed using the pre-activation values for numerical stability.
References¶
- Haarnoja et al. 2018. “Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor.” arXiv [cs.LG].
- Haarnoja et al. 2018. “Soft Actor-Critic Algorithms and Applications.” arXiv [cs.LG].
- Vitchyr Pong's RLkit.
Example
mean = th.zeros(5)
std = th.ones(5)
dist = TanhNormal(mean, std)
samples = dist.rsample()
logprobs = dist.log_prob(samples) # Numerically unstable :(
samples, logprobs = dist.rsample_and_log_prob() # Stable :)
__init__(self, normal_mean, normal_std)
special
¶
Arguments¶
normal_mean
(tensor) - Mean of the Normal distribution.normal_std
(tensor) - Standard deviation of the Normal distribution.
rsample_and_log_prob(self)
¶
Description¶
Similar to sample_and_log_prob
but with reparameterized samples.
sample_and_log_prob(self)
¶
Description¶
Samples from the TanhNormal and computes the log-density of the samples in a numerically stable way.
Returns¶
value
(tensor) - samples from the TanhNormal.log_prob
(tensor) - log-probabilities of the samples.
Example¶
tanh_normal = TanhNormal(torch.zeros(bsz, action_size), torch.ones(bsz, action_size))
actions, log_probs = tanh_normal.sample_and_log_prob()
cherry.distributions.Reparameterization
¶
Description¶
Unifies interface for distributions that support rsample
and those that do not.
When calling sample()
, this class checks whether density
has a rsample()
member,
and defaults to call sample()
if it does not.
References¶
- Kingma and Welling. 2013. “Auto-Encoding Variational Bayes.” arXiv [stat.ML].
Example¶
density = Normal(mean, std)
reparam = Reparameterization(density)
sample = reparam.sample() # Uses Normal.rsample()
cherry.distributions.ActionDistribution
¶
Description¶
A helper module to automatically choose the proper policy distribution,
based on the Gym environment action_space
.
For Discrete
action spaces, it uses a Categorical
distribution, otherwise
it uses a Normal
which uses a diagonal covariance matrix.
This class enables to write single version policy body that will be compatible with a variety of environments.
Example¶
env = gym.make('CartPole-v1')
action_dist = ActionDistribution(env)
__init__(self, env, logstd = None, use_probs = False, reparam = False)
special
¶
Arguments¶
env
(Environment) - Gym environment for which actions will be sampled.logstd
(float/tensor, optional, default=0) - The log standard deviation for theNormal
distribution.use_probs
(bool, optional, default=False) - Whether to use probabilities or logits for theCategorical
case.reparam
(bool, optional, default=False) - Whether to use reparameterization in theNormal
case.