Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix basedpyright and enable run_experiment test #14

Merged
merged 2 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
{
"python.analysis.typeCheckingMode": "standard",
"python.analysis.ignore": [
"basedpyright.analysis.ignore": [
"${workspaceFolder}/.venv"
],
"python.defaultInterpreterPath": "${workspaceFolder}/.venv/bin/python",
Expand Down
7 changes: 7 additions & 0 deletions MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ pip.parse(
python_version = _PYTHON_VERSION,
requirements_lock = "//:requirements_linux_x86_64.txt",
)

# https://github.com/google/orbax/issues/1429
pip.override(
file = "orbax_checkpoint-0.11.1-py3-none-any.whl",
patch_strip = 1,
patches = ["//tools/patches:orbax-remove-BUILD.patch"],
)
use_repo(pip, "pypi")
# END python dependencies

Expand Down
101 changes: 97 additions & 4 deletions MODULE.bazel.lock

Large diffs are not rendered by default.

9 changes: 4 additions & 5 deletions earl/environment_loop/gymnasium_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,14 +259,13 @@ def _env_factory() -> GymnasiumEnv:
raise ValueError("On-policy training is not supported in GymnasiumLoop.")

devices = devices or jax.local_devices()
self._inference_device: jax.Device = devices[0]
self._update_device: jax.Device = devices[0]
if len(devices) > 1:
self._inference_device: jax.Device = devices[0]
self._update_device: jax.Device = devices[1]
self._inference_device = devices[0]
self._update_device = devices[1]
if len(devices) > 2:
_logger.warning("Multiple update devices are not supported yet. Using only the first device.")
else:
self._inference_device: jax.Device = devices[0]
self._update_device: jax.Device = devices[0]

def reset_env(self) -> EnvStep:
"""Resets the environment.
Expand Down
33 changes: 33 additions & 0 deletions earl/experiments/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
load("@aspect_rules_py//py:defs.bzl", "py_library")
load("//tools/py_test:py_test.bzl", "py_test")

py_library(
name = "run_experiment",
srcs = [
"config.py",
"run_experiment.py",
],
deps = [
"//earl:core",
"//earl/environment_loop:gymnasium_loop",
"//earl/environment_loop:gymnax_loop",
"@pypi//draccus",
"@pypi//equinox",
"@pypi//gymnasium",
"@pypi//gymnax",
"@pypi//jax",
"@pypi//jax_loop_utils",
"@pypi//jaxtyping",
"@pypi//orbax_checkpoint",
],
)

py_test(
name = "test_run_experiment",
srcs = ["test_run_experiment.py"],
deps = [
":run_experiment",
"//earl/agents/random_agent",
"@pypi//orbax_checkpoint",
],
)
16 changes: 12 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
# https://python-poetry.org/docs/pyproject/
[project]
dependencies = [
"draccus",
"equinox>=0.11.11",
"gymnax",
"gymnasium>=1.0.0",
"jax_loop_utils",
"jax>=0.4.0",
"jaxtyping>=0.2.0",
"orbax-checkpoint==0.11.1",
"tensorstore==0.1.68", # transitive of orbax-checkpoint, pinned to avoid bug https://github.com/google/orbax/issues/1429#issuecomment-2543832552
"tqdm>=4.0.0",
]
name = "earl"
Expand Down Expand Up @@ -40,7 +43,10 @@ requires = ["hatchling"]
build-backend = "hatchling.build"

[tool.hatch.build]
exclude = ["test_*.py"]
exclude = ["test_*.py", "BUILD.bazel"]

[tool.hatch.build.targets.wheel]
packages = ["earl"]

[tool.pytest.ini_options]
filterwarnings = [
Expand All @@ -65,9 +71,11 @@ select = [
"UP", # pyupgrade
]

[tool.pyright]
typeCheckingMode = "standard"

[tool.uv.sources]
draccus = { git = "https://github.com/dlwh/draccus", rev = "9b690730ca108930519f48cc5dead72a72fd27cb" }
gymnax = { git = "https://github.com/Astera-org/gymnax", rev = "c52a7dac7b41514297d2e98b1b288d56715a5165" }
jax_loop_utils = { git = "https://github.com/Astera-org/jax_loop_utils", rev = "5cd50bfa0a6e42ccc7438fb556d80e1ec3074932" }

[tool.basedpyright]
include = ["earl"]
typeCheckingMode = "standard"
74 changes: 51 additions & 23 deletions requirements_linux_x86_64.txt
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ cycler==0.12.1 \
--hash=sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30 \
--hash=sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c
# via matplotlib
draccus @ git+https://github.com/dlwh/draccus@9b690730ca108930519f48cc5dead72a72fd27cb
# via earl (pyproject.toml)
equinox==0.11.11 \
--hash=sha256:49e9674f9bff0cde7ebcfbf2cdf4585c9231eb377eda31168bbf6467f88241e5 \
--hash=sha256:648072c1384adc3528930a3bf089246fd77aa873310a19f1f21c08e7681f95a7
Expand Down Expand Up @@ -418,6 +420,10 @@ mdurl==0.1.2 \
--hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \
--hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba
# via markdown-it-py
mergedeep==1.3.4 \
--hash=sha256:0096d52e9dad9939c3d975a774666af186eda617e6ca84df4c94dec30004f2a8 \
--hash=sha256:70775750742b25c0d8f36c55aed03d24c3384d17c951b3175d898bd778ef0307
# via draccus
ml-dtypes==0.5.1 \
--hash=sha256:023ce2f502efd4d6c1e0472cc58ce3640d051d40e71e27386bed33901e201327 \
--hash=sha256:05f23447a1c20ddf4dc7c2c661aa9ed93fcb2658f1017c204d1e758714dc28a8 \
Expand Down Expand Up @@ -515,6 +521,10 @@ msgpack==1.1.0 \
# via
# flax
# orbax-checkpoint
mypy-extensions==1.0.0 \
--hash=sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d \
--hash=sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782
# via typing-inspect
nest-asyncio==1.6.0 \
--hash=sha256:6f172d5449aca15afd6c646851f4e31e02c598d553a667e38cafa997cfec55fe \
--hash=sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c
Expand Down Expand Up @@ -603,7 +613,9 @@ optax==0.2.4 \
orbax-checkpoint==0.11.1 \
--hash=sha256:9cfbe49331ffdef4aedfec337da683eb8fc65221dc43924cf6daea6c27d6dc06 \
--hash=sha256:e303f4f4441ba0367a14c692d6c73070bba954205de89df608f21823dd86eaee
# via flax
# via
# earl (pyproject.toml)
# flax
packaging==24.2 \
--hash=sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759 \
--hash=sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f
Expand Down Expand Up @@ -822,9 +834,15 @@ pyyaml==6.0.2 \
--hash=sha256:f753120cb8181e736c57ef7636e83f31b9c0d1722c516f7e86cf15b7aa57ff12 \
--hash=sha256:ff3824dc5261f50c9b0dfb3be22b4567a6f938ccce4587b38952d85fd9e9afe4
# via
# draccus
# flax
# gymnax
# orbax-checkpoint
# pyyaml-include
pyyaml-include==1.4.1 \
--hash=sha256:1a96e33a99a3e56235f5221273832464025f02ff3d8539309a3bf00dec624471 \
--hash=sha256:323c7f3a19c82fbc4d73abbaab7ef4f793e146a13383866831631b26ccc7fb00
# via draccus
rich==13.9.4 \
--hash=sha256:439594978a49a09530cff7ebc4b5c7103ef57baf48d5ea3184f21d9a2befa098 \
--hash=sha256:6049d5e6ec054bf2779ab3358186963bac2ea89175919d699e378b99738c2a90
Expand Down Expand Up @@ -997,31 +1015,36 @@ six==1.17.0 \
--hash=sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274 \
--hash=sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81
# via python-dateutil
tensorstore==0.1.71 \
--hash=sha256:0bd87899e1c6049b078e785e8b7871e2579202f5b929e89c3c37340965b922bb \
--hash=sha256:1a6cdcc52e4b841d23e50a2fa28e016e6d9f61d6ea9188d4555ea189b040a0f6 \
--hash=sha256:31e39ed7d374f43e45bff52611bad99315c577b44c099b2f6837b801b3467645 \
--hash=sha256:321d6302e5116b20fda500821240eba7de28477209070728d98edefced97d2b5 \
--hash=sha256:373558b803d8c2c57fc613b11007ae58139f19a3cddd443a0de5d7b5321e5961 \
--hash=sha256:46ff0f41ef3b1dbd1a925d62e6475523a587bcd37b277bf4f633f46f5b7e22bd \
--hash=sha256:52b546f076b2c3bf217c60f05de4124cc1197ce92f8e826e7ec73ae324074a5a \
--hash=sha256:583f0ec143062176ca21fe8dcc3b3b6f94d7f4ea643443b49942d3d1a2fa29b4 \
--hash=sha256:5c37c7b385517b568282a7aedded446216335d0cb41187c93c80b53596c92c96 \
--hash=sha256:6276e279b45eb5d9b95c4df3e7956255f414fd4b128d2de16d8aecde86c36357 \
--hash=sha256:65c3a1a2a35a1b537403f36403d258caab477e564bc0f64109b941cc77b4f203 \
--hash=sha256:75a9ff1f7b6759094cc210baa4e8135c4898472e08a7036476374433d03c6a34 \
--hash=sha256:87a97a34b0475ddc7d2afc40e5dd7f8d12522aa81edfbcccb39628cf591454d5 \
--hash=sha256:95041b55a2ec86d1f6690512d1883581b18f2f4f46c3d97894aeb0ac2db6af7f \
--hash=sha256:b961bbbb7a1c6a48e4c1406a98caebeb400461e2e75a08b6df0c013294037a15 \
--hash=sha256:ced5430bcdfa7fcb3a6bdc44733176158cb877b35bdd233cac82e25b4cc94e92 \
--hash=sha256:d3a24feb6195f1c222162965c0107c9ff56d322cca23e19f0e66636f6eb80f14 \
--hash=sha256:de8843fb3462899de7bcdeeaccb92303a9d61006bc36364deb4a88df46320ba4 \
--hash=sha256:ecf4feb574051f40e81572ea2ff8e5895b2980c5dd3b29fe81c70d25e42d3b6a \
--hash=sha256:f3e62aa7b473c0715706a809da3591763906059e8731a38c0b495337a1dc55ea \
--hash=sha256:f40e73bcdc333dfb3f7fe0fcf023bcbec41533c9856657718ff76ece1a1902e0
tensorstore==0.1.68 \
--hash=sha256:172420ec1c4e925a8ec3c386e31b4f81eae403bdca71b6258e7f775a69c3bfb3 \
--hash=sha256:23dc88d5188267529beb72012f72ce892ee25d40daf9dd533413bfc818b1d030 \
--hash=sha256:3deec3fab13f489493ecf76206114c6ac657fd2fe4e9469d3ad843e916ca7cf8 \
--hash=sha256:425c56cdd7f76af8be0c056933da9bf8b8812c00e4fef08888465e2f126d53eb \
--hash=sha256:4c392ceabd864b8c18546ce690a758030590b3b9416dcef9f3fbbce862ed0ccb \
--hash=sha256:50d6119fd6d158be3f96c04f75484373c998bb11380ca8ac85a3d3d8e85145fa \
--hash=sha256:5902d7c36e6119b761d02260b68646585b315202397e2a6c016e3f5d81d39a43 \
--hash=sha256:62561f9cb29e9e887b646177235b35fcb52b50c14f446768568f7fb3a95a571b \
--hash=sha256:6672b2047df3f772350ac75d6780f31201a82383c5b7c0c1986903b88e6f341a \
--hash=sha256:68e3a84ef65ae5583dda40eecf39af413c905f2a7fcec52ede4df9ab912c3a59 \
--hash=sha256:6e13d3e3c8fb6ed67712835a343821536b38d6bdb517db554d41cebfe5947ab7 \
--hash=sha256:76ebad6762d226c9621d256d8703381963e407d0361cd33f0f89409a31acb57e \
--hash=sha256:889900ee6a9ffba4635f44f663b41f5b43f67b1e74bd507fa4a30f0f02704c80 \
--hash=sha256:9d62c4288e68b4640de878f8393a5779440b2de8e84cf7b717f91a01a4e6b4be \
--hash=sha256:a1348768a5aae514b440212eedb50d246a1a4b39f8e74d275ef0bead688c562b \
--hash=sha256:a93fe05708acb9d9e3813f7f7ecd807c8ff34ec3fa30e2baa37e9270d128dcf0 \
--hash=sha256:b6e51188a82c93563440c805bd501b12f0dc30267667f664091b3a2b8b108017 \
--hash=sha256:c65460ac90f8db49ad35779964ea5983332fe63e60b4d94ba66640c68ef73091 \
--hash=sha256:c9ca5a5dc1e13760f024c3607219e60c3b8338f1b4f7413e1a13115a132ac7d9 \
--hash=sha256:d5fa0e47b42eb58ddea81763cb0de4a92c4ab0da530d2a27f1928539980a781a \
--hash=sha256:d80f9b48b057fda9aea0407e576324354b054aae02fa08fc0a8e6b11acf7ae3a
# via
# earl (pyproject.toml)
# flax
# orbax-checkpoint
toml==0.10.2 \
--hash=sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b \
--hash=sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f
# via draccus
toolz==1.0.0 \
--hash=sha256:292c8f1c4e7516bf9086f8850935c799a874039c8bcf959d47b600e4c44a6236 \
--hash=sha256:2c86e3d9a04798ac556793bced838816296a2f085017664e4995cb40a1047a02
Expand All @@ -1040,6 +1063,11 @@ typing-extensions==4.12.2 \
# flax
# gymnasium
# orbax-checkpoint
# typing-inspect
typing-inspect==0.9.0 \
--hash=sha256:9ee6fc59062311ef8547596ab6b955e1b8aa46242d854bfc78f4f6b0eff35f9f \
--hash=sha256:b23fc42ff6f6ef6954e4852c1fb512cdd18dbea03134f91f856a95ccc9461f78
# via draccus
tzdata==2025.1 \
--hash=sha256:24894909e88cdb28bd1636c6887801df64cb485bd593f2fd83ef29075a81d694 \
--hash=sha256:7e127113816800496f027041c570f50bcd464a020098a3b6b199517772303639
Expand Down
Empty file added tools/patches/BUILD.bazel
Empty file.
Loading