diff --git a/.yamato/gym-interface-test.yml b/.yamato/gym-interface-test.yml index 8ecb5f9347..01c975bca2 100644 --- a/.yamato/gym-interface-test.yml +++ b/.yamato/gym-interface-test.yml @@ -13,6 +13,7 @@ test_gym_interface_{{ editor.version }}: - | sudo apt-get update && sudo apt-get install -y python3-venv python3 -m venv venv && source venv/bin/activate + python -m pip install wheel --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple python -m pip install pyyaml --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple python -u -m ml-agents.tests.yamato.setup_venv python ml-agents/tests/yamato/scripts/run_gym.py --env=artifacts/testPlayer-Basic diff --git a/Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicActuatorComponent.cs b/Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicActuatorComponent.cs index 1caea48397..6369271199 100644 --- a/Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicActuatorComponent.cs +++ b/Project/Assets/ML-Agents/Examples/Basic/Scripts/BasicActuatorComponent.cs @@ -1,5 +1,6 @@ using System; using Unity.MLAgents.Actuators; +using UnityEngine; namespace Unity.MLAgentsExamples { @@ -30,7 +31,7 @@ public override ActionSpec ActionSpec /// /// Simple actuator that converts the action into a {-1, 0, 1} direction /// - public class BasicActuator : IActuator + public class BasicActuator : IActuator, IHeuristicProvider { public BasicController basicController; ActionSpec m_ActionSpec; @@ -76,6 +77,19 @@ public void OnActionReceived(ActionBuffers actionBuffers) basicController.MoveDirection(direction); } + public void Heuristic(in ActionBuffers actionBuffersOut) + { + var direction = Input.GetAxis("Horizontal"); + var discreteActions = actionBuffersOut.DiscreteActions; + if (Mathf.Approximately(direction, 0.0f)) + { + discreteActions[0] = 0; + return; + } + var sign = Math.Sign(direction); + discreteActions[0] = sign < 0 ? 1 : 2; + } + public void WriteDiscreteActionMask(IDiscreteActionMask actionMask) { diff --git a/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3Heuristic.prefab b/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3Heuristic.prefab index ec8b60c12c..adf6bbe161 100644 --- a/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3Heuristic.prefab +++ b/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3Heuristic.prefab @@ -13,8 +13,8 @@ GameObject: - component: {fileID: 3508723250470608012} - component: {fileID: 3508723250470608011} - component: {fileID: 3508723250470608009} - - component: {fileID: 3508723250470608013} - component: {fileID: 3508723250470608014} + - component: {fileID: 2112317463290853299} m_Layer: 0 m_Name: Match3 Agent m_TagString: Untagged @@ -51,9 +51,13 @@ MonoBehaviour: m_BrainParameters: VectorObservationSize: 0 NumStackedVectorObservations: 1 + m_ActionSpec: + m_NumContinuousActions: 0 + BranchSizes: VectorActionSize: VectorActionDescriptions: [] VectorActionSpaceType: 0 + hasUpgradedBrainParametersWithActionSpec: 1 m_Model: {fileID: 11400000, guid: c34da50737a3c4a50918002b20b2b927, type: 3} m_InferenceDevice: 0 m_BehaviorType: 0 @@ -81,7 +85,6 @@ MonoBehaviour: Board: {fileID: 0} MoveTime: 0.25 MaxMoves: 500 - HeuristicQuality: 0 --- !u!114 &3508723250470608011 MonoBehaviour: m_ObjectHideFlags: 0 @@ -96,7 +99,6 @@ MonoBehaviour: m_EditorClassIdentifier: DebugMoveIndex: -1 CubeSpacing: 1.25 - Board: {fileID: 0} TilePrefab: {fileID: 4007900521885639951, guid: faee4e805953b49e688bd00b45c55f2e, type: 3} --- !u!114 &3508723250470608009 @@ -119,7 +121,7 @@ MonoBehaviour: BasicCellPoints: 1 SpecialCell1Points: 2 SpecialCell2Points: 3 ---- !u!114 &3508723250470608013 +--- !u!114 &3508723250470608014 MonoBehaviour: m_ObjectHideFlags: 0 m_CorrespondingSourceObject: {fileID: 0} @@ -128,12 +130,12 @@ MonoBehaviour: m_GameObject: {fileID: 3508723250470608007} m_Enabled: 1 m_EditorHideFlags: 0 - m_Script: {fileID: 11500000, guid: 08e4b0da54cb4d56bfcbae22dd49ab8d, type: 3} + m_Script: {fileID: 11500000, guid: 530d2f105aa145bd8a00e021bdd925fd, type: 3} m_Name: m_EditorClassIdentifier: - ActuatorName: Match3 Actuator - ForceHeuristic: 1 ---- !u!114 &3508723250470608014 + SensorName: Match3 Sensor + ObservationType: 0 +--- !u!114 &2112317463290853299 MonoBehaviour: m_ObjectHideFlags: 0 m_CorrespondingSourceObject: {fileID: 0} @@ -142,11 +144,12 @@ MonoBehaviour: m_GameObject: {fileID: 3508723250470608007} m_Enabled: 1 m_EditorHideFlags: 0 - m_Script: {fileID: 11500000, guid: 530d2f105aa145bd8a00e021bdd925fd, type: 3} + m_Script: {fileID: 11500000, guid: b17adcc6c9b241da903aa134f2dac930, type: 3} m_Name: m_EditorClassIdentifier: - SensorName: Match3 Sensor - ObservationType: 0 + ActuatorName: Match3 Actuator + ForceHeuristic: 1 + HeuristicQuality: 0 --- !u!1 &3508723250774301855 GameObject: m_ObjectHideFlags: 0 diff --git a/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VectorObs.prefab b/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VectorObs.prefab index 3b4b66024f..4177d3c687 100644 --- a/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VectorObs.prefab +++ b/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VectorObs.prefab @@ -44,8 +44,8 @@ GameObject: - component: {fileID: 2118285884327540682} - component: {fileID: 2118285884327540685} - component: {fileID: 2118285884327540687} - - component: {fileID: 2118285884327540683} - component: {fileID: 2118285884327540680} + - component: {fileID: 3357012711826686276} m_Layer: 0 m_Name: Match3 Agent m_TagString: Untagged @@ -82,9 +82,13 @@ MonoBehaviour: m_BrainParameters: VectorObservationSize: 0 NumStackedVectorObservations: 1 + m_ActionSpec: + m_NumContinuousActions: 0 + BranchSizes: VectorActionSize: VectorActionDescriptions: [] VectorActionSpaceType: 0 + hasUpgradedBrainParametersWithActionSpec: 1 m_Model: {fileID: 11400000, guid: 9e89b8e81974148d3b7213530d00589d, type: 3} m_InferenceDevice: 0 m_BehaviorType: 0 @@ -112,7 +116,6 @@ MonoBehaviour: Board: {fileID: 0} MoveTime: 0.25 MaxMoves: 500 - HeuristicQuality: 0 --- !u!114 &2118285884327540685 MonoBehaviour: m_ObjectHideFlags: 0 @@ -127,7 +130,6 @@ MonoBehaviour: m_EditorClassIdentifier: DebugMoveIndex: -1 CubeSpacing: 1.25 - Board: {fileID: 0} TilePrefab: {fileID: 4007900521885639951, guid: faee4e805953b49e688bd00b45c55f2e, type: 3} --- !u!114 &2118285884327540687 @@ -150,7 +152,7 @@ MonoBehaviour: BasicCellPoints: 1 SpecialCell1Points: 2 SpecialCell2Points: 3 ---- !u!114 &2118285884327540683 +--- !u!114 &2118285884327540680 MonoBehaviour: m_ObjectHideFlags: 0 m_CorrespondingSourceObject: {fileID: 0} @@ -159,12 +161,12 @@ MonoBehaviour: m_GameObject: {fileID: 2118285884327540673} m_Enabled: 1 m_EditorHideFlags: 0 - m_Script: {fileID: 11500000, guid: 08e4b0da54cb4d56bfcbae22dd49ab8d, type: 3} + m_Script: {fileID: 11500000, guid: 530d2f105aa145bd8a00e021bdd925fd, type: 3} m_Name: m_EditorClassIdentifier: - ActuatorName: Match3 Actuator - ForceHeuristic: 0 ---- !u!114 &2118285884327540680 + SensorName: Match3 Sensor + ObservationType: 0 +--- !u!114 &3357012711826686276 MonoBehaviour: m_ObjectHideFlags: 0 m_CorrespondingSourceObject: {fileID: 0} @@ -173,8 +175,9 @@ MonoBehaviour: m_GameObject: {fileID: 2118285884327540673} m_Enabled: 1 m_EditorHideFlags: 0 - m_Script: {fileID: 11500000, guid: 530d2f105aa145bd8a00e021bdd925fd, type: 3} + m_Script: {fileID: 11500000, guid: b17adcc6c9b241da903aa134f2dac930, type: 3} m_Name: m_EditorClassIdentifier: - SensorName: Match3 Sensor - ObservationType: 0 + ActuatorName: Match3 Actuator + ForceHeuristic: 0 + HeuristicQuality: 0 diff --git a/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VisualObs.prefab b/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VisualObs.prefab index 28d219acd3..2c78f72948 100644 --- a/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VisualObs.prefab +++ b/Project/Assets/ML-Agents/Examples/Match3/Prefabs/Match3VisualObs.prefab @@ -44,8 +44,8 @@ GameObject: - component: {fileID: 3019509692332007781} - component: {fileID: 3019509692332007778} - component: {fileID: 3019509692332007776} - - component: {fileID: 3019509692332007780} - component: {fileID: 3019509692332007783} + - component: {fileID: 8270768986451624427} m_Layer: 0 m_Name: Match3 Agent m_TagString: Untagged @@ -82,9 +82,13 @@ MonoBehaviour: m_BrainParameters: VectorObservationSize: 0 NumStackedVectorObservations: 1 + m_ActionSpec: + m_NumContinuousActions: 0 + BranchSizes: VectorActionSize: VectorActionDescriptions: [] VectorActionSpaceType: 0 + hasUpgradedBrainParametersWithActionSpec: 1 m_Model: {fileID: 11400000, guid: 48d14da88fea74d0693c691c6e3f2e34, type: 3} m_InferenceDevice: 0 m_BehaviorType: 0 @@ -112,7 +116,6 @@ MonoBehaviour: Board: {fileID: 0} MoveTime: 0.25 MaxMoves: 500 - HeuristicQuality: 0 --- !u!114 &3019509692332007778 MonoBehaviour: m_ObjectHideFlags: 0 @@ -127,7 +130,6 @@ MonoBehaviour: m_EditorClassIdentifier: DebugMoveIndex: -1 CubeSpacing: 1.25 - Board: {fileID: 0} TilePrefab: {fileID: 4007900521885639951, guid: faee4e805953b49e688bd00b45c55f2e, type: 3} --- !u!114 &3019509692332007776 @@ -150,7 +152,7 @@ MonoBehaviour: BasicCellPoints: 1 SpecialCell1Points: 2 SpecialCell2Points: 3 ---- !u!114 &3019509692332007780 +--- !u!114 &3019509692332007783 MonoBehaviour: m_ObjectHideFlags: 0 m_CorrespondingSourceObject: {fileID: 0} @@ -159,12 +161,12 @@ MonoBehaviour: m_GameObject: {fileID: 3019509692332007790} m_Enabled: 1 m_EditorHideFlags: 0 - m_Script: {fileID: 11500000, guid: 08e4b0da54cb4d56bfcbae22dd49ab8d, type: 3} + m_Script: {fileID: 11500000, guid: 530d2f105aa145bd8a00e021bdd925fd, type: 3} m_Name: m_EditorClassIdentifier: - ActuatorName: Match3 Actuator - ForceHeuristic: 0 ---- !u!114 &3019509692332007783 + SensorName: Match3 Sensor + ObservationType: 2 +--- !u!114 &8270768986451624427 MonoBehaviour: m_ObjectHideFlags: 0 m_CorrespondingSourceObject: {fileID: 0} @@ -173,8 +175,9 @@ MonoBehaviour: m_GameObject: {fileID: 3019509692332007790} m_Enabled: 1 m_EditorHideFlags: 0 - m_Script: {fileID: 11500000, guid: 530d2f105aa145bd8a00e021bdd925fd, type: 3} + m_Script: {fileID: 11500000, guid: b17adcc6c9b241da903aa134f2dac930, type: 3} m_Name: m_EditorClassIdentifier: - SensorName: Match3 Sensor - ObservationType: 2 + ActuatorName: Match3 Actuator + ForceHeuristic: 0 + HeuristicQuality: 0 diff --git a/Project/Assets/ML-Agents/Examples/Match3/Scenes/Match3.unity b/Project/Assets/ML-Agents/Examples/Match3/Scenes/Match3.unity index 8383c776d2..f543515eb8 100644 --- a/Project/Assets/ML-Agents/Examples/Match3/Scenes/Match3.unity +++ b/Project/Assets/ML-Agents/Examples/Match3/Scenes/Match3.unity @@ -690,6 +690,11 @@ PrefabInstance: m_Modification: m_TransformParent: {fileID: 0} m_Modifications: + - target: {fileID: 2112317463290853299, guid: 2fafdcd0587684641b03b11f04454f1b, + type: 3} + propertyPath: HeuristicQuality + value: 1 + objectReference: {fileID: 0} - target: {fileID: 3508723250470608011, guid: 2fafdcd0587684641b03b11f04454f1b, type: 3} propertyPath: cubeSpacing @@ -1385,6 +1390,11 @@ PrefabInstance: m_Modification: m_TransformParent: {fileID: 0} m_Modifications: + - target: {fileID: 2112317463290853299, guid: 2fafdcd0587684641b03b11f04454f1b, + type: 3} + propertyPath: HeuristicQuality + value: 1 + objectReference: {fileID: 0} - target: {fileID: 3508723250470608011, guid: 2fafdcd0587684641b03b11f04454f1b, type: 3} propertyPath: cubeSpacing diff --git a/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Agent.cs b/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Agent.cs index a890f12c17..7c7192b8f7 100644 --- a/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Agent.cs +++ b/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Agent.cs @@ -63,20 +63,6 @@ enum State WaitForMove = 4, } - public enum HeuristicQuality - { - /// - /// The heuristic will pick any valid move at random. - /// - RandomValidMove, - - /// - /// The heuristic will pick the move that scores the most points. - /// This only looks at the immediate move, and doesn't consider where cells will fall. - /// - Greedy - } - public class Match3Agent : Agent { [HideInInspector] @@ -86,20 +72,14 @@ public class Match3Agent : Agent public int MaxMoves = 500; - public HeuristicQuality HeuristicQuality = HeuristicQuality.RandomValidMove; - State m_CurrentState = State.WaitForMove; float m_TimeUntilMove; private int m_MovesMade; - private System.Random m_Random; private const float k_RewardMultiplier = 0.01f; - void Awake() { Board = GetComponent(); - var seed = Board.RandomSeed == -1 ? gameObject.GetInstanceID() : Board.RandomSeed + 1; - m_Random = new System.Random(seed); } public override void OnEpisodeBegin() @@ -222,152 +202,6 @@ bool HasValidMoves() return false; } - public override void Heuristic(in ActionBuffers actionsOut) - { - var discreteActions = actionsOut.DiscreteActions; - discreteActions[0] = GreedyMove(); - } - - int GreedyMove() - { - var pointsByType = new[] { Board.BasicCellPoints, Board.SpecialCell1Points, Board.SpecialCell2Points }; - - var bestMoveIndex = 0; - var bestMovePoints = -1; - var numMovesAtCurrentScore = 0; - - foreach (var move in Board.ValidMoves()) - { - var movePoints = HeuristicQuality == HeuristicQuality.Greedy ? EvalMovePoints(move, pointsByType) : 1; - if (movePoints < bestMovePoints) - { - // Worse, skip - continue; - } - - if (movePoints > bestMovePoints) - { - // Better, keep - bestMovePoints = movePoints; - bestMoveIndex = move.MoveIndex; - numMovesAtCurrentScore = 1; - } - else - { - // Tied for best - use reservoir sampling to make sure we select from equal moves uniformly. - // See https://en.wikipedia.org/wiki/Reservoir_sampling#Simple_algorithm - numMovesAtCurrentScore++; - var randVal = m_Random.Next(0, numMovesAtCurrentScore); - if (randVal == 0) - { - // Keep the new one - bestMoveIndex = move.MoveIndex; - } - } - } - - return bestMoveIndex; - } - - int EvalMovePoints(Move move, int[] pointsByType) - { - // Counts the expected points for making the move. - var moveVal = Board.GetCellType(move.Row, move.Column); - var moveSpecial = Board.GetSpecialType(move.Row, move.Column); - var (otherRow, otherCol) = move.OtherCell(); - var oppositeVal = Board.GetCellType(otherRow, otherCol); - var oppositeSpecial = Board.GetSpecialType(otherRow, otherCol); - - - int movePoints = EvalHalfMove( - otherRow, otherCol, moveVal, moveSpecial, move.Direction, pointsByType - ); - int otherPoints = EvalHalfMove( - move.Row, move.Column, oppositeVal, oppositeSpecial, move.OtherDirection(), pointsByType - ); - return movePoints + otherPoints; - } - - int EvalHalfMove(int newRow, int newCol, int newValue, int newSpecial, Direction incomingDirection, int[] pointsByType) - { - // This is a essentially a duplicate of AbstractBoard.CheckHalfMove but also counts the points for the move. - int matchedLeft = 0, matchedRight = 0, matchedUp = 0, matchedDown = 0; - int scoreLeft = 0, scoreRight = 0, scoreUp = 0, scoreDown = 0; - - if (incomingDirection != Direction.Right) - { - for (var c = newCol - 1; c >= 0; c--) - { - if (Board.GetCellType(newRow, c) == newValue) - { - matchedLeft++; - scoreLeft += pointsByType[Board.GetSpecialType(newRow, c)]; - } - else - break; - } - } - - if (incomingDirection != Direction.Left) - { - for (var c = newCol + 1; c < Board.Columns; c++) - { - if (Board.GetCellType(newRow, c) == newValue) - { - matchedRight++; - scoreRight += pointsByType[Board.GetSpecialType(newRow, c)]; - } - else - break; - } - } - - if (incomingDirection != Direction.Down) - { - for (var r = newRow + 1; r < Board.Rows; r++) - { - if (Board.GetCellType(r, newCol) == newValue) - { - matchedUp++; - scoreUp += pointsByType[Board.GetSpecialType(r, newCol)]; - } - else - break; - } - } - - if (incomingDirection != Direction.Up) - { - for (var r = newRow - 1; r >= 0; r--) - { - if (Board.GetCellType(r, newCol) == newValue) - { - matchedDown++; - scoreDown += pointsByType[Board.GetSpecialType(r, newCol)]; - } - else - break; - } - } - - if ((matchedUp + matchedDown >= 2) || (matchedLeft + matchedRight >= 2)) - { - // It's a match. Start from counting the piece being moved - var totalScore = pointsByType[newSpecial]; - if (matchedUp + matchedDown >= 2) - { - totalScore += scoreUp + scoreDown; - } - - if (matchedLeft + matchedRight >= 2) - { - totalScore += scoreLeft + scoreRight; - } - return totalScore; - } - - return 0; - } } } diff --git a/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Board.cs b/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Board.cs index 859ec2b9e8..9aa1d59f57 100644 --- a/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Board.cs +++ b/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3Board.cs @@ -1,3 +1,4 @@ +using System; using Unity.MLAgents.Extensions.Match3; using UnityEngine; @@ -7,8 +8,6 @@ namespace Unity.MLAgentsExamples public class Match3Board : AbstractBoard { - public int RandomSeed = -1; - public const int k_EmptyCell = -1; [Tooltip("Points earned for clearing a basic cell (cube)")] public int BasicCellPoints = 1; @@ -19,6 +18,11 @@ public class Match3Board : AbstractBoard [Tooltip("Points earned for clearing an extra special cell (plus)")] public int SpecialCell2Points = 3; + /// + /// Seed to initialize the object. + /// + public int RandomSeed; + (int, int)[,] m_Cells; bool[,] m_Matched; @@ -29,8 +33,11 @@ void Awake() m_Cells = new (int, int)[Columns, Rows]; m_Matched = new bool[Columns, Rows]; - m_Random = new System.Random(RandomSeed == -1 ? gameObject.GetInstanceID() : RandomSeed); + } + void Start() + { + m_Random = new System.Random(RandomSeed == -1 ? gameObject.GetInstanceID() : RandomSeed); InitRandom(); } diff --git a/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3ExampleActuator.cs b/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3ExampleActuator.cs new file mode 100644 index 0000000000..c50dc9ff2a --- /dev/null +++ b/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3ExampleActuator.cs @@ -0,0 +1,121 @@ +using Unity.MLAgents; +using Unity.MLAgents.Extensions.Match3; + +namespace Unity.MLAgentsExamples +{ + public class Match3ExampleActuator : Match3Actuator + { + Match3Board Board => (Match3Board)m_Board; + + public Match3ExampleActuator(Match3Board board, + bool forceHeuristic, + Agent agent, + string name, + int seed + ) + : base(board, forceHeuristic, seed, agent, name) { } + + + protected override int EvalMovePoints(Move move) + { + var pointsByType = new[] { Board.BasicCellPoints, Board.SpecialCell1Points, Board.SpecialCell2Points }; + // Counts the expected points for making the move. + var moveVal = m_Board.GetCellType(move.Row, move.Column); + var moveSpecial = m_Board.GetSpecialType(move.Row, move.Column); + var (otherRow, otherCol) = move.OtherCell(); + var oppositeVal = m_Board.GetCellType(otherRow, otherCol); + var oppositeSpecial = m_Board.GetSpecialType(otherRow, otherCol); + + + int movePoints = EvalHalfMove( + otherRow, otherCol, moveVal, moveSpecial, move.Direction, pointsByType + ); + int otherPoints = EvalHalfMove( + move.Row, move.Column, oppositeVal, oppositeSpecial, move.OtherDirection(), pointsByType + ); + return movePoints + otherPoints; + } + + int EvalHalfMove(int newRow, int newCol, int newValue, int newSpecial, Direction incomingDirection, int[] pointsByType) + { + // This is a essentially a duplicate of AbstractBoard.CheckHalfMove but also counts the points for the move. + int matchedLeft = 0, matchedRight = 0, matchedUp = 0, matchedDown = 0; + int scoreLeft = 0, scoreRight = 0, scoreUp = 0, scoreDown = 0; + + if (incomingDirection != Direction.Right) + { + for (var c = newCol - 1; c >= 0; c--) + { + if (m_Board.GetCellType(newRow, c) == newValue) + { + matchedLeft++; + scoreLeft += pointsByType[m_Board.GetSpecialType(newRow, c)]; + } + else + break; + } + } + + if (incomingDirection != Direction.Left) + { + for (var c = newCol + 1; c < m_Board.Columns; c++) + { + if (m_Board.GetCellType(newRow, c) == newValue) + { + matchedRight++; + scoreRight += pointsByType[m_Board.GetSpecialType(newRow, c)]; + } + else + break; + } + } + + if (incomingDirection != Direction.Down) + { + for (var r = newRow + 1; r < m_Board.Rows; r++) + { + if (m_Board.GetCellType(r, newCol) == newValue) + { + matchedUp++; + scoreUp += pointsByType[m_Board.GetSpecialType(r, newCol)]; + } + else + break; + } + } + + if (incomingDirection != Direction.Up) + { + for (var r = newRow - 1; r >= 0; r--) + { + if (m_Board.GetCellType(r, newCol) == newValue) + { + matchedDown++; + scoreDown += pointsByType[m_Board.GetSpecialType(r, newCol)]; + } + else + break; + } + } + + if ((matchedUp + matchedDown >= 2) || (matchedLeft + matchedRight >= 2)) + { + // It's a match. Start from counting the piece being moved + var totalScore = pointsByType[newSpecial]; + if (matchedUp + matchedDown >= 2) + { + totalScore += scoreUp + scoreDown; + } + + if (matchedLeft + matchedRight >= 2) + { + totalScore += scoreLeft + scoreRight; + } + return totalScore; + } + + return 0; + } + } + +} diff --git a/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3ExampleActuator.cs.meta b/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3ExampleActuator.cs.meta new file mode 100644 index 0000000000..acaf52892c --- /dev/null +++ b/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3ExampleActuator.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: 9e6fe1a020a04421ab828be4543a655c +timeCreated: 1610665874 \ No newline at end of file diff --git a/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3ExampleActuatorComponent.cs b/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3ExampleActuatorComponent.cs new file mode 100644 index 0000000000..8f32bf1755 --- /dev/null +++ b/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3ExampleActuatorComponent.cs @@ -0,0 +1,18 @@ +using Unity.MLAgents; +using Unity.MLAgents.Actuators; +using Unity.MLAgents.Extensions.Match3; + +namespace Unity.MLAgentsExamples +{ + public class Match3ExampleActuatorComponent : Match3ActuatorComponent + { + /// + public override IActuator CreateActuator() + { + var board = GetComponent(); + var agent = GetComponentInParent(); + var seed = RandomSeed == -1 ? gameObject.GetInstanceID() : RandomSeed + 1; + return new Match3ExampleActuator(board, ForceHeuristic, agent, ActuatorName, seed); + } + } +} diff --git a/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3ExampleActuatorComponent.cs.meta b/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3ExampleActuatorComponent.cs.meta new file mode 100644 index 0000000000..e0569da775 --- /dev/null +++ b/Project/Assets/ML-Agents/Examples/Match3/Scripts/Match3ExampleActuatorComponent.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: b17adcc6c9b241da903aa134f2dac930 +timeCreated: 1610665885 \ No newline at end of file diff --git a/com.unity.ml-agents.extensions/Runtime/Match3/Match3Actuator.cs b/com.unity.ml-agents.extensions/Runtime/Match3/Match3Actuator.cs index 67ff107344..bfa2d267d1 100644 --- a/com.unity.ml-agents.extensions/Runtime/Match3/Match3Actuator.cs +++ b/com.unity.ml-agents.extensions/Runtime/Match3/Match3Actuator.cs @@ -9,12 +9,12 @@ namespace Unity.MLAgents.Extensions.Match3 /// Actuator for a Match3 game. It translates valid moves (defined by AbstractBoard.IsMoveValid()) /// in action masks, and applies the action to the board via AbstractBoard.MakeMove(). /// - public class Match3Actuator : IActuator + public class Match3Actuator : IActuator, IHeuristicProvider { - private AbstractBoard m_Board; + protected AbstractBoard m_Board; + protected System.Random m_Random; private ActionSpec m_ActionSpec; private bool m_ForceHeuristic; - private System.Random m_Random; private Agent m_Agent; private int m_Rows; @@ -27,9 +27,14 @@ public class Match3Actuator : IActuator /// /// Whether the inference action should be ignored and the Agent's Heuristic /// should be called. This should only be used for generating comparison stats of the Heuristic. + /// The seed used to initialize . /// /// - public Match3Actuator(AbstractBoard board, bool forceHeuristic, Agent agent, string name) + public Match3Actuator(AbstractBoard board, + bool forceHeuristic, + int seed, + Agent agent, + string name) { m_Board = board; m_Rows = board.Rows; @@ -42,6 +47,7 @@ public Match3Actuator(AbstractBoard board, bool forceHeuristic, Agent agent, str var numMoves = Move.NumPotentialMoves(m_Board.Rows, m_Board.Columns); m_ActionSpec = ActionSpec.MakeDiscrete(numMoves); + m_Random = new System.Random(seed); } /// @@ -52,7 +58,7 @@ public void OnActionReceived(ActionBuffers actions) { if (m_ForceHeuristic) { - m_Agent.Heuristic(actions); + Heuristic(actions); } var moveIndex = actions.DiscreteActions[0]; @@ -116,5 +122,63 @@ IEnumerable InvalidMoveIndices() yield return move.MoveIndex; } } + + public void Heuristic(in ActionBuffers actionsOut) + { + var discreteActions = actionsOut.DiscreteActions; + discreteActions[0] = GreedyMove(); + } + + + protected int GreedyMove() + { + + var bestMoveIndex = 0; + var bestMovePoints = -1; + var numMovesAtCurrentScore = 0; + + foreach (var move in m_Board.ValidMoves()) + { + var movePoints = EvalMovePoints(move); + if (movePoints < bestMovePoints) + { + // Worse, skip + continue; + } + + if (movePoints > bestMovePoints) + { + // Better, keep + bestMovePoints = movePoints; + bestMoveIndex = move.MoveIndex; + numMovesAtCurrentScore = 1; + } + else + { + // Tied for best - use reservoir sampling to make sure we select from equal moves uniformly. + // See https://en.wikipedia.org/wiki/Reservoir_sampling#Simple_algorithm + numMovesAtCurrentScore++; + var randVal = m_Random.Next(0, numMovesAtCurrentScore); + if (randVal == 0) + { + // Keep the new one + bestMoveIndex = move.MoveIndex; + } + } + } + + return bestMoveIndex; + } + + /// + /// Method to be overridden when evaluating how many points a specific move will generate. + /// + /// The move to evaluate. + /// The number of points the move generates. + protected virtual int EvalMovePoints(Move move) + { + return 1; + } + } } diff --git a/com.unity.ml-agents.extensions/Runtime/Match3/Match3ActuatorComponent.cs b/com.unity.ml-agents.extensions/Runtime/Match3/Match3ActuatorComponent.cs index 9c48d9cf20..7f14fb4ccd 100644 --- a/com.unity.ml-agents.extensions/Runtime/Match3/Match3ActuatorComponent.cs +++ b/com.unity.ml-agents.extensions/Runtime/Match3/Match3ActuatorComponent.cs @@ -5,7 +5,7 @@ namespace Unity.MLAgents.Extensions.Match3 { /// - /// Actuator component for a Match 3 game. Generates a Match3Actuator at runtime. + /// Actuator component for a Match3 game. Generates a Match3Actuator at runtime. /// public class Match3ActuatorComponent : ActuatorComponent { @@ -15,6 +15,11 @@ public class Match3ActuatorComponent : ActuatorComponent /// public string ActuatorName = "Match3 Actuator"; + /// + /// A random seed used to generate a board, if needed. + /// + public int RandomSeed = -1; + /// /// Force using the Agent's Heuristic() method to decide the action. This should only be used in testing. /// @@ -27,7 +32,8 @@ public override IActuator CreateActuator() { var board = GetComponent(); var agent = GetComponentInParent(); - return new Match3Actuator(board, ForceHeuristic, agent, ActuatorName); + var seed = RandomSeed == -1 ? gameObject.GetInstanceID() : RandomSeed + 1; + return new Match3Actuator(board, ForceHeuristic, seed, agent, ActuatorName); } /// diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md index 5d05c80f4f..24e33d3a17 100755 --- a/com.unity.ml-agents/CHANGELOG.md +++ b/com.unity.ml-agents/CHANGELOG.md @@ -20,6 +20,10 @@ will result in the values being summed (instead of averaged) when written to TensorBoard. Thanks to @brccabral for the contribution! (#4816) - The upper limit for the time scale (by setting the `--time-scale` paramater in mlagents-learn) was removed when training with a player. The Editor still requires it to be clamped to 100. (#4867) +- Added the IHeuristicProvider interface to allow IActuators as well as Agent implement the Heuristic function to generate actions. + Updated the Basic example and the Match3 Example to use Actuators. + Changed the namespace and file names of classes in com.unity.ml-agents.extensions. (#4849) + #### ml-agents / ml-agents-envs / gym-unity (Python) diff --git a/com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs b/com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs index aa72c4e423..2e4538b93e 100644 --- a/com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs +++ b/com.unity.ml-agents/Runtime/Actuators/ActuatorManager.cs @@ -204,6 +204,49 @@ public void WriteActionMask() } } + /// + /// Iterates through all of the IActuators in this list and calls their + /// method on them, if implemented, with the appropriate + /// s depending on their . + /// + public void ApplyHeuristic(in ActionBuffers actionBuffersOut) + { + var continuousStart = 0; + var discreteStart = 0; + for (var i = 0; i < m_Actuators.Count; i++) + { + var actuator = m_Actuators[i]; + var numContinuousActions = actuator.ActionSpec.NumContinuousActions; + var numDiscreteActions = actuator.ActionSpec.NumDiscreteActions; + + if (numContinuousActions == 0 && numDiscreteActions == 0) + { + continue; + } + + var continuousActions = ActionSegment.Empty; + if (numContinuousActions > 0) + { + continuousActions = new ActionSegment(actionBuffersOut.ContinuousActions.Array, + continuousStart, + numContinuousActions); + } + + var discreteActions = ActionSegment.Empty; + if (numDiscreteActions > 0) + { + discreteActions = new ActionSegment(actionBuffersOut.DiscreteActions.Array, + discreteStart, + numDiscreteActions); + } + + var heuristic = actuator as IHeuristicProvider; + heuristic?.Heuristic(new ActionBuffers(continuousActions, discreteActions)); + continuousStart += numContinuousActions; + discreteStart += numDiscreteActions; + } + } + /// /// Iterates through all of the IActuators in this list and calls their /// method on them with the appropriate diff --git a/com.unity.ml-agents/Runtime/Actuators/IHeuristicProvider.cs b/com.unity.ml-agents/Runtime/Actuators/IHeuristicProvider.cs new file mode 100644 index 0000000000..b992361c83 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Actuators/IHeuristicProvider.cs @@ -0,0 +1,18 @@ +namespace Unity.MLAgents.Actuators +{ + /// + /// Interface that allows objects to fill out an data structure for controlling + /// behavior of Agents or Actuators. + /// + public interface IHeuristicProvider + { + /// + /// Method called on objects which are expected to fill out the data structure. + /// Object that implement this interface should be careful to be consistent in the placement of their actions + /// in the data structure. + /// + /// The data structure to be filled by the + /// object implementing this interface. + void Heuristic(in ActionBuffers actionBuffersOut); + } +} diff --git a/com.unity.ml-agents/Runtime/Actuators/IHeuristicProvider.cs.meta b/com.unity.ml-agents/Runtime/Actuators/IHeuristicProvider.cs.meta new file mode 100644 index 0000000000..ca8338a072 --- /dev/null +++ b/com.unity.ml-agents/Runtime/Actuators/IHeuristicProvider.cs.meta @@ -0,0 +1,3 @@ +fileFormatVersion: 2 +guid: be90ffb28f39444a8fb02dfd4a82870c +timeCreated: 1610057456 \ No newline at end of file diff --git a/com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs b/com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs index e4bfad2c65..1b18300f56 100644 --- a/com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs +++ b/com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs @@ -1,11 +1,12 @@ namespace Unity.MLAgents.Actuators { /// - /// IActuator implementation that forwards to an . + /// IActuator implementation that forwards calls to an and an . /// - internal class VectorActuator : IActuator + internal class VectorActuator : IActuator, IHeuristicProvider { IActionReceiver m_ActionReceiver; + IHeuristicProvider m_HeuristicProvider; ActionBuffers m_ActionBuffers; internal ActionBuffers ActionBuffers @@ -14,17 +15,34 @@ internal ActionBuffers ActionBuffers private set => m_ActionBuffers = value; } + /// + /// Create a VectorActuator that forwards to the provided IActionReceiver. + /// + /// The used for OnActionReceived and WriteDiscreteActionMask. + /// If this parameter also implements it will be cast and used to forward calls to + /// . + /// + /// + public VectorActuator(IActionReceiver actionReceiver, + ActionSpec actionSpec, + string name = "VectorActuator") + : this(actionReceiver, actionReceiver as IHeuristicProvider, actionSpec, name) { } + /// /// Create a VectorActuator that forwards to the provided IActionReceiver. /// /// The used for OnActionReceived and WriteDiscreteActionMask. + /// The used to fill the + /// for Heuristic Policies. /// /// public VectorActuator(IActionReceiver actionReceiver, + IHeuristicProvider heuristicProvider, ActionSpec actionSpec, string name = "VectorActuator") { m_ActionReceiver = actionReceiver; + m_HeuristicProvider = heuristicProvider; ActionSpec = actionSpec; string suffix; if (actionSpec.NumContinuousActions == 0) @@ -55,6 +73,11 @@ public void OnActionReceived(ActionBuffers actionBuffers) m_ActionReceiver.OnActionReceived(ActionBuffers); } + public void Heuristic(in ActionBuffers actionBuffersOut) + { + m_HeuristicProvider?.Heuristic(actionBuffersOut); + } + /// public void WriteDiscreteActionMask(IDiscreteActionMask actionMask) { diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs index 347dc16ec8..29c0dffe73 100644 --- a/com.unity.ml-agents/Runtime/Agent.cs +++ b/com.unity.ml-agents/Runtime/Agent.cs @@ -166,7 +166,7 @@ public void CopyActions(ActionBuffers actionBuffers) "docs/Learning-Environment-Design-Agents.md")] [Serializable] [RequireComponent(typeof(BehaviorParameters))] - public partial class Agent : MonoBehaviour, ISerializationCallbackReceiver, IActionReceiver + public partial class Agent : MonoBehaviour, ISerializationCallbackReceiver, IActionReceiver, IHeuristicProvider { IPolicy m_Brain; BehaviorParameters m_PolicyFactory; @@ -312,6 +312,11 @@ internal struct AgentParameters /// float[] m_LegacyActionCache; + /// + /// This is used to avoid allocation of a float array during legacy calls to Heuristic. + /// + float[] m_LegacyHeuristicCache; + /// /// Called when the attached [GameObject] becomes enabled and active. /// [GameObject]: https://docs.unity3d.com/Manual/GameObjects.html @@ -429,7 +434,7 @@ public void LazyInitialize() InitializeActuators(); } - m_Brain = m_PolicyFactory.GeneratePolicy(m_ActuatorManager.GetCombinedActionSpec(), Heuristic); + m_Brain = m_PolicyFactory.GeneratePolicy(m_ActuatorManager.GetCombinedActionSpec(), m_ActuatorManager); ResetData(); Initialize(); @@ -606,7 +611,7 @@ internal void ReloadPolicy() return; } m_Brain?.Dispose(); - m_Brain = m_PolicyFactory.GeneratePolicy(m_ActuatorManager.GetCombinedActionSpec(), Heuristic); + m_Brain = m_PolicyFactory.GeneratePolicy(m_ActuatorManager.GetCombinedActionSpec(), m_ActuatorManager); } /// @@ -826,11 +831,11 @@ void ResetData() public virtual void Initialize() { } /// - /// Implement `Heuristic()` to choose an action for this agent using a custom heuristic. + /// Implement to choose an action for this agent using a custom heuristic. /// /// /// Implement this function to provide custom decision making logic or to support manual - /// control of an agent using keyboard, mouse, or game controller input. + /// control of an agent using keyboard, mouse, game controller input, or a script. /// /// Your heuristic implementation can use any decision making logic you specify. Assign decision /// values to the and @@ -894,22 +899,21 @@ public virtual void Heuristic(in ActionBuffers actionsOut) switch (m_PolicyFactory.BrainParameters.VectorActionSpaceType) { case SpaceType.Continuous: - Heuristic(actionsOut.ContinuousActions.Array); + Heuristic(m_LegacyHeuristicCache); + Array.Copy(m_LegacyHeuristicCache, actionsOut.ContinuousActions.Array, m_LegacyActionCache.Length); actionsOut.DiscreteActions.Clear(); break; case SpaceType.Discrete: - var convertedOut = Array.ConvertAll(actionsOut.DiscreteActions.Array, x => (float)x); - Heuristic(convertedOut); + Heuristic(m_LegacyHeuristicCache); var discreteActionSegment = actionsOut.DiscreteActions; for (var i = 0; i < actionsOut.DiscreteActions.Length; i++) { - discreteActionSegment[i] = (int)convertedOut[i]; + discreteActionSegment[i] = (int)m_LegacyHeuristicCache[i]; } actionsOut.ContinuousActions.Clear(); break; } #pragma warning restore CS0618 - } /// @@ -993,9 +997,10 @@ void InitializeActuators() // Support legacy OnActionReceived // TODO don't set this up if the sizes are 0? var param = m_PolicyFactory.BrainParameters; - m_VectorActuator = new VectorActuator(this, param.ActionSpec); + m_VectorActuator = new VectorActuator(this, this, param.ActionSpec); m_ActuatorManager = new ActuatorManager(attachedActuators.Length + 1); m_LegacyActionCache = new float[m_VectorActuator.TotalNumberOfActions()]; + m_LegacyHeuristicCache = new float[m_VectorActuator.TotalNumberOfActions()]; m_ActuatorManager.Add(m_VectorActuator); @@ -1178,7 +1183,7 @@ public virtual void WriteDiscreteActionMask(IDiscreteActionMask actionMask) /// three values in ActionBuffers.ContinuousActions array to use as the force components. /// During training, the agent's policy learns to set those particular elements of /// the array to maximize the training rewards the agent receives. (Of course, - /// if you implement a function, it must use the same + /// if you implement a function, it must use the same /// elements of the action array for the same purpose since there is no learning /// involved.) /// @@ -1241,11 +1246,16 @@ public virtual void OnActionReceived(ActionBuffers actions) if (!actions.ContinuousActions.IsEmpty()) { - m_LegacyActionCache = actions.ContinuousActions.Array; + Array.Copy(actions.ContinuousActions.Array, + m_LegacyActionCache, + actionSpec.NumContinuousActions); } else { - m_LegacyActionCache = Array.ConvertAll(actions.DiscreteActions.Array, x => (float)x); + for (var i = 0; i < m_LegacyActionCache.Length; i++) + { + m_LegacyActionCache[i] = (float)actions.DiscreteActions[i]; + } } // Disable deprecation warnings so we can call the legacy overload. #pragma warning disable CS0618 diff --git a/com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs b/com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs index 100220e320..49e00918c1 100644 --- a/com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs +++ b/com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs @@ -196,12 +196,12 @@ public string FullyQualifiedBehaviorName get { return m_BehaviorName + "?team=" + TeamId; } } - internal IPolicy GeneratePolicy(ActionSpec actionSpec, HeuristicPolicy.ActionGenerator heuristic) + internal IPolicy GeneratePolicy(ActionSpec actionSpec, ActuatorManager actuatorManager) { switch (m_BehaviorType) { case BehaviorType.HeuristicOnly: - return new HeuristicPolicy(heuristic, actionSpec); + return new HeuristicPolicy(actuatorManager, actionSpec); case BehaviorType.InferenceOnly: { if (m_Model == null) @@ -225,10 +225,10 @@ internal IPolicy GeneratePolicy(ActionSpec actionSpec, HeuristicPolicy.ActionGen } else { - return new HeuristicPolicy(heuristic, actionSpec); + return new HeuristicPolicy(actuatorManager, actionSpec); } default: - return new HeuristicPolicy(heuristic, actionSpec); + return new HeuristicPolicy(actuatorManager, actionSpec); } } @@ -241,6 +241,5 @@ internal void UpdateAgentPolicy() } agent.ReloadPolicy(); } - } } diff --git a/com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs b/com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs index 217f226857..7163a8face 100644 --- a/com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs +++ b/com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs @@ -7,14 +7,13 @@ namespace Unity.MLAgents.Policies { /// - /// The Heuristic Policy uses a hards coded Heuristic method + /// The Heuristic Policy uses a hard-coded Heuristic method /// to take decisions each time the RequestDecision method is /// called. /// internal class HeuristicPolicy : IPolicy { - public delegate void ActionGenerator(in ActionBuffers actionBuffers); - ActionGenerator m_Heuristic; + ActuatorManager m_ActuatorManager; ActionBuffers m_ActionBuffers; bool m_Done; bool m_DecisionRequested; @@ -24,9 +23,9 @@ internal class HeuristicPolicy : IPolicy /// - public HeuristicPolicy(ActionGenerator heuristic, ActionSpec actionSpec) + public HeuristicPolicy(ActuatorManager actuatorManager, ActionSpec actionSpec) { - m_Heuristic = heuristic; + m_ActuatorManager = actuatorManager; var numContinuousActions = actionSpec.NumContinuousActions; var numDiscreteActions = actionSpec.NumDiscreteActions; var continuousDecision = new ActionSegment(new float[numContinuousActions], 0, numContinuousActions); @@ -47,7 +46,7 @@ public ref readonly ActionBuffers DecideAction() { if (!m_Done && m_DecisionRequested) { - m_Heuristic.Invoke(m_ActionBuffers); + m_ActuatorManager.ApplyHeuristic(m_ActionBuffers); } m_DecisionRequested = false; return ref m_ActionBuffers; diff --git a/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs b/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs index e10cea84da..e095557ced 100644 --- a/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs +++ b/com.unity.ml-agents/Tests/Editor/Actuators/ActuatorManagerTests.cs @@ -303,5 +303,23 @@ public void TestWriteDiscreteActionMask() manager.WriteActionMask(); Assert.IsTrue(groundTruthMask.SequenceEqual(manager.DiscreteActionMask.GetMask())); } + + [Test] + public void TestHeuristic() + { + var manager = new ActuatorManager(2); + var va1 = new TestActuator(ActionSpec.MakeDiscrete(1, 2, 3), "name"); + var va2 = new TestActuator(ActionSpec.MakeDiscrete(3, 2, 1, 8), "name1"); + manager.Add(va1); + manager.Add(va2); + + var actionBuf = new ActionBuffers(Array.Empty(), new[] { 0, 0, 0, 0, 0, 0, 0 }); + manager.ApplyHeuristic(actionBuf); + + Assert.IsTrue(va1.m_HeuristicCalled); + Assert.AreEqual(va1.m_DiscreteBufferSize, 3); + Assert.IsTrue(va2.m_HeuristicCalled); + Assert.AreEqual(va2.m_DiscreteBufferSize, 4); + } } } diff --git a/com.unity.ml-agents/Tests/Editor/Actuators/TestActuator.cs b/com.unity.ml-agents/Tests/Editor/Actuators/TestActuator.cs index d0d0c4d2e9..649a643320 100644 --- a/com.unity.ml-agents/Tests/Editor/Actuators/TestActuator.cs +++ b/com.unity.ml-agents/Tests/Editor/Actuators/TestActuator.cs @@ -1,10 +1,13 @@ using Unity.MLAgents.Actuators; namespace Unity.MLAgents.Tests.Actuators { - internal class TestActuator : IActuator + internal class TestActuator : IActuator, IHeuristicProvider { public ActionBuffers LastActionBuffer; public int[][] Masks; + public bool m_HeuristicCalled; + public int m_DiscreteBufferSize; + public TestActuator(ActionSpec actuatorSpace, string name) { ActionSpec = actuatorSpace; @@ -32,5 +35,11 @@ public void WriteDiscreteActionMask(IDiscreteActionMask actionMask) public void ResetData() { } + + public void Heuristic(in ActionBuffers actionBuffersOut) + { + m_HeuristicCalled = true; + m_DiscreteBufferSize = actionBuffersOut.DiscreteActions.Length; + } } } diff --git a/com.unity.ml-agents/Tests/Editor/Actuators/VectorActuatorTests.cs b/com.unity.ml-agents/Tests/Editor/Actuators/VectorActuatorTests.cs index 1b9625463c..2b3dcabfef 100644 --- a/com.unity.ml-agents/Tests/Editor/Actuators/VectorActuatorTests.cs +++ b/com.unity.ml-agents/Tests/Editor/Actuators/VectorActuatorTests.cs @@ -1,3 +1,4 @@ +using System; using System.Collections.Generic; using System.Linq; using NUnit.Framework; @@ -9,12 +10,13 @@ namespace Unity.MLAgents.Tests.Actuators [TestFixture] public class VectorActuatorTests { - class TestActionReceiver : IActionReceiver + class TestActionReceiver : IActionReceiver, IHeuristicProvider { public ActionBuffers LastActionBuffers; public int Branch; public IList Mask; public ActionSpec ActionSpec { get; } + public bool HeuristicCalled; public void OnActionReceived(ActionBuffers actionBuffers) { @@ -25,6 +27,11 @@ public void WriteDiscreteActionMask(IDiscreteActionMask actionMask) { actionMask.WriteMask(Branch, Mask); } + + public void Heuristic(in ActionBuffers actionBuffersOut) + { + HeuristicCalled = true; + } } [Test] @@ -93,5 +100,15 @@ public void TestWriteDiscreteActionMask() Assert.IsTrue(groundTruthMask.SequenceEqual(bdam.GetMask())); } + + [Test] + public void TestHeuristic() + { + var ar = new TestActionReceiver(); + var va = new VectorActuator(ar, ActionSpec.MakeDiscrete(1, 2, 3), "name"); + + va.Heuristic(new ActionBuffers(Array.Empty(), va.ActionSpec.BranchSizes)); + Assert.IsTrue(ar.HeuristicCalled); + } } } diff --git a/com.unity.ml-agents/Tests/Editor/BehaviorParameterTests.cs b/com.unity.ml-agents/Tests/Editor/BehaviorParameterTests.cs index f4d46b7c04..3d79558935 100644 --- a/com.unity.ml-agents/Tests/Editor/BehaviorParameterTests.cs +++ b/com.unity.ml-agents/Tests/Editor/BehaviorParameterTests.cs @@ -6,9 +6,9 @@ namespace Unity.MLAgents.Tests { [TestFixture] - public class BehaviorParameterTests + public class BehaviorParameterTests : IHeuristicProvider { - static void DummyHeuristic(in ActionBuffers actionsOut) + public void Heuristic(in ActionBuffers actionsOut) { // No-op } @@ -23,7 +23,7 @@ public void TestNoModelInferenceOnlyThrows() Assert.Throws(() => { - bp.GeneratePolicy(actionSpec, DummyHeuristic); + bp.GeneratePolicy(actionSpec, new ActuatorManager()); }); } } diff --git a/docs/Migrating.md b/docs/Migrating.md index afdc1ce40a..dd0d6f77e2 100644 --- a/docs/Migrating.md +++ b/docs/Migrating.md @@ -12,6 +12,14 @@ double-check that the versions are in the same. The versions can be found in - `UnityEnvironment.API_VERSION` in environment.py ([example](https://github.com/Unity-Technologies/ml-agents/blob/b255661084cb8f701c716b040693069a3fb9a257/ml-agents-envs/mlagents/envs/environment.py#L45)) + +# Migrating +## Migrating to Release 13 +### Implementing IHeuristic in your IActuator implementations + - If you have any custom actuators, you can now implement the `IHeuristicProvider` interface to have your actuator +handle the generation of actions when an Agent is running in heuristic mode. + + # Migrating ## Migrating to Release 11 ### Agent virtual method deprecation