From 12e608aa140b4d16fda5bf664efc4dd210eac969 Mon Sep 17 00:00:00 2001 From: crrow Date: Wed, 25 Mar 2026 17:25:35 +0900 Subject: [PATCH 1/6] chore(deps): add md5 and urlencoding crates (#14) Required for aligning media upload flow (MD5 hash) and CDN URL construction (encrypted_query_param encoding) with Python SDK. Closes #14 --- Cargo.lock | 14 ++++++++++++++ Cargo.toml | 2 ++ 2 files changed, 16 insertions(+) diff --git a/Cargo.lock b/Cargo.lock index 8888f95..f029cb4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -795,6 +795,12 @@ version = "0.4.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" +[[package]] +name = "md5" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771" + [[package]] name = "memchr" version = "2.8.0" @@ -1732,6 +1738,12 @@ dependencies = [ "serde", ] +[[package]] +name = "urlencoding" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" + [[package]] name = "utf8_iter" version = "1.0.4" @@ -1917,6 +1929,7 @@ dependencies = [ "dirs", "ecb", "hex", + "md5", "mime_guess", "qrcode", "rand", @@ -1930,6 +1943,7 @@ dependencies = [ "tokio-test", "tracing", "tracing-subscriber", + "urlencoding", "uuid", ] diff --git a/Cargo.toml b/Cargo.toml index 7b870d6..c3e9a2a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,6 +30,8 @@ chrono = "0.4" anyhow = "1" tracing-subscriber = "0.3" base64 = "0.22" +md5 = "0.7" +urlencoding = "2" [dev-dependencies] tokio-test = "0.4" From ce84b3a18c05dc117ff9fe6cfc08dce4e18ee2ae Mon Sep 17 00:00:00 2001 From: crrow Date: Wed, 25 Mar 2026 17:29:30 +0900 Subject: [PATCH 2/6] fix(storage): align paths, sync format, and config with Python SDK (#14) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - State dir resolution: check $OPENCLAW_STATE_DIR → $CLAWDBOT_STATE_DIR → ~/.openclaw - Add normalize_account_id() and derive_raw_account_id() public functions - Change sync buf format from plain text to JSON in accounts/{id}.sync.json - Change get_account_config to read from global {state_dir}/openclaw.json - Add chmod 600 on Unix for save_account_data - Add tests for normalization, sync buf format, and global config parsing Closes #14 Co-Authored-By: Claude Opus 4.6 --- src/storage.rs | 177 +++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 158 insertions(+), 19 deletions(-) diff --git a/src/storage.rs b/src/storage.rs index 0dbee46..039a60f 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::path::PathBuf; use serde::{Deserialize, Serialize}; @@ -14,11 +15,36 @@ pub const DEFAULT_BASE_URL: &str = "https://ilinkai.weixin.qq.com"; /// Base URL for downloading encrypted media from the `WeChat` CDN. pub const CDN_BASE_URL: &str = "https://novac2c.cdn.weixin.qq.com/c2c"; +/// Resolves the state directory by checking env vars in priority order: +/// `$OPENCLAW_STATE_DIR` -> `$CLAWDBOT_STATE_DIR` -> `~/.openclaw`. +fn state_dir() -> PathBuf { + std::env::var("OPENCLAW_STATE_DIR") + .or_else(|_| std::env::var("CLAWDBOT_STATE_DIR")) + .map_or_else( + |_| { + dirs::home_dir() + .expect("no home directory") + .join(".openclaw") + }, + PathBuf::from, + ) +} + fn storage_root() -> PathBuf { - dirs::home_dir() - .expect("no home directory") - .join(".openclaw") - .join("openclaw-weixin") + state_dir().join("openclaw-weixin") +} + +/// Normalizes an account ID: trims whitespace, lowercases, and replaces +/// `@` and `.` with `-`. +pub fn normalize_account_id(id: &str) -> String { + id.trim().to_lowercase().replace(['@', '.'], "-") +} + +/// Derives the raw account ID from a normalized one. +/// +/// For now this is an identity function — reverse mapping is best-effort. +pub fn derive_raw_account_id(id: &str) -> String { + id.to_string() } /// Persisted authentication credentials for a single `WeChat` account. @@ -59,7 +85,8 @@ pub fn get_account_ids() -> Result> { /// Persists the given list of account IDs to local storage. pub fn save_account_ids(ids: &[String]) -> Result<()> { let path = storage_root().join("accounts.json"); - std::fs::create_dir_all(path.parent().unwrap()).context(IoSnafu)?; + std::fs::create_dir_all(path.parent().expect("accounts.json must have a parent dir")) + .context(IoSnafu)?; let json = serde_json::to_string_pretty(ids).context(JsonSnafu)?; std::fs::write(&path, json).context(IoSnafu)?; Ok(()) @@ -75,41 +102,101 @@ pub fn get_account_data(account_id: &str) -> Result { } /// Saves credentials for the given account to local storage. +/// +/// On Unix systems, the file is set to mode 600 (owner read/write only). pub fn save_account_data(account_id: &str, data: &AccountData) -> Result<()> { let path = storage_root() .join("accounts") .join(format!("{account_id}.json")); - std::fs::create_dir_all(path.parent().unwrap()).context(IoSnafu)?; + std::fs::create_dir_all(path.parent().expect("account file must have a parent dir")) + .context(IoSnafu)?; let json = serde_json::to_string_pretty(data).context(JsonSnafu)?; - std::fs::write(&path, json).context(IoSnafu)?; + std::fs::write(&path, &json).context(IoSnafu)?; + + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + let perms = std::fs::Permissions::from_mode(0o600); + std::fs::set_permissions(&path, perms).context(IoSnafu)?; + } + Ok(()) } +/// JSON wrapper for the sync buffer, matching the Python SDK format. +#[derive(Debug, Serialize, Deserialize)] +struct SyncBuf { + get_updates_buf: String, +} + /// Returns the saved long-poll continuation buffer for the given account. +/// +/// Reads from `accounts/{id}.sync.json` as a JSON object with a +/// `get_updates_buf` field, matching the Python SDK format. pub fn get_updates_buf(account_id: &str) -> Option { let path = storage_root() - .join("get_updates_buf") - .join(format!("{account_id}.txt")); - std::fs::read_to_string(&path).ok() + .join("accounts") + .join(format!("{account_id}.sync.json")); + let data = std::fs::read_to_string(&path).ok()?; + let buf: SyncBuf = serde_json::from_str(&data).ok()?; + Some(buf.get_updates_buf) } /// Saves the long-poll continuation buffer for the given account. +/// +/// Writes to `accounts/{id}.sync.json` as a JSON object with a +/// `get_updates_buf` field, matching the Python SDK format. pub fn save_updates_buf(account_id: &str, buf: &str) -> Result<()> { let path = storage_root() - .join("get_updates_buf") - .join(format!("{account_id}.txt")); - std::fs::create_dir_all(path.parent().unwrap()).context(IoSnafu)?; - std::fs::write(&path, buf).context(IoSnafu)?; + .join("accounts") + .join(format!("{account_id}.sync.json")); + std::fs::create_dir_all(path.parent().expect("sync file must have a parent dir")) + .context(IoSnafu)?; + let wrapper = SyncBuf { + get_updates_buf: buf.to_string(), + }; + let json = serde_json::to_string_pretty(&wrapper).context(JsonSnafu)?; + std::fs::write(&path, json).context(IoSnafu)?; Ok(()) } -/// Loads the optional per-account configuration, returning `None` if absent. +/// Global config file structure matching the Python SDK's `openclaw.json`. +#[derive(Debug, Deserialize)] +struct GlobalConfig { + #[serde(default)] + channels: HashMap, +} + +/// Channel-level configuration within the global config. +#[derive(Debug, Deserialize)] +struct ChannelConfig { + #[serde(default)] + accounts: HashMap, +} + +/// Per-account settings within a channel config. +#[derive(Debug, Deserialize)] +struct AccountSettingsInConfig { + #[serde(default, rename = "routeTag")] + route_tag: Option, +} + +/// Loads the optional per-account configuration from the global +/// `{state_dir}/openclaw.json` file, reading +/// `.channels["openclaw-weixin"].accounts["{raw_id}"].routeTag`. pub fn get_account_config(account_id: &str) -> Option { - let path = storage_root() - .join("config") - .join(format!("{account_id}.json")); + let path = state_dir().join("openclaw.json"); let data = std::fs::read_to_string(&path).ok()?; - serde_json::from_str(&data).ok() + let config: GlobalConfig = serde_json::from_str(&data).ok()?; + let raw_id = derive_raw_account_id(account_id); + let route_tag = config + .channels + .get("openclaw-weixin")? + .accounts + .get(&raw_id)? + .route_tag + .clone(); + Some(AccountConfig { route_tag }) } #[cfg(test)] @@ -148,4 +235,56 @@ mod tests { let config: AccountConfig = serde_json::from_str(json).unwrap(); assert_eq!(config.route_tag, None); } + + #[test] + fn test_normalize_account_id() { + assert_eq!( + normalize_account_id(" MyBot@Test.Com "), + "mybot-test-com" + ); + assert_eq!(normalize_account_id("simple"), "simple"); + assert_eq!(normalize_account_id("A.B@C"), "a-b-c"); + } + + #[test] + fn test_derive_raw_account_id() { + assert_eq!(derive_raw_account_id("mybot"), "mybot"); + } + + #[test] + fn test_sync_buf_json_format() { + let wrapper = SyncBuf { + get_updates_buf: "some-buf-value".to_string(), + }; + let json = serde_json::to_string(&wrapper).unwrap(); + let deserialized: SyncBuf = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.get_updates_buf, "some-buf-value"); + } + + #[test] + fn test_global_config_parsing() { + let json = r#"{ + "channels": { + "openclaw-weixin": { + "accounts": { + "mybot": { + "routeTag": "tag-123" + } + } + } + } + }"#; + let config: GlobalConfig = serde_json::from_str(json).unwrap(); + let route_tag = config + .channels + .get("openclaw-weixin") + .unwrap() + .accounts + .get("mybot") + .unwrap() + .route_tag + .as_ref() + .unwrap(); + assert_eq!(route_tag, "tag-123"); + } } From b5c56597ae7bde5a6ba706d7f3ab95ccc343f212 Mon Sep 17 00:00:00 2001 From: crrow Date: Wed, 25 Mar 2026 17:35:27 +0900 Subject: [PATCH 3/6] fix(api): align all endpoints with Python SDK (#14) - Add constants: DEFAULT_LONG_POLL_TIMEOUT, DEFAULT_API_TIMEOUT, DEFAULT_CONFIG_TIMEOUT, DEFAULT_ILINK_BOT_TYPE, SDK_VERSION - Fix headers(): base64-encoded random u32 for X-WECHAT-UIN, content-type header, skip auth when token empty - Add base_info with SDK version to all POST payloads - Extract check_response() to check both errcode and ret fields - Add GET helper (get_with_timeout) for QR code endpoints - Change fetch_qr_code to GET with bot_type query param - Change get_qr_code_status to GET with iLink-App-ClientVersion header - Fix get_updates timeout to 35s (long poll) - Merge send_text_message + send_media_message into send_message - Fix send_typing to include typing_status, use 10s timeout - Fix get_upload_url to accept full media metadata - Add get_config endpoint - Update callers in runtime.rs, bot.rs, media.rs for new signatures Closes #14 Co-Authored-By: Claude Opus 4.6 --- src/api.rs | 299 +++++++++++++++++++++++++++++++++---------------- src/bot.rs | 15 +-- src/media.rs | 6 +- src/runtime.rs | 20 ++-- 4 files changed, 225 insertions(+), 115 deletions(-) diff --git a/src/api.rs b/src/api.rs index 8d491d2..1690ab3 100644 --- a/src/api.rs +++ b/src/api.rs @@ -1,5 +1,7 @@ use std::time::Duration; +use base64::Engine as _; +use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; use reqwest::Client; use serde_json::Value; use snafu::ResultExt; @@ -7,6 +9,11 @@ use snafu::ResultExt; use crate::errors::{ApiSnafu, HttpSnafu, Result, SessionExpiredSnafu}; const SESSION_EXPIRED_ERRCODE: i64 = -14; +const DEFAULT_LONG_POLL_TIMEOUT: Duration = Duration::from_secs(35); +const DEFAULT_API_TIMEOUT: Duration = Duration::from_secs(15); +const DEFAULT_CONFIG_TIMEOUT: Duration = Duration::from_secs(10); +const DEFAULT_ILINK_BOT_TYPE: &str = "3"; +const SDK_VERSION: &str = env!("CARGO_PKG_VERSION"); /// HTTP client wrapper for the `WeChat` iLink Bot API. /// @@ -34,34 +41,61 @@ impl WeixinApiClient { /// Replaces the bearer token used for subsequent requests. pub fn set_token(&mut self, token: &str) { self.token = token.to_string(); } - fn headers(&self) -> reqwest::header::HeaderMap { - use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; + fn headers(&self) -> HeaderMap { let mut headers = HeaderMap::new(); headers.insert( - HeaderName::from_static("authorizationtype"), - HeaderValue::from_static("ilink_bot_token"), + reqwest::header::CONTENT_TYPE, + HeaderValue::from_static("application/json"), ); headers.insert( - reqwest::header::AUTHORIZATION, - HeaderValue::from_str(&format!("Bearer {}", self.token)).unwrap(), + HeaderName::from_static("authorizationtype"), + HeaderValue::from_static("ilink_bot_token"), ); - let uin: u64 = rand::random::() % 9_000_000_000 + 1_000_000_000; + // base64-encoded random u32 to match Python SDK behaviour + let uin: u32 = rand::random(); + let uin_b64 = base64::engine::general_purpose::STANDARD.encode(uin.to_le_bytes()); headers.insert( HeaderName::from_static("x-wechat-uin"), - HeaderValue::from_str(&uin.to_string()).unwrap(), + HeaderValue::from_str(&uin_b64).expect("valid base64"), ); + // Only add auth if token is non-empty + if !self.token.is_empty() { + headers.insert( + reqwest::header::AUTHORIZATION, + HeaderValue::from_str(&format!("Bearer {}", self.token)) + .expect("valid token"), + ); + } if let Some(ref tag) = self.route_tag { headers.insert( HeaderName::from_static("skroutetag"), - HeaderValue::from_str(tag).unwrap(), + HeaderValue::from_str(tag).expect("valid route tag"), ); } headers } - async fn post(&self, path: &str, body: &Value) -> Result { - self.post_with_timeout(path, body, Duration::from_secs(30)) - .await + /// Checks the API response for error codes (`errcode` or `ret`). + #[allow(clippy::unused_self)] + fn check_response(&self, resp: &Value) -> Result<()> { + let code = resp + .get("errcode") + .and_then(Value::as_i64) + .or_else(|| resp.get("ret").and_then(Value::as_i64)); + if let Some(code) = code { + if code == SESSION_EXPIRED_ERRCODE { + return Err(SessionExpiredSnafu.build()); + } + if code != 0 { + let msg = resp + .get("errmsg") + .and_then(|v| v.as_str()) + .unwrap_or("unknown error") + .to_string(); + return Err(ApiSnafu { code, message: msg }.build()); + } + } + Ok(()) } async fn post_with_timeout( @@ -71,11 +105,19 @@ impl WeixinApiClient { timeout: Duration, ) -> Result { let url = format!("{}/{}", self.base_url, path); + // Inject base_info with SDK version into every POST payload + let mut payload = body.clone(); + if let Some(obj) = payload.as_object_mut() { + obj.insert( + "base_info".to_string(), + serde_json::json!({"channel_version": SDK_VERSION}), + ); + } let resp = self .client .post(&url) .headers(self.headers()) - .json(body) + .json(&payload) .timeout(timeout) .send() .await @@ -84,84 +126,62 @@ impl WeixinApiClient { .await .context(HttpSnafu)?; - if let Some(code) = resp.get("errcode").and_then(serde_json::Value::as_i64) { - if code == SESSION_EXPIRED_ERRCODE { - return Err(SessionExpiredSnafu.build()); - } - if code != 0 { - let msg = resp - .get("errmsg") - .and_then(|v| v.as_str()) - .unwrap_or("unknown error") - .to_string(); - return Err(ApiSnafu { code, message: msg }.build()); - } - } + self.check_response(&resp)?; Ok(resp) } - /// Sends a form-encoded POST request and checks the `ret` error field. - /// - /// Login endpoints use form encoding + `ret`/`err_msg` instead of JSON + - /// `errcode`/`errmsg` used by messaging endpoints. - async fn post_form(&self, path: &str, params: &[(&str, &str)]) -> Result { - self.post_form_with_timeout(path, params, Duration::from_secs(30)) - .await - } - - /// Same as [`post_form`](Self::post_form) but with a custom timeout. - async fn post_form_with_timeout( + /// Sends a GET request with query parameters and optional extra headers. + async fn get_with_timeout( &self, path: &str, - params: &[(&str, &str)], + query: &[(&str, &str)], + extra_headers: Option, timeout: Duration, ) -> Result { let url = format!("{}/{}", self.base_url, path); - let resp = self + let mut req = self .client - .post(&url) + .get(&url) .headers(self.headers()) - .form(params) - .timeout(timeout) + .query(query) + .timeout(timeout); + if let Some(h) = extra_headers { + req = req.headers(h); + } + let resp = req .send() .await .context(HttpSnafu)? .json::() .await .context(HttpSnafu)?; - - if let Some(ret) = resp.get("ret").and_then(Value::as_i64) - && ret != 0 - { - let msg = resp - .get("err_msg") - .and_then(|v| v.as_str()) - .unwrap_or("unknown error") - .to_string(); - return Err(ApiSnafu { - code: ret, - message: msg, - } - .build()); - } + self.check_response(&resp)?; Ok(resp) } /// Requests a new login QR code from the API. pub async fn fetch_qr_code(&self) -> Result { - self.post_form("ilink/bot/get_bot_qrcode", &[("bot_type", "3")]) - .await + self.get_with_timeout( + "ilink/bot/get_bot_qrcode", + &[("bot_type", DEFAULT_ILINK_BOT_TYPE)], + None, + DEFAULT_API_TIMEOUT, + ) + .await } - /// Polls the current scan status for the given `qrcode_id`. - /// - /// Uses a longer timeout than the default because this endpoint - /// long-polls until the user scans the QR code. - pub async fn get_qr_code_status(&self, qrcode_id: &str) -> Result { - self.post_form_with_timeout( + /// Polls the current scan status for the given `qrcode`. + pub async fn get_qr_code_status(&self, qrcode: &str) -> Result { + let mut extra = HeaderMap::new(); + extra.insert( + HeaderName::from_static("ilink-app-clientversion"), + HeaderValue::from_static("1"), + ); + self.get_with_timeout( "ilink/bot/get_qrcode_status", - &[("qrcode", qrcode_id), ("bot_type", "3")], - Duration::from_secs(60), + &[("qrcode", qrcode)], + Some(extra), + DEFAULT_LONG_POLL_TIMEOUT, ) .await } @@ -172,65 +192,72 @@ impl WeixinApiClient { if let Some(b) = buf { body["get_updates_buf"] = Value::String(b.to_string()); } - self.post_with_timeout("ilink/bot/getupdates", &body, Duration::from_secs(40)) + self.post_with_timeout("ilink/bot/getupdates", &body, DEFAULT_LONG_POLL_TIMEOUT) .await } - /// Sends a plain-text message to `to_user_id`. - pub async fn send_text_message( + /// Sends a message with the given item list to `to_user_id`. + pub async fn send_message( &self, to_user_id: &str, context_token: &str, - text: &str, + item_list: &[Value], ) -> Result { let body = serde_json::json!({ "to_user_id": to_user_id, "context_token": context_token, - "item_list": [{ - "type": 0, - "body": text - }] + "item_list": item_list, }); - self.post("ilink/bot/sendmessage", &body).await + self.post_with_timeout("ilink/bot/sendmessage", &body, DEFAULT_API_TIMEOUT) + .await } - /// Sends a media message (image, video, or file) to `to_user_id`. - pub async fn send_media_message( + /// Sends a typing indicator to `to_user_id`. + pub async fn send_typing( &self, to_user_id: &str, context_token: &str, - text: Option<&str>, - file_info: &Value, + typing_status: u8, ) -> Result { - let mut item_list = vec![]; - if let Some(t) = text { - item_list.push(serde_json::json!({ "type": 0, "body": t })); - } - item_list.push(file_info.clone()); let body = serde_json::json!({ "to_user_id": to_user_id, "context_token": context_token, - "item_list": item_list + "typing_status": typing_status, }); - self.post("ilink/bot/sendmessage", &body).await + self.post_with_timeout("ilink/bot/sendtyping", &body, DEFAULT_CONFIG_TIMEOUT) + .await } - /// Sends a typing indicator to `to_user_id`. - pub async fn send_typing(&self, to_user_id: &str, context_token: &str) -> Result { + /// Requests a pre-signed upload URL with full media metadata. + #[allow(clippy::too_many_arguments)] + pub async fn get_upload_url( + &self, + filekey: &str, + media_type: u8, + to_user_id: &str, + rawsize: u64, + rawfilemd5: &str, + filesize: u64, + aeskey: &str, + ) -> Result { let body = serde_json::json!({ + "filekey": filekey, + "media_type": media_type, "to_user_id": to_user_id, - "context_token": context_token, + "rawsize": rawsize, + "rawfilemd5": rawfilemd5, + "filesize": filesize, + "no_need_thumb": true, + "aeskey": aeskey, }); - self.post("ilink/bot/sendtyping", &body).await + self.post_with_timeout("ilink/bot/getuploadurl", &body, DEFAULT_API_TIMEOUT) + .await } - /// Requests a pre-signed upload URL for a file of the given name and size. - pub async fn get_upload_url(&self, file_name: &str, file_size: u64) -> Result { - let body = serde_json::json!({ - "file_name": file_name, - "file_size": file_size, - }); - self.post("ilink/bot/getuploadurl", &body).await + /// Fetches the bot configuration (e.g. typing ticket). + pub async fn get_config(&self) -> Result { + self.post_with_timeout("ilink/bot/getconfig", &serde_json::json!({}), DEFAULT_CONFIG_TIMEOUT) + .await } } @@ -253,4 +280,80 @@ mod tests { client.set_token("new_token"); assert_eq!(client.token, "new_token"); } + + #[test] + fn test_headers_contain_base64_uin() { + let client = WeixinApiClient::new("https://example.com", "tok", None); + let headers = client.headers(); + let uin = headers + .get("x-wechat-uin") + .unwrap() + .to_str() + .unwrap(); + assert!(base64::engine::general_purpose::STANDARD.decode(uin).is_ok()); + } + + #[test] + fn test_headers_have_content_type() { + let client = WeixinApiClient::new("https://example.com", "tok", None); + let headers = client.headers(); + assert_eq!( + headers.get("content-type").unwrap().to_str().unwrap(), + "application/json" + ); + } + + #[test] + fn test_headers_skip_auth_when_empty_token() { + let client = WeixinApiClient::new("https://example.com", "", None); + let headers = client.headers(); + assert!(headers.get("authorization").is_none()); + } + + #[test] + fn test_check_response_ok() { + let client = WeixinApiClient::new("https://example.com", "tok", None); + let resp = serde_json::json!({"errcode": 0, "errmsg": "ok"}); + assert!(client.check_response(&resp).is_ok()); + } + + #[test] + fn test_check_response_session_expired() { + let client = WeixinApiClient::new("https://example.com", "tok", None); + let resp = serde_json::json!({"errcode": -14}); + let err = client.check_response(&resp).unwrap_err(); + assert!( + matches!(err, crate::Error::SessionExpired), + "expected SessionExpired, got: {err:?}" + ); + } + + #[test] + fn test_check_response_api_error() { + let client = WeixinApiClient::new("https://example.com", "tok", None); + let resp = serde_json::json!({"errcode": 42, "errmsg": "bad request"}); + let err = client.check_response(&resp).unwrap_err(); + assert!( + matches!(err, crate::Error::Api { code: 42, .. }), + "expected Api error with code 42, got: {err:?}" + ); + } + + #[test] + fn test_check_response_ret_field() { + let client = WeixinApiClient::new("https://example.com", "tok", None); + let resp = serde_json::json!({"ret": -14}); + let err = client.check_response(&resp).unwrap_err(); + assert!( + matches!(err, crate::Error::SessionExpired), + "expected SessionExpired via ret field, got: {err:?}" + ); + } + + #[test] + fn test_check_response_no_code() { + let client = WeixinApiClient::new("https://example.com", "tok", None); + let resp = serde_json::json!({"data": "something"}); + assert!(client.check_response(&resp).is_ok()); + } } diff --git a/src/bot.rs b/src/bot.rs index e313ce3..ea673f3 100644 --- a/src/bot.rs +++ b/src/bot.rs @@ -24,12 +24,14 @@ pub async fn login(options: LoginOptions) -> Result { let qrcode_url = qr_resp["qrcode_img_content"] .as_str() .context(LoginFailedSnafu { - reason: "no qrcode_img_content in response", + reason: "no qrcode_url", + })?; + let qrcode = qr_resp["data"]["qrcode"] + .as_str() + .or_else(|| qr_resp["data"]["qrcode_id"].as_str()) + .context(LoginFailedSnafu { + reason: "no qrcode", })?; - let qrcode_id = qr_resp["qrcode"].as_str().context(LoginFailedSnafu { - reason: "no qrcode in response", - })?; - let qr = qrcode::QrCode::new(qrcode_url.as_bytes()).map_err(|e| { LoginFailedSnafu { reason: format!("QR generation failed: {e}"), @@ -46,8 +48,7 @@ pub async fn login(options: LoginOptions) -> Result { loop { tokio::time::sleep(std::time::Duration::from_secs(2)).await; - let status_resp = client.get_qr_code_status(qrcode_id).await?; - // Try top-level field first (v2 API), fall back to nested data.status (v1) + let status_resp = client.get_qr_code_status(qrcode).await?; let status = status_resp["status"] .as_str() .or_else(|| status_resp["data"]["status"].as_str()) diff --git a/src/media.rs b/src/media.rs index 075ad66..768840e 100644 --- a/src/media.rs +++ b/src/media.rs @@ -98,7 +98,11 @@ pub async fn upload_media(api_client: &WeixinApiClient, file_path: &Path) -> cra let aes_key_hex = hex::encode(key); let encrypted = encrypt_aes_ecb(&key, &data); - let upload_info = api_client.get_upload_url(file_name, file_size).await?; + let raw_md5 = format!("{:x}", md5::compute(&data)); + let encrypted_size = encrypted.len() as u64; + let upload_info = api_client + .get_upload_url(file_name, 0, "", file_size, &raw_md5, encrypted_size, &aes_key_hex) + .await?; let upload_url = upload_info["data"]["upload_url"].as_str().ok_or_else(|| { ApiSnafu { code: -1_i64, diff --git a/src/runtime.rs b/src/runtime.rs index f32db59..e8160c9 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -133,10 +133,11 @@ async fn process_message( let text = body_from_item_list(&item_list); if let Some(echo_text) = text.strip_prefix("/echo ") { + let items = vec![serde_json::json!({"type": 0, "body": echo_text})]; api_client .lock() .await - .send_text_message(to_user_id, context_token, echo_text) + .send_message(to_user_id, context_token, &items) .await?; return Ok(()); } @@ -144,7 +145,7 @@ async fn process_message( let _ = api_client .lock() .await - .send_typing(to_user_id, context_token) + .send_typing(to_user_id, context_token, 1) .await; let incoming_media = extract_media_from_items(&item_list).await; @@ -187,19 +188,20 @@ async fn process_message( "body": uploaded, }); + let mut items = vec![]; + if let Some(ref t) = response.text { + items.push(serde_json::json!({"type": 0, "body": t})); + } + items.push(file_info); client - .send_media_message( - to_user_id, - context_token, - response.text.as_deref(), - &file_info, - ) + .send_message(to_user_id, context_token, &items) .await?; drop(client); } else if let Some(text) = &response.text { let plain = markdown_to_plain_text(text); + let items = vec![serde_json::json!({"type": 0, "body": plain})]; client - .send_text_message(to_user_id, context_token, &plain) + .send_message(to_user_id, context_token, &items) .await?; drop(client); } From 1b5cf5089b221f33c8e5ef92a4820c6e8c127214 Mon Sep 17 00:00:00 2001 From: crrow Date: Wed, 25 Mar 2026 17:42:05 +0900 Subject: [PATCH 4/6] fix(media): align download/upload with Python SDK (#14) - Add constants UPLOAD_MEDIA_IMAGE/VIDEO/FILE/VOICE and MAX_UPLOAD_RETRIES - Rewrite parse_aes_key to support base64-encoded keys (raw 16-byte or hex) - Add aes_ecb_padded_size helper - Change download_media to use encrypted_query_param URL pattern with subdir - Add download_media_from_item for per-type field extraction (image/voice/file/video) - Rewrite upload_media with full Python SDK flow: random filekey, CDN POST, x-encrypted-param header extraction, retry logic, UploadResult struct - Update runtime.rs call sites minimally to compile (Task 5 will fix properly) - Add tests for base64 key parsing and padded size calculation Closes #14 Co-Authored-By: Claude Opus 4.6 --- src/media.rs | 379 +++++++++++++++++++++++++++++++++++++++++-------- src/runtime.rs | 55 +++---- 2 files changed, 341 insertions(+), 93 deletions(-) diff --git a/src/media.rs b/src/media.rs index 768840e..1ca01cc 100644 --- a/src/media.rs +++ b/src/media.rs @@ -1,6 +1,7 @@ use std::path::{Path, PathBuf}; use aes::Aes128; +use base64::Engine as _; use block_padding::Pkcs7; use cipher::{BlockDecryptMut as _, BlockEncryptMut as _, KeyInit}; use ecb; @@ -11,6 +12,7 @@ use snafu::ResultExt; use crate::{ api::WeixinApiClient, errors::{ApiSnafu, EncryptionSnafu, HttpSnafu, IoSnafu}, + models::MediaType, storage::CDN_BASE_URL, }; @@ -19,6 +21,29 @@ type Aes128EcbDec = ecb::Decryptor; const MEDIA_DIR: &str = "/tmp/weixin-agent/media"; +/// Upload media type: image. +pub const UPLOAD_MEDIA_IMAGE: u8 = 1; +/// Upload media type: video. +pub const UPLOAD_MEDIA_VIDEO: u8 = 2; +/// Upload media type: file. +pub const UPLOAD_MEDIA_FILE: u8 = 3; +/// Upload media type: voice. +pub const UPLOAD_MEDIA_VOICE: u8 = 4; + +const MAX_UPLOAD_RETRIES: u8 = 3; + +/// Result of uploading a media file to the CDN. +pub struct UploadResult { + /// The encrypted query parameter for constructing download URLs. + pub encrypt_query_param: String, + /// The AES key as base64-encoded hex string. + pub aes_key: String, + /// The original file name. + pub file_name: String, + /// The original file size in bytes. + pub file_size: u64, +} + /// Encrypts `data` using AES-128 in ECB mode with PKCS7 padding. pub fn encrypt_aes_ecb(key: &[u8; 16], data: &[u8]) -> Vec { let enc = Aes128EcbEnc::new(key.into()); @@ -36,32 +61,71 @@ pub fn decrypt_aes_ecb(key: &[u8; 16], data: &[u8]) -> crate::Result> { }) } -/// Parses a hex-encoded AES-128 key string into a 16-byte array. -pub fn parse_aes_key(hex_key: &str) -> crate::Result<[u8; 16]> { - let bytes = hex::decode(hex_key).map_err(|e| { - EncryptionSnafu { - reason: format!("invalid hex key: {e}"), - } - .build() - })?; - bytes.try_into().map_err(|_| { - EncryptionSnafu { - reason: "AES key must be 16 bytes".to_owned(), - } - .build() - }) +/// Parses an AES-128 key from a hex string or base64-encoded string. +/// +/// Supports three formats: +/// - Direct 32-character hex string (decodes to 16 bytes) +/// - Base64-encoded 16 raw bytes +/// - Base64-encoded 32-character hex string (decoded recursively) +pub fn parse_aes_key(key_str: &str) -> crate::Result<[u8; 16]> { + // Try direct hex decode first (32-char hex string) + if key_str.len() == 32 + && let Ok(bytes) = hex::decode(key_str) + && let Ok(arr) = <[u8; 16]>::try_from(bytes.as_slice()) + { + return Ok(arr); + } + // Try base64 decode + let decoded = base64::engine::general_purpose::STANDARD + .decode(key_str) + .map_err(|e| { + EncryptionSnafu { + reason: format!("invalid key encoding: {e}"), + } + .build() + })?; + if decoded.len() == 16 { + return decoded.try_into().map_err(|_| { + EncryptionSnafu { + reason: "AES key must be 16 bytes".to_owned(), + } + .build() + }); + } + if decoded.len() == 32 { + let hex_str = std::str::from_utf8(&decoded).map_err(|e| { + EncryptionSnafu { + reason: format!("invalid hex in base64: {e}"), + } + .build() + })?; + // Recurse to hex-decode the inner string + return parse_aes_key(hex_str); + } + Err(EncryptionSnafu { + reason: format!("unexpected key length: {}", decoded.len()), + } + .build()) +} + +/// Calculates the AES-ECB padded ciphertext size for a given plaintext size. +pub const fn aes_ecb_padded_size(size: u64) -> u64 { + ((size / 16) + 1) * 16 } /// Downloads and decrypts a media file from the `WeChat` CDN. /// +/// Uses the `encrypted_query_param` URL pattern to construct the download URL. /// Returns the local filesystem path where the decrypted file was saved. pub async fn download_media( - file_key: &str, - aes_key_hex: &str, + encrypt_query_param: &str, + aes_key_str: &str, file_name: Option<&str>, + subdir: &str, ) -> crate::Result { - let key = parse_aes_key(aes_key_hex)?; - let url = format!("{CDN_BASE_URL}/{file_key}"); + let key = parse_aes_key(aes_key_str)?; + let encoded = urlencoding::encode(encrypt_query_param); + let url = format!("{CDN_BASE_URL}/download?encrypted_query_param={encoded}"); let client = Client::new(); let encrypted_bytes = client .get(&url) @@ -73,8 +137,8 @@ pub async fn download_media( .context(HttpSnafu)?; let decrypted = decrypt_aes_ecb(&key, &encrypted_bytes)?; - let dir = Path::new(MEDIA_DIR); - std::fs::create_dir_all(dir).context(IoSnafu)?; + let dir = Path::new(MEDIA_DIR).join(subdir); + std::fs::create_dir_all(&dir).context(IoSnafu)?; let name = file_name.unwrap_or("download"); let path = dir.join(format!("{}_{}", uuid::Uuid::new_v4(), name)); @@ -82,57 +146,223 @@ pub async fn download_media( Ok(path) } +/// Extracts the `encrypt_query_param` and `aes_key` from a media sub-item JSON node. +fn extract_media_fields<'a>( + sub_item: &'a Value, + type_name: &str, +) -> crate::Result<(&'a str, &'a str)> { + let eqp = sub_item["media"]["encrypt_query_param"] + .as_str() + .ok_or_else(|| { + ApiSnafu { + code: -1_i64, + message: format!("missing encrypt_query_param for {type_name}"), + } + .build() + })?; + let key = sub_item["media"]["aes_key"].as_str().ok_or_else(|| { + ApiSnafu { + code: -1_i64, + message: format!("missing aes key for {type_name}"), + } + .build() + })?; + Ok((eqp, key)) +} + +/// Downloads media from an incoming message item, handling per-type field structures. +/// +/// Extracts the appropriate fields based on `item_type`: +/// - IMAGE (type=2): `image_item.media.encrypt_query_param` + hex/base64 aes key +/// - VOICE (type=3): `voice_item.media.encrypt_query_param` + `voice_item.media.aes_key` +/// - FILE (type=4): `file_item.media.encrypt_query_param` + `file_item.media.aes_key` + file name +/// - VIDEO (type=5): `video_item.media.encrypt_query_param` + `video_item.media.aes_key` +/// +/// Returns `(path, media_type, mime_type, file_name)`. +pub async fn download_media_from_item( + item: &Value, + item_type: u64, +) -> crate::Result<(PathBuf, MediaType, String, Option)> { + let (encrypt_query_param, aes_key_str, file_name, media_type, subdir) = match item_type { + 2 => { + let image_item = &item["image_item"]; + let eqp = image_item["media"]["encrypt_query_param"] + .as_str() + .ok_or_else(|| { + ApiSnafu { + code: -1_i64, + message: "missing encrypt_query_param for image".to_owned(), + } + .build() + })?; + // Image: try hex key first (aeskey field), then base64 (media.aes_key) + let key = image_item["aeskey"] + .as_str() + .or_else(|| image_item["media"]["aes_key"].as_str()) + .ok_or_else(|| { + ApiSnafu { + code: -1_i64, + message: "missing aes key for image".to_owned(), + } + .build() + })?; + (eqp, key, None, MediaType::Image, "image") + } + 3 => { + let (eqp, key) = extract_media_fields(&item["voice_item"], "voice")?; + (eqp, key, None, MediaType::Audio, "voice") + } + 4 => { + let (eqp, key) = extract_media_fields(&item["file_item"], "file")?; + let fname = item["file_item"]["file_name"].as_str().map(String::from); + (eqp, key, fname, MediaType::File, "file") + } + 5 => { + let (eqp, key) = extract_media_fields(&item["video_item"], "video")?; + (eqp, key, None, MediaType::Video, "video") + } + _ => { + return Err(ApiSnafu { + code: -1_i64, + message: format!("unsupported media item_type: {item_type}"), + } + .build()); + } + }; + + let file_name_ref = file_name.as_deref(); + let path = download_media(encrypt_query_param, aes_key_str, file_name_ref, subdir).await?; + let mime = mime_guess::from_path(&path) + .first_or_octet_stream() + .to_string(); + Ok((path, media_type, mime, file_name)) +} + /// Encrypts and uploads a local file to the `WeChat` CDN. /// -/// Returns a JSON object containing the `filekey`, `aes_key`, and metadata -/// needed to reference the uploaded file in a message. -pub async fn upload_media(api_client: &WeixinApiClient, file_path: &Path) -> crate::Result { +/// Follows the Python SDK upload flow: +/// 1. Read file, compute raw size + MD5 +/// 2. Generate random filekey (16 bytes hex) and AES key (16 bytes) +/// 3. Request a pre-signed upload URL from the API +/// 4. AES-ECB encrypt the file data +/// 5. POST encrypted data to the CDN +/// 6. Extract `x-encrypted-param` header from response +/// 7. Return [`UploadResult`] with the encrypted query param and metadata +/// +/// Retries up to 3 times on non-4xx failures. +pub async fn upload_media( + api_client: &WeixinApiClient, + file_path: &Path, + media_type: u8, + to_user_id: &str, +) -> crate::Result { let file_name = file_path .file_name() .and_then(|n| n.to_str()) .unwrap_or("file"); let data = std::fs::read(file_path).context(IoSnafu)?; - let file_size = data.len() as u64; + let raw_size = data.len() as u64; + let raw_md5 = format!("{:x}", md5::compute(&data)); - let key: [u8; 16] = rand::random(); - let aes_key_hex = hex::encode(key); - let encrypted = encrypt_aes_ecb(&key, &data); + // Generate random filekey and AES key + let filekey_bytes: [u8; 16] = rand::random(); + let filekey = hex::encode(filekey_bytes); + let aes_key: [u8; 16] = rand::random(); + let aes_key_hex = hex::encode(aes_key); + + let file_size = aes_ecb_padded_size(raw_size); - let raw_md5 = format!("{:x}", md5::compute(&data)); - let encrypted_size = encrypted.len() as u64; let upload_info = api_client - .get_upload_url(file_name, 0, "", file_size, &raw_md5, encrypted_size, &aes_key_hex) + .get_upload_url( + &filekey, + media_type, + to_user_id, + raw_size, + &raw_md5, + file_size, + &aes_key_hex, + ) .await?; - let upload_url = upload_info["data"]["upload_url"].as_str().ok_or_else(|| { - ApiSnafu { - code: -1_i64, - message: "no upload_url in response".to_owned(), - } - .build() - })?; - let file_key = upload_info["data"]["file_key"] + let upload_url = upload_info["data"]["upload_url"] .as_str() - .unwrap_or("") - .to_string(); + .ok_or_else(|| { + ApiSnafu { + code: -1_i64, + message: "no upload_url in response".to_owned(), + } + .build() + })?; + let encrypted = encrypt_aes_ecb(&aes_key, &data); let client = Client::new(); - client - .put(upload_url) - .body(encrypted) - .send() - .await - .context(HttpSnafu)?; - let mime = mime_guess::from_path(file_path) - .first_or_octet_stream() - .to_string(); + // Retry loop: up to MAX_UPLOAD_RETRIES attempts, no retry on 4xx + let mut last_err = None; + for _ in 0..MAX_UPLOAD_RETRIES { + let resp = client + .post(upload_url) + .header("Content-Type", "application/octet-stream") + .body(encrypted.clone()) + .send() + .await; + + match resp { + Ok(response) => { + let status = response.status(); + if status.is_client_error() { + return Err(ApiSnafu { + code: i64::from(status.as_u16()), + message: format!("CDN upload failed with {status}"), + } + .build()); + } + if !status.is_success() { + last_err = Some( + ApiSnafu { + code: i64::from(status.as_u16()), + message: format!("CDN upload failed with {status}"), + } + .build(), + ); + continue; + } + // Extract x-encrypted-param header + let encrypt_query_param = response + .headers() + .get("x-encrypted-param") + .and_then(|v| v.to_str().ok()) + .ok_or_else(|| { + ApiSnafu { + code: -1_i64, + message: "missing x-encrypted-param header in CDN response".to_owned(), + } + .build() + })? + .to_string(); + + // Base64-encode the hex key string for the result + let aes_key_b64 = + base64::engine::general_purpose::STANDARD.encode(aes_key_hex.as_bytes()); + + return Ok(UploadResult { + encrypt_query_param, + aes_key: aes_key_b64, + file_name: file_name.to_string(), + file_size: raw_size, + }); + } + Err(e) => { + last_err = Some(crate::Error::Http { source: e }); + } + } + } - Ok(serde_json::json!({ - "filekey": file_key, - "aes_key": aes_key_hex, - "file_name": file_name, - "file_size": file_size, - "mime_type": mime, + Err(last_err.unwrap_or_else(|| { + ApiSnafu { + code: -1_i64, + message: "upload failed after retries".to_owned(), + } + .build() })) } @@ -177,18 +407,39 @@ mod tests { } #[test] - fn test_parse_aes_key_valid() { - let hex_key = "00112233445566778899aabbccddeeff"; - let key = parse_aes_key(hex_key).unwrap(); + fn test_parse_aes_key_direct_hex() { + let key = parse_aes_key("00112233445566778899aabbccddeeff").unwrap(); assert_eq!( key, [ - 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, - 0xee, 0xff + 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, + 0xdd, 0xee, 0xff ] ); } + #[test] + fn test_parse_aes_key_base64_hex() { + let b64 = + base64::engine::general_purpose::STANDARD.encode("00112233445566778899aabbccddeeff"); + let key = parse_aes_key(&b64).unwrap(); + assert_eq!( + key, + [ + 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, + 0xdd, 0xee, 0xff + ] + ); + } + + #[test] + fn test_parse_aes_key_base64_raw() { + let raw_key = [0x42u8; 16]; + let b64 = base64::engine::general_purpose::STANDARD.encode(raw_key); + let key = parse_aes_key(&b64).unwrap(); + assert_eq!(key, raw_key); + } + #[test] fn test_parse_aes_key_invalid_hex() { let result = parse_aes_key("zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz"); @@ -224,4 +475,12 @@ mod tests { "expected Encryption error, got: {err}" ); } + + #[test] + fn test_aes_ecb_padded_size() { + assert_eq!(aes_ecb_padded_size(0), 16); + assert_eq!(aes_ecb_padded_size(1), 16); + assert_eq!(aes_ecb_padded_size(16), 32); + assert_eq!(aes_ecb_padded_size(17), 32); + } } diff --git a/src/runtime.rs b/src/runtime.rs index e8160c9..1b9bc9e 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -8,8 +8,8 @@ use tracing::{error, warn}; use crate::{ api::WeixinApiClient, errors::{HttpSnafu, IoSnafu}, - media::{download_media, upload_media}, - models::{Agent, ChatRequest, IncomingMedia, MediaType}, + media::{download_media_from_item, upload_media}, + models::{Agent, ChatRequest, IncomingMedia}, storage, }; @@ -175,17 +175,23 @@ async fn process_message( let tmp_path = tmp_dir.join(format!("{}_{file_name}", uuid::Uuid::new_v4())); std::fs::write(&tmp_path, &media_bytes).context(IoSnafu)?; - let uploaded = upload_media(&client, &tmp_path).await?; - let media_type_id = match media.media_type { crate::models::OutgoingMediaType::Image => 1, crate::models::OutgoingMediaType::Video => 2, crate::models::OutgoingMediaType::File => 3, }; + // TODO(Task 5): properly wire media_type and to_user_id + let uploaded = upload_media(&client, &tmp_path, media_type_id, to_user_id).await?; + let file_info = serde_json::json!({ "type": media_type_id, - "body": uploaded, + "body": { + "encrypt_query_param": uploaded.encrypt_query_param, + "aes_key": uploaded.aes_key, + "file_name": uploaded.file_name, + "file_size": uploaded.file_size, + }, }); let mut items = vec![]; @@ -212,34 +218,17 @@ async fn process_message( async fn extract_media_from_items(item_list: &[Value]) -> Option { for item in item_list { let item_type = item["type"].as_u64().unwrap_or(0); - if matches!(item_type, 1..=5) { - let file_key = item["body"]["filekey"] - .as_str() - .or_else(|| item["filekey"].as_str())?; - let aes_key = item["body"]["aes_key"] - .as_str() - .or_else(|| item["aes_key"].as_str())?; - let file_name = item["body"]["file_name"] - .as_str() - .or_else(|| item["file_name"].as_str()); - - if let Ok(path) = download_media(file_key, aes_key, file_name).await { - let media_type = match item_type { - 1 => MediaType::Image, - 2 => MediaType::Video, - 4 | 5 => MediaType::Audio, - _ => MediaType::File, - }; - let mime = mime_guess::from_path(&path) - .first_or_octet_stream() - .to_string(); - return Some(IncomingMedia { - media_type, - file_path: path.to_string_lossy().to_string(), - mime_type: mime, - file_name: file_name.map(String::from), - }); - } + // Types 2-5 are media items (image, voice, file, video) + if matches!(item_type, 2..=5) + && let Ok((path, media_type, mime, file_name)) = + download_media_from_item(item, item_type).await + { + return Some(IncomingMedia { + media_type, + file_path: path.to_string_lossy().to_string(), + mime_type: mime, + file_name, + }); } } None From 3b7f0575eb0ce05a169202498b2e3605b5b5adea Mon Sep 17 00:00:00 2001 From: crrow Date: Wed, 25 Mar 2026 17:48:04 +0900 Subject: [PATCH 5/6] fix(runtime): align message processing (#14) - Switch to 1-based message type constants - Rewrite body_from_item_list for ref_msg - Add find_media_item with priority order - Add build_media_send_item per-type format - Replace to_user_id with from_user_id - Split text and media into separate sends - Forward error text to user on failure - Add typing cancel after response - Sleep 1 hour on session expiry - Extract send_outgoing_media helper - Update all tests for new type values Closes #14 Co-Authored-By: Claude Opus 4.6 --- src/runtime.rs | 379 ++++++++++++++++++++++++++++++++++++------------- 1 file changed, 284 insertions(+), 95 deletions(-) diff --git a/src/runtime.rs b/src/runtime.rs index 1b9bc9e..53b8bfc 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -1,4 +1,4 @@ -use std::{path::Path, sync::Arc}; +use std::sync::Arc; use serde_json::Value; use snafu::ResultExt; @@ -8,11 +8,26 @@ use tracing::{error, warn}; use crate::{ api::WeixinApiClient, errors::{HttpSnafu, IoSnafu}, - media::{download_media_from_item, upload_media}, - models::{Agent, ChatRequest, IncomingMedia}, + media, + models::{Agent, ChatRequest, IncomingMedia, OutgoingMediaType}, storage, }; +/// Message item type: text (1-based, aligned with Python SDK). +const MESSAGE_ITEM_TEXT: u64 = 1; +/// Message item type: image. +const MESSAGE_ITEM_IMAGE: u64 = 2; +/// Message item type: voice. +const MESSAGE_ITEM_VOICE: u64 = 3; +/// Message item type: file. +const MESSAGE_ITEM_FILE: u64 = 4; +/// Message item type: video. +const MESSAGE_ITEM_VIDEO: u64 = 5; +/// Typing indicator status: currently typing. +const TYPING_STATUS_TYPING: u8 = 1; +/// Typing indicator status: cancel typing. +const TYPING_STATUS_CANCEL: u8 = 2; + /// Strips `Markdown` formatting from text, returning a plain-text /// approximation. pub fn markdown_to_plain_text(text: &str) -> String { @@ -37,35 +52,97 @@ pub fn markdown_to_plain_text(text: &str) -> String { } /// Extracts the text body from a `WeChat` `item_list` JSON array. +/// +/// Handles text items (type=1) with optional quoted messages in `ref_msg`, +/// and voice items (type=3) via transcription. pub fn body_from_item_list(item_list: &[Value]) -> String { let mut parts = vec![]; for item in item_list { let item_type = item["type"].as_u64().unwrap_or(0); match item_type { - 0 => { + MESSAGE_ITEM_TEXT => { if let Some(body) = item["body"].as_str() { parts.push(body.to_string()); } + if let Some(ref_msg) = item.get("ref_msg") + && let Some(ref_items) = ref_msg["item_list"].as_array() + { + let ref_text = body_from_item_list(ref_items); + if !ref_text.is_empty() { + parts.push(format!("[Quoted: {ref_text}]")); + } + } } - 5 => { + MESSAGE_ITEM_VOICE => { if let Some(trans) = item["voice_transcription_body"].as_str() { parts.push(trans.to_string()); } } - 7 => { - if let Some(ref_list) = item["ref_item_list"].as_array() { - let ref_text = body_from_item_list(ref_list); - if !ref_text.is_empty() { - parts.push(format!("> {ref_text}")); - } - } - } _ => {} } } parts.join("\n") } +/// Finds the highest-priority media item in the given item list. +/// +/// Priority: IMAGE > VIDEO > FILE > VOICE (only when no text) > `ref_msg` media. +/// Returns the item JSON and its type code. +fn find_media_item(item_list: &[Value], has_text: bool) -> Option<(Value, u64)> { + // Check for image, video, file in priority order + for &target in &[MESSAGE_ITEM_IMAGE, MESSAGE_ITEM_VIDEO, MESSAGE_ITEM_FILE] { + for item in item_list { + if item["type"].as_u64() == Some(target) { + return Some((item.clone(), target)); + } + } + } + // Voice only when there is no text body + if !has_text { + for item in item_list { + if item["type"].as_u64() == Some(MESSAGE_ITEM_VOICE) { + return Some((item.clone(), MESSAGE_ITEM_VOICE)); + } + } + } + // Recurse into quoted/referenced messages + for item in item_list { + if let Some(ref_msg) = item.get("ref_msg") + && let Some(ref_items) = ref_msg["item_list"].as_array() + && let Some(found) = find_media_item(ref_items, has_text) + { + return Some(found); + } + } + None +} + +/// Builds a per-type outgoing media item JSON for `send_message`. +fn build_media_send_item( + upload: &media::UploadResult, + outgoing_type: OutgoingMediaType, +) -> Value { + let media_obj = serde_json::json!({ + "encrypt_query_param": upload.encrypt_query_param, + "aes_key": upload.aes_key, + "encrypt_type": 1, + }); + match outgoing_type { + OutgoingMediaType::Video => serde_json::json!({ + "type": MESSAGE_ITEM_VIDEO, + "video_item": {"media": media_obj, "video_size": upload.file_size} + }), + OutgoingMediaType::Image => serde_json::json!({ + "type": MESSAGE_ITEM_IMAGE, + "image_item": {"media": media_obj, "mid_size": upload.file_size} + }), + OutgoingMediaType::File => serde_json::json!({ + "type": MESSAGE_ITEM_FILE, + "file_item": {"media": media_obj, "file_name": upload.file_name, "len": upload.file_size} + }), + } +} + /// Runs the long-polling message loop, dispatching each message to `agent`. pub async fn monitor_weixin( api_client: Arc>, @@ -102,7 +179,8 @@ pub async fn monitor_weixin( } } Err(crate::Error::SessionExpired) => { - warn!("Session expired, need to re-login"); + warn!("Session expired, sleeping 1 hour before exit"); + tokio::time::sleep(std::time::Duration::from_secs(3600)).await; break; } Err(crate::Error::Http { ref source }) if source.is_timeout() => {} @@ -121,119 +199,143 @@ pub async fn monitor_weixin( } } +/// Downloads an outgoing media file from its URL, uploads to CDN, and sends it. +async fn send_outgoing_media( + client: &WeixinApiClient, + outgoing_media: &crate::models::OutgoingMedia, + to_user_id: &str, + context_token: &str, +) -> crate::Result<()> { + let http_client = reqwest::Client::new(); + let media_bytes = http_client + .get(&outgoing_media.url) + .send() + .await + .context(HttpSnafu)? + .bytes() + .await + .context(HttpSnafu)?; + let tmp_dir = std::path::Path::new("/tmp/weixin-agent/media/upload"); + std::fs::create_dir_all(tmp_dir).context(IoSnafu)?; + let file_name = outgoing_media.file_name.as_deref().unwrap_or("file"); + let tmp_path = tmp_dir.join(format!("{}_{file_name}", uuid::Uuid::new_v4())); + std::fs::write(&tmp_path, &media_bytes).context(IoSnafu)?; + let upload_media_type = match outgoing_media.media_type { + OutgoingMediaType::Image => media::UPLOAD_MEDIA_IMAGE, + OutgoingMediaType::Video => media::UPLOAD_MEDIA_VIDEO, + OutgoingMediaType::File => media::UPLOAD_MEDIA_FILE, + }; + let uploaded = media::upload_media(client, &tmp_path, upload_media_type, to_user_id).await?; + let media_item = build_media_send_item(&uploaded, outgoing_media.media_type); + client + .send_message(to_user_id, context_token, &[media_item]) + .await?; + Ok(()) +} + async fn process_message( api_client: Arc>, agent: Arc, msg: &Value, ) -> crate::Result<()> { let item_list = msg["item_list"].as_array().cloned().unwrap_or_default(); - let to_user_id = msg["to_user_id"].as_str().unwrap_or(""); + let from_user_id = msg["from_user_id"].as_str().unwrap_or(""); let context_token = msg["context_token"].as_str().unwrap_or(""); - let text = body_from_item_list(&item_list); + // Slash commands if let Some(echo_text) = text.strip_prefix("/echo ") { - let items = vec![serde_json::json!({"type": 0, "body": echo_text})]; + let item = serde_json::json!({"type": MESSAGE_ITEM_TEXT, "body": echo_text}); api_client .lock() .await - .send_message(to_user_id, context_token, &items) + .send_message(from_user_id, context_token, &[item]) .await?; return Ok(()); } + // Typing indicator let _ = api_client .lock() .await - .send_typing(to_user_id, context_token, 1) + .send_typing(from_user_id, context_token, TYPING_STATUS_TYPING) .await; - let incoming_media = extract_media_from_items(&item_list).await; + // Extract and download media + let has_text = !text.is_empty(); + let incoming_media = + if let Some((media_item, media_type)) = find_media_item(&item_list, has_text) { + match media::download_media_from_item(&media_item, media_type).await { + Ok((path, mt, mime, fname)) => Some(IncomingMedia { + media_type: mt, + file_path: path.to_string_lossy().to_string(), + mime_type: mime, + file_name: fname, + }), + Err(e) => { + warn!("Failed to download media: {e}"); + None + } + } + } else { + None + }; let request = ChatRequest { - conversation_id: to_user_id.to_string(), + conversation_id: from_user_id.to_string(), text, media: incoming_media, }; - let response = agent.chat(request).await?; + // Call agent -- on error, send error text to user then propagate + let response = match agent.chat(request).await { + Ok(resp) => resp, + Err(e) => { + error!("Agent error: {e}"); + let client = api_client.lock().await; + let err_item = serde_json::json!({ + "type": MESSAGE_ITEM_TEXT, + "body": format!("Error: {e}") + }); + let _ = client + .send_message(from_user_id, context_token, &[err_item]) + .await; + let _ = client + .send_typing(from_user_id, context_token, TYPING_STATUS_CANCEL) + .await; + drop(client); + return Err(e); + } + }; let client = api_client.lock().await; - if let Some(ref media) = response.media { - let http_client = reqwest::Client::new(); - let media_bytes = http_client - .get(&media.url) - .send() - .await - .context(HttpSnafu)? - .bytes() - .await - .context(HttpSnafu)?; - let tmp_dir = Path::new("/tmp/weixin-agent/media"); - std::fs::create_dir_all(tmp_dir).context(IoSnafu)?; - let file_name = media.file_name.as_deref().unwrap_or("file"); - let tmp_path = tmp_dir.join(format!("{}_{file_name}", uuid::Uuid::new_v4())); - std::fs::write(&tmp_path, &media_bytes).context(IoSnafu)?; - - let media_type_id = match media.media_type { - crate::models::OutgoingMediaType::Image => 1, - crate::models::OutgoingMediaType::Video => 2, - crate::models::OutgoingMediaType::File => 3, - }; - // TODO(Task 5): properly wire media_type and to_user_id - let uploaded = upload_media(&client, &tmp_path, media_type_id, to_user_id).await?; - - let file_info = serde_json::json!({ - "type": media_type_id, - "body": { - "encrypt_query_param": uploaded.encrypt_query_param, - "aes_key": uploaded.aes_key, - "file_name": uploaded.file_name, - "file_size": uploaded.file_size, - }, - }); - - let mut items = vec![]; - if let Some(ref t) = response.text { - items.push(serde_json::json!({"type": 0, "body": t})); + // Send text and media as SEPARATE messages (aligned with Python SDK) + if let Some(ref outgoing_media) = response.media { + if let Some(ref resp_text) = response.text { + let plain = markdown_to_plain_text(resp_text); + let text_item = serde_json::json!({"type": MESSAGE_ITEM_TEXT, "body": plain}); + let _ = client + .send_message(from_user_id, context_token, &[text_item]) + .await; } - items.push(file_info); - client - .send_message(to_user_id, context_token, &items) - .await?; - drop(client); - } else if let Some(text) = &response.text { - let plain = markdown_to_plain_text(text); - let items = vec![serde_json::json!({"type": 0, "body": plain})]; + send_outgoing_media(&client, outgoing_media, from_user_id, context_token).await?; + } else if let Some(ref resp_text) = response.text { + let plain = markdown_to_plain_text(resp_text); + let text_item = serde_json::json!({"type": MESSAGE_ITEM_TEXT, "body": plain}); client - .send_message(to_user_id, context_token, &items) + .send_message(from_user_id, context_token, &[text_item]) .await?; - drop(client); } + // Cancel typing + let _ = client + .send_typing(from_user_id, context_token, TYPING_STATUS_CANCEL) + .await; + drop(client); Ok(()) } -async fn extract_media_from_items(item_list: &[Value]) -> Option { - for item in item_list { - let item_type = item["type"].as_u64().unwrap_or(0); - // Types 2-5 are media items (image, voice, file, video) - if matches!(item_type, 2..=5) - && let Ok((path, media_type, mime, file_name)) = - download_media_from_item(item, item_type).await - { - return Some(IncomingMedia { - media_type, - file_path: path.to_string_lossy().to_string(), - mime_type: mime, - file_name, - }); - } - } - None -} - #[cfg(test)] mod tests { use super::*; @@ -321,7 +423,7 @@ mod tests { #[test] fn test_text_item() { - let items = vec![serde_json::json!({"type": 0, "body": "hello world"})]; + let items = vec![serde_json::json!({"type": 1, "body": "hello world"})]; let result = body_from_item_list(&items); assert_eq!(result, "hello world"); } @@ -329,7 +431,7 @@ mod tests { #[test] fn test_voice_transcription() { let items = vec![serde_json::json!({ - "type": 5, + "type": 3, "voice_transcription_body": "transcribed text" })]; let result = body_from_item_list(&items); @@ -339,18 +441,19 @@ mod tests { #[test] fn test_quoted_message() { let items = vec![serde_json::json!({ - "type": 7, - "ref_item_list": [{"type": 0, "body": "original message"}] + "type": 1, + "body": "reply", + "ref_msg": {"item_list": [{"type": 1, "body": "original message"}]} })]; let result = body_from_item_list(&items); - assert_eq!(result, "> original message"); + assert_eq!(result, "reply\n[Quoted: original message]"); } #[test] fn test_multiple_items() { let items = vec![ - serde_json::json!({"type": 0, "body": "first"}), - serde_json::json!({"type": 0, "body": "second"}), + serde_json::json!({"type": 1, "body": "first"}), + serde_json::json!({"type": 1, "body": "second"}), ]; let result = body_from_item_list(&items); assert_eq!(result, "first\nsecond"); @@ -369,4 +472,90 @@ mod tests { let result = body_from_item_list(&items); assert_eq!(result, ""); } + + // -- find_media_item tests -- + + #[test] + fn test_find_media_image_priority() { + let items = vec![ + serde_json::json!({"type": MESSAGE_ITEM_FILE, "file_item": {}}), + serde_json::json!({"type": MESSAGE_ITEM_IMAGE, "image_item": {}}), + ]; + let (_, t) = find_media_item(&items, false).unwrap(); + assert_eq!(t, MESSAGE_ITEM_IMAGE); + } + + #[test] + fn test_find_media_voice_skipped_when_text() { + let items = vec![serde_json::json!({"type": MESSAGE_ITEM_VOICE})]; + assert!(find_media_item(&items, true).is_none()); + } + + #[test] + fn test_find_media_voice_when_no_text() { + let items = vec![serde_json::json!({"type": MESSAGE_ITEM_VOICE})]; + let (_, t) = find_media_item(&items, false).unwrap(); + assert_eq!(t, MESSAGE_ITEM_VOICE); + } + + #[test] + fn test_find_media_in_ref_msg() { + let items = vec![serde_json::json!({ + "type": MESSAGE_ITEM_TEXT, + "body": "look at this", + "ref_msg": { + "item_list": [{"type": MESSAGE_ITEM_IMAGE, "image_item": {}}] + } + })]; + let (_, t) = find_media_item(&items, true).unwrap(); + assert_eq!(t, MESSAGE_ITEM_IMAGE); + } + + #[test] + fn test_find_media_none() { + let items = vec![serde_json::json!({"type": MESSAGE_ITEM_TEXT, "body": "hi"})]; + assert!(find_media_item(&items, false).is_none()); + } + + // -- build_media_send_item tests -- + + #[test] + fn test_build_image_send_item() { + let upload = media::UploadResult { + encrypt_query_param: "eqp".to_string(), + aes_key: "key".to_string(), + file_name: "img.png".to_string(), + file_size: 1024, + }; + let item = build_media_send_item(&upload, OutgoingMediaType::Image); + assert_eq!(item["type"], MESSAGE_ITEM_IMAGE); + assert!(item["image_item"]["media"]["encrypt_query_param"].is_string()); + } + + #[test] + fn test_build_video_send_item() { + let upload = media::UploadResult { + encrypt_query_param: "eqp".to_string(), + aes_key: "key".to_string(), + file_name: "vid.mp4".to_string(), + file_size: 2048, + }; + let item = build_media_send_item(&upload, OutgoingMediaType::Video); + assert_eq!(item["type"], MESSAGE_ITEM_VIDEO); + assert!(item["video_item"]["media"]["encrypt_query_param"].is_string()); + } + + #[test] + fn test_build_file_send_item() { + let upload = media::UploadResult { + encrypt_query_param: "eqp".to_string(), + aes_key: "key".to_string(), + file_name: "doc.pdf".to_string(), + file_size: 4096, + }; + let item = build_media_send_item(&upload, OutgoingMediaType::File); + assert_eq!(item["type"], MESSAGE_ITEM_FILE); + assert_eq!(item["file_item"]["file_name"], "doc.pdf"); + assert_eq!(item["file_item"]["len"], 4096); + } } From ce5459970bf505996c236b367d6bc72cd81d95a7 Mon Sep 17 00:00:00 2001 From: crrow Date: Wed, 25 Mar 2026 20:17:28 +0900 Subject: [PATCH 6/6] fix(bot): align login flow with Python SDK (#14) - Add QR code retry loop: outer loop refreshes QR (up to 3 times), inner loop polls status, with 480s global deadline - Handle response fields at root or nested under "data" - Accept both "scanned" and "scaned" status variants - Use storage::normalize_account_id() for account ID normalization Closes #14 --- src/bot.rs | 177 +++++++++++++++++++++++++++++++---------------------- 1 file changed, 103 insertions(+), 74 deletions(-) diff --git a/src/bot.rs b/src/bot.rs index ea673f3..739892f 100644 --- a/src/bot.rs +++ b/src/bot.rs @@ -12,96 +12,125 @@ use crate::{ storage::{self, DEFAULT_BASE_URL}, }; +/// Maximum number of QR code refreshes before giving up. +const MAX_QR_REFRESHES: u8 = 3; + +/// Total deadline for the entire login flow (seconds). +const LOGIN_DEADLINE_SECS: u64 = 480; + /// Performs an interactive QR-code login and persists the resulting /// credentials. /// +/// Uses a two-level loop matching the Python SDK: the outer loop refreshes the +/// QR code (up to 3 times), the inner loop polls scan status, and a global +/// 480-second deadline aborts the whole flow. +/// /// Returns the account ID on success. pub async fn login(options: LoginOptions) -> Result { let base_url = options.base_url.as_deref().unwrap_or(DEFAULT_BASE_URL); let client = WeixinApiClient::new(base_url, "", None); - - let qr_resp = client.fetch_qr_code().await?; - let qrcode_url = qr_resp["qrcode_img_content"] - .as_str() - .context(LoginFailedSnafu { - reason: "no qrcode_url", - })?; - let qrcode = qr_resp["data"]["qrcode"] - .as_str() - .or_else(|| qr_resp["data"]["qrcode_id"].as_str()) - .context(LoginFailedSnafu { - reason: "no qrcode", - })?; - let qr = qrcode::QrCode::new(qrcode_url.as_bytes()).map_err(|e| { - LoginFailedSnafu { - reason: format!("QR generation failed: {e}"), - } - .build() - })?; - let image = qr - .render::() - .quiet_zone(true) - .module_dimensions(2, 1) - .build(); - println!("{image}"); - println!("Scan the QR code above with WeChat to login"); + let deadline = + tokio::time::Instant::now() + std::time::Duration::from_secs(LOGIN_DEADLINE_SECS); + let mut refresh_count = 0u8; loop { - tokio::time::sleep(std::time::Duration::from_secs(2)).await; - let status_resp = client.get_qr_code_status(qrcode).await?; - let status = status_resp["status"] + let qr_resp = client.fetch_qr_code().await?; + // Response fields may be at root or nested under "data" + let qrcode = qr_resp["qrcode"] + .as_str() + .or_else(|| qr_resp["data"]["qrcode"].as_str()) + .context(LoginFailedSnafu { + reason: "no qrcode", + })?; + let qrcode_url = qr_resp["qrcode_url"] .as_str() - .or_else(|| status_resp["data"]["status"].as_str()) - .unwrap_or("unknown"); + .or_else(|| qr_resp["data"]["qrcode_url"].as_str()) + .unwrap_or(qrcode); - match status { - "wait" => {} - "scaned" => { - info!("QR code scanned, waiting for confirmation..."); + let qr = qrcode::QrCode::new(qrcode_url.as_bytes()).map_err(|e| { + LoginFailedSnafu { + reason: format!("QR generation failed: {e}"), } - "expired" => { + .build() + })?; + let image = qr + .render::() + .quiet_zone(true) + .module_dimensions(2, 1) + .build(); + println!("{image}"); + println!("Scan the QR code above with WeChat to login"); + + loop { + if tokio::time::Instant::now() > deadline { return Err(QrCodeExpiredSnafu.build()); } - "confirmed" => { - // v2 API returns credentials at top level; v1 nests under data - let data = if status_resp.get("bot_token").is_some() { - &status_resp - } else { - &status_resp["data"] - }; - let token = data["bot_token"].as_str().context(LoginFailedSnafu { - reason: "no bot_token", - })?; - let bot_id = data["ilink_bot_id"].as_str().context(LoginFailedSnafu { - reason: "no ilink_bot_id", - })?; - let base = data["baseurl"].as_str().unwrap_or(base_url); - let user_id = data["ilink_user_id"].as_str().unwrap_or(""); - - let account_id = bot_id - .strip_prefix("ilink_bot_") - .unwrap_or(bot_id) - .to_string(); - - let account_data = storage::AccountData { - token: token.to_string(), - saved_at: chrono::Utc::now().to_rfc3339(), - base_url: base.to_string(), - user_id: user_id.to_string(), - }; - storage::save_account_data(&account_id, &account_data)?; - - let mut ids = storage::get_account_ids().unwrap_or_default(); - if !ids.contains(&account_id) { - ids.push(account_id.clone()); - storage::save_account_ids(&ids)?; + tokio::time::sleep(std::time::Duration::from_secs(2)).await; + + let status_resp = client.get_qr_code_status(qrcode).await?; + // Status field may be at root or nested under "data" + let status = status_resp["status"] + .as_str() + .or_else(|| status_resp["data"]["status"].as_str()) + .unwrap_or("unknown"); + + match status { + "wait" => {} + "scanned" | "scaned" => { + info!("QR code scanned, waiting for confirmation..."); + } + "expired" => { + refresh_count += 1; + if refresh_count >= MAX_QR_REFRESHES { + return Err(QrCodeExpiredSnafu.build()); + } + warn!( + "QR code expired, refreshing ({refresh_count}/{MAX_QR_REFRESHES})..." + ); + break; // break inner loop to refresh QR in outer loop } + "confirmed" => { + // Confirmed fields may be at root or nested under "data" + let data = if status_resp.get("bot_token").is_some() { + &status_resp + } else { + &status_resp["data"] + }; + let token = data["bot_token"].as_str().context(LoginFailedSnafu { + reason: "no bot_token", + })?; + let bot_id = + data["ilink_bot_id"] + .as_str() + .context(LoginFailedSnafu { + reason: "no ilink_bot_id", + })?; + let base = data["baseurl"].as_str().unwrap_or(base_url); + let user_id = data["ilink_user_id"].as_str().unwrap_or(""); - info!("Login successful! Account ID: {account_id}"); - return Ok(account_id); - } - other => { - warn!("Unknown QR status: {other}"); + let raw_id = bot_id.strip_prefix("ilink_bot_").unwrap_or(bot_id); + let account_id = storage::normalize_account_id(raw_id); + + let account_data = storage::AccountData { + token: token.to_string(), + saved_at: chrono::Utc::now().to_rfc3339(), + base_url: base.to_string(), + user_id: user_id.to_string(), + }; + storage::save_account_data(&account_id, &account_data)?; + + let mut ids = storage::get_account_ids().unwrap_or_default(); + if !ids.contains(&account_id) { + ids.push(account_id.clone()); + storage::save_account_ids(&ids)?; + } + + info!("Login successful! Account ID: {account_id}"); + return Ok(account_id); + } + other => { + warn!("Unknown QR status: {other}"); + } } } }