Skip to content

[Bug Fix]: qpc sdk config dump issue fixing #379

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
May 8, 2025
124 changes: 73 additions & 51 deletions QEfficient/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,27 +521,57 @@ def __repr__(self):
def dump_qconfig(func):
def wrapper(self, *args, **kwargs):
result = func(self, *args, **kwargs)
create_and_dump_qconfigs(
self.qpc_path,
self.onnx_path,
self.get_model_config,
[cls.__name__ for cls in self._pytorch_transforms],
[cls.__name__ for cls in self._onnx_transforms],
kwargs.get("specializations"),
kwargs.get("mdp_ts_num_devices", 1),
kwargs.get("num_speculative_tokens"),
**{
k: v
for k, v in kwargs.items()
if k
not in ["specializations", "mdp_ts_num_devices", "num_speculative_tokens", "custom_io", "onnx_path"]
},
)
try:
create_and_dump_qconfigs(
self.qpc_path,
self.onnx_path,
self.get_model_config,
[cls.__name__ for cls in self._pytorch_transforms],
[cls.__name__ for cls in self._onnx_transforms],
kwargs.get("specializations"),
kwargs.get("mdp_ts_num_devices", 1),
kwargs.get("num_speculative_tokens"),
**{
k: v
for k, v in kwargs.items()
if k
not in ["specializations", "mdp_ts_num_devices", "num_speculative_tokens", "custom_io", "onnx_path"]
},
)
except Exception as e:
print(f"An unexpected error occurred while dumping the qconfig: {e}")
return result

return wrapper


def get_qaic_sdk_version(qaic_sdk_xml_path: str) -> Optional[str]:
"""
Extracts the QAIC SDK version from the given SDK XML file.

Args:
qaic_sdk_xml_path (str): Path to the SDK XML file.
Returns:
The SDK version as a string if found, otherwise None.
"""
qaic_sdk_version = None

# Check and extract version from the given SDK XML file
if os.path.exists(qaic_sdk_xml_path):
try:
tree = ET.parse(qaic_sdk_xml_path)
root = tree.getroot()
base_version_element = root.find(".//base_version")
if base_version_element is not None:
qaic_sdk_version = base_version_element.text
except ET.ParseError as e:
print(f"Error parsing XML file {qaic_sdk_xml_path}: {e}")
except Exception as e:
print(f"An unexpected error occurred while processing {qaic_sdk_xml_path}: {e}")

return qaic_sdk_version


def create_and_dump_qconfigs(
qpc_path,
onnx_path,
Expand All @@ -558,29 +588,12 @@ def create_and_dump_qconfigs(
Such as huggingface configs, QEff transforms, QAIC sdk version, QNN sdk, compilation dir, qpc dir and
many other compilation options.
"""
qnn_config = compiler_options["qnn_config"] if "qnn_config" in compiler_options else None
enable_qnn = True if "qnn_config" in compiler_options else None

enable_qnn = compiler_options.get("enable_qnn", False)
qnn_config_path = compiler_options.get("qnn_config", None)
qconfig_file_path = os.path.join(os.path.dirname(qpc_path), "qconfig.json")
onnx_path = str(onnx_path)
specializations_file_path = str(os.path.join(os.path.dirname(qpc_path), "specializations.json"))
compile_dir = str(os.path.dirname(qpc_path))
qnn_config_path = (
(qnn_config if qnn_config is not None else "QEfficient/compile/qnn_config.json") if enable_qnn else None
)

# Extract QAIC SDK Apps Version from SDK XML file
tree = ET.parse(Constants.SDK_APPS_XML)
root = tree.getroot()
qaic_version = root.find(".//base_version").text

# Extract QNN SDK details from YAML file if the environment variable is set
qnn_sdk_details = None
qnn_sdk_path = os.getenv(QnnConstants.QNN_SDK_PATH_ENV_VAR_NAME)
if enable_qnn and qnn_sdk_path:
qnn_sdk_yaml_path = os.path.join(qnn_sdk_path, QnnConstants.QNN_SDK_YAML)
with open(qnn_sdk_yaml_path, "r") as file:
qnn_sdk_details = yaml.safe_load(file)

# Ensure all objects in the configs dictionary are JSON serializable
def make_serializable(obj):
Expand All @@ -602,29 +615,38 @@ def make_serializable(obj):
"onnx_transforms": make_serializable(onnx_transforms),
"onnx_path": onnx_path,
},
"compiler_config": {
"enable_qnn": enable_qnn,
"compile_dir": compile_dir,
"specializations_file_path": specializations_file_path,
"specializations": make_serializable(specializations),
"mdp_ts_num_devices": mdp_ts_num_devices,
"num_speculative_tokens": num_speculative_tokens,
**compiler_options,
},
"aic_sdk_config": {
"qaic_apps_version": get_qaic_sdk_version(Constants.SDK_APPS_XML),
"qaic_platform_version": get_qaic_sdk_version(Constants.SDK_PLATFORM_XML),
},
},
}

aic_compiler_config = {
"apps_sdk_version": qaic_version,
"compile_dir": compile_dir,
"specializations_file_path": specializations_file_path,
"specializations": make_serializable(specializations),
"mdp_ts_num_devices": mdp_ts_num_devices,
"num_speculative_tokens": num_speculative_tokens,
**compiler_options,
}
qnn_config = {
"enable_qnn": enable_qnn,
"qnn_config_path": qnn_config_path,
}
# Put AIC or qnn details.
if enable_qnn:
qnn_sdk_path = os.getenv(QnnConstants.QNN_SDK_PATH_ENV_VAR_NAME)
if not qnn_sdk_path:
raise EnvironmentError(
f"QNN_SDK_PATH {qnn_sdk_path} is not set. Please set {QnnConstants.QNN_SDK_PATH_ENV_VAR_NAME}"
)
qnn_sdk_yaml_path = os.path.join(qnn_sdk_path, QnnConstants.QNN_SDK_YAML)
qnn_sdk_details = load_yaml(
qnn_sdk_yaml_path
) # Extract QNN SDK details from YAML file if the environment variable is set
qnn_config = {
"qnn_config_path": qnn_config_path,
}
qconfigs["qpc_config"]["qnn_config"] = qnn_config
if qnn_sdk_details:
qconfigs["qpc_config"]["qnn_config"].update(qnn_sdk_details)
else:
qconfigs["qpc_config"]["aic_compiler_config"] = aic_compiler_config

create_json(qconfig_file_path, qconfigs)

Expand Down
5 changes: 4 additions & 1 deletion QEfficient/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,10 @@ class Constants:
MAX_QPC_LIMIT = 30
MAX_RETRIES = 10 # This constant will be used set the maximum number of retry attempts for downloading a model using huggingface_hub snapshot_download
NUM_SPECULATIVE_TOKENS = 2
SDK_APPS_XML = "/opt/qti-aic/versions/apps.xml" # This xml file is parsed to find out the SDK version.
SDK_APPS_XML = "/opt/qti-aic/versions/apps.xml" # This xml file is parsed to find out the SDK apps version.
SDK_PLATFORM_XML = (
"/opt/qti-aic/versions/platform.xml" # This xml file is parsed to find out the SDK platform version.
)


@dataclass
Expand Down
Loading