diff --git a/.yamato/gym-interface-test.yml b/.yamato/gym-interface-test.yml
index 8ecb5f9347..01c975bca2 100644
--- a/.yamato/gym-interface-test.yml
+++ b/.yamato/gym-interface-test.yml
@@ -13,6 +13,7 @@ test_gym_interface_{{ editor.version }}:
- |
sudo apt-get update && sudo apt-get install -y python3-venv
python3 -m venv venv && source venv/bin/activate
+ python -m pip install wheel --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple
python -m pip install pyyaml --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple
python -u -m ml-agents.tests.yamato.setup_venv
python ml-agents/tests/yamato/scripts/run_gym.py --env=artifacts/testPlayer-Basic
diff --git a/Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicActuatorComponent.cs b/Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicActuatorComponent.cs
index 1caea48397..6369271199 100644
--- a/Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicActuatorComponent.cs
+++ b/Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicActuatorComponent.cs
@@ -1,5 +1,6 @@
using System;
using Unity.MLAgents.Actuators;
+using UnityEngine;
namespace Unity.MLAgentsExamples
{
@@ -30,7 +31,7 @@ public override ActionSpec ActionSpec
///
/// Simple actuator that converts the action into a {-1, 0, 1} direction
///
- public class BasicActuator : IActuator
+ public class BasicActuator : IActuator, IHeuristicProvider
{
public BasicController basicController;
ActionSpec m_ActionSpec;
@@ -76,6 +77,19 @@ public void OnActionReceived(ActionBuffers actionBuffers)
basicController.MoveDirection(direction);
}
+ public void Heuristic(in ActionBuffers actionBuffersOut)
+ {
+ var direction = Input.GetAxis("Horizontal");
+ var discreteActions = actionBuffersOut.DiscreteActions;
+ if (Mathf.Approximately(direction, 0.0f))
+ {
+ discreteActions[0] = 0;
+ return;
+ }
+ var sign = Math.Sign(direction);
+ discreteActions[0] = sign < 0 ? 1 : 2;
+ }
+
public void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
{
diff --git a/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3Heuristic.prefab b/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3Heuristic.prefab
index ec8b60c12c..adf6bbe161 100644
--- a/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3Heuristic.prefab
+++ b/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3Heuristic.prefab
@@ -13,8 +13,8 @@ GameObject:
- component: {fileID: 3508723250470608012}
- component: {fileID: 3508723250470608011}
- component: {fileID: 3508723250470608009}
- - component: {fileID: 3508723250470608013}
- component: {fileID: 3508723250470608014}
+ - component: {fileID: 2112317463290853299}
m_Layer: 0
m_Name: Match3 Agent
m_TagString: Untagged
@@ -51,9 +51,13 @@ MonoBehaviour:
m_BrainParameters:
VectorObservationSize: 0
NumStackedVectorObservations: 1
+ m_ActionSpec:
+ m_NumContinuousActions: 0
+ BranchSizes:
VectorActionSize:
VectorActionDescriptions: []
VectorActionSpaceType: 0
+ hasUpgradedBrainParametersWithActionSpec: 1
m_Model: {fileID: 11400000, guid: c34da50737a3c4a50918002b20b2b927, type: 3}
m_InferenceDevice: 0
m_BehaviorType: 0
@@ -81,7 +85,6 @@ MonoBehaviour:
Board: {fileID: 0}
MoveTime: 0.25
MaxMoves: 500
- HeuristicQuality: 0
--- !u!114 &3508723250470608011
MonoBehaviour:
m_ObjectHideFlags: 0
@@ -96,7 +99,6 @@ MonoBehaviour:
m_EditorClassIdentifier:
DebugMoveIndex: -1
CubeSpacing: 1.25
- Board: {fileID: 0}
TilePrefab: {fileID: 4007900521885639951, guid: faee4e805953b49e688bd00b45c55f2e,
type: 3}
--- !u!114 &3508723250470608009
@@ -119,7 +121,7 @@ MonoBehaviour:
BasicCellPoints: 1
SpecialCell1Points: 2
SpecialCell2Points: 3
---- !u!114 &3508723250470608013
+--- !u!114 &3508723250470608014
MonoBehaviour:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
@@ -128,12 +130,12 @@ MonoBehaviour:
m_GameObject: {fileID: 3508723250470608007}
m_Enabled: 1
m_EditorHideFlags: 0
- m_Script: {fileID: 11500000, guid: 08e4b0da54cb4d56bfcbae22dd49ab8d, type: 3}
+ m_Script: {fileID: 11500000, guid: 530d2f105aa145bd8a00e021bdd925fd, type: 3}
m_Name:
m_EditorClassIdentifier:
- ActuatorName: Match3 Actuator
- ForceHeuristic: 1
---- !u!114 &3508723250470608014
+ SensorName: Match3 Sensor
+ ObservationType: 0
+--- !u!114 &2112317463290853299
MonoBehaviour:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
@@ -142,11 +144,12 @@ MonoBehaviour:
m_GameObject: {fileID: 3508723250470608007}
m_Enabled: 1
m_EditorHideFlags: 0
- m_Script: {fileID: 11500000, guid: 530d2f105aa145bd8a00e021bdd925fd, type: 3}
+ m_Script: {fileID: 11500000, guid: b17adcc6c9b241da903aa134f2dac930, type: 3}
m_Name:
m_EditorClassIdentifier:
- SensorName: Match3 Sensor
- ObservationType: 0
+ ActuatorName: Match3 Actuator
+ ForceHeuristic: 1
+ HeuristicQuality: 0
--- !u!1 &3508723250774301855
GameObject:
m_ObjectHideFlags: 0
diff --git a/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VectorObs.prefab b/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VectorObs.prefab
index 3b4b66024f..4177d3c687 100644
--- a/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VectorObs.prefab
+++ b/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VectorObs.prefab
@@ -44,8 +44,8 @@ GameObject:
- component: {fileID: 2118285884327540682}
- component: {fileID: 2118285884327540685}
- component: {fileID: 2118285884327540687}
- - component: {fileID: 2118285884327540683}
- component: {fileID: 2118285884327540680}
+ - component: {fileID: 3357012711826686276}
m_Layer: 0
m_Name: Match3 Agent
m_TagString: Untagged
@@ -82,9 +82,13 @@ MonoBehaviour:
m_BrainParameters:
VectorObservationSize: 0
NumStackedVectorObservations: 1
+ m_ActionSpec:
+ m_NumContinuousActions: 0
+ BranchSizes:
VectorActionSize:
VectorActionDescriptions: []
VectorActionSpaceType: 0
+ hasUpgradedBrainParametersWithActionSpec: 1
m_Model: {fileID: 11400000, guid: 9e89b8e81974148d3b7213530d00589d, type: 3}
m_InferenceDevice: 0
m_BehaviorType: 0
@@ -112,7 +116,6 @@ MonoBehaviour:
Board: {fileID: 0}
MoveTime: 0.25
MaxMoves: 500
- HeuristicQuality: 0
--- !u!114 &2118285884327540685
MonoBehaviour:
m_ObjectHideFlags: 0
@@ -127,7 +130,6 @@ MonoBehaviour:
m_EditorClassIdentifier:
DebugMoveIndex: -1
CubeSpacing: 1.25
- Board: {fileID: 0}
TilePrefab: {fileID: 4007900521885639951, guid: faee4e805953b49e688bd00b45c55f2e,
type: 3}
--- !u!114 &2118285884327540687
@@ -150,7 +152,7 @@ MonoBehaviour:
BasicCellPoints: 1
SpecialCell1Points: 2
SpecialCell2Points: 3
---- !u!114 &2118285884327540683
+--- !u!114 &2118285884327540680
MonoBehaviour:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
@@ -159,12 +161,12 @@ MonoBehaviour:
m_GameObject: {fileID: 2118285884327540673}
m_Enabled: 1
m_EditorHideFlags: 0
- m_Script: {fileID: 11500000, guid: 08e4b0da54cb4d56bfcbae22dd49ab8d, type: 3}
+ m_Script: {fileID: 11500000, guid: 530d2f105aa145bd8a00e021bdd925fd, type: 3}
m_Name:
m_EditorClassIdentifier:
- ActuatorName: Match3 Actuator
- ForceHeuristic: 0
---- !u!114 &2118285884327540680
+ SensorName: Match3 Sensor
+ ObservationType: 0
+--- !u!114 &3357012711826686276
MonoBehaviour:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
@@ -173,8 +175,9 @@ MonoBehaviour:
m_GameObject: {fileID: 2118285884327540673}
m_Enabled: 1
m_EditorHideFlags: 0
- m_Script: {fileID: 11500000, guid: 530d2f105aa145bd8a00e021bdd925fd, type: 3}
+ m_Script: {fileID: 11500000, guid: b17adcc6c9b241da903aa134f2dac930, type: 3}
m_Name:
m_EditorClassIdentifier:
- SensorName: Match3 Sensor
- ObservationType: 0
+ ActuatorName: Match3 Actuator
+ ForceHeuristic: 0
+ HeuristicQuality: 0
diff --git a/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VisualObs.prefab b/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VisualObs.prefab
index 28d219acd3..2c78f72948 100644
--- a/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VisualObs.prefab
+++ b/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VisualObs.prefab
@@ -44,8 +44,8 @@ GameObject:
- component: {fileID: 3019509692332007781}
- component: {fileID: 3019509692332007778}
- component: {fileID: 3019509692332007776}
- - component: {fileID: 3019509692332007780}
- component: {fileID: 3019509692332007783}
+ - component: {fileID: 8270768986451624427}
m_Layer: 0
m_Name: Match3 Agent
m_TagString: Untagged
@@ -82,9 +82,13 @@ MonoBehaviour:
m_BrainParameters:
VectorObservationSize: 0
NumStackedVectorObservations: 1
+ m_ActionSpec:
+ m_NumContinuousActions: 0
+ BranchSizes:
VectorActionSize:
VectorActionDescriptions: []
VectorActionSpaceType: 0
+ hasUpgradedBrainParametersWithActionSpec: 1
m_Model: {fileID: 11400000, guid: 48d14da88fea74d0693c691c6e3f2e34, type: 3}
m_InferenceDevice: 0
m_BehaviorType: 0
@@ -112,7 +116,6 @@ MonoBehaviour:
Board: {fileID: 0}
MoveTime: 0.25
MaxMoves: 500
- HeuristicQuality: 0
--- !u!114 &3019509692332007778
MonoBehaviour:
m_ObjectHideFlags: 0
@@ -127,7 +130,6 @@ MonoBehaviour:
m_EditorClassIdentifier:
DebugMoveIndex: -1
CubeSpacing: 1.25
- Board: {fileID: 0}
TilePrefab: {fileID: 4007900521885639951, guid: faee4e805953b49e688bd00b45c55f2e,
type: 3}
--- !u!114 &3019509692332007776
@@ -150,7 +152,7 @@ MonoBehaviour:
BasicCellPoints: 1
SpecialCell1Points: 2
SpecialCell2Points: 3
---- !u!114 &3019509692332007780
+--- !u!114 &3019509692332007783
MonoBehaviour:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
@@ -159,12 +161,12 @@ MonoBehaviour:
m_GameObject: {fileID: 3019509692332007790}
m_Enabled: 1
m_EditorHideFlags: 0
- m_Script: {fileID: 11500000, guid: 08e4b0da54cb4d56bfcbae22dd49ab8d, type: 3}
+ m_Script: {fileID: 11500000, guid: 530d2f105aa145bd8a00e021bdd925fd, type: 3}
m_Name:
m_EditorClassIdentifier:
- ActuatorName: Match3 Actuator
- ForceHeuristic: 0
---- !u!114 &3019509692332007783
+ SensorName: Match3 Sensor
+ ObservationType: 2
+--- !u!114 &8270768986451624427
MonoBehaviour:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
@@ -173,8 +175,9 @@ MonoBehaviour:
m_GameObject: {fileID: 3019509692332007790}
m_Enabled: 1
m_EditorHideFlags: 0
- m_Script: {fileID: 11500000, guid: 530d2f105aa145bd8a00e021bdd925fd, type: 3}
+ m_Script: {fileID: 11500000, guid: b17adcc6c9b241da903aa134f2dac930, type: 3}
m_Name:
m_EditorClassIdentifier:
- SensorName: Match3 Sensor
- ObservationType: 2
+ ActuatorName: Match3 Actuator
+ ForceHeuristic: 0
+ HeuristicQuality: 0
diff --git a/Project/Assets/ML-Agents/Examples/Match3/Scenes/Match3.unity b/Project/Assets/ML-Agents/Examples/Match3/Scenes/Match3.unity
index 8383c776d2..f543515eb8 100644
--- a/Project/Assets/ML-Agents/Examples/Match3/Scenes/Match3.unity
+++ b/Project/Assets/ML-Agents/Examples/Match3/Scenes/Match3.unity
@@ -690,6 +690,11 @@ PrefabInstance:
m_Modification:
m_TransformParent: {fileID: 0}
m_Modifications:
+ - target: {fileID: 2112317463290853299, guid: 2fafdcd0587684641b03b11f04454f1b,
+ type: 3}
+ propertyPath: HeuristicQuality
+ value: 1
+ objectReference: {fileID: 0}
- target: {fileID: 3508723250470608011, guid: 2fafdcd0587684641b03b11f04454f1b,
type: 3}
propertyPath: cubeSpacing
@@ -1385,6 +1390,11 @@ PrefabInstance:
m_Modification:
m_TransformParent: {fileID: 0}
m_Modifications:
+ - target: {fileID: 2112317463290853299, guid: 2fafdcd0587684641b03b11f04454f1b,
+ type: 3}
+ propertyPath: HeuristicQuality
+ value: 1
+ objectReference: {fileID: 0}
- target: {fileID: 3508723250470608011, guid: 2fafdcd0587684641b03b11f04454f1b,
type: 3}
propertyPath: cubeSpacing
diff --git a/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Agent.cs b/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Agent.cs
index a890f12c17..7c7192b8f7 100644
--- a/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Agent.cs
+++ b/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Agent.cs
@@ -63,20 +63,6 @@ enum State
WaitForMove = 4,
}
- public enum HeuristicQuality
- {
- ///
- /// The heuristic will pick any valid move at random.
- ///
- RandomValidMove,
-
- ///
- /// The heuristic will pick the move that scores the most points.
- /// This only looks at the immediate move, and doesn't consider where cells will fall.
- ///
- Greedy
- }
-
public class Match3Agent : Agent
{
[HideInInspector]
@@ -86,20 +72,14 @@ public class Match3Agent : Agent
public int MaxMoves = 500;
- public HeuristicQuality HeuristicQuality = HeuristicQuality.RandomValidMove;
-
State m_CurrentState = State.WaitForMove;
float m_TimeUntilMove;
private int m_MovesMade;
- private System.Random m_Random;
private const float k_RewardMultiplier = 0.01f;
-
void Awake()
{
Board = GetComponent();
- var seed = Board.RandomSeed == -1 ? gameObject.GetInstanceID() : Board.RandomSeed + 1;
- m_Random = new System.Random(seed);
}
public override void OnEpisodeBegin()
@@ -222,152 +202,6 @@ bool HasValidMoves()
return false;
}
- public override void Heuristic(in ActionBuffers actionsOut)
- {
- var discreteActions = actionsOut.DiscreteActions;
- discreteActions[0] = GreedyMove();
- }
-
- int GreedyMove()
- {
- var pointsByType = new[] { Board.BasicCellPoints, Board.SpecialCell1Points, Board.SpecialCell2Points };
-
- var bestMoveIndex = 0;
- var bestMovePoints = -1;
- var numMovesAtCurrentScore = 0;
-
- foreach (var move in Board.ValidMoves())
- {
- var movePoints = HeuristicQuality == HeuristicQuality.Greedy ? EvalMovePoints(move, pointsByType) : 1;
- if (movePoints < bestMovePoints)
- {
- // Worse, skip
- continue;
- }
-
- if (movePoints > bestMovePoints)
- {
- // Better, keep
- bestMovePoints = movePoints;
- bestMoveIndex = move.MoveIndex;
- numMovesAtCurrentScore = 1;
- }
- else
- {
- // Tied for best - use reservoir sampling to make sure we select from equal moves uniformly.
- // See https://en.wikipedia.org/wiki/Reservoir_sampling#Simple_algorithm
- numMovesAtCurrentScore++;
- var randVal = m_Random.Next(0, numMovesAtCurrentScore);
- if (randVal == 0)
- {
- // Keep the new one
- bestMoveIndex = move.MoveIndex;
- }
- }
- }
-
- return bestMoveIndex;
- }
-
- int EvalMovePoints(Move move, int[] pointsByType)
- {
- // Counts the expected points for making the move.
- var moveVal = Board.GetCellType(move.Row, move.Column);
- var moveSpecial = Board.GetSpecialType(move.Row, move.Column);
- var (otherRow, otherCol) = move.OtherCell();
- var oppositeVal = Board.GetCellType(otherRow, otherCol);
- var oppositeSpecial = Board.GetSpecialType(otherRow, otherCol);
-
-
- int movePoints = EvalHalfMove(
- otherRow, otherCol, moveVal, moveSpecial, move.Direction, pointsByType
- );
- int otherPoints = EvalHalfMove(
- move.Row, move.Column, oppositeVal, oppositeSpecial, move.OtherDirection(), pointsByType
- );
- return movePoints + otherPoints;
- }
-
- int EvalHalfMove(int newRow, int newCol, int newValue, int newSpecial, Direction incomingDirection, int[] pointsByType)
- {
- // This is a essentially a duplicate of AbstractBoard.CheckHalfMove but also counts the points for the move.
- int matchedLeft = 0, matchedRight = 0, matchedUp = 0, matchedDown = 0;
- int scoreLeft = 0, scoreRight = 0, scoreUp = 0, scoreDown = 0;
-
- if (incomingDirection != Direction.Right)
- {
- for (var c = newCol - 1; c >= 0; c--)
- {
- if (Board.GetCellType(newRow, c) == newValue)
- {
- matchedLeft++;
- scoreLeft += pointsByType[Board.GetSpecialType(newRow, c)];
- }
- else
- break;
- }
- }
-
- if (incomingDirection != Direction.Left)
- {
- for (var c = newCol + 1; c < Board.Columns; c++)
- {
- if (Board.GetCellType(newRow, c) == newValue)
- {
- matchedRight++;
- scoreRight += pointsByType[Board.GetSpecialType(newRow, c)];
- }
- else
- break;
- }
- }
-
- if (incomingDirection != Direction.Down)
- {
- for (var r = newRow + 1; r < Board.Rows; r++)
- {
- if (Board.GetCellType(r, newCol) == newValue)
- {
- matchedUp++;
- scoreUp += pointsByType[Board.GetSpecialType(r, newCol)];
- }
- else
- break;
- }
- }
-
- if (incomingDirection != Direction.Up)
- {
- for (var r = newRow - 1; r >= 0; r--)
- {
- if (Board.GetCellType(r, newCol) == newValue)
- {
- matchedDown++;
- scoreDown += pointsByType[Board.GetSpecialType(r, newCol)];
- }
- else
- break;
- }
- }
-
- if ((matchedUp + matchedDown >= 2) || (matchedLeft + matchedRight >= 2))
- {
- // It's a match. Start from counting the piece being moved
- var totalScore = pointsByType[newSpecial];
- if (matchedUp + matchedDown >= 2)
- {
- totalScore += scoreUp + scoreDown;
- }
-
- if (matchedLeft + matchedRight >= 2)
- {
- totalScore += scoreLeft + scoreRight;
- }
- return totalScore;
- }
-
- return 0;
- }
}
}
diff --git a/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Board.cs b/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Board.cs
index 859ec2b9e8..9aa1d59f57 100644
--- a/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Board.cs
+++ b/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Board.cs
@@ -1,3 +1,4 @@
+using System;
using Unity.MLAgents.Extensions.Match3;
using UnityEngine;
@@ -7,8 +8,6 @@ namespace Unity.MLAgentsExamples
public class Match3Board : AbstractBoard
{
- public int RandomSeed = -1;
-
public const int k_EmptyCell = -1;
[Tooltip("Points earned for clearing a basic cell (cube)")]
public int BasicCellPoints = 1;
@@ -19,6 +18,11 @@ public class Match3Board : AbstractBoard
[Tooltip("Points earned for clearing an extra special cell (plus)")]
public int SpecialCell2Points = 3;
+ ///
+ /// Seed to initialize the object.
+ ///
+ public int RandomSeed;
+
(int, int)[,] m_Cells;
bool[,] m_Matched;
@@ -29,8 +33,11 @@ void Awake()
m_Cells = new (int, int)[Columns, Rows];
m_Matched = new bool[Columns, Rows];
- m_Random = new System.Random(RandomSeed == -1 ? gameObject.GetInstanceID() : RandomSeed);
+ }
+ void Start()
+ {
+ m_Random = new System.Random(RandomSeed == -1 ? gameObject.GetInstanceID() : RandomSeed);
InitRandom();
}
diff --git a/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3ExampleActuator.cs b/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3ExampleActuator.cs
new file mode 100644
index 0000000000..c50dc9ff2a
--- /dev/null
+++ b/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3ExampleActuator.cs
@@ -0,0 +1,121 @@
+using Unity.MLAgents;
+using Unity.MLAgents.Extensions.Match3;
+
+namespace Unity.MLAgentsExamples
+{
+ public class Match3ExampleActuator : Match3Actuator
+ {
+ Match3Board Board => (Match3Board)m_Board;
+
+ public Match3ExampleActuator(Match3Board board,
+ bool forceHeuristic,
+ Agent agent,
+ string name,
+ int seed
+ )
+ : base(board, forceHeuristic, seed, agent, name) { }
+
+
+ protected override int EvalMovePoints(Move move)
+ {
+ var pointsByType = new[] { Board.BasicCellPoints, Board.SpecialCell1Points, Board.SpecialCell2Points };
+ // Counts the expected points for making the move.
+ var moveVal = m_Board.GetCellType(move.Row, move.Column);
+ var moveSpecial = m_Board.GetSpecialType(move.Row, move.Column);
+ var (otherRow, otherCol) = move.OtherCell();
+ var oppositeVal = m_Board.GetCellType(otherRow, otherCol);
+ var oppositeSpecial = m_Board.GetSpecialType(otherRow, otherCol);
+
+
+ int movePoints = EvalHalfMove(
+ otherRow, otherCol, moveVal, moveSpecial, move.Direction, pointsByType
+ );
+ int otherPoints = EvalHalfMove(
+ move.Row, move.Column, oppositeVal, oppositeSpecial, move.OtherDirection(), pointsByType
+ );
+ return movePoints + otherPoints;
+ }
+
+ int EvalHalfMove(int newRow, int newCol, int newValue, int newSpecial, Direction incomingDirection, int[] pointsByType)
+ {
+ // This is a essentially a duplicate of AbstractBoard.CheckHalfMove but also counts the points for the move.
+ int matchedLeft = 0, matchedRight = 0, matchedUp = 0, matchedDown = 0;
+ int scoreLeft = 0, scoreRight = 0, scoreUp = 0, scoreDown = 0;
+
+ if (incomingDirection != Direction.Right)
+ {
+ for (var c = newCol - 1; c >= 0; c--)
+ {
+ if (m_Board.GetCellType(newRow, c) == newValue)
+ {
+ matchedLeft++;
+ scoreLeft += pointsByType[m_Board.GetSpecialType(newRow, c)];
+ }
+ else
+ break;
+ }
+ }
+
+ if (incomingDirection != Direction.Left)
+ {
+ for (var c = newCol + 1; c < m_Board.Columns; c++)
+ {
+ if (m_Board.GetCellType(newRow, c) == newValue)
+ {
+ matchedRight++;
+ scoreRight += pointsByType[m_Board.GetSpecialType(newRow, c)];
+ }
+ else
+ break;
+ }
+ }
+
+ if (incomingDirection != Direction.Down)
+ {
+ for (var r = newRow + 1; r < m_Board.Rows; r++)
+ {
+ if (m_Board.GetCellType(r, newCol) == newValue)
+ {
+ matchedUp++;
+ scoreUp += pointsByType[m_Board.GetSpecialType(r, newCol)];
+ }
+ else
+ break;
+ }
+ }
+
+ if (incomingDirection != Direction.Up)
+ {
+ for (var r = newRow - 1; r >= 0; r--)
+ {
+ if (m_Board.GetCellType(r, newCol) == newValue)
+ {
+ matchedDown++;
+ scoreDown += pointsByType[m_Board.GetSpecialType(r, newCol)];
+ }
+ else
+ break;
+ }
+ }
+
+ if ((matchedUp + matchedDown >= 2) || (matchedLeft + matchedRight >= 2))
+ {
+ // It's a match. Start from counting the piece being moved
+ var totalScore = pointsByType[newSpecial];
+ if (matchedUp + matchedDown >= 2)
+ {
+ totalScore += scoreUp + scoreDown;
+ }
+
+ if (matchedLeft + matchedRight >= 2)
+ {
+ totalScore += scoreLeft + scoreRight;
+ }
+ return totalScore;
+ }
+
+ return 0;
+ }
+ }
+
+}
diff --git a/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3ExampleActuator.cs.meta b/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3ExampleActuator.cs.meta
new file mode 100644
index 0000000000..acaf52892c
--- /dev/null
+++ b/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3ExampleActuator.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: 9e6fe1a020a04421ab828be4543a655c
+timeCreated: 1610665874
\ No newline at end of file
diff --git a/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3ExampleActuatorComponent.cs b/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3ExampleActuatorComponent.cs
new file mode 100644
index 0000000000..8f32bf1755
--- /dev/null
+++ b/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3ExampleActuatorComponent.cs
@@ -0,0 +1,18 @@
+using Unity.MLAgents;
+using Unity.MLAgents.Actuators;
+using Unity.MLAgents.Extensions.Match3;
+
+namespace Unity.MLAgentsExamples
+{
+ public class Match3ExampleActuatorComponent : Match3ActuatorComponent
+ {
+ ///
+ public override IActuator CreateActuator()
+ {
+ var board = GetComponent();
+ var agent = GetComponentInParent();
+ var seed = RandomSeed == -1 ? gameObject.GetInstanceID() : RandomSeed + 1;
+ return new Match3ExampleActuator(board, ForceHeuristic, agent, ActuatorName, seed);
+ }
+ }
+}
diff --git a/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3ExampleActuatorComponent.cs.meta b/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3ExampleActuatorComponent.cs.meta
new file mode 100644
index 0000000000..e0569da775
--- /dev/null
+++ b/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3ExampleActuatorComponent.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: b17adcc6c9b241da903aa134f2dac930
+timeCreated: 1610665885
\ No newline at end of file
diff --git a/com.unity.ml-agents.extensions/Runtime/Match3/Match3Actuator.cs b/com.unity.ml-agents.extensions/Runtime/Match3/Match3Actuator.cs
index 67ff107344..bfa2d267d1 100644
--- a/com.unity.ml-agents.extensions/Runtime/Match3/Match3Actuator.cs
+++ b/com.unity.ml-agents.extensions/Runtime/Match3/Match3Actuator.cs
@@ -9,12 +9,12 @@ namespace Unity.MLAgents.Extensions.Match3
/// Actuator for a Match3 game. It translates valid moves (defined by AbstractBoard.IsMoveValid())
/// in action masks, and applies the action to the board via AbstractBoard.MakeMove().
///
- public class Match3Actuator : IActuator
+ public class Match3Actuator : IActuator, IHeuristicProvider
{
- private AbstractBoard m_Board;
+ protected AbstractBoard m_Board;
+ protected System.Random m_Random;
private ActionSpec m_ActionSpec;
private bool m_ForceHeuristic;
- private System.Random m_Random;
private Agent m_Agent;
private int m_Rows;
@@ -27,9 +27,14 @@ public class Match3Actuator : IActuator
///
/// Whether the inference action should be ignored and the Agent's Heuristic
/// should be called. This should only be used for generating comparison stats of the Heuristic.
+ /// The seed used to initialize .
///
///
- public Match3Actuator(AbstractBoard board, bool forceHeuristic, Agent agent, string name)
+ public Match3Actuator(AbstractBoard board,
+ bool forceHeuristic,
+ int seed,
+ Agent agent,
+ string name)
{
m_Board = board;
m_Rows = board.Rows;
@@ -42,6 +47,7 @@ public Match3Actuator(AbstractBoard board, bool forceHeuristic, Agent agent, str
var numMoves = Move.NumPotentialMoves(m_Board.Rows, m_Board.Columns);
m_ActionSpec = ActionSpec.MakeDiscrete(numMoves);
+ m_Random = new System.Random(seed);
}
///
@@ -52,7 +58,7 @@ public void OnActionReceived(ActionBuffers actions)
{
if (m_ForceHeuristic)
{
- m_Agent.Heuristic(actions);
+ Heuristic(actions);
}
var moveIndex = actions.DiscreteActions[0];
@@ -116,5 +122,63 @@ IEnumerable InvalidMoveIndices()
yield return move.MoveIndex;
}
}
+
+ public void Heuristic(in ActionBuffers actionsOut)
+ {
+ var discreteActions = actionsOut.DiscreteActions;
+ discreteActions[0] = GreedyMove();
+ }
+
+
+ protected int GreedyMove()
+ {
+
+ var bestMoveIndex = 0;
+ var bestMovePoints = -1;
+ var numMovesAtCurrentScore = 0;
+
+ foreach (var move in m_Board.ValidMoves())
+ {
+ var movePoints = EvalMovePoints(move);
+ if (movePoints < bestMovePoints)
+ {
+ // Worse, skip
+ continue;
+ }
+
+ if (movePoints > bestMovePoints)
+ {
+ // Better, keep
+ bestMovePoints = movePoints;
+ bestMoveIndex = move.MoveIndex;
+ numMovesAtCurrentScore = 1;
+ }
+ else
+ {
+ // Tied for best - use reservoir sampling to make sure we select from equal moves uniformly.
+ // See https://en.wikipedia.org/wiki/Reservoir_sampling#Simple_algorithm
+ numMovesAtCurrentScore++;
+ var randVal = m_Random.Next(0, numMovesAtCurrentScore);
+ if (randVal == 0)
+ {
+ // Keep the new one
+ bestMoveIndex = move.MoveIndex;
+ }
+ }
+ }
+
+ return bestMoveIndex;
+ }
+
+ ///
+ /// Method to be overridden when evaluating how many points a specific move will generate.
+ ///
+ /// The move to evaluate.
+ /// The number of points the move generates.
+ protected virtual int EvalMovePoints(Move move)
+ {
+ return 1;
+ }
+
}
}
diff --git a/com.unity.ml-agents.extensions/Runtime/Match3/Match3ActuatorComponent.cs b/com.unity.ml-agents.extensions/Runtime/Match3/Match3ActuatorComponent.cs
index 9c48d9cf20..7f14fb4ccd 100644
--- a/com.unity.ml-agents.extensions/Runtime/Match3/Match3ActuatorComponent.cs
+++ b/com.unity.ml-agents.extensions/Runtime/Match3/Match3ActuatorComponent.cs
@@ -5,7 +5,7 @@
namespace Unity.MLAgents.Extensions.Match3
{
///
- /// Actuator component for a Match 3 game. Generates a Match3Actuator at runtime.
+ /// Actuator component for a Match3 game. Generates a Match3Actuator at runtime.
///
public class Match3ActuatorComponent : ActuatorComponent
{
@@ -15,6 +15,11 @@ public class Match3ActuatorComponent : ActuatorComponent
///
public string ActuatorName = "Match3 Actuator";
+ ///
+ /// A random seed used to generate a board, if needed.
+ ///
+ public int RandomSeed = -1;
+
///
/// Force using the Agent's Heuristic() method to decide the action. This should only be used in testing.
///
@@ -27,7 +32,8 @@ public override IActuator CreateActuator()
{
var board = GetComponent();
var agent = GetComponentInParent();
- return new Match3Actuator(board, ForceHeuristic, agent, ActuatorName);
+ var seed = RandomSeed == -1 ? gameObject.GetInstanceID() : RandomSeed + 1;
+ return new Match3Actuator(board, ForceHeuristic, seed, agent, ActuatorName);
}
///
diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md
index 5d05c80f4f..24e33d3a17 100755
--- a/com.unity.ml-agents/CHANGELOG.md
+++ b/com.unity.ml-agents/CHANGELOG.md
@@ -20,6 +20,10 @@ will result in the values being summed (instead of averaged) when written to
TensorBoard. Thanks to @brccabral for the contribution! (#4816)
- The upper limit for the time scale (by setting the `--time-scale` paramater in mlagents-learn) was
removed when training with a player. The Editor still requires it to be clamped to 100. (#4867)
+- Added the IHeuristicProvider interface to allow IActuators as well as Agent implement the Heuristic function to generate actions.
+ Updated the Basic example and the Match3 Example to use Actuators.
+ Changed the namespace and file names of classes in com.unity.ml-agents.extensions. (#4849)
+
#### ml-agents / ml-agents-envs / gym-unity (Python)
diff --git a/com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs b/com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs
index aa72c4e423..2e4538b93e 100644
--- a/com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs
+++ b/com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs
@@ -204,6 +204,49 @@ public void WriteActionMask()
}
}
+ ///
+ /// Iterates through all of the IActuators in this list and calls their
+ /// method on them, if implemented, with the appropriate
+ /// s depending on their .
+ ///
+ public void ApplyHeuristic(in ActionBuffers actionBuffersOut)
+ {
+ var continuousStart = 0;
+ var discreteStart = 0;
+ for (var i = 0; i < m_Actuators.Count; i++)
+ {
+ var actuator = m_Actuators[i];
+ var numContinuousActions = actuator.ActionSpec.NumContinuousActions;
+ var numDiscreteActions = actuator.ActionSpec.NumDiscreteActions;
+
+ if (numContinuousActions == 0 && numDiscreteActions == 0)
+ {
+ continue;
+ }
+
+ var continuousActions = ActionSegment.Empty;
+ if (numContinuousActions > 0)
+ {
+ continuousActions = new ActionSegment(actionBuffersOut.ContinuousActions.Array,
+ continuousStart,
+ numContinuousActions);
+ }
+
+ var discreteActions = ActionSegment.Empty;
+ if (numDiscreteActions > 0)
+ {
+ discreteActions = new ActionSegment(actionBuffersOut.DiscreteActions.Array,
+ discreteStart,
+ numDiscreteActions);
+ }
+
+ var heuristic = actuator as IHeuristicProvider;
+ heuristic?.Heuristic(new ActionBuffers(continuousActions, discreteActions));
+ continuousStart += numContinuousActions;
+ discreteStart += numDiscreteActions;
+ }
+ }
+
///
/// Iterates through all of the IActuators in this list and calls their
/// method on them with the appropriate
diff --git a/com.unity.ml-agents/Runtime/Actuators/IHeuristicProvider.cs b/com.unity.ml-agents/Runtime/Actuators/IHeuristicProvider.cs
new file mode 100644
index 0000000000..b992361c83
--- /dev/null
+++ b/com.unity.ml-agents/Runtime/Actuators/IHeuristicProvider.cs
@@ -0,0 +1,18 @@
+namespace Unity.MLAgents.Actuators
+{
+ ///
+ /// Interface that allows objects to fill out an data structure for controlling
+ /// behavior of Agents or Actuators.
+ ///
+ public interface IHeuristicProvider
+ {
+ ///
+ /// Method called on objects which are expected to fill out the data structure.
+ /// Object that implement this interface should be careful to be consistent in the placement of their actions
+ /// in the data structure.
+ ///
+ /// The data structure to be filled by the
+ /// object implementing this interface.
+ void Heuristic(in ActionBuffers actionBuffersOut);
+ }
+}
diff --git a/com.unity.ml-agents/Runtime/Actuators/IHeuristicProvider.cs.meta b/com.unity.ml-agents/Runtime/Actuators/IHeuristicProvider.cs.meta
new file mode 100644
index 0000000000..ca8338a072
--- /dev/null
+++ b/com.unity.ml-agents/Runtime/Actuators/IHeuristicProvider.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: be90ffb28f39444a8fb02dfd4a82870c
+timeCreated: 1610057456
\ No newline at end of file
diff --git a/com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs b/com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs
index e4bfad2c65..1b18300f56 100644
--- a/com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs
+++ b/com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs
@@ -1,11 +1,12 @@
namespace Unity.MLAgents.Actuators
{
///
- /// IActuator implementation that forwards to an .
+ /// IActuator implementation that forwards calls to an and an .
///
- internal class VectorActuator : IActuator
+ internal class VectorActuator : IActuator, IHeuristicProvider
{
IActionReceiver m_ActionReceiver;
+ IHeuristicProvider m_HeuristicProvider;
ActionBuffers m_ActionBuffers;
internal ActionBuffers ActionBuffers
@@ -14,17 +15,34 @@ internal ActionBuffers ActionBuffers
private set => m_ActionBuffers = value;
}
+ ///
+ /// Create a VectorActuator that forwards to the provided IActionReceiver.
+ ///
+ /// The used for OnActionReceived and WriteDiscreteActionMask.
+ /// If this parameter also implements it will be cast and used to forward calls to
+ /// .
+ ///
+ ///
+ public VectorActuator(IActionReceiver actionReceiver,
+ ActionSpec actionSpec,
+ string name = "VectorActuator")
+ : this(actionReceiver, actionReceiver as IHeuristicProvider, actionSpec, name) { }
+
///
/// Create a VectorActuator that forwards to the provided IActionReceiver.
///
/// The used for OnActionReceived and WriteDiscreteActionMask.
+ /// The used to fill the
+ /// for Heuristic Policies.
///
///
public VectorActuator(IActionReceiver actionReceiver,
+ IHeuristicProvider heuristicProvider,
ActionSpec actionSpec,
string name = "VectorActuator")
{
m_ActionReceiver = actionReceiver;
+ m_HeuristicProvider = heuristicProvider;
ActionSpec = actionSpec;
string suffix;
if (actionSpec.NumContinuousActions == 0)
@@ -55,6 +73,11 @@ public void OnActionReceived(ActionBuffers actionBuffers)
m_ActionReceiver.OnActionReceived(ActionBuffers);
}
+ public void Heuristic(in ActionBuffers actionBuffersOut)
+ {
+ m_HeuristicProvider?.Heuristic(actionBuffersOut);
+ }
+
///
public void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
{
diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs
index 347dc16ec8..29c0dffe73 100644
--- a/com.unity.ml-agents/Runtime/Agent.cs
+++ b/com.unity.ml-agents/Runtime/Agent.cs
@@ -166,7 +166,7 @@ public void CopyActions(ActionBuffers actionBuffers)
"docs/Learning-Environment-Design-Agents.md")]
[Serializable]
[RequireComponent(typeof(BehaviorParameters))]
- public partial class Agent : MonoBehaviour, ISerializationCallbackReceiver, IActionReceiver
+ public partial class Agent : MonoBehaviour, ISerializationCallbackReceiver, IActionReceiver, IHeuristicProvider
{
IPolicy m_Brain;
BehaviorParameters m_PolicyFactory;
@@ -312,6 +312,11 @@ internal struct AgentParameters
///
float[] m_LegacyActionCache;
+ ///
+ /// This is used to avoid allocation of a float array during legacy calls to Heuristic.
+ ///
+ float[] m_LegacyHeuristicCache;
+
///
/// Called when the attached [GameObject] becomes enabled and active.
/// [GameObject]: https://docs.unity3d.com/Manual/GameObjects.html
@@ -429,7 +434,7 @@ public void LazyInitialize()
InitializeActuators();
}
- m_Brain = m_PolicyFactory.GeneratePolicy(m_ActuatorManager.GetCombinedActionSpec(), Heuristic);
+ m_Brain = m_PolicyFactory.GeneratePolicy(m_ActuatorManager.GetCombinedActionSpec(), m_ActuatorManager);
ResetData();
Initialize();
@@ -606,7 +611,7 @@ internal void ReloadPolicy()
return;
}
m_Brain?.Dispose();
- m_Brain = m_PolicyFactory.GeneratePolicy(m_ActuatorManager.GetCombinedActionSpec(), Heuristic);
+ m_Brain = m_PolicyFactory.GeneratePolicy(m_ActuatorManager.GetCombinedActionSpec(), m_ActuatorManager);
}
///
@@ -826,11 +831,11 @@ void ResetData()
public virtual void Initialize() { }
///
- /// Implement `Heuristic()` to choose an action for this agent using a custom heuristic.
+ /// Implement to choose an action for this agent using a custom heuristic.
///
///
/// Implement this function to provide custom decision making logic or to support manual
- /// control of an agent using keyboard, mouse, or game controller input.
+ /// control of an agent using keyboard, mouse, game controller input, or a script.
///
/// Your heuristic implementation can use any decision making logic you specify. Assign decision
/// values to the and
@@ -894,22 +899,21 @@ public virtual void Heuristic(in ActionBuffers actionsOut)
switch (m_PolicyFactory.BrainParameters.VectorActionSpaceType)
{
case SpaceType.Continuous:
- Heuristic(actionsOut.ContinuousActions.Array);
+ Heuristic(m_LegacyHeuristicCache);
+ Array.Copy(m_LegacyHeuristicCache, actionsOut.ContinuousActions.Array, m_LegacyActionCache.Length);
actionsOut.DiscreteActions.Clear();
break;
case SpaceType.Discrete:
- var convertedOut = Array.ConvertAll(actionsOut.DiscreteActions.Array, x => (float)x);
- Heuristic(convertedOut);
+ Heuristic(m_LegacyHeuristicCache);
var discreteActionSegment = actionsOut.DiscreteActions;
for (var i = 0; i < actionsOut.DiscreteActions.Length; i++)
{
- discreteActionSegment[i] = (int)convertedOut[i];
+ discreteActionSegment[i] = (int)m_LegacyHeuristicCache[i];
}
actionsOut.ContinuousActions.Clear();
break;
}
#pragma warning restore CS0618
-
}
///
@@ -993,9 +997,10 @@ void InitializeActuators()
// Support legacy OnActionReceived
// TODO don't set this up if the sizes are 0?
var param = m_PolicyFactory.BrainParameters;
- m_VectorActuator = new VectorActuator(this, param.ActionSpec);
+ m_VectorActuator = new VectorActuator(this, this, param.ActionSpec);
m_ActuatorManager = new ActuatorManager(attachedActuators.Length + 1);
m_LegacyActionCache = new float[m_VectorActuator.TotalNumberOfActions()];
+ m_LegacyHeuristicCache = new float[m_VectorActuator.TotalNumberOfActions()];
m_ActuatorManager.Add(m_VectorActuator);
@@ -1178,7 +1183,7 @@ public virtual void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
/// three values in ActionBuffers.ContinuousActions array to use as the force components.
/// During training, the agent's policy learns to set those particular elements of
/// the array to maximize the training rewards the agent receives. (Of course,
- /// if you implement a function, it must use the same
+ /// if you implement a function, it must use the same
/// elements of the action array for the same purpose since there is no learning
/// involved.)
///
@@ -1241,11 +1246,16 @@ public virtual void OnActionReceived(ActionBuffers actions)
if (!actions.ContinuousActions.IsEmpty())
{
- m_LegacyActionCache = actions.ContinuousActions.Array;
+ Array.Copy(actions.ContinuousActions.Array,
+ m_LegacyActionCache,
+ actionSpec.NumContinuousActions);
}
else
{
- m_LegacyActionCache = Array.ConvertAll(actions.DiscreteActions.Array, x => (float)x);
+ for (var i = 0; i < m_LegacyActionCache.Length; i++)
+ {
+ m_LegacyActionCache[i] = (float)actions.DiscreteActions[i];
+ }
}
// Disable deprecation warnings so we can call the legacy overload.
#pragma warning disable CS0618
diff --git a/com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs b/com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs
index 100220e320..49e00918c1 100644
--- a/com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs
+++ b/com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs
@@ -196,12 +196,12 @@ public string FullyQualifiedBehaviorName
get { return m_BehaviorName + "?team=" + TeamId; }
}
- internal IPolicy GeneratePolicy(ActionSpec actionSpec, HeuristicPolicy.ActionGenerator heuristic)
+ internal IPolicy GeneratePolicy(ActionSpec actionSpec, ActuatorManager actuatorManager)
{
switch (m_BehaviorType)
{
case BehaviorType.HeuristicOnly:
- return new HeuristicPolicy(heuristic, actionSpec);
+ return new HeuristicPolicy(actuatorManager, actionSpec);
case BehaviorType.InferenceOnly:
{
if (m_Model == null)
@@ -225,10 +225,10 @@ internal IPolicy GeneratePolicy(ActionSpec actionSpec, HeuristicPolicy.ActionGen
}
else
{
- return new HeuristicPolicy(heuristic, actionSpec);
+ return new HeuristicPolicy(actuatorManager, actionSpec);
}
default:
- return new HeuristicPolicy(heuristic, actionSpec);
+ return new HeuristicPolicy(actuatorManager, actionSpec);
}
}
@@ -241,6 +241,5 @@ internal void UpdateAgentPolicy()
}
agent.ReloadPolicy();
}
-
}
}
diff --git a/com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs b/com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs
index 217f226857..7163a8face 100644
--- a/com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs
+++ b/com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs
@@ -7,14 +7,13 @@
namespace Unity.MLAgents.Policies
{
///
- /// The Heuristic Policy uses a hards coded Heuristic method
+ /// The Heuristic Policy uses a hard-coded Heuristic method
/// to take decisions each time the RequestDecision method is
/// called.
///
internal class HeuristicPolicy : IPolicy
{
- public delegate void ActionGenerator(in ActionBuffers actionBuffers);
- ActionGenerator m_Heuristic;
+ ActuatorManager m_ActuatorManager;
ActionBuffers m_ActionBuffers;
bool m_Done;
bool m_DecisionRequested;
@@ -24,9 +23,9 @@ internal class HeuristicPolicy : IPolicy
///
- public HeuristicPolicy(ActionGenerator heuristic, ActionSpec actionSpec)
+ public HeuristicPolicy(ActuatorManager actuatorManager, ActionSpec actionSpec)
{
- m_Heuristic = heuristic;
+ m_ActuatorManager = actuatorManager;
var numContinuousActions = actionSpec.NumContinuousActions;
var numDiscreteActions = actionSpec.NumDiscreteActions;
var continuousDecision = new ActionSegment(new float[numContinuousActions], 0, numContinuousActions);
@@ -47,7 +46,7 @@ public ref readonly ActionBuffers DecideAction()
{
if (!m_Done && m_DecisionRequested)
{
- m_Heuristic.Invoke(m_ActionBuffers);
+ m_ActuatorManager.ApplyHeuristic(m_ActionBuffers);
}
m_DecisionRequested = false;
return ref m_ActionBuffers;
diff --git a/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs b/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs
index e10cea84da..e095557ced 100644
--- a/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs
+++ b/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs
@@ -303,5 +303,23 @@ public void TestWriteDiscreteActionMask()
manager.WriteActionMask();
Assert.IsTrue(groundTruthMask.SequenceEqual(manager.DiscreteActionMask.GetMask()));
}
+
+ [Test]
+ public void TestHeuristic()
+ {
+ var manager = new ActuatorManager(2);
+ var va1 = new TestActuator(ActionSpec.MakeDiscrete(1, 2, 3), "name");
+ var va2 = new TestActuator(ActionSpec.MakeDiscrete(3, 2, 1, 8), "name1");
+ manager.Add(va1);
+ manager.Add(va2);
+
+ var actionBuf = new ActionBuffers(Array.Empty(), new[] { 0, 0, 0, 0, 0, 0, 0 });
+ manager.ApplyHeuristic(actionBuf);
+
+ Assert.IsTrue(va1.m_HeuristicCalled);
+ Assert.AreEqual(va1.m_DiscreteBufferSize, 3);
+ Assert.IsTrue(va2.m_HeuristicCalled);
+ Assert.AreEqual(va2.m_DiscreteBufferSize, 4);
+ }
}
}
diff --git a/com.unity.ml-agents/Tests/Editor/Actuators/TestActuator.cs b/com.unity.ml-agents/Tests/Editor/Actuators/TestActuator.cs
index d0d0c4d2e9..649a643320 100644
--- a/com.unity.ml-agents/Tests/Editor/Actuators/TestActuator.cs
+++ b/com.unity.ml-agents/Tests/Editor/Actuators/TestActuator.cs
@@ -1,10 +1,13 @@
using Unity.MLAgents.Actuators;
namespace Unity.MLAgents.Tests.Actuators
{
- internal class TestActuator : IActuator
+ internal class TestActuator : IActuator, IHeuristicProvider
{
public ActionBuffers LastActionBuffer;
public int[][] Masks;
+ public bool m_HeuristicCalled;
+ public int m_DiscreteBufferSize;
+
public TestActuator(ActionSpec actuatorSpace, string name)
{
ActionSpec = actuatorSpace;
@@ -32,5 +35,11 @@ public void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
public void ResetData()
{
}
+
+ public void Heuristic(in ActionBuffers actionBuffersOut)
+ {
+ m_HeuristicCalled = true;
+ m_DiscreteBufferSize = actionBuffersOut.DiscreteActions.Length;
+ }
}
}
diff --git a/com.unity.ml-agents/Tests/Editor/Actuators/VectorActuatorTests.cs b/com.unity.ml-agents/Tests/Editor/Actuators/VectorActuatorTests.cs
index 1b9625463c..2b3dcabfef 100644
--- a/com.unity.ml-agents/Tests/Editor/Actuators/VectorActuatorTests.cs
+++ b/com.unity.ml-agents/Tests/Editor/Actuators/VectorActuatorTests.cs
@@ -1,3 +1,4 @@
+using System;
using System.Collections.Generic;
using System.Linq;
using NUnit.Framework;
@@ -9,12 +10,13 @@ namespace Unity.MLAgents.Tests.Actuators
[TestFixture]
public class VectorActuatorTests
{
- class TestActionReceiver : IActionReceiver
+ class TestActionReceiver : IActionReceiver, IHeuristicProvider
{
public ActionBuffers LastActionBuffers;
public int Branch;
public IList Mask;
public ActionSpec ActionSpec { get; }
+ public bool HeuristicCalled;
public void OnActionReceived(ActionBuffers actionBuffers)
{
@@ -25,6 +27,11 @@ public void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
{
actionMask.WriteMask(Branch, Mask);
}
+
+ public void Heuristic(in ActionBuffers actionBuffersOut)
+ {
+ HeuristicCalled = true;
+ }
}
[Test]
@@ -93,5 +100,15 @@ public void TestWriteDiscreteActionMask()
Assert.IsTrue(groundTruthMask.SequenceEqual(bdam.GetMask()));
}
+
+ [Test]
+ public void TestHeuristic()
+ {
+ var ar = new TestActionReceiver();
+ var va = new VectorActuator(ar, ActionSpec.MakeDiscrete(1, 2, 3), "name");
+
+ va.Heuristic(new ActionBuffers(Array.Empty(), va.ActionSpec.BranchSizes));
+ Assert.IsTrue(ar.HeuristicCalled);
+ }
}
}
diff --git a/com.unity.ml-agents/Tests/Editor/BehaviorParameterTests.cs b/com.unity.ml-agents/Tests/Editor/BehaviorParameterTests.cs
index f4d46b7c04..3d79558935 100644
--- a/com.unity.ml-agents/Tests/Editor/BehaviorParameterTests.cs
+++ b/com.unity.ml-agents/Tests/Editor/BehaviorParameterTests.cs
@@ -6,9 +6,9 @@
namespace Unity.MLAgents.Tests
{
[TestFixture]
- public class BehaviorParameterTests
+ public class BehaviorParameterTests : IHeuristicProvider
{
- static void DummyHeuristic(in ActionBuffers actionsOut)
+ public void Heuristic(in ActionBuffers actionsOut)
{
// No-op
}
@@ -23,7 +23,7 @@ public void TestNoModelInferenceOnlyThrows()
Assert.Throws(() =>
{
- bp.GeneratePolicy(actionSpec, DummyHeuristic);
+ bp.GeneratePolicy(actionSpec, new ActuatorManager());
});
}
}
diff --git a/docs/Migrating.md b/docs/Migrating.md
index afdc1ce40a..dd0d6f77e2 100644
--- a/docs/Migrating.md
+++ b/docs/Migrating.md
@@ -12,6 +12,14 @@ double-check that the versions are in the same. The versions can be found in
- `UnityEnvironment.API_VERSION` in environment.py
([example](https://github.com/Unity-Technologies/ml-agents/blob/b255661084cb8f701c716b040693069a3fb9a257/ml-agents-envs/mlagents/envs/environment.py#L45))
+
+# Migrating
+## Migrating to Release 13
+### Implementing IHeuristic in your IActuator implementations
+ - If you have any custom actuators, you can now implement the `IHeuristicProvider` interface to have your actuator
+handle the generation of actions when an Agent is running in heuristic mode.
+
+
# Migrating
## Migrating to Release 11
### Agent virtual method deprecation