Adding a New RL Method#
This guide walks through the process of implementing a new reinforcement learning method in LeggedGym-Ex. The framework uses a modular architecture where each RL method consists of four core components that must work together seamlessly.
Note
Before starting, review existing implementations in rsl_rl/ to understand the patterns. The Teacher-Student (TS) variant provides a complete reference implementation.
Architecture Overview#
LeggedGym-Ex follows a component-based architecture for RL methods. Each method requires four interconnected components:
Component |
Purpose |
Base Class |
Example File |
|---|---|---|---|
Algorithm |
Core RL logic, loss computation, optimization |
|
|
Actor-Critic |
Neural network policy and value function |
|
|
Storage |
Rollout buffer for experience collection |
|
|
Runner |
Training orchestration, logging, checkpointing |
|
|
Important
All four components must be compatible with each other. Using mismatched components (e.g., PPO_TS algorithm with base RolloutStorage) will cause runtime errors.
Step 1: Create Algorithm Class#
The algorithm class implements the core RL logic. Extend BaseAlgorithm or an existing algorithm like PPO for PPO-based methods.
1.1 Basic Structure#
Create a new file in rsl_rl/algorithms/:
# rsl_rl/algorithms/ppo_custom.py
from __future__ import annotations
from typing import Tuple, Optional
import torch
import torch.nn as nn
import torch.optim as optim
from rsl_rl.algorithms.ppo import PPO
from rsl_rl.modules import ActorCriticCustom
from rsl_rl.storage import RolloutStorageCustom
class PPO_Custom(PPO):
"""Custom PPO variant with specific modifications.
This class extends the base PPO algorithm with custom features.
"""
actor_critic: ActorCriticCustom
storage: Optional[RolloutStorageCustom]
transition: RolloutStorageCustom.Transition
def __init__(
self,
actor_critic: ActorCriticCustom,
num_learning_epochs: int = 1,
num_mini_batches: int = 1,
clip_param: float = 0.2,
gamma: float = 0.998,
lam: float = 0.95,
value_loss_coef: float = 1.0,
entropy_coef: float = 0.0,
learning_rate: float = 1e-3,
max_grad_norm: float = 1.0,
use_clipped_value_loss: bool = True,
schedule: str = "fixed",
desired_kl: Optional[float] = 0.01,
device: str = 'cpu',
# Add custom parameters here
custom_param: float = 0.5,
) -> None:
super().__init__(
actor_critic,
num_learning_epochs,
num_mini_batches,
clip_param,
gamma,
lam,
value_loss_coef,
entropy_coef,
learning_rate,
max_grad_norm,
use_clipped_value_loss,
schedule,
desired_kl,
device=device,
)
self.custom_param = custom_param
# Override components if needed
self.actor_critic = actor_critic
self.actor_critic.to(self.device)
# Custom optimizer setup
self.optimizer = optim.Adam(
self.actor_critic.parameters(),
lr=learning_rate
)
# Initialize transition storage
self.transition = RolloutStorageCustom.Transition()
1.2 Override Core Methods#
At minimum, override these methods:
def init_storage(
self,
num_envs: int,
num_transitions_per_env: int,
actor_obs_shape: Tuple[int, ...],
privileged_obs_shape: Tuple[int, ...],
critic_obs_shape: Tuple[int, ...],
action_shape: Tuple[int, ...],
) -> None:
"""Initialize storage with custom observation shapes."""
self.storage = RolloutStorageCustom(
num_envs,
num_transitions_per_env,
actor_obs_shape,
privileged_obs_shape,
critic_obs_shape,
action_shape,
self.device
)
def act(
self,
obs: torch.Tensor,
privileged_obs: torch.Tensor,
critic_obs: torch.Tensor
) -> torch.Tensor:
"""Compute actions during rollout collection.
This method must:
1. Store observations in transition
2. Compute and store actions, values, log probs
3. Return actions for environment stepping
"""
# Store observations
self.transition.observations = obs
self.transition.privileged_observations = privileged_obs
self.transition.critic_observations = critic_obs
# Compute actions
self.transition.actions = self.actor_critic.act(
obs, privileged_obs
).detach()
# Compute values and log probabilities
self.transition.values = self.actor_critic.evaluate(
critic_obs
).detach()
self.transition.actions_log_prob = self.actor_critic.get_actions_log_prob(
self.transition.actions
).detach()
# Store action distribution parameters
self.transition.action_mean = self.actor_critic.action_mean.detach()
self.transition.action_sigma = self.actor_critic.action_std.detach()
return self.transition.actions
def update(self) -> Tuple[float, ...]:
"""Update policy using collected experiences.
Returns:
Tuple of loss values for logging.
"""
mean_value_loss = 0.0
mean_surrogate_loss = 0.0
# Get mini-batch generator
generator = self.storage.mini_batch_generator(
self.num_mini_batches,
self.num_learning_epochs
)
for batch in generator:
obs_batch, privileged_obs_batch, critic_obs_batch, \
actions_batch, values_batch, advantages_batch, returns_batch, \
old_log_prob_batch, old_mu_batch, old_sigma_batch, \
hidden_states_batch, masks_batch = batch
# Compute losses
loss, surrogate_loss, value_loss = self._compute_loss(
obs_batch, privileged_obs_batch, critic_obs_batch,
actions_batch, values_batch, advantages_batch, returns_batch,
old_log_prob_batch, old_mu_batch, old_sigma_batch,
hidden_states_batch, masks_batch
)
# Gradient step
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(
self.actor_critic.parameters(),
self.max_grad_norm
)
self.optimizer.step()
mean_value_loss += value_loss.item()
mean_surrogate_loss += surrogate_loss.item()
# Clear storage after update
self.storage.clear()
num_updates = self.num_learning_epochs * self.num_mini_batches
return (
mean_value_loss / num_updates,
mean_surrogate_loss / num_updates,
)
1.3 Algorithm Registration#
Add to rsl_rl/algorithms/__init__.py:
from .ppo_custom import PPO_Custom
__all__.append("PPO_Custom")
Step 2: Create Actor-Critic Module#
The actor-critic module defines the neural network architecture for your policy and value function.
2.1 Basic Structure#
Create a new file in rsl_rl/modules/:
# rsl_rl/modules/actor_critic_custom.py
import torch
import torch.nn as nn
from torch.distributions import Normal
class ActorCriticCustom(nn.Module):
"""Custom actor-critic network for PPO_Custom."""
is_recurrent = False # Set to True for RNN policies
def __init__(
self,
num_actor_obs: int,
num_actions: int,
num_critic_obs: int,
actor_hidden_dims: list = [512, 256, 128],
critic_hidden_dims: list = [512, 256, 128],
activation: str = 'elu',
init_noise_std: float = 1.0,
**kwargs
):
super().__init__()
# Get activation function
activation_fn = self._get_activation(activation)
# Build actor network
actor_layers = []
actor_layers.append(nn.Linear(num_actor_obs, actor_hidden_dims[0]))
actor_layers.append(activation_fn)
for i in range(len(actor_hidden_dims) - 1):
actor_layers.append(
nn.Linear(actor_hidden_dims[i], actor_hidden_dims[i + 1])
)
actor_layers.append(activation_fn)
actor_layers.append(
nn.Linear(actor_hidden_dims[-1], num_actions)
)
self.actor = nn.Sequential(*actor_layers)
# Build critic network
critic_layers = []
critic_layers.append(nn.Linear(num_critic_obs, critic_hidden_dims[0]))
critic_layers.append(activation_fn)
for i in range(len(critic_hidden_dims) - 1):
critic_layers.append(
nn.Linear(critic_hidden_dims[i], critic_hidden_dims[i + 1])
)
critic_layers.append(activation_fn)
critic_layers.append(nn.Linear(critic_hidden_dims[-1], 1))
self.critic = nn.Sequential(*critic_layers)
# Action noise parameter
self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
self.distribution = None
# Disable validation for speed
Normal.set_default_validate_args = False
@staticmethod
def _get_activation(name: str) -> nn.Module:
"""Get activation function by name."""
activations = {
'elu': nn.ELU(),
'relu': nn.ReLU(),
'selu': nn.SELU(),
'tanh': nn.Tanh(),
'leaky_relu': nn.LeakyReLU(),
}
return activations.get(name, nn.ELU())
2.2 Implement Required Methods#
def act(self, observations: torch.Tensor, **kwargs) -> torch.Tensor:
"""Sample actions from the policy."""
self.update_distribution(observations)
return self.distribution.sample()
def update_distribution(self, observations: torch.Tensor) -> None:
"""Update the action distribution."""
mean = self.actor(observations)
self.distribution = Normal(mean, mean * 0.0 + self.std)
def get_actions_log_prob(self, actions: torch.Tensor) -> torch.Tensor:
"""Get log probability of actions."""
return self.distribution.log_prob(actions).sum(dim=-1)
def evaluate(self, critic_observations: torch.Tensor, **kwargs) -> torch.Tensor:
"""Evaluate the value function."""
return self.critic(critic_observations)
@property
def action_mean(self) -> torch.Tensor:
return self.distribution.mean
@property
def action_std(self) -> torch.Tensor:
return self.distribution.stddev
@property
def entropy(self) -> torch.Tensor:
return self.distribution.entropy().sum(dim=-1)
def reset(self, dones: torch.Tensor = None) -> None:
"""Reset hidden states for recurrent policies."""
pass # No-op for non-recurrent
2.3 Module Registration#
Add to rsl_rl/modules/__init__.py:
from .actor_critic_custom import ActorCriticCustom
Step 3: Create Storage Class#
The storage class manages rollout buffers for experience collection. Extend RolloutStorage for most cases.
3.1 Basic Structure#
Create a new file in rsl_rl/storage/:
# rsl_rl/storage/rollout_storage_custom.py
import torch
from .rollout_storage import RolloutStorage
class RolloutStorageCustom(RolloutStorage):
"""Custom rollout storage for PPO_Custom."""
class Transition(RolloutStorage.Transition):
"""Extended transition for custom observations."""
def __init__(self):
super().__init__()
# Add custom observation fields
self.custom_observations = None
self.privileged_observations = None
self.critic_observations = None
def __init__(
self,
num_envs: int,
num_transitions_per_env: int,
obs_shape: tuple,
privileged_obs_shape: tuple,
critic_obs_shape: tuple,
actions_shape: tuple,
device: str = 'cpu'
):
super().__init__(
num_envs,
num_transitions_per_env,
obs_shape,
privileged_obs_shape, # Used as privileged obs
actions_shape,
device
)
# Additional storage for custom observations
self.critic_obs_shape = critic_obs_shape
self.critic_observations = torch.zeros(
num_transitions_per_env,
num_envs,
*critic_obs_shape,
device=self.device
)
3.2 Override Storage Methods#
def add_transitions(self, transition: Transition) -> None:
"""Add a transition to the buffer."""
if self.step >= self.num_transitions_per_env:
raise AssertionError("Rollout buffer overflow")
# Store all transition data
self.observations[self.step].copy_(transition.observations)
self.privileged_observations[self.step].copy_(
transition.privileged_observations
)
self.critic_observations[self.step].copy_(
transition.critic_observations
)
self.actions[self.step].copy_(transition.actions)
self.rewards[self.step].copy_(transition.rewards.view(-1, 1))
self.dones[self.step].copy_(transition.dones.view(-1, 1))
self.values[self.step].copy_(transition.values)
self.actions_log_prob[self.step].copy_(
transition.actions_log_prob.view(-1, 1)
)
self.mu[self.step].copy_(transition.action_mean)
self.sigma[self.step].copy_(transition.action_sigma)
self.step += 1
def mini_batch_generator(self, num_mini_batches: int, num_epochs: int = 8):
"""Generate mini-batches for training."""
batch_size = self.num_envs * self.num_transitions_per_env
mini_batch_size = batch_size // num_mini_batches
indices = torch.randperm(
num_mini_batches * mini_batch_size,
requires_grad=False,
device=self.device
)
# Flatten all buffers
observations = self.observations.flatten(0, 1)
privileged_observations = self.privileged_observations.flatten(0, 1)
critic_observations = self.critic_observations.flatten(0, 1)
actions = self.actions.flatten(0, 1)
values = self.values.flatten(0, 1)
returns = self.returns.flatten(0, 1)
old_actions_log_prob = self.actions_log_prob.flatten(0, 1)
advantages = self.advantages.flatten(0, 1)
old_mu = self.mu.flatten(0, 1)
old_sigma = self.sigma.flatten(0, 1)
for epoch in range(num_epochs):
for i in range(num_mini_batches):
start = i * mini_batch_size
end = (i + 1) * mini_batch_size
batch_idx = indices[start:end]
yield (
observations[batch_idx],
privileged_observations[batch_idx],
critic_observations[batch_idx],
actions[batch_idx],
values[batch_idx],
advantages[batch_idx],
returns[batch_idx],
old_actions_log_prob[batch_idx],
old_mu[batch_idx],
old_sigma[batch_idx],
(None, None), # Hidden states for RNN
None, # Masks for RNN
)
3.3 Storage Registration#
Add to rsl_rl/storage/__init__.py:
from .rollout_storage_custom import RolloutStorageCustom
Step 4: Create Runner Class#
The runner orchestrates the training loop, handles logging, and manages checkpoints.
4.1 Basic Structure#
Create a new file in rsl_rl/runners/:
# rsl_rl/runners/custom_runner.py
from typing import Optional, Union, Callable, Dict, Any, List
import torch
from collections import deque
import statistics
import time
from rsl_rl.algorithms import PPO_Custom
from rsl_rl.modules import ActorCriticCustom
from rsl_rl.env import VecEnv
from .on_policy_runner import OnPolicyRunner
class CustomRunner(OnPolicyRunner):
"""Custom runner for PPO_Custom training."""
def __init__(
self,
env: VecEnv,
train_cfg: Dict[str, Any],
log_dir: Optional[str] = None,
device: Union[str, torch.device] = "cpu",
) -> None:
super().__init__(env, train_cfg, log_dir, device)
def _init_agent_and_algo(self) -> None:
"""Initialize actor-critic and algorithm."""
actor_critic_class = eval(self.cfg["policy_class_name"])
actor_critic: ActorCriticCustom = actor_critic_class(
self.env.num_obs,
self.env.num_actions,
self.env.num_critic_obs, # Must be defined in env
**self.policy_cfg
).to(self.device)
alg_class = eval(self.cfg["algorithm_class_name"])
self.alg: PPO_Custom = alg_class(
actor_critic,
device=self.device,
**self.alg_cfg
)
def _init_storage(self) -> None:
"""Initialize storage with correct shapes."""
self.alg.init_storage(
self.env.num_envs,
self.num_steps_per_env,
(self.env.num_obs,),
(self.env.num_privileged_obs,),
(self.env.num_critic_obs,),
(self.env.num_actions,),
)
4.2 Override Training Loop#
def learn(
self,
num_learning_iterations: int,
init_at_random_ep_len: bool = False,
) -> None:
"""Main training loop."""
# Pre-learn setup
if init_at_random_ep_len:
self.env.episode_length_buf = torch.randint(
0,
self.env.max_episode_length,
(self.env.num_envs,),
device=self.env.device
)
# Get initial observations
obs, privileged_obs, critic_obs = self.env.get_observations()
obs = obs.to(self.device)
privileged_obs = privileged_obs.to(self.device)
critic_obs = critic_obs.to(self.device)
self.alg.actor_critic.train()
# Episode tracking buffers
ep_infos: List[Dict[str, Any]] = []
rewbuffer = deque(maxlen=100)
lenbuffer = deque(maxlen=100)
cur_reward_sum = torch.zeros(
self.env.num_envs,
dtype=torch.float,
device=self.device
)
cur_episode_length = torch.zeros(
self.env.num_envs,
dtype=torch.float,
device=self.device
)
# Main training loop
tot_iter = self.current_learning_iteration + num_learning_iterations
for it in range(self.current_learning_iteration, tot_iter):
start = time.time()
# Rollout collection
with torch.inference_mode():
for i in range(self.num_steps_per_env):
actions = self.alg.act(obs, privileged_obs, critic_obs)
obs, privileged_obs, critic_obs, rewards, dones, infos = \
self.env.step(actions)
obs = obs.to(self.device)
privileged_obs = privileged_obs.to(self.device)
critic_obs = critic_obs.to(self.device)
rewards = rewards.to(self.device)
dones = dones.to(self.device)
self.alg.process_env_step(rewards, dones, infos)
# Episode tracking
if self.log_dir is not None:
if 'episode' in infos:
ep_infos.append(infos['episode'])
cur_reward_sum += rewards
cur_episode_length += 1
new_ids = (dones > 0).nonzero(as_tuple=False)
rewbuffer.extend(
cur_reward_sum[new_ids][:, 0].cpu().numpy().tolist()
)
lenbuffer.extend(
cur_episode_length[new_ids][:, 0].cpu().numpy().tolist()
)
cur_reward_sum[new_ids] = 0
cur_episode_length[new_ids] = 0
collection_time = time.time() - start
# Learning step
start = time.time()
self.alg.compute_returns(critic_obs)
mean_value_loss, mean_surrogate_loss = self.alg.update()
learn_time = time.time() - start
# Logging
if self.log_dir is not None:
self.log(locals())
# Checkpointing
if it % self.save_interval == 0:
self.save(f"{self.log_dir}/model_{it}.pt")
ep_infos.clear()
self.current_learning_iteration += num_learning_iterations
self.save(f"{self.log_dir}/model_{self.current_learning_iteration}.pt")
4.3 Implement Inference Policy#
def get_inference_policy(
self,
device: Optional[Union[str, torch.device]] = None,
) -> Callable[[torch.Tensor], torch.Tensor]:
"""Get policy for inference/deployment."""
self.alg.actor_critic.eval()
if device is not None:
self.alg.actor_critic.to(device)
return self.alg.actor_critic.act
Step 5: Register Runner#
Register the runner in rsl_rl/runners/__init__.py:
from .custom_runner import CustomRunner
from rsl_rl.utils.runner_registry import runner_registry
runner_registry.register("CustomRunner", CustomRunner)
Step 6: Create Environment Variant (Optional)#
If your method requires special environment modifications, create a variant in legged_gym/envs/.
6.1 Define Configuration#
# legged_gym/envs/go2/go2_custom/go2_custom_config.py
from legged_gym.envs.base.legged_robot_config import (
LeggedRobotCfg,
LeggedRobotCfgPPO
)
class GO2CustomCfg(LeggedRobotCfg):
"""Configuration for GO2 with custom method."""
class env(LeggedRobotCfg.env):
pass
class runner(LeggedRobotCfg.runner):
runner_class_name = "CustomRunner"
policy_class_name = "ActorCriticCustom"
algorithm_class_name = "PPO_Custom"
class GO2CustomCfgPPO(LeggedRobotCfgPPO):
"""PPO configuration for GO2 custom method."""
class algorithm(LeggedRobotCfgPPO.algorithm):
custom_param = 0.5 # Custom parameter
6.2 Register Task#
Add to legged_gym/envs/__init__.py:
from .go2.go2_custom.go2_custom import GO2Custom
from .go2.go2_custom.go2_custom_config import GO2CustomCfg, GO2CustomCfgPPO
task_registry.register(
"go2_custom",
GO2Custom,
GO2CustomCfg,
GO2CustomCfgPPO
)
Component Compatibility Requirements#
Understanding component compatibility is critical for implementing new methods.
Storage-Algorithm Matching#
The storage must provide all observations the algorithm requires:
Algorithm |
Required Storage Fields |
|---|---|
|
|
|
|
|
|
|
|
Actor-Critic Interface Matching#
The actor-critic must implement methods the algorithm calls:
Method |
Called By |
Must Return |
|---|---|---|
|
Algorithm.act() |
Action tensor |
|
Algorithm.act() |
Value tensor |
|
Algorithm.update() |
Log probability tensor |
|
Algorithm.act() |
Distribution parameters |
Runner-Environment Matching#
The environment must provide observation shapes the runner expects:
# Environment must define these properties
self.num_obs = 48
self.num_privileged_obs = 187
self.num_critic_obs = 187 # For TS methods
self.num_history_obs = 480 # For TS methods
self.num_latent_dims = 64 # For encoder-based methods
Complete Example: Teacher-Student Implementation#
The Teacher-Student (TS) method demonstrates all components working together:
Algorithm (PPO_TS)#
class PPO_TS(PPO):
actor_critic: ActorCriticTS
storage: RolloutStorageTS
def __init__(self, actor_critic, encoder_lr=1e-3, ...):
super().__init__(actor_critic, ...)
# Separate optimizer for history encoder
self.history_encoder_optimizer = optim.Adam(
actor_critic.history_encoder.parameters(),
lr=encoder_lr
)
def act(self, obs, privileged_obs, obs_history, critic_obs):
# Store all observation types
self.transition.observations = obs
self.transition.privileged_observations = privileged_obs
self.transition.observation_histories = obs_history
self.transition.critic_observations = critic_obs
# Compute actions with privileged encoder
self.transition.actions = self.actor_critic.act(
obs, privileged_obs
).detach()
return self.transition.actions
def update(self):
# Standard PPO update
mean_value_loss, mean_surrogate_loss = self._ppo_update()
# Additional encoder distillation
mean_encoder_loss = self._encoder_update()
return mean_value_loss, mean_surrogate_loss, mean_encoder_loss
Actor-Critic (ActorCriticTS)#
class ActorCriticTS(nn.Module):
def __init__(self, num_actor_obs, num_actions,
num_privilege_encoder_input, num_history_encoder_input,
num_latent_dims, num_critic_obs, ...):
# Privilege encoder: privileged_obs -> latent
self.privilege_encoder = nn.Sequential(...)
# History encoder: obs_history -> latent (student)
self.history_encoder = nn.Sequential(...)
# Actor: [obs, latent] -> action
self.actor = nn.Sequential(...)
# Critic: critic_obs -> value
self.critic = nn.Sequential(...)
def act(self, observations, privilege_observations):
latent = self.privilege_encoder(privilege_observations)
mean = self.actor(torch.cat([observations, latent], dim=-1))
self.distribution = Normal(mean, self.std)
return self.distribution.sample()
def act_student(self, observations, observation_history):
# For deployment: use history encoder instead
latent = self.history_encoder(observation_history)
mean = self.actor(torch.cat([observations, latent], dim=-1))
return mean
Storage (RolloutStorageTS)#
class RolloutStorageTS(RolloutStorage):
class Transition(RolloutStorage.Transition):
def __init__(self):
super().__init__()
self.privileged_observations = None
self.observation_histories = None
def __init__(self, num_envs, num_transitions_per_env,
obs_shape, privileged_obs_shape,
obs_history_shape, critic_obs_shape, ...):
# Additional buffers for TS-specific observations
self.observation_histories = torch.zeros(...)
self.critic_observations = torch.zeros(...)
Runner (TSRunner)#
class TSRunner(OnPolicyRunner):
def _init_agent_and_algo(self):
actor_critic = ActorCriticTS(
self.env.num_obs,
self.env.num_actions,
self.env.num_privileged_obs,
self.env.num_history_obs,
self.env.num_latent_dims,
self.env.num_critic_obs,
**self.policy_cfg
)
self.alg = PPO_TS(actor_critic, **self.alg_cfg)
def learn(self, num_learning_iterations, ...):
# Get all observation types
obs, privileged_obs, obs_history, critic_obs = \
self.env.get_observations()
# Training loop with 4 observation types
...
def get_inference_policy(self, device=None):
# Return student policy for deployment
return self.alg.actor_critic.act_student
Testing Your Implementation#
After implementing all components, test with:
# Quick test with few environments
python -m legged_gym.scripts.train --task go2_custom --num_envs 10 --headless
# Full training
python -m legged_gym.scripts.train --task go2_custom --headless
# Inference test
python -m legged_gym.scripts.play --task go2_custom
Summary#
Implementing a new RL method requires:
Algorithm: Extend
BaseAlgorithmorPPO, implementinit_storage(),act(),update()Actor-Critic: Build neural networks, implement
act(),evaluate(), distribution methodsStorage: Extend
RolloutStorage, add custom observation buffersRunner: Extend
OnPolicyRunner, customize initialization and training loopRegistration: Register runner in
runners/__init__.pyEnvironment: Create config variant if needed, register task
Always ensure component compatibility: the algorithm must match the storage format, and the actor-critic must implement the methods the algorithm expects. Reference existing implementations like PPO_TS for complete working examples.