cherry.algorithms

cherry.algorithms.arguments.AlgorithmArguments

[Source]

Description

Utility functions to work with dataclass algorithms.

Example
@dataclasses.dataclass
class MyNewAlgorithm(AlgorithmArguments).

    my_arg1: float = 0.0

    def update(self, my_arg1, **kwargs):
        pass

cherry.algorithms.a2c.A2C dataclass

[Source]

Description

Helper functions for implementing A2C.

A2C simply computes the gradient of the policy as follows:

policy_loss(log_probs, advantages) staticmethod

Description

The policy loss of the Advantage Actor-Critic.

This function simply performs an element-wise multiplication and a mean reduction.

References
  1. Mnih et al. 2016. “Asynchronous Methods for Deep Reinforcement Learning.” arXiv [cs.LG].
Arguments
  • log_probs (tensor) - Log-density of the selected actions.
  • advantages (tensor) - Advantage of the action-state pairs.
Returns
  • (tensor) - The policy loss for the given arguments.
Example
advantages = replay.advantage()
log_probs = replay.log_prob()
loss = a2c.policy_loss(log_probs, advantages)

state_value_loss(values, rewards) staticmethod

Description

The state-value loss of the Advantage Actor-Critic.

This function is equivalent to a MSELoss.

References
  1. Mnih et al. 2016. “Asynchronous Methods for Deep Reinforcement Learning.” arXiv [cs.LG].
Arguments
  • values (tensor) - Predicted values for some states.
  • rewards (tensor) - Observed rewards for those states.
Returns
  • (tensor) - The value loss for the given arguments.
Example
values = replay.value()
rewards = replay.reward()
loss = a2c.state_value_loss(values, rewards)

cherry.algorithms.ddpg.DDPG dataclass

[Source]

Description

Utilities to implement deep deterministic policy gradient algorithms from [1].

References
  1. Lillicrap et al., "Continuous Control with Deep Reinforcement Learning", ICLR 2016.

state_value_loss(values, next_values, rewards, dones, gamma) staticmethod

Description

The discounted Bellman loss, computed as:

Arguments
  • values (tensor) - State values for timestep t.
  • next_values (tensor) - State values for timestep t+1.
  • rewards (tensor) - Vector of rewards for timestep t.
  • dones (tensor) - Termination flag.
  • gamma (float) - Discount factor.
Returns
  • (tensor) - The state value loss above.

cherry.algorithms.drq.DrQ dataclass

[Source]

Description

Utilities to implement DrQ from [1].

DrQ (Data-regularized Q) extends SAC to more efficiently train policies and action values from pixels.

References
  1. Kostrikov et al., "Image Augmentation Is All You Need: Regularizing Deep Reinforcement Learning from Pixels", ICLR 2021.
Arguments
  • batch_size (int, optional, default=512) - Number of samples to get from the replay.
  • discount (float, optional, default=0.99) - Discount factor.
  • use_automatic_entropy_tuning (bool, optional, default=True) - Whether to optimize the entropy weight .
  • policy_delay (int, optional, default=1) - Delay between policy updates.
  • target_delay (int, optional, default=1) - Delay between action value updates.
  • target_polyak_weight (float, optional, default=0.995) - Weight factor alpha for Polyak averaging; see cherry.models.polyak_average.

__init__(self, batch_size: int = 512, discount: float = 0.99, use_automatic_entropy_tuning: bool = True, policy_delay: int = 2, target_delay: int = 2, target_polyak_weight: float = 0.995) -> None special

update(self, replay, policy, action_value, target_action_value, features, target_features, log_alpha, target_entropy, policy_optimizer, action_value_optimizer, features_optimizer, alpha_optimizer, update_policy = True, update_target = False, update_value = True, update_entropy = True, augmentation_transform = None, device = None, **kwargs)

Description

Implements a single DrQ update.

Arguments
  • replay (cherry.ExperienceReplay) - Offline replay to sample transitions from.
  • policy (cherry.nn.Policy) - Policy to optimize.
  • action_value (cherry.nn.ActionValue) - Twin action value to optimize; see cherry.nn.Twin.
  • target_action_value (cherry.nn.ActionValue) - Target action value.
  • features (torch.nn.Module) - Feature extractor for the policy and action value.
  • target_features (torch.nn.Module) - Feature extractor for the target action value.
  • log_alpha (torch.Tensor) - SAC's (log) entropy weight.
  • target_entropy (torch.Tensor) - SAC's target for the policy entropy (typically ).
  • policy_optimizer (torch.optim.Optimizer) - Optimizer for the policy.
  • action_value_optimizer (torch.optim.Optimizer) - Optimizer for the action_value.
  • features_optimizer (torch.optim.Optimizer) - Optimizer for the features.
  • alpha_optimizer (torch.optim.Optimizer) - Optimizer for log_alpha.
  • update_policy (bool, optional, default=True) - Whether to update the policy.
  • update_target (bool, optional, default=False) - Whether to update the action value target network.
  • update_value (bool, optional, default=True) - Whether to update the action value.
  • update_entropy (bool, optional, default=True) - Whether to update the entropy weight.
  • augmentation_transform (torch.nn.Module, optional, default=None) - Data augmentation transform to augment image observations. Defaults to RandomShiftsAug(4) (as in the paper).
  • device (torch.device) - The device used to compute the update.

