Skip to content

Commit c67dbdf

Browse files
committed
Address comments about moving hash to a class function and not nested.
1 parent 433c064 commit c67dbdf

File tree

1 file changed

+16
-28
lines changed

1 file changed

+16
-28
lines changed

ml-agents/mlagents/training_analytics_side_channel.py

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,11 @@ def __init__(self) -> None:
3333
super().__init__(uuid.UUID("b664a4a9-d86f-5a5f-95cb-e8353a7e8356"))
3434
self.run_options: Optional[RunOptions] = None
3535

36-
@staticmethod
37-
def __hash(key: str, data: str) -> str:
36+
@classmethod
37+
def __hash(cls, data: str) -> str:
3838
res = hmac.new(
39-
key.encode("utf-8"), data.encode("utf-8"), hashlib.sha256
39+
cls.__vendorKey.encode("utf-8"), data.encode("utf-8"), hashlib.sha256
4040
).hexdigest()
41-
print(res)
4241
return res
4342

4443
def on_message_received(self, msg: IncomingMessage) -> None:
@@ -47,36 +46,31 @@ def on_message_received(self, msg: IncomingMessage) -> None:
4746
+ "this should not have happened."
4847
)
4948

50-
@staticmethod
51-
def __sanitize_run_options(config: RunOptions) -> Dict[str, Any]:
49+
@classmethod
50+
def __sanitize_run_options(cls, config: RunOptions) -> Dict[str, Any]:
5251
res = copy.deepcopy(config.as_dict())
5352

54-
def hash(value: str) -> str:
55-
return TrainingAnalyticsSideChannel.__hash(
56-
TrainingAnalyticsSideChannel.__vendorKey, value
57-
)
58-
5953
# Filter potentially PII behavior names
6054
if "behaviors" in res and res["behaviors"]:
61-
res["behaviors"] = {hash(k): v for (k, v) in res["behaviors"].items()}
55+
res["behaviors"] = {cls.__hash(k): v for (k, v) in res["behaviors"].items()}
6256
for (k, v) in res["behaviors"].items():
6357
if "init_path" in v and v["init_path"] is not None:
64-
hashed_path = hash(v["init_path"])
58+
hashed_path = cls.__hash(v["init_path"])
6559
res["behaviors"][k]["init_path"] = hashed_path
6660

6761
# Filter potentially PII curriculum and behavior names from Checkpoint Settings
6862
if "environment_parameters" in res and res["environment_parameters"]:
6963
res["environment_parameters"] = {
70-
hash(k): v for (k, v) in res["environment_parameters"].items()
64+
cls.__hash(k): v for (k, v) in res["environment_parameters"].items()
7165
}
7266
for (curriculumName, curriculum) in res["environment_parameters"].items():
7367
updated_lessons = []
7468
for lesson in curriculum["curriculum"]:
7569
new_lesson = copy.deepcopy(lesson)
7670
if lesson.has_keys("name"):
77-
new_lesson["name"] = hash(lesson["name"])
71+
new_lesson["name"] = cls.__hash(lesson["name"])
7872
if lesson.has_keys("completion_criteria"):
79-
new_lesson["completion_criteria"]["behavior"] = hash(
73+
new_lesson["completion_criteria"]["behavior"] = cls.__hash(
8074
new_lesson["completion_criteria"]["behavior"]
8175
)
8276
updated_lessons.append(new_lesson)
@@ -90,7 +84,7 @@ def hash(value: str) -> str:
9084
"initialize_from" in res["checkpoint_settings"]
9185
and res["checkpoint_settings"]["initialize_from"] is not None
9286
):
93-
res["checkpoint_settings"]["initialize_from"] = hash(
87+
res["checkpoint_settings"]["initialize_from"] = cls.__hash(
9488
res["checkpoint_settings"]["initialize_from"]
9589
)
9690
if (
@@ -123,31 +117,25 @@ def environment_initialized(self, run_options: RunOptions) -> None:
123117
run_options=json.dumps(sanitized_run_options),
124118
)
125119

126-
print(msg)
127-
128120
any_message = Any()
129121
any_message.Pack(msg)
130122

131123
env_init_msg = OutgoingMessage()
132124
env_init_msg.set_raw_bytes(any_message.SerializeToString())
133125
super().queue_message_to_send(env_init_msg)
134126

135-
@staticmethod
136-
def __sanitize_trainer_settings(config: TrainerSettings) -> Dict[str, Any]:
127+
@classmethod
128+
def __sanitize_trainer_settings(cls, config: TrainerSettings) -> Dict[str, Any]:
137129
config_dict = copy.deepcopy(config.as_dict())
138130
if "init_path" in config_dict and config_dict["init_path"] is not None:
139-
hashed_path = TrainingAnalyticsSideChannel.__hash(
140-
TrainingAnalyticsSideChannel.__vendorKey, config_dict["init_path"]
141-
)
131+
hashed_path = cls.__hash(config_dict["init_path"])
142132
config_dict["init_path"] = hashed_path
143133
return config_dict
144134

145135
def training_started(self, behavior_name: str, config: TrainerSettings) -> None:
146-
raw_config = TrainingAnalyticsSideChannel.__sanitize_trainer_settings(config)
136+
raw_config = self.__sanitize_trainer_settings(config)
147137
msg = TrainingBehaviorInitialized(
148-
behavior_name=TrainingAnalyticsSideChannel.__hash(
149-
self.__vendorKey, behavior_name
150-
),
138+
behavior_name=self.__hash(behavior_name),
151139
trainer_type=config.trainer_type.value,
152140
extrinsic_reward_enabled=(
153141
RewardSignalType.EXTRINSIC in config.reward_signals

0 commit comments

Comments
 (0)