Skip to content

Set torch device from commandline #4888

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ removed when training with a player. The Editor still requires it to be clamped
Changed the namespace and file names of classes in com.unity.ml-agents.extensions. (#4849)

#### ml-agents / ml-agents-envs / gym-unity (Python)
- Added a `--torch-device` commandline option to `mlagents-learn`, which sets the default
[`torch.device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.device) used for training. (#4888)
- The `--cpu` commandline option had no effect and was removed. Use `--torch-device=cpu` to force CPU training. (#4888)

### Bug Fixes
#### com.unity.ml-agents (C#)
Expand Down
10 changes: 9 additions & 1 deletion docs/Training-ML-Agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ using the help utility:
mlagents-learn --help
```

These additional CLI arguments are grouped into environment, engine and checkpoint. The available settings and example values are shown below.
These additional CLI arguments are grouped into environment, engine, checkpoint and torch.
The available settings and example values are shown below.

#### Environment settings

Expand Down Expand Up @@ -227,6 +228,13 @@ checkpoint_settings:
inference: false
```

#### Torch settings:

```yaml
torch_settings:
device: cpu
```

### Behavior Configurations

The primary section of the trainer config file is a
Expand Down
1 change: 1 addition & 0 deletions ml-agents/mlagents/torch_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from mlagents.torch_utils.torch import torch as torch # noqa
from mlagents.torch_utils.torch import nn # noqa
from mlagents.torch_utils.torch import set_torch_config # noqa
from mlagents.torch_utils.torch import default_device # noqa
37 changes: 30 additions & 7 deletions ml-agents/mlagents/torch_utils/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
from distutils.version import LooseVersion
import pkg_resources
from mlagents.torch_utils import cpu_utils
from mlagents.trainers.settings import TorchSettings
from mlagents_envs.logging_util import get_logger


logger = get_logger(__name__)


def assert_torch_installed():
Expand Down Expand Up @@ -32,14 +37,32 @@ def assert_torch_installed():
torch.set_num_threads(cpu_utils.get_num_threads_to_use())
os.environ["KMP_BLOCKTIME"] = "0"

if torch.cuda.is_available():
torch.set_default_tensor_type(torch.cuda.FloatTensor)
device = torch.device("cuda")
else:
torch.set_default_tensor_type(torch.FloatTensor)
device = torch.device("cpu")

_device = torch.device("cpu")


def set_torch_config(torch_settings: TorchSettings) -> None:
global _device

if torch_settings.device is None:
device_str = "cuda" if torch.cuda.is_available() else "cpu"
else:
device_str = torch_settings.device

_device = torch.device(device_str)

if _device.type == "cuda":
torch.set_default_tensor_type(torch.cuda.FloatTensor)
else:
torch.set_default_tensor_type(torch.FloatTensor)
logger.info(f"default Torch device: {_device}")


# Initialize to default settings
set_torch_config(TorchSettings(device=None))

nn = torch.nn


def default_device():
return device
return _device
15 changes: 9 additions & 6 deletions ml-agents/mlagents/trainers/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,6 @@ def _create_parser() -> argparse.ArgumentParser:
"passed to the executable.",
action=DetectDefault,
)
argparser.add_argument(
"--cpu",
default=False,
action=DetectDefaultStoreTrue,
help="Forces training using CPU only",
)
argparser.add_argument(
"--torch",
default=False,
Expand Down Expand Up @@ -252,6 +246,15 @@ def _create_parser() -> argparse.ArgumentParser:
help="Whether to run the Unity executable in no-graphics mode (i.e. without initializing "
"the graphics driver. Use this only if your agents don't use visual observations.",
)

torch_conf = argparser.add_argument_group(title="Torch Configuration")
torch_conf.add_argument(
"--torch-device",
default=None,
dest="device",
action=DetectDefault,
help='Settings for the default torch.device used in training, for example, "cpu", "cuda", or "cuda:0"',
)
return argparser


Expand Down
1 change: 1 addition & 0 deletions ml-agents/mlagents/trainers/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def run_training(run_seed: int, options: RunOptions) -> None:
:param run_options: Command line arguments for training.
"""
with hierarchical_timer("run_training.setup"):
torch_utils.set_torch_config(options.torch_settings)
checkpoint_settings = options.checkpoint_settings
env_settings = options.env_settings
engine_settings = options.engine_settings
Expand Down
9 changes: 9 additions & 0 deletions ml-agents/mlagents/trainers/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,11 @@ class EngineSettings:
no_graphics: bool = parser.get_default("no_graphics")


@attr.s(auto_attribs=True)
class TorchSettings:
device: Optional[str] = parser.get_default("torch_device")


@attr.s(auto_attribs=True)
class RunOptions(ExportableSettings):
default_settings: Optional[TrainerSettings] = None
Expand All @@ -743,6 +748,7 @@ class RunOptions(ExportableSettings):
engine_settings: EngineSettings = attr.ib(factory=EngineSettings)
environment_parameters: Optional[Dict[str, EnvironmentParameterSettings]] = None
checkpoint_settings: CheckpointSettings = attr.ib(factory=CheckpointSettings)
torch_settings: TorchSettings = attr.ib(factory=TorchSettings)

# These are options that are relevant to the run itself, and not the engine or environment.
# They will be left here.
Expand Down Expand Up @@ -784,6 +790,7 @@ def from_argparse(args: argparse.Namespace) -> "RunOptions":
"checkpoint_settings": {},
"env_settings": {},
"engine_settings": {},
"torch_settings": {},
}
if config_path is not None:
configured_dict.update(load_config(config_path))
Expand All @@ -808,6 +815,8 @@ def from_argparse(args: argparse.Namespace) -> "RunOptions":
configured_dict["env_settings"][key] = val
elif key in attr.fields_dict(EngineSettings):
configured_dict["engine_settings"][key] = val
elif key in attr.fields_dict(TorchSettings):
configured_dict["torch_settings"][key] = val
else: # Base options
configured_dict[key] = val

Expand Down
41 changes: 41 additions & 0 deletions ml-agents/mlagents/trainers/tests/test_torch_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import pytest
from unittest import mock

import torch # noqa I201

from mlagents.torch_utils import set_torch_config, default_device
from mlagents.trainers.settings import TorchSettings


@pytest.mark.parametrize(
"device_str, expected_type, expected_index, expected_tensor_type",
[
("cpu", "cpu", None, torch.FloatTensor),
("cuda", "cuda", None, torch.cuda.FloatTensor),
("cuda:42", "cuda", 42, torch.cuda.FloatTensor),
("opengl", "opengl", None, torch.FloatTensor),
],
)
@mock.patch.object(torch, "set_default_tensor_type")
def test_set_torch_device(
mock_set_default_tensor_type,
device_str,
expected_type,
expected_index,
expected_tensor_type,
):
try:
torch_settings = TorchSettings(device=device_str)
set_torch_config(torch_settings)
assert default_device().type == expected_type
if expected_index is None:
assert default_device().index is None
else:
assert default_device().index == expected_index
mock_set_default_tensor_type.assert_called_once_with(expected_tensor_type)
except Exception:
raise
finally:
# restore the defaults
torch_settings = TorchSettings(device=None)
set_torch_config(torch_settings)