Skip to content

Propagate from rendered image into meshes? #276

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
magicknight opened this issue Jul 15, 2020 · 5 comments
Closed

Propagate from rendered image into meshes? #276

magicknight opened this issue Jul 15, 2020 · 5 comments
Assignees
Labels
how to How to use PyTorch3D in my project

Comments

@magicknight
Copy link

Hi,

I am trying to propagate from a single rendered image into meshes, but it seems that the meshes never converge?
My poor cow was deforming into a moster like this:

gif

This is the original cow and the target cow:

target

My code looks like this:

class Model(nn.Module):
    def __init__(self, mesh, tracer, image_ref):
        super().__init__()
        self.mesh = mesh
        self.new_mesh = mesh.clone()
        self.device = tracer.device
        self.tracer = tracer
        
        # Get the reference image
        image_ref = defect_image
        self.register_buffer('image_ref', image_ref)

        # Create an optimizable parameter mesh vertices
#         self.vertices = nn.Parameter(self.mesh.verts_packed()).to(self.device)
        
        # We will learn to deform the source mesh by offsetting its vertices
        # The shape of the deform parameters is equal to the total number of vertices in src_mesh
        self.deform_verts = torch.full(self.mesh.verts_packed().shape, 0.0, device=self.device, requires_grad=True)

    def forward(self):
        self.new_mesh = self.mesh.offset_verts(self.deform_verts)
        
        # Based on the new position of the vertices we update mesh, then make new projection using the updated mesh      
        image = renderer(self.new_mesh, lights=lights, materials=materials, cameras=cameras)

        # Calculate the loss
#         loss = 0.001 * torch.sum((image - self.image_ref) ** 2) + torch.sum(self.deform_verts ** 2)
        loss = torch.sum((image - self.image_ref) ** 2)
        return loss, image

# Initialize a model using the renderer, mesh and reference image
model = Model(mesh=cow_mesh, tracer=tracer, image_ref=reference_proj).to(device)

# Create an optimizer. Here we are using Adam and we pass in the vertices of the model
optimizer = torch.optim.Adam([model.deform_verts], lr=0.005, weight_decay=0.5)
@nikhilaravi
Copy link
Contributor

@magicknight what are the rasterization settings for the renderer?

@nikhilaravi nikhilaravi self-assigned this Jul 15, 2020
@nikhilaravi nikhilaravi added the how to How to use PyTorch3D in my project label Jul 15, 2020
@magicknight
Copy link
Author

@magicknight what are the rasterization settings for the renderer?

Hi @nikhilaravi , thank you for your reply.
Here is the rasterization settings of my renderer

# Rotate the camera by the azimuth angle
R, T = look_at_view_transform(dist=2.5, elev=10, azim=50)
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
lights = PointLights(device=device, location=[[5.0, 5.0, -0.0]])
# blend_setting = BlendParams(2.0, 2.0, (0.0, 0.0, 0.0))raster_settings = RasterizationSettings(
    image_size=512, 
    blur_radius=0.0, 
    faces_per_pixel=1, 
    bin_size = None,  # this setting controls whether naive or coarse-to-fine rasterization is used
    max_faces_per_bin = None  # this setting is for coarse rasterization
)
renderer = MeshRenderer(
    rasterizer=MeshRasterizer(
        cameras=cameras, 
        raster_settings=raster_settings
    ),
    shader=SoftPhongShader(
        device=device, 
        cameras=cameras,
        lights=lights,
#         blend_params=blend_setting
    ),
​
)
​
# Change specular color to green and change material shininess 
materials = Materials(
    device=device,
    specular_color=[[1.0, 1.0, 1.0]],
    shininess=1.0
)
​
# Render the mesh, passing in keyword arguments for the modified components.
cow_image = renderer(cow_mesh, lights=lights, materials=materials, cameras=cameras)
defect_image = renderer(defect_mesh, lights=lights, materials=materials, cameras=cameras)

@nikhilaravi
Copy link
Contributor

@magicknight the blur_radius needs to be > 0.0 and the faces_per_pixel also need to be > 1 - otherwise the loss from each pixel is only being propagated back to one face in the mesh. e.g.

raster_settings = RasterizationSettings(
    image_size= 512, 
    blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma, 
    faces_per_pixel=100, 
)

@nikhilaravi
Copy link
Contributor

@magicknight if this resolves your question please close this issue.

@magicknight
Copy link
Author

Thank you nikhilaravi :D

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
how to How to use PyTorch3D in my project
Projects
None yet
Development

No branches or pull requests

2 participants