cherry.algorithms.drqv2.DrQv2 dataclass

[Source]

Description

Utilities to implement DrQ-v2 from [1].

DrQ-v2 builds on DrQ but replaces the underlying SAC with TD3. It is noticeably faster in terms of wall-clock time and sample complexity.

References
  1. Yarats et al., "Mastering Visual Continuous Control: Improved Data-Augmented Reinforcement Learning", ICLR 2022.
Arguments
  • batch_size (int, optional, default=512) - Number of samples to get from the replay.
  • discount (float, optional, default=0.99) - Discount factor.
  • policy_delay (int, optional, default=1) - Delay between policy updates.
  • target_delay (int, optional, default=1) - Delay between action value updates.
  • target_polyak_weight (float, optional, default=0.995) - Weight factor alpha for Polyak averaging; see cherry.models.polyak_average.
  • nsteps (int, optional, default=1) - Number of bootstrapping steps to compute the target values.
  • std_decay (float, optional, default=0.0) - Exponential decay rate of the policy's standard deviation. A reasonable value for DMC is 0.99997.
  • min_std (float, optional, default=0.1) - Minimum standard deviation for the policy.

__init__(self, batch_size: int = 512, discount: float = 0.99, policy_delay: int = 1, target_delay: int = 1, target_polyak_weight: float = 0.995, nsteps: int = 1, std_decay: float = 0.0, min_std: float = 0.1) -> None special

update(self, replay, policy, action_value, target_action_value, features, policy_optimizer, action_value_optimizer, features_optimizer, update_policy = True, update_target = True, update_value = True, augmentation_transform = None, device = None, **kwargs)

Description

Implements a single DrQ-v2 update.

Arguments
  • replay (cherry.ExperienceReplay) - Offline replay to sample transitions from.
  • policy (cherry.nn.Policy) - Policy to optimize.
  • action_value (cherry.nn.ActionValue) - Twin action value to optimize; see cherry.nn.Twin.
  • target_action_value (cherry.nn.ActionValue) - Target action value.
  • features (torch.nn.Module) - Feature extractor for the policy and action value.
  • policy_optimizer (torch.optim.Optimizer) - Optimizer for the policy.
  • features_optimizer (torch.optim.Optimizer) - Optimizer for the features.
  • update_policy (bool, optional, default=True) - Whether to update the policy.
  • update_target (bool, optional, default=True) - Whether to update the action value target network.
  • update_value (bool, optional, default=True) - Whether to update the action value.
  • augmentation_transform (torch.nn.Module, optional, default=None) - Data augmentation transform to augment image observations. Defaults to RandomShiftsAug(4) (as in the paper).
  • device (torch.device) - The device used to compute the update.

cherry.algorithms.ppo.PPO dataclass

[Source]

Description

Utilities to implement PPO from [1].

The idea behing PPO is to cheaply approximate TRPO's trust-region with the following objective:

where is the current policy and is the policy used to collect the online replay's data.

References
  1. Schulman et al., “Proximal Policy Optimization Algorithms”, ArXiv 2017.
Arguments

Note: the following arguments were optimized for continuous control on PyBullet / MuJoCo.

  • num_steps (int, optional, default=320) - Number of of PPO gradient steps in a single update.
  • batch_size (int, optional, default=512) - Number of samples to get from the replay.
  • policy_clip (float, optional, default=0.2) - Clip constant for the policy.
  • value_clip (float, optional, default=0.2) - Clip constant for state value function.
  • value_weight (float, optional, default=0.5) - Scaling factor fo the state value function penalty.
  • entropy_weight (float, optional, default=0.0) - Scaling factor of the entropy penalty.
  • discount (float, optional, default=0.99) - Discount factor.
  • gae_tau (float, optional, default=0.95) - Bias-variance trade-off for the generalized advantage estimator.
  • gradient_norm (float, optional, default=0.5) - Maximum gradient norm.
  • eps (float, optional, default=0.5) - Numerical stability constant.

