Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

R2D2 #34

Merged
merged 77 commits into from
Mar 3, 2025
Merged

R2D2 #34

Changes from all commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
0543ae4
R2D2 intial commit, not tested
garymm Feb 19, 2025
1055fbe
more WIP. Trying to get Atari frames in
garymm Feb 20, 2025
f9f60f3
atari input at least doesn't crash
garymm Feb 20, 2025
8ea67e1
add test for atari input and fix q value size
garymm Feb 20, 2025
a236ce0
fix static fields in r2d2 networks
garymm Feb 20, 2025
21327c0
gymnasium_loop: fix mixed up argument order
garymm Feb 20, 2025
260cbf9
r2d2 WIP
garymm Feb 20, 2025
ab84e4a
r2d2: some fixes. still not done
garymm Feb 20, 2025
844b796
r2d2: add value rescaling
garymm Feb 20, 2025
940d216
port code over from acme/agents/jax/r2d2
garymm Feb 21, 2025
66507b3
update experience state in loss
garymm Feb 21, 2025
afb5a99
gymnasium_loop: fix bug. only copy one net replica
garymm Feb 21, 2025
1ba65bc
sharding: fix docstring
garymm Feb 21, 2025
9a892ac
r2d2: runs for 2 cycles!
garymm Feb 21, 2025
5647447
change grad means metric names to make Mlflow happy
garymm Feb 22, 2025
ad743ba
scripts to train r2d2 on atari
garymm Feb 22, 2025
db4d7c5
gymnasium_loop: fix buffer donation bug
garymm Feb 22, 2025
40749b2
r2d2: experiment runner files
garymm Feb 22, 2025
7a82127
fix bug in _sample_from_experience
garymm Feb 24, 2025
6489689
use better import path
garymm Feb 24, 2025
18bd558
fix buffer update
garymm Feb 24, 2025
d444d6b
minor cleanup
garymm Feb 24, 2025
7ab4825
r2d2: epslion-greedy and remove incremental updates
garymm Feb 26, 2025
961cb01
suppress warnings
garymm Feb 26, 2025
9e2525b
epsilon greedy schedule and more debugging
garymm Feb 26, 2025
a48c3cb
make lstm optional
garymm Feb 26, 2025
d6d001f
add TODO about stop grad after burn in
garymm Feb 26, 2025
76e6f2f
use optax.incremental_update
garymm Feb 27, 2025
ac6bcde
r2d2: tests passing!
garymm Feb 27, 2025
68f52d8
set adam eps value to what it was in r2d2 paper
garymm Feb 27, 2025
49e90af
test_learns_cartpole passing with LSTM
garymm Feb 27, 2025
cfaec6b
use jax-loop-utils from PyPi
garymm Feb 27, 2025
059475a
remove unused filter_incremental_update
garymm Feb 27, 2025
5a1fcee
fix test_sample_from_experince
garymm Feb 27, 2025
cee2658
make cartpole test easier for now
garymm Feb 27, 2025
25c1621
implement prioritized replay
garymm Feb 27, 2025
1ea3bd1
support sticky actions for exploration
garymm Feb 28, 2025
773a36f
use tensorboard instead of mlflow
garymm Feb 28, 2025
3540374
some updates to run_gymnax_cartpole
garymm Feb 28, 2025
e840390
ignore logs dir
garymm Feb 28, 2025
50d8c55
upgrade basedpyright
garymm Feb 28, 2025
3bcbe79
new notebook for asterix
garymm Feb 28, 2025
c3bb26d
remove commented out code
garymm Feb 28, 2025
073584e
remove unneeded warning suppression
garymm Feb 28, 2025
86a3aa5
better hyperparemeters for test_learns_cartpole
garymm Feb 28, 2025
206084d
require python 3.11, not 3.12, for Colab
garymm Feb 28, 2025
c4401c9
WIP: asterix notebook
garymm Feb 28, 2025
293f3c5
fix sharding.pytree_get_index_0
garymm Mar 1, 2025
5106531
env_info_from_gymnasium: support vecenv
garymm Mar 1, 2025
750ab7b
WIP: support envpool
garymm Mar 1, 2025
c4d7f7a
delete unused runners
garymm Mar 1, 2025
219916e
fix shard agent state
garymm Mar 1, 2025
80381fe
add render atari observe cycle
garymm Mar 1, 2025
f750cb0
WIP: asterix atari
garymm Mar 1, 2025
0f837eb
run_atari: assert num envs per learner even
garymm Mar 2, 2025
1981640
r2d2: support replaying larger batches
garymm Mar 2, 2025
8336410
ignore warning triggered by envpool
garymm Mar 2, 2025
5e6f4f2
test_r2d2: use env_factory
garymm Mar 2, 2025
b881812
fix metrickey import
garymm Mar 2, 2025
a8d2c08
restore test_learns_cartpole
garymm Mar 2, 2025
db2ac3c
set priority to 1 for new experience
garymm Mar 2, 2025
2faf2b0
double replay batch size
garymm Mar 2, 2025
84b81ef
improve error message
garymm Mar 2, 2025
0e12c9c
fix dtype support in resnet
garymm Mar 2, 2025
cb5ceca
gymnasium_loop: fix bug when len(learner_devices) > 1
garymm Mar 2, 2025
f0258a7
run minasterix longer
garymm Mar 3, 2025
5b519b4
vs code setting: ignore git limit warning
garymm Mar 3, 2025
14ce38d
start to fix bazel test
garymm Mar 3, 2025
727603f
fix gymnasium tests
garymm Mar 3, 2025
152145c
fix run_experiment for env_factory
garymm Mar 3, 2025
a8fb7d7
suppress false pyright error
garymm Mar 3, 2025
f3a52c3
rename and delete runners
garymm Mar 3, 2025
4ad97b3
fix some broken stuff
garymm Mar 3, 2025
587a9df
set long timeout for slow github runner
garymm Mar 3, 2025
9781362
shard test_run_experiment
garymm Mar 3, 2025
36891f1
split learns cartpole to separate test
garymm Mar 3, 2025
a18c16f
shorten test
garymm Mar 3, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
bazel-*
__pycache__
uv.lock
# mlflow
mlruns
logs
3 changes: 2 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -24,5 +24,6 @@
"bazel.buildifierExecutable": "@buildifier_prebuilt//:buildifier",
"python.testing.pytestArgs": [
"earl"
]
],
"git.ignoreLimitWarning": true
}
4 changes: 3 additions & 1 deletion BUILD.bazel
Original file line number Diff line number Diff line change
@@ -8,8 +8,10 @@ pip_compile(
"--emit-index-url",
"--no-strip-extras",
"--extra=test",
"--extra=agent-r2d2",
"--index=https://download.pytorch.org/whl/cpu",
],
python_platform = "x86_64-unknown-linux-gnu",
python_platform = "x86_64-manylinux_2_28", # envpool needs at least 2_24
requirements_in = "//:pyproject.toml",
requirements_txt = "requirements_linux_x86_64.txt",
)
4 changes: 2 additions & 2 deletions MODULE.bazel
Original file line number Diff line number Diff line change
@@ -6,15 +6,15 @@ module(
)

# BEGIN python toolchain
_PYTHON_VERSION = "3.12"
_PYTHON_VERSION = "3.11" # latest that supports envpool

bazel_dep(name = "rules_python", version = "1.1.0", dev_dependency = True)

python = use_extension("@rules_python//python/extensions:python.bzl", "python")
python.toolchain(python_version = _PYTHON_VERSION)
# END python toolchain

# BEGIN python dependencies
# BEGIN python dependenciesP
bazel_dep(name = "rules_uv", version = "0.53.0", dev_dependency = True)

pip = use_extension("@rules_python//python/extensions:pip.bzl", "pip")
1,219 changes: 871 additions & 348 deletions MODULE.bazel.lock

Large diffs are not rendered by default.

67 changes: 67 additions & 0 deletions earl/agents/r2d2/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
load("@aspect_rules_py//py:defs.bzl", "py_library")
load("//tools/py_test:py_test.bzl", "py_test")

py_library(
name = "r2d2",
srcs = [
"networks.py",
"r2d2.py",
"utils.py",
],
visibility = ["//visibility:public"],
deps = [
"//earl:core",
"@pypi//chex",
"@pypi//distrax",
"@pypi//equinox",
"@pypi//jax",
"@pypi//jaxtyping",
"@pypi//optax",
],
)

py_test(
name = "test_r2d2",
timeout = "long",
srcs = ["test_r2d2.py"],
filterwarnings = [
"ignore:jax.interpreters.xla.pytype_aval_mappings is deprecated.:DeprecationWarning",
"ignore:Shape is deprecated; use StableHLO instead.:DeprecationWarning",
],
shard_count = 2,
deps = [
":r2d2",
"//earl:core",
"//earl/environment_loop:gymnasium_loop",
"//earl/environment_loop:gymnax_loop",
"@pypi//envpool",
"@pypi//gymnasium",
"@pypi//gymnax",
"@pypi//jax",
"@pypi//jax_loop_utils",
"@pypi//numpy",
"@pypi//optax",
"@pypi//pytest",
],
)

py_test(
name = "test_r2d2_learns",
timeout = "long",
srcs = ["test_r2d2_learns.py"],
filterwarnings = [
"ignore:jax.interpreters.xla.pytype_aval_mappings is deprecated.:DeprecationWarning",
"ignore:Shape is deprecated; use StableHLO instead.:DeprecationWarning",
],
tags = ["manual"],
deps = [
":r2d2",
"//earl:core",
"//earl/environment_loop:gymnax_loop",
"@pypi//gymnax",
"@pypi//jax",
"@pypi//jax_loop_utils",
"@pypi//numpy",
"@pypi//pytest",
],
)
380 changes: 380 additions & 0 deletions earl/agents/r2d2/networks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,380 @@
import copy
import math
from collections.abc import Callable, Sequence

import equinox as eqx
import jax
import jax.numpy as jnp
import jaxtyping


class ResidualBlock(eqx.Module):
"""Residual block of operations, e.g. convolutional."""

inner_op1: eqx.nn.Conv2d
inner_op2: eqx.nn.Conv2d
layernorm1: eqx.nn.LayerNorm | None
layernorm2: eqx.nn.LayerNorm | None
use_layer_norm: bool

def __init__(
self,
num_channels: int,
use_layer_norm: bool = False,
dtype: jnp.dtype = jnp.float32,
*,
key: jaxtyping.PRNGKeyArray,
):
keys = jax.random.split(key, 2)
self.inner_op1 = eqx.nn.Conv2d(
num_channels, num_channels, kernel_size=3, padding=1, key=keys[0], dtype=dtype
)
self.inner_op2 = eqx.nn.Conv2d(
num_channels, num_channels, kernel_size=3, padding=1, key=keys[1], dtype=dtype
)
self.use_layer_norm = use_layer_norm

