Skip to content

QAT model drops accuracy after converting with torch.ao.quantization.convert #2138

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
tranngocduvnvp opened this issue Apr 28, 2025 · 3 comments
Assignees

Comments

@tranngocduvnvp
Copy link

Hello everyone.

I am implementing QAT model yolov8 in 4bit mode for weight and 8bit for activation by setting quant_min, quant_max in config. The model when training and eval gives quite good results, however when I convert using torch.ao.quantization.convert method, the model gives very bad evaluation results. Does anyone know how to solve this problem?

@supriyar supriyar added the qat label Apr 29, 2025
@supriyar
Copy link
Contributor

cc @andrewor14

@andrewor14
Copy link
Contributor

Hi @tranngocduvnvp, can you share your prepare and convert flow? Are you using the APIs documented here? https://github.com/pytorch/ao/blob/main/torchao/quantization/qat/README.md#quantize_-api-recommended. torchao QAT is not expected to work with torch.ao.quantization.convert from the pytorch/pytorch repo.

@tranngocduvnvp
Copy link
Author

tranngocduvnvp commented May 5, 2025

Hi @andrewor14, thanks for your feedback!

I found the cause of the drop in accuracy when performing the convert function. It came from the fact that in the training loop, at the 3rd epoch I turned off the FakeQuantize feature of some layers while still enabling observer, which caused the scale value to change when converting the weight to int format.

But I have another question, my model when training still gives quite bad results, I use another repo about QAT model yolo 4bit using pip install brevitas library and it gives very good results. Can you please show me the reason for the decrease in accuracy when using torch.ao library? The quantization configuration code for my layers is as follows:

def config_quant(bit_width_act, bit_width_weight, asym=False): my_qconfig = QConfig( activation=FakeQuantize.with_args( observer=MovingAverageMinMaxObserver , quant_min=-(2**(bit_width_act-1)), quant_max=2**(bit_width_act-1)-1, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False, averaging_constant=0.01 ) if asym == False else FakeQuantize.with_args( observer=MovingAverageMinMaxObserver , quant_min=0, quant_max=2**(bit_width_act)-1, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False, averaging_constant=0.01 ), weight=FakeQuantize.with_args( observer=MovingAverageMinMaxObserver, quant_min=-(2**(bit_width_weight-1)), quant_max=2**(bit_width_weight-1)-1, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False, # averaging_constant=0.01 ) ) return my_qconfig

`
model.qconfig = config_quant(bit_width_act=8, bit_width_weight=8)

for name, module in model.named_modules():
    print(name)

    if name == "model.net.p1":
        module.qconfig = config_quant(bit_width_act=8, bit_width_weight=8)
    elif name == "model.net.p2.0":
        module.qconfig = config_quant(bit_width_act=8, bit_width_weight=8)
    elif name == "model.net.p2.1":
        module.qconfig = config_quant(bit_width_act=4, bit_width_weight=4, asym=True)
    elif name == "model.net.p3.0":
        module.qconfig = config_quant(bit_width_act=4, bit_width_weight=4, asym=True) 
    elif name == "model.net.p3.1":
        module.qconfig = config_quant(bit_width_act=4, bit_width_weight=4, asym=True)
    elif name == "model.net.p4.0":
        module.qconfig = config_quant(bit_width_act=4, bit_width_weight=4, asym=True) 
    elif name == "model.net.p4.1":
        module.qconfig = config_quant(bit_width_act=4, bit_width_weight=4, asym=True) 
    elif name == "model.net.p5.0":
        module.qconfig = config_quant(bit_width_act=4, bit_width_weight=4, asym=True) 
    elif name == "model.net.p5.1":
        module.qconfig = config_quant(bit_width_act=4, bit_width_weight=4, asym=True) 
    elif name == "model.net.p5.2":
        module.qconfig = config_quant(bit_width_act=4, bit_width_weight=4, asym=True) 
    elif name == "model.fpn.h1":
        module.qconfig = config_quant(bit_width_act=4, bit_width_weight=4, asym=True) 

`

Thank you so much !!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants