Skip to content

Conversation

@oreomaker
Copy link
Collaborator

@oreomaker oreomaker commented Jan 20, 2026

Please check Guidelines for Contributing.

Summary by CodeRabbit

  • New Features

    • Added context persistence functionality for model state saving
    • Added greedy token sampling method for text generation
  • Refactor

    • Consolidated configuration structures across runtime components
    • Enhanced tensor input/output preparation and binding in model execution
    • Streamlined model module initialization process

✏️ Tip: You can customize this high-level summary in your review settings.

- Add saveContext method to QNNBackend for saving context binary to file
- Implement proper output tensor validation and allocation in graphExecute
- Remove redundant output reordering logic that was causing issues
- Add tensor caching and management improvements in QnnAOTGraph
- Enhance QnnAOTEnv to properly track and retrieve tensors
- Add sub-graph input/output tensor capture in LLM2QnnLoweringPass
- Remove duplicate allocation warning in QNNTensorWrapper::alloc
- Remove redundant temperature parameter from example application
- Replace custom config structs with unified QnnAOTConfig
- Move initialization of QNN backend after argument parsing
- Simplify module construction by removing model path dependency
- Add sampleGreedy method to QnnAOTModule for token sampling
- Update tensor shapes and I/O handling for proper cache management
- Remove unused includes and commented code for cleaner implementation
fix(qnn-aot): add position ID handling in PromptProcessor
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 20, 2026

📝 Walkthrough

Walkthrough

This PR refactors the QNN AOT runtime infrastructure to consolidate configuration management and streamline data flow APIs. It introduces a unified QnnAOTConfig structure replacing scattered per-component config types, refactors input/output handling in the QNN backend to use wrapper-based allocation, adds context persistence capabilities, updates the Runner::generate API to accept Tensor directly, and captures subgraph I/O during lowering.

Changes

Cohort / File(s) Summary
CLI Example
examples/qwen3_qnn_aot/aot_run.cpp
Removed temperature CLI option; moved backend initialization after help check; directly pass tensor sequence to runner.generate instead of constructing prompt tokens manually
QNN Backend Core
mllm/backends/qnn/QNNBackend.cpp, mllm/backends/qnn/QNNBackend.hpp
Added saveContext() method to persist QNN context binary to file; refactored graphExecute input/output handling to use wrapper-based allocation with nil-checks instead of in-place direct handling; removed output reordering logic
Configuration Consolidation
mllm/backends/qnn/aot_rt/QnnAOTConfig.hpp, mllm/backends/qnn/aot_rt/KVCacheManager.*, mllm/backends/qnn/aot_rt/PromptProcessor.*, mllm/backends/qnn/aot_rt/QnnAOTRuntime.hpp, mllm/backends/qnn/aot_rt/TokenGenerator.*
Introduced new QnnAOTConfig struct with consolidated model hyperparameters; replaced scattered Config/KVCacheConfig types across KVCacheManager, PromptProcessor, TokenGenerator, and QnnAOTRuntime with unified QnnAOTConfig; changed RunnerConfig to alias of QnnAOTConfig
AOT Module & Runtime
mllm/backends/qnn/aot_rt/QnnAOTModule.*, mllm/backends/qnn/aot_rt/QnnAOTRuntime.cpp
Removed model_path parameter from QnnAOTModule constructor; added sampleGreedy() method for greedy sampling; updated Runner::generate to accept Tensor instead of vector<uint64_t>&; refactored token processing to read from tensor data; updated KV cache tensor shapes and indexing
PromptProcessor & TokenGenerator
mllm/backends/qnn/aot_rt/PromptProcessor.cpp, mllm/backends/qnn/aot_rt/TokenGenerator.cpp
Updated constructors to accept QnnAOTConfig; refactored QnnAOTModule instantiation to drop model_path; adjusted KV cache tensor shapes and indexing; updated prepare_io signature to include prompt_pos parameter; changed logits extraction and sampling logic
AOT Graph Management
mllm/backends/qnn/aot/QnnWrappersAPI.cpp, mllm/backends/qnn/aot/QnnWrappersAPI.hpp
Added QnnAOTGraph::addTensor() method for tensor registration; refactored captureQnnAOTNodeTensor() to cache tensors in graph's all_tensors_ map and register via addTensor() before returning
Lowering Pass
mllm/backends/qnn/aot/passes/LLM2QnnLoweringPass.cpp
Added capture of subgraph inputs/outputs by iterating region I/O and recording in AOT environment before operation traversal
Utilities
mllm/backends/qnn/QNNUtils.cpp, mllm/mllm.inl
Removed runtime warning for repeated tensor allocations; added kUInt16 dtype case to tensor value printing logic

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • chenghuaWang
  • liang1232018