if use_layer_norm:
# Shape is (channels) for the normalization, we'll vmap over H,W
self.layernorm1 = eqx.nn.LayerNorm(num_channels, eps=1e-6, dtype=dtype)
self.layernorm2 = eqx.nn.LayerNorm(num_channels, eps=1e-6, dtype=dtype)
else:
self.layernorm1 = None
self.layernorm2 = None

def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
output = x

# First layer in residual block
if self.layernorm1 is not None:
# Transpose to (H, W, C) for LayerNorm
output = jnp.transpose(output, (1, 2, 0))
# Apply LayerNorm to channel dimension at each spatial location
output = jax.vmap(jax.vmap(self.layernorm1))(output)
# Transpose back to (C, H, W)
output = jnp.transpose(output, (2, 0, 1))

output = jax.nn.relu(output)
output = self.inner_op1(output)

# Second layer in residual block
if self.layernorm2 is not None:
# Same normalization pattern as above
output = jnp.transpose(output, (1, 2, 0))
output = jax.vmap(jax.vmap(self.layernorm2))(output)
output = jnp.transpose(output, (2, 0, 1))

output = jax.nn.relu(output)
output = self.inner_op2(output)
return x + output


def make_downsampling_layer(
in_channels: int, out_channels: int, *, key: jaxtyping.PRNGKeyArray, dtype: jnp.dtype
) -> tuple[eqx.nn.Conv2d, eqx.nn.MaxPool2d]:
"""Returns conv + maxpool layers for downsampling."""
conv = eqx.nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
padding=1,
key=key,
dtype=dtype,
)
maxpool = eqx.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
return conv, maxpool


class ResNetTorso(eqx.Module):
"""ResNetTorso for visual inputs, inspired by the IMPALA paper."""

in_channels: int = eqx.field(static=True)
channels_per_group: Sequence[int] = eqx.field(static=True)
blocks_per_group: Sequence[int] = eqx.field(static=True)
use_layer_norm: bool = eqx.field(static=True)
downsampling_layers: list[tuple[eqx.nn.Conv2d, eqx.nn.MaxPool2d]]
residual_blocks: list[list[ResidualBlock]]

def __init__(
self,
in_channels: int = 4, # 4 stacked frames
channels_per_group: Sequence[int] = (16, 32, 32),
blocks_per_group: Sequence[int] = (2, 2, 2),
use_layer_norm: bool = False,
dtype: jnp.dtype = jnp.float32,
*,
key: jaxtyping.PRNGKeyArray,
):
self.in_channels = in_channels
self.channels_per_group = channels_per_group
self.blocks_per_group = blocks_per_group
self.use_layer_norm = use_layer_norm

if len(channels_per_group) != len(blocks_per_group):
raise ValueError(
"Length of channels_per_group and blocks_per_group must be equal. "
f"Got channels_per_group={channels_per_group}, "
f"blocks_per_group={blocks_per_group}"
)

# Create keys for all layers
num_groups = len(channels_per_group)
total_blocks = sum(blocks_per_group)
keys = jax.random.split(key, num_groups + total_blocks)

# Create downsampling layers
downsample_keys = keys[:num_groups]
prev_channels = in_channels
self.downsampling_layers = []
for channels, k in zip(channels_per_group, downsample_keys, strict=False):
layer = make_downsampling_layer(
in_channels=prev_channels, out_channels=channels, key=k, dtype=dtype
)
self.downsampling_layers.append(layer)
prev_channels = channels

# Create residual blocks
block_keys = keys[num_groups:]
key_idx = 0
self.residual_blocks = []
for channels, num_blocks in zip(channels_per_group, blocks_per_group, strict=False):
group_blocks = []
for _ in range(num_blocks):
block = ResidualBlock(
num_channels=channels,
use_layer_norm=use_layer_norm,
key=block_keys[key_idx],
dtype=dtype,
)
group_blocks.append(block)
key_idx += 1
self.residual_blocks.append(group_blocks)

def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
output = inputs

for (conv, maxpool), res_blocks in zip(
self.downsampling_layers, self.residual_blocks, strict=False
):
# Downsampling
output = conv(output)
output = maxpool(output)

# Residual blocks
for block in res_blocks:
output = block(output)

return output


class DeepAtariTorso(eqx.Module):
"""Deep torso for Atari, from the IMPALA paper.
Based on
https://github.com/google-deepmind/acme/blob/eedf63ca039856876ff85be472fa9186cf29b073/acme/jax/networks/atari.py
"""

resnet: ResNetTorso
mlp_head: eqx.nn.MLP
use_layer_norm: bool = eqx.field(static=True)
channel_last: bool = eqx.field(static=True)
dtype: jnp.dtype = eqx.field(static=True)

def __init__(
self,
channels_per_group: Sequence[int] = (16, 32, 32),
blocks_per_group: Sequence[int] = (2, 2, 2),
hidden_sizes: Sequence[int] = (512,),
use_layer_norm: bool = True,
in_channels: int = 4,
input_size: tuple[int, int] = (84, 84),
dtype: jnp.dtype = jnp.float32,
channel_last: bool = True,
*,
key: jaxtyping.PRNGKeyArray,
):
keys = jax.random.split(key, 2)
self.use_layer_norm = use_layer_norm
self.channel_last = channel_last
self.dtype = dtype
self.resnet = ResNetTorso(
channels_per_group=channels_per_group,
blocks_per_group=blocks_per_group,
use_layer_norm=use_layer_norm,
in_channels=in_channels,
dtype=dtype,
key=keys[0],
)

# Calculate resnet output size
sample_input = jax.ShapeDtypeStruct((in_channels, *input_size), dtype=dtype)
sample_output = eqx.filter_eval_shape(self.resnet, sample_input)
in_size = int(math.prod(sample_output.shape))

self.mlp_head = eqx.nn.MLP(
in_size=in_size,
out_size=hidden_sizes[-1],
width_size=hidden_sizes[0],
depth=len(hidden_sizes),
activation=jax.nn.relu,
final_activation=jax.nn.relu,
dtype=dtype,
key=keys[1],
)

def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
was_uint8 = x.dtype == jnp.uint8
x = x.astype(self.dtype)
if was_uint8:
x = x / 255.0
if self.channel_last:
x = jnp.swapaxes(x, 0, self.resnet.in_channels - 1)
output = self.resnet(x)
output = jax.nn.relu(output)
# Flatten all dimensions into a single vector
output = output.reshape(-1)
output = self.mlp_head(output)
return output


class OAREmbedding(eqx.Module):
"""Module for embedding (observation, action, reward) inputs together.
Based on
https://github.com/google-deepmind/acme/blob/eedf63ca039856876ff85be472fa9186cf29b073/acme/jax/networks/embedding.py
"""

torso: Callable[[jax.Array], jax.Array]
num_actions: int

def __call__(self, observation: jax.Array, action: jax.Array, reward: jax.Array) -> jnp.ndarray:
"""Embed each of the (observation, action, reward) inputs & concatenate."""
features = self.torso(observation) # [D]

action = jnp.squeeze(jax.nn.one_hot(action, num_classes=self.num_actions)) # [A]
# Map rewards -> [-1, 1].
reward = jnp.tanh(reward)
# Add dummy trailing dimensions to rewards if necessary.
while reward.ndim < action.ndim:
reward = jnp.expand_dims(reward, axis=-1)

# Concatenate on final dimension.
embedding = jnp.concatenate([features, action, reward], axis=-1) # [D+A+1]
return embedding


class R2D2Network(eqx.Module):
"""The R2D2 network: a convolutional feature extractor, an LSTMCell, and a dueling head."""

embed: Callable[
[jax.Array, jax.Array, jax.Array], jax.Array
] # Section 2.3: convolutional feature extractor.
lstm_cell: eqx.nn.LSTMCell | None # Section 2.3 & 3: recurrent cell.
dueling_value: eqx.nn.Linear # Section 2.3: value branch.
dueling_advantage: eqx.nn.Linear # Section 2.3: advantage branch.

def __init__(
self,
torso: Callable[[jax.Array], jax.Array],
lstm_cell: eqx.nn.LSTMCell | None,
dueling_value: eqx.nn.Linear,
dueling_advantage: eqx.nn.Linear,
num_actions: int,
):
self.embed = OAREmbedding(torso=torso, num_actions=num_actions)
self.lstm_cell = lstm_cell
self.dueling_value = dueling_value
self.dueling_advantage = dueling_advantage

def __call__(
self,
observation: jax.Array,
action: jax.Array,
reward: jax.Array,
hidden: tuple[jax.Array, jax.Array],
) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]:
features = self.embed(observation, action, reward)
if self.lstm_cell is not None:
h, c = self.lstm_cell(features, hidden)
else:
h = features
c = jnp.zeros_like(features)
value = self.dueling_value(h)
advantage = self.dueling_advantage(h)
q_values = value + (advantage - jnp.mean(advantage, axis=-1, keepdims=True))
return q_values, (h, c)


class R2D2Networks(eqx.Module):
online: R2D2Network
target: R2D2Network


def make_networks_resnet(
num_actions: int,
in_channels: int,
input_size: tuple[int, int] = (84, 84),
dtype: jnp.dtype = jnp.float32,
hidden_size: int = 512,
channel_last: bool = False,
*,
key: jaxtyping.PRNGKeyArray,
) -> R2D2Networks:
torso_key, lstm_key, dueling_value_key, dueling_advantage_key = jax.random.split(key, 4)

online = R2D2Network(
torso=DeepAtariTorso(
in_channels=in_channels,
input_size=input_size,
# output will be concatenated with action and reward, so we subtract them from the hidden size
hidden_sizes=(hidden_size - num_actions - 1,),
channel_last=channel_last,
dtype=dtype,
key=torso_key,
),
lstm_cell=eqx.nn.LSTMCell(hidden_size, hidden_size, dtype=dtype, key=lstm_key),
dueling_value=eqx.nn.Linear(hidden_size, 1, dtype=dtype, key=dueling_value_key),
dueling_advantage=eqx.nn.Linear(
hidden_size, num_actions, dtype=dtype, key=dueling_advantage_key
),
num_actions=num_actions,
)
target = copy.deepcopy(online)
return R2D2Networks(online=online, target=target)


def make_networks_mlp(
num_actions: int,
input_size: int,
dtype: jnp.dtype = jnp.float32,
hidden_size: int = 512,
use_lstm: bool = True,
*,
key: jaxtyping.PRNGKeyArray,
) -> R2D2Networks:
torso_key_0, torso_key_1, lstm_key, dueling_value_key, dueling_advantage_key = jax.random.split(
key, 5
)

