Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update AgentState doc string #37

Merged
merged 1 commit into from
Mar 3, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 35 additions & 10 deletions earl/core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
"""Core types."""
"""Core types.

The goal of the Agent and AgentState classes is to be structured such that it is feasible to
implement distributed training following the Sebulba architecture described in
[Podracer architectures for scalable Reinforcement Learning](https://arxiv.org/abs/2104.06272),
and otherwise allowing users flexibility to implement things as they wish.
"""

import abc
from collections.abc import Callable, Mapping, Sequence
Expand Down Expand Up @@ -51,27 +57,46 @@ def __str__(self) -> str: ...
class AgentState(eqx.Module, Generic[_Networks, _OptState, _ExperienceState, _ActorState]):
"""The state of an agent.

To enable implementation of the Sebulba architecture, Earl requires the user to split their
Agent's state into the fields in this class.
The Agent method signatures enforce that the different pieces of state are read and written only
when appropriate.

All pytree leaves in the fields of any subclass must be one of the following types:
bool, int, float, jax.Array. This allows them to be saved and restored
by orbax.
"""

actor: _ActorState
"""Contains anything that needs to be updated on each act().
Typically random keys and RNN states go here."""
"""Actor state. This is read and written by the actor.

This is also read by the learner when calculating the loss.
In agents that use recurrent networks, this includes the recurrent hidden states.
"""
nets: _Networks
"""neural networks. It must be a PyTree since it will
be passed to equinox.combine(). It is updated in update_from_grads() based on gradients.
Anything that needs to be optimized via gradient descent should be in nets.
"""Neural networks.

This is read by the actor. It is read and written by the learner. Anything that needs a
gradient computed needs to be in the networks.

The nets must be a PyTree since it will be passed to equinox.combine().

Any objects that need to change their behavior in inference or training mode should
have a boolean member variable that is named "inference" for that purpose
(all equinox.nn built-in modules have this)."""
(all equinox.nn built-in modules have this).
"""
opt: _OptState
"""contains anything other than nets that also needs to be updated when optimizing
"""Optimization state.

Anything other than nets that also needs to be updated when optimizing
(i.e. in optimize_from_grads()). This is where optimizer state belongs.
Set to optax.OptState if you only have one optimizer, or you can set it to a custom class."""

Can be optax.OptState if that's all you need, or you can set it to a custom class."""
experience: _ExperienceState
"""experience replay buffer state."""
"""Experience state accumulated from the actor and sent to the learner.

For agents that use experience replay, this replay buffers.
"""


class ActionAndState(NamedTuple, Generic[_ActorState]):
Expand Down