cherry.distributions

cherry.distributions.Categorical

[Source]

Description

Similar to torch.nn.Categorical, but reshapes tensors of N samples into (N, 1)-shaped tensors.

Arguments

Identical to torch.distribution.Categorical.

Example
dist = Categorical(logits=torch.randn(bsz, action_size))
actions = dist.sample()  # shape: bsz x 1
log_probs = dist.log_prob(actions)  # shape: bsz x 1
deterministic_action = action.mode()

mode(self)

Description

Returns the model of normal distribution (ie, argmax over probabilities).

cherry.distributions.Normal

[Source]

Description

Similar to PyTorch's Independent(Normal(loc, std)): when computing log-densities or the entropy, we sum over the last dimension.

This is typically used to compute log-probabilities of N-dimensional actions sampled from a multivariate Gaussian with diagional covariance.

Arguments

Identical to torch.distribution.Normal.

Example
normal = Normal(torch.zeros(bsz, action_size), torch.ones(bsz, action_size))
actions = normal.sample()
log_probs = normal.log_prob(actions)  # shape: bsz x 1
entropies = normal.entropy()  # shape: bsz x 1
deterministic_action = action.mode()

mode(self)

Description

Returns the model of normal distribution (ie, its mean).

cherry.distributions.TanhNormal

[Source]

Description

Implements a Normal distribution followed by a Tanh, often used with the Soft Actor-Critic.

This implementation also exposes sample_and_log_prob and rsample_and_log_prob, which returns both samples and log-densities. The log-densities are computed using the pre-activation values for numerical stability.

References
  1. Haarnoja et al. 2018. “Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor.” arXiv [cs.LG].
  2. Haarnoja et al. 2018. “Soft Actor-Critic Algorithms and Applications.” arXiv [cs.LG].
  3. Vitchyr Pong's RLkit.

Example

mean = th.zeros(5)
std = th.ones(5)
dist = TanhNormal(mean, std)
samples = dist.rsample()
logprobs = dist.log_prob(samples)  # Numerically unstable :(
samples, logprobs = dist.rsample_and_log_prob()  # Stable :)

__init__(self, normal_mean, normal_std) special

Arguments
  • normal_mean (tensor) - Mean of the Normal distribution.
  • normal_std (tensor) - Standard deviation of the Normal distribution.

mean(self)

Description

Returns the mean of the TanhDistribution (ie, tan(normal.mean)).

mode(self)

Description

Returns the mode of the TanhDistribution (ie, its mean).

rsample_and_log_prob(self)

Description

Similar to sample_and_log_prob but with reparameterized samples.

sample_and_log_prob(self)

Description

Samples from the TanhNormal and computes the log-density of the samples in a numerically stable way.

Returns
  • value (tensor) - samples from the TanhNormal.
  • log_prob (tensor) - log-probabilities of the samples.
Example
tanh_normal = TanhNormal(torch.zeros(bsz, action_size), torch.ones(bsz, action_size))
actions, log_probs = tanh_normal.sample_and_log_prob()

cherry.distributions.Reparameterization

[Source]

Description

Unifies interface for distributions that support rsample and those that do not.

When calling sample(), this class checks whether density has a rsample() member, and defaults to call sample() if it does not.

References
  1. Kingma and Welling. 2013. “Auto-Encoding Variational Bayes.” arXiv [stat.ML].
Example
density = Normal(mean, std)
reparam = Reparameterization(density)
sample = reparam.sample()  # Uses Normal.rsample()

__init__(self, density) special

Arguments
  • density (Distribution) - The distribution to wrap.

cherry.distributions.ActionDistribution

[Source]

Description

A helper module to automatically choose the proper policy distribution, based on the Gym environment action_space.

For Discrete action spaces, it uses a Categorical distribution, otherwise it uses a Normal which uses a diagonal covariance matrix.

This class enables to write single version policy body that will be compatible with a variety of environments.

Example
env = gym.make('CartPole-v1')
action_dist = ActionDistribution(env)

__init__(self, env, logstd = None, use_probs = False, reparam = False) special

Arguments
  • env (Environment) - Gym environment for which actions will be sampled.
  • logstd (float/tensor, optional, default=0) - The log standard deviation for the Normal distribution.
  • use_probs (bool, optional, default=False) - Whether to use probabilities or logits for the Categorical case.
  • reparam (bool, optional, default=False) - Whether to use reparameterization in the Normal case.