diff --git a/Project/Assets/ML-Agents/Examples/DungeonEscape/Scripts/PushAgentEscape.cs b/Project/Assets/ML-Agents/Examples/DungeonEscape/Scripts/PushAgentEscape.cs
index 6f31ce29b1..5f7c426e39 100644
--- a/Project/Assets/ML-Agents/Examples/DungeonEscape/Scripts/PushAgentEscape.cs
+++ b/Project/Assets/ML-Agents/Examples/DungeonEscape/Scripts/PushAgentEscape.cs
@@ -116,7 +116,6 @@ void OnTriggerEnter(Collider col)
public override void Heuristic(in ActionBuffers actionsOut)
{
var discreteActionsOut = actionsOut.DiscreteActions;
- discreteActionsOut[0] = 0;
if (Input.GetKey(KeyCode.D))
{
discreteActionsOut[0] = 3;
diff --git a/Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs b/Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs
index 1a8801f0f0..e9e173e56f 100644
--- a/Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs
+++ b/Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorAgent.cs
@@ -192,9 +192,6 @@ public override void OnActionReceived(ActionBuffers actionBuffers)
public override void Heuristic(in ActionBuffers actionsOut)
{
var continuousActionsOut = actionsOut.ContinuousActions;
- continuousActionsOut[0] = 0;
- continuousActionsOut[1] = 0;
- continuousActionsOut[2] = 0;
if (Input.GetKey(KeyCode.D))
{
continuousActionsOut[2] = 1;
diff --git a/Project/Assets/ML-Agents/Examples/Hallway/Scripts/HallwayAgent.cs b/Project/Assets/ML-Agents/Examples/Hallway/Scripts/HallwayAgent.cs
index aa7daf1a57..86ed43c533 100644
--- a/Project/Assets/ML-Agents/Examples/Hallway/Scripts/HallwayAgent.cs
+++ b/Project/Assets/ML-Agents/Examples/Hallway/Scripts/HallwayAgent.cs
@@ -100,7 +100,6 @@ void OnCollisionEnter(Collision col)
public override void Heuristic(in ActionBuffers actionsOut)
{
var discreteActionsOut = actionsOut.DiscreteActions;
- discreteActionsOut[0] = 0;
if (Input.GetKey(KeyCode.D))
{
discreteActionsOut[0] = 3;
diff --git a/Project/Assets/ML-Agents/Examples/PushBlock/Scripts/PushAgentBasic.cs b/Project/Assets/ML-Agents/Examples/PushBlock/Scripts/PushAgentBasic.cs
index 7ed48de608..488c96b022 100644
--- a/Project/Assets/ML-Agents/Examples/PushBlock/Scripts/PushAgentBasic.cs
+++ b/Project/Assets/ML-Agents/Examples/PushBlock/Scripts/PushAgentBasic.cs
@@ -177,7 +177,6 @@ public override void OnActionReceived(ActionBuffers actionBuffers)
public override void Heuristic(in ActionBuffers actionsOut)
{
var discreteActionsOut = actionsOut.DiscreteActions;
- discreteActionsOut[0] = 0;
if (Input.GetKey(KeyCode.D))
{
discreteActionsOut[0] = 3;
diff --git a/Project/Assets/ML-Agents/Examples/PushBlock/Scripts/PushAgentCollab.cs b/Project/Assets/ML-Agents/Examples/PushBlock/Scripts/PushAgentCollab.cs
index 5e09e5c20e..fd2fe5fd72 100644
--- a/Project/Assets/ML-Agents/Examples/PushBlock/Scripts/PushAgentCollab.cs
+++ b/Project/Assets/ML-Agents/Examples/PushBlock/Scripts/PushAgentCollab.cs
@@ -71,7 +71,6 @@ public override void OnActionReceived(ActionBuffers actionBuffers)
public override void Heuristic(in ActionBuffers actionsOut)
{
var discreteActionsOut = actionsOut.DiscreteActions;
- discreteActionsOut[0] = 0;
if (Input.GetKey(KeyCode.D))
{
discreteActionsOut[0] = 3;
diff --git a/Project/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidAgent.cs b/Project/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidAgent.cs
index bc848c48d4..5371a33d84 100644
--- a/Project/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidAgent.cs
+++ b/Project/Assets/ML-Agents/Examples/Pyramids/Scripts/PyramidAgent.cs
@@ -66,7 +66,6 @@ public override void OnActionReceived(ActionBuffers actionBuffers)
public override void Heuristic(in ActionBuffers actionsOut)
{
var discreteActionsOut = actionsOut.DiscreteActions;
- discreteActionsOut[0] = 0;
if (Input.GetKey(KeyCode.D))
{
discreteActionsOut[0] = 3;
diff --git a/Project/Assets/ML-Agents/Examples/Soccer/Scripts/AgentSoccer.cs b/Project/Assets/ML-Agents/Examples/Soccer/Scripts/AgentSoccer.cs
index abddb5485f..a794971162 100644
--- a/Project/Assets/ML-Agents/Examples/Soccer/Scripts/AgentSoccer.cs
+++ b/Project/Assets/ML-Agents/Examples/Soccer/Scripts/AgentSoccer.cs
@@ -153,7 +153,6 @@ public override void OnActionReceived(ActionBuffers actionBuffers)
public override void Heuristic(in ActionBuffers actionsOut)
{
var discreteActionsOut = actionsOut.DiscreteActions;
- discreteActionsOut.Clear();
//forward
if (Input.GetKey(KeyCode.W))
{
diff --git a/Project/Assets/ML-Agents/Examples/Sorter/Scripts/SorterAgent.cs b/Project/Assets/ML-Agents/Examples/Sorter/Scripts/SorterAgent.cs
index 96b7a9781c..c74b6e3f19 100644
--- a/Project/Assets/ML-Agents/Examples/Sorter/Scripts/SorterAgent.cs
+++ b/Project/Assets/ML-Agents/Examples/Sorter/Scripts/SorterAgent.cs
@@ -238,7 +238,6 @@ public override void OnActionReceived(ActionBuffers actionBuffers)
public override void Heuristic(in ActionBuffers actionsOut)
{
var discreteActionsOut = actionsOut.DiscreteActions;
- discreteActionsOut.Clear();
//forward
if (Input.GetKey(KeyCode.W))
{
diff --git a/Project/Assets/ML-Agents/Examples/WallJump/Scripts/WallJumpAgent.cs b/Project/Assets/ML-Agents/Examples/WallJump/Scripts/WallJumpAgent.cs
index 0a34632578..d9379c6629 100644
--- a/Project/Assets/ML-Agents/Examples/WallJump/Scripts/WallJumpAgent.cs
+++ b/Project/Assets/ML-Agents/Examples/WallJump/Scripts/WallJumpAgent.cs
@@ -264,7 +264,6 @@ public override void OnActionReceived(ActionBuffers actionBuffers)
public override void Heuristic(in ActionBuffers actionsOut)
{
var discreteActionsOut = actionsOut.DiscreteActions;
- discreteActionsOut.Clear();
if (Input.GetKey(KeyCode.D))
{
discreteActionsOut[1] = 2;
diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md
index 4c5aafd9a9..a8f1c7dd5e 100755
--- a/com.unity.ml-agents/CHANGELOG.md
+++ b/com.unity.ml-agents/CHANGELOG.md
@@ -48,6 +48,8 @@ depend on the previous behavior, you can explicitly set the Agent's `InferenceDe
- Added support for `Goal Signal` as a type of observation. Trainers can now use HyperNetworks to process `Goal Signal`. Trainers with HyperNetworks are more effective at solving multiple tasks. (#5142, #5159, #5149)
- Modified the [GridWorld environment](https://github.com/Unity-Technologies/ml-agents/blob/main/docs/Learning-Environment-Examples.md#gridworld) to use the new `Goal Signal` feature. (#5193)
- `RaycastPerceptionSensor` now caches its raycast results; they can be accessed via `RayPerceptionSensor.RayPerceptionOutput`. (#5222)
+- `ActionBuffers` are now reset to zero before being passed to `Agent.Heuristic()` and
+`IHeuristicProvider.Heuristic()`. (#5227)
#### ml-agents / ml-agents-envs / gym-unity (Python)
- Some console output have been moved from `info` to `debug` and will not be printed by default. If you want all messages to be printed, you can run `mlagents-learn` with the `--debug` option or add the line `debug: true` at the top of the yaml config file. (#5211)
diff --git a/com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs b/com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs
index ae7273001c..b18956833a 100644
--- a/com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs
+++ b/com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs
@@ -46,6 +46,7 @@ public ref readonly ActionBuffers DecideAction()
{
if (!m_Done && m_DecisionRequested)
{
+ m_ActionBuffers.Clear();
m_ActuatorManager.ApplyHeuristic(m_ActionBuffers);
}
m_DecisionRequested = false;
diff --git a/com.unity.ml-agents/Tests/Editor/Policies.meta b/com.unity.ml-agents/Tests/Editor/Policies.meta
new file mode 100644
index 0000000000..be3f189b91
--- /dev/null
+++ b/com.unity.ml-agents/Tests/Editor/Policies.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: df271cac120e4d6893b14599fa8eb64d
+timeCreated: 1617813392
\ No newline at end of file
diff --git a/com.unity.ml-agents/Tests/Editor/Policies/HeuristicPolicyTest.cs b/com.unity.ml-agents/Tests/Editor/Policies/HeuristicPolicyTest.cs
new file mode 100644
index 0000000000..944b7ff907
--- /dev/null
+++ b/com.unity.ml-agents/Tests/Editor/Policies/HeuristicPolicyTest.cs
@@ -0,0 +1,125 @@
+using NUnit.Framework;
+using Unity.MLAgents.Actuators;
+using Unity.MLAgents.Policies;
+using UnityEngine;
+
+namespace Unity.MLAgents.Tests.Policies
+{
+ [TestFixture]
+ public class HeuristicPolicyTest
+ {
+ [SetUp]
+ public void SetUp()
+ {
+ if (Academy.IsInitialized)
+ {
+ Academy.Instance.Dispose();
+ }
+ }
+
+ ///
+ /// Assert that the action buffers are initialized to zero, and then set them to non-zero values.
+ ///
+ ///
+ static void CheckAndSetBuffer(in ActionBuffers actionsOut)
+ {
+ var continuousActions = actionsOut.ContinuousActions;
+ for (var continuousIndex = 0; continuousIndex < continuousActions.Length; continuousIndex++)
+ {
+ Assert.AreEqual(continuousActions[continuousIndex], 0.0f);
+ continuousActions[continuousIndex] = 1.0f;
+ }
+
+ var discreteActions = actionsOut.DiscreteActions;
+ for (var discreteIndex = 0; discreteIndex < discreteActions.Length; discreteIndex++)
+ {
+ Assert.AreEqual(discreteActions[discreteIndex], 0);
+ discreteActions[discreteIndex] = 1;
+ }
+ }
+
+
+ class ActionClearedAgent : Agent
+ {
+ public int HeuristicCalls = 0;
+ public override void Heuristic(in ActionBuffers actionsOut)
+ {
+ CheckAndSetBuffer(actionsOut);
+ HeuristicCalls++;
+ }
+ }
+
+ class ActionClearedActuator : IActuator
+ {
+ public int HeuristicCalls = 0;
+ public ActionClearedActuator(ActionSpec actionSpec)
+ {
+ ActionSpec = actionSpec;
+ Name = GetType().Name;
+ }
+
+ public void OnActionReceived(ActionBuffers actionBuffers)
+ {
+ }
+
+ public void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
+ {
+ }
+
+ public void Heuristic(in ActionBuffers actionBuffersOut)
+ {
+ CheckAndSetBuffer(actionBuffersOut);
+ HeuristicCalls++;
+ }
+
+ public ActionSpec ActionSpec { get; }
+ public string Name { get; }
+
+ public void ResetData()
+ {
+
+ }
+ }
+
+ class ActionClearedActuatorComponent : ActuatorComponent
+ {
+ public ActionClearedActuator ActionClearedActuator;
+ public ActionClearedActuatorComponent()
+ {
+ ActionSpec = new ActionSpec(2, new[] { 3, 3 });
+ }
+
+ public override IActuator[] CreateActuators()
+ {
+ ActionClearedActuator = new ActionClearedActuator(ActionSpec);
+ return new IActuator[] { ActionClearedActuator };
+ }
+
+ public override ActionSpec ActionSpec { get; }
+ }
+
+ [Test]
+ public void TestActionsCleared()
+ {
+ var gameObj = new GameObject();
+ var agent = gameObj.AddComponent();
+ var behaviorParameters = agent.GetComponent();
+ behaviorParameters.BrainParameters.ActionSpec = new ActionSpec(1, new[] { 4 });
+ behaviorParameters.BrainParameters.VectorObservationSize = 0;
+ behaviorParameters.BehaviorType = BehaviorType.HeuristicOnly;
+
+ var actuatorComponent = gameObj.AddComponent();
+ agent.LazyInitialize();
+
+ const int k_NumSteps = 5;
+ for (var i = 0; i < k_NumSteps; i++)
+ {
+ agent.RequestDecision();
+ Academy.Instance.EnvironmentStep();
+ }
+
+ Assert.AreEqual(agent.HeuristicCalls, k_NumSteps);
+ Assert.AreEqual(actuatorComponent.ActionClearedActuator.HeuristicCalls, k_NumSteps);
+ }
+ }
+}
diff --git a/com.unity.ml-agents/Tests/Editor/Policies/HeuristicPolicyTest.cs.meta b/com.unity.ml-agents/Tests/Editor/Policies/HeuristicPolicyTest.cs.meta
new file mode 100644
index 0000000000..682a64b746
--- /dev/null
+++ b/com.unity.ml-agents/Tests/Editor/Policies/HeuristicPolicyTest.cs.meta
@@ -0,0 +1,3 @@
+fileFormatVersion: 2
+guid: 5108e92f91a04ddab9d628c9bc57cadb
+timeCreated: 1617813411
\ No newline at end of file