Skip to content

Commit ac4f43c

Browse files
andrewcohvincentpierreErvin T.
authored
Load individual elements if state dict load fails (#5213)
Co-authored-by: Vincent-Pierre BERGES <vincentpierre@unity3d.com> Co-authored-by: Ervin T. <ervin@unity3d.com>
1 parent 30fde2d commit ac4f43c

File tree

5 files changed

+147
-3
lines changed

5 files changed

+147
-3
lines changed

com.unity.ml-agents/CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ sizes and will need to be retrained. (#5181)
3131
different sizes using the same model. For a summary of the interface changes, please see the Migration Guide. (##5189)
3232

3333
#### ml-agents / ml-agents-envs / gym-unity (Python)
34+
- The `--resume` flag now supports resuming experiments with additional reward providers or
35+
loading partial models if the network architecture has changed. See
36+
[here](https://github.com/Unity-Technologies/ml-agents/blob/main/docs/Training-ML-Agents.md#loading-an-existing-model)
37+
for more details. (#5213)
3438

3539
### Minor Changes
3640
#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)

docs/Training-ML-Agents.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,13 @@ Python by using both the `--resume` and `--inference` flags. Note that if you
117117
want to run inference in Unity, you should use the
118118
[Unity Inference Engine](Getting-Started.md#running-a-pre-trained-model).
119119

120+
Additionally, if the network architecture changes, you may still load an existing model,
121+
but ML-Agents will only load the parts of the model it can load and ignore all others. For instance,
122+
if you add a new reward signal, the existing model will load but the new reward signal
123+
will be initialized from scratch. If you have a model with a visual encoder (CNN) but
124+
change the `hidden_units`, the CNN will be loaded but the body of the network will be
125+
initialized from scratch.
126+
120127
Alternatively, you might want to start a new training run but _initialize_ it
121128
using an already-trained model. You may want to do this, for instance, if your
122129
environment changed and you want a new model, but the old behavior is still

ml-agents/mlagents/trainers/model_saver/torch_model_saver.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,34 @@ def _load_model(
8989
policy = cast(TorchPolicy, policy)
9090

9191
for name, mod in modules.items():
92-
mod.load_state_dict(saved_state_dict[name])
92+
try:
93+
if isinstance(mod, torch.nn.Module):
94+
missing_keys, unexpected_keys = mod.load_state_dict(
95+
saved_state_dict[name], strict=False
96+
)
97+
if missing_keys:
98+
logger.warning(
99+
f"Did not find these keys {missing_keys} in checkpoint. Initializing."
100+
)
101+
if unexpected_keys:
102+
logger.warning(
103+
f"Did not expect these keys {unexpected_keys} in checkpoint. Ignoring."
104+
)
105+
else:
106+
# If module is not an nn.Module, try to load as one piece
107+
mod.load_state_dict(saved_state_dict[name])
108+
109+
# KeyError is raised if the module was not present in the last run but is being
110+
# accessed in the saved_state_dict.
111+
# ValueError is raised by the optimizer's load_state_dict if the parameters have
112+
# have changed. Note, the optimizer uses a completely different load_state_dict
113+
# function because it is not an nn.Module.
114+
# RuntimeError is raised by PyTorch if there is a size mismatch between modules
115+
# of the same name. This will still partially assign values to those layers that
116+
# have not changed shape.
117+
except (KeyError, ValueError, RuntimeError) as err:
118+
logger.warning(f"Failed to load for module {name}. Initializing")
119+
logger.debug(f"Module loading error : {err}")
93120

94121
if reset_global_steps:
95122
policy.set_step(0)

ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from mlagents.trainers.model_saver.torch_model_saver import TorchModelSaver
1313
from mlagents.trainers.settings import (
1414
TrainerSettings,
15+
NetworkSettings,
16+
EncoderType,
1517
PPOSettings,
1618
SACSettings,
1719
POCASettings,
@@ -70,6 +72,50 @@ def test_load_save_policy(tmp_path):
7072
assert policy3.get_current_step() == 0
7173

7274

75+
@pytest.mark.parametrize("vis_encode_type", ["resnet", "nature_cnn", "match3"])
76+
def test_load_policy_different_hidden_units(tmp_path, vis_encode_type):
77+
path1 = os.path.join(tmp_path, "runid1")
78+
trainer_params = TrainerSettings()
79+
trainer_params.network_settings = NetworkSettings(
80+
hidden_units=12, vis_encode_type=EncoderType(vis_encode_type)
81+
)
82+
policy = create_policy_mock(trainer_params, use_visual=True)
83+
conv_params = [mod for mod in policy.actor.parameters() if len(mod.shape) > 2]
84+
85+
model_saver = TorchModelSaver(trainer_params, path1)
86+
model_saver.register(policy)
87+
model_saver.initialize_or_load(policy)
88+
policy.set_step(2000)
89+
90+
mock_brain_name = "MockBrain"
91+
model_saver.save_checkpoint(mock_brain_name, 2000)
92+
93+
# Try load from this path
94+
trainer_params2 = TrainerSettings()
95+
trainer_params2.network_settings = NetworkSettings(
96+
hidden_units=10, vis_encode_type=EncoderType(vis_encode_type)
97+
)
98+
model_saver2 = TorchModelSaver(trainer_params2, path1, load=True)
99+
policy2 = create_policy_mock(trainer_params2, use_visual=True)
100+
conv_params2 = [mod for mod in policy2.actor.parameters() if len(mod.shape) > 2]
101+
# asserts convolutions have different parameters before load
102+
for conv1, conv2 in zip(conv_params, conv_params2):
103+
assert not torch.equal(conv1, conv2)
104+
# asserts layers still have different dimensions
105+
for mod1, mod2 in zip(policy.actor.parameters(), policy2.actor.parameters()):
106+
if mod1.shape[0] == 12:
107+
assert mod2.shape[0] == 10
108+
model_saver2.register(policy2)
109+
model_saver2.initialize_or_load(policy2)
110+
# asserts convolutions have same parameters after load
111+
for conv1, conv2 in zip(conv_params, conv_params2):
112+
assert torch.equal(conv1, conv2)
113+
# asserts layers still have different dimensions
114+
for mod1, mod2 in zip(policy.actor.parameters(), policy2.actor.parameters()):
115+
if mod1.shape[0] == 12:
116+
assert mod2.shape[0] == 10
117+
118+
73119
@pytest.mark.parametrize(
74120
"optimizer",
75121
[

ml-agents/mlagents/trainers/tests/torch/saver/test_saver_reward_providers.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33

44
import numpy as np
55

6+
from mlagents_envs.logging_util import WARNING
67
from mlagents.trainers.ppo.optimizer_torch import TorchPPOOptimizer
78
from mlagents.trainers.sac.optimizer_torch import TorchSACOptimizer
9+
from mlagents.trainers.poca.optimizer_torch import TorchPOCAOptimizer
810
from mlagents.trainers.model_saver.torch_model_saver import TorchModelSaver
911
from mlagents.trainers.settings import (
1012
TrainerSettings,
@@ -14,12 +16,14 @@
1416
RNDSettings,
1517
PPOSettings,
1618
SACSettings,
19+
POCASettings,
1720
)
1821
from mlagents.trainers.tests.torch.test_policy import create_policy_mock
1922
from mlagents.trainers.tests.torch.test_reward_providers.utils import (
2023
create_agent_buffer,
2124
)
2225

26+
2327
DEMO_PATH = (
2428
os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir, os.pardir)
2529
+ "/test.demo"
@@ -28,8 +32,12 @@
2832

2933
@pytest.mark.parametrize(
3034
"optimizer",
31-
[(TorchPPOOptimizer, PPOSettings), (TorchSACOptimizer, SACSettings)],
32-
ids=["ppo", "sac"],
35+
[
36+
(TorchPPOOptimizer, PPOSettings),
37+
(TorchSACOptimizer, SACSettings),
38+
(TorchPOCAOptimizer, POCASettings),
39+
],
40+
ids=["ppo", "sac", "poca"],
3341
)
3442
def test_reward_provider_save(tmp_path, optimizer):
3543
OptimizerClass, HyperparametersClass = optimizer
@@ -87,3 +95,55 @@ def test_reward_provider_save(tmp_path, optimizer):
8795
rp_1 = optimizer.reward_signals[reward_name]
8896
rp_2 = optimizer2.reward_signals[reward_name]
8997
assert np.array_equal(rp_1.evaluate(data), rp_2.evaluate(data))
98+
99+
100+
@pytest.mark.parametrize(
101+
"optimizer",
102+
[
103+
(TorchPPOOptimizer, PPOSettings),
104+
(TorchSACOptimizer, SACSettings),
105+
(TorchPOCAOptimizer, POCASettings),
106+
],
107+
ids=["ppo", "sac", "poca"],
108+
)
109+
def test_load_different_reward_provider(caplog, tmp_path, optimizer):
110+
OptimizerClass, HyperparametersClass = optimizer
111+
112+
trainer_settings = TrainerSettings()
113+
trainer_settings.hyperparameters = HyperparametersClass()
114+
trainer_settings.reward_signals = {
115+
RewardSignalType.CURIOSITY: CuriositySettings(),
116+
RewardSignalType.RND: RNDSettings(),
117+
}
118+
119+
policy = create_policy_mock(trainer_settings, use_discrete=False)
120+
optimizer = OptimizerClass(policy, trainer_settings)
121+
122+
# save at path 1
123+
path1 = os.path.join(tmp_path, "runid1")
124+
model_saver = TorchModelSaver(trainer_settings, path1)
125+
model_saver.register(policy)
126+
model_saver.register(optimizer)
127+
model_saver.initialize_or_load()
128+
assert len(optimizer.critic.value_heads.stream_names) == 2
129+
policy.set_step(2000)
130+
model_saver.save_checkpoint("MockBrain", 2000)
131+
132+
trainer_settings2 = TrainerSettings()
133+
trainer_settings2.hyperparameters = HyperparametersClass()
134+
trainer_settings2.reward_signals = {
135+
RewardSignalType.GAIL: GAILSettings(demo_path=DEMO_PATH)
136+
}
137+
138+
# create a new optimizer and policy
139+
policy2 = create_policy_mock(trainer_settings2, use_discrete=False)
140+
optimizer2 = OptimizerClass(policy2, trainer_settings2)
141+
142+
# load weights
143+
model_saver2 = TorchModelSaver(trainer_settings2, path1, load=True)
144+
model_saver2.register(policy2)
145+
model_saver2.register(optimizer2)
146+
assert len(optimizer2.critic.value_heads.stream_names) == 1
147+
model_saver2.initialize_or_load() # This is to load the optimizers
148+
messages = [rec.message for rec in caplog.records if rec.levelno == WARNING]
149+
assert len(messages) > 0

0 commit comments

Comments
 (0)