Skip to content

Commit 4c8bc5f

Browse files
committed
fix: bincode + skip_serializing_if incompatibility in LLM types
- Remove skip_serializing_if from LLM SDK types (breaks bincode roundtrip) - Deserialize LLM requests using SDK types directly in host - Spawn dedicated thread for reqwest::blocking in async context - Make --name required in platform-sudo upload - Add detailed logging to LLM proxy host function
1 parent 476ea06 commit 4c8bc5f

File tree

3 files changed

+133
-49
lines changed

3 files changed

+133
-49
lines changed

bins/platform-sudo/src/main.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ enum Commands {
3737
/// Challenge ID (UUID format)
3838
#[arg(short, long)]
3939
id: String,
40-
/// Challenge name
40+
/// Challenge name (required to avoid UUID-only names)
4141
#[arg(short, long)]
42-
name: Option<String>,
42+
name: String,
4343
},
4444
/// Activate a challenge
4545
Activate {
@@ -199,7 +199,7 @@ impl SudoCli {
199199
&self,
200200
file: &PathBuf,
201201
challenge_id: &str,
202-
name: Option<String>,
202+
name: String,
203203
) -> Result<()> {
204204
let wasm_bytes = std::fs::read(file).context("Failed to read WASM file")?;
205205

@@ -222,7 +222,7 @@ impl SudoCli {
222222
action: "wasm_upload".to_string(),
223223
challenge_id: id_str.clone(),
224224
data: Some(base64::engine::general_purpose::STANDARD.encode(&wasm_bytes)),
225-
name: name.or_else(|| Some(id_str.clone())),
225+
name: Some(name),
226226
signature: hex::encode(&signature),
227227
timestamp,
228228
};
@@ -433,7 +433,7 @@ impl SudoCli {
433433
match parts[0] {
434434
"help" | "?" => {
435435
println!("\nCommands:");
436-
println!(" upload <file> <challenge_id> [name] - Upload WASM module");
436+
println!(" upload <file> <challenge_id> <name> - Upload WASM module");
437437
println!(" activate <challenge_id> - Activate challenge");
438438
println!(
439439
" deactivate <challenge_id> - Deactivate challenge"
@@ -446,10 +446,10 @@ impl SudoCli {
446446
);
447447
println!(" exit | quit - Exit CLI\n");
448448
}
449-
"upload" if parts.len() >= 3 => {
449+
"upload" if parts.len() >= 4 => {
450450
let file = PathBuf::from(parts[1]);
451451
let id = parts[2];
452-
let name = parts.get(3).map(|s| s.to_string());
452+
let name = parts[3].to_string();
453453
if let Err(e) = self.upload_wasm(&file, id, name).await {
454454
println!("Error: {}", e);
455455
}

crates/challenge-sdk-wasm/src/llm_types.rs

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,15 @@ use serde::{Deserialize, Serialize};
88
pub struct LlmRequest {
99
pub model: String,
1010
pub messages: Vec<LlmMessage>,
11-
#[serde(skip_serializing_if = "Option::is_none")]
1211
pub max_tokens: Option<u32>,
13-
#[serde(skip_serializing_if = "Option::is_none")]
1412
pub temperature: Option<f32>,
15-
#[serde(skip_serializing_if = "Option::is_none")]
1613
pub top_p: Option<f32>,
17-
#[serde(skip_serializing_if = "Option::is_none")]
1814
pub frequency_penalty: Option<f32>,
19-
#[serde(skip_serializing_if = "Option::is_none")]
2015
pub presence_penalty: Option<f32>,
21-
#[serde(skip_serializing_if = "Option::is_none")]
2216
pub stop: Option<Vec<String>>,
2317
/// OpenAI function calling / tools
24-
#[serde(skip_serializing_if = "Option::is_none")]
2518
pub tools: Option<Vec<Tool>>,
26-
#[serde(skip_serializing_if = "Option::is_none")]
2719
pub tool_choice: Option<ToolChoice>,
28-
#[serde(skip_serializing_if = "Option::is_none")]
2920
pub response_format: Option<ResponseFormat>,
3021
}
3122

@@ -72,11 +63,8 @@ impl LlmRequest {
7263
pub struct LlmMessage {
7364
pub role: String,
7465
pub content: Option<String>,
75-
#[serde(skip_serializing_if = "Option::is_none")]
7666
pub name: Option<String>,
77-
#[serde(skip_serializing_if = "Option::is_none")]
7867
pub tool_calls: Option<Vec<ToolCall>>,
79-
#[serde(skip_serializing_if = "Option::is_none")]
8068
pub tool_call_id: Option<String>,
8169
}
8270

@@ -155,10 +143,8 @@ impl Tool {
155143
#[derive(Clone, Debug, Serialize, Deserialize)]
156144
pub struct FunctionDef {
157145
pub name: String,
158-
#[serde(skip_serializing_if = "Option::is_none")]
159146
pub description: Option<String>,
160147
/// JSON Schema string for the function parameters
161-
#[serde(skip_serializing_if = "Option::is_none")]
162148
pub parameters: Option<String>,
163149
}
164150

crates/wasm-runtime-interface/src/llm.rs

Lines changed: 126 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
use crate::runtime::{HostFunctionRegistrar, RuntimeState, WasmRuntimeError};
1313
use serde::{Deserialize, Serialize};
1414
use std::fmt;
15-
use tracing::warn;
15+
use tracing::{info, warn};
1616
use wasmtime::{Caller, Linker, Memory};
1717

1818
const MAX_CHAT_REQUEST_SIZE: u64 = 4 * 1024 * 1024;
@@ -175,14 +175,17 @@ fn handle_chat_completion(
175175
}
176176

177177
if !policy_available {
178+
warn!("llm proxy: policy not available (disabled or no key)");
178179
return LlmHostStatus::Disabled.to_i32();
179180
}
180181

181182
if requests_made >= max_requests {
183+
warn!(requests_made, max_requests, "llm proxy: rate limited");
182184
return LlmHostStatus::RateLimited.to_i32();
183185
}
184186

185187
if req_ptr < 0 || req_len < 0 || resp_ptr < 0 || resp_len < 0 {
188+
warn!(req_ptr, req_len, resp_ptr, resp_len, "llm proxy: invalid pointers");
186189
return LlmHostStatus::InvalidRequest.to_i32();
187190
}
188191

@@ -195,6 +198,7 @@ fn handle_chat_completion(
195198
};
196199

197200
if request_bytes.len() as u64 > MAX_CHAT_REQUEST_SIZE {
201+
warn!(size = request_bytes.len(), "llm proxy: request too large");
198202
return LlmHostStatus::InvalidRequest.to_i32();
199203
}
200204

@@ -204,56 +208,77 @@ fn handle_chat_completion(
204208
let state = &caller.data().llm_state;
205209
api_key = match &state.policy.api_key {
206210
Some(k) => k.clone(),
207-
None => return LlmHostStatus::Disabled.to_i32(),
211+
None => {
212+
warn!("llm proxy: no API key");
213+
return LlmHostStatus::Disabled.to_i32();
214+
}
208215
};
209216
endpoint = state.policy.endpoint.clone();
210217
}
211218

212-
// Deserialize the SDK request (bincode-encoded LlmRequest from challenge-sdk-wasm)
213-
let sdk_req: SdkRequest = match bincode::deserialize(&request_bytes) {
214-
Ok(r) => r,
215-
Err(_) => return LlmHostStatus::InvalidRequest.to_i32(),
216-
};
219+
// Deserialize using the SDK types directly (same serde attributes including
220+
// skip_serializing_if which bincode respects for serialization layout)
221+
let wasm_req: platform_challenge_sdk_wasm::LlmRequest =
222+
match bincode::deserialize(&request_bytes) {
223+
Ok(r) => r,
224+
Err(e) => {
225+
let preview: Vec<u8> = request_bytes.iter().take(64).copied().collect();
226+
warn!(error = %e, req_len = request_bytes.len(), first_bytes = ?preview, "llm proxy: bincode deserialize failed");
227+
return LlmHostStatus::InvalidRequest.to_i32();
228+
}
229+
};
230+
info!("llm proxy: request decoded, model={}, messages={}", wasm_req.model, wasm_req.messages.len());
217231

218232
// Validate model against allowed list
219233
{
220234
let state = &caller.data().llm_state;
221235
let allowed = &state.policy.allowed_models;
222-
if !allowed.is_empty() && !allowed.contains(&sdk_req.model) {
223-
warn!(model = %sdk_req.model, "llm proxy: model not in allowed list");
236+
if !allowed.is_empty() && !allowed.contains(&wasm_req.model) {
237+
warn!(model = %wasm_req.model, "llm proxy: model not in allowed list");
224238
return LlmHostStatus::InvalidRequest.to_i32();
225239
}
226240
}
227241

228-
// Build OpenAI-compatible JSON, force stream: false
242+
// Convert SDK types to local types for build_openai_request
243+
let sdk_req = convert_sdk_request(&wasm_req);
229244
let openai_json = build_openai_request(&sdk_req);
230245
let json_body = match serde_json::to_vec(&openai_json) {
231246
Ok(b) => b,
232247
Err(_) => return LlmHostStatus::InvalidRequest.to_i32(),
233248
};
234249

235-
// Forward to LLM endpoint
236-
let client = reqwest::blocking::Client::new();
237-
let http_response = match client
238-
.post(&endpoint)
239-
.header("Content-Type", "application/json")
240-
.header("Authorization", format!("Bearer {}", api_key))
241-
.body(json_body)
242-
.timeout(std::time::Duration::from_secs(LLM_REQUEST_TIMEOUT_SECS))
243-
.send()
244-
{
245-
Ok(r) => r,
246-
Err(err) => {
250+
// Forward to LLM endpoint via a dedicated thread to avoid
251+
// reqwest::blocking conflicts with the tokio async runtime.
252+
let (tx, rx) = std::sync::mpsc::channel();
253+
let endpoint_clone = endpoint.clone();
254+
let api_key_clone = api_key.clone();
255+
let json_body_clone = json_body.clone();
256+
std::thread::spawn(move || {
257+
let result = (|| -> Result<Vec<u8>, String> {
258+
let client = reqwest::blocking::Client::new();
259+
let resp = client
260+
.post(&endpoint_clone)
261+
.header("Content-Type", "application/json")
262+
.header("Authorization", format!("Bearer {}", api_key_clone))
263+
.body(json_body_clone)
264+
.timeout(std::time::Duration::from_secs(LLM_REQUEST_TIMEOUT_SECS))
265+
.send()
266+
.map_err(|e| format!("HTTP request failed: {}", e))?;
267+
let bytes = resp.bytes().map_err(|e| format!("read body failed: {}", e))?;
268+
Ok(bytes.to_vec())
269+
})();
270+
let _ = tx.send(result);
271+
});
272+
273+
let response_body = match rx.recv() {
274+
Ok(Ok(body)) => body,
275+
Ok(Err(err)) => {
247276
warn!(error = %err, "llm proxy: HTTP request failed");
248277
return LlmHostStatus::ApiError.to_i32();
249278
}
250-
};
251-
252-
let response_body = match http_response.bytes() {
253-
Ok(b) => b.to_vec(),
254279
Err(err) => {
255-
warn!(error = %err, "llm proxy: failed to read response body");
256-
return LlmHostStatus::ApiError.to_i32();
280+
warn!(error = %err, "llm proxy: thread communication failed");
281+
return LlmHostStatus::InternalError.to_i32();
257282
}
258283
};
259284

@@ -291,23 +316,35 @@ fn handle_chat_completion(
291316
struct SdkRequest {
292317
model: String,
293318
messages: Vec<SdkMessage>,
319+
#[serde(default)]
294320
max_tokens: Option<u32>,
321+
#[serde(default)]
295322
temperature: Option<f32>,
323+
#[serde(default)]
296324
top_p: Option<f32>,
325+
#[serde(default)]
297326
frequency_penalty: Option<f32>,
327+
#[serde(default)]
298328
presence_penalty: Option<f32>,
329+
#[serde(default)]
299330
stop: Option<Vec<String>>,
331+
#[serde(default)]
300332
tools: Option<Vec<SdkTool>>,
333+
#[serde(default)]
301334
tool_choice: Option<SdkToolChoice>,
335+
#[serde(default)]
302336
response_format: Option<SdkResponseFormat>,
303337
}
304338

305339
#[derive(Deserialize)]
306340
struct SdkMessage {
307341
role: String,
308342
content: Option<String>,
343+
#[serde(default)]
309344
name: Option<String>,
345+
#[serde(default)]
310346
tool_calls: Option<Vec<SdkToolCall>>,
347+
#[serde(default)]
311348
tool_call_id: Option<String>,
312349
}
313350

@@ -399,6 +436,49 @@ struct OpenAiMessage {
399436
tool_call_id: Option<String>,
400437
}
401438

439+
fn convert_sdk_request(wasm: &platform_challenge_sdk_wasm::LlmRequest) -> SdkRequest {
440+
use platform_challenge_sdk_wasm::llm_types::ToolChoice as WasmTC;
441+
SdkRequest {
442+
model: wasm.model.clone(),
443+
messages: wasm.messages.iter().map(|m| SdkMessage {
444+
role: m.role.clone(),
445+
content: m.content.clone(),
446+
name: m.name.clone(),
447+
tool_calls: m.tool_calls.as_ref().map(|tcs| tcs.iter().map(|tc| SdkToolCall {
448+
id: tc.id.clone(),
449+
call_type: tc.call_type.clone(),
450+
function: SdkFunctionCall { name: tc.function.name.clone(), arguments: tc.function.arguments.clone() },
451+
}).collect()),
452+
tool_call_id: m.tool_call_id.clone(),
453+
}).collect(),
454+
max_tokens: wasm.max_tokens,
455+
temperature: wasm.temperature,
456+
top_p: wasm.top_p,
457+
frequency_penalty: wasm.frequency_penalty,
458+
presence_penalty: wasm.presence_penalty,
459+
stop: wasm.stop.clone(),
460+
tools: wasm.tools.as_ref().map(|ts| ts.iter().map(|t| SdkTool {
461+
tool_type: "function".to_string(),
462+
function: SdkFunctionDef {
463+
name: t.function.name.clone(),
464+
description: t.function.description.clone(),
465+
parameters: t.function.parameters.clone(),
466+
},
467+
}).collect()),
468+
tool_choice: wasm.tool_choice.as_ref().map(|tc| match tc {
469+
WasmTC::Auto => SdkToolChoice::Auto,
470+
WasmTC::None => SdkToolChoice::None,
471+
WasmTC::Required => SdkToolChoice::Required,
472+
WasmTC::Specific { function } => SdkToolChoice::Specific {
473+
function: SdkToolChoiceFunction { name: function.name.clone() },
474+
},
475+
}),
476+
response_format: wasm.response_format.as_ref().map(|rf| SdkResponseFormat {
477+
format_type: rf.format_type.clone(),
478+
}),
479+
}
480+
}
481+
402482
fn build_openai_request(sdk: &SdkRequest) -> OpenAiRequest {
403483
let messages = sdk
404484
.messages
@@ -740,6 +820,24 @@ mod tests {
740820
assert_eq!(sdk_resp.usage.unwrap().total_tokens, 70);
741821
}
742822

823+
#[test]
824+
fn test_bincode_roundtrip_llm_request() {
825+
use platform_challenge_sdk_wasm::{LlmMessage, LlmRequest};
826+
let req = LlmRequest::simple(
827+
"moonshotai/Kimi-K2.5-TEE",
828+
vec![
829+
LlmMessage::system("test system"),
830+
LlmMessage::user("test user"),
831+
],
832+
2048,
833+
);
834+
let bytes = bincode::serialize(&req).unwrap();
835+
println!("Serialized {} bytes, first 64: {:?}", bytes.len(), &bytes[..64.min(bytes.len())]);
836+
let decoded: LlmRequest = bincode::deserialize(&bytes).unwrap();
837+
assert_eq!(decoded.model, "moonshotai/Kimi-K2.5-TEE");
838+
assert_eq!(decoded.messages.len(), 2);
839+
}
840+
743841
#[test]
744842
fn test_parse_openai_response_text_only() {
745843
let response_json = serde_json::json!({

0 commit comments

Comments
 (0)