cherry.algorithms¶
cherry.algorithms.arguments.AlgorithmArguments
¶
cherry.algorithms.a2c.A2C
dataclass
¶
Description¶
Helper functions for implementing A2C.
A2C simply computes the gradient of the policy as follows:
policy_loss(log_probs, advantages)
staticmethod
¶
Description¶
The policy loss of the Advantage Actor-Critic.
This function simply performs an element-wise multiplication and a mean reduction.
References¶
- Mnih et al. 2016. “Asynchronous Methods for Deep Reinforcement Learning.” arXiv [cs.LG].
Arguments¶
log_probs
(tensor) - Log-density of the selected actions.advantages
(tensor) - Advantage of the action-state pairs.
Returns¶
- (tensor) - The policy loss for the given arguments.
Example¶
advantages = replay.advantage()
log_probs = replay.log_prob()
loss = a2c.policy_loss(log_probs, advantages)
state_value_loss(values, rewards)
staticmethod
¶
Description¶
The state-value loss of the Advantage Actor-Critic.
This function is equivalent to a MSELoss.
References¶
- Mnih et al. 2016. “Asynchronous Methods for Deep Reinforcement Learning.” arXiv [cs.LG].
Arguments¶
values
(tensor) - Predicted values for some states.rewards
(tensor) - Observed rewards for those states.
Returns¶
- (tensor) - The value loss for the given arguments.
Example¶
values = replay.value()
rewards = replay.reward()
loss = a2c.state_value_loss(values, rewards)
cherry.algorithms.ddpg.DDPG
dataclass
¶
Description¶
Utilities to implement deep deterministic policy gradient algorithms from [1].
References¶
- Lillicrap et al., "Continuous Control with Deep Reinforcement Learning", ICLR 2016.
state_value_loss(values, next_values, rewards, dones, gamma)
staticmethod
¶
Description¶
The discounted Bellman loss, computed as:
Arguments¶
values
(tensor) - State values for timestep t.next_values
(tensor) - State values for timestep t+1.rewards
(tensor) - Vector of rewards for timestep t.dones
(tensor) - Termination flag.gamma
(float) - Discount factor.
Returns¶
- (tensor) - The state value loss above.
cherry.algorithms.drq.DrQ
dataclass
¶
Description¶
Utilities to implement DrQ from [1].
DrQ (Data-regularized Q) extends SAC to more efficiently train policies and action values from pixels.
References¶
- Kostrikov et al., "Image Augmentation Is All You Need: Regularizing Deep Reinforcement Learning from Pixels", ICLR 2021.
Arguments¶
batch_size
(int, optional, default=512) - Number of samples to get from the replay.discount
(float, optional, default=0.99) - Discount factor.use_automatic_entropy_tuning
(bool, optional, default=True) - Whether to optimize the entropy weight .policy_delay
(int, optional, default=1) - Delay between policy updates.target_delay
(int, optional, default=1) - Delay between action value updates.target_polyak_weight
(float, optional, default=0.995) - Weight factoralpha
for Polyak averaging; see cherry.models.polyak_average.
__init__(self, batch_size: int = 512, discount: float = 0.99, use_automatic_entropy_tuning: bool = True, policy_delay: int = 2, target_delay: int = 2, target_polyak_weight: float = 0.995) -> None
special
¶
update(self, replay, policy, action_value, target_action_value, features, target_features, log_alpha, target_entropy, policy_optimizer, action_value_optimizer, features_optimizer, alpha_optimizer, update_policy = True, update_target = False, update_value = True, update_entropy = True, augmentation_transform = None, device = None, **kwargs)
¶
Description¶
Implements a single DrQ update.
Arguments¶
replay
(cherry.ExperienceReplay) - Offline replay to sample transitions from.policy
(cherry.nn.Policy) - Policy to optimize.action_value
(cherry.nn.ActionValue) - Twin action value to optimize; see cherry.nn.Twin.target_action_value
(cherry.nn.ActionValue) - Target action value.features
(torch.nn.Module) - Feature extractor for the policy and action value.target_features
(torch.nn.Module) - Feature extractor for the target action value.log_alpha
(torch.Tensor) - SAC's (log) entropy weight.target_entropy
(torch.Tensor) - SAC's target for the policy entropy (typically ).policy_optimizer
(torch.optim.Optimizer) - Optimizer for thepolicy
.action_value_optimizer
(torch.optim.Optimizer) - Optimizer for theaction_value
.features_optimizer
(torch.optim.Optimizer) - Optimizer for thefeatures
.alpha_optimizer
(torch.optim.Optimizer) - Optimizer forlog_alpha
.update_policy
(bool, optional, default=True) - Whether to update the policy.update_target
(bool, optional, default=False) - Whether to update the action value target network.update_value
(bool, optional, default=True) - Whether to update the action value.update_entropy
(bool, optional, default=True) - Whether to update the entropy weight.augmentation_transform
(torch.nn.Module, optional, default=None) - Data augmentation transform to augment image observations. Defaults toRandomShiftsAug(4)
(as in the paper).device
(torch.device) - The device used to compute the update.
cherry.algorithms.drqv2.DrQv2
dataclass
¶
Description¶
Utilities to implement DrQ-v2 from [1].
DrQ-v2 builds on DrQ but replaces the underlying SAC with TD3. It is noticeably faster in terms of wall-clock time and sample complexity.
References¶
- Yarats et al., "Mastering Visual Continuous Control: Improved Data-Augmented Reinforcement Learning", ICLR 2022.
Arguments¶
batch_size
(int, optional, default=512) - Number of samples to get from the replay.discount
(float, optional, default=0.99) - Discount factor.policy_delay
(int, optional, default=1) - Delay between policy updates.target_delay
(int, optional, default=1) - Delay between action value updates.target_polyak_weight
(float, optional, default=0.995) - Weight factoralpha
for Polyak averaging; see cherry.models.polyak_average.nsteps
(int, optional, default=1) - Number of bootstrapping steps to compute the target values.std_decay
(float, optional, default=0.0) - Exponential decay rate of the policy's standard deviation. A reasonable value for DMC is 0.99997.min_std
(float, optional, default=0.1) - Minimum standard deviation for the policy.
__init__(self, batch_size: int = 512, discount: float = 0.99, policy_delay: int = 1, target_delay: int = 1, target_polyak_weight: float = 0.995, nsteps: int = 1, std_decay: float = 0.0, min_std: float = 0.1) -> None
special
¶
update(self, replay, policy, action_value, target_action_value, features, policy_optimizer, action_value_optimizer, features_optimizer, update_policy = True, update_target = True, update_value = True, augmentation_transform = None, device = None, **kwargs)
¶
Description¶
Implements a single DrQ-v2 update.
Arguments¶
replay
(cherry.ExperienceReplay) - Offline replay to sample transitions from.policy
(cherry.nn.Policy) - Policy to optimize.action_value
(cherry.nn.ActionValue) - Twin action value to optimize; see cherry.nn.Twin.target_action_value
(cherry.nn.ActionValue) - Target action value.features
(torch.nn.Module) - Feature extractor for the policy and action value.policy_optimizer
(torch.optim.Optimizer) - Optimizer for thepolicy
.features_optimizer
(torch.optim.Optimizer) - Optimizer for thefeatures
.update_policy
(bool, optional, default=True) - Whether to update the policy.update_target
(bool, optional, default=True) - Whether to update the action value target network.update_value
(bool, optional, default=True) - Whether to update the action value.augmentation_transform
(torch.nn.Module, optional, default=None) - Data augmentation transform to augment image observations. Defaults toRandomShiftsAug(4)
(as in the paper).device
(torch.device) - The device used to compute the update.
cherry.algorithms.ppo.PPO
dataclass
¶
Description¶
Utilities to implement PPO from [1].
The idea behing PPO is to cheaply approximate TRPO's trust-region with the following objective:
where is the current policy and is the policy used to collect the online replay's data.
References¶
- Schulman et al., “Proximal Policy Optimization Algorithms”, ArXiv 2017.
Arguments¶
Note: the following arguments were optimized for continuous control on PyBullet / MuJoCo.
num_steps
(int, optional, default=320) - Number of of PPO gradient steps in a single update.batch_size
(int, optional, default=512) - Number of samples to get from the replay.policy_clip
(float, optional, default=0.2) - Clip constant for the policy.value_clip
(float, optional, default=0.2) - Clip constant for state value function.value_weight
(float, optional, default=0.5) - Scaling factor fo the state value function penalty.entropy_weight
(float, optional, default=0.0) - Scaling factor of the entropy penalty.discount
(float, optional, default=0.99) - Discount factor.gae_tau
(float, optional, default=0.95) - Bias-variance trade-off for the generalized advantage estimator.gradient_norm
(float, optional, default=0.5) - Maximum gradient norm.eps
(float, optional, default=0.5) - Numerical stability constant.
__init__(self, num_steps: int = 320, batch_size: float = 64, policy_clip: float = 0.2, value_clip: float = 0.2, value_weight: float = 0.5, entropy_weight: float = 0.0, discount: float = 0.99, gae_tau: float = 0.95, gradient_norm: float = 0.5, eps: float = 1e-08) -> None
special
¶
policy_loss(new_log_probs, old_log_probs, advantages, clip = 0.1)
staticmethod
¶
Description¶
The clipped policy loss of Proximal Policy Optimization.
Arguments¶
new_log_probs
(tensor) - The log-density of actions from the target policy.old_log_probs
(tensor) - The log-density of actions from the behaviour policy.advantages
(tensor) - Advantage of the actions.clip
(float, optional, default=0.1) - The clipping coefficient.
Returns¶
- loss (tensor) - The clipped policy loss for the given arguments.
Example
advantage = ch.pg.generalized_advantage(
GAMMA,
TAU,
replay.reward(),
replay.done(),
replay.value(),
next_state_value,
)
new_densities = policy(replay.state())
new_logprobs = new_densities.log_prob(replay.action())
loss = policy_loss(
new_logprobs,
replay.logprob().detach(),
advantage.detach(),
clip=0.2,
)
state_value_loss(new_values, old_values, rewards, clip = 0.1)
staticmethod
¶
Description¶
The clipped state-value loss of Proximal Policy Optimization.
Arguments¶
new_values
(tensor) - State values from the optimized value function.old_values
(tensor) - State values from the reference value function.rewards
(tensor) - Observed rewards.clip
(float, optional, default=0.1) - The clipping coefficient.
Returns¶
- loss (tensor) - The clipped value loss for the given arguments.
Example¶
values = v_function(batch.state())
value_loss = ppo.state_value_loss(
values,
batch.value().detach(),
batch.reward(),
clip=0.2,
)
update(self, replay, policy, optimizer, state_value, **kwargs)
¶
Description¶
Implements a single PPO update.
Arguments¶
replay
(cherry.ExperienceReplay) - Online replay to sample transitions from.policy
(cherry.nn.Policy) - Policy to optimize.state_value
(cherry.nn.StateValue) - State value function .optimizer
(torch.optim.Optimizer) - Optimizer for thepolicy
.
cherry.algorithms.td3.TD3
dataclass
¶
Description¶
Utilities to implement TD3 from [1].
The main idea behind TD3 is to extend DDPG with twin action value functions. Namely, the action values are computed with:
where is a deterministic policy and is (typically) sampled from a Gaussian distribution. See cherry.nn.Twin to easily implement such twin Q-functions.
The authors also suggest to delay the updates to the policy. This simply boils down to applying 1 policy update every N times the action value function is updated. This implementation also supports delaying updates to the action value and its target network.
References¶
- Fujimoto et al., "Addressing Function Approximation Error in Actor-Critic Methods", ICML 2018.
Arguments¶
batch_size
(int, optional, default=512) - Number of samples to get from the replay.discount
(float, optional, default=0.99) - Discount factor.policy_delay
(int, optional, default=1) - Delay between policy updates.target_delay
(int, optional, default=1) - Delay between action value updates.target_polyak_weight
(float, optional, default=0.995) - Weight factoralpha
for Polyak averaging; see cherry.models.polyak_average.nsteps
(int, optional, default=1) - Number of bootstrapping steps to compute the target values.
__init__(self, batch_size: int = 512, discount: float = 0.99, policy_delay: int = 1, target_delay: int = 1, target_polyak_weight: float = 0.995, nsteps: int = 1) -> None
special
¶
update(self, replay, policy, action_value, target_action_value, policy_optimizer, action_value_optimizer, update_policy = True, update_target = True, update_value = True, device = None, **kwargs)
¶
Description¶
Implements a single TD3 update.
Arguments¶
replay
(cherry.ExperienceReplay) - Offline replay to sample transitions from.policy
(cherry.nn.Policy) - Policy to optimize.action_value
(cherry.nn.ActionValue) - Twin action value to optimize; see cherry.nn.Twin.target_action_value
(cherry.nn.ActionValue) - Target action value.policy_optimizer
(torch.optim.Optimizer) - Optimizer for thepolicy
.action_value_optimizer
(torch.optim.Optimizer) - Optimizer for theaction_value
.update_policy
(bool, optional, default=True) - Whether to update the policy.update_target
(bool, optional, default=True) - Whether to update the action value target network.update_value
(bool, optional, default=True) - Whether to update the action value.device
(torch.device) - The device used to compute the update.
cherry.algorithms.trpo.TRPO
dataclass
¶
Description¶
Helper functions for implementing Trust-Region Policy Optimization.
Recall that TRPO strives to solve the following objective:
conjugate_gradient(Ax, b, num_iterations = 10, tol = 1e-10, eps = 1e-08)
staticmethod
¶
Description¶
Computes using the conjugate gradient algorithm.
Credit¶
Adapted from Kai Arulkumaran's implementation, with additions inspired from John Schulman's implementation.
References¶
- Nocedal and Wright. 2006. "Numerical Optimization, 2nd edition". Springer.
- Shewchuk et al. 1994. “An Introduction to the Conjugate Gradient Method without the Agonizing Pain.” CMU.
Arguments¶
Ax
(callable) - Given a vector x, computes A@x.b
(tensor or list) - The reference vector.num_iterations
(int, optional, default=10) - Number of conjugate gradient iterations.tol
(float, optional, default=1e-10) - Tolerance for proposed solution.eps
(float, optional, default=1e-8) - Numerical stability constant.
Returns¶
x
(tensor or list) - The solution to Ax = b, as a list if b is a list else a tensor.
hessian_vector_product(loss, parameters, damping = 1e-05)
staticmethod
¶
Description¶
Returns a callable that computes the product of the Hessian of loss (w.r.t. parameters) with another vector, using Pearlmutter's trick.
Note that parameters and the argument of the callable can be tensors or list of tensors.
References¶
- Pearlmutter, B. A. 1994. “Fast Exact Multiplication by the Hessian.” Neural Computation.
Arguments¶
loss
(tensor) - The loss of which to compute the Hessian.parameters
(tensor or list) - The tensors to take the gradient with respect to.damping
(float, optional, default=1e-5) - Damping of the Hessian-vector product.
Returns¶
hvp(other)
(callable) - A function to compute the Hessian-vector product, given a vector or listother
.
line_search(params_init, params_update, model, stop_criterion, initial_stepsize = 1.0, backtrack_factor = 0.5, max_iterations = 15)
staticmethod
¶
Description¶
Computes line-search for model parameters given a parameter update and a stopping criterion.
Credit¶
Adapted from Kai Arulkumaran's implementation, with additions inspired from John Schulman's implementation.
References¶
- Nocedal and Wright. 2006. "Numerical Optimization, 2nd edition". Springer.
Arguments¶
params_init
(tensor or iteratble) - Initial parameter values.params_update
(tensor or iteratble) - Update direction.model
(Module) - The model to be updated.stop_criterion
(callable) - Given a model, decided whether to stop the line-search.initial_stepsize
(float, optional, default=1.0) - Initial stepsize of search.backtrack_factor
(float, optional, default=0.5) - Backtracking factor.max_iterations
(int, optional, default=15) - Max number of backtracking iterations.
Returns¶
new_model
(Module) - The updated model if line-search is successful, else the model with initial parameter values.
Example¶
def ls_criterion(new_policy):
new_density = new_policy(states)
new_kl = kl_divergence(old_density, new_densityl).mean()
new_loss = - qvalue(new_density.sample()).mean()
return new_loss < policy_loss and new_kl < max_kl
with torch.no_grad():
policy = trpo.line_search(
params_init=policy.parameters(),
params_update=step,
model=policy,
criterion=ls_criterion
)
policy_loss(new_log_probs, old_log_probs, advantages)
staticmethod
¶
Description¶
The policy loss for Trust-Region Policy Optimization.
This is also known as the surrogate loss.
References¶
- Schulman et al. 2015. “Trust Region Policy Optimization.” ICML 2015.
Arguments¶
new_log_probs
(tensor) - The log-density of actions from the target policy.old_log_probs
(tensor) - The log-density of actions from the behaviour policy.advantages
(tensor) - Advantage of the actions.
Returns¶
- (tensor) - The policy loss for the given arguments.
Example¶
advantage = ch.pg.generalized_advantage(GAMMA,
TAU,
replay.reward(),
replay.done(),
replay.value(),
next_state_value)
new_densities = policy(replay.state())
new_logprobs = new_densities.log_prob(replay.action())
loss = policy_loss(new_logprobs,
replay.logprob().detach(),
advantage.detach())
cherry.algorithms.sac.SAC
dataclass
¶
Description¶
Utilities to implement SAC from [1].
The update()
function updates the function approximators in the following order:
- Entropy weight update.
- Action-value update.
- State-value update. (Optional, c.f. below)
- Policy update.
Note that most recent implementations of SAC omit step 3. above by using the Bellman residual instead of modelling a state-value function. For an example of such implementation refer to this link.
References¶
1 2 3 4 5 |
|
Arguments¶
batch_size
(int, optional, default=512) - Number of samples to get from the replay.discount
(float, optional, default=0.99) - Discount factor.use_automatic_entropy_tuning
(bool, optional, default=True) - Whether to optimize the entropy weight .policy_delay
(int, optional, default=1) - Delay between policy updates.target_delay
(int, optional, default=1) - Delay between action value updates.target_polyak_weight
(float, optional, default=0.995) - Weight factoralpha
for Polyak averaging; see cherry.models.polyak_average.
__init__(self, batch_size: int = 512, discount: float = 0.99, use_automatic_entropy_tuning: bool = True, policy_delay: int = 2, target_delay: int = 2, target_polyak_weight: float = 0.01) -> None
special
¶
action_value_loss(value, next_value, rewards, dones, gamma)
staticmethod
¶
Description¶
The action-value loss of the Soft Actor-Critic.
value
should be the value of the current state-action pair, estimated via the Q-function.
next_value
is the expected value of the next state; it can be estimated via a V-function,
or alternatively by computing the Q-value of the next observed state-action pair.
In the latter case, make sure that the action is sampled according to the current policy,
not the one used to gather the data.
Arguments¶
value
(tensor) - Action values of the actual transition.next_value
(tensor) - State values of the resulting state.rewards
(tensor) - Observed rewards of the transition.dones
(tensor) - Which states were terminal.gamma
(float) - Discount factor.
Returns¶
- (tensor) - The policy loss for the given arguments.
Example¶
value = qf(batch.state(), batch.action().detach())
next_value = targe_vf(batch.next_state())
loss = action_value_loss(value,
next_value,
batch.reward(),
batch.done(),
gamma=0.99)
policy_loss(log_probs, q_curr, alpha = 1.0)
staticmethod
¶
Description¶
The policy loss of the Soft Actor-Critic.
New actions are sampled from the target policy, and those are used to compute the Q-values. While we should back-propagate through the Q-values to the policy parameters, we shouldn't use that gradient to optimize the Q parameters. This is often avoided by either using a target Q function, or by zero-ing out the gradients of the Q function parameters.
Arguments¶
log_probs
(tensor) - Log-density of the selected actions.q_curr
(tensor) - Q-values of state-action pairs.alpha
(float, optional, default=1.0) - Entropy weight.
Returns¶
- (tensor) - The policy loss for the given arguments.
Example¶
densities = policy(batch.state())
actions = densities.sample()
log_probs = densities.log_prob(actions)
q_curr = q_function(batch.state(), actions)
loss = policy_loss(log_probs, q_curr, alpha=0.1)
update(self, replay, policy, action_value, target_action_value, log_alpha, target_entropy, policy_optimizer, features_optimizer, action_value_optimizer, alpha_optimizer, features = None, target_features = None, update_policy = True, update_target = False, update_value = True, update_entropy = True, device = None, **kwargs)
¶
Description¶
Implements a single SAC update.
Arguments¶
replay
(cherry.ExperienceReplay) - Offline replay to sample transitions from.policy
(cherry.nn.Policy) - Policy to optimize.action_value
(cherry.nn.ActionValue) - Twin action value to optimize; see cherry.nn.Twin.target_action_value
(cherry.nn.ActionValue) - Target action value.log_alpha
(torch.Tensor) - SAC's (log) entropy weight.target_entropy
(torch.Tensor) - SAC's target for the policy entropy (typically ).policy_optimizer
(torch.optim.Optimizer) - Optimizer for thepolicy
.action_value_optimizer
(torch.optim.Optimizer) - Optimizer for theaction_value
.features_optimizer
(torch.optim.Optimizer) - Optimizer for thefeatures
.alpha_optimizer
(torch.optim.Optimizer) - Optimizer forlog_alpha
.features
(torch.nn.Module, optional, default=None) - Feature extractor for the policy and action value.target_features
(torch.nn.Module, optional, default=None) - Feature extractor for the target action value.update_policy
(bool, optional, default=True) - Whether to update the policy.update_target
(bool, optional, default=False) - Whether to update the action value target network.update_value
(bool, optional, default=True) - Whether to update the action value.update_entropy
(bool, optional, default=True) - Whether to update the entropy weight.device
(torch.device) - The device used to compute the update.