@@ -287,6 +287,45 @@ def test_permute_dims(x, axes):
287
287
out_indices = permuted_indices )
288
288
289
289
290
+ @pytest .mark .min_version ("2023.12" )
291
+ @given (
292
+ x = hh .arrays (dtype = hh .all_dtypes , shape = shared_shapes (min_dims = 1 )),
293
+ kw = hh .kwargs (
294
+ axis = st .none () | shared_shapes (min_dims = 1 ).flatmap (
295
+ lambda s : st .integers (- len (s ), len (s ) - 1 )
296
+ )
297
+ ),
298
+ data = st .data (),
299
+ )
300
+ def test_repeat (x , kw , data ):
301
+ shape = x .shape
302
+ axis = kw .get ("axis" , None )
303
+ dim = math .prod (shape ) if axis is None else shape [axis ]
304
+ repeat_strat = st .integers (1 , 4 )
305
+ repeats = data .draw (repeat_strat
306
+ | hh .arrays (dtype = hh .int_dtypes , elements = repeat_strat ,
307
+ shape = st .sampled_from ([(1 ,), (dim ,)])),
308
+ label = "repeats" )
309
+ if isinstance (repeats , int ):
310
+ n_repitions = dim * repeats
311
+ else :
312
+ if repeats .shape == (1 ,):
313
+ n_repitions = dim * repeats [0 ]
314
+ else :
315
+ n_repitions = int (xp .sum (repeats ))
316
+
317
+ out = xp .repeat (x , repeats , ** kw )
318
+ ph .assert_dtype ("repeat" , in_dtype = x .dtype , out_dtype = out .dtype )
319
+ if axis is None :
320
+ expected_shape = (n_repitions ,)
321
+ else :
322
+ expected_shape = list (shape )
323
+ expected_shape [axis ] = n_repitions
324
+ expected_shape = tuple (expected_shape )
325
+ ph .assert_shape ("repeat" , out_shape = out .shape , expected = expected_shape )
326
+ # TODO: values testing
327
+
328
+
290
329
@st .composite
291
330
def reshape_shapes (draw , shape ):
292
331
size = 1 if len (shape ) == 0 else math .prod (shape )
@@ -298,20 +337,6 @@ def reshape_shapes(draw, shape):
298
337
return tuple (rshape )
299
338
300
339
301
- @pytest .mark .min_version ("2023.12" )
302
- @given (
303
- x = hh .arrays (dtype = hh .all_dtypes , shape = hh .shapes (min_dims = 1 )),
304
- repeats = st .integers (1 , 4 ),
305
- )
306
- def test_repeat (x , repeats ):
307
- # TODO: test array repeats and non-None axis, adjust shape and value testing accordingly
308
- out = xp .repeat (x , repeats )
309
- ph .assert_dtype ("repeat" , in_dtype = x .dtype , out_dtype = out .dtype )
310
- expected_shape = (math .prod (x .shape ) * repeats ,)
311
- ph .assert_shape ("repeat" , out_shape = out .shape , expected = expected_shape )
312
- # TODO: values testing
313
-
314
-
315
340
@pytest .mark .unvectorized
316
341
@pytest .mark .skip ("flaky" ) # TODO: fix!
317
342
@given (
0 commit comments