Skip to content

shader: add SoftZShader and HardZShader for rendering depth maps #1208

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions pytorch3d/renderer/mesh/shader.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,3 +375,70 @@ def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tenso
)

return images

class HardZShader(ShaderBase):
"""
Renders the Z distances of the closest face for each pixel. If no face is
found it returns the zfar value of the camera.

To use the default values, simply initialize the shader with the desired
device e.g.

.. code-block::

shader = HardZShader(device=torch.device("cuda:0"))
"""
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
msg = "Cameras must be specified either at initialization \
or in the forward pass of HardZShader"
raise ValueError(msg)

zfar = kwargs.get("zfar", getattr(cameras, "zfar", 100.0))
mask = fragments.pix_to_face < 0

zbuf = fragments.zbuf[..., 0].clone()
zbuf[mask] = zfar
return zbuf


class SoftZShader(ShaderBase):
"""
Renders the Z distances using an aggregate of the distances of each face
based off of the point distance.

To use the default values, simply initialize the shader with the desired
device e.g.

.. code-block::

shader = SoftZShader(device=torch.device("cuda:0"))
"""
def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
msg = "Cameras must be specified either at initialization \
or in the forward pass of SoftZShader"
raise ValueError(msg)

N, H, W, K = fragments.pix_to_face.shape
device = fragments.zbuf.device
mask = fragments.pix_to_face >= 0

zfar = kwargs.get("zfar", getattr(cameras, "zfar", 100.0))

# Sigmoid probability map based on the distance of the pixel to the face.
prob_map = torch.sigmoid(-fragments.dists / self.blend_params.sigma) * mask

# append extra face for zfar
dists = torch.cat((fragments.zbuf, torch.ones((N, H, W, 1), device=device) * zfar), dim=3)
probs = torch.cat((prob_map, torch.ones((N, H, W, 1), device=device)), dim=3)

# compute weighting based off of probabilities using cumsum
probs = probs.cumsum(dim=3)
probs = probs.clamp(max=1)
probs = probs.diff(dim=3, prepend=torch.zeros((N, H, W, 1), device=device))


return (probs * dists).sum(dim=3)
6 changes: 6 additions & 0 deletions tests/test_shader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
HardFlatShader,
HardGouraudShader,
HardPhongShader,
HardZShader,
SoftPhongShader,
SoftZShader,
)
from pytorch3d.structures.meshes import Meshes

Expand All @@ -30,7 +32,9 @@ def test_to(self):
HardFlatShader,
HardGouraudShader,
HardPhongShader,
HardZShader,
SoftPhongShader,
SoftZShader,
]

for shader_class in shader_classes:
Expand Down Expand Up @@ -78,7 +82,9 @@ def test_cameras_check(self):
HardFlatShader,
HardGouraudShader,
HardPhongShader,
HardZShader,
SoftPhongShader,
SoftZShader,
]

for shader_class in shader_classes:
Expand Down