online = R2D2Network(
torso=eqx.nn.Sequential(
(
eqx.nn.Lambda(jnp.ravel),
eqx.nn.Linear(input_size, hidden_size, dtype=dtype, key=torso_key_0),
eqx.nn.Lambda(jax.nn.relu),
eqx.nn.Linear(hidden_size, hidden_size - num_actions - 1, dtype=dtype, key=torso_key_1),
eqx.nn.Lambda(jax.nn.relu),
)
),
lstm_cell=eqx.nn.LSTMCell(hidden_size, hidden_size, dtype=dtype, key=lstm_key)
if use_lstm
else None,
dueling_value=eqx.nn.Linear(hidden_size, 1, dtype=dtype, key=dueling_value_key),
dueling_advantage=eqx.nn.Linear(
hidden_size, num_actions, dtype=dtype, key=dueling_advantage_key
),
num_actions=num_actions,
)
target = copy.deepcopy(online)
return R2D2Networks(online=online, target=target)
588 changes: 588 additions & 0 deletions earl/agents/r2d2/r2d2.py

Large diffs are not rendered by default.

235 changes: 235 additions & 0 deletions earl/agents/r2d2/test_r2d2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
import dataclasses
import functools
import io

import chex
import envpool
import gymnasium
import jax
import jax.numpy as jnp
import pytest
from gymnax.environments.spaces import Box, Discrete
from jax_loop_utils.metric_writers import MemoryWriter
from jax_loop_utils.metric_writers._audio_video import encode_video_to_gif

import earl.agents.r2d2.networks as r2d2_networks
from earl.agents.r2d2.r2d2 import R2D2, R2D2Config
from earl.agents.r2d2.utils import render_atari_cycle, update_buffer_batch
from earl.core import EnvInfo, env_info_from_gymnasium
from earl.environment_loop.gymnasium_loop import GymnasiumLoop


# triggered by envpool
def test_train_atari():
stack_num = 4
num_envs = 6
input_size = (84, 84)
env_factory = functools.partial(
envpool.make_gymnasium,
"Asterix-v5",
num_envs=num_envs,
stack_num=stack_num,
img_height=input_size[0],
img_width=input_size[1],
)
with env_factory() as env:
assert isinstance(env.action_space, gymnasium.spaces.Discrete)
num_actions = int(env.action_space.n)
devices = jax.local_devices()
if len(devices) > 1:
actor_devices = devices[:1]
learner_devices = devices[1:]
else:
actor_devices = devices
learner_devices = devices
env_info = env_info_from_gymnasium(env, num_envs)
hidden_size = 512
key = jax.random.PRNGKey(0)
networks_key, loop_key, agent_key = jax.random.split(key, 3)
networks = r2d2_networks.make_networks_resnet(
num_actions=num_actions,
in_channels=stack_num,
hidden_size=hidden_size,
input_size=input_size,
key=networks_key,
)
steps_per_cycle = 10

config = R2D2Config(
epsilon_greedy_schedule_args=dict(
init_value=0.9, end_value=0.01, transition_steps=steps_per_cycle * 1_000
),
num_envs_per_learner=num_envs,
replay_seq_length=steps_per_cycle,
buffer_capacity=steps_per_cycle * 10,
burn_in=5,
learning_rate_schedule_name="cosine_onecycle_schedule",
learning_rate_schedule_args=dict(
transition_steps=steps_per_cycle * 2_500,
# NOTE: more devices effectively means a larger batch size, so we
# scale the learning rate up to train faster!
peak_value=5e-5 * len(devices),
),
target_update_step_size=0.00,
target_update_period=500,
)
agent = R2D2(env_info, config)
loop_state = agent.new_state(networks, agent_key)
metric_writer = MemoryWriter()
train_loop = GymnasiumLoop(
env_factory,
agent,
num_envs,
loop_key,
observe_cycle=render_atari_cycle,
metric_writer=metric_writer,
actor_devices=actor_devices,
learner_devices=learner_devices,
vectorization_mode="none",
)
# just one cycle, make sure it runs
loop_state = train_loop.run(loop_state, 1, steps_per_cycle)
train_loop.close()
# make sure we can render the video
video_buf = io.BytesIO()
video_array = next(iter(metric_writer.videos.values()))["video"]
encode_video_to_gif(video_array, video_buf)
# for manual inspection, uncomment
# with open("asterix_initial.gif", "wb") as f:
# f.write(video_buf.getvalue())


_dummy_env_info = EnvInfo(
num_envs=2,
observation_space=Box(low=0, high=1, shape=(4,)),
action_space=Discrete(num_categories=2),
name="dummy",
)


@pytest.fixture
def mlp_agent_and_networks():
key = jax.random.PRNGKey(0)
env_info = _dummy_env_info
config = R2D2Config(
epsilon_greedy_schedule_args=dict(init_value=0.5, end_value=0.0001, transition_steps=10000),
discount=0.99,
q_learning_n_steps=3,
burn_in=2,
importance_sampling_priority_exponent=0.9,
target_update_period=10,
buffer_capacity=8, # must be divisible by replay_seq_length
replay_seq_length=4,
store_hidden_states=True,
num_envs_per_learner=env_info.num_envs,
)
# Create dummy networks using the MLP builder
networks = r2d2_networks.make_networks_mlp(
num_actions=2, input_size=4, dtype=jnp.float32, hidden_size=32, key=key
)
agent = R2D2(env_info, config)
return agent, networks


def test_slice_for_replay(mlp_agent_and_networks):
agent, _ = mlp_agent_and_networks
B = agent.env_info.num_envs
T = 8
dummy_data = jnp.arange(B * T).reshape(B, T, 1) # shape (B, T, 1)
start_idx = jnp.array([0, 4]) # one index per environment
assert start_idx.shape == (B,)
length = 4
sliced = agent._slice_for_replay(dummy_data, start_idx, length)
assert sliced.shape == (B, length, 1)
# Verify a couple of values
chex.assert_equal(sliced[0, 0, 0], dummy_data[0, 0, 0])
chex.assert_equal(sliced[1, 0, 0], dummy_data[1, 4, 0])


def test_sample_from_experience(mlp_agent_and_networks):
agent, networks = mlp_agent_and_networks
agent = dataclasses.replace(
agent, config=dataclasses.replace(agent.config, replay_batch_size=2 * agent.env_info.num_envs)
)
key = jax.random.PRNGKey(0)
exp_state = agent._new_experience_state(networks, key)
outputs = agent._sample_from_experience(networks, key, exp_state)
sampled_seq_idx, obs_time, action_time, reward_time, dones_time, hidden_h_pre, hidden_c_pre = (
outputs
)
assert sampled_seq_idx.shape == (agent.config.replay_batch_size,)
assert obs_time.shape == (
agent.config.replay_seq_length,
agent.config.replay_batch_size,
agent.env_info.observation_space.shape[0],
)
assert action_time.shape == (agent.config.replay_seq_length, agent.config.replay_batch_size)
assert reward_time.shape == (agent.config.replay_seq_length, agent.config.replay_batch_size)
assert dones_time.shape == (agent.config.replay_seq_length, agent.config.replay_batch_size)
# Hidden state shapes
assert hidden_h_pre.shape == (
agent.config.replay_batch_size,
networks.online.lstm_cell.hidden_size,
)
assert hidden_c_pre.shape == (
agent.config.replay_batch_size,
networks.online.lstm_cell.hidden_size,
)


def test_update_buffer_batch():
"""Test the update_buffer_batch function with various scenarios."""
# Set up parameters
seq_length = 4
buffer_capacity = 8
num_envs = 2

# Test with pointer at beginning
buffer = jnp.zeros((num_envs, buffer_capacity), dtype=jnp.bool)
pointer = jnp.array(0, dtype=jnp.uint32)
data = jnp.array(
[[True, False, True, False], [True, False, True, False]]
) # Shape is (num_envs, seq_length)
updated_buffer = update_buffer_batch(buffer, pointer, data, debug=True)

# Check buffer has been updated correctly
expected_buffer = jnp.zeros((num_envs, buffer_capacity), dtype=jnp.bool)
expected_buffer = expected_buffer.at[:, 0:4].set(data)
chex.assert_trees_all_close(updated_buffer, expected_buffer)

# Test with pointer in middle
buffer = jnp.zeros((num_envs, buffer_capacity), dtype=jnp.bool)
pointer = jnp.array(2, dtype=jnp.uint32)
updated_buffer = update_buffer_batch(buffer, pointer, data, debug=True)

# Check buffer has been updated correctly
expected_buffer = jnp.zeros((num_envs, buffer_capacity), dtype=jnp.bool)
expected_buffer = expected_buffer.at[:, 2:6].set(data)
chex.assert_trees_all_close(updated_buffer, expected_buffer)

# Test with nested dimensions
hidden_size = 3
buffer = jnp.zeros((num_envs, buffer_capacity, hidden_size), dtype=jnp.float32)
data = jnp.ones((num_envs, seq_length, hidden_size), dtype=jnp.float32)
data = data * jnp.arange(1, num_envs + 1).reshape(num_envs, 1, 1) # Different values per env

pointer = jnp.array(1, dtype=jnp.uint32)
updated_buffer = update_buffer_batch(buffer, pointer, data, debug=True)

# Check buffer shape and updated values
chex.assert_shape(updated_buffer, (num_envs, buffer_capacity, hidden_size))

# Verify values in updated region
expected_buffer = jnp.zeros((num_envs, buffer_capacity, hidden_size), dtype=jnp.float32)
expected_buffer = expected_buffer.at[:, 1:5].set(data)
chex.assert_trees_all_close(updated_buffer, expected_buffer)

# Also verify specific values to ensure environment-specific data was preserved
for env_idx in range(num_envs):
# Check that the updated region has values equal to env_idx + 1
chex.assert_trees_all_close(
updated_buffer[env_idx, 1:5], jnp.ones((4, hidden_size)) * (env_idx + 1)
)
# Check that areas outside the updated region remain zeros
chex.assert_trees_all_close(updated_buffer[env_idx, 0], jnp.zeros(hidden_size))
chex.assert_trees_all_close(updated_buffer[env_idx, 5:], jnp.zeros((3, hidden_size)))
93 changes: 93 additions & 0 deletions earl/agents/r2d2/test_r2d2_learns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import math

import jax
import jax.numpy as jnp
import numpy as np
from gymnax.environments.classic_control import CartPole
from gymnax.environments.spaces import Box, Discrete
from jax_loop_utils.metric_writers import MemoryWriter

import earl.agents.r2d2.networks as r2d2_networks
from earl.agents.r2d2.r2d2 import R2D2, R2D2Config
from earl.core import env_info_from_gymnax
from earl.environment_loop.gymnax_loop import GymnaxLoop
from earl.metric_key import MetricKey


def test_learns_cartpole():
# NOTE: this test is pretty fragile. I had to search for good hyperparameters to get it to learn.
# I think if it were run for many more cycles it would learn, but that's not appropriate for CI.
env = CartPole()
env_params = env.default_params
action_space = env.action_space(env_params)
assert isinstance(action_space, Discrete), action_space
num_actions = int(action_space.n)
observation_space = env.observation_space(env_params)
assert isinstance(observation_space, Box), observation_space

hidden_size = 32
key = jax.random.PRNGKey(1)
networks_key, loop_key, agent_key = jax.random.split(key, 3)
networks = r2d2_networks.make_networks_mlp(
num_actions=num_actions,
input_size=int(math.prod(observation_space.shape)),
dtype=jnp.float32,
hidden_size=hidden_size,
use_lstm=True,
key=networks_key,
)

num_envs = 512
burn_in = 10
steps_per_cycle = 80 + burn_in
num_cycles = 1000
env_info = env_info_from_gymnax(env, env_params, num_envs)
num_off_policy_optims_per_cycle = 1
config = R2D2Config(
epsilon_greedy_schedule_args=dict(
init_value=0.5, end_value=0.01, transition_steps=steps_per_cycle * num_cycles
),
q_learning_n_steps=2,
debug=False,
buffer_capacity=steps_per_cycle * 10,
num_envs_per_learner=num_envs,
replay_seq_length=steps_per_cycle,
burn_in=burn_in,
value_rescaling_epsilon=0.0,
num_off_policy_optims_per_cycle=num_off_policy_optims_per_cycle,
gradient_clipping_max_delta=1.0,
learning_rate_schedule_name="cosine_onecycle_schedule",
learning_rate_schedule_args=dict(
transition_steps=steps_per_cycle * num_cycles // 2,
peak_value=2e-4,
),
target_update_step_size=0.00,
target_update_period=100,
)
agent = R2D2(env_info, config)
memory_writer = MemoryWriter()
# For tweaking of hyperparameters, you can use the mlflow writer
# and view metrics with `uv run --with mlflow mlflow server`
# metric_writer = MultiWriter(
# (memory_writer, MlflowMetricWriter(experiment_name=env.name)),
# )
metric_writer = memory_writer
loop = GymnaxLoop(env, env.default_params, agent, num_envs, loop_key, metric_writer=metric_writer)
agent_state = agent.new_state(networks, agent_key)
_ = loop.run(agent_state, num_cycles, steps_per_cycle)
del agent_state
metrics = memory_writer.scalars
metric_writer.close()

episode_lengths = np.array(
[step_metrics[MetricKey.COMPLETE_EPISODE_LENGTH_MEAN] for step_metrics in metrics.values()]
)

assert len(metrics) == num_cycles
# Due to auto-resets, the reward is always constant, but it's a survival task
# so longer episodes are better.
mean_over_cycles = 20
first_mean = float(np.mean(episode_lengths[:mean_over_cycles]))
assert first_mean > 0
last_mean = float(np.mean(episode_lengths[-mean_over_cycles:]))
assert last_mean > 1.4 * first_mean
672 changes: 672 additions & 0 deletions earl/agents/r2d2/train_r2d2_asterix.ipynb

Large diffs are not rendered by default.

335 changes: 335 additions & 0 deletions earl/agents/r2d2/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,335 @@
import functools
from collections.abc import Callable
from typing import Any, NamedTuple

import chex
import jax
import jax.numpy as jnp
import numpy as np

from earl.core import EnvStep, Metrics, Video

Array = jax.Array


# BEGIN copied from rlax
# Below copied from rlax because its package is unusable with Bazel
# due to https://github.com/google-deepmind/rlax/issues/133.
def double_q_learning(
q_tm1: Array,
a_tm1: Array,
r_t: Array,
discount_t: Array,
q_t_value: Array,
q_t_selector: Array,
stop_target_gradients: bool = True,
) -> Array:
"""Calculates the double Q-learning temporal difference error.
See "Double Q-learning" by van Hasselt.
(https://papers.nips.cc/paper/3964-double-q-learning.pdf).
Args:
q_tm1: Q-values at time t-1.
a_tm1: action index at time t-1.
r_t: reward at time t.
discount_t: discount at time t.
q_t_value: Q-values at time t.
q_t_selector: selector Q-values at time t.
stop_target_gradients: bool indicating whether or not to apply stop gradient
to targets.
Returns:
Double Q-learning temporal difference error.
"""
chex.assert_rank([q_tm1, a_tm1, r_t, discount_t, q_t_value, q_t_selector], [1, 0, 0, 0, 1, 1])
chex.assert_type(
[q_tm1, a_tm1, r_t, discount_t, q_t_value, q_t_selector],
[float, int, float, float, float, float],
)

target_tm1 = r_t + discount_t * q_t_value[q_t_selector.argmax()]
target_tm1 = jax.lax.select(stop_target_gradients, jax.lax.stop_gradient(target_tm1), target_tm1)
return target_tm1 - q_tm1[a_tm1]


def n_step_bootstrapped_returns(
r_t: Array,
discount_t: Array,
v_t: Array,
n: int,
lambda_t: Array | float = 1.0,
stop_target_gradients: bool = False,
) -> Array:
"""Computes strided n-step bootstrapped return targets over a sequence.
The returns are computed according to the below equation iterated `n` times:
Gₜ = rₜ₊₁ + γₜ₊₁ [(1 - λₜ₊₁) vₜ₊₁ + λₜ₊₁ Gₜ₊₁].
When lambda_t == 1. (default), this reduces to
Gₜ = rₜ₊₁ + γₜ₊₁ * (rₜ₊₂ + γₜ₊₂ * (... * (rₜ₊ₙ + γₜ₊ₙ * vₜ₊ₙ ))).
Args:
r_t: rewards at times [1, ..., T].
discount_t: discounts at times [1, ..., T].
v_t: state or state-action values to bootstrap from at time [1, ...., T].
n: number of steps over which to accumulate reward before bootstrapping.
lambda_t: lambdas at times [1, ..., T]. Shape is [], or [T-1].
stop_target_gradients: bool indicating whether or not to apply stop gradient
to targets.
Returns:
estimated bootstrapped returns at times [0, ...., T-1]
"""
chex.assert_rank([r_t, discount_t, v_t, lambda_t], [1, 1, 1, {0, 1}])
chex.assert_type([r_t, discount_t, v_t, lambda_t], float)
chex.assert_equal_shape([r_t, discount_t, v_t])
seq_len = r_t.shape[0]

# Maybe change scalar lambda to an array.
lambda_t = jnp.ones_like(discount_t) * lambda_t

# Shift bootstrap values by n and pad end of sequence with last value v_t[-1].
pad_size = min(n - 1, seq_len)
targets = jnp.concatenate([v_t[n - 1 :], jnp.array([v_t[-1]] * pad_size)])

# Pad sequences. Shape is now (T + n - 1,).
r_t = jnp.concatenate([r_t, jnp.zeros(n - 1)])
discount_t = jnp.concatenate([discount_t, jnp.ones(n - 1)])
lambda_t = jnp.concatenate([lambda_t, jnp.ones(n - 1)])
v_t = jnp.concatenate([v_t, jnp.array([v_t[-1]] * (n - 1))])

# Work backwards to compute n-step returns.
for i in reversed(range(n)):
r_ = r_t[i : i + seq_len]
discount_ = discount_t[i : i + seq_len]
lambda_ = lambda_t[i : i + seq_len]
v_ = v_t[i : i + seq_len]
targets = r_ + discount_ * ((1.0 - lambda_) * v_ + lambda_ * targets)

return jax.lax.select(stop_target_gradients, jax.lax.stop_gradient(targets), targets)


def transform_values(build_targets, *value_argnums):
"""Decorator to convert targets to use transformed value function."""

@functools.wraps(build_targets)
def wrapped_build_targets(tx_pair, *args, **kwargs):
tx_args = list(args)
for index in value_argnums:
tx_args[index] = tx_pair.apply_inv(tx_args[index])

targets = build_targets(*tx_args, **kwargs)
return tx_pair.apply(targets)

return wrapped_build_targets


transformed_n_step_returns = transform_values(n_step_bootstrapped_returns, 2)


class TxPair(NamedTuple):
apply: Callable
apply_inv: Callable


def signed_hyperbolic(x: Array, eps: float = 1e-3) -> Array:
"""Signed hyperbolic transform, inverse of signed_parabolic."""
chex.assert_type(x, float)
return jnp.sign(x) * (jnp.sqrt(jnp.abs(x) + 1) - 1) + eps * x


def signed_parabolic(x: Array, eps: float = 1e-3) -> Array:
"""Signed parabolic transform, inverse of signed_hyperbolic."""
chex.assert_type(x, float)
z = jnp.sqrt(1 + 4 * eps * (eps + 1 + jnp.abs(x))) / 2 / eps - 1 / 2 / eps
return jnp.sign(x) * (jnp.square(z) - 1)


IDENTITY_PAIR = TxPair(lambda x: x, lambda x: x)


def batched_index(values: Array, indices: Array, keepdims: bool = False) -> Array:
"""Index into the last dimension of a tensor, preserving all others dims.
Args:
values: a tensor of shape [..., D],
indices: indices of shape [...].
keepdims: whether to keep the final dimension.
Returns:
a tensor of shape [...] or [..., 1].
"""
indexed = jnp.take_along_axis(values, indices[..., None], axis=-1)
if not keepdims:
indexed = jnp.squeeze(indexed, axis=-1)
return indexed


def transformed_n_step_q_learning(
q_tm1: Array,
a_tm1: Array,
target_q_t: Array,
a_t: Array,
r_t: Array,
discount_t: Array,
n: int,
stop_target_gradients: bool = True,
tx_pair: TxPair = IDENTITY_PAIR,
) -> Array:
"""Calculates transformed n-step TD errors.
See "Recurrent Experience Replay in Distributed Reinforcement Learning" by
Kapturowski et al. (https://openreview.net/pdf?id=r1lyTjAqYX).
Args:
q_tm1: Q-values at times [0, ..., T - 1].
a_tm1: action index at times [0, ..., T - 1].
target_q_t: target Q-values at time [1, ..., T].
a_t: action index at times [[1, ... , T]] used to select target q-values to
bootstrap from; max(target_q_t) for normal Q-learning, max(q_t) for double
Q-learning.
r_t: reward at times [1, ..., T].
discount_t: discount at times [1, ..., T].
n: number of steps over which to accumulate reward before bootstrapping.
stop_target_gradients: bool indicating whether or not to apply stop gradient
to targets.
tx_pair: TxPair of value function transformation and its inverse.
Returns:
Transformed N-step TD error.
"""
chex.assert_rank([q_tm1, target_q_t, a_tm1, a_t, r_t, discount_t], [2, 2, 1, 1, 1, 1])
chex.assert_type(
[q_tm1, target_q_t, a_tm1, a_t, r_t, discount_t], [float, float, int, int, float, float]
)

