Skip to content

Commit 59972b1

Browse files
Alex Greenefacebook-github-bot
Alex Greene
authored andcommitted
flexible background color for point compositing
Summary: Modified the compositor background color tests to account for either a 3rd or 4th channel. Also replaced hard coding of channel value with C. Implemented changes to alpha channel appending logic, and cleaned up extraneous warnings and checks, per task instructions. Fixes #1048 Reviewed By: bottler Differential Revision: D34305312 fbshipit-source-id: 2176c3bdd897d1a2ba6ff4c6fa801fea889e4f02
1 parent c8f3d6b commit 59972b1

File tree

2 files changed

+61
-15
lines changed

2 files changed

+61
-15
lines changed

pytorch3d/renderer/points/compositor.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def forward(self, fragments, alphas, ptclds, **kwargs) -> torch.Tensor:
3535

3636
# images are of shape (N, C, H, W)
3737
# check for background color & feature size C (C=4 indicates rgba)
38-
if background_color is not None and images.shape[1] == 4:
38+
if background_color is not None:
3939
return _add_background_color_to_images(fragments, images, background_color)
4040
return images
4141

@@ -57,7 +57,7 @@ def forward(self, fragments, alphas, ptclds, **kwargs) -> torch.Tensor:
5757

5858
# images are of shape (N, C, H, W)
5959
# check for background color & feature size C (C=4 indicates rgba)
60-
if background_color is not None and images.shape[1] == 4:
60+
if background_color is not None:
6161
return _add_background_color_to_images(fragments, images, background_color)
6262
return images
6363

@@ -85,22 +85,22 @@ def _add_background_color_to_images(pix_idxs, images, background_color):
8585
if not torch.is_tensor(background_color):
8686
background_color = images.new_tensor(background_color)
8787

88-
background_shape = background_color.shape
89-
90-
if len(background_shape) != 1 or background_shape[0] not in (3, 4):
91-
warnings.warn(
92-
"Background color should be size (3) or (4), but is size %s instead"
93-
% (background_shape,)
94-
)
95-
return images
88+
if len(background_color.shape) != 1:
89+
raise ValueError("Wrong shape of background_color")
9690

9791
background_color = background_color.to(images)
9892

9993
# add alpha channel
100-
if background_shape[0] == 3:
94+
if background_color.shape[0] == 3 and images.shape[1] == 4:
95+
# special case to allow giving RGB background for RGBA
10196
alpha = images.new_ones(1)
10297
background_color = torch.cat([background_color, alpha])
10398

99+
if images.shape[1] != background_color.shape[0]:
100+
raise ValueError(
101+
f"background color has {background_color.shape[0] } channels not {images.shape[1]}"
102+
)
103+
104104
num_background_pixels = background_mask.sum()
105105

106106
# permute so that features are the last dimension for masked_scatter to work

tests/test_render_points.py

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ def test_simple_sphere_batched(self):
326326
)
327327
self.assertClose(rgb, image_ref)
328328

329-
def test_compositor_background_color(self):
329+
def test_compositor_background_color_rgba(self):
330330

331331
N, H, W, K, C, P = 1, 15, 15, 20, 4, 225
332332
ptclds = torch.randn((C, P))
@@ -357,7 +357,7 @@ def test_compositor_background_color(self):
357357
torch.masked_select(images, is_foreground[:, None]),
358358
)
359359

360-
is_background = ~is_foreground[..., None].expand(-1, -1, -1, 4)
360+
is_background = ~is_foreground[..., None].expand(-1, -1, -1, C)
361361

362362
# permute masked_images to correctly get rgb values
363363
masked_images = masked_images.permute(0, 2, 3, 1)
@@ -367,12 +367,58 @@ def test_compositor_background_color(self):
367367
# check if background colors are properly changed
368368
self.assertTrue(
369369
masked_images[is_background]
370-
.view(-1, 4)[..., i]
370+
.view(-1, C)[..., i]
371371
.eq(channel_color)
372372
.all()
373373
)
374374

375375
# check background color alpha values
376376
self.assertTrue(
377-
masked_images[is_background].view(-1, 4)[..., 3].eq(1).all()
377+
masked_images[is_background].view(-1, C)[..., 3].eq(1).all()
378378
)
379+
380+
def test_compositor_background_color_rgb(self):
381+
382+
N, H, W, K, C, P = 1, 15, 15, 20, 3, 225
383+
ptclds = torch.randn((C, P))
384+
alphas = torch.rand((N, K, H, W))
385+
pix_idxs = torch.randint(-1, 20, (N, K, H, W)) # 20 < P, large amount of -1
386+
background_color = [0.5, 0, 1]
387+
388+
compositor_funcs = [
389+
(NormWeightedCompositor, norm_weighted_sum),
390+
(AlphaCompositor, alpha_composite),
391+
]
392+
393+
for (compositor_class, composite_func) in compositor_funcs:
394+
395+
compositor = compositor_class(background_color)
396+
397+
# run the forward method to generate masked images
398+
masked_images = compositor.forward(pix_idxs, alphas, ptclds)
399+
400+
# generate unmasked images for testing purposes
401+
images = composite_func(pix_idxs, alphas, ptclds)
402+
403+
is_foreground = pix_idxs[:, 0] >= 0
404+
405+
# make sure foreground values are unchanged
406+
self.assertClose(
407+
torch.masked_select(masked_images, is_foreground[:, None]),
408+
torch.masked_select(images, is_foreground[:, None]),
409+
)
410+
411+
is_background = ~is_foreground[..., None].expand(-1, -1, -1, C)
412+
413+
# permute masked_images to correctly get rgb values
414+
masked_images = masked_images.permute(0, 2, 3, 1)
415+
for i in range(3):
416+
channel_color = background_color[i]
417+
418+
# check if background colors are properly changed
419+
self.assertTrue(
420+
masked_images[is_background]
421+
.view(-1, C)[..., i]
422+
.eq(channel_color)
423+
.all()
424+
)

0 commit comments

Comments
 (0)