Skip to content

Commit 8278515

Browse files
author
Chris Elion
authored
[MLA-1159] Add virtual methods to DecisionRequester (#5223)
1 parent ac4f43c commit 8278515

File tree

3 files changed

+70
-2
lines changed

3 files changed

+70
-2
lines changed

com.unity.ml-agents/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ different sizes using the same model. For a summary of the interface changes, pl
4747
depend on the previous behavior, you can explicitly set the Agent's `InferenceDevice` to `InferenceDevice.CPU`. (#5175)
4848
- Added support for `Goal Signal` as a type of observation. Trainers can now use HyperNetworks to process `Goal Signal`. Trainers with HyperNetworks are more effective at solving multiple tasks. (#5142, #5159, #5149)
4949
- Modified the [GridWorld environment](https://github.com/Unity-Technologies/ml-agents/blob/main/docs/Learning-Environment-Examples.md#gridworld) to use the new `Goal Signal` feature. (#5193)
50+
- `DecisionRequester.ShouldRequestDecision()` and `ShouldRequestAction()`methods were added. These are used to
51+
determine whether `Agent.RequestDecision()` and `Agent.RequestAction()` are called (respectively). (#5223)
5052
- `RaycastPerceptionSensor` now caches its raycast results; they can be accessed via `RayPerceptionSensor.RayPerceptionOutput`. (#5222)
5153

5254
#### ml-agents / ml-agents-envs / gym-unity (Python)

com.unity.ml-agents/Runtime/DecisionRequester.cs

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,14 @@ public class DecisionRequester : MonoBehaviour
4242
[NonSerialized]
4343
Agent m_Agent;
4444

45+
/// <summary>
46+
/// Get the Agent attached to the DecisionRequester.
47+
/// </summary>
48+
public Agent Agent
49+
{
50+
get => m_Agent;
51+
}
52+
4553
internal void Awake()
4654
{
4755
m_Agent = gameObject.GetComponent<Agent>();
@@ -57,21 +65,58 @@ void OnDestroy()
5765
}
5866
}
5967

68+
/// <summary>
69+
/// Information about Academy step used to make decisions about whether to request a decision.
70+
/// </summary>
71+
public struct DecisionRequestContext
72+
{
73+
/// <summary>
74+
/// The current step count of the Academy, equivalent to Academy.StepCount.
75+
/// </summary>
76+
public int AcademyStepCount;
77+
}
78+
6079
/// <summary>
6180
/// Method that hooks into the Academy in order inform the Agent on whether or not it should request a
6281
/// decision, and whether or not it should take actions between decisions.
6382
/// </summary>
6483
/// <param name="academyStepCount">The current step count of the academy.</param>
6584
void MakeRequests(int academyStepCount)
6685
{
67-
if (academyStepCount % DecisionPeriod == 0)
86+
var context = new DecisionRequestContext
87+
{
88+
AcademyStepCount = academyStepCount
89+
};
90+
91+
if (ShouldRequestDecision(context))
6892
{
6993
m_Agent?.RequestDecision();
7094
}
71-
if (TakeActionsBetweenDecisions)
95+
96+
if (ShouldRequestAction(context))
7297
{
7398
m_Agent?.RequestAction();
7499
}
75100
}
101+
102+
/// <summary>
103+
/// Whether Agent.RequestDecision should be called on this update step.
104+
/// </summary>
105+
/// <param name="context"></param>
106+
/// <returns></returns>
107+
protected virtual bool ShouldRequestDecision(DecisionRequestContext context)
108+
{
109+
return context.AcademyStepCount % DecisionPeriod == 0;
110+
}
111+
112+
/// <summary>
113+
/// Whether Agent.RequestAction should be called on this update step.
114+
/// </summary>
115+
/// <param name="context"></param>
116+
/// <returns></returns>
117+
protected virtual bool ShouldRequestAction(DecisionRequestContext context)
118+
{
119+
return TakeActionsBetweenDecisions;
120+
}
76121
}
77122
}

com.unity.ml-agents/Tests/Editor/PublicAPI/PublicApiValidation.cs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System.Collections.Generic;
22
using Unity.MLAgents.Sensors;
33
using NUnit.Framework;
4+
using Unity.MLAgents;
45
using UnityEngine;
56

67
namespace Unity.MLAgentsExamples
@@ -76,5 +77,25 @@ public void CheckSetupRayPerceptionSensorComponent()
7677
Assert.AreEqual(outputs.RayOutputs.Length, 2*sensorComponent.RaysPerDirection + 1);
7778
}
7879
#endif
80+
81+
/// <summary>
82+
/// Make sure we can inherit from DecisionRequester and override some logic.
83+
/// </summary>
84+
class CustomDecisionRequester : DecisionRequester
85+
{
86+
/// <summary>
87+
/// Example logic. If the killswitch flag is set, the Agent never requests a decision.
88+
/// </summary>
89+
public bool KillswitchEnabled;
90+
91+
public CustomDecisionRequester()
92+
{
93+
}
94+
95+
protected override bool ShouldRequestDecision(DecisionRequestContext context)
96+
{
97+
return !KillswitchEnabled && base.ShouldRequestDecision(context);
98+
}
99+
}
79100
}
80101
}

0 commit comments

Comments
 (0)