v_t = batched_index(target_q_t, a_t)
target_tm1 = transformed_n_step_returns(
tx_pair, r_t, discount_t, v_t, n, stop_target_gradients=stop_target_gradients
)
q_a_tm1 = batched_index(q_tm1, a_tm1)
return target_tm1 - q_a_tm1


# END copied from rlax


def update_buffer_batch(buffer, pointer, data, debug=False):
"""
Update buffer with data for all environments at once.
Args:
buffer: Array of shape (num_envs, buffer_capacity, ...)
pointer: Scalar uint32 index where to insert data (same for all environments)
data: Array of shape (num_envs, seq_length, ...)
debug: Whether to print debug information
Returns:
Updated buffer of shape (num_envs, buffer_capacity, ...)
"""
if debug:
jax.debug.print(
"update_buffer_batch: buffer.shape: {}, data.shape: {}", buffer.shape, data.shape
)

# Start indices for the update:
# - First dimension: start at the first environment (index 0)
# - Second dimension: start at the pointer
# - Additional dimensions: start at index 0
start_indices = (jnp.array(0, dtype=jnp.uint32), pointer) + tuple(
jnp.array(0, dtype=jnp.uint32) for _ in range(len(buffer.shape) - 2)
)

if debug:
jax.debug.print("start_indices: {}", start_indices)

return jax.lax.dynamic_update_slice(buffer, data, start_indices)


def render_minatar(obs: Array) -> np.ndarray:
n_channels = obs.shape[-1]
numerical_state = np.amax(obs * np.reshape(np.arange(n_channels) + 1, (1, 1, -1)), 2) + 0.5

# Create a simple color map (similar to cubehelix)
# Add black as the first color (for value 0)
colors = np.zeros((n_channels + 1, 3))

# Generate colors for each channel (1 to n_channels)
for i in range(1, n_channels + 1):
# Create colors with increasing intensity and some variation
# This is a simplified version of cubehelix - adjust as needed
hue = (i / n_channels) * 0.8 + 0.1 # Hue varies from 0.1 to 0.9
saturation = 0.7
value = 0.5 + i / (2 * n_channels) # Value increases with channel index

# Simple HSV to RGB conversion
h = hue * 6
c = value * saturation
x = c * (1 - abs(h % 2 - 1))
m = value - c

if h < 1:
r, g, b = c, x, 0
elif h < 2:
r, g, b = x, c, 0
elif h < 3:
r, g, b = 0, c, x
elif h < 4:
r, g, b = 0, x, c
elif h < 5:
r, g, b = x, 0, c
else:
r, g, b = c, 0, x

colors[i] = np.array([r + m, g + m, b + m])

# Vectorized mapping of numerical_state to RGB colors
# Convert numerical_state to integer indices and clip to valid range
indices = np.clip(numerical_state.astype(np.int32), 0, n_channels).reshape(-1)

# Use the indices to look up colors
rgb_values = colors[indices]

# Reshape back to image dimensions with RGB channels
rgb_image = rgb_values.reshape(numerical_state.shape + (3,))

# Convert from float (0-1) to uint8 (0-255) for PIL compatibility
rgb_image_uint8 = (rgb_image * 255).astype(np.uint8)

# Resize to 64x64 using nearest neighbor interpolation
height, width = rgb_image_uint8.shape[:2]
scale_h = 64 // height
scale_w = 64 // width

# Use numpy's repeat for simple nearest-neighbor upscaling
# First, repeat rows
upscaled = np.repeat(rgb_image_uint8, scale_h, axis=0)
# Then, repeat columns
upscaled = np.repeat(upscaled, scale_w, axis=1)

return upscaled


def render_minatar_cycle(trajectory: EnvStep, step_infos: dict[Any, Any]) -> Metrics:
obs = trajectory.obs
if len(obs.shape) != 5:
raise ValueError(f"Expected trajectory.obs to have shape (B, T, H, W, C),got {obs.shape}")
obs = obs[0]
img_array = np.stack([render_minatar(obs[i]) for i in range(obs.shape[0])])

return {"video": Video(img_array)} # pyright: ignore


def render_atari_cycle(trajectory: EnvStep, step_infos: dict[Any, Any]) -> Metrics:
obs = trajectory.obs
if len(obs.shape) != 5:
raise ValueError(
f"Expected trajectory.obs to have shape (B, T, stack_size, H, W),got {obs.shape}"
)
obs = obs[0, :, 0, :, :] # batch index 0, stack index 0
obs = np.expand_dims(obs, axis=3) # add channel dimension

return {"video": Video(obs)} # pyright: ignore
9 changes: 3 additions & 6 deletions earl/agents/random_agent/random_agent.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@
import jax.numpy as jnp
from jaxtyping import PRNGKeyArray, PyTree, Scalar

from earl.core import ActionAndState, Agent, EnvStep, Metrics
from earl.core import ActionAndState, Agent, EnvStep
from earl.core import AgentState as CoreAgentState


@@ -30,7 +30,6 @@ class RandomAgent(Agent[None, OptState, None, ActorState]):

_sample_action_space: Callable
_num_off_policy_updates: int
_opt_count_metric_key: str = "opt_count"

