Skip to content

AssertionError During Quantization of torch.empty_like(), torch.ones_like, and torch.randn_like #2146

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
defaultd661 opened this issue Apr 28, 2025 · 2 comments
Labels
pt2e_quant pt2 export quantization triaged

Comments

@defaultd661
Copy link

defaultd661 commented Apr 28, 2025

🐛 Describe the bug

Similar to #146621, when quantizing a model containing a torch.empty_like(), torch.ones_like, and torch.randn_like operations using PT2E (prepare_pt2e), the process fails with an assertion error inside _maybe_insert_input_observers_for_node. The root cause is that torch.empty_like(), torch.ones_like, and torch.randn_like have kwargs, but currently the code assumes that most aten ops (except a few listed ones) should not have kwargs.

torch.empty_like

import torch
from torch.ao.quantization.quantize_pt2e import prepare_pt2e
from torch.ao.quantization.quantizer import Quantizer, QuantizationSpec
from torch.ao.quantization.observer import HistogramObserver
from torch.ao.quantization.quantizer.utils import _annotate_input_qspec_map, _annotate_output_qspec

class EmptyLikeModule(torch.nn.Module):

    def forward(self, t: torch.Tensor):
        return torch.empty_like(t)

class TestQuantizer(Quantizer):

    def annotate(self, model: torch.fx.GraphModule) ->torch.fx.GraphModule:
        qspec = QuantizationSpec(torch.int8, HistogramObserver, qscheme=
            torch.per_tensor_symmetric)
        for node in model.graph.nodes:
            for input_node in node.all_input_nodes:
                _annotate_input_qspec_map(node, input_node, qspec)
            _annotate_output_qspec(node, qspec)
        return model

    def validate(self, model: torch.fx.GraphModule) ->None:
        pass
    
def test_bug():
    model = EmptyLikeModule()
    exported_model = torch.export.export(model, (torch.randn(10),))
    prepared_model = prepare_pt2e(exported_model.graph_module,
        TestQuantizer())

if __name__ == '__main__':
    test_bug()

torch.ones_like

import torch
from torch.ao.quantization.quantize_pt2e import prepare_pt2e
from torch.ao.quantization.quantizer import Quantizer, QuantizationSpec
from torch.ao.quantization.observer import HistogramObserver
from torch.ao.quantization.quantizer.utils import _annotate_input_qspec_map, _annotate_output_qspec

class OnesLikeModule(torch.nn.Module):

    def forward(self, t: torch.Tensor):
        return torch.ones_like(t)

class TestQuantizer(Quantizer):

    def annotate(self, model: torch.fx.GraphModule) ->torch.fx.GraphModule:
        qspec = QuantizationSpec(torch.int8, HistogramObserver, qscheme=
            torch.per_tensor_symmetric)
        for node in model.graph.nodes:
            for input_node in node.all_input_nodes:
                _annotate_input_qspec_map(node, input_node, qspec)
            _annotate_output_qspec(node, qspec)
        return model

    def validate(self, model: torch.fx.GraphModule) ->None:
        pass

def test_bug():
    model = OnesLikeModule()
    exported_model = torch.export.export(model, (torch.randn(10),))
    prepared_model = prepare_pt2e(exported_model.graph_module,
        TestQuantizer())

if __name__ == '__main__':
    test_bug()

torch.randn_like

import torch
from torch.ao.quantization.quantize_pt2e import prepare_pt2e
from torch.ao.quantization.quantizer import Quantizer, QuantizationSpec
from torch.ao.quantization.observer import HistogramObserver
from torch.ao.quantization.quantizer.utils import _annotate_input_qspec_map, _annotate_output_qspec

class RandnLikeModule(torch.nn.Module):

    def forward(self, t: torch.Tensor):
        return torch.randn_like(t)

class TestQuantizer(Quantizer):

    def annotate(self, model: torch.fx.GraphModule) ->torch.fx.GraphModule:
        qspec = QuantizationSpec(torch.int8, HistogramObserver, qscheme=
            torch.per_tensor_symmetric)
        for node in model.graph.nodes:
            for input_node in node.all_input_nodes:
                _annotate_input_qspec_map(node, input_node, qspec)
            _annotate_output_qspec(node, qspec)
        return model

    def validate(self, model: torch.fx.GraphModule) ->None:
        pass

def test_bug():
    model = RandnLikeModule()
    exported_model = torch.export.export(model, (torch.randn(10),))
    prepared_model = prepare_pt2e(exported_model.graph_module,
        TestQuantizer())

if __name__ == '__main__':
    test_bug()

Versions

PyTorch version: 2.7.0+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A

cc @jerryzh168 @jianyuh @raghuramank100 @jamesr66a @vkuzo @jgong5 @Xia-Weiwen @leslie-fang-intel @msaroufim

@jerryzh168
Copy link
Contributor

thanks for reporting the issue, we are moving pt2e to torchao, so let's move the issue there

@jerryzh168 jerryzh168 transferred this issue from pytorch/pytorch Apr 29, 2025
@jerryzh168 jerryzh168 added the pt2e_quant pt2 export quantization label Apr 29, 2025
@jerryzh168
Copy link
Contributor

so this looks like a limitation of current code, we should be able to relax it, although it might take us some time to get there, please feel free to contribute if you need it earlier!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pt2e_quant pt2 export quantization triaged
Projects
None yet
Development

No branches or pull requests

2 participants