Skip to content

Commit 98da4b1

Browse files
cmardmaryamhonari
andauthored
Update ml-agents/mlagents/trainers/torch/distributions.py
Co-authored-by: Maryam Honari <honari.m94@gmail.com>
1 parent 0431025 commit 98da4b1

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

ml-agents/mlagents/trainers/torch/distributions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def sample(self):
124124
return torch.multinomial(self.probs, 1)
125125

126126
def deterministic_sample(self):
127-
return torch.argmax(self.probs).reshape((1, 1))
127+
return torch.argmax(self.probs, dim=1, keepdim=True)
128128

129129
def pdf(self, value):
130130
# This function is equivalent to torch.diag(self.probs.T[value.flatten().long()]),

0 commit comments

Comments
 (0)