__init__(self, num_steps: int = 320, batch_size: float = 64, policy_clip: float = 0.2, value_clip: float = 0.2, value_weight: float = 0.5, entropy_weight: float = 0.0, discount: float = 0.99, gae_tau: float = 0.95, gradient_norm: float = 0.5, eps: float = 1e-08) -> None special

policy_loss(new_log_probs, old_log_probs, advantages, clip = 0.1) staticmethod

Description

The clipped policy loss of Proximal Policy Optimization.

Arguments
  • new_log_probs (tensor) - The log-density of actions from the target policy.
  • old_log_probs (tensor) - The log-density of actions from the behaviour policy.
  • advantages (tensor) - Advantage of the actions.
  • clip (float, optional, default=0.1) - The clipping coefficient.
Returns
  • loss (tensor) - The clipped policy loss for the given arguments.

Example

advantage = ch.pg.generalized_advantage(
    GAMMA,
    TAU,
    replay.reward(),
    replay.done(),
    replay.value(),
    next_state_value,
)
new_densities = policy(replay.state())
new_logprobs = new_densities.log_prob(replay.action())
loss = policy_loss(
    new_logprobs,
    replay.logprob().detach(),
    advantage.detach(),
    clip=0.2,
)

state_value_loss(new_values, old_values, rewards, clip = 0.1) staticmethod

Description

The clipped state-value loss of Proximal Policy Optimization.

Arguments
  • new_values (tensor) - State values from the optimized value function.
  • old_values (tensor) - State values from the reference value function.
  • rewards (tensor) - Observed rewards.
  • clip (float, optional, default=0.1) - The clipping coefficient.
Returns
  • loss (tensor) - The clipped value loss for the given arguments.
Example
values = v_function(batch.state())
value_loss = ppo.state_value_loss(
    values,
    batch.value().detach(),
    batch.reward(),
    clip=0.2,
)

update(self, replay, policy, optimizer, state_value, **kwargs)

Description

Implements a single PPO update.

Arguments
  • replay (cherry.ExperienceReplay) - Online replay to sample transitions from.
  • policy (cherry.nn.Policy) - Policy to optimize.
  • state_value (cherry.nn.StateValue) - State value function .
  • optimizer (torch.optim.Optimizer) - Optimizer for the policy.

cherry.algorithms.td3.TD3 dataclass

[Source]

Description

Utilities to implement TD3 from [1].

The main idea behind TD3 is to extend DDPG with twin action value functions. Namely, the action values are computed with:

where is a deterministic policy and is (typically) sampled from a Gaussian distribution. See cherry.nn.Twin to easily implement such twin Q-functions.

The authors also suggest to delay the updates to the policy. This simply boils down to applying 1 policy update every N times the action value function is updated. This implementation also supports delaying updates to the action value and its target network.

References
  1. Fujimoto et al., "Addressing Function Approximation Error in Actor-Critic Methods", ICML 2018.
Arguments
  • batch_size (int, optional, default=512) - Number of samples to get from the replay.
  • discount (float, optional, default=0.99) - Discount factor.
  • policy_delay (int, optional, default=1) - Delay between policy updates.
  • target_delay (int, optional, default=1) - Delay between action value updates.
  • target_polyak_weight (float, optional, default=0.995) - Weight factor alpha for Polyak averaging; see cherry.models.polyak_average.
  • nsteps (int, optional, default=1) - Number of bootstrapping steps to compute the target values.

__init__(self, batch_size: int = 512, discount: float = 0.99, policy_delay: int = 1, target_delay: int = 1, target_polyak_weight: float = 0.995, nsteps: int = 1) -> None special

update(self, replay, policy, action_value, target_action_value, policy_optimizer, action_value_optimizer, update_policy = True, update_target = True, update_value = True, device = None, **kwargs)

Description

Implements a single TD3 update.

Arguments
  • replay (cherry.ExperienceReplay) - Offline replay to sample transitions from.
  • policy (cherry.nn.Policy) - Policy to optimize.
  • action_value (cherry.nn.ActionValue) - Twin action value to optimize; see cherry.nn.Twin.
  • target_action_value (cherry.nn.ActionValue) - Target action value.
  • policy_optimizer (torch.optim.Optimizer) - Optimizer for the policy.
  • action_value_optimizer (torch.optim.Optimizer) - Optimizer for the action_value.
  • update_policy (bool, optional, default=True) - Whether to update the policy.
  • update_target (bool, optional, default=True) - Whether to update the action value target network.
  • update_value (bool, optional, default=True) - Whether to update the action value.
  • device (torch.device) - The device used to compute the update.

