Skip to content

Commit 85fa5c2

Browse files
btabapsc-g
authored andcommitted
re-implement self._add for OutOfGraphPrioritizedReplayBuffer and assert during runtime that the args length in _add is equal to output of self.get_add_args_signature
PiperOrigin-RevId: 258364421
1 parent f5f971f commit 85fa5c2

File tree

3 files changed

+48
-13
lines changed

3 files changed

+48
-13
lines changed

dopamine/replay_memory/circular_replay_buffer.py

+28-7
Original file line numberDiff line numberDiff line change
@@ -265,17 +265,40 @@ def _add(self, *args):
265265
Args:
266266
*args: All the elements in a transition.
267267
"""
268-
cursor = self.cursor()
268+
self._check_args_length(*args)
269+
transition = {e.name: args[idx]
270+
for idx, e in enumerate(self.get_add_args_signature())}
271+
self._add_transition(transition)
272+
273+
def _add_transition(self, transition):
274+
"""Internal add method to add transition dictionary to storage arrays.
269275
270-
arg_names = [e.name for e in self.get_add_args_signature()]
271-
for arg_name, arg in zip(arg_names, args):
272-
self._store[arg_name][cursor] = arg
276+
Args:
277+
transition: The dictionary of names and values of the transition
278+
to add to the storage.
279+
"""
280+
cursor = self.cursor()
281+
for arg_name in transition:
282+
self._store[arg_name][cursor] = transition[arg_name]
273283

274284
self.add_count += 1
275285
self.invalid_range = invalid_range(
276286
self.cursor(), self._replay_capacity, self._stack_size,
277287
self._update_horizon)
278288

289+
def _check_args_length(self, *args):
290+
"""Check if args passed to the add method have the same length as storage.
291+
292+
Args:
293+
*args: Args for elements used in storage.
294+
295+
Raises:
296+
ValueError: If args have wrong length.
297+
"""
298+
if len(args) != len(self.get_add_args_signature()):
299+
raise ValueError('Add expects {} elements, received {}'.format(
300+
len(self.get_add_args_signature()), len(args)))
301+
279302
def _check_add_types(self, *args):
280303
"""Checks if args passed to the add method match those of the storage.
281304
@@ -285,9 +308,7 @@ def _check_add_types(self, *args):
285308
Raises:
286309
ValueError: If args have wrong shape or dtype.
287310
"""
288-
if len(args) != len(self.get_add_args_signature()):
289-
raise ValueError('Add expects {} elements, received {}'.format(
290-
len(self.get_add_args_signature()), len(args)))
311+
self._check_args_length(*args)
291312
for arg_element, store_element in zip(args, self.get_add_args_signature()):
292313
if isinstance(arg_element, np.ndarray):
293314
arg_shape = arg_element.shape

dopamine/replay_memory/prioritized_replay_buffer.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -124,20 +124,20 @@ def _add(self, *args):
124124
Args:
125125
*args: All the elements in a transition.
126126
"""
127+
self._check_args_length(*args)
128+
127129
# Use Schaul et al.'s (2015) scheme of setting the priority of new elements
128130
# to the maximum priority so far.
129-
parent_add_args = []
130-
# Picks out 'priority' from arguments and passes the other arguments to the
131-
# parent method.
131+
# Picks out 'priority' from arguments and adds it to the sum_tree.
132+
transition = {}
132133
for i, element in enumerate(self.get_add_args_signature()):
133134
if element.name == 'priority':
134135
priority = args[i]
135136
else:
136-
parent_add_args.append(args[i])
137+
transition[element.name] = args[i]
137138

138139
self.sum_tree.set(self.cursor(), priority)
139-
140-
super(OutOfGraphPrioritizedReplayBuffer, self)._add(*parent_add_args)
140+
super(OutOfGraphPrioritizedReplayBuffer, self)._add_transition(transition)
141141

142142
def sample_index_batch(self, batch_size):
143143
"""Returns a batch of valid indices sampled as in Schaul et al. (2015).

tests/dopamine/replay_memory/prioritized_replay_buffer_test.py

+14
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,20 @@ def add_blank(self, memory, action=0, reward=0.0, terminal=0, priority=1.0):
6060
index = (memory.cursor() - 1) % REPLAY_CAPACITY
6161
return index
6262

63+
def testAddWithAndWithoutPriority(self):
64+
memory = self.create_default_memory()
65+
self.assertEqual(memory.cursor(), 0)
66+
zeros = np.zeros(SCREEN_SIZE)
67+
68+
self.add_blank(memory)
69+
self.assertEqual(memory.cursor(), STACK_SIZE)
70+
self.assertEqual(memory.add_count, STACK_SIZE)
71+
72+
# Check that the prioritized replay buffer expects an additional argument
73+
# for priority.
74+
with self.assertRaisesRegexp(ValueError, 'Add expects'):
75+
memory.add(zeros, 0, 0, 0)
76+
6377
def testDummyScreensAddedToNewMemory(self):
6478
memory = self.create_default_memory()
6579
index = self.add_blank(memory)

0 commit comments

Comments
 (0)