Artificial Intelligence 10 min read

Using TorchRL to Implement Multi‑Agent PPO for MARL

This tutorial explains how to set up a multi‑agent reinforcement learning (MARL) environment with VMAS, install required dependencies, configure PPO hyper‑parameters, build policy and critic networks, collect data with TorchRL, and run a training loop to train agents for coordinated navigation tasks.

Python Programming Learning Circle
Python Programming Learning Circle
Python Programming Learning Circle
Using TorchRL to Implement Multi‑Agent PPO for MARL

With the emergence of multi‑agent systems, reinforcement learning (RL) becomes increasingly complex, and dedicated libraries such as TorchRL offer a robust framework for developing and experimenting with multi‑agent RL (MARL) algorithms, focusing on proximal policy optimization (PPO) in multi‑agent settings.

We use the VMAS simulator, a multi‑robot environment that runs on GPU, where several robots must navigate to individual goals while avoiding collisions.

Dependencies

!pip3 install torchrl
!pip3 install vmas
!pip3 install tqdm

Understanding Proximal Policy Optimization (PPO)

PPO is a policy‑gradient method that iteratively samples data from the environment and directly uses those samples to update the policy. The algorithm alternates between a sampling phase and a training phase, applying the collected data immediately to improve the policy.

Online Learning

In PPO, a critic estimates the value of states by comparing expected returns with actual outcomes, guiding policy updates. In multi‑agent setups, each agent can have its own policy based on local observations, while the critic can be either centralized (MAPPO) or decentralized (IPPO).

MAPPO: Centralized critic that receives global or concatenated observations, useful when full state information is available during training.

IPPO: Decentralized critic that relies only on local observations, supporting fully distributed training.

TorchRL Overview

TorchRL is a PyTorch‑based RL library designed for researchers and developers. It offers deep integration with PyTorch, modular components, GPU acceleration, support for many environments (Gym, DeepMind Control Suite, etc.), and implementations of popular algorithms such as DQN, PPO, and SAC.

1. Hyper‑parameters

import torch
from torch import multiprocessing

# Device selection
is_fork = multiprocessing.get_start_method() == "fork"
device = torch.device(0) if torch.cuda.is_available() and not is_fork else torch.device("cpu")
vmas_device = device

# Sampling and training parameters
frames_per_batch = 6000
n_iters = 10
total_frames = frames_per_batch * n_iters

# Training details
num_epochs = 30
minibatch_size = 400
lr = 3e-4
max_grad_norm = 1.0

# PPO specific
clip_epsilon = 0.2
gamma = 0.99
lmbda = 0.9
entropy_eps = 1e-4

2. Environment Creation

from torchrl.envs.libs.vmas import VmasEnv

max_steps = 100
num_vmas_envs = frames_per_batch // max_steps
scenario_name = "navigation"
n_agents = 3

env = VmasEnv(
    scenario=scenario_name,
    num_envs=num_vmas_envs,
    continuous_actions=True,
    max_steps=max_steps,
    device=vmas_device,
    n_agents=n_agents,
)

3. Policy Network Design

from torch.nn import Sequential, Tanh
from tensordict.nn import TensorDictModule
from torchrl.modules import MultiAgentMLP, ProbabilisticActor, TanhNormal
from tensordict.nn.distributions import NormalParamExtractor

share_parameters_policy = True

policy_net = Sequential(
    MultiAgentMLP(
        n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1],
        n_agent_outputs=2 * env.action_spec.shape[-1],
        n_agents=env.n_agents,
        centralised=False,
        share_params=share_parameters_policy,
        device=device,
        depth=2,
        num_cells=256,
        activation_class=Tanh,
    ),
    NormalParamExtractor(),
)

policy_module = TensorDictModule(
    policy_net,
    in_keys=[("agents", "observation")],
    out_keys=[("agents", "loc"), ("agents", "scale")],
)

policy = ProbabilisticActor(
    module=policy_module,
    spec=env.unbatched_action_spec,
    in_keys=[("agents", "loc"), ("agents", "scale")],
    out_keys=[env.action_key],
    distribution_class=TanhNormal,
    distribution_kwargs={
        "low": env.unbatched_action_spec[env.action_key].space.low,
        "high": env.unbatched_action_spec[env.action_key].space.high,
    },
    return_log_prob=True,
    log_prob_key=("agents", "sample_log_prob"),
)

4. Critic Network Design

share_parameters_critic = True
mappo = True  # Set to False to use IPPO

critic_net = MultiAgentMLP(
    n_agent_inputs=env.observation_spec["agents", "observation"].shape[-1],
    n_agent_outputs=1,
    n_agents=env.n_agents,
    centralised=mappo,
    share_params=share_parameters_critic,
    device=device,
    depth=2,
    num_cells=256,
    activation_class=Tanh,
)

critic = TensorDictModule(
    module=critic_net,
    in_keys=[("agents", "observation")],
    out_keys=[("agents", "state_value")],
)

5. Data Collection

from torchrl.collectors import SyncDataCollector

collector = SyncDataCollector(
    env,
    policy,
    device=vmas_device,
    storing_device=device,
    frames_per_batch=frames_per_batch,
    total_frames=total_frames,
)

6. Training Loop

from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.objectives import ClipPPOLoss, ValueEstimators
from tqdm import tqdm

# Replay buffer
replay_buffer = ReplayBuffer(
    storage=LazyTensorStorage(size=frames_per_batch, device=device),
    sampler=SamplerWithoutReplacement(),
)

# PPO loss
ppo_loss = ClipPPOLoss(
    actor=policy,
    critic=critic,
    clip_epsilon=clip_epsilon,
    entropy_bonus=entropy_eps,
    value_loss_coef=0.5,
    gamma=gamma,
    lam=lmbda,
    advantage_normalization=True,
)

# Training iterations
for i in tqdm(range(n_iters)):
    batch = collector.next()
    replay_buffer.extend(batch)
    for _ in range(num_epochs):
        for minibatch in replay_buffer.sample(minibatch_size):
            loss = ppo_loss(minibatch)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(ppo_loss.parameters(), max_grad_norm)
            optimizer.step()
            optimizer.zero_grad()

The complete script demonstrates how TorchRL simplifies the development of MARL solutions, leveraging GPU acceleration and parallelism to train agents that navigate complex tasks while offering flexibility to choose centralized or decentralized critics.

Conclusion

This guide provides a comprehensive walkthrough for implementing a MARL solution with TorchRL and PPO, enabling efficient training of agents in multi‑agent environments, exploiting GPU‑accelerated simulation, and allowing experimentation with different critic architectures to address the challenges of multi‑agent reinforcement learning.

Pythondeep learningMulti-Agent Reinforcement LearningPPOTorchRLVMAS
Python Programming Learning Circle
Written by

Python Programming Learning Circle

A global community of Chinese Python developers offering technical articles, columns, original video tutorials, and problem sets. Topics include web full‑stack development, web scraping, data analysis, natural language processing, image processing, machine learning, automated testing, DevOps automation, and big data.

0 followers
Reader feedback

How this landed with the community

login Sign in to like

Rate this article

Was this worth your time?

Sign in to rate
Discussion

0 Comments

Thoughtful readers leave field notes, pushback, and hard-won operational detail here.