diff --git a/QEfficient/cloud/infer.py b/QEfficient/cloud/infer.py index 30e67344a..95f4e2ff8 100644 --- a/QEfficient/cloud/infer.py +++ b/QEfficient/cloud/infer.py @@ -102,6 +102,7 @@ def main( full_batch_size: Optional[int] = None, prompt_len: int = 32, ctx_len: int = 128, + comp_ctx_lengths: Optional[List[int]] = None, generation_len: Optional[int] = None, mxfp6: bool = False, mxint8: bool = False, @@ -163,6 +164,7 @@ def main( cache_dir=cache_dir, hf_token=hf_token, full_batch_size=full_batch_size, + comp_ctx_lengths=comp_ctx_lengths, local_model_dir=local_model_dir, ) @@ -257,6 +259,12 @@ def main( "--prompt-len", "--prompt_len", default=32, type=int, help="Sequence length for text generation." ) parser.add_argument("--ctx-len", "--ctx_len", default=128, type=int, help="Context length for text generation.") + parser.add_argument( + "--comp_ctx_lengths", + "--comp_ctx_lengths", + type=lambda comp_ctx_lengths: [int(x) for x in comp_ctx_lengths.strip("[]").split(",")], + help="Compute Context length for text generation (comma-separated) e.g. [512,1024,2048] ", + ) parser.add_argument( "--mxfp6", "--mxfp6_matmul", diff --git a/QEfficient/customop/ctx_scatter_gather.py b/QEfficient/customop/ctx_scatter_gather.py index 570df0cf5..79761e02d 100644 --- a/QEfficient/customop/ctx_scatter_gather.py +++ b/QEfficient/customop/ctx_scatter_gather.py @@ -115,8 +115,14 @@ def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> tor @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) -def CtxGather(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxscript.FLOAT: - ctx_indices = ops.Expand(ctx_indices, ops.Slice(ops.Shape(data), starts=[0], ends=[3], axes=[0])) +def CtxGather( + data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32, comp_ctx_len: onnxscript.INT32 +) -> onnxscript.FLOAT: + # Create a shape tensor based on comp_ctx_len + shape_tensor = ops.Concat(ops.Shape(data)[:2], ops.Reshape(comp_ctx_len, [1]), axis=0) + + # Directly use the shape tensor without validation + ctx_indices = ops.Expand(ctx_indices, shape_tensor) ctx_indices = ops.Unsqueeze(ctx_indices, [-1]) return ops.GatherND(data, ctx_indices, batch_dims=2) @@ -127,7 +133,7 @@ class CtxGatherFunc(torch.autograd.Function): """ @staticmethod - def forward(data: torch.Tensor, ctx_indices: torch.Tensor): + def forward(data: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int): batch_indices = torch.arange(data.shape[0]).view(-1, 1, 1) head_indices = torch.arange(data.shape[1]).view(1, -1, 1) return data[batch_indices, head_indices, ctx_indices] @@ -137,5 +143,5 @@ def setup_context(ctx, inputs, outputs): pass @staticmethod - def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> torch.Value: - return g.onnxscript_op(CtxGather, data, ctx_indices).setTypeAs(data) + def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value, comp_ctx_len: int) -> torch.Value: + return g.onnxscript_op(CtxGather, data, ctx_indices, comp_ctx_len).setTypeAs(data) diff --git a/QEfficient/customop/ctx_scatter_gather_cb.py b/QEfficient/customop/ctx_scatter_gather_cb.py index e4408829d..15b0847aa 100644 --- a/QEfficient/customop/ctx_scatter_gather_cb.py +++ b/QEfficient/customop/ctx_scatter_gather_cb.py @@ -97,11 +97,12 @@ def symbolic( @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) def CtxGatherCB( - data: onnxscript.FLOAT, batch_index: onnxscript.INT32, ctx_indices: onnxscript.INT32 + data: onnxscript.FLOAT, batch_index: onnxscript.INT32, ctx_indices: onnxscript.INT32, comp_ctx_len: onnxscript.INT32 ) -> onnxscript.FLOAT: batch_size = ops.Gather(ops.Shape(batch_index), [0]) num_heads = ops.Gather(ops.Shape(data), [1]) - ctx_len = ops.Gather(ops.Shape(data), [2]) + # using compute-context-length (CCL) instead of context-length to do gather process based on CCL and later do attention computations based on CCL as well. + ctx_len = ops.Reshape(comp_ctx_len, [1]) # Expanded shape to create indices zero = ops.Constant(value_ints=[0]) @@ -119,7 +120,7 @@ def CtxGatherCB( class CtxGatherFuncCB(torch.autograd.Function): @staticmethod - def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor): + def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int): batch_indices = batch_index.view(-1, 1, 1) head_indices = torch.arange(data.shape[1]).view(1, -1, 1) return data[batch_indices, head_indices, ctx_indices] @@ -129,8 +130,10 @@ def setup_context(ctx, inputs, outputs): pass @staticmethod - def symbolic(g: torch.Graph, data: torch.Value, batch_index: torch.Value, ctx_indices: torch.Value) -> torch.Value: - return g.onnxscript_op(CtxGatherCB, data, batch_index, ctx_indices).setTypeAs(data) + def symbolic( + g: torch.Graph, data: torch.Value, batch_index: torch.Value, ctx_indices: torch.Value, comp_ctx_len: int + ) -> torch.Value: + return g.onnxscript_op(CtxGatherCB, data, batch_index, ctx_indices, comp_ctx_len).setTypeAs(data) @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index 2dd485a5e..02f11c91e 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -316,6 +316,7 @@ def cloud_ai_100_exec_kv( prompts_txt_file_path: Optional[str] = None, device_id: Optional[List[int]] = None, generation_len: Optional[int] = None, + comp_ctx_lengths: Optional[List[int]] = None, enable_debug_logs: bool = False, stream: bool = True, write_io_dir: Optional[str] = None, @@ -368,6 +369,7 @@ def cloud_ai_100_exec_kv( qpc_path=qpc_path, device_id=device_id, ctx_len=ctx_len, + comp_ctx_lengths=comp_ctx_lengths, enable_debug_logs=enable_debug_logs, write_io_dir=write_io_dir, full_batch_size=full_batch_size, @@ -407,12 +409,14 @@ def __init__( qpc_path: str, full_batch_size: Optional[int] = None, ctx_len: Optional[int] = None, + comp_ctx_lengths: Optional[List[int]] = None, device_id: Optional[List[int]] = None, enable_debug_logs: bool = False, write_io_dir: Optional[str] = None, is_tlm: Optional[int] = None, ) -> None: self._ctx_len = ctx_len + self.comp_ctx_lengths = comp_ctx_lengths self._write_io_dir = write_io_dir self.is_tlm = is_tlm @@ -724,6 +728,11 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i batch_lora_ids = [self._prompt_to_lora_id_mapping_prefill.popleft() for i in range(self.batch_size)] inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1) + if self.comp_ctx_lengths is not None: + inputs["comp_ctx_lengths"] = np.random.rand(self.comp_ctx_lengths[0]) + buffers = {"comp_ctx_len_out": np.zeros(1)} + self._session.set_buffers(buffers) + for i in range(num_chunks): chunk_inputs = inputs.copy() chunk_inputs["input_ids"] = inputs["input_ids"][ @@ -741,6 +750,18 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i generation_len, ) + def initialize_ccl(self, decode_inputs): + max_ccl_id = len(self.comp_ctx_lengths) - 1 + max_position_id = np.max(decode_inputs["position_ids"]) + ccl_id = 1 + for i in range(1, len(self.comp_ctx_lengths)): + if max_position_id < self.comp_ctx_lengths[i]: + ccl_id = i + break + buffers = {"comp_ctx_len_out": np.zeros(1)} + + return buffers, ccl_id, max_ccl_id + def run_continuous_batching_decode(self, prompt_queue, generation_len): """ Runs continuous batching decode for the given prompt queue and generation length. @@ -771,6 +792,12 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): # Prepare decode inputs inputs. decode_inputs = self.prepare_decode_inputs() + if self.comp_ctx_lengths is not None: + list_of_comp_ctx_lengths = [np.zeros(length) for length in self.comp_ctx_lengths] + buffers, ccl_id, max_ccl_id = self.initialize_ccl(decode_inputs) + decode_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths[ccl_id] + self._session.set_buffers(buffers) + while prompt_queue or current_decode_ongoing.any(): outputs = self._session.run(decode_inputs) @@ -808,6 +835,19 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): batch_id_map[decode_batch_id] ] + if self.comp_ctx_lengths is not None: + ###Recalculate ccl_id based on position ids### + # Determine the maximum value of position_ids across all batch elements + max_position_id = np.max(decode_inputs["position_ids"]) + + # Update ccl_id and comp_ctx_lengths based on the maximum position id + ccl_id = 1 + for i in range(1, len(self.comp_ctx_lengths)): + if max_position_id < self.comp_ctx_lengths[i]: + ccl_id = i + break + decode_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths[ccl_id] + else: current_decode_ongoing[decode_batch_id] = False else: @@ -818,6 +858,12 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): next_token_id[decode_batch_id, -1] ) + if self.comp_ctx_lengths is not None: + # Update ccl_id and comp_ctx_lengths based on the maximum position id + if decode_inputs["position_ids"][decode_batch_id, -1] >= self.comp_ctx_lengths[ccl_id] - 1: + ccl_id = min(ccl_id + 1, max_ccl_id) + decode_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths[ccl_id] + generated_id_current_index[decode_batch_id] += 1 return decode_pause_time @@ -842,7 +888,21 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform self._session.set_buffers({"logits": logits_out_placeholder}) finished_sequences = decode_inputs["input_ids"] == self.tokenizer.eos_token_id num_token = 0 + + if self.comp_ctx_lengths is not None: + list_of_comp_ctx_lengths = [np.zeros(length) for length in self.comp_ctx_lengths] + buffers, ccl_id, max_ccl_id = self.initialize_ccl(decode_inputs) + decode_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths[ccl_id] + self._session.set_buffers(buffers) + + cache_index = np.max(decode_inputs["position_ids"]) for num_token in range(1, generation_len): + if self.comp_ctx_lengths is not None: + if cache_index >= self.comp_ctx_lengths[ccl_id] - 1: + # if cache_index >= self.comp_ctx_lengths[ccl_id] - 1: + ccl_id = min(ccl_id + 1, max_ccl_id) + decode_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths[ccl_id] + if streamer: streamer.put(decode_inputs["input_ids"][0]) outputs = self._session.run(decode_inputs) @@ -854,6 +914,7 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform # Prepare inputs for next iteration decode_inputs["input_ids"] = outputs["logits"].argmax(2) decode_inputs["position_ids"][:, -1] += 1 + cache_index += 1 self.generated_ids[:, num_token] = decode_inputs["input_ids"][:, -1] finished_sequences |= decode_inputs["input_ids"] == self.tokenizer.eos_token_id @@ -901,17 +962,27 @@ def __init__( qpc_path: str, full_batch_size: Optional[int] = None, ctx_len: Optional[int] = None, + comp_ctx_lengths: Optional[List[int]] = None, device_id: Optional[List[int]] = None, enable_debug_logs: bool = False, write_io_dir: Optional[str] = None, is_tlm: bool = False, ) -> None: self._qaic_model = QEffTextGenerationBase( - tokenizer, qpc_path, full_batch_size, ctx_len, device_id, enable_debug_logs, write_io_dir, is_tlm + tokenizer, + qpc_path, + full_batch_size, + ctx_len, + comp_ctx_lengths, + device_id, + enable_debug_logs, + write_io_dir, + is_tlm, ) self._full_batch_size = self._qaic_model.full_batch_size self._tokenizer = self._qaic_model.tokenizer self._ctx_len = ctx_len + self.comp_ctx_lengths = comp_ctx_lengths self._perf_metrics = None self._prompt_queue = None self._text_streamer = None diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index f9d529038..4b1e243e5 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -91,6 +91,8 @@ def read_only(self, layer_idx, cache_kwargs): k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] position_ids = cache_kwargs.get("position_ids") batch_index = cache_kwargs.get("batch_index", None) + comp_ctx_len = cache_kwargs.get("CCL") + ctx_len = k_out.shape[2] ctx_indices = torch.arange(ctx_len)[None, None, ...] gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) @@ -101,15 +103,19 @@ def read_only(self, layer_idx, cache_kwargs): else: invalid_idx_value = 0 + ctx_indices = ctx_indices[:, :, :comp_ctx_len] + invalid_mask = ctx_indices > gather_limit + + invalid_mask = invalid_mask[:, :, :comp_ctx_len] + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) if batch_index is not None: - k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices) - v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices) + k_out = CtxGatherFuncCB.apply(self.key_cache[layer_idx], batch_index, ctx_indices, comp_ctx_len) + v_out = CtxGatherFuncCB.apply(self.value_cache[layer_idx], batch_index, ctx_indices, comp_ctx_len) else: - k_out = CtxGatherFunc.apply(k_out, ctx_indices) - v_out = CtxGatherFunc.apply(v_out, ctx_indices) - + k_out = CtxGatherFunc.apply(self.key_cache[layer_idx], ctx_indices, comp_ctx_len) + v_out = CtxGatherFunc.apply(self.value_cache[layer_idx], ctx_indices, comp_ctx_len) v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) return k_out, v_out @@ -144,6 +150,7 @@ def update( else: position_ids = cache_kwargs.get("position_ids") batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value form the kwargs + comp_ctx_len = cache_kwargs.get("CCL") # Scatter if batch_index is not None: @@ -163,26 +170,29 @@ def update( self.value_cache[layer_idx], position_ids, value_states ) - k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] - # Gather - ctx_len = k_out.shape[2] + ctx_len = self.key_cache[layer_idx].shape[2] ctx_indices = torch.arange(ctx_len)[None, None, ...] gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) - invalid_mask = ctx_indices > gather_limit if torch.onnx.is_in_onnx_export(): invalid_idx_value = torch.iinfo(torch.int32).max else: invalid_idx_value = 0 + ctx_indices = ctx_indices[:, :, :comp_ctx_len] + invalid_mask = ctx_indices > gather_limit + + invalid_mask = invalid_mask[:, :, :comp_ctx_len] + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + if batch_index is not None: - k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices) - v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices) + k_out = CtxGatherFuncCB.apply(self.key_cache[layer_idx], batch_index, ctx_indices, comp_ctx_len) + v_out = CtxGatherFuncCB.apply(self.value_cache[layer_idx], batch_index, ctx_indices, comp_ctx_len) else: - k_out = CtxGatherFunc.apply(k_out, ctx_indices) - v_out = CtxGatherFunc.apply(v_out, ctx_indices) + k_out = CtxGatherFunc.apply(self.key_cache[layer_idx], ctx_indices, comp_ctx_len) + v_out = CtxGatherFunc.apply(self.value_cache[layer_idx], ctx_indices, comp_ctx_len) v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) return k_out, v_out diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index dae783361..6e2a563cd 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -5,6 +5,7 @@ # # ----------------------------------------------------------------------------- +from dataclasses import dataclass from typing import Callable, List, Optional, Tuple, Union import torch @@ -29,6 +30,16 @@ from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +@dataclass +class QEffBaseModelOutputWithPast(BaseModelOutputWithPast): + comp_ctx_len_out: Optional[torch.LongTensor] = None + + +@dataclass +class QEffCausalLMOutputWithPast(CausalLMOutputWithPast): + comp_ctx_len_out: Optional[torch.LongTensor] = None + + class QEffLlamaRotaryEmbedding(LlamaRotaryEmbedding): """ Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @@ -130,6 +141,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: bool = False, use_cache: bool = False, @@ -154,8 +166,16 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids, + "CCL": attention_mask.shape[-1], + } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -188,6 +208,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -204,6 +225,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -241,6 +263,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -294,6 +317,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -315,11 +339,13 @@ def forward( if return_legacy_cache: past_key_values = past_key_values.to_legacy_cache() - output = BaseModelOutputWithPast( + comp_ctx_len_out = comp_ctx_lengths[comp_ctx_lengths.shape[-1] - 1 :] if comp_ctx_lengths is not None else None + output = QEffBaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, + comp_ctx_len_out=comp_ctx_len_out, ) return output if return_dict else output.to_tuple() @@ -337,6 +363,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -360,6 +387,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, @@ -377,10 +405,12 @@ def forward( logits = self.lm_head(hidden_states) logits = logits.float() - return CausalLMOutputWithPast( + comp_ctx_len_out = comp_ctx_lengths[comp_ctx_lengths.shape[-1] - 1 :] if comp_ctx_lengths is not None else None + return QEffCausalLMOutputWithPast( loss=None, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + comp_ctx_len_out=comp_ctx_len_out, ) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index f181ee5eb..1f0f8541c 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1326,6 +1326,7 @@ def __init__( self.continuous_batching = continuous_batching self.model, transformed = SpDTransform.apply(self.model, qaic_config, **kwargs) self.is_tlm = transformed + self.comp_ctx_lengths = kwargs.pop("comp_ctx_lengths", None) @property def model_name(self) -> str: @@ -1388,6 +1389,8 @@ def from_pretrained( kv_offload = kwargs.pop("kv_offload", None) + comp_ctx_lengths = kwargs.pop("comp_ctx_lengths", None) + kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) if qaic_config is not None: @@ -1404,6 +1407,7 @@ def from_pretrained( model, continuous_batching=continuous_batching, qaic_config=qaic_config, + comp_ctx_lengths=comp_ctx_lengths, **kwargs, ) @@ -1447,6 +1451,10 @@ def export(self, export_dir: Optional[str] = None) -> str: "input_ids": {0: "batch_size", 1: "seq_len"}, "position_ids": {0: "batch_size", 1: "seq_len"}, } + if self.comp_ctx_lengths is not None: + example_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long) + dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + if len(kv_cache_shape) == 3: # For GPTBigCode arch the pkv is 3d pkv_dynamic_axes = { 0: "full_batch_size" if self.continuous_batching else "batch_size", @@ -1485,6 +1493,7 @@ def build_prefill_specialization( self, prefill_seq_len: int = 32, ctx_len: int = 128, + comp_ctx_lengths: Optional[int] = None, batch_size: int = 1, kv_cache_batch_size: Optional[int] = None, full_batch_size: Optional[int] = None, @@ -1495,6 +1504,9 @@ def build_prefill_specialization( "ctx_len": ctx_len, "num_logits_to_keep": 1 if self.is_tlm else None, } + if comp_ctx_lengths is not None: + spec["comp_ctx_lengths"] = comp_ctx_lengths + if self.continuous_batching: spec["full_batch_size"] = kv_cache_batch_size else: @@ -1507,6 +1519,7 @@ def build_decode_specialization( self, prefill_seq_len: int = 32, ctx_len: int = 128, + comp_ctx_lengths: Optional[int] = None, batch_size: int = 1, kv_cache_batch_size: Optional[int] = None, full_batch_size: Optional[int] = None, @@ -1520,6 +1533,9 @@ def build_decode_specialization( "ctx_len": ctx_len, "num_logits_to_keep": (num_speculative_tokens + 1) if self.is_tlm else None, } + if comp_ctx_lengths is not None: + spec["comp_ctx_lengths"] = comp_ctx_lengths + if self.continuous_batching: spec["full_batch_size"] = kv_cache_batch_size else: @@ -1598,17 +1614,40 @@ def compile( specializations = [] if prefill_only is None or prefill_only or prefill_seq_len == 1: + ctx_for_specialization = self.comp_ctx_lengths[0] if self.comp_ctx_lengths is not None else None specializations.append( self.build_prefill_specialization( - prefill_seq_len, ctx_len, batch_size, kv_cache_batch_size, full_batch_size + prefill_seq_len, ctx_len, ctx_for_specialization, batch_size, kv_cache_batch_size, full_batch_size ) ) if prefill_only is None or not prefill_only: - decode_spec = self.build_decode_specialization( - prefill_seq_len, ctx_len, batch_size, kv_cache_batch_size, full_batch_size, num_speculative_tokens - ) - if decode_spec: - specializations.append(decode_spec) + if self.comp_ctx_lengths is not None: + # Adding elements from self.comp_ctx_lengths to decode_specialization + for i in range(1, len(self.comp_ctx_lengths)): + decode_spec = self.build_decode_specialization( + prefill_seq_len, + ctx_len, + self.comp_ctx_lengths[i], + batch_size, + kv_cache_batch_size, + full_batch_size, + num_speculative_tokens, + ) + if decode_spec: + specializations.append(decode_spec) + + else: + decode_spec = self.build_decode_specialization( + prefill_seq_len, + ctx_len, + None, + batch_size, + kv_cache_batch_size, + full_batch_size, + num_speculative_tokens, + ) + if decode_spec: + specializations.append(decode_spec) # --- Compilation --- kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16" @@ -1673,6 +1712,7 @@ def generate( tokenizer, self.qpc_path, prompt=prompts, + comp_ctx_lengths=self.comp_ctx_lengths, device_id=device_id, generation_len=generation_len, is_tlm=self.is_tlm, diff --git a/examples/compute_context_length.py b/examples/compute_context_length.py new file mode 100644 index 000000000..7e4d1f304 --- /dev/null +++ b/examples/compute_context_length.py @@ -0,0 +1,47 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +## In this example, you can run a model for static and continuous batching with different Compute-Context-Length (CCL) inputs. ## + +from transformers import AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM + +## Using optional variable comp_ctx_lengths variable you can pass a list of context lengths. It will run the model with default context length if comp_ctx_lengths=None. ## +## - The first number in this list is the context length that will be used during prefilling. ## +## - During the decoding process, based on the position_id or cache index it will work with the specific compute-context-length in the list. It will start from a proper compute-context-length in the list based on input prompt length and will gradually increase the compute-context-length if the cache index passes the current compute-context-length. ## +comp_ctx_lengths = [256, 512, 1024] # None + +model_name = "meta-llama/Llama-3.2-1B-Instruct" +model = QEFFAutoModelForCausalLM.from_pretrained( + model_name, continuous_batching=True, comp_ctx_lengths=comp_ctx_lengths +) +# model = QEFFAutoModelForCausalLM.from_pretrained(model_name, comp_ctx_lengths=comp_ctx_lengths) + +# model compilation for either continuous or static batching. For continuous batching full_batch_size is needed. +model.compile( + prefill_seq_len=128, + ctx_len=1024, + num_cores=16, + num_devices=1, + full_batch_size=4, + mxfp6_matmul=True, + mxint8_kv_cache=True, +) +# model.compile(prefill_seq_len=128, ctx_len=1024, num_cores=16, num_devices=1,batch_size=4,mxfp6_matmul=True,mxint8_kv_cache=True) + +# Create tokenizer and run model.generate and passes the input prompts to it. It also receives comp_ctx_lengths list which will be used during the decoding process to apply the best and most efficient compute context length. +tokenizer = AutoTokenizer.from_pretrained(model_name) +model.generate( + prompts=[ + "What are some healthy foods to include in a balanced diet?", + "What is a nutritious meal that can keep you energized throughout the day?", + "What are some fun and relaxing activities to do over the weekend?", + "What's your favorite hobby?", + ], + tokenizer=tokenizer, +) diff --git a/tests/transformers/test_compute_context_length.py b/tests/transformers/test_compute_context_length.py new file mode 100644 index 000000000..a68f68141 --- /dev/null +++ b/tests/transformers/test_compute_context_length.py @@ -0,0 +1,176 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import copy +import os +from time import perf_counter + +import onnx +import pytest +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + +configs = [ + # name, max_position_embeddings, num_hidden_layers, num_attention_heads, hidden_size, intermediate_size, vocab_size, additional_params + ("gpt2", 256, 2, 4, 128, 512, 127, {}), + ("codegen", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), + ("falcon", 256, 2, 4, 128, 512, 127, {}), + ("gptj", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), + ("llama", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("mistral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("mixtral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("mpt", 256, 2, 4, 128, 512, 127, {}), + ("phi", 256, 2, 4, 128, 512, 127, {}), + ("phi3", 256, 2, 4, 128, 512, 127, {"pad_token_id": 0}), + ("qwen2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("starcoder2", 256, 2, 4, 128, 512, 127, {}), + ("granite", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), +] + +configs = [ + AutoConfig.for_model( + model_name, + max_position_embeddings=max_position_embeddings, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + vocab_size=vocab_size, + **additional_params, + ) + for ( + model_name, + max_position_embeddings, + num_hidden_layers, + num_attention_heads, + hidden_size, + intermediate_size, + vocab_size, + additional_params, + ) in configs +] +config_ids = [x.model_type for x in configs] + +model_kwargs = {"attn_implementation": "eager"} + + +@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) +def test_causal_lm_unsupported(cb): + model = AutoModelForCausalLM.from_config(AutoConfig.for_model("opt")) + with pytest.warns(): + QEFFAutoModelForCausalLM(model, cb) + + +@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) +@pytest.mark.parametrize("config", configs, ids=config_ids) +def test_causal_lm_init(config, cb): + model = AutoModelForCausalLM.from_config(config, **model_kwargs) + qeff_model = QEFFAutoModelForCausalLM(model, cb) + with pytest.raises(TypeError): + QEFFAutoModelForCausalLM(AutoModel.from_config(config, **model_kwargs), cb) + assert qeff_model.model.__class__.__name__.startswith("QEff") + + +@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) +@pytest.mark.parametrize("config", configs, ids=config_ids) +def test_causal_lm_pretrained(config, cb, tmp_path): + model = AutoModelForCausalLM.from_config(config, **model_kwargs) + model.save_pretrained(tmp_path) + + qeff_model = QEFFAutoModelForCausalLM.from_pretrained(tmp_path, cb) + assert qeff_model.model.__class__.__name__.startswith("QEff") + + +@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) +@pytest.mark.parametrize("config", configs, ids=config_ids) +def test_causal_lm_hash(config, cb): + hash_0_0 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(config, **model_kwargs), cb).model_hash + hash_0_1 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(config, **model_kwargs), cb).model_hash + + assert hash_0_0 == hash_0_1 + + cfg1 = copy.deepcopy(config) + cfg1.num_hidden_layers -= 1 + hash_1_0 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(cfg1, **model_kwargs), cb).model_hash + cfg2 = copy.deepcopy(config) + cfg2.num_hidden_layers -= 1 + hash_1_1 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(cfg2, **model_kwargs), cb).model_hash + assert hash_1_0 == hash_1_1 + + assert hash_0_0 != hash_1_0 + + if cb: + hash_0_no_cb = QEFFAutoModelForCausalLM( + AutoModelForCausalLM.from_config(config, **model_kwargs), False + ).model_hash + assert hash_0_0 != hash_0_no_cb + + +@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) +@pytest.mark.parametrize("config", configs, ids=config_ids) +def test_causal_lm_export(config, cb, tmp_path): + model = AutoModelForCausalLM.from_config(config, **model_kwargs) + qeff_model = QEFFAutoModelForCausalLM(model, cb) + comp_ctx_lengths = [512, 1024, 2048] + qeff_model.export(comp_ctx_lengths, tmp_path) + model_path = tmp_path.with_name(tmp_path.name + "-" + qeff_model.model_hash) + assert model_path.is_dir() + assert qeff_model.onnx_path.is_file() + assert qeff_model.onnx_path.relative_to(model_path).parts == (qeff_model.model_name + ".onnx",) + + # Check if the KV-cache inputs and outputs are created + onnx_model = onnx.load(qeff_model.onnx_path, load_external_data=False) + retained_output_names = { + x.name[: -len("_RetainedState")] for x in onnx_model.graph.output if x.name.endswith("_RetainedState") + } + retained_output_names.issubset({x.name for x in onnx_model.graph.input}) + + # Check if there is no re-export + start = perf_counter() + qeff_model.export(tmp_path) + end = perf_counter() + export_time = end - start + assert export_time < 2.0 + + +@pytest.fixture +def tmp_cache(tmp_path, monkeypatch): + monkeypatch.setattr("QEfficient.base.modeling_qeff.QEFF_HOME", tmp_path) + yield tmp_path + + +@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) +@pytest.mark.parametrize("config", configs, ids=config_ids) +def test_causal_lm_compile(config, cb, tmp_cache): + model = AutoModelForCausalLM.from_config(config, **model_kwargs) + comp_ctx_lengths = [8, 12, 16] + qeff_model = QEFFAutoModelForCausalLM(model, cb, comp_ctx_lengths=comp_ctx_lengths) + compile_params = {"prefill_seq_len": 8, "ctx_len": 16} + if cb: + compile_params["full_batch_size"] = 32 + compile_params["batch_size"] = 8 + qeff_model.compile(**compile_params) + model_path = tmp_cache / (qeff_model.model_name + "-" + qeff_model.model_hash) + + # Check if ONNX is exported properly + assert model_path.is_dir() + assert qeff_model.onnx_path.is_file() + assert qeff_model.onnx_path.relative_to(model_path).parts == (qeff_model.model_name + ".onnx",) + + # Check if QPC is compiled properly + assert qeff_model.qpc_path.is_dir() + assert (qeff_model.qpc_path / "programqpc.bin").is_file() + assert qeff_model.qpc_path.relative_to(tmp_cache).parts[0] == qeff_model.model_name + "-" + qeff_model.model_hash + + # Check if there is no re-compilation + start = perf_counter() + qeff_model.compile(**compile_params) + end = perf_counter() + compile_time = end - start + assert compile_time < 2.0 + assert os.path.isfile(os.path.join(os.path.dirname(qeff_model.qpc_path), "qconfig.json"))