cherry.distributions

Description

A set of common distributions.

Reparameterization

Reparameterization(density)

[Source]

Description

Unifies interface for distributions that support rsample and those that do not.

When calling sample(), this class checks whether density has a rsample() member, and defaults to call sample() if it does not.

References

  1. Kingma and Welling. 2013. “Auto-Encoding Variational Bayes.” arXiv [stat.ML].

Arguments

Example

density = Normal(mean, std)
reparam = Reparameterization(density)
sample = reparam.sample()  # Uses Normal.rsample()

ActionDistribution

ActionDistribution(env, logstd=None, use_probs=False, reparam=False)

[Source]

Description

A helper module to automatically choose the proper policy distribution, based on the Gym environment action_space.

For Discrete action spaces, it uses a Categorical distribution, otherwise it uses a Normal which uses a diagonal covariance matrix.

This class enables to write single version policy body that will be compatible with a variety of environments.

Arguments

Example

env = gym.make('CartPole-v1')
action_dist = ActionDistribution(env)

TanhNormal

TanhNormal(normal_mean, normal_std)

[Source]

Description

Implements a Normal distribution followed by a Tanh, often used with the Soft Actor-Critic.

This implementation also exposes sample_and_log_prob and rsample_and_log_prob, which returns both samples and log-densities. The log-densities are computed using the pre-activation values for numerical stability.

Credit

Adapted from Vitchyr Pong's RLkit: https://github.com/vitchyr/rlkit/blob/master/rlkit/torch/distributions.py

References

  1. Haarnoja et al. 2018. “Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor.” arXiv [cs.LG].
  2. Haarnoja et al. 2018. “Soft Actor-Critic Algorithms and Applications.” arXiv [cs.LG].

Arguments

Example

mean = th.zeros(5)
std = th.ones(5)
dist = TanhNormal(mean, std)
samples = dist.rsample()
logprobs = dist.log_prob(samples)  # Numerically unstable :(
samples, logprobs = dist.rsample_and_log_prob()  # Stable :)