def _new_actor_state(self, nets: None, key: PRNGKeyArray) -> ActorState:
return ActorState(key, jnp.zeros((1,), dtype=jnp.uint32))
@@ -53,10 +52,8 @@ def _act(
def _partition_for_grad(self, nets: None) -> tuple[None, None]:
return None, nets

def _loss(
self, nets: None, opt_state: OptState, experience_state: None
) -> tuple[Scalar, Metrics]:
return jnp.array(0.0), {self._opt_count_metric_key: opt_state.opt_count}
def _loss(self, nets: None, opt_state: OptState, experience_state: None) -> tuple[Scalar, None]:
return jnp.array(0.0), None

def _optimize_from_grads(
self, nets: None, opt_state: OptState, nets_grads: PyTree
6 changes: 3 additions & 3 deletions earl/agents/simple_policy_gradient/simple_policy_gradient.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@
import optax
from jaxtyping import PRNGKeyArray, PyTree, Scalar

from earl.core import ActionAndState, Agent, EnvStep, Metrics
from earl.core import ActionAndState, Agent, EnvStep
from earl.core import AgentState as CoreAgentState
from earl.utils.sharding import shard_along_axis_0

@@ -127,7 +127,7 @@ def _optimize_from_grads(

def _loss(
self, nets: eqx.nn.Sequential, opt_state: optax.OptState, experience_state: ExperienceState
) -> tuple[Scalar, Metrics]:
) -> tuple[Scalar, ExperienceState]:
def discounted_returns(carry, x):
carry = x + self.config.discount * carry
return carry, carry
@@ -138,7 +138,7 @@ def vmap_discounted_returns(rewards):
return ys

returns = vmap_discounted_returns(experience_state.rewards)
return -jnp.mean(returns * experience_state.chosen_action_log_probs), {}
return -jnp.mean(returns * experience_state.chosen_action_log_probs), experience_state

def num_off_policy_optims_per_cycle(self) -> int:
return 0
Original file line number Diff line number Diff line change
@@ -10,7 +10,8 @@
make_networks,
)
from earl.core import env_info_from_gymnax
from earl.environment_loop.gymnax_loop import GymnaxLoop, MetricKey
from earl.environment_loop.gymnax_loop import GymnaxLoop
from earl.metric_key import MetricKey


def test_learns_cart_pole():
33 changes: 20 additions & 13 deletions earl/core.py
Original file line number Diff line number Diff line change
@@ -10,6 +10,7 @@
import jax
import jax.numpy as jnp
from gymnasium.core import Env as GymnasiumEnv
from gymnasium.vector import VectorEnv
from gymnax.environments.environment import Environment as GymnaxEnv
from gymnax.environments.spaces import Space
from jaxtyping import PRNGKeyArray, PyTree, Scalar
@@ -111,9 +112,15 @@ def env_info_from_gymnax(env: GymnaxEnv, params: Any, num_envs: int) -> EnvInfo:
return EnvInfo(num_envs, env.observation_space(params), env.action_space(params), env.name)


def env_info_from_gymnasium(env: GymnasiumEnv, num_envs: int) -> EnvInfo:
observation_space = _convert_gymnasium_space_to_gymnax_space(env.observation_space)
action_space = _convert_gymnasium_space_to_gymnax_space(env.action_space)
def env_info_from_gymnasium(env: GymnasiumEnv | VectorEnv, num_envs: int) -> EnvInfo:
if isinstance(env, VectorEnv):
env_obs_space = env.single_observation_space
env_action_space = env.single_action_space
else:
env_obs_space = env.observation_space
env_action_space = env.action_space
observation_space = _convert_gymnasium_space_to_gymnax_space(env_obs_space)
action_space = _convert_gymnasium_space_to_gymnax_space(env_action_space)
return EnvInfo(num_envs, observation_space, action_space, str(env))


@@ -160,13 +167,13 @@ def _optimize_from_grads_jit(
return optimize_from_grads(nets, opt_state, nets_grads)


@eqx.filter_jit()
@eqx.filter_jit(donate="warn-except-first")
def _loss_jit(
nets: _Networks,
opt_state: _OptState,
non_donated: tuple[_Networks, _OptState],
experience_state: _ExperienceState,
loss: Callable[[_Networks, _OptState, _ExperienceState], tuple[Scalar, Metrics]],
loss: Callable[[_Networks, _OptState, _ExperienceState], tuple[Scalar, _ExperienceState]],
):
nets, opt_state = non_donated
return loss(nets, opt_state, experience_state)


@@ -466,20 +473,20 @@ def _partition_for_grad(self, nets: _Networks) -> tuple[_Networks, _Networks]:

def loss(
self, nets: _Networks, opt_state: _OptState, experience_state: _ExperienceState
) -> tuple[Scalar, Metrics]:
"""Returns loss and metrics. Called after some number of environment steps.
) -> tuple[Scalar, _ExperienceState]:
"""Returns loss and experience state. Called after some number of environment steps.
Sub-classes should override _loss. This method is a wrapper that adds jit-compilation.
Note: the returned metrics should not have any keys that conflict with gymnax_loop.MetricKey.
The experience_state arg is donated, meaning callers should not access it after calling.
"""
return _loss_jit(nets, opt_state, experience_state, self._loss)
return _loss_jit((nets, opt_state), experience_state, self._loss)

@abc.abstractmethod
def _loss(
self, nets: _Networks, opt_state: _OptState, experience_state: _ExperienceState
) -> tuple[Scalar, Metrics]:
"""Returns loss and metrics. Called after some number of environment steps.
) -> tuple[Scalar, _ExperienceState]:
"""Returns loss and experience state. Called after some number of environment steps.
Must be jit-compatible.
"""
4 changes: 1 addition & 3 deletions earl/environment_loop/_common.py
Original file line number Diff line number Diff line change
@@ -90,9 +90,7 @@ def traverse(obj, current_path=""):

if isinstance(obj, tuple | list):
return {
k: v
for i, item in enumerate(obj)
for k, v in traverse(item, f"{current_path}[{i}]").items()
k: v for i, item in enumerate(obj) for k, v in traverse(item, f"{current_path}_{i}").items()
}

if hasattr(obj, "__dict__"):
133 changes: 76 additions & 57 deletions earl/environment_loop/gymnasium_loop.py
Original file line number Diff line number Diff line change
@@ -35,6 +35,11 @@ def nvtx_annotate(message: str):
yield


try:
import envpool.python.envpool as envpool
except ImportError:
envpool = None

from earl.core import (
Agent,
AgentState,
@@ -66,6 +71,13 @@ def nvtx_annotate(message: str):
_logger = logging.getLogger(__name__)


def _some_array_leaf(x: PyTree) -> jax.Array:
leaves = jax.tree.leaves(x, is_leaf=eqx.is_array)
assert leaves
assert isinstance(leaves[0], jax.Array)
return leaves[0]


class _ActorExperience(typing.NamedTuple, typing.Generic[_ActorState]):
actor_state_pre: _ActorState
cycle_result: CycleResult
@@ -80,29 +92,19 @@ def _stack_leaves(pytree_list: list[PyTree]) -> PyTree:
return pytree


@eqx.filter_jit
def _copy_pytree(pytree: PyTree) -> PyTree:
return jax.tree.map(lambda x: x.copy(), pytree)


@eqx.filter_grad(has_aux=True)
@eqx.filter_value_and_grad(has_aux=True)
def _loss_for_cycle_grad(
nets_yes_grad,
nets_no_grad,
opt_state,
experience_state,
agent: Agent,
check_metrics: typing.Callable[[Mapping], None],
) -> tuple[Scalar, ArrayMetrics]:
experience_state: _ExperienceState,
agent: Agent[_Networks, _OptState, _ExperienceState, _ActorState],
) -> tuple[Scalar, _ExperienceState]:
# this is a free function so we don't have to pass self as first arg, since filter_grad
# takes gradient with respect to the first arg.
nets = eqx.combine(nets_yes_grad, nets_no_grad)
loss, metrics = agent.loss(nets, opt_state, experience_state)
check_metrics(metrics)
# inside jit, return values are guaranteed to be arrays
mutable_metrics: ArrayMetrics = typing.cast(ArrayMetrics, dict(metrics))
mutable_metrics[MetricKey.LOSS] = loss
return loss, mutable_metrics
loss, experience_state = agent.loss(nets, opt_state, experience_state)
return loss, experience_state


def _filter_device_put(x: PyTree[typing.Any], device: jax.Device | None):
@@ -131,7 +133,7 @@ class _ActorThread(threading.Thread):
def __init__(
self,
loop: "GymnasiumLoop",
env: VectorEnv,
env: VectorEnv | GymnasiumEnv,
actor_state: typing.Any,
env_step: EnvStep,
num_steps: int,
@@ -156,8 +158,10 @@ def __init__(
super().__init__()
self._loop = loop
self._env = env
# if _filter_device_put doesn't copy, then we need to copy the actor state
# because the actor state will be donated in each actor thread.
if (
learner_devices == [device]
_some_array_leaf(actor_state).device == device
or learner_devices[0].platform == "cpu"
and device.platform == "cpu"
):
@@ -178,9 +182,9 @@ def run(self):
if nets is self.STOP_SIGNAL:
_logger.debug("stopping actor on device %s", self._device)
break
self._actor_state = self._loop._agent.prepare_for_actor_cycle(self._actor_state)
actor_state_pre = copy.deepcopy(self._actor_state)
with jax.default_device(self._device):
self._actor_state = self._loop._agent.prepare_for_actor_cycle(self._actor_state)
actor_state_pre = copy.deepcopy(self._actor_state)
cycle_result = self._loop._actor_cycle(
self._env, self._actor_state, nets, self._env_step, self._num_steps, self._key
)
@@ -244,6 +248,28 @@ def _jax_platform_cpu():
os.environ["JAX_PLATFORMS"] = prev_value


def _make_vec_env(
vectorization_mode: typing.Literal["sync", "async", "none"],
env_factory: typing.Callable[[], GymnasiumEnv],
num_envs: int,
) -> GymnasiumEnv | VectorEnv:
if vectorization_mode == "none":
env = env_factory()
if (
envpool is not None and isinstance(env, envpool.EnvPoolMixin) and len(env) != num_envs # pyright: ignore[reportArgumentType]
):
raise ValueError("envpool env must have the same number of environments as num_envs")
return env
if vectorization_mode == "sync":
return SyncVectorEnv([lambda: Autoreset(env_factory()) for _ in range(num_envs)])
assert vectorization_mode == "async"
# context="spawn" because others are unsafe with multithreaded JAX
with _jax_platform_cpu():
return AsyncVectorEnv(
[lambda: Autoreset(env_factory()) for _ in range(num_envs)], context="spawn"
)


class GymnasiumLoop:
"""Runs an Agent in a Gymnasium environment.
@@ -261,22 +287,22 @@ class GymnasiumLoop:

def __init__(
self,
env: GymnasiumEnv,
env_factory: typing.Callable[[], GymnasiumEnv],
agent: Agent,
num_envs: int,
key: PRNGKeyArray,
metric_writer: MetricWriter,
observe_cycle: ObserveCycle = no_op_observe_cycle,
actor_only: bool = False,
assert_no_recompile: bool = True,
vectorization_mode: typing.Literal["sync", "async"] = "async",
vectorization_mode: typing.Literal["sync", "async", "none"] = "async",
actor_devices: typing.Sequence[jax.Device] | None = None,
learner_devices: typing.Sequence[jax.Device] | None = None,
):
"""Initializes the GymnasiumLoop.
Args:
env: The environment.
env_factory: A function that returns a GymnasiumEnv.
agent: The agent.
num_envs: The number of environments to run in parallel per actor.
Note there will be 2 actor threads per actor device, so the total number of
@@ -289,7 +315,8 @@ def __init__(
actor_only: If True, agent.optimize_from_grads() will not be called.
assert_no_recompile: Whether to fail if the inner loop gets compiled more than once.
vectorization_mode: Whether to create a synchronous or asynchronous vectorized environment
from the provided Gymnasium environment.
from the provided Gymnasium environment. If "none", the env_factory must return a
a vectorized environment with num_envs.
actor_devices: The devices to use for acting. If None, uses jax.local_devices()[0].
learner_devices: The devices to use for learning. If None, uses jax.local_devices()[0].
"""
@@ -305,26 +332,14 @@ def __init__(
self._networks_for_actor_lock = threading.Lock()
self._actor_threads: list[_ActorThread] = []

env = Autoreset(env) # run() assumes autoreset.

def _env_factory() -> GymnasiumEnv:
return copy.deepcopy(env)

self._env_for_actor_thread: list[VectorEnv] = []
self._env_for_actor_thread: list[GymnasiumEnv | VectorEnv] = []
for _ in range(self._NUM_ACTOR_THREADS * len(self._actor_devices)):
if vectorization_mode == "sync":
self._env_for_actor_thread.append(SyncVectorEnv([_env_factory for _ in range(num_envs)]))
else:
assert vectorization_mode == "async"
# context="spawn" because others are unsafe with multithreaded JAX
with _jax_platform_cpu():
self._env_for_actor_thread.append(
AsyncVectorEnv([_env_factory for _ in range(num_envs)], context="spawn")
)
self._env_for_actor_thread.append(_make_vec_env(vectorization_mode, env_factory, num_envs))

sample_key, key = jax.random.split(key)
sample_key = jax.random.split(sample_key, num_envs)
env_info = env_info_from_gymnasium(env, num_envs)
env_0 = self._env_for_actor_thread[0]
env_info = env_info_from_gymnasium(env_0, num_envs)
self._action_space = env_info.action_space
self._example_action = jax.vmap(self._action_space.sample)(sample_key)
self._agent = agent
@@ -464,11 +479,6 @@ def run(
except queue.Empty:
break

agent_state = dataclasses.replace(
agent_state,
# Not ideal. See comment above where we pass actor state to the actor thread.
actor=pytree_get_index_0(actor_experiences[-1].cycle_result.agent_state),
)
cycle_result = actor_experiences[-1].cycle_result
if self._actor_only:
learn_duration = 0
@@ -477,9 +487,9 @@ def run(
actor_experience = actor_experiences.pop()
experience_state = self._agent_update_experience(
agent_state.experience,
actor_experience.cycle_result.trajectory,
actor_experience.actor_state_pre,
actor_experience.cycle_result.agent_state,
actor_experience.cycle_result.trajectory,
)
agent_state = dataclasses.replace(agent_state, experience=experience_state)
del actor_experiences
@@ -501,7 +511,9 @@ def run(
)
with self._networks_for_actor_lock:
for device in self._networks_for_actor_device:
self._networks_for_actor_device[device] = _filter_device_put(agent_state.nets, device)
self._networks_for_actor_device[device] = _filter_device_put(
pytree_get_index_0(agent_state.nets), device
)

step_infos_device_0 = pytree_get_index_0(cycle_result.step_infos)
trajectory_device_0 = pytree_get_index_0(cycle_result.trajectory)
@@ -524,6 +536,11 @@ def run(

self._stop_actor_threads()
env_steps = [at._env_step for at in self._actor_threads]
agent_state = dataclasses.replace(
agent_state,
# Not ideal. See comment above where we pass actor state to the actor thread.
actor=self._actor_threads[-1]._actor_state,
)
return Result(
agent_state,
None,
@@ -536,14 +553,14 @@ def _off_policy_optim(
) -> tuple[tuple[_Networks, _OptState, _ExperienceState], ArrayMetrics]:
nets, opt_state, experience_state = states
nets_yes_grad, nets_no_grad = self._agent.partition_for_grad(nets)
grad, metrics = _loss_for_cycle_grad(
(loss, experience_state), grad = _loss_for_cycle_grad(
nets_yes_grad,
nets_no_grad,
opt_state,
experience_state,
self._agent,
self._raise_if_metric_conflicts,
)
metrics: ArrayMetrics = {MetricKey.LOSS: loss}
grad = jax.lax.pmean(grad, axis_name=self._PMAP_AXIS_NAME)
grad_means = pytree_leaf_means(grad, "grad_mean")
metrics.update(grad_means)
@@ -571,7 +588,7 @@ def _learn(

def _actor_cycle(
self,
env: VectorEnv,
env: VectorEnv | GymnasiumEnv,
actor_state: typing.Any,
nets: typing.Any,
env_step: EnvStep,
@@ -591,7 +608,11 @@ def _actor_cycle(

with nvtx_annotate("env.step"):
obs, reward, done, trunc, info = env.step(np.array(action_and_state.action))
env_step = EnvStep(done | trunc, obs, action_and_state.action, reward)
# the type ignore is needed because the type returned by a regular GymnasiumEnv
# is a scalar, but we actually will have either a vector env or an envpool env.
# envpool envs are subclasses of GymnasiumEnv, but don't respect its annotated
# return types :-(.
env_step = EnvStep(done | trunc, obs, action_and_state.action, reward) # pyright: ignore[reportArgumentType]

trajectory.append(env_step)
step_infos.append(info)
@@ -686,12 +707,10 @@ def replicate(self, agent_state: AgentState) -> AgentState:
"""Replicates the agent state for distributed training."""
# Don't require the caller to replicate the agent state.
actor_state = agent_state.actor
agent_state_leaves = jax.tree.leaves(
dataclasses.replace(agent_state, actor=None), is_leaf=eqx.is_array
)
assert agent_state_leaves
assert isinstance(agent_state_leaves[0], jax.Array)
if isinstance(agent_state_leaves[0].sharding, jax.sharding.SingleDeviceSharding):
agent_state = dataclasses.replace(agent_state, actor=None)
agent_state_leaf = _some_array_leaf(agent_state)

if isinstance(agent_state_leaf.sharding, jax.sharding.SingleDeviceSharding):
agent_state_arrays, agent_state_static = eqx.partition(agent_state, eqx.is_array)
agent_state_arrays = jax.device_put_replicated(agent_state_arrays, self._learner_devices)
agent_state = eqx.combine(agent_state_arrays, agent_state_static)
31 changes: 13 additions & 18 deletions earl/environment_loop/gymnax_loop.py
Original file line number Diff line number Diff line change
@@ -43,19 +43,15 @@
from earl.utils.sharding import pytree_get_index_0


@eqx.filter_grad(has_aux=True)
@eqx.filter_value_and_grad(has_aux=True)
def _loss_for_cycle_grad(
nets_yes_grad, nets_no_grad, opt_state, experience_state, agent: Agent
) -> tuple[Scalar, ArrayMetrics]:
nets_yes_grad, nets_no_grad, opt_state, experience_state: _ExperienceState, agent: Agent
) -> tuple[Scalar, _ExperienceState]:
# this is a free function so we don't have to pass self as first arg, since filter_grad
# takes gradient with respect to the first arg.
nets = eqx.combine(nets_yes_grad, nets_no_grad)
loss, metrics = agent.loss(nets, opt_state, experience_state)
raise_if_metric_conflicts(metrics)
# inside jit, return values are guaranteed to be arrays
mutable_metrics: ArrayMetrics = typing.cast(ArrayMetrics, dict(metrics))
mutable_metrics[MetricKey.LOSS] = loss
return loss, mutable_metrics
loss, experience_state = agent.loss(nets, opt_state, experience_state)
return loss, experience_state


class StepCarry(eqx.Module, typing.Generic[_ActorState]):
@@ -270,11 +266,12 @@ def _off_policy_optim(
self, agent_state: AgentState[_Networks, _OptState, _ExperienceState, _ActorState], _
) -> tuple[AgentState[_Networks, _OptState, _ExperienceState, _ActorState], ArrayMetrics]:
nets_yes_grad, nets_no_grad = self._agent.partition_for_grad(agent_state.nets)
nets_grad, metrics = _loss_for_cycle_grad(
(loss, experience_state), nets_grad = _loss_for_cycle_grad(
nets_yes_grad, nets_no_grad, agent_state.opt, agent_state.experience, self._agent
)
nets_grad = jax.lax.pmean(nets_grad, axis_name=self._PMAP_AXIS_NAME)
grad_means = pytree_leaf_means(nets_grad, "grad_mean")
metrics: ArrayMetrics = {MetricKey.LOSS: loss}
metrics.update(grad_means)
nets, opt_state = self._agent.optimize_from_grads(agent_state.nets, agent_state.opt, nets_grad)
return dataclasses.replace(agent_state, nets=nets, opt=opt_state), metrics
@@ -313,15 +310,13 @@ def _act_and_loss_grad(
cycle_result.trajectory,
)
agent_state = dataclasses.replace(cycle_result.agent_state, experience=experience_state)
loss, metrics = self._agent.loss(agent_state.nets, agent_state.opt, agent_state.experience)
raise_if_metric_conflicts(metrics)
# inside jit, return values are guaranteed to be arrays
mutable_metrics: ArrayMetrics = typing.cast(ArrayMetrics, dict(metrics))
mutable_metrics[MetricKey.LOSS] = loss
mutable_metrics.update(cycle_result.metrics)
return loss, dataclasses.replace(
cycle_result, metrics=mutable_metrics, agent_state=agent_state
loss, experience_state = self._agent.loss(
agent_state.nets, agent_state.opt, agent_state.experience
)
agent_state = dataclasses.replace(agent_state, experience=experience_state)
metrics: ArrayMetrics = {MetricKey.LOSS: loss}
metrics.update(cycle_result.metrics)
return loss, dataclasses.replace(cycle_result, metrics=metrics, agent_state=agent_state)

if not self._actor_only and not self._agent.num_off_policy_optims_per_cycle():
# On-policy update. Calculate the gradient through the entire cycle.
55 changes: 38 additions & 17 deletions earl/environment_loop/test_gymnasium_loop.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import dataclasses
import time
import typing
from typing import Any

import gymnax
import gymnax.environments.spaces
@@ -8,6 +9,7 @@
import pytest
from gymnasium.envs.classic_control.cartpole import CartPoleEnv
from gymnasium.envs.classic_control.pendulum import PendulumEnv
from gymnasium.vector import VectorEnv
from gymnax.environments.spaces import Box, Discrete
from jax_loop_utils.metric_writers.memory_writer import MemoryWriter

@@ -32,6 +34,7 @@
def test_gymnasium_loop(inference: bool, num_off_policy_updates: int):
num_envs = 2
env = CartPoleEnv()
env_factory = CartPoleEnv
env_info = env_info_from_gymnasium(env, num_envs)
networks = None
key_gen = keygen(jax.random.PRNGKey(0))
@@ -40,12 +43,17 @@ def test_gymnasium_loop(inference: bool, num_off_policy_updates: int):
if not inference and not num_off_policy_updates:
with pytest.raises(ValueError, match="On-policy training is not supported in GymnasiumLoop."):
loop = GymnasiumLoop(
env, agent, num_envs, next(key_gen), metric_writer=metric_writer, actor_only=inference
env_factory,
agent,
num_envs,
next(key_gen),
metric_writer=metric_writer,
actor_only=inference,
)
return

loop = GymnasiumLoop(
env,
env_factory,
agent,
num_envs,
next(key_gen),
@@ -76,34 +84,35 @@ def test_gymnasium_loop(inference: bool, num_off_policy_updates: int):
assert action_count_sum > 0
if not inference:
assert MetricKey.LOSS in metrics_for_step
assert agent._opt_count_metric_key in metrics_for_step
if inference:
assert result.agent_state.opt.opt_count == 0
else:
assert first_step_num is not None
assert last_step_num is not None
assert (
metrics[first_step_num][agent._opt_count_metric_key]
!= metrics[last_step_num][agent._opt_count_metric_key]
)
expected_opt_count = num_cycles * (num_off_policy_updates or 1)
assert result.agent_state.opt.opt_count == expected_opt_count

assert isinstance(env_info.action_space, Discrete)
assert env_info.action_space.n > 0
assert all(not env.closed for env in loop._env_for_actor_thread)
assert all(not typing.cast(VectorEnv, env).closed for env in loop._env_for_actor_thread)
loop.close()
assert all(env.closed for env in loop._env_for_actor_thread)
assert all(typing.cast(VectorEnv, env).closed for env in loop._env_for_actor_thread)


def test_bad_args():
num_envs = 2
env = CartPoleEnv()
env_factory = CartPoleEnv
env_info = env_info_from_gymnasium(env, num_envs)
agent = RandomAgent(env_info, env_info.action_space.sample, 0)
metric_writer = MemoryWriter()
loop = GymnasiumLoop(
env, agent, num_envs, jax.random.PRNGKey(0), metric_writer=metric_writer, actor_only=True
env_factory,
agent,
num_envs,
jax.random.PRNGKey(0),
metric_writer=metric_writer,
actor_only=True,
)
agent_state = agent.new_state(None, jax.random.PRNGKey(0))
with pytest.raises(ValueError, match="num_cycles"):
@@ -116,14 +125,23 @@ def test_bad_metric_key():
networks = None
num_envs = 2
env = CartPoleEnv()
env_factory = CartPoleEnv
env_info = env_info_from_gymnasium(env, num_envs)
key_gen = keygen(jax.random.PRNGKey(0))
# make the agent return a metric with a key that conflicts with a built-in metric.
agent = RandomAgent(env_info, env_info.action_space.sample, 1)
agent = dataclasses.replace(agent, _opt_count_metric_key=MetricKey.DURATION_SEC)

def observe_cycle(trajectory: EnvStep, step_infos: dict[Any, Any]) -> Metrics:
return {MetricKey.DURATION_SEC: 1}

metric_writer = MemoryWriter()
loop = GymnasiumLoop(env, agent, num_envs, next(key_gen), metric_writer=metric_writer)
loop = GymnasiumLoop(
env_factory,
agent,
num_envs,
next(key_gen),
metric_writer=metric_writer,
observe_cycle=observe_cycle,
)
num_cycles = 1
steps_per_cycle = 1
agent_state = agent.new_state(networks, jax.random.PRNGKey(0))
@@ -134,6 +152,7 @@ def test_bad_metric_key():
def test_continuous_action_space():
num_envs = 2
env = PendulumEnv()
env_factory = PendulumEnv
env_info = env_info_from_gymnasium(env, num_envs)
networks = None
key_gen = keygen(jax.random.PRNGKey(0))
@@ -144,7 +163,7 @@ def test_continuous_action_space():
agent = RandomAgent(env_info, action_space.sample, 0)
metric_writer = MemoryWriter()
loop = GymnasiumLoop(
env, agent, num_envs, next(key_gen), metric_writer=metric_writer, actor_only=True
env_factory, agent, num_envs, next(key_gen), metric_writer=metric_writer, actor_only=True
)
num_cycles = 1
steps_per_cycle = 1
@@ -158,6 +177,7 @@ def test_continuous_action_space():
def test_observe_cycle():
num_envs = 2
env = PendulumEnv()
env_factory = PendulumEnv
env_info = env_info_from_gymnasium(env, num_envs)
networks = None
key_gen = keygen(jax.random.PRNGKey(0))
@@ -172,7 +192,7 @@ def observe_cycle(trajectory: EnvStep, step_infos: dict) -> Metrics:
return {"ran": True}

loop = GymnasiumLoop(
env,
env_factory,
agent,
num_envs,
next(key_gen),
@@ -189,6 +209,7 @@ def observe_cycle(trajectory: EnvStep, step_infos: dict) -> Metrics:
def test_benchmark_gymnasium_inference():
num_envs = 16
env = CartPoleEnv()
env_factory = CartPoleEnv
env_info = env_info_from_gymnasium(env, num_envs)
networks = None
key_gen = keygen(jax.random.PRNGKey(0))
@@ -202,7 +223,7 @@ def test_benchmark_gymnasium_inference():
)
metric_writer = MemoryWriter()
loop = GymnasiumLoop(
env,
env_factory,
agent,
num_envs,
next(key_gen),
3 changes: 2 additions & 1 deletion earl/environment_loop/test_gymnasium_loop_multi_device.py
Original file line number Diff line number Diff line change
@@ -51,13 +51,14 @@ def test_actor_learner_different_devices():

num_envs = 2
env = NoOpEnv()
env_factory = NoOpEnv
env_info = env_info_from_gymnasium(env, num_envs)
key_gen = keygen(jax.random.PRNGKey(0))
agent = RandomAgent(env_info, env_info.action_space.sample, 1)
metric_writer = MemoryWriter()

loop = GymnasiumLoop(
env,
env_factory,
agent,
num_envs,
next(key_gen),
26 changes: 16 additions & 10 deletions earl/environment_loop/test_gymnax_loop.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import dataclasses
from typing import Any

import gymnax
import gymnax.environments.spaces
@@ -9,7 +9,8 @@

from earl.agents.random_agent.random_agent import RandomAgent
from earl.core import ConflictingMetricError, EnvStep, Metrics, env_info_from_gymnax
from earl.environment_loop.gymnax_loop import GymnaxLoop, MetricKey
from earl.environment_loop.gymnax_loop import GymnaxLoop
from earl.metric_key import MetricKey
from earl.utils.prng import keygen


@@ -64,16 +65,11 @@ def test_gymnax_loop(inference: bool, num_off_policy_updates: int):
assert action_count_sum > 0
if not inference:
assert MetricKey.LOSS in metrics_for_step
assert agent._opt_count_metric_key in metrics_for_step
if inference:
assert result.agent_state.opt.opt_count == 0
else:
assert first_step_num is not None
assert last_step_num is not None
assert (
metrics[first_step_num][agent._opt_count_metric_key]
!= metrics[last_step_num][agent._opt_count_metric_key]
)
expected_opt_count = num_cycles * (num_off_policy_updates or 1)
assert result.agent_state.opt.opt_count == expected_opt_count

@@ -101,11 +97,21 @@ def test_bad_metric_key():
num_envs = 2
env_info = env_info_from_gymnax(env, env_params, num_envs)
key_gen = keygen(jax.random.PRNGKey(0))
# make the agent return a metric with a key that conflicts with a built-in metric.

def observe_cycle(trajectory: EnvStep, step_infos: dict[Any, Any]) -> Metrics:
return {MetricKey.DURATION_SEC: 1}

agent = RandomAgent(env_info, env.action_space().sample, 0)
agent = dataclasses.replace(agent, _opt_count_metric_key=MetricKey.DURATION_SEC)

loop = GymnaxLoop(env, env_params, agent, num_envs, next(key_gen), metric_writer=NoOpWriter())
loop = GymnaxLoop(
env,
env_params,
agent,
num_envs,
next(key_gen),
metric_writer=NoOpWriter(),
observe_cycle=observe_cycle,
)
num_cycles = 1
steps_per_cycle = 1
agent_state = agent.new_state(networks, jax.random.PRNGKey(0))
1 change: 1 addition & 0 deletions earl/experiments/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -25,6 +25,7 @@ py_library(
py_test(
name = "test_run_experiment",
srcs = ["test_run_experiment.py"],
shard_count = 2,
deps = [
":run_experiment",
"//earl/agents/random_agent",
12 changes: 6 additions & 6 deletions earl/experiments/run_experiment.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import dataclasses
import functools
import pathlib
from collections.abc import Callable, Iterable
from typing import Any
@@ -160,7 +161,7 @@ def _new_checkpoint_manager(


def _new_gymnasium_loop(
env: GymnasiumEnv,
env_factory: Callable[[], GymnasiumEnv],
env_params: Any,
agent: Agent,
num_envs: int,
@@ -173,7 +174,7 @@ def _new_gymnasium_loop(
learner_devices: list[jax.Device] | None = None,
) -> GymnasiumLoop:
return GymnasiumLoop(
env,
env_factory,
agent,
num_envs,
key,
@@ -248,13 +249,13 @@ def run_experiment(config: ExperimentConfig) -> LoopResult:
num_envs = int(num_envs)

if isinstance(env, GymnasiumEnv):
loop_factory = _new_gymnasium_loop
loop_factory = functools.partial(_new_gymnasium_loop, config.new_env) # pyright: ignore[reportArgumentType]
env.close()
else:
assert isinstance(env, gymnax.environments.environment.Environment)
loop_factory = _new_gymnax_loop
loop_factory = functools.partial(_new_gymnax_loop, env) # pyright: ignore[reportArgumentType]

train_loop: GymnasiumLoop | GymnaxLoop = loop_factory(
env, # pyright: ignore[reportArgumentType]
env_params,
agent,
config.num_envs,
@@ -300,7 +301,6 @@ def run_experiment(config: ExperimentConfig) -> LoopResult:
train_cycles_per_eval = config.num_train_cycles // config.num_eval_cycles
eval_key, key = jax.random.split(key)
eval_loop = loop_factory(
env, # pyright: ignore[reportArgumentType]
env_params,
agent,
config.num_envs,
7 changes: 5 additions & 2 deletions earl/utils/sharding.py
Original file line number Diff line number Diff line change
@@ -30,6 +30,9 @@ def pytree_get_index_0(pytree: typing.Any) -> typing.Any:
Otherwise, return the leaf unchanged.
Args:
pytree: The pytree to shard.
pytree: The pytree to get the 0th index of.
Returns:
A pytree with the same structure as the input, but with each array replaced with its 0th index.
"""
return jax.tree.map(lambda x: x[0] if isinstance(x, jax.Array) else x, pytree)
return jax.tree.map(lambda x: x[0] if isinstance(x, jax.Array) and x.ndim else x, pytree)
7 changes: 6 additions & 1 deletion earl/utils/test_sharding.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
import jax.numpy as jnp
import pytest

from earl.utils.sharding import shard_along_axis_0
from earl.utils.sharding import pytree_get_index_0, shard_along_axis_0


def test_shard_along_axis_0_correct_shape_and_contents():
@@ -40,3 +40,8 @@ def test_shard_along_axis_0_not_divisible_raises_error():
arr = jnp.arange((n_devices * 4) + 1) # intentionally off by one
with pytest.raises(ValueError):
shard_along_axis_0(arr, devices)


def test_pytree_get_index_0():
pytree = {"a": jnp.arange(4), "b": {"c": jnp.arange(4)}, "scalar": jnp.array(1)}
assert pytree_get_index_0(pytree) == {"a": 0, "b": {"c": 0}, "scalar": jnp.array(1)}
42 changes: 38 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@ dependencies = [
"equinox>=0.11.11",
"gymnax",
"gymnasium>=1.0.0",
"jax_loop_utils",
"jax-loop-utils>=0.0.13",
"jax>=0.4.0",
"jaxtyping>=0.2.0",
"orbax-checkpoint==0.11.1",
@@ -16,7 +16,7 @@ name = "earl"
version = "0.0.0"
description = "Reinforcement learning with Equinox."
readme = "README.md"
requires-python = ">=3.12"
requires-python = ">=3.11"
license = { text = "MIT" }
authors = [{ name = "Gary Miguel", email = "garymm@garymm.org" }]
classifiers = [
@@ -30,6 +30,16 @@ classifiers = [
keywords = ["Equinox", "JAX", "Reinforcement Learning"]

[project.optional-dependencies]
agent-r2d2 = [
"chex>=0.1.88",
"distrax>=0.1.5",
"envpool>=0.8.4",
"jax-loop-utils[audio-video]>=0.0.14",
"optax>=0.2.4",
"numpy<2.0.0", # https://github.com/sail-sg/envpool/issues/321
"tensorboard>=2.15.0",
"torch>=2.0.0",
]
cuda = ["jax[cuda12]>=0.4.0", "nvtx>=0.2.10"]
jupyter = ["ipykernel>=6.0.0"]
test = ["coverage>=7.0.0", "pytest>=8.0.0", "pytest-repeat==0.9.3"]
@@ -39,7 +49,7 @@ test = ["coverage>=7.0.0", "pytest>=8.0.0", "pytest-repeat==0.9.3"]
Homepage = "http://github.com/garymm/earl"

[dependency-groups]
dev = ["basedpyright==1.26.0", "ruff>=0.9.1"]
dev = ["basedpyright==1.28.0", "ruff>=0.9.1"]

[build-system]
requires = ["hatchling"]
@@ -51,6 +61,18 @@ exclude = ["test_*.py", "BUILD.bazel"]
[tool.hatch.build.targets.wheel]
packages = ["earl"]

[tool.pytest.ini_options]
filterwarnings = [
# Please only ignore warnings that come from a transitive dependency that we
# can't easily avoid.
# See https://docs.pytest.org/en/stable/how-to/capture-warnings.html#controlling-warnings
# action:message:category:module:line
"error",
"ignore:jax.interpreters.xla.pytype_aval_mappings is deprecated.:DeprecationWarning",
# triggered by envpool
"ignore:Shape is deprecated; use StableHLO instead.:DeprecationWarning",
]

[tool.ruff]
line-length = 100
indent-width = 2
@@ -65,11 +87,23 @@ select = [
"SIM", # flake8-simplify
"UP", # pyupgrade
]
[tool.uv]
# restrict to platforms we care about so that version resolution is faster and more likely to succeed
# (e.g. don't fail if a package isn't built for windows)
environments = [
"sys_platform == 'linux' and platform_machine == 'x86_64' and python_version>='3.11'",
]

[[tool.uv.index]]
name = "pytorch-cpu"
url = "https://download.pytorch.org/whl/cpu"
explicit = true

[tool.uv.sources]
# we only need torch for the tensorboard writer
torch = { index = "pytorch-cpu" }
draccus = { git = "https://github.com/dlwh/draccus", rev = "9b690730ca108930519f48cc5dead72a72fd27cb" }
gymnax = { git = "https://github.com/Astera-org/gymnax", rev = "c52a7dac7b41514297d2e98b1b288d56715a5165" }
jax_loop_utils = { git = "https://github.com/Astera-org/jax_loop_utils", rev = "5cd50bfa0a6e42ccc7438fb556d80e1ec3074932" }

[tool.basedpyright]
include = ["earl"]
591 changes: 442 additions & 149 deletions requirements_linux_x86_64.txt

Large diffs are not rendered by default.