Skip to content

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

Backprop RGB image losses to mesh shape? #839

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
priyasundaresan opened this issue Sep 14, 2021 · 1 comment
Closed

Backprop RGB image losses to mesh shape? #839

priyasundaresan opened this issue Sep 14, 2021 · 1 comment
Assignees
Labels
how to How to use PyTorch3D in my project

Comments

@priyasundaresan
Copy link

Backprop RGB image losses to mesh shape?

I am trying to fit a source mesh (sphere) to a target mesh (cow) per the example from https://pytorch3d.org/tutorials/fit_textured_mesh. In particular, I would like to propagate losses taken over the rendered RGB images of the current and target mesh to the vertex positions of the current mesh being deformed.

I am able to achieve the desired results using a 50-50 weighting of L1 silhouette loss and L1 RGB loss taken over the rendered images, and the training progression across 200 iterations is shown here:
00000_depth
00050_depth
00100_depth
00150_depth
00200_depth

However, using only L1 RGB loss, the mesh doesn't converge to the desired shape as shown here:
00000_depth
00050_depth
00100_depth
00150_depth

I have tried using L2 RGB loss and changing the texture color, but still have this issue. Is it possible to use RGB image supervision without silhouettes and propagate to mesh shape? I have referred to the similar issue here but have confirmed that the rasterization settings do not have this problem.

This can be reproduced by running the following code, and replacing losses = {"rgb": {"weight": 0.5, "values": []}, "silhouette": {"weight": 0.5, "values": []}} with losses = {"rgb": {"weight": 1.0, "values": []}, "silhouette": {"weight": 0.0, "values": []}}

import os
import cv2
import sys
sys.path.append(os.path.abspath(''))
import torch
import os
import torch
import matplotlib.pyplot as plt
from pytorch3d.utils import ico_sphere
import numpy as np
from pytorch3d.io import load_objs_as_meshes, save_obj
from plot_image_grid import image_grid
from pytorch3d.loss import (
    chamfer_distance, 
    mesh_edge_loss, 
    mesh_laplacian_smoothing, 
    mesh_normal_consistency,
)
# Data structures and functions for rendering
from pytorch3d.structures import Meshes
from pytorch3d.renderer import (
    look_at_view_transform,
    OpenGLPerspectiveCameras, 
    PointLights, 
    DirectionalLights, 
    Materials, 
    RasterizationSettings, 
    MeshRenderer, 
    MeshRasterizer,  
    SoftPhongShader,
    SoftSilhouetteShader,
    SoftPhongShader,
    TexturesVertex
)

if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)
else:
    device = torch.device("cpu")

DATA_DIR = "./data"
obj_filename = os.path.join(DATA_DIR, "cow_mesh/cow.obj")
mesh = load_objs_as_meshes([obj_filename], device=device)
white_tex = torch.ones_like(mesh.verts_packed())
white_tex = white_tex.unsqueeze(0) 
white_tex = TexturesVertex(verts_features=white_tex.to(device))
mesh.textures = white_tex

verts = mesh.verts_packed()
N = verts.shape[0]
center = verts.mean(0)
scale = max((verts - center).abs().max(0)[0])
mesh.offset_verts_(-center)
mesh.scale_verts_((1.0 / float(scale)));

num_views = 20
#num_views = 2
elev = torch.linspace(0, 360, num_views)
azim = torch.linspace(-180, 180, num_views)
#azim = torch.linspace(-180, -90, num_views)
lights = PointLights(device=device, location=[[0.0, 0.0, -3.0]])
R, T = look_at_view_transform(dist=2.7, elev=elev, azim=azim)
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
camera = cameras[0]

raster_settings = RasterizationSettings(
    image_size=128, 
    blur_radius=0.0, 
    faces_per_pixel=1, 
    perspective_correct=False
)
renderer = MeshRenderer(
    rasterizer=MeshRasterizer(
        cameras=camera, 
        raster_settings=raster_settings
    ),
    shader=SoftPhongShader(
        device=device, 
        cameras=camera,
        lights=lights
    )
)

sigma = 1e-4
raster_settings_silhouette = RasterizationSettings(
    image_size=128, 
    blur_radius=np.log(1. / 1e-4 - 1.)*sigma, 
    faces_per_pixel=50, 
    perspective_correct=False
)

# Silhouette renderer 
renderer_silhouette = MeshRenderer(
    rasterizer=MeshRasterizer(
        cameras=camera, 
        raster_settings=raster_settings_silhouette
    ),
    shader=SoftSilhouetteShader()
)

meshes = mesh.extend(num_views)
target_images = renderer(meshes, cameras=cameras, lights=lights)
target_rgb = [target_images[i, ..., :3] for i in range(num_views)]
target_cameras = [OpenGLPerspectiveCameras(device=device, R=R[None, i, ...], 
                                           T=T[None, i, ...]) for i in range(num_views)]
image_grid(target_images.cpu().numpy(), rows=4, cols=5, rgb=True)
plt.show()
plt.savefig('cows_rgb.png')
plt.clf()

# Show a visualization comparing the rendered predicted mesh to the ground truth 
def visualize_pred(curr_image, ref_image, fname):
    visualization = np.hstack((curr_image.detach().cpu().numpy(), ref_image.detach().cpu().numpy()))
    cv2.imwrite('%s'%(fname), visualization*255)

