Skip to content

Commit 6b2c127

Browse files
authored
Initialize-from custom checkpoints (#5525)
* init from any checkpoint including older ones * moving init_path logic ahead to learn.py * fixing pytest to take the full path * doc & changelog
1 parent 7603fb7 commit 6b2c127

File tree

8 files changed

+106
-15
lines changed

8 files changed

+106
-15
lines changed

com.unity.ml-agents/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ and this project adheres to
1212
#### ml-agents / ml-agents-envs / gym-unity (Python)
1313
### Minor Changes
1414
#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)
15+
- Added the capacity to initialize behaviors from any checkpoint and not just the latest one (#5525)
1516
#### ml-agents / ml-agents-envs / gym-unity (Python)
1617
### Bug Fixes
1718
#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)

docs/Training-Configuration-File.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ choice of the trainer (which we review on subsequent sections).
3333
| `max_steps` | (default = `500000`) Total number of steps (i.e., observation collected and action taken) that must be taken in the environment (or across all environments if using multiple in parallel) before ending the training process. If you have multiple agents with the same behavior name within your environment, all steps taken by those agents will contribute to the same `max_steps` count. <br><br>Typical range: `5e5` - `1e7` |
3434
| `keep_checkpoints` | (default = `5`) The maximum number of model checkpoints to keep. Checkpoints are saved after the number of steps specified by the checkpoint_interval option. Once the maximum number of checkpoints has been reached, the oldest checkpoint is deleted when saving a new checkpoint. |
3535
| `checkpoint_interval` | (default = `500000`) The number of experiences collected between each checkpoint by the trainer. A maximum of `keep_checkpoints` checkpoints are saved before old ones are deleted. Each checkpoint saves the `.onnx` files in `results/` folder.|
36-
| `init_path` | (default = None) Initialize trainer from a previously saved model. Note that the prior run should have used the same trainer configurations as the current run, and have been saved with the same version of ML-Agents. <br><br>You should provide the full path to the folder where the checkpoints were saved, e.g. `./models/{run-id}/{behavior_name}`. This option is provided in case you want to initialize different behaviors from different runs; in most cases, it is sufficient to use the `--initialize-from` CLI parameter to initialize all models from the same run. |
36+
| `init_path` | (default = None) Initialize trainer from a previously saved model. Note that the prior run should have used the same trainer configurations as the current run, and have been saved with the same version of ML-Agents. <br><br>You can provide either the file name or the full path to the checkpoint, e.g. `{checkpoint_name.pt}` or `./models/{run-id}/{behavior_name}/{checkpoint_name.pt}`. This option is provided in case you want to initialize different behaviors from different runs or initialize from an older checkpoint; in most cases, it is sufficient to use the `--initialize-from` CLI parameter to initialize all models from the same run. |
3737
| `threaded` | (default = `false`) Allow environments to step while updating the model. This might result in a training speedup, especially when using SAC. For best performance, leave setting to `false` when using self-play. |
3838
| `hyperparameters -> learning_rate` | (default = `3e-4`) Initial learning rate for gradient descent. Corresponds to the strength of each gradient descent update step. This should typically be decreased if training is unstable, and the reward does not consistently increase. <br><br>Typical range: `1e-5` - `1e-3` |
3939
| `hyperparameters -> batch_size` | Number of experiences in each iteration of gradient descent. **This should always be multiple times smaller than `buffer_size`**. If you are using continuous actions, this value should be large (on the order of 1000s). If you are using only discrete actions, this value should be smaller (on the order of 10s). <br><br> Typical range: (Continuous - PPO): `512` - `5120`; (Continuous - SAC): `128` - `1024`; (Discrete, PPO & SAC): `32` - `512`. |

ml-agents/mlagents/trainers/directory_utils.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import os
22
from mlagents.trainers.exception import UnityTrainerException
3+
from mlagents.trainers.settings import TrainerSettings
4+
from mlagents.trainers.model_saver.torch_model_saver import DEFAULT_CHECKPOINT_NAME
35

46

57
def validate_existing_directories(
@@ -13,6 +15,7 @@ def validate_existing_directories(
1315
:param summary_path: The summary path to be used.
1416
:param resume: Whether or not the --resume flag was passed.
1517
:param force: Whether or not the --force flag was passed.
18+
:param init_path: Path to run-id dir to initialize from
1619
"""
1720

1821
output_path_exists = os.path.isdir(output_path)
@@ -40,3 +43,34 @@ def validate_existing_directories(
4043
init_path
4144
)
4245
)
46+
47+
48+
def setup_init_path(
49+
behaviors: TrainerSettings.DefaultTrainerDict, init_dir: str
50+
) -> None:
51+
"""
52+
For each behavior, setup full init_path to checkpoint file to initialize policy from
53+
:param behaviors: mapping from behavior_name to TrainerSettings
54+
:param init_dir: Path to run-id dir to initialize from
55+
"""
56+
for behavior_name, ts in behaviors.items():
57+
if ts.init_path is None:
58+
# set default if None
59+
ts.init_path = os.path.join(
60+
init_dir, behavior_name, DEFAULT_CHECKPOINT_NAME
61+
)
62+
elif not os.path.dirname(ts.init_path):
63+
# update to full path if just the file name
64+
ts.init_path = os.path.join(init_dir, behavior_name, ts.init_path)
65+
_validate_init_full_path(ts.init_path)
66+
67+
68+
def _validate_init_full_path(init_file: str) -> None:
69+
"""
70+
Validate initialization path to be a .pt file
71+
:param init_file: full path to initialization checkpoint file
72+
"""
73+
if not (os.path.isfile(init_file) and init_file.endswith(".pt")):
74+
raise UnityTrainerException(
75+
f"Could not initialize from {init_file}. file does not exists or is not a `.pt` file"
76+
)

ml-agents/mlagents/trainers/learn.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
from mlagents.trainers.trainer_controller import TrainerController
1414
from mlagents.trainers.environment_parameter_manager import EnvironmentParameterManager
1515
from mlagents.trainers.trainer import TrainerFactory
16-
from mlagents.trainers.directory_utils import validate_existing_directories
16+
from mlagents.trainers.directory_utils import (
17+
validate_existing_directories,
18+
setup_init_path,
19+
)
1720
from mlagents.trainers.stats import StatsReporter
1821
from mlagents.trainers.cli_utils import parser
1922
from mlagents_envs.environment import UnityEnvironment
@@ -72,11 +75,14 @@ def run_training(run_seed: int, options: RunOptions) -> None:
7275
)
7376
# Make run logs directory
7477
os.makedirs(run_logs_dir, exist_ok=True)
75-
# Load any needed states
78+
# Load any needed states in case of resume
7679
if checkpoint_settings.resume:
7780
GlobalTrainingStatus.load_state(
7881
os.path.join(run_logs_dir, "training_status.json")
7982
)
83+
# In case of initialization, set full init_path for all behaviors
84+
elif checkpoint_settings.maybe_init_path is not None:
85+
setup_init_path(options.behaviors, checkpoint_settings.maybe_init_path)
8086

8187
# Configure Tensorboard Writers and StatsReporter
8288
stats_writers = register_stats_writer_plugins(options)

ml-agents/mlagents/trainers/model_saver/torch_model_saver.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313

1414
logger = get_logger(__name__)
15+
DEFAULT_CHECKPOINT_NAME = "checkpoint.pt"
1516

1617

1718
class TorchModelSaver(BaseModelSaver):
@@ -55,7 +56,7 @@ def save_checkpoint(self, behavior_name: str, step: int) -> Tuple[str, List[str]
5556
pytorch_ckpt_path = f"{checkpoint_path}.pt"
5657
export_ckpt_path = f"{checkpoint_path}.onnx"
5758
torch.save(state_dict, f"{checkpoint_path}.pt")
58-
torch.save(state_dict, os.path.join(self.model_path, "checkpoint.pt"))
59+
torch.save(state_dict, os.path.join(self.model_path, DEFAULT_CHECKPOINT_NAME))
5960
self.export(checkpoint_path, behavior_name)
6061
return export_ckpt_path, [pytorch_ckpt_path]
6162

@@ -75,16 +76,19 @@ def initialize_or_load(self, policy: Optional[TorchPolicy] = None) -> None:
7576
)
7677
elif self.load:
7778
logger.info(f"Resuming from {self.model_path}.")
78-
self._load_model(self.model_path, policy, reset_global_steps=reset_steps)
79+
self._load_model(
80+
os.path.join(self.model_path, DEFAULT_CHECKPOINT_NAME),
81+
policy,
82+
reset_global_steps=reset_steps,
83+
)
7984

8085
def _load_model(
8186
self,
8287
load_path: str,
8388
policy: Optional[TorchPolicy] = None,
8489
reset_global_steps: bool = False,
8590
) -> None:
86-
model_path = os.path.join(load_path, "checkpoint.pt")
87-
saved_state_dict = torch.load(model_path)
91+
saved_state_dict = torch.load(load_path)
8892
if policy is None:
8993
modules = self.modules
9094
policy = self.policy

ml-agents/mlagents/trainers/tests/test_trainer_util.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22
import io
33
import os
4+
import yaml
45
from unittest.mock import patch
56

67
from mlagents.trainers.trainer import TrainerFactory
@@ -10,7 +11,10 @@
1011
from mlagents.trainers.settings import RunOptions
1112
from mlagents.trainers.tests.dummy_config import ppo_dummy_config
1213
from mlagents.trainers.environment_parameter_manager import EnvironmentParameterManager
13-
from mlagents.trainers.directory_utils import validate_existing_directories
14+
from mlagents.trainers.directory_utils import (
15+
validate_existing_directories,
16+
setup_init_path,
17+
)
1418

1519

1620
@pytest.fixture
@@ -137,3 +141,47 @@ def test_existing_directories(tmp_path):
137141
os.mkdir(init_path)
138142
# Should pass since the directory exists now.
139143
validate_existing_directories(output_path, False, True, init_path)
144+
145+
146+
@pytest.mark.parametrize("dir_exists", [True, False])
147+
def test_setup_init_path(tmpdir, dir_exists):
148+
"""
149+
150+
:return:
151+
"""
152+
test_yaml = """
153+
behaviors:
154+
BigWallJump:
155+
init_path: BigWallJump-6540981.pt #full path
156+
trainer_type: ppo
157+
MediumWallJump:
158+
init_path: {}/test_setup_init_path_results/test_run_id/MediumWallJump/checkpoint.pt
159+
trainer_type: ppo
160+
SmallWallJump:
161+
trainer_type: ppo
162+
checkpoint_settings:
163+
run_id: test_run_id
164+
initialize_from: test_run_id
165+
""".format(
166+
tmpdir
167+
)
168+
run_options = RunOptions.from_dict(yaml.safe_load(test_yaml))
169+
if dir_exists:
170+
init_path = tmpdir.mkdir("test_setup_init_path_results").mkdir("test_run_id")
171+
big = init_path.mkdir("BigWallJump").join("BigWallJump-6540981.pt")
172+
big.write("content")
173+
med = init_path.mkdir("MediumWallJump").join("checkpoint.pt")
174+
med.write("content")
175+
small = init_path.mkdir("SmallWallJump").join("checkpoint.pt")
176+
small.write("content")
177+
178+
setup_init_path(run_options.behaviors, init_path)
179+
assert run_options.behaviors["BigWallJump"].init_path == big
180+
assert run_options.behaviors["MediumWallJump"].init_path == med
181+
assert run_options.behaviors["SmallWallJump"].init_path == small
182+
else:
183+
# don't make dirs and fail
184+
with pytest.raises(UnityTrainerException):
185+
setup_init_path(
186+
run_options.behaviors, run_options.checkpoint_settings.maybe_init_path
187+
)

ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99
from mlagents.trainers.ppo.optimizer_torch import TorchPPOOptimizer
1010
from mlagents.trainers.sac.optimizer_torch import TorchSACOptimizer
1111
from mlagents.trainers.poca.optimizer_torch import TorchPOCAOptimizer
12-
from mlagents.trainers.model_saver.torch_model_saver import TorchModelSaver
12+
from mlagents.trainers.model_saver.torch_model_saver import (
13+
TorchModelSaver,
14+
DEFAULT_CHECKPOINT_NAME,
15+
)
1316
from mlagents.trainers.settings import (
1417
TrainerSettings,
1518
NetworkSettings,
@@ -62,7 +65,7 @@ def test_load_save_policy(tmp_path):
6265
assert policy2.get_current_step() == 2000
6366

6467
# Try initialize from path 1
65-
trainer_params.init_path = path1
68+
trainer_params.init_path = os.path.join(path1, DEFAULT_CHECKPOINT_NAME)
6669
model_saver3 = TorchModelSaver(trainer_params, path2)
6770
policy3 = create_policy_mock(trainer_params)
6871
model_saver3.register(policy3)

ml-agents/mlagents/trainers/trainer/trainer_factory.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ def generate(self, behavior_name: str) -> Trainer:
6666
self.ghost_controller,
6767
self.seed,
6868
self.param_manager,
69-
self.init_path,
7069
self.multi_gpu,
7170
)
7271

@@ -80,7 +79,6 @@ def _initialize_trainer(
8079
ghost_controller: GhostController,
8180
seed: int,
8281
param_manager: EnvironmentParameterManager,
83-
init_path: str = None,
8482
multi_gpu: bool = False,
8583
) -> Trainer:
8684
"""
@@ -96,12 +94,9 @@ def _initialize_trainer(
9694
:param ghost_controller: The object that coordinates ghost trainers
9795
:param seed: The random seed to use
9896
:param param_manager: EnvironmentParameterManager, used to determine a reward buffer length for PPOTrainer
99-
:param init_path: Path from which to load model, if different from model_path.
10097
:return:
10198
"""
10299
trainer_artifact_path = os.path.join(output_path, brain_name)
103-
if init_path is not None:
104-
trainer_settings.init_path = os.path.join(init_path, brain_name)
105100

106101
min_lesson_length = param_manager.get_minimum_reward_buffer_size(brain_name)
107102

0 commit comments

Comments
 (0)