cherry.algorithms.trpo.TRPO dataclass

Description

Helper functions for implementing Trust-Region Policy Optimization.

Recall that TRPO strives to solve the following objective:

conjugate_gradient(Ax, b, num_iterations = 10, tol = 1e-10, eps = 1e-08) staticmethod

[Source]

Description

Computes using the conjugate gradient algorithm.

Credit

Adapted from Kai Arulkumaran's implementation, with additions inspired from John Schulman's implementation.

References
  1. Nocedal and Wright. 2006. "Numerical Optimization, 2nd edition". Springer.
  2. Shewchuk et al. 1994. “An Introduction to the Conjugate Gradient Method without the Agonizing Pain.” CMU.
Arguments
  • Ax (callable) - Given a vector x, computes A@x.
  • b (tensor or list) - The reference vector.
  • num_iterations (int, optional, default=10) - Number of conjugate gradient iterations.
  • tol (float, optional, default=1e-10) - Tolerance for proposed solution.
  • eps (float, optional, default=1e-8) - Numerical stability constant.
Returns
  • x (tensor or list) - The solution to Ax = b, as a list if b is a list else a tensor.

hessian_vector_product(loss, parameters, damping = 1e-05) staticmethod

[Source]

Description

Returns a callable that computes the product of the Hessian of loss (w.r.t. parameters) with another vector, using Pearlmutter's trick.

Note that parameters and the argument of the callable can be tensors or list of tensors.

References
  1. Pearlmutter, B. A. 1994. “Fast Exact Multiplication by the Hessian.” Neural Computation.
Arguments
  • loss (tensor) - The loss of which to compute the Hessian.
  • parameters (tensor or list) - The tensors to take the gradient with respect to.
  • damping (float, optional, default=1e-5) - Damping of the Hessian-vector product.
Returns
  • hvp(other) (callable) - A function to compute the Hessian-vector product, given a vector or list other.

[Source]

Description

Computes line-search for model parameters given a parameter update and a stopping criterion.

Credit

Adapted from Kai Arulkumaran's implementation, with additions inspired from John Schulman's implementation.

References
  1. Nocedal and Wright. 2006. "Numerical Optimization, 2nd edition". Springer.
Arguments
  • params_init (tensor or iteratble) - Initial parameter values.
  • params_update (tensor or iteratble) - Update direction.
  • model (Module) - The model to be updated.
  • stop_criterion (callable) - Given a model, decided whether to stop the line-search.
  • initial_stepsize (float, optional, default=1.0) - Initial stepsize of search.
  • backtrack_factor (float, optional, default=0.5) - Backtracking factor.
  • max_iterations (int, optional, default=15) - Max number of backtracking iterations.
Returns
  • new_model (Module) - The updated model if line-search is successful, else the model with initial parameter values.
Example
def ls_criterion(new_policy):
    new_density = new_policy(states)
    new_kl = kl_divergence(old_density, new_densityl).mean()
    new_loss = - qvalue(new_density.sample()).mean()
    return new_loss < policy_loss and new_kl < max_kl

with torch.no_grad():
    policy = trpo.line_search(
        params_init=policy.parameters(),
        params_update=step,
        model=policy,
        criterion=ls_criterion
    )

policy_loss(new_log_probs, old_log_probs, advantages) staticmethod

[Source]

Description

The policy loss for Trust-Region Policy Optimization.

This is also known as the surrogate loss.

References
  1. Schulman et al. 2015. “Trust Region Policy Optimization.” ICML 2015.
Arguments
  • new_log_probs (tensor) - The log-density of actions from the target policy.
  • old_log_probs (tensor) - The log-density of actions from the behaviour policy.
  • advantages (tensor) - Advantage of the actions.
Returns
  • (tensor) - The policy loss for the given arguments.
Example
advantage = ch.pg.generalized_advantage(GAMMA,
                                        TAU,
                                        replay.reward(),
                                        replay.done(),
                                        replay.value(),
                                        next_state_value)
new_densities = policy(replay.state())
new_logprobs = new_densities.log_prob(replay.action())
loss = policy_loss(new_logprobs,
                   replay.logprob().detach(),
                   advantage.detach())

cherry.algorithms.sac.SAC dataclass

[Source]

Description

Utilities to implement SAC from [1].

The update() function updates the function approximators in the following order:

  1. Entropy weight update.
  2. Action-value update.
  3. State-value update. (Optional, c.f. below)
  4. Policy update.

Note that most recent implementations of SAC omit step 3. above by using the Bellman residual instead of modelling a state-value function. For an example of such implementation refer to this link.