🐰 Hop, hop! A config rebrand so grand,
Tensors now flow through AOT's land,
Context is saved with a context so bright,
Wrappers and caches aligned just right!

🚥 Pre-merge checks | ✅ 1 | ❌ 2
❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Description check ⚠️ Warning The PR description is entirely empty except for the repository contribution template boilerplate, with no actual content describing the changes, motivation, or impact. Provide a detailed description covering: main objectives (QNN AOT features, config refactoring), key changes across the modified files, and any testing or breaking changes introduced.
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title 'feat(qualcomm): Qnn AOT Runtime' accurately summarizes the main change: introduction of QNN AOT runtime functionality across multiple backend files and supporting utilities.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
mllm/backends/qnn/aot_rt/TokenGenerator.cpp (1)

36-55: Validate context_len before using context_len - 1 for V-cache shape.

If context_len is 1 (or misconfigured), the V-cache dimension becomes 0/negative, which may be invalid for QNN tensors. Consider asserting context_len > 1 or handling the edge case.

mllm/backends/qnn/aot_rt/PromptProcessor.cpp (1)

42-62: Guard against zero-length past-cache tensors when ar_len == context_len.

context_len - ar_len becomes 0 in that case; ensure the backend supports zero-length dimensions or skip creating the past-cache tensors.

🤖 Fix all issues with AI agents
In `@mllm/backends/qnn/aot_rt/QnnAOTModule.cpp`:
- Around line 17-22: The sampleGreedy implementation assumes logits are uint16_t
which will misread other dtypes; update QnnAOTModule::sampleGreedy to check the
tensor dtype (via logits.dtype() or equivalent) and either assert/throw if
unsupported or dispatch per dtype (handle at least float32, float16 and uint16),
reading the correct element type from logits.ptr<T>() and using std::max_element
on that typed pointer; ensure you reference QnnAOTModule::sampleGreedy, the
logits parameter, logits.ptr<>, and logits.shape().back() when locating and
modifying the code.

In `@mllm/backends/qnn/aot_rt/QnnAOTRuntime.cpp`:
- Around line 57-78: The generate() implementation currently stops after prefill
and ignores seq_len because the decode call was commented out; restore the
decode path by un-commenting and using token_generator_->generate with the
Tensor input and correct current position: compute int64_t cur_pos =
prompt_tokens.shape()[1] (or start_pos + prompt length) after prefill, then call
token_generator_->generate(prompt_tokens, cur_pos, seq_len, token_callback,
false) so the generator consumes the Tensor input and honors seq_len; keep
existing uses of prompt_processor_->prefill and tokenizer_->detokenize and
ensure token_callback is forwarded to token_generator_->generate.
- Around line 59-66: The current assertion before accessing prompt_tokens raw
data doesn't validate the batch dimension, so access via
prompt_tokens.ptr<int64_t>()[i] assumes a single batch; update the checks to
ensure prompt_tokens.shape()[0] == 1 as well as rank==2 and dtype==kInt64
(either by extending the existing MLLM_RT_ASSERT or adding an additional assert)
so the subsequent loop over prompt_tokens.shape()[1] is safe; reference the
existing MLLM_RT_ASSERT and prompt_tokens.shape() usage around where
prompt_tokens_i64 is filled.

In `@mllm/backends/qnn/QNNBackend.cpp`:
- Around line 440-457: QNNBackend::saveContext must perform and act on error
returns and validate file open/size before writing: check the return value of
runtime_->qnnInterface.contextGetBinarySize(context_, &binarySize) and bail/log
on failure; after allocating binaryBuffer, check the return of
runtime_->qnnInterface.contextGetBinary(context_, ..., &writtenSize) and
bail/log on failure; if writtenSize != binarySize treat it as an error and do
not proceed to write the buffer; verify std::ofstream file(contextPath,
std::ios::binary).is_open() before file.write and handle/log/return on failure;
ensure all early exits free resources and log via MLLM_ERROR/MllM_INFO as
appropriate.
🧹 Nitpick comments (9)
mllm/backends/qnn/aot/passes/LLM2QnnLoweringPass.cpp (1)

149-158: Consider logging when non-tensor inputs/outputs are encountered.

The silent skip when cast_<ir::tensor::TensorValue>() returns nullptr could hide unexpected input/output types. If all subgraph I/O is expected to be tensor values, consider adding a debug log or warning when the cast fails for traceability.

♻️ Optional: Add debug logging for skipped non-tensor values
     // Add sub-graph inputs
     for (auto& input : region->inputs()) {
       auto tensor_input = input->cast_<ir::tensor::TensorValue>();
-      if (tensor_input) { aot_env->captureQnnAOTNodeTensor("context.0", subgraph_name, tensor_input); }
+      if (tensor_input) {
+        aot_env->captureQnnAOTNodeTensor("context.0", subgraph_name, tensor_input);
+      } else {
+        MLLM_DEBUG("Skipped non-tensor input in subgraph: {}", subgraph_name);
+      }
     }
     // Add sub-graph outputs
     for (auto& output : region->outputs()) {
       auto tensor_output = output->cast_<ir::tensor::TensorValue>();
-      if (tensor_output) { aot_env->captureQnnAOTNodeTensor("context.0", subgraph_name, tensor_output); }
+      if (tensor_output) {
+        aot_env->captureQnnAOTNodeTensor("context.0", subgraph_name, tensor_output);
+      } else {
+        MLLM_DEBUG("Skipped non-tensor output in subgraph: {}", subgraph_name);
+      }
     }
mllm/backends/qnn/QNNBackend.hpp (1)

92-92: Consider using the QNN_Context_File constant for the default parameter.

The constant QNN_Context_File is defined at line 19 with the same value. Using the constant would improve maintainability and ensure consistency if the default filename ever changes.

♻️ Proposed fix
-  void saveContext(const std::string& contextPath = "qnn_context.bin");
+  void saveContext(const std::string& contextPath = QNN_Context_File);
mllm/backends/qnn/aot_rt/KVCacheManager.hpp (1)

26-26: Consider passing QnnAOTConfig by const reference.

Since QnnAOTConfig is a struct with multiple members, passing by const reference avoids an unnecessary copy during construction.

♻️ Proposed fix
-  explicit KVCacheManager(QnnAOTConfig config);
+  explicit KVCacheManager(const QnnAOTConfig& config);
mllm/backends/qnn/aot_rt/QnnAOTConfig.hpp (1)

10-26: Consider documenting model-specific default values.

The default values appear to be specific to a particular model (e.g., vocab_size = 151936 suggests Qwen). Consider adding a comment indicating which model these defaults target, or using named constants to make the configuration clearer.

♻️ Optional: Add documentation for defaults
 struct QnnAOTConfig {
+  // Default values are configured for Qwen models
   int num_layers = 28;
   int num_heads = 12;
   int head_dim = 128;
-  int vocab_size = 151936;
+  int vocab_size = 151936;  // Qwen vocab size

   int context_len = 4096;
   int ar_len = 128;  // Chunk size for prefill
   int sliding_window = 0;

-  // Derived/Computed
+  // Runtime limits (should be set based on context_len and ar_len)
   int max_ar_len = 128;
   int max_cache_len = 4096;
mllm/backends/qnn/aot_rt/TokenGenerator.cpp (2)

9-13: Avoid extra QnnAOTConfig copy in the constructor.

config is passed by value, so config_(config) performs an extra copy. Prefer moving it into the member to keep construction cheaper.

♻️ Proposed change
-    : tokenizer_(tokenizer), kv_manager_(kv_manager), eos_ids_(std::move(eos_ids)), config_(config) {
+    : tokenizer_(tokenizer), kv_manager_(kv_manager), eos_ids_(std::move(eos_ids)), config_(std::move(config)) {

117-120: Avoid per-token vector copies when invoking the module.

auto module_input = input_tensors_ creates an extra copy each decode step. If QnnAOTModule::operator() can accept a const ref, pass input_tensors_ directly; otherwise consider moving.

♻️ Proposed change
-    auto module_input = input_tensors_;
-    output_tensors_ = (*module_)(module_input);
+    output_tensors_ = (*module_)(input_tensors_);
mllm/backends/qnn/aot_rt/QnnAOTRuntime.hpp (1)

17-25: Document the Tensor input contract for Runner::generate.

generate now accepts a Tensor, but the required rank/dtype/batch shape isn’t stated in the header. Please add a brief doc comment to prevent misuse. As per coding guidelines, please document the expected rank (2), dtype (kInt64), and batch size (1).

mllm/backends/qnn/aot_rt/PromptProcessor.cpp (2)

122-125: Consider gating duplicate attention-mask init/update calls.

Both overloads are invoked back-to-back. If the sliding-window variant supersedes the non-window one, you can avoid the redundant call (or conditionally call one based on sliding_window) to save work.

Also applies to: 142-144


135-136: Avoid per-chunk vector copies when invoking the module.

If QnnAOTModule::operator() can accept a const ref, pass input_tensors_ directly to avoid the extra copy per chunk.

♻️ Proposed change
-    auto module_input = input_tensors_;
-    output_tensors_ = (*module_)(module_input);
+    output_tensors_ = (*module_)(input_tensors_);

Comment on lines +17 to +22
int64_t QnnAOTModule::sampleGreedy(mllm::Tensor& logits) {
auto logits_data = logits.ptr<uint16_t>();
int vocab_size = logits.shape().back();
auto max_it = std::max_element(logits_data, logits_data + vocab_size);
return std::distance(logits_data, max_it);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Hardcoded uint16_t type assumption may cause incorrect sampling results.

The sampleGreedy method unconditionally casts logits data to uint16_t*. If the logits tensor has a different dtype (e.g., float32, float16), this will produce incorrect sampling results without any warning.

Consider either:

  1. Adding a dtype assertion/check before the cast
  2. Dispatching based on the tensor's actual dtype
Suggested fix with dtype validation
 int64_t QnnAOTModule::sampleGreedy(mllm::Tensor& logits) {
+  MLLM_RT_ASSERT_EQ(logits.dtype(), kUInt16);
   auto logits_data = logits.ptr<uint16_t>();
   int vocab_size = logits.shape().back();
   auto max_it = std::max_element(logits_data, logits_data + vocab_size);
   return std::distance(logits_data, max_it);
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
int64_t QnnAOTModule::sampleGreedy(mllm::Tensor& logits) {
auto logits_data = logits.ptr<uint16_t>();
int vocab_size = logits.shape().back();
auto max_it = std::max_element(logits_data, logits_data + vocab_size);
return std::distance(logits_data, max_it);
}
int64_t QnnAOTModule::sampleGreedy(mllm::Tensor& logits) {
MLLM_RT_ASSERT_EQ(logits.dtype(), kUInt16);
auto logits_data = logits.ptr<uint16_t>();
int vocab_size = logits.shape().back();
auto max_it = std::max_element(logits_data, logits_data + vocab_size);
return std::distance(logits_data, max_it);
}
🤖 Prompt for AI Agents
In `@mllm/backends/qnn/aot_rt/QnnAOTModule.cpp` around lines 17 - 22, The
sampleGreedy implementation assumes logits are uint16_t which will misread other
dtypes; update QnnAOTModule::sampleGreedy to check the tensor dtype (via
logits.dtype() or equivalent) and either assert/throw if unsupported or dispatch
per dtype (handle at least float32, float16 and uint16), reading the correct
element type from logits.ptr<T>() and using std::max_element on that typed
pointer; ensure you reference QnnAOTModule::sampleGreedy, the logits parameter,
logits.ptr<>, and logits.shape().back() when locating and modifying the code.

Comment on lines +57 to 78
void Runner::generate(const Tensor& prompt_tokens, int32_t seq_len,
const std::function<void(const std::string&)>& token_callback) {
MLLM_RT_ASSERT(prompt_tokens.rank() == 2 && prompt_tokens.dtype() == kInt64);

int64_t start_pos = 0;

std::vector<int64_t> prompt_tokens_i64;
prompt_tokens_i64.reserve(prompt_tokens.size());
for (auto t : prompt_tokens) prompt_tokens_i64.push_back((int64_t)t);
prompt_tokens_i64.reserve(prompt_tokens.shape()[1]);
for (int i = 0; i < prompt_tokens.shape()[1]; i++) { prompt_tokens_i64.push_back(prompt_tokens.ptr<int64_t>()[i]); }

int64_t next_token = prompt_processor_->prefill(prompt_tokens_i64, start_pos);

prompt_tokens.push_back((uint64_t)next_token);
if (token_callback) {
std::wstring wstr = tokenizer_->detokenize(next_token);
std::string str = mllm::preprocessor::wideString2Utf8String(wstr);
token_callback(str);
}

int64_t cur_pos = prompt_tokens.size();
// int64_t cur_pos = prompt_tokens.size();

token_generator_->generate(prompt_tokens, cur_pos, seq_len, token_callback, false);
// token_generator_->generate(prompt_tokens, cur_pos, seq_len, token_callback, false);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Generation stops after prefill; seq_len is ignored.

The decode phase is commented out, so generate() emits only the first sampled token and never uses seq_len. Please restore the decode path with the new Tensor input.

🐛 Proposed fix
   int64_t next_token = prompt_processor_->prefill(prompt_tokens_i64, start_pos);

   if (token_callback) {
     std::wstring wstr = tokenizer_->detokenize(next_token);
     std::string str = mllm::preprocessor::wideString2Utf8String(wstr);
     token_callback(str);
   }
 
-  // int64_t cur_pos = prompt_tokens.size();
-  // token_generator_->generate(prompt_tokens, cur_pos, seq_len, token_callback, false);
+  std::vector<uint64_t> tokens;
+  tokens.reserve(prompt_tokens_i64.size() + 1);
+  for (auto t : prompt_tokens_i64) { tokens.push_back(static_cast<uint64_t>(t)); }
+  tokens.push_back(static_cast<uint64_t>(next_token));
+
+  int64_t cur_pos = static_cast<int64_t>(prompt_tokens_i64.size());
+  token_generator_->generate(tokens, cur_pos, seq_len, token_callback, false);
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
void Runner::generate(const Tensor& prompt_tokens, int32_t seq_len,
const std::function<void(const std::string&)>& token_callback) {
MLLM_RT_ASSERT(prompt_tokens.rank() == 2 && prompt_tokens.dtype() == kInt64);
int64_t start_pos = 0;
std::vector<int64_t> prompt_tokens_i64;
prompt_tokens_i64.reserve(prompt_tokens.size());
for (auto t : prompt_tokens) prompt_tokens_i64.push_back((int64_t)t);
prompt_tokens_i64.reserve(prompt_tokens.shape()[1]);
for (int i = 0; i < prompt_tokens.shape()[1]; i++) { prompt_tokens_i64.push_back(prompt_tokens.ptr<int64_t>()[i]); }
int64_t next_token = prompt_processor_->prefill(prompt_tokens_i64, start_pos);
prompt_tokens.push_back((uint64_t)next_token);
if (token_callback) {
std::wstring wstr = tokenizer_->detokenize(next_token);
std::string str = mllm::preprocessor::wideString2Utf8String(wstr);
token_callback(str);
}
int64_t cur_pos = prompt_tokens.size();
// int64_t cur_pos = prompt_tokens.size();
token_generator_->generate(prompt_tokens, cur_pos, seq_len, token_callback, false);
// token_generator_->generate(prompt_tokens, cur_pos, seq_len, token_callback, false);
}
void Runner::generate(const Tensor& prompt_tokens, int32_t seq_len,
const std::function<void(const std::string&)>& token_callback) {
MLLM_RT_ASSERT(prompt_tokens.rank() == 2 && prompt_tokens.dtype() == kInt64);
int64_t start_pos = 0;
std::vector<int64_t> prompt_tokens_i64;
prompt_tokens_i64.reserve(prompt_tokens.shape()[1]);
for (int i = 0; i < prompt_tokens.shape()[1]; i++) { prompt_tokens_i64.push_back(prompt_tokens.ptr<int64_t>()[i]); }
int64_t next_token = prompt_processor_->prefill(prompt_tokens_i64, start_pos);
if (token_callback) {
std::wstring wstr = tokenizer_->detokenize(next_token);
std::string str = mllm::preprocessor::wideString2Utf8String(wstr);
token_callback(str);
}
std::vector<uint64_t> tokens;
tokens.reserve(prompt_tokens_i64.size() + 1);
for (auto t : prompt_tokens_i64) { tokens.push_back(static_cast<uint64_t>(t)); }
tokens.push_back(static_cast<uint64_t>(next_token));
int64_t cur_pos = static_cast<int64_t>(prompt_tokens_i64.size());
token_generator_->generate(tokens, cur_pos, seq_len, token_callback, false);
}
🤖 Prompt for AI Agents
In `@mllm/backends/qnn/aot_rt/QnnAOTRuntime.cpp` around lines 57 - 78, The
generate() implementation currently stops after prefill and ignores seq_len
because the decode call was commented out; restore the decode path by
un-commenting and using token_generator_->generate with the Tensor input and
correct current position: compute int64_t cur_pos = prompt_tokens.shape()[1] (or
start_pos + prompt length) after prefill, then call
token_generator_->generate(prompt_tokens, cur_pos, seq_len, token_callback,
false) so the generator consumes the Tensor input and honors seq_len; keep
existing uses of prompt_processor_->prefill and tokenizer_->detokenize and
ensure token_callback is forwarded to token_generator_->generate.

Comment on lines +59 to 66
MLLM_RT_ASSERT(prompt_tokens.rank() == 2 && prompt_tokens.dtype() == kInt64);

int64_t start_pos = 0;

std::vector<int64_t> prompt_tokens_i64;
prompt_tokens_i64.reserve(prompt_tokens.size());
for (auto t : prompt_tokens) prompt_tokens_i64.push_back((int64_t)t);
prompt_tokens_i64.reserve(prompt_tokens.shape()[1]);
for (int i = 0; i < prompt_tokens.shape()[1]; i++) { prompt_tokens_i64.push_back(prompt_tokens.ptr<int64_t>()[i]); }

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add batch dimension validation to assert before tensor data access.

The loop reads tensor elements using ptr<int64_t>()[i] for i < shape()[1], which assumes the tensor is shaped [1, N]. Without validating shape()[0] == 1, a larger batch dimension would cause incorrect data access. Add this validation to the existing assertion:

Proposed fix
-  MLLM_RT_ASSERT(prompt_tokens.rank() == 2 && prompt_tokens.dtype() == kInt64);
+  MLLM_RT_ASSERT(prompt_tokens.rank() == 2 && prompt_tokens.dtype() == kInt64 && prompt_tokens.shape()[0] == 1);
🤖 Prompt for AI Agents
In `@mllm/backends/qnn/aot_rt/QnnAOTRuntime.cpp` around lines 59 - 66, The current
assertion before accessing prompt_tokens raw data doesn't validate the batch
dimension, so access via prompt_tokens.ptr<int64_t>()[i] assumes a single batch;
update the checks to ensure prompt_tokens.shape()[0] == 1 as well as rank==2 and
dtype==kInt64 (either by extending the existing MLLM_RT_ASSERT or adding an
additional assert) so the subsequent loop over prompt_tokens.shape()[1] is safe;
reference the existing MLLM_RT_ASSERT and prompt_tokens.shape() usage around
where prompt_tokens_i64 is filled.

Comment on lines +440 to +457
void QNNBackend::saveContext(const std::string& contextPath) {
uint64_t binarySize, writtenSize;

runtime_->qnnInterface.contextGetBinarySize(context_, &binarySize);

std::unique_ptr<uint8_t[]> binaryBuffer(new uint8_t[binarySize]);

runtime_->qnnInterface.contextGetBinary(context_, reinterpret_cast<void*>(binaryBuffer.get()), binarySize, &writtenSize);

if (binarySize < writtenSize) {
MLLM_ERROR("QNN context binary size mismatch. Written {} bytes, expected {} bytes.", writtenSize, binarySize);
}
std::ofstream file(contextPath, std::ios::binary);
file.write(reinterpret_cast<char*>(binaryBuffer.get()), writtenSize);
file.close();

MLLM_INFO("QNN context saved to {} written {} bytes.", contextPath, writtenSize);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Missing error handling in saveContext - may silently fail or write corrupt data.

Several issues compared to the similar QnnAOTEnv::saveContext implementation in QnnWrappersAPI.cpp:

  1. No error check on contextGetBinarySize return value (line 443)
  2. No error check on contextGetBinary return value (line 447)
  3. Size mismatch at line 449 logs an error but continues writing potentially corrupt data
  4. No check if file opened successfully before writing (line 452)
Suggested fix with proper error handling
 void QNNBackend::saveContext(const std::string& contextPath) {
   uint64_t binarySize, writtenSize;

-  runtime_->qnnInterface.contextGetBinarySize(context_, &binarySize);
+  auto status = runtime_->qnnInterface.contextGetBinarySize(context_, &binarySize);
+  if (status != QNN_SUCCESS) {
+    MLLM_ERROR("Failed to get QNN context binary size.");
+    return;
+  }

   std::unique_ptr<uint8_t[]> binaryBuffer(new uint8_t[binarySize]);

-  runtime_->qnnInterface.contextGetBinary(context_, reinterpret_cast<void*>(binaryBuffer.get()), binarySize, &writtenSize);
+  status = runtime_->qnnInterface.contextGetBinary(context_, reinterpret_cast<void*>(binaryBuffer.get()), binarySize, &writtenSize);
+  if (status != QNN_SUCCESS) {
+    MLLM_ERROR("Failed to get QNN context binary.");
+    return;
+  }

   if (binarySize < writtenSize) {
     MLLM_ERROR("QNN context binary size mismatch. Written {}  bytes, expected {} bytes.", writtenSize, binarySize);
+    return;
   }
+
   std::ofstream file(contextPath, std::ios::binary);
+  if (!file.is_open()) {
+    MLLM_ERROR("Failed to open file {} for writing QNN context.", contextPath);
+    return;
+  }
   file.write(reinterpret_cast<char*>(binaryBuffer.get()), writtenSize);
   file.close();

   MLLM_INFO("QNN context saved to {} written {} bytes.", contextPath, writtenSize);
 }
🤖 Prompt for AI Agents
In `@mllm/backends/qnn/QNNBackend.cpp` around lines 440 - 457,
QNNBackend::saveContext must perform and act on error returns and validate file
open/size before writing: check the return value of
runtime_->qnnInterface.contextGetBinarySize(context_, &binarySize) and bail/log
on failure; after allocating binaryBuffer, check the return of
runtime_->qnnInterface.contextGetBinary(context_, ..., &writtenSize) and
bail/log on failure; if writtenSize != binarySize treat it as an error and do
not proceed to write the buffer; verify std::ofstream file(contextPath,
std::ios::binary).is_open() before file.write and handle/log/return on failure;
ensure all early exits free resources and log via MLLM_ERROR/MllM_INFO as
appropriate.

@chenghuaWang chenghuaWang changed the title Qnn aot feat(qualcomm): Qnn AOT Runtime Jan 20, 2026
Copy link
Collaborator

@chenghuaWang chenghuaWang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@chenghuaWang chenghuaWang merged commit bd64a2c into UbiquitousLearning:main Jan 20, 2026
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants