Skip to content

[Cherry-pick for 2.0.1 verified patch]Harden user PII in analytics (#5512) harden analytics #5604

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,19 @@ and this project adheres to
[Semantic Versioning](http://semver.org/spec/v2.0.0.html).
## [2.0.1] - 2021-10-13
### Minor Changes
#### com.unity.ml-agents (C#)
#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)
- Upgrade to 2.0.1
- Update gRPC native lib to universal for arm64 and x86_64. This change should enable ml-agents usage on mac M1 (#5283, #5519)
#### ml-agents / ml-agents-envs / gym-unity (Python)
- Set gym version in gym-unity to gym release 0.20.0
- Set gym version in gym-unity to gym release 0.20.0(#5540)
- Harden user PII protection logic and extend TrainingAnalytics to expose detailed configuration parameters. (#5512)
- Added minimal analytics collection to LL-API (#5511)

### Bug Fixes
#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)
- Fixed the bug where curriculum learning would crash because of the incorrect run_options parsing. (#5586)
#### ml-agents / ml-agents-envs / gym-unity (Python)

## [2.0.0] - 2021-09-01
### Minor Changes
#### com.unity.ml-agents (C#)
Expand Down
36 changes: 31 additions & 5 deletions com.unity.ml-agents/Runtime/Analytics/AnalyticsUtils.cs
Original file line number Diff line number Diff line change
@@ -1,19 +1,45 @@
using System;
using System.Text;
using System.Security.Cryptography;
using UnityEngine;

namespace Unity.MLAgents.Analytics
{

internal static class AnalyticsUtils
{
/// <summary>
/// Conversion function from byte array to hex string
/// </summary>
/// <param name="array"></param>
/// <returns>A byte array to be hex encoded.</returns>
private static string ToHexString(byte[] array)
{
StringBuilder hex = new StringBuilder(array.Length * 2);
foreach (byte b in array)
{
hex.AppendFormat("{0:x2}", b);
}
return hex.ToString();
}

/// <summary>
/// Hash a string to remove PII or secret info before sending to analytics
/// </summary>
/// <param name="s"></param>
/// <returns>A string containing the Hash128 of the input string.</returns>
public static string Hash(string s)
/// <param name="key"></param>
/// <returns>A string containing the key to be used for HMAC encoding.</returns>
/// <param name="value"></param>
/// <returns>A string containing the value to be encoded.</returns>
public static string Hash(string key, string value)
{
var behaviorNameHash = Hash128.Compute(s);
return behaviorNameHash.ToString();
string hash;
UTF8Encoding encoder = new UTF8Encoding();
using (HMACSHA256 hmac = new HMACSHA256(encoder.GetBytes(key)))
{
Byte[] hmBytes = hmac.ComputeHash(encoder.GetBytes(value));
hash = ToHexString(hmBytes);
}
return hash;
}

internal static bool s_SendEditorAnalytics = true;
Expand Down
2 changes: 2 additions & 0 deletions com.unity.ml-agents/Runtime/Analytics/Events.cs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ internal struct TrainingEnvironmentInitializedEvent
public string TorchDeviceType;
public int NumEnvironments;
public int NumEnvironmentParameters;
public string RunOptions;
}

[Flags]
Expand Down Expand Up @@ -188,5 +189,6 @@ internal struct TrainingBehaviorInitializedEvent
public string VisualEncoder;
public int NumNetworkLayers;
public int NumNetworkHiddenUnits;
public string Config;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ IList<IActuator> actuators
var inferenceEvent = new InferenceEvent();

// Hash the behavior name so that there's no concern about PII or "secret" data being leaked.
inferenceEvent.BehaviorName = AnalyticsUtils.Hash(behaviorName);
inferenceEvent.BehaviorName = AnalyticsUtils.Hash(k_VendorKey, behaviorName);

inferenceEvent.BarracudaModelSource = barracudaModel.IrSource;
inferenceEvent.BarracudaModelVersion = barracudaModel.IrVersion;
Expand Down
20 changes: 16 additions & 4 deletions com.unity.ml-agents/Runtime/Analytics/TrainingAnalytics.cs
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,21 @@ internal static string ParseBehaviorName(string fullyQualifiedBehaviorName)
return fullyQualifiedBehaviorName.Substring(0, lastQuestionIndex);
}

internal static TrainingBehaviorInitializedEvent SanitizeTrainingBehaviorInitializedEvent(TrainingBehaviorInitializedEvent tbiEvent)
{
// Hash the behavior name if the message version is from an older version of ml-agents that doesn't do trainer-side hashing.
// We'll also, for extra safety, verify that the BehaviorName is the size of the expected SHA256 hash.
// Context: The config field was added at the same time as trainer side hashing, so messages including it should already be hashed.
if (tbiEvent.Config.Length == 0 || tbiEvent.BehaviorName.Length != 64)
{
tbiEvent.BehaviorName = AnalyticsUtils.Hash(k_VendorKey, tbiEvent.BehaviorName);
}

return tbiEvent;
}

[Conditional("MLA_UNITY_ANALYTICS_MODULE")]
public static void TrainingBehaviorInitialized(TrainingBehaviorInitializedEvent tbiEvent)
public static void TrainingBehaviorInitialized(TrainingBehaviorInitializedEvent rawTbiEvent)
{
#if UNITY_EDITOR && MLA_UNITY_ANALYTICS_MODULE
if (!IsAnalyticsEnabled())
Expand All @@ -202,6 +215,7 @@ public static void TrainingBehaviorInitialized(TrainingBehaviorInitializedEvent
if (!EnableAnalytics())
return;

var tbiEvent = SanitizeTrainingBehaviorInitializedEvent(rawTbiEvent);
var behaviorName = tbiEvent.BehaviorName;
var added = s_SentTrainingBehaviorInitialized.Add(behaviorName);

Expand All @@ -211,9 +225,7 @@ public static void TrainingBehaviorInitialized(TrainingBehaviorInitializedEvent
return;
}

// Hash the behavior name so that there's no concern about PII or "secret" data being leaked.
tbiEvent.TrainingSessionGuid = s_TrainingSessionGuid.ToString();
tbiEvent.BehaviorName = AnalyticsUtils.Hash(tbiEvent.BehaviorName);

// Note - to debug, use JsonUtility.ToJson on the event.
// Debug.Log(
Expand All @@ -236,7 +248,7 @@ IList<IActuator> actuators
var remotePolicyEvent = new RemotePolicyInitializedEvent();

// Hash the behavior name so that there's no concern about PII or "secret" data being leaked.
remotePolicyEvent.BehaviorName = AnalyticsUtils.Hash(behaviorName);
remotePolicyEvent.BehaviorName = AnalyticsUtils.Hash(k_VendorKey, behaviorName);

remotePolicyEvent.TrainingSessionGuid = s_TrainingSessionGuid.ToString();
remotePolicyEvent.ActionSpec = EventActionSpec.FromActionSpec(actionSpec);
Expand Down
2 changes: 2 additions & 0 deletions com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ internal static TrainingEnvironmentInitializedEvent ToTrainingEnvironmentInitial
TorchDeviceType = inputProto.TorchDeviceType,
NumEnvironments = inputProto.NumEnvs,
NumEnvironmentParameters = inputProto.NumEnvironmentParameters,
RunOptions = inputProto.RunOptions,
};
}

Expand Down Expand Up @@ -530,6 +531,7 @@ internal static TrainingBehaviorInitializedEvent ToTrainingBehaviorInitializedEv
VisualEncoder = inputProto.VisualEncoder,
NumNetworkLayers = inputProto.NumNetworkLayers,
NumNetworkHiddenUnits = inputProto.NumNetworkHiddenUnits,
Config = inputProto.Config,
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,28 +25,29 @@ static TrainingAnalyticsReflection() {
byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"CjttbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL3RyYWluaW5n",
"X2FuYWx5dGljcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMi2QEKHlRy",
"X2FuYWx5dGljcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMi7gEKHlRy",
"YWluaW5nRW52aXJvbm1lbnRJbml0aWFsaXplZBIYChBtbGFnZW50c192ZXJz",
"aW9uGAEgASgJEh0KFW1sYWdlbnRzX2VudnNfdmVyc2lvbhgCIAEoCRIWCg5w",
"eXRob25fdmVyc2lvbhgDIAEoCRIVCg10b3JjaF92ZXJzaW9uGAQgASgJEhkK",
"EXRvcmNoX2RldmljZV90eXBlGAUgASgJEhAKCG51bV9lbnZzGAYgASgFEiIK",
"Gm51bV9lbnZpcm9ubWVudF9wYXJhbWV0ZXJzGAcgASgFIq0DChtUcmFpbmlu",
"Z0JlaGF2aW9ySW5pdGlhbGl6ZWQSFQoNYmVoYXZpb3JfbmFtZRgBIAEoCRIU",
"Cgx0cmFpbmVyX3R5cGUYAiABKAkSIAoYZXh0cmluc2ljX3Jld2FyZF9lbmFi",
"bGVkGAMgASgIEhsKE2dhaWxfcmV3YXJkX2VuYWJsZWQYBCABKAgSIAoYY3Vy",
"aW9zaXR5X3Jld2FyZF9lbmFibGVkGAUgASgIEhoKEnJuZF9yZXdhcmRfZW5h",
"YmxlZBgGIAEoCBIiChpiZWhhdmlvcmFsX2Nsb25pbmdfZW5hYmxlZBgHIAEo",
"CBIZChFyZWN1cnJlbnRfZW5hYmxlZBgIIAEoCBIWCg52aXN1YWxfZW5jb2Rl",
"chgJIAEoCRIaChJudW1fbmV0d29ya19sYXllcnMYCiABKAUSIAoYbnVtX25l",
"dHdvcmtfaGlkZGVuX3VuaXRzGAsgASgFEhgKEHRyYWluZXJfdGhyZWFkZWQY",
"DCABKAgSGQoRc2VsZl9wbGF5X2VuYWJsZWQYDSABKAgSGgoSY3VycmljdWx1",
"bV9lbmFibGVkGA4gASgIQiWqAiJVbml0eS5NTEFnZW50cy5Db21tdW5pY2F0",
"b3JPYmplY3RzYgZwcm90bzM="));
"Gm51bV9lbnZpcm9ubWVudF9wYXJhbWV0ZXJzGAcgASgFEhMKC3J1bl9vcHRp",
"b25zGAggASgJIr0DChtUcmFpbmluZ0JlaGF2aW9ySW5pdGlhbGl6ZWQSFQoN",
"YmVoYXZpb3JfbmFtZRgBIAEoCRIUCgx0cmFpbmVyX3R5cGUYAiABKAkSIAoY",
"ZXh0cmluc2ljX3Jld2FyZF9lbmFibGVkGAMgASgIEhsKE2dhaWxfcmV3YXJk",
"X2VuYWJsZWQYBCABKAgSIAoYY3VyaW9zaXR5X3Jld2FyZF9lbmFibGVkGAUg",
"ASgIEhoKEnJuZF9yZXdhcmRfZW5hYmxlZBgGIAEoCBIiChpiZWhhdmlvcmFs",
"X2Nsb25pbmdfZW5hYmxlZBgHIAEoCBIZChFyZWN1cnJlbnRfZW5hYmxlZBgI",
"IAEoCBIWCg52aXN1YWxfZW5jb2RlchgJIAEoCRIaChJudW1fbmV0d29ya19s",
"YXllcnMYCiABKAUSIAoYbnVtX25ldHdvcmtfaGlkZGVuX3VuaXRzGAsgASgF",
"EhgKEHRyYWluZXJfdGhyZWFkZWQYDCABKAgSGQoRc2VsZl9wbGF5X2VuYWJs",
"ZWQYDSABKAgSGgoSY3VycmljdWx1bV9lbmFibGVkGA4gASgIEg4KBmNvbmZp",
"ZxgPIAEoCUIlqgIiVW5pdHkuTUxBZ2VudHMuQ29tbXVuaWNhdG9yT2JqZWN0",
"c2IGcHJvdG8z"));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
new pbr::FileDescriptor[] { },
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.TrainingEnvironmentInitialized), global::Unity.MLAgents.CommunicatorObjects.TrainingEnvironmentInitialized.Parser, new[]{ "MlagentsVersion", "MlagentsEnvsVersion", "PythonVersion", "TorchVersion", "TorchDeviceType", "NumEnvs", "NumEnvironmentParameters" }, null, null, null),
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.TrainingBehaviorInitialized), global::Unity.MLAgents.CommunicatorObjects.TrainingBehaviorInitialized.Parser, new[]{ "BehaviorName", "TrainerType", "ExtrinsicRewardEnabled", "GailRewardEnabled", "CuriosityRewardEnabled", "RndRewardEnabled", "BehavioralCloningEnabled", "RecurrentEnabled", "VisualEncoder", "NumNetworkLayers", "NumNetworkHiddenUnits", "TrainerThreaded", "SelfPlayEnabled", "CurriculumEnabled" }, null, null, null)
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.TrainingEnvironmentInitialized), global::Unity.MLAgents.CommunicatorObjects.TrainingEnvironmentInitialized.Parser, new[]{ "MlagentsVersion", "MlagentsEnvsVersion", "PythonVersion", "TorchVersion", "TorchDeviceType", "NumEnvs", "NumEnvironmentParameters", "RunOptions" }, null, null, null),
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.TrainingBehaviorInitialized), global::Unity.MLAgents.CommunicatorObjects.TrainingBehaviorInitialized.Parser, new[]{ "BehaviorName", "TrainerType", "ExtrinsicRewardEnabled", "GailRewardEnabled", "CuriosityRewardEnabled", "RndRewardEnabled", "BehavioralCloningEnabled", "RecurrentEnabled", "VisualEncoder", "NumNetworkLayers", "NumNetworkHiddenUnits", "TrainerThreaded", "SelfPlayEnabled", "CurriculumEnabled", "Config" }, null, null, null)
}));
}
#endregion
Expand Down Expand Up @@ -85,6 +86,7 @@ public TrainingEnvironmentInitialized(TrainingEnvironmentInitialized other) : th
torchDeviceType_ = other.torchDeviceType_;
numEnvs_ = other.numEnvs_;
numEnvironmentParameters_ = other.numEnvironmentParameters_;
runOptions_ = other.runOptions_;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