References
1
2
3
4
5
New actions are sampled from the target policy, and those are used to compute the Q-values.
While we should back-propagate through the Q-values to the policy parameters, we shouldn't
use that gradient to optimize the Q parameters.
This is often avoided by either using a target Q function, or by zero-ing out the gradients
of the Q function parameters.
Arguments
  • batch_size (int, optional, default=512) - Number of samples to get from the replay.
  • discount (float, optional, default=0.99) - Discount factor.
  • use_automatic_entropy_tuning (bool, optional, default=True) - Whether to optimize the entropy weight .
  • policy_delay (int, optional, default=1) - Delay between policy updates.
  • target_delay (int, optional, default=1) - Delay between action value updates.
  • target_polyak_weight (float, optional, default=0.995) - Weight factor alpha for Polyak averaging; see cherry.models.polyak_average.

__init__(self, batch_size: int = 512, discount: float = 0.99, use_automatic_entropy_tuning: bool = True, policy_delay: int = 2, target_delay: int = 2, target_polyak_weight: float = 0.01) -> None special

action_value_loss(value, next_value, rewards, dones, gamma) staticmethod

Description

The action-value loss of the Soft Actor-Critic.

value should be the value of the current state-action pair, estimated via the Q-function. next_value is the expected value of the next state; it can be estimated via a V-function, or alternatively by computing the Q-value of the next observed state-action pair. In the latter case, make sure that the action is sampled according to the current policy, not the one used to gather the data.

Arguments
  • value (tensor) - Action values of the actual transition.
  • next_value (tensor) - State values of the resulting state.
  • rewards (tensor) - Observed rewards of the transition.
  • dones (tensor) - Which states were terminal.
  • gamma (float) - Discount factor.
Returns
  • (tensor) - The policy loss for the given arguments.
Example
value = qf(batch.state(), batch.action().detach())
next_value = targe_vf(batch.next_state())
loss = action_value_loss(value,
                         next_value,
                         batch.reward(),
                         batch.done(),
                         gamma=0.99)

policy_loss(log_probs, q_curr, alpha = 1.0) staticmethod

Description

The policy loss of the Soft Actor-Critic.

New actions are sampled from the target policy, and those are used to compute the Q-values. While we should back-propagate through the Q-values to the policy parameters, we shouldn't use that gradient to optimize the Q parameters. This is often avoided by either using a target Q function, or by zero-ing out the gradients of the Q function parameters.

Arguments
  • log_probs (tensor) - Log-density of the selected actions.
  • q_curr (tensor) - Q-values of state-action pairs.
  • alpha (float, optional, default=1.0) - Entropy weight.
Returns
  • (tensor) - The policy loss for the given arguments.
Example
densities = policy(batch.state())
actions = densities.sample()
log_probs = densities.log_prob(actions)
q_curr = q_function(batch.state(), actions)
loss = policy_loss(log_probs, q_curr, alpha=0.1)

update(self, replay, policy, action_value, target_action_value, log_alpha, target_entropy, policy_optimizer, features_optimizer, action_value_optimizer, alpha_optimizer, features = None, target_features = None, update_policy = True, update_target = False, update_value = True, update_entropy = True, device = None, **kwargs)

Description

Implements a single SAC update.

Arguments
  • replay (cherry.ExperienceReplay) - Offline replay to sample transitions from.
  • policy (cherry.nn.Policy) - Policy to optimize.
  • action_value (cherry.nn.ActionValue) - Twin action value to optimize; see cherry.nn.Twin.
  • target_action_value (cherry.nn.ActionValue) - Target action value.
  • log_alpha (torch.Tensor) - SAC's (log) entropy weight.
  • target_entropy (torch.Tensor) - SAC's target for the policy entropy (typically ).
  • policy_optimizer (torch.optim.Optimizer) - Optimizer for the policy.
  • action_value_optimizer (torch.optim.Optimizer) - Optimizer for the action_value.
  • features_optimizer (torch.optim.Optimizer) - Optimizer for the features.
  • alpha_optimizer (torch.optim.Optimizer) - Optimizer for log_alpha.
  • features (torch.nn.Module, optional, default=None) - Feature extractor for the policy and action value.
  • target_features (torch.nn.Module, optional, default=None) - Feature extractor for the target action value.
  • update_policy (bool, optional, default=True) - Whether to update the policy.
  • update_target (bool, optional, default=False) - Whether to update the action value target network.
  • update_value (bool, optional, default=True) - Whether to update the action value.
  • update_entropy (bool, optional, default=True) - Whether to update the entropy weight.
  • device (torch.device) - The device used to compute the update.