diff --git a/bhwi-async/src/transport/jade/mod.rs b/bhwi-async/src/transport/jade/mod.rs index e748cf5..0ef91de 100644 --- a/bhwi-async/src/transport/jade/mod.rs +++ b/bhwi-async/src/transport/jade/mod.rs @@ -1,4 +1,34 @@ +use async_trait::async_trait; pub use bhwi::jade::JADE_DEVICE_IDS; +use serde_cbor::Value; #[cfg(feature = "emulators")] pub mod tcp; + +#[async_trait(?Send)] +pub trait CborStream { + async fn write_all(&mut self, command: &[u8]) -> Result<(), std::io::Error>; + async fn read(&mut self, buf: &mut [u8]) -> Result; + + /// Reads from the client until a complete CBOR value is received. + async fn read_cbor_message(&mut self) -> Result, std::io::Error> { + let mut buf = Vec::new(); + let mut chunk = [0u8; 1024]; + + loop { + let n = self.read(&mut chunk).await?; + if n == 0 { + return Err(std::io::Error::other( + "stream ended before complete CBOR message", + )); + } + buf.extend_from_slice(&chunk[..n]); + let mut cursor = std::io::Cursor::new(&buf); + match serde_cbor::from_reader::(&mut cursor) { + Ok(_) => return Ok(buf), + Err(e) if e.is_io() || e.is_eof() => continue, + Err(e) => return Err(std::io::Error::other(e)), + } + } + } +} diff --git a/bhwi-async/src/transport/jade/tcp.rs b/bhwi-async/src/transport/jade/tcp.rs index d0aac42..6478df6 100644 --- a/bhwi-async/src/transport/jade/tcp.rs +++ b/bhwi-async/src/transport/jade/tcp.rs @@ -1,7 +1,6 @@ use async_trait::async_trait; -use serde_cbor::Value; -use crate::Transport; +use crate::{Transport, transport::jade::CborStream}; pub struct TcpTransport { pub client: C, @@ -14,38 +13,11 @@ impl TcpTransport { } #[async_trait(?Send)] -pub trait TcpClient { - async fn write_all(&mut self, command: &[u8]) -> Result<(), std::io::Error>; - async fn read(&mut self, buf: &mut [u8]) -> Result; -} - -#[async_trait(?Send)] -impl Transport for TcpTransport { +impl Transport for TcpTransport { type Error = std::io::Error; async fn exchange(&mut self, command: &[u8], _encrypted: bool) -> Result, Self::Error> { self.client.write_all(command).await?; - - let mut buf = Vec::new(); - let mut temp = [0u8; 1024]; - - // HACK: i don't know a better way right now! - loop { - let n = self.client.read(&mut temp).await?; - // XXX: what happens when n is 0? - buf.extend_from_slice(&temp[..n]); - let mut cursor = std::io::Cursor::new(&buf); - match serde_cbor::from_reader::(&mut cursor) { - Ok(_) => { - return Ok(buf); - } - Err(e) if e.is_io() => { - continue; // read more bytes - } - Err(e) => { - return Err(std::io::Error::other(e)); - } - } - } + self.client.read_cbor_message().await } } diff --git a/bhwi-cli/src/jade.rs b/bhwi-cli/src/jade.rs index bd7bec6..0fa1e35 100644 --- a/bhwi-cli/src/jade.rs +++ b/bhwi-cli/src/jade.rs @@ -4,10 +4,7 @@ use anyhow::Result; use async_trait::async_trait; use bhwi_async::{ HttpClient, Jade, Transport, - transport::jade::{ - JADE_DEVICE_IDS, - tcp::{TcpClient as TcpClientTrait, TcpTransport}, - }, + transport::jade::{CborStream, JADE_DEVICE_IDS, tcp::TcpTransport}, }; use bitcoin::Network; use futures::{TryStreamExt, stream::iter}; @@ -51,11 +48,20 @@ impl SerialTransport { impl Transport for SerialTransport { type Error = std::io::Error; async fn exchange(&mut self, command: &[u8], _encrypted: bool) -> Result, Self::Error> { + self.write_all(command).await?; + self.read_cbor_message().await + } +} + +#[async_trait(?Send)] +impl CborStream for SerialTransport { + async fn write_all(&mut self, command: &[u8]) -> Result<(), std::io::Error> { + let mut stream = self.stream.lock().await; + Ok(stream.write_all(command).await?) + } + async fn read(&mut self, buf: &mut [u8]) -> Result { let mut stream = self.stream.lock().await; - stream.write_all(command).await?; - let mut buf = vec![]; - stream.read_to_end(&mut buf).await?; - Ok(buf) + Ok(stream.read(buf).await?) } } @@ -153,7 +159,7 @@ impl HttpClient for PinServerClient { Ok(self .inner .post(url) - .header("Content-Type", "application/octet-stream") + .header("Content-Type", "application/json") .body(request.to_vec()) .send() .await? @@ -174,7 +180,7 @@ impl TcpClient { } #[async_trait(?Send)] -impl TcpClientTrait for TcpClient { +impl CborStream for TcpClient { async fn write_all(&mut self, command: &[u8]) -> Result<(), std::io::Error> { Ok(self.stream.write_all(command).await?) }