# Plot losses as a function of optimization iteration
def plot_losses(losses):
    fig = plt.figure(figsize=(13, 5))
    ax = fig.gca()
    for k, l in losses.items():
        ax.plot(l['values'], label=k + " loss")
    ax.legend(fontsize="16")
    ax.set_xlabel("Iteration", fontsize="16")
    ax.set_ylabel("Loss", fontsize="16")
    ax.set_title("Loss vs iterations", fontsize="16")

# We initialize the source shape to be a sphere of radius 1.  
src_mesh = ico_sphere(4, device)
white_tex = torch.ones_like(src_mesh.verts_packed())
white_tex = white_tex.unsqueeze(0) 
white_tex = TexturesVertex(verts_features=white_tex.to(device))
src_mesh.textures = white_tex


# Rasterization settings for differentiable rendering, where the blur_radius
# initialization is based on Liu et al, 'Soft Rasterizer: A Differentiable 
# Renderer for Image-based 3D Reasoning', ICCV 2019
sigma = 1e-4
raster_settings_soft = RasterizationSettings(
    image_size=128, 
    blur_radius=np.log(1. / 1e-4 - 1.)*sigma, 
    faces_per_pixel=50, 
    perspective_correct=False
)

# Depth rasterizer 
renderer = MeshRenderer(
    rasterizer=MeshRasterizer(
        cameras=camera, 
        raster_settings=raster_settings_soft
    ),
    shader=SoftPhongShader(
        device=device, 
        cameras=camera,
        lights=lights
    )
)

num_views_per_iteration = 5 # use 5 which works best
Niter = 350
plot_period = 50

losses = {"rgb": {"weight": 0.5, "values": []},
          "silhouette": {"weight": 0.5, "values": []}}
verts_shape = src_mesh.verts_packed().shape
deform_verts = torch.full(verts_shape, 0.0, device=device, requires_grad=True)

# The optimizer
optimizer = torch.optim.Adam([deform_verts], lr=0.01) 

loop = range(Niter)

out_dir = os.path.join('out')
if not os.path.exists(out_dir):
    os.mkdir(out_dir)

for i in loop:
    # Initialize optimizer
    optimizer.zero_grad()
    new_src_mesh = src_mesh.offset_verts(deform_verts)
    # Losses to smooth /regularize the mesh shape
    loss = {k: torch.tensor(0.0, device=device) for k in losses}
    
    for j in np.random.permutation(num_views).tolist()[:num_views_per_iteration]:
        curr_image = renderer(new_src_mesh, cameras=target_cameras[j], lights=lights)[..., :3]
        curr_sil = renderer_silhouette(new_src_mesh, cameras=target_cameras[j], lights=lights)[..., 3]
        ref_image = renderer(meshes[j], cameras=target_cameras[j], lights=lights)[..., :3]
        ref_sil = renderer_silhouette(meshes[j], cameras=target_cameras[j], lights=lights)[..., 3]
        loss_rgb = (torch.abs(curr_image - ref_image)).mean() # L1 rgb loss
        loss_silhouette = (torch.abs(curr_sil - ref_sil)).mean() # L1 silhouette loss
        loss["silhouette"] += loss_silhouette / num_views_per_iteration
        loss["rgb"] += loss_rgb / num_views_per_iteration
    
    # Weighted sum of the losses
    sum_loss = torch.tensor(0.0, device=device)
    for k, l in loss.items():
        sum_loss += l * losses[k]["weight"]
        losses[k]["values"].append(float(l.detach().cpu()))

    # Print the losses
    #loop.set_description("total_loss = %.6f" % sum_loss)
    print("iter: %d/%d, total_loss = %.6f" % (i,Niter,sum_loss), end='\r')
    sys.stdout.flush()
    
    # Plot mesh
    if i % plot_period == 0:
        #visualize_prediction(new_src_mesh, fname="%s/%05d.png" % (out_dir, i), silhouette=True, target_image=target_silhouette[1])
        visualize_pred(curr_image[0], ref_image[0], fname="%s/%05d_depth.png"%(out_dir, i))
    # Optimization step
    sum_loss.backward()
    optimizer.step()

#visualize_prediction(new_src_mesh, silhouette=True, target_image=target_silhouette[1], fname='preds.png')
final_verts, final_faces = new_src_mesh.get_mesh_verts_faces(0)
final_verts = final_verts * scale + center
final_obj = os.path.join('./', 'cow_sil_model.obj')
save_obj(final_obj, final_verts, final_faces)
plot_losses(losses)

The cow mesh data is available by running:

!mkdir -p data/cow_mesh
!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow.obj
!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow.mtl
!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow_texture.png

Thanks in advance for your help!

@gkioxari gkioxari self-assigned this Sep 14, 2021
@gkioxari gkioxari added the how to How to use PyTorch3D in my project label Sep 14, 2021
@gkioxari
Copy link
Contributor

@priyasundaresan Thanks for the post. Correct me if I am wrong, but you don't seem to be reporting a specific bug in the PyTorch3D library. I suggest you move this post to the Github Discussions page where other users can weigh in on your problem.

@facebookresearch facebookresearch locked and limited conversation to collaborators Sep 14, 2021

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

Labels
how to How to use PyTorch3D in my project
Projects
None yet
Development

No branches or pull requests

2 participants