Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

failure of SigLIP2 FP32 to FP16 #4373

Open
yijun02 opened this issue Mar 4, 2025 · 6 comments
Open

failure of SigLIP2 FP32 to FP16 #4373

yijun02 opened this issue Mar 4, 2025 · 6 comments
Labels
Module:Accuracy Output mismatch between TensorRT and other frameworks triaged Issue has been triaged by maintainers

Comments

@yijun02
Copy link

yijun02 commented Mar 4, 2025

I am trying to convert an SigLIP2 model to TensorRT and use fp16, but the cosine similarity between onnx and trt is 0.6463.

I used the following code convert to onnx.

import torch
import torch.nn as nn
import torch.nn.functional as F
from open_clip import create_model_from_pretrained
import subprocess
from urllib.request import urlopen
from PIL import Image
import numpy as np

model_path = "model"

# load model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = create_model_from_pretrained('hf-hub:timm/ViT-B-16-SigLIP2-256', device=device)
model.eval()
# export image encoder
class ImageEncoder(nn.Module):
    def __init__(self, model) -> None:
        super().__init__()
        self.model = model
    @torch.no_grad()
    def forward(self, image):
        image = (image-127.5)/127.5
        image = image.permute(0, 3, 1, 2)
        image_features = model.encode_image(image)
        return image_features

image_encoder = ImageEncoder(model)
dummy_img = Image.open(urlopen(
    'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))
dummy_img = np.array(dummy_img.resize((256, 256)).convert('RGB')).astype(np.float32)
dummy_img = torch.from_numpy(dummy_img).unsqueeze(0).to(device)

torch.onnx.export(image_encoder,
                  (dummy_img),
                  f"{model_path}/img_en_ori.onnx",
                  export_params=True,
                  opset_version=16,
                  do_constant_folding=True,
                  input_names = ['img'],
                  output_names = ['image_feature'])

subprocess.run(["onnxsim", f"{model_path}/img_en_ori.onnx", f"{model_path}/img_en_ori.onnx"], check=True)

and use the command to fp16 trt engine.

/usr/src/tensorrt/bin/trtexec --onnx=model/img_en_ori.onnx --saveEngine=model/img_en_ori.engine --fp16

Environment

AGX with dustynv/l4t-pytorch:r36.4.0

NX with dustynv/l4t-pytorch:2.2-r35.4.1

ubuntu 22.04, RTX 3090 with nvcr.io/nvidia/pytorch:25.01-py3

@kevinch-nv kevinch-nv added triaged Issue has been triaged by maintainers Module:Accuracy Output mismatch between TensorRT and other frameworks labels Mar 7, 2025
@kevinch-nv
Copy link
Collaborator

Are the FP32 results good? Can you try exporting a model without simplifying it (i.e. remove the subprocess.run(["onnxsim", f"{model_path}/img_en_ori.onnx", f"{model_path}/img_en_ori.onnx"], check=True)) line and check if the results are the same?

@yijun02
Copy link
Author

yijun02 commented Mar 10, 2025

Yes, the FP32 results are good.
I computed cosine similarity between onnx and trt the result is as follows.

Image

and the following is the onnx model's diference between opset_version 16(left) and 17(right).

Image Image

I think dustynv/l4t-pytorch:2.2-r35.4.1 and nvcr.io/nvidia/pytorch:25.01-py3 have the same problem in Layer Normalization op with opset_version 16.
I think dustynv/l4t-pytorch:r36.4.0 has the same problem as #4333 and also has the same problem as dustynv/l4t-pytorch:2.2-r35.4.1 in the layer normalization operation of opset_version 16.

@lix19937
Copy link

Use /usr/src/tensorrt/bin/trtexec --onnx=model/img_en_ori.onnx --saveEngine=model/img_en_ori.engine --fp16 --verbose to upload a log here ?

@yijun02
Copy link
Author

yijun02 commented Mar 21, 2025

The log of nx with opset_version 16:
nx_log.txt

The log of agx with opset_version 16:
agx_log.txt

The log of RTX3090 with opset_version 16:
RTX3090_log.txt

@lix19937
Copy link

The logs has no problem, you can export onnx weith ops=17, and use --noTF32 to build.
BTW, use polygraphy to profile which layer output begin to diff.

@yijun02
Copy link
Author

yijun02 commented Mar 21, 2025

The log of agx with opset_version 17 and the cosine similarity is 0.6441.
agx_log_noTF32.txt

polygraphy run model/img_en_ori.onnx --trt --onnxrt --onnx-outputs mark all --trt-outputs mark all > comparison_results.txt
comparison_results.txt

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Module:Accuracy Output mismatch between TensorRT and other frameworks triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

3 participants