From 8464ddf12cb915a3d5773049d4a7a30dc0f0bd9d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 29 Apr 2025 22:39:13 +0800 Subject: [PATCH] add __repr__ for better printing of configs. --- src/diffusers/quantizers/quantization_config.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index 0bc433be0ff3..76dc961d0d3b 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -428,6 +428,10 @@ def __init__(self, compute_dtype: Optional["torch.dtype"] = None): if self.compute_dtype is None: self.compute_dtype = torch.float32 + def __repr__(self): + config_dict = self.to_dict() + return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n" + @dataclass class TorchAoConfig(QuantizationConfigMixin): @@ -722,3 +726,7 @@ def post_init(self): accepted_weights = ["float8", "int8", "int4", "int2"] if self.weights_dtype not in accepted_weights: raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights_dtype}") + + def __repr__(self): + config_dict = self.to_dict() + return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"