@@ -326,7 +326,7 @@ def test_simple_sphere_batched(self):
326
326
)
327
327
self .assertClose (rgb , image_ref )
328
328
329
- def test_compositor_background_color (self ):
329
+ def test_compositor_background_color_rgba (self ):
330
330
331
331
N , H , W , K , C , P = 1 , 15 , 15 , 20 , 4 , 225
332
332
ptclds = torch .randn ((C , P ))
@@ -357,7 +357,7 @@ def test_compositor_background_color(self):
357
357
torch .masked_select (images , is_foreground [:, None ]),
358
358
)
359
359
360
- is_background = ~ is_foreground [..., None ].expand (- 1 , - 1 , - 1 , 4 )
360
+ is_background = ~ is_foreground [..., None ].expand (- 1 , - 1 , - 1 , C )
361
361
362
362
# permute masked_images to correctly get rgb values
363
363
masked_images = masked_images .permute (0 , 2 , 3 , 1 )
@@ -367,12 +367,58 @@ def test_compositor_background_color(self):
367
367
# check if background colors are properly changed
368
368
self .assertTrue (
369
369
masked_images [is_background ]
370
- .view (- 1 , 4 )[..., i ]
370
+ .view (- 1 , C )[..., i ]
371
371
.eq (channel_color )
372
372
.all ()
373
373
)
374
374
375
375
# check background color alpha values
376
376
self .assertTrue (
377
- masked_images [is_background ].view (- 1 , 4 )[..., 3 ].eq (1 ).all ()
377
+ masked_images [is_background ].view (- 1 , C )[..., 3 ].eq (1 ).all ()
378
378
)
379
+
380
+ def test_compositor_background_color_rgb (self ):
381
+
382
+ N , H , W , K , C , P = 1 , 15 , 15 , 20 , 3 , 225
383
+ ptclds = torch .randn ((C , P ))
384
+ alphas = torch .rand ((N , K , H , W ))
385
+ pix_idxs = torch .randint (- 1 , 20 , (N , K , H , W )) # 20 < P, large amount of -1
386
+ background_color = [0.5 , 0 , 1 ]
387
+
388
+ compositor_funcs = [
389
+ (NormWeightedCompositor , norm_weighted_sum ),
390
+ (AlphaCompositor , alpha_composite ),
391
+ ]
392
+
393
+ for (compositor_class , composite_func ) in compositor_funcs :
394
+
395
+ compositor = compositor_class (background_color )
396
+
397
+ # run the forward method to generate masked images
398
+ masked_images = compositor .forward (pix_idxs , alphas , ptclds )
399
+
400
+ # generate unmasked images for testing purposes
401
+ images = composite_func (pix_idxs , alphas , ptclds )
402
+
403
+ is_foreground = pix_idxs [:, 0 ] >= 0
404
+
405
+ # make sure foreground values are unchanged
406
+ self .assertClose (
407
+ torch .masked_select (masked_images , is_foreground [:, None ]),
408
+ torch .masked_select (images , is_foreground [:, None ]),
409
+ )
410
+
411
+ is_background = ~ is_foreground [..., None ].expand (- 1 , - 1 , - 1 , C )
412
+
413
+ # permute masked_images to correctly get rgb values
414
+ masked_images = masked_images .permute (0 , 2 , 3 , 1 )
415
+ for i in range (3 ):
416
+ channel_color = background_color [i ]
417
+
418
+ # check if background colors are properly changed
419
+ self .assertTrue (
420
+ masked_images [is_background ]
421
+ .view (- 1 , C )[..., i ]
422
+ .eq (channel_color )
423
+ .all ()
424
+ )
0 commit comments