@@ -265,17 +265,40 @@ def _add(self, *args):
265
265
Args:
266
266
*args: All the elements in a transition.
267
267
"""
268
- cursor = self .cursor ()
268
+ self ._check_args_length (* args )
269
+ transition = {e .name : args [idx ]
270
+ for idx , e in enumerate (self .get_add_args_signature ())}
271
+ self ._add_transition (transition )
272
+
273
+ def _add_transition (self , transition ):
274
+ """Internal add method to add transition dictionary to storage arrays.
269
275
270
- arg_names = [e .name for e in self .get_add_args_signature ()]
271
- for arg_name , arg in zip (arg_names , args ):
272
- self ._store [arg_name ][cursor ] = arg
276
+ Args:
277
+ transition: The dictionary of names and values of the transition
278
+ to add to the storage.
279
+ """
280
+ cursor = self .cursor ()
281
+ for arg_name in transition :
282
+ self ._store [arg_name ][cursor ] = transition [arg_name ]
273
283
274
284
self .add_count += 1
275
285
self .invalid_range = invalid_range (
276
286
self .cursor (), self ._replay_capacity , self ._stack_size ,
277
287
self ._update_horizon )
278
288
289
+ def _check_args_length (self , * args ):
290
+ """Check if args passed to the add method have the same length as storage.
291
+
292
+ Args:
293
+ *args: Args for elements used in storage.
294
+
295
+ Raises:
296
+ ValueError: If args have wrong length.
297
+ """
298
+ if len (args ) != len (self .get_add_args_signature ()):
299
+ raise ValueError ('Add expects {} elements, received {}' .format (
300
+ len (self .get_add_args_signature ()), len (args )))
301
+
279
302
def _check_add_types (self , * args ):
280
303
"""Checks if args passed to the add method match those of the storage.
281
304
@@ -285,9 +308,7 @@ def _check_add_types(self, *args):
285
308
Raises:
286
309
ValueError: If args have wrong shape or dtype.
287
310
"""
288
- if len (args ) != len (self .get_add_args_signature ()):
289
- raise ValueError ('Add expects {} elements, received {}' .format (
290
- len (self .get_add_args_signature ()), len (args )))
311
+ self ._check_args_length (* args )
291
312
for arg_element , store_element in zip (args , self .get_add_args_signature ()):
292
313
if isinstance (arg_element , np .ndarray ):
293
314
arg_shape = arg_element .shape
0 commit comments