Expand Down Expand Up @@ -170,6 +172,17 @@ public int NumEnvironmentParameters {
}
}

/// <summary>Field number for the "run_options" field.</summary>
public const int RunOptionsFieldNumber = 8;
private string runOptions_ = "";
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public string RunOptions {
get { return runOptions_; }
set {
runOptions_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
}
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as TrainingEnvironmentInitialized);
Expand All @@ -190,6 +203,7 @@ public bool Equals(TrainingEnvironmentInitialized other) {
if (TorchDeviceType != other.TorchDeviceType) return false;
if (NumEnvs != other.NumEnvs) return false;
if (NumEnvironmentParameters != other.NumEnvironmentParameters) return false;
if (RunOptions != other.RunOptions) return false;
return Equals(_unknownFields, other._unknownFields);
}

Expand All @@ -203,6 +217,7 @@ public override int GetHashCode() {
if (TorchDeviceType.Length != 0) hash ^= TorchDeviceType.GetHashCode();
if (NumEnvs != 0) hash ^= NumEnvs.GetHashCode();
if (NumEnvironmentParameters != 0) hash ^= NumEnvironmentParameters.GetHashCode();
if (RunOptions.Length != 0) hash ^= RunOptions.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}
Expand Down Expand Up @@ -244,6 +259,10 @@ public void WriteTo(pb::CodedOutputStream output) {
output.WriteRawTag(56);
output.WriteInt32(NumEnvironmentParameters);
}
if (RunOptions.Length != 0) {
output.WriteRawTag(66);
output.WriteString(RunOptions);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
Expand Down Expand Up @@ -273,6 +292,9 @@ public int CalculateSize() {
if (NumEnvironmentParameters != 0) {
size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumEnvironmentParameters);
}
if (RunOptions.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeStringSize(RunOptions);
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}
Expand Down Expand Up @@ -305,6 +327,9 @@ public void MergeFrom(TrainingEnvironmentInitialized other) {
if (other.NumEnvironmentParameters != 0) {
NumEnvironmentParameters = other.NumEnvironmentParameters;
}
if (other.RunOptions.Length != 0) {
RunOptions = other.RunOptions;
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

Expand Down Expand Up @@ -344,6 +369,10 @@ public void MergeFrom(pb::CodedInputStream input) {
NumEnvironmentParameters = input.ReadInt32();
break;
}
case 66: {
RunOptions = input.ReadString();
break;
}
}
}
}
Expand Down Expand Up @@ -389,6 +418,7 @@ public TrainingBehaviorInitialized(TrainingBehaviorInitialized other) : this() {
trainerThreaded_ = other.trainerThreaded_;
selfPlayEnabled_ = other.selfPlayEnabled_;
curriculumEnabled_ = other.curriculumEnabled_;
config_ = other.config_;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

Expand Down Expand Up @@ -551,6 +581,17 @@ public bool CurriculumEnabled {
}
}

/// <summary>Field number for the "config" field.</summary>
public const int ConfigFieldNumber = 15;
private string config_ = "";
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public string Config {
get { return config_; }
set {
config_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
}
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as TrainingBehaviorInitialized);
Expand Down Expand Up @@ -578,6 +619,7 @@ public bool Equals(TrainingBehaviorInitialized other) {
if (TrainerThreaded != other.TrainerThreaded) return false;
if (SelfPlayEnabled != other.SelfPlayEnabled) return false;
if (CurriculumEnabled != other.CurriculumEnabled) return false;
if (Config != other.Config) return false;
return Equals(_unknownFields, other._unknownFields);
}

Expand All @@ -598,6 +640,7 @@ public override int GetHashCode() {
if (TrainerThreaded != false) hash ^= TrainerThreaded.GetHashCode();
if (SelfPlayEnabled != false) hash ^= SelfPlayEnabled.GetHashCode();
if (CurriculumEnabled != false) hash ^= CurriculumEnabled.GetHashCode();
if (Config.Length != 0) hash ^= Config.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}
Expand Down Expand Up @@ -667,6 +710,10 @@ public void WriteTo(pb::CodedOutputStream output) {
output.WriteRawTag(112);
output.WriteBool(CurriculumEnabled);
}
if (Config.Length != 0) {
output.WriteRawTag(122);
output.WriteString(Config);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
Expand Down Expand Up @@ -717,6 +764,9 @@ public int CalculateSize() {
if (CurriculumEnabled != false) {
size += 1 + 1;
}
if (Config.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeStringSize(Config);
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}
Expand Down Expand Up @@ -770,6 +820,9 @@ public void MergeFrom(TrainingBehaviorInitialized other) {
if (other.CurriculumEnabled != false) {
CurriculumEnabled = other.CurriculumEnabled;
}
if (other.Config.Length != 0) {
Config = other.Config;
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

Expand Down Expand Up @@ -837,6 +890,10 @@ public void MergeFrom(pb::CodedInputStream input) {
CurriculumEnabled = input.ReadBool();
break;
}
case 122: {
Config = input.ReadString();
break;
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,19 @@ public void TestRemotePolicy()
Academy.Instance.Dispose();
}

[TestCase("a name we expect to hash", ExpectedResult = "d084a8b6da6a6a1c097cdc9ffea95e1546da4647352113ed77cbe7b4192e6d73")]
[TestCase("another_name", ExpectedResult = "0b74613c872e79aba11e06eda3538f2b646eb2b459e75087829ea500bd703d0b")]
[TestCase("0b74613c872e79aba11e06eda3538f2b646eb2b459e75087829ea500bd703d0b", ExpectedResult = "0b74613c872e79aba11e06eda3538f2b646eb2b459e75087829ea500bd703d0b")]
public string TestTrainingBehaviorInitialized(string stringToMaybeHash)
{
var tbiEvent = new TrainingBehaviorInitializedEvent();
tbiEvent.BehaviorName = stringToMaybeHash;
tbiEvent.Config = "{}";

var sanitizedEvent = TrainingAnalytics.SanitizeTrainingBehaviorInitializedEvent(tbiEvent);
return sanitizedEvent.BehaviorName;
}

[Test]
public void TestEnableAnalytics()
{
Expand Down
Loading