diff --git a/fuzz/.gitignore b/fuzz/.gitignore new file mode 100644 index 0000000..1a45eee --- /dev/null +++ b/fuzz/.gitignore @@ -0,0 +1,4 @@ +target +corpus +artifacts +coverage diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml new file mode 100644 index 0000000..6025d0a --- /dev/null +++ b/fuzz/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "usbd-ctaphid-fuzz" +version = "0.0.0" +publish = false +edition = "2021" + +[package.metadata] +cargo-fuzz = true + +[dependencies] +ctaphid-dispatch = "0.1.0" +libfuzzer-sys = { version = "0.4", features = ["arbitrary-derive"] } + +[dependencies.usbd-ctaphid] +path = ".." + +[[bin]] +name = "buffer" +path = "fuzz_targets/buffer.rs" +test = false +doc = false +bench = false + +[patch.crates-io] +ctaphid-dispatch = { git = "https://github.com/trussed-dev/ctaphid-dispatch.git", rev = "57cb3317878a8593847595319aa03ef17c29ec5b" } +trussed = { git = "https://github.com/trussed-dev/trussed.git", rev = "51e68500d7601d04f884f5e95567d14b9018a6cb" } diff --git a/fuzz/fuzz_targets/buffer.rs b/fuzz/fuzz_targets/buffer.rs new file mode 100644 index 0000000..0ae8be2 --- /dev/null +++ b/fuzz/fuzz_targets/buffer.rs @@ -0,0 +1,75 @@ +#![no_main] + +use ctaphid_dispatch::types::{Channel, Error, InterchangeResponse, Message, Responder}; +use libfuzzer_sys::{arbitrary::{self, Arbitrary}, fuzz_target}; +use usbd_ctaphid::buffer::{Buffer, BufferState}; + +#[derive(Debug, Arbitrary)] +enum Action { + HandlePacket { + data: [u8; 64], + success: bool, + }, + HandleResponse(bool), + TrySendPacket(bool), + CheckTimeout { + milliseconds: u32, + success: bool, + }, + DidStartProcessing, + SendKeepalive(bool), + GenerateResponse(Vec), +} + +impl Action { + fn run(self, buffer: &mut Buffer<'_, '_>, rp: &mut Responder<'_>) { + match self { + Self::HandlePacket { data, success } => { + let state = buffer.handle_packet(&data); + self.handle_state(buffer, state, success); + } + Self::HandleResponse(success) => { + let state = buffer.handle_response(); + self.handle_state(buffer, state, success); + } + Self::TrySendPacket(success) => { + self.try_send_packet(buffer, success); + } + Self::CheckTimeout { milliseconds, success } => { + let state = buffer.check_timeout(milliseconds); + self.handle_state(buffer, state, success); + } + Self::DidStartProcessing => { + buffer.did_start_processing(); + } + Self::SendKeepalive(waiting) => { + let _ = buffer.send_keepalive(waiting); + } + Self::GenerateResponse(response) => { + if let Ok(_request) = rp.request() { + let response = Message::from_slice(&response).map_err(|_| Error::InvalidLength); + rp.respond(InterchangeResponse(response)).ok(); + } + } + } + } + + fn handle_state(&self, buffer: &mut Buffer, state: BufferState, success: bool) { + if state == BufferState::ResponseQueued { + self.try_send_packet(buffer, success); + } + } + + fn try_send_packet(&self, buffer: &mut Buffer, success: bool) { + buffer.try_send_packet(|_| if success { Ok(()) } else { Err(()) }); + } +} + +fuzz_target!(|actions: Vec| { + let channel = Channel::new(); + let (rq, mut rp) = channel.split().unwrap(); + let mut buffer = Buffer::new(rq, 0, None); + for action in actions { + action.run(&mut buffer, &mut rp); + } +}); diff --git a/src/buffer.rs b/src/buffer.rs new file mode 100644 index 0000000..fc434bb --- /dev/null +++ b/src/buffer.rs @@ -0,0 +1,673 @@ +use crate::{ + constants::{ + // 3072 + MESSAGE_SIZE, + // 64 + PACKET_SIZE, + }, + types::KeepaliveStatus, + Version, +}; +use core::sync::atomic::Ordering; +use ctap_types::Error as AuthenticatorError; +use ctaphid_dispatch::command::Command; +use ctaphid_dispatch::types::{Error as DispatchError, Requester}; +use ref_swap::OptionRefSwap; +use trussed::interrupt::InterruptFlag; + +/// The actual payload of given length is dealt with separately +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +struct Request { + channel: u32, + command: Command, + length: u16, + timestamp: u32, +} + +impl Request { + fn error(self, error: AuthenticatorError) -> PipeError { + PipeError { + channel: self.channel, + error, + keep_state: false, + } + } + + fn error_now(self, error: AuthenticatorError) -> PipeError { + PipeError { + channel: self.channel, + error, + keep_state: true, + } + } +} + +/// The actual payload of given length is dealt with separately +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +struct Response { + channel: u32, + command: Command, + length: u16, +} + +impl Response { + fn from_request_and_size(request: Request, size: usize) -> Self { + Self { + channel: request.channel, + command: request.command, + length: size as u16, + } + } + + fn error_on_channel(channel: u32) -> Self { + Self { + channel, + command: Command::Error, + length: 1, + } + } +} + +struct PipeError { + channel: u32, + error: AuthenticatorError, + keep_state: bool, +} + +impl PipeError { + fn on_channel(channel: u32, error: AuthenticatorError) -> Self { + Self { + channel, + error, + keep_state: false, + } + } +} + +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +struct MessageState { + // sequence number of next continuation packet + next_sequence: u8, + // number of bytes of message payload transmitted so far + transmitted: usize, +} + +impl Default for MessageState { + fn default() -> Self { + Self { + next_sequence: 0, + transmitted: PACKET_SIZE - 7, + } + } +} + +impl MessageState { + // update state due to receiving a full new continuation packet + #[must_use] + fn absorb_packet(mut self) -> Self { + self.next_sequence += 1; + self.transmitted += PACKET_SIZE - 5; + self + } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +enum State { + Idle, + + // if request payload data is larger than one packet + Receiving((Request, MessageState)), + + // Processing(Request), + + // // the request message is ready, need to dispatch to authenticator + // Dispatching((Request, Ctap2Request)), + + // waiting for response from authenticator + WaitingOnAuthenticator(Request), + + WaitingToSend(Response), + + Sending((Response, MessageState)), +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[must_use] +pub enum BufferState { + Idle, + ResponseQueued, + Error(BufferError), +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub struct BufferError { + channel: u32, + error: u8, +} + +pub struct Buffer<'pipe, 'interrupt> { + state: State, + interchange: Requester<'pipe>, + interrupt: Option<&'interrupt OptionRefSwap<'interrupt, InterruptFlag>>, + // shared between requests and responses, due to size + buffer: [u8; MESSAGE_SIZE], + // we assign channel IDs one by one, this is the one last assigned + // TODO: move into "app" + last_channel: u32, + // Indicator of implemented commands in INIT response. + implements: u8, + // timestamp that gets used for timing out CID's + last_milliseconds: u32, + // a "read once" indicator if now we're waiting on the application processing + started_processing: bool, + needs_keepalive: bool, + version: Version, +} + +impl<'pipe, 'interrupt> Buffer<'pipe, 'interrupt> { + pub fn new( + interchange: Requester<'pipe>, + initial_milliseconds: u32, + interrupt: Option<&'interrupt OptionRefSwap<'interrupt, InterruptFlag>>, + ) -> Self { + Self { + state: State::Idle, + interchange, + interrupt, + buffer: [0; MESSAGE_SIZE], + last_channel: 0, + // Default to nothing implemented. + implements: 0x80, + last_milliseconds: initial_milliseconds, + started_processing: false, + needs_keepalive: false, + version: Default::default(), + } + } + + pub fn implements(&self) -> u8 { + self.implements + } + + pub fn set_implements(&mut self, implements: u8) { + self.implements = implements; + } + + pub fn set_version(&mut self, version: Version) { + self.version = version; + } + + fn cancel_ongoing_activity(&mut self) { + if matches!(self.state, State::WaitingOnAuthenticator(_)) { + info_now!("Interrupting request"); + if let Some(Some(i)) = self.interrupt.map(|i| i.load(Ordering::Relaxed)) { + info_now!("Loaded some interrupter"); + i.interrupt(); + } + } + } + + pub fn check_timeout(&mut self, milliseconds: u32) -> BufferState { + // At any point the RP application could crash or something, + // so its up to the device to timeout those transactions. + let last = core::mem::replace(&mut self.last_milliseconds, milliseconds); + if let State::Receiving((request, message_state)) = &self.state { + if milliseconds.saturating_sub(last) > 200 { + // If there's a lapse in `check_timeout(...)` getting called (e.g. due to logging), + // this could lead to inaccurate timestamps on requests. So we'll + // just "forgive" requests temporarily if this happens. + debug!( + "lapse in hid check.. {} {} {}", + request.timestamp, milliseconds, last + ); + let mut request = *request; + request.timestamp = milliseconds; + self.state = State::Receiving((request, *message_state)); + BufferState::Idle + } + // compare keeping in mind of possible overflow in timestamp. + else if (milliseconds > request.timestamp && (milliseconds - request.timestamp) > 550) + || (milliseconds < request.timestamp && milliseconds > 550) + { + debug!( + "Channel timeout. {}, {}, {}", + request.timestamp, milliseconds, last + ); + self.send_error(request.error(AuthenticatorError::Timeout)) + } else { + BufferState::Idle + } + } else { + BufferState::Idle + } + } + + #[must_use] + pub fn send_keepalive(&self, is_waiting_for_user_presence: bool) -> Option> { + if let State::WaitingOnAuthenticator(request) = &self.state { + if !self.needs_keepalive { + // let response go out normally in idle loop + info!("cmd does not need keepalive messages"); + None + } else { + info!("keepalive"); + + let response = Response { + channel: request.channel, + command: Command::KeepAlive, + length: 1, + }; + let status = if is_waiting_for_user_presence { + &(KeepaliveStatus::UpNeeded as u8) + } else { + &(KeepaliveStatus::Processing as u8) + }; + Some(Packet::init(response, core::slice::from_ref(status))) + } + } else { + info!("keepalive done"); + None + } + } + + pub fn try_send_packet) -> Result<(), ()>>(&mut self, f: F) { + if let Some(packet) = self.packet_to_send() { + if f(packet).is_ok() { + self.state = packet.next_state(); + } + } + } + + #[must_use] + fn packet_to_send(&self) -> Option> { + match self.state { + State::WaitingToSend(response) => Some(Packet::init(response, &self.buffer)), + State::Sending((response, message_state)) => { + Some(Packet::cont(response, message_state, &self.buffer)) + } + // nothing to send + _ => None, + } + } + + pub fn handle_packet(&mut self, packet: &[u8; 64]) -> BufferState { + match self.handle_packet_impl(packet) { + Ok(Some(response)) => self.send_response(response), + Ok(None) => BufferState::Idle, + Err(error) => self.send_error(error), + } + } + + fn handle_packet_impl(&mut self, packet: &[u8; 64]) -> Result, PipeError> { + info!(">> "); + info!("{}", hex_str!(&packet[..16])); + + // packet is 64 bytes, reading 4 will not panic + let channel = u32::from_be_bytes(packet[..4].try_into().unwrap()); + // info_now!("channel {}", channel); + + let is_initialization = (packet[4] >> 7) != 0; + // info_now!("is_initialization {}", is_initialization); + + if is_initialization { + // case of initialization packet + info!("init"); + + let command_number = packet[4] & !0x80; + // info_now!("command number {}", command_number); + + let command = match Command::try_from(command_number) { + Ok(command) => command, + // `solo ls` crashes here as it uses command 0x86 + Err(_) => { + info!("Received invalid command."); + return Err(PipeError::on_channel( + channel, + AuthenticatorError::InvalidCommand, + )); + } + }; + + // can't actually fail + let length = u16::from_be_bytes(packet[5..][..2].try_into().unwrap()); + + let timestamp = self.last_milliseconds; + let current_request = Request { + channel, + command, + length, + timestamp, + }; + + if !(self.state == State::Idle) { + let request = match self.state { + State::WaitingOnAuthenticator(request) => request, + State::Receiving((request, _message_state)) => request, + _ => { + info_now!("Ignoring transaction as we're already transmitting."); + return Ok(None); + } + }; + if packet[4] == 0x86 { + info_now!("Resyncing!"); + self.cancel_ongoing_activity(); + } else { + return if channel == request.channel { + if command == Command::Cancel { + info_now!("Cancelling"); + self.cancel_ongoing_activity(); + Ok(None) + } else { + info_now!("Expected seq, {:?}", request.command); + Err(request.error(AuthenticatorError::InvalidSeq)) + } + } else { + info_now!("busy."); + Err(current_request.error_now(AuthenticatorError::ChannelBusy)) + }; + } + } + + if length > MESSAGE_SIZE as u16 { + info!("Error message too big."); + return Err(current_request.error_now(AuthenticatorError::InvalidLength)); + } + + if length > PACKET_SIZE as u16 - 7 { + // store received part of payload, + // prepare for continuation packets + self.buffer[..PACKET_SIZE - 7].copy_from_slice(&packet[7..]); + self.state = State::Receiving((current_request, { MessageState::default() })); + // we're done... wait for next packet + Ok(None) + } else { + // request fits in one packet + self.buffer[..length as usize].copy_from_slice(&packet[7..][..length as usize]); + self.dispatch_request(current_request) + } + } else { + // case of continuation packet + match self.state { + State::Receiving((request, message_state)) => { + let sequence = packet[4]; + // info_now!("receiving continuation packet {}", sequence); + if sequence != message_state.next_sequence { + // error handling? + // info_now!("wrong sequence for continuation packet, expected {} received {}", + // message_state.next_sequence, sequence); + info!("Error invalid cont pkt"); + return Err(request.error(AuthenticatorError::InvalidSeq)); + } + if channel != request.channel { + // error handling? + // info_now!("wrong channel for continuation packet, expected {} received {}", + // request.channel, channel); + info!("Ignore invalid channel"); + return Ok(None); + } + + let payload_length = request.length as usize; + if message_state.transmitted + (PACKET_SIZE - 5) < payload_length { + // info_now!("transmitted {} + (PACKET_SIZE - 5) < {}", + // message_state.transmitted, payload_length); + // store received part of payload + self.buffer[message_state.transmitted..][..PACKET_SIZE - 5] + .copy_from_slice(&packet[5..]); + let message_state = message_state.absorb_packet(); + self.state = State::Receiving((request, message_state)); + // info_now!("absorbed packet, awaiting next"); + Ok(None) + } else { + let missing = request.length as usize - message_state.transmitted; + self.buffer[message_state.transmitted..payload_length] + .copy_from_slice(&packet[5..][..missing]); + self.dispatch_request(request) + } + } + _ => { + // unexpected continuation packet + info!("Ignore unexpected cont pkt"); + Ok(None) + } + } + } + } + + fn dispatch_request(&mut self, request: Request) -> Result, PipeError> { + info!("Got request: {:?}", request.command); + match request.command { + Command::Init => {} + _ => { + if request.channel == 0xffffffff { + return Err(request.error(AuthenticatorError::InvalidChannel)); + } + } + } + // dispatch request further + match request.command { + Command::Init => { + // info_now!("command INIT!"); + // info_now!("data: {:?}", &self.buffer[..request.length as usize]); + match request.channel { + 0 => { + // this is an error / reserved number + Err(request.error(AuthenticatorError::InvalidChannel)) + } + + // broadcast channel ID - request for assignment + cid => { + if request.length != 8 { + // error + info!("Invalid length for init. ignore."); + Ok(None) + } else { + self.last_channel += 1; + // info_now!( + // "assigned channel {}", self.last_channel); + let _nonce = &self.buffer[..8]; + let response = Response { + channel: cid, + command: request.command, + length: 17, + }; + + self.buffer[8..12].copy_from_slice(&self.last_channel.to_be_bytes()); + // CTAPHID protocol version + self.buffer[12] = 2; + // major device version number + self.buffer[13] = self.version.major; + // minor device version number + self.buffer[14] = self.version.minor; + // build device version number + self.buffer[15] = self.version.build; + // capabilities flags + // 0x1: implements WINK + // 0x4: implements CBOR + // 0x8: does not implement MSG + // self.buffer[16] = 0x01 | 0x08; + self.buffer[16] = self.implements; + Ok(Some(response)) + } + } + } + } + + Command::Ping => { + let response = Response::from_request_and_size(request, request.length as usize); + Ok(Some(response)) + } + + Command::Cancel => { + info!("CTAPHID_CANCEL"); + self.cancel_ongoing_activity(); + Ok(None) + } + + _ => { + self.needs_keepalive = request.command == Command::Cbor; + if self.interchange.state() == interchange::State::Responded { + info!("dumping stale response"); + self.interchange.take_response(); + } + match self.interchange.request(( + request.command, + heapless::Vec::from_slice(&self.buffer[..request.length as usize]).unwrap(), + )) { + Ok(_) => { + self.state = State::WaitingOnAuthenticator(request); + self.started_processing = true; + Ok(None) + } + Err(_) => { + // busy + info_now!("STATE: {:?}", self.interchange.state()); + info!("can't handle more than one authenticator request at a time."); + Err(request.error_now(AuthenticatorError::ChannelBusy)) + } + } + } + } + } + + #[inline(never)] + pub fn handle_response(&mut self) -> BufferState { + if let State::WaitingOnAuthenticator(request) = self.state { + if let Ok(response) = self.interchange.response() { + match &response.0 { + Err(DispatchError::InvalidCommand) => { + info!("Got waiting reply from authenticator??"); + self.send_error(request.error(AuthenticatorError::InvalidCommand)) + } + Err(DispatchError::InvalidLength) => { + info!("Error, payload needed app command."); + self.send_error(request.error(AuthenticatorError::InvalidLength)) + } + Err(DispatchError::NoResponse) => { + info!("Got waiting noresponse from authenticator??"); + BufferState::Idle + } + + Ok(message) => { + if message.len() > self.buffer.len() { + error!( + "Message is longer than buffer ({} > {})", + message.len(), + self.buffer.len(), + ); + self.send_error(request.error(AuthenticatorError::InvalidLength)) + } else { + info!( + "Got {} bytes response from authenticator, starting send", + message.len() + ); + let response = Response::from_request_and_size(request, message.len()); + self.buffer[..message.len()].copy_from_slice(message); + self.send_response(response) + } + } + } + } else { + BufferState::Idle + } + } else { + BufferState::Idle + } + } + + fn send_error(&mut self, error: PipeError) -> BufferState { + let response = Response::error_on_channel(error.channel); + if error.keep_state { + BufferState::Error(BufferError { + channel: error.channel, + error: error.error as u8, + }) + } else { + self.buffer[0] = error.error as u8; + self.state = State::WaitingToSend(response); + BufferState::ResponseQueued + } + } + + fn send_response(&mut self, response: Response) -> BufferState { + self.state = State::WaitingToSend(response); + BufferState::ResponseQueued + } + + pub fn did_start_processing(&mut self) -> bool { + if self.started_processing { + self.started_processing = false; + true + } else { + false + } + } +} + +#[derive(Clone, Copy, Debug)] +pub struct Packet<'a> { + response: Response, + message_state: Option, + buffer: &'a [u8], +} + +impl<'a> Packet<'a> { + fn init(response: Response, buffer: &'a [u8]) -> Self { + Self { + response, + message_state: None, + buffer, + } + } + + fn cont(response: Response, message_state: MessageState, buffer: &'a [u8]) -> Self { + Self { + response, + message_state: Some(message_state), + buffer, + } + } + + fn has_more(&self) -> bool { + if let Some(message_state) = self.message_state { + let remaining = usize::from(self.response.length) - message_state.transmitted; + remaining > PACKET_SIZE - 5 + } else { + usize::from(self.response.length) > PACKET_SIZE - 7 + } + } + + fn next_state(&self) -> State { + if self.has_more() { + let message_state = self + .message_state + .map(MessageState::absorb_packet) + .unwrap_or_default(); + State::Sending((self.response, message_state)) + } else { + State::Idle + } + } + + pub fn serialize(&self, buffer: &mut [u8; PACKET_SIZE]) { + // buffer must be zeroed + buffer[..4].copy_from_slice(&self.response.channel.to_be_bytes()); + if let Some(message_state) = self.message_state { + buffer[4] = message_state.next_sequence; + let remaining = usize::from(self.response.length) - message_state.transmitted; + let n = remaining.min(PACKET_SIZE - 5); + buffer[5..][..n].copy_from_slice(&self.buffer[message_state.transmitted..][..n]); + } else { + buffer[4] = self.response.command.into_u8() | 0x80; + buffer[5..7].copy_from_slice(&self.response.length.to_be_bytes()); + let n = usize::from(self.response.length).min(PACKET_SIZE - 7); + buffer[7..][..n].copy_from_slice(&self.buffer[..n]); + } + } +} + +impl<'a> From<&'a BufferError> for Packet<'a> { + fn from(error: &'a BufferError) -> Self { + let response = Response::error_on_channel(error.channel); + Self::init(response, core::slice::from_ref(&error.error)) + } +} diff --git a/src/class.rs b/src/class.rs index 720a532..79445be 100644 --- a/src/class.rs +++ b/src/class.rs @@ -95,19 +95,19 @@ where /// Indicate in INIT response that Wink command is implemented. pub fn implements_wink(mut self) -> Self { - self.pipe.implements |= 0x01; + self.pipe.set_implements(self.pipe.implements() | 0x01); self } /// Indicate in INIT response that RawMsg command is implemented. pub fn implements_ctap1(mut self) -> Self { - self.pipe.implements &= !0x80; + self.pipe.set_implements(self.pipe.implements() & !0x80); self } /// Indicate in INIT response that Cbor command is implemented. pub fn implements_ctap2(mut self) -> Self { - self.pipe.implements |= 0x04; + self.pipe.set_implements(self.pipe.implements() | 0x04); self } @@ -253,8 +253,7 @@ where #[inline(never)] fn poll(&mut self) { // debug!("state = {:?}", self.pipe().state); - self.pipe.handle_response(); - self.pipe.maybe_write_packet(); + self.pipe.handle_and_write_response(); } // called when endpoint with given address received a packet diff --git a/src/lib.rs b/src/lib.rs index f57ec34..0922edd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,6 +16,7 @@ generate_macros!(); // pub mod authenticator; +pub mod buffer; pub mod class; pub mod constants; pub use class::CtapHid; diff --git a/src/pipe.rs b/src/pipe.rs index a4b6834..b887268 100644 --- a/src/pipe.rs +++ b/src/pipe.rs @@ -12,182 +12,40 @@ receive busy errors). No state is maintained between transactions. */ -use core::convert::TryFrom; -use core::convert::TryInto; -use core::sync::atomic::Ordering; -// pub type ContactInterchange = usbd_ccid::types::ApduInterchange; -// pub type ContactlessInterchange = iso14443::types::ApduInterchange; - -use ctaphid_dispatch::command::Command; use ctaphid_dispatch::types::Requester; - -use ctap_types::Error as AuthenticatorError; -use trussed::interrupt::InterruptFlag; - use ref_swap::OptionRefSwap; -// use serde::Serialize; +use trussed::interrupt::InterruptFlag; use usb_device::{ bus::UsbBus, endpoint::{EndpointAddress, EndpointIn, EndpointOut}, UsbError, - // Result as UsbResult, }; use crate::{ - constants::{ - // 3072 - MESSAGE_SIZE, - // 64 - PACKET_SIZE, - }, - types::KeepaliveStatus, + buffer::{Buffer, BufferState, Packet}, + constants::PACKET_SIZE, + Version, }; -/// The actual payload of given length is dealt with separately -#[derive(Copy, Clone, Debug, Eq, PartialEq)] -pub struct Request { - channel: u32, - command: Command, - length: u16, - timestamp: u32, -} - -/// The actual payload of given length is dealt with separately -#[derive(Copy, Clone, Debug, Eq, PartialEq)] -pub struct Response { - channel: u32, - command: Command, - length: u16, -} - -impl Response { - pub fn from_request_and_size(request: Request, size: usize) -> Self { - Self { - channel: request.channel, - command: request.command, - length: size as u16, - } - } - - pub fn error_from_request(request: Request) -> Self { - Self::error_on_channel(request.channel) - } - - pub fn error_on_channel(channel: u32) -> Self { - Self { - channel, - command: ctaphid_dispatch::command::Command::Error, - length: 1, - } - } -} - -#[derive(Copy, Clone, Debug, Eq, PartialEq)] -pub struct MessageState { - // sequence number of next continuation packet - next_sequence: u8, - // number of bytes of message payload transmitted so far - transmitted: usize, -} - -impl Default for MessageState { - fn default() -> Self { - Self { - next_sequence: 0, - transmitted: PACKET_SIZE - 7, - } - } -} - -impl MessageState { - // update state due to receiving a full new continuation packet - pub fn absorb_packet(&mut self) { - self.next_sequence += 1; - self.transmitted += PACKET_SIZE - 5; - } -} - -#[derive(Clone, Debug, Eq, PartialEq)] -#[allow(unused)] -pub enum State { - Idle, - - // if request payload data is larger than one packet - Receiving((Request, MessageState)), - - // Processing(Request), - - // // the request message is ready, need to dispatch to authenticator - // Dispatching((Request, Ctap2Request)), - - // waiting for response from authenticator - WaitingOnAuthenticator(Request), - - WaitingToSend(Response), - - Sending((Response, MessageState)), -} - pub struct Pipe<'alloc, 'pipe, 'interrupt, Bus: UsbBus> { - read_endpoint: EndpointOut<'alloc, Bus>, - write_endpoint: EndpointIn<'alloc, Bus>, - state: State, - - interchange: Requester<'pipe>, - interrupt: Option<&'interrupt OptionRefSwap<'interrupt, InterruptFlag>>, - - // shared between requests and responses, due to size - buffer: [u8; MESSAGE_SIZE], - - // we assign channel IDs one by one, this is the one last assigned - // TODO: move into "app" - last_channel: u32, - - // Indicator of implemented commands in INIT response. - pub(crate) implements: u8, - - // timestamp that gets used for timing out CID's - pub(crate) last_milliseconds: u32, - - // a "read once" indicator if now we're waiting on the application processing - started_processing: bool, - - needs_keepalive: bool, - - pub(crate) version: crate::Version, + endpoints: Endpoints<'alloc, Bus>, + buffer: Buffer<'pipe, 'interrupt>, } impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus> { - pub(crate) fn new( + pub fn new( read_endpoint: EndpointOut<'alloc, Bus>, write_endpoint: EndpointIn<'alloc, Bus>, interchange: Requester<'pipe>, initial_milliseconds: u32, ) -> Self { Self { - read_endpoint, - write_endpoint, - state: State::Idle, - interchange, - buffer: [0u8; MESSAGE_SIZE], - last_channel: 0, - interrupt: None, - // Default to nothing implemented. - implements: 0x80, - last_milliseconds: initial_milliseconds, - started_processing: false, - needs_keepalive: false, - version: Default::default(), + endpoints: Endpoints::new(read_endpoint, write_endpoint), + buffer: Buffer::new(interchange, initial_milliseconds, None), } } -} - -impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus> { - // pub fn borrow_mut_authenticator(&mut self) -> &mut Authenticator { - // &mut self.authenticator - // } - pub(crate) fn with_interrupt( + pub fn with_interrupt( read_endpoint: EndpointOut<'alloc, Bus>, write_endpoint: EndpointIn<'alloc, Bus>, interchange: Requester<'pipe>, @@ -195,563 +53,150 @@ impl<'alloc, 'pipe, 'interrupt, Bus: UsbBus> Pipe<'alloc, 'pipe, 'interrupt, Bus initial_milliseconds: u32, ) -> Self { Self { - read_endpoint, - write_endpoint, - state: State::Idle, - interchange, - buffer: [0u8; MESSAGE_SIZE], - last_channel: 0, - interrupt, - // Default to nothing implemented. - implements: 0x80, - last_milliseconds: initial_milliseconds, - started_processing: false, - needs_keepalive: false, - version: Default::default(), + endpoints: Endpoints::new(read_endpoint, write_endpoint), + buffer: Buffer::new(interchange, initial_milliseconds, interrupt), } } - pub(crate) fn set_version(&mut self, version: crate::Version) { - self.version = version; + pub fn implements(&self) -> u8 { + self.buffer.implements() + } + + pub fn set_implements(&mut self, implements: u8) { + self.buffer.set_implements(implements); + } + + pub fn set_version(&mut self, version: Version) { + self.buffer.set_version(version); } pub fn read_address(&self) -> EndpointAddress { - self.read_endpoint.address() + self.endpoints.read.address() } pub fn write_address(&self) -> EndpointAddress { - self.write_endpoint.address() + self.endpoints.write.address() } // used to generate the configuration descriptors - pub(crate) fn read_endpoint(&self) -> &EndpointOut<'alloc, Bus> { - &self.read_endpoint + pub fn read_endpoint(&self) -> &EndpointOut<'alloc, Bus> { + &self.endpoints.read } // used to generate the configuration descriptors - pub(crate) fn write_endpoint(&self) -> &EndpointIn<'alloc, Bus> { - &self.write_endpoint - } - - fn cancel_ongoing_activity(&mut self) { - if matches!(self.state, State::WaitingOnAuthenticator(_)) { - info_now!("Interrupting request"); - if let Some(Some(i)) = self.interrupt.map(|i| i.load(Ordering::Relaxed)) { - info_now!("Loaded some interrupter"); - i.interrupt(); - } - } + pub fn write_endpoint(&self) -> &EndpointIn<'alloc, Bus> { + &self.endpoints.write } /// This method handles CTAP packets (64 bytes), until it has assembled /// a CTAP message, with which it then calls `dispatch_message`. /// /// During these calls, we can be in states: Idle, Receiving, Dispatching. - pub(crate) fn read_and_handle_packet(&mut self) { + pub fn read_and_handle_packet(&mut self) { // info_now!("got a packet!"); let mut packet = [0u8; PACKET_SIZE]; - match self.read_endpoint.read(&mut packet) { - Ok(PACKET_SIZE) => {} - Ok(_size) => { - // error handling? - // from spec: "Packets are always fixed size (defined by the endpoint and - // HID report descriptors) and although all bytes may not be needed in a - // particular packet, the full size always has to be sent. - // Unused bytes SHOULD be set to zero." - // !("OK but size {}", size); - info!("error unexpected size {}", _size); - return; - } - // usb-device lists WouldBlock or BufferOverflow as possible errors. - // both should not occur here, and we can't do anything anyway. - // Err(UsbError::WouldBlock) => { return; }, - // Err(UsbError::BufferOverflow) => { return; }, - Err(_error) => { - info!("error no {}", _error as i32); - return; - } - }; - info!(">> "); - info!("{}", hex_str!(&packet[..16])); - - // packet is 64 bytes, reading 4 will not panic - let channel = u32::from_be_bytes(packet[..4].try_into().unwrap()); - // info_now!("channel {}", channel); - - let is_initialization = (packet[4] >> 7) != 0; - // info_now!("is_initialization {}", is_initialization); - - if is_initialization { - // case of initialization packet - info!("init"); - - let command_number = packet[4] & !0x80; - // info_now!("command number {}", command_number); - - let command = match Command::try_from(command_number) { - Ok(command) => command, - // `solo ls` crashes here as it uses command 0x86 - Err(_) => { - info!("Received invalid command."); - self.start_sending_error_on_channel(channel, AuthenticatorError::InvalidCommand); - return; - } - }; - - // can't actually fail - let length = u16::from_be_bytes(packet[5..][..2].try_into().unwrap()); - - let timestamp = self.last_milliseconds; - let current_request = Request { - channel, - command, - length, - timestamp, - }; - - if !(self.state == State::Idle) { - let request = match self.state { - State::WaitingOnAuthenticator(request) => request, - State::Receiving((request, _message_state)) => request, - _ => { - info_now!("Ignoring transaction as we're already transmitting."); - return; - } - }; - if packet[4] == 0x86 { - info_now!("Resyncing!"); - self.cancel_ongoing_activity(); - } else { - if channel == request.channel { - if command == Command::Cancel { - info_now!("Cancelling"); - self.cancel_ongoing_activity(); - } else { - info_now!("Expected seq, {:?}", request.command); - self.start_sending_error(request, AuthenticatorError::InvalidSeq); - } - } else { - info_now!("busy."); - self.send_error_now(current_request, AuthenticatorError::ChannelBusy); - } - - return; - } - } - - if length > MESSAGE_SIZE as u16 { - info!("Error message too big."); - self.send_error_now(current_request, AuthenticatorError::InvalidLength); - return; - } - - if length > PACKET_SIZE as u16 - 7 { - // store received part of payload, - // prepare for continuation packets - self.buffer[..PACKET_SIZE - 7].copy_from_slice(&packet[7..]); - self.state = State::Receiving((current_request, { MessageState::default() })); - // we're done... wait for next packet - } else { - // request fits in one packet - self.buffer[..length as usize].copy_from_slice(&packet[7..][..length as usize]); - self.dispatch_request(current_request); - } - } else { - // case of continuation packet - match self.state { - State::Receiving((request, mut message_state)) => { - let sequence = packet[4]; - // info_now!("receiving continuation packet {}", sequence); - if sequence != message_state.next_sequence { - // error handling? - // info_now!("wrong sequence for continuation packet, expected {} received {}", - // message_state.next_sequence, sequence); - info!("Error invalid cont pkt"); - self.start_sending_error(request, AuthenticatorError::InvalidSeq); - return; - } - if channel != request.channel { - // error handling? - // info_now!("wrong channel for continuation packet, expected {} received {}", - // request.channel, channel); - info!("Ignore invalid channel"); - return; - } - - let payload_length = request.length as usize; - if message_state.transmitted + (PACKET_SIZE - 5) < payload_length { - // info_now!("transmitted {} + (PACKET_SIZE - 5) < {}", - // message_state.transmitted, payload_length); - // store received part of payload - self.buffer[message_state.transmitted..][..PACKET_SIZE - 5] - .copy_from_slice(&packet[5..]); - message_state.absorb_packet(); - self.state = State::Receiving((request, message_state)); - // info_now!("absorbed packet, awaiting next"); - } else { - let missing = request.length as usize - message_state.transmitted; - self.buffer[message_state.transmitted..payload_length] - .copy_from_slice(&packet[5..][..missing]); - self.dispatch_request(request); - } - } - _ => { - // unexpected continuation packet - info!("Ignore unexpected cont pkt"); - } - } + if self.endpoints.read(&mut packet).is_ok() { + let state = self.buffer.handle_packet(&packet); + self.handle(state); } } pub fn check_timeout(&mut self, milliseconds: u32) { - // At any point the RP application could crash or something, - // so its up to the device to timeout those transactions. - let last = self.last_milliseconds; - self.last_milliseconds = milliseconds; - if let State::Receiving((request, _message_state)) = &mut self.state { - if (milliseconds - last) > 200 { - // If there's a lapse in `check_timeout(...)` getting called (e.g. due to logging), - // this could lead to inaccurate timestamps on requests. So we'll - // just "forgive" requests temporarily if this happens. - debug!( - "lapse in hid check.. {} {} {}", - request.timestamp, milliseconds, last - ); - request.timestamp = milliseconds; - } - // compare keeping in mind of possible overflow in timestamp. - else if (milliseconds > request.timestamp && (milliseconds - request.timestamp) > 550) - || (milliseconds < request.timestamp && milliseconds > 550) - { - debug!( - "Channel timeout. {}, {}, {}", - request.timestamp, milliseconds, last - ); - let req = *request; - self.start_sending_error(req, AuthenticatorError::Timeout); - } - } - } - - fn dispatch_request(&mut self, request: Request) { - info!("Got request: {:?}", request.command); - match request.command { - Command::Init => {} - _ => { - if request.channel == 0xffffffff { - self.start_sending_error(request, AuthenticatorError::InvalidChannel); - return; - } - } - } - // dispatch request further - match request.command { - Command::Init => { - // info_now!("command INIT!"); - // info_now!("data: {:?}", &self.buffer[..request.length as usize]); - match request.channel { - 0 => { - // this is an error / reserved number - self.start_sending_error(request, AuthenticatorError::InvalidChannel); - } - - // broadcast channel ID - request for assignment - cid => { - if request.length != 8 { - // error - info!("Invalid length for init. ignore."); - } else { - self.last_channel += 1; - // info_now!( - // "assigned channel {}", self.last_channel); - let _nonce = &self.buffer[..8]; - let response = Response { - channel: cid, - command: request.command, - length: 17, - }; - - self.buffer[8..12].copy_from_slice(&self.last_channel.to_be_bytes()); - // CTAPHID protocol version - self.buffer[12] = 2; - // major device version number - self.buffer[13] = self.version.major; - // minor device version number - self.buffer[14] = self.version.minor; - // build device version number - self.buffer[15] = self.version.build; - // capabilities flags - // 0x1: implements WINK - // 0x4: implements CBOR - // 0x8: does not implement MSG - // self.buffer[16] = 0x01 | 0x08; - self.buffer[16] = self.implements; - self.start_sending(response); - } - } - } - } - - Command::Ping => { - let response = Response::from_request_and_size(request, request.length as usize); - self.start_sending(response); - } - - Command::Cancel => { - info!("CTAPHID_CANCEL"); - self.cancel_ongoing_activity(); - } - - _ => { - if request.command == Command::Cbor { - self.needs_keepalive = true; - } else { - self.needs_keepalive = false; - } - if self.interchange.state() == interchange::State::Responded { - info!("dumping stale response"); - self.interchange.take_response(); - } - match self.interchange.request(( - request.command, - heapless::Vec::from_slice(&self.buffer[..request.length as usize]).unwrap(), - )) { - Ok(_) => { - self.state = State::WaitingOnAuthenticator(request); - self.started_processing = true; - } - Err(_) => { - // busy - info_now!("STATE: {:?}", self.interchange.state()); - info!("can't handle more than one authenticator request at a time."); - self.send_error_now(request, AuthenticatorError::ChannelBusy); - } - } - } - } + let state = self.buffer.check_timeout(milliseconds); + self.handle(state); } pub fn did_start_processing(&mut self) -> bool { - if self.started_processing { - self.started_processing = false; - true - } else { - false - } + self.buffer.did_start_processing() } pub fn send_keepalive(&mut self, is_waiting_for_user_presence: bool) -> bool { - if let State::WaitingOnAuthenticator(request) = &self.state { - if !self.needs_keepalive { - // let response go out normally in idle loop - info!("cmd does not need keepalive messages"); - false - } else { - info!("keepalive"); - - let mut packet = [0u8; PACKET_SIZE]; - - packet[..4].copy_from_slice(&request.channel.to_be_bytes()); - packet[4] = 0x80 | 0x3B; - packet[5..7].copy_from_slice(&1u16.to_be_bytes()); - - if is_waiting_for_user_presence { - packet[7] = KeepaliveStatus::UpNeeded as u8; - } else { - packet[7] = KeepaliveStatus::Processing as u8; - } - - self.write_endpoint.write(&packet).ok(); - - true - } + if let Some(packet) = self.buffer.send_keepalive(is_waiting_for_user_presence) { + self.endpoints.write(packet).ok(); + true } else { - info!("keepalive done"); false } } - #[inline(never)] - pub fn handle_response(&mut self) { - if let State::WaitingOnAuthenticator(request) = self.state { - if let Ok(response) = self.interchange.response() { - match &response.0 { - Err(ctaphid_dispatch::app::Error::InvalidCommand) => { - info!("Got waiting reply from authenticator??"); - self.start_sending_error(request, AuthenticatorError::InvalidCommand); - } - Err(ctaphid_dispatch::app::Error::InvalidLength) => { - info!("Error, payload needed app command."); - self.start_sending_error(request, AuthenticatorError::InvalidLength); - } - Err(ctaphid_dispatch::app::Error::NoResponse) => { - info!("Got waiting noresponse from authenticator??"); - } + pub fn handle_and_write_response(&mut self) { + let state = self.buffer.handle_response(); + self.handle(state); + } - Ok(message) => { - if message.len() > self.buffer.len() { - error!( - "Message is longer than buffer ({} > {})", - message.len(), - self.buffer.len(), - ); - self.start_sending_error(request, AuthenticatorError::InvalidLength); - } else { - info!( - "Got {} bytes response from authenticator, starting send", - message.len() - ); - let response = Response::from_request_and_size(request, message.len()); - self.buffer[..message.len()].copy_from_slice(&message); - self.start_sending(response); - } - } - } + fn handle(&mut self, state: BufferState) { + match state { + BufferState::Idle => (), + BufferState::ResponseQueued => self.maybe_write_packet(), + BufferState::Error(error) => { + // TODO: should we block? + self.endpoints.write(Packet::from(&error)).ok(); } } } - fn start_sending(&mut self, response: Response) { - self.state = State::WaitingToSend(response); - self.maybe_write_packet(); + // called from poll, and when a packet has been sent + #[inline(never)] + pub fn maybe_write_packet(&mut self) { + self.buffer + .try_send_packet(|packet| self.endpoints.write(packet)); } +} - fn start_sending_error(&mut self, request: Request, error: AuthenticatorError) { - self.start_sending_error_on_channel(request.channel, error); - } +struct Endpoints<'a, Bus: UsbBus> { + read: EndpointOut<'a, Bus>, + write: EndpointIn<'a, Bus>, +} - fn start_sending_error_on_channel(&mut self, channel: u32, error: AuthenticatorError) { - self.buffer[0] = error as u8; - let response = Response::error_on_channel(channel); - self.start_sending(response); +impl<'a, Bus: UsbBus> Endpoints<'a, Bus> { + fn new(read: EndpointOut<'a, Bus>, write: EndpointIn<'a, Bus>) -> Self { + Self { read, write } } - fn send_error_now(&mut self, request: Request, error: AuthenticatorError) { - let last_state = core::mem::replace(&mut self.state, State::Idle); - let last_first_byte = self.buffer[0]; - - self.buffer[0] = error as u8; - let response = Response::error_from_request(request); - self.start_sending(response); - self.maybe_write_packet(); - - self.state = last_state; - self.buffer[0] = last_first_byte; + fn read(&mut self, packet: &mut [u8; PACKET_SIZE]) -> Result<(), ()> { + match self.read.read(packet) { + Ok(PACKET_SIZE) => Ok(()), + Ok(_size) => { + // error handling? + // from spec: "Packets are always fixed size (defined by the endpoint and + // HID report descriptors) and although all bytes may not be needed in a + // particular packet, the full size always has to be sent. + // Unused bytes SHOULD be set to zero." + // !("OK but size {}", size); + info!("error unexpected size {}", _size); + Err(()) + } + // usb-device lists WouldBlock or BufferOverflow as possible errors. + // both should not occur here, and we can't do anything anyway. + // Err(UsbError::WouldBlock) => { return; }, + // Err(UsbError::BufferOverflow) => { return; }, + Err(_error) => { + info!("error no {}", _error as i32); + Err(()) + } + } } - // called from poll, and when a packet has been sent - #[inline(never)] - pub(crate) fn maybe_write_packet(&mut self) { - match self.state { - State::WaitingToSend(response) => { - // zeros leftover bytes - let mut packet = [0u8; PACKET_SIZE]; - packet[..4].copy_from_slice(&response.channel.to_be_bytes()); - // packet[4] = response.command.into() | 0x80u8; - packet[4] = response.command.into_u8() | 0x80; - packet[5..7].copy_from_slice(&response.length.to_be_bytes()); - - let fits_in_one_packet = 7 + response.length as usize <= PACKET_SIZE; - if fits_in_one_packet { - packet[7..][..response.length as usize] - .copy_from_slice(&self.buffer[..response.length as usize]); - self.state = State::Idle; - } else { - packet[7..].copy_from_slice(&self.buffer[..PACKET_SIZE - 7]); - } - - // try actually sending - // info_now!("attempting to write init packet {:?}, {:?}", - // &packet[..32], &packet[32..]); - let result = self.write_endpoint.write(&packet); - - match result { - Err(UsbError::WouldBlock) => { - // fine, can't write try later - // this shouldn't happen probably - info!("hid usb WouldBlock"); - } - Err(_) => { - // info_now!("weird USB errrorrr"); - panic!("unexpected error writing packet!"); - } - Ok(PACKET_SIZE) => { - // goodie, this worked - if fits_in_one_packet { - self.state = State::Idle; - // info_now!("StartSent {} bytes, idle again", response.length); - // info_now!("IDLE again"); - } else { - self.state = State::Sending((response, MessageState::default())); - // info_now!( - // "StartSent {} of {} bytes, waiting to send again", - // PACKET_SIZE - 7, response.length); - // info_now!("State: {:?}", &self.state); - } - } - Ok(_) => { - // info_now!("short write"); - panic!("unexpected size writing packet!"); - } - }; + fn write(&mut self, packet: Packet<'_>) -> Result<(), ()> { + // zeros leftover bytes + let mut buffer = [0u8; PACKET_SIZE]; + packet.serialize(&mut buffer); + match self.write.write(&buffer) { + Ok(PACKET_SIZE) => Ok(()), + Ok(_) => { + error!("short write"); + panic!("unexpected size writing packet!"); } - - State::Sending((response, mut message_state)) => { - // info_now!("in StillSending"); - let mut packet = [0u8; PACKET_SIZE]; - packet[..4].copy_from_slice(&response.channel.to_be_bytes()); - packet[4] = message_state.next_sequence; - - let sent = message_state.transmitted; - let remaining = response.length as usize - sent; - let last_packet = 5 + remaining <= PACKET_SIZE; - if last_packet { - packet[5..][..remaining] - .copy_from_slice(&self.buffer[message_state.transmitted..][..remaining]); - } else { - packet[5..].copy_from_slice( - &self.buffer[message_state.transmitted..][..PACKET_SIZE - 5], - ); - } - - // try actually sending - // info_now!("attempting to write cont packet {:?}, {:?}", - // &packet[..32], &packet[32..]); - let result = self.write_endpoint.write(&packet); - - match result { - Err(UsbError::WouldBlock) => { - // fine, can't write try later - // this shouldn't happen probably - // info_now!("can't send seq {}, write endpoint busy", - // message_state.next_sequence); - } - Err(_) => { - // info_now!("weird USB error"); - panic!("unexpected error writing packet!"); - } - Ok(PACKET_SIZE) => { - // goodie, this worked - if last_packet { - self.state = State::Idle; - // info_now!("in IDLE state after {:?}", &message_state); - } else { - message_state.absorb_packet(); - // DANGER! destructuring in the match arm copies out - // message state, so need to update state - // info_now!("sent one more, now {:?}", &message_state); - self.state = State::Sending((response, message_state)); - } - } - Ok(_) => { - debug!("short write"); - panic!("unexpected size writing packet!"); - } - }; + Err(UsbError::WouldBlock) => { + // fine, can't write try later + // this shouldn't happen probably + info!("hid usb WouldBlock"); + Err(()) + } + Err(_) => { + // info_now!("weird USB error"); + panic!("unexpected error writing packet!"); } - - // nothing to send - _ => {} } } }