diff --git a/smb-core/src/nt_status.rs b/smb-core/src/nt_status.rs index 6eb985e..1656183 100644 --- a/smb-core/src/nt_status.rs +++ b/smb-core/src/nt_status.rs @@ -19,6 +19,10 @@ pub enum NTStatus { UserSessionDeleted = 0xC0000203, NetworkSessionExpired = 0xC000035C, FileNotAvailable = 0xC0000467, + FileClosed = 0xC0000128, + EndOfFile = 0xC0000011, + InvalidInfoClass = 0xC0000003, + InvalidDeviceRequest = 0xC0000010, UnknownError = 0xFFFFFFFF, } diff --git a/smb/src/main.rs b/smb/src/main.rs index 63f35ae..a7980e3 100644 --- a/smb/src/main.rs +++ b/smb/src/main.rs @@ -40,7 +40,7 @@ async fn main() -> SMBResult<()> { .unencrypted_access(true) .require_message_signing(false) .encrypt_data(false) - .add_fs_share("test".into(), "".into(), file_allowed, get_file_perms) + .add_fs_share("test".into(), std::env::var("SMB_SHARE_PATH").unwrap_or_default(), file_allowed, get_file_perms) .add_ipc_share() .auth_provider(NTLMAuthProvider::new(vec![ User::new("tejasmehta", "password"), diff --git a/smb/src/protocol/body/close/mod.rs b/smb/src/protocol/body/close/mod.rs index 9f278f6..dab14fa 100644 --- a/smb/src/protocol/body/close/mod.rs +++ b/smb/src/protocol/body/close/mod.rs @@ -9,7 +9,7 @@ use crate::protocol::body::create::file_attributes::SMBFileAttributes; use crate::protocol::body::create::file_id::SMBFileId; use crate::protocol::body::filetime::FileTime; -mod flags; +pub mod flags; #[derive( Debug, @@ -32,6 +32,16 @@ pub struct SMBCloseRequest { file_id: SMBFileId, } +impl SMBCloseRequest { + pub fn flags(&self) -> SMBCloseFlags { + self.flags + } + + pub fn file_id(&self) -> &SMBFileId { + &self.file_id + } +} + #[derive( Debug, PartialEq, @@ -63,4 +73,116 @@ pub struct SMBCloseResponse { end_of_file: u64, #[smb_direct(start(fixed = 56))] file_attributes: SMBFileAttributes, +} + +impl SMBCloseResponse { + pub fn from_metadata(metadata: &crate::server::share::SMBFileMetadata, attributes: SMBFileAttributes) -> Self { + Self { + flags: SMBCloseFlags::POSTQUERY_ATTRIB, + reserved: PhantomData, + creation_time: metadata.creation_time.clone(), + last_access_time: metadata.last_access_time.clone(), + last_write_time: metadata.last_write_time.clone(), + change_time: metadata.last_modification_time.clone(), + allocation_size: metadata.allocated_size, + end_of_file: metadata.actual_size, + file_attributes: attributes, + } + } + + pub fn empty() -> Self { + Self { + flags: SMBCloseFlags::empty(), + reserved: PhantomData, + creation_time: FileTime::zero(), + last_access_time: FileTime::zero(), + last_write_time: FileTime::zero(), + change_time: FileTime::zero(), + allocation_size: 0, + end_of_file: 0, + file_attributes: SMBFileAttributes::empty(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use smb_core::{SMBByteSize, SMBToBytes, SMBFromBytes}; + + #[test] + fn close_response_empty_has_zero_fields() { + let resp = SMBCloseResponse::empty(); + assert_eq!(resp.flags, SMBCloseFlags::empty()); + assert_eq!(resp.allocation_size, 0); + assert_eq!(resp.end_of_file, 0); + assert_eq!(resp.file_attributes, SMBFileAttributes::empty()); + } + + #[test] + fn close_response_empty_serialization_round_trip() { + let resp = SMBCloseResponse::empty(); + let bytes = resp.smb_to_bytes(); + assert_eq!(bytes.len(), resp.smb_byte_size()); + let (_, parsed) = SMBCloseResponse::smb_from_bytes(&bytes).unwrap(); + assert_eq!(resp, parsed); + } + + #[test] + fn close_response_from_metadata_sets_postquery_flag() { + use crate::server::share::SMBFileMetadata; + let metadata = SMBFileMetadata { + creation_time: FileTime::from_unix(1700000000), + last_access_time: FileTime::from_unix(1700000100), + last_write_time: FileTime::from_unix(1700000200), + last_modification_time: FileTime::from_unix(1700000300), + allocated_size: 4096, + actual_size: 1024, + }; + let resp = SMBCloseResponse::from_metadata(&metadata, SMBFileAttributes::NORMAL); + assert!(resp.flags.contains(SMBCloseFlags::POSTQUERY_ATTRIB)); + assert_eq!(resp.allocation_size, 4096); + assert_eq!(resp.end_of_file, 1024); + assert_eq!(resp.file_attributes, SMBFileAttributes::NORMAL); + } + + #[test] + fn close_response_from_metadata_serialization_round_trip() { + use crate::server::share::SMBFileMetadata; + let metadata = SMBFileMetadata { + creation_time: FileTime::from_unix(1700000000), + last_access_time: FileTime::from_unix(1700000100), + last_write_time: FileTime::from_unix(1700000200), + last_modification_time: FileTime::from_unix(1700000300), + allocated_size: 8192, + actual_size: 2048, + }; + let resp = SMBCloseResponse::from_metadata(&metadata, SMBFileAttributes::ARCHIVE); + let bytes = resp.smb_to_bytes(); + assert_eq!(bytes.len(), resp.smb_byte_size()); + let (_, parsed) = SMBCloseResponse::smb_from_bytes(&bytes).unwrap(); + assert_eq!(resp, parsed); + } + + #[test] + fn close_request_accessors() { + let file_id = SMBFileId { persistent: 42, volatile: 99 }; + let bytes = { + let mut buf = Vec::new(); + // struct_size (u16) = 24 + buf.extend_from_slice(&24u16.to_le_bytes()); + // flags (u16) = POSTQUERY_ATTRIB = 0x0001 + buf.extend_from_slice(&1u16.to_le_bytes()); + // reserved (4 bytes) + buf.extend_from_slice(&[0u8; 4]); + // file_id: persistent (u64) + volatile (u64) + buf.extend_from_slice(&42u64.to_le_bytes()); + buf.extend_from_slice(&99u64.to_le_bytes()); + buf + }; + let (_, req) = SMBCloseRequest::smb_from_bytes(&bytes).unwrap(); + assert_eq!(req.file_id().persistent, file_id.persistent); + assert_eq!(req.file_id().volatile, file_id.volatile); + assert!(req.flags().contains(SMBCloseFlags::POSTQUERY_ATTRIB)); + } } \ No newline at end of file diff --git a/smb/src/protocol/body/create/mod.rs b/smb/src/protocol/body/create/mod.rs index 5281546..a0d4a0a 100644 --- a/smb/src/protocol/body/create/mod.rs +++ b/smb/src/protocol/body/create/mod.rs @@ -66,7 +66,7 @@ pub struct SMBCreateRequest { create_disposition: SMBCreateDisposition, #[smb_direct(start(fixed = 40))] create_options: SMBCreateOptions, - #[smb_string(order = 0, start(inner(start = 44, num_type = "u16", subtract = 68)), length(inner(start = 46, num_type = "u16")), underlying = "u16")] + #[smb_string(order = 0, start(inner(start = 44, num_type = "u16", subtract = 64)), length(inner(start = 46, num_type = "u16")), underlying = "u16")] file_name: String, #[smb_vector(order = 1, align = 8, length(inner(start = 52, num_type = "u32")), offset(inner(start = 48, num_type = "u32", subtract = 64)))] contexts: Vec, diff --git a/smb/src/protocol/body/file_info/access.rs b/smb/src/protocol/body/file_info/access.rs new file mode 100644 index 0000000..7c8a135 --- /dev/null +++ b/smb/src/protocol/body/file_info/access.rs @@ -0,0 +1,10 @@ +use serde::{Deserialize, Serialize}; + +use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; + +/// FILE_ACCESS_INFORMATION (MS-FSCC 2.4.1) — 4 bytes +#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, SMBByteSize, SMBFromBytes, SMBToBytes)] +pub struct FileAccessInformation { + #[smb_direct(start(fixed = 0))] + pub access_flags: u32, +} diff --git a/smb/src/protocol/body/file_info/alignment.rs b/smb/src/protocol/body/file_info/alignment.rs new file mode 100644 index 0000000..ca5b046 --- /dev/null +++ b/smb/src/protocol/body/file_info/alignment.rs @@ -0,0 +1,10 @@ +use serde::{Deserialize, Serialize}; + +use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; + +/// FILE_ALIGNMENT_INFORMATION (MS-FSCC 2.4.3) — 4 bytes +#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, SMBByteSize, SMBFromBytes, SMBToBytes)] +pub struct FileAlignmentInformation { + #[smb_direct(start(fixed = 0))] + pub alignment_requirement: u32, +} diff --git a/smb/src/protocol/body/file_info/basic.rs b/smb/src/protocol/body/file_info/basic.rs new file mode 100644 index 0000000..f592547 --- /dev/null +++ b/smb/src/protocol/body/file_info/basic.rs @@ -0,0 +1,23 @@ +use serde::{Deserialize, Serialize}; + +use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; + +use crate::protocol::body::create::file_attributes::SMBFileAttributes; +use crate::protocol::body::filetime::FileTime; + +/// FILE_BASIC_INFORMATION (MS-FSCC 2.4.7) — 40 bytes +#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, SMBByteSize, SMBFromBytes, SMBToBytes)] +pub struct FileBasicInformation { + #[smb_direct(start(fixed = 0))] + pub creation_time: FileTime, + #[smb_direct(start(fixed = 8))] + pub last_access_time: FileTime, + #[smb_direct(start(fixed = 16))] + pub last_write_time: FileTime, + #[smb_direct(start(fixed = 24))] + pub change_time: FileTime, + #[smb_direct(start(fixed = 32))] + pub file_attributes: SMBFileAttributes, + #[smb_direct(start(fixed = 36))] + pub reserved: u32, +} diff --git a/smb/src/protocol/body/file_info/ea.rs b/smb/src/protocol/body/file_info/ea.rs new file mode 100644 index 0000000..f530012 --- /dev/null +++ b/smb/src/protocol/body/file_info/ea.rs @@ -0,0 +1,10 @@ +use serde::{Deserialize, Serialize}; + +use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; + +/// FILE_EA_INFORMATION (MS-FSCC 2.4.12) — 4 bytes +#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, SMBByteSize, SMBFromBytes, SMBToBytes)] +pub struct FileEaInformation { + #[smb_direct(start(fixed = 0))] + pub ea_size: u32, +} diff --git a/smb/src/protocol/body/file_info/internal.rs b/smb/src/protocol/body/file_info/internal.rs new file mode 100644 index 0000000..c09ac8b --- /dev/null +++ b/smb/src/protocol/body/file_info/internal.rs @@ -0,0 +1,10 @@ +use serde::{Deserialize, Serialize}; + +use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; + +/// FILE_INTERNAL_INFORMATION (MS-FSCC 2.4.20) — 8 bytes +#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, SMBByteSize, SMBFromBytes, SMBToBytes)] +pub struct FileInternalInformation { + #[smb_direct(start(fixed = 0))] + pub index_number: u64, +} diff --git a/smb/src/protocol/body/file_info/mod.rs b/smb/src/protocol/body/file_info/mod.rs new file mode 100644 index 0000000..433d67d --- /dev/null +++ b/smb/src/protocol/body/file_info/mod.rs @@ -0,0 +1,279 @@ +//! MS-FSCC File Information Classes +//! +//! Typed representations of the file information structures defined in +//! [MS-FSCC] sections 2.4.x, used in QueryInfo / SetInfo responses. + +mod access; +mod alignment; +mod basic; +mod ea; +mod internal; +mod mode; +mod name; +mod network_open; +mod position; +mod standard; + +pub use access::FileAccessInformation; +pub use alignment::FileAlignmentInformation; +pub use basic::FileBasicInformation; +pub use ea::FileEaInformation; +pub use internal::FileInternalInformation; +pub use mode::FileModeInformation; +pub use name::FileNameInformation; +pub use network_open::FileNetworkOpenInformation; +pub use position::FilePositionInformation; +pub use standard::FileStandardInformation; + +/// FILE_ALL_INFORMATION (MS-FSCC 2.4.2) — composite structure +/// +/// This is a concatenation of the sub-structures above. +/// We serialize it by concatenating each sub-struct's bytes rather than +/// using the derive macro, because the derive macro doesn't support +/// nested struct composition at variable offsets. +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct FileAllInformation { + pub basic: FileBasicInformation, + pub standard: FileStandardInformation, + pub internal: FileInternalInformation, + pub ea: FileEaInformation, + pub access: FileAccessInformation, + pub position: FilePositionInformation, + pub mode: FileModeInformation, + pub alignment: FileAlignmentInformation, + pub name: FileNameInformation, +} + +impl FileAllInformation { + pub fn to_bytes(&self) -> Vec { + use smb_core::SMBToBytes; + let mut buf = Vec::with_capacity(104); + buf.extend_from_slice(&self.basic.smb_to_bytes()); + buf.extend_from_slice(&self.standard.smb_to_bytes()); + buf.extend_from_slice(&self.internal.smb_to_bytes()); + buf.extend_from_slice(&self.ea.smb_to_bytes()); + buf.extend_from_slice(&self.access.smb_to_bytes()); + buf.extend_from_slice(&self.position.smb_to_bytes()); + buf.extend_from_slice(&self.mode.smb_to_bytes()); + buf.extend_from_slice(&self.alignment.smb_to_bytes()); + buf.extend_from_slice(&self.name.smb_to_bytes()); + buf + } +} + +#[cfg(test)] +mod tests { + use super::*; + use smb_core::{SMBByteSize, SMBFromBytes, SMBToBytes}; + use crate::protocol::body::create::file_attributes::SMBFileAttributes; + use crate::protocol::body::filetime::FileTime; + + #[test] + fn file_basic_information_size_is_40() { + let info = FileBasicInformation { + creation_time: FileTime::zero(), + last_access_time: FileTime::zero(), + last_write_time: FileTime::zero(), + change_time: FileTime::zero(), + file_attributes: SMBFileAttributes::NORMAL, + reserved: 0, + }; + assert_eq!(info.smb_byte_size(), 40); + } + + #[test] + fn file_basic_information_round_trip() { + let info = FileBasicInformation { + creation_time: FileTime::now(), + last_access_time: FileTime::now(), + last_write_time: FileTime::now(), + change_time: FileTime::now(), + file_attributes: SMBFileAttributes::ARCHIVE | SMBFileAttributes::READONLY, + reserved: 0, + }; + let bytes = info.smb_to_bytes(); + assert_eq!(bytes.len(), 40); + let (_, parsed) = FileBasicInformation::smb_from_bytes(&bytes).unwrap(); + assert_eq!(info, parsed); + } + + #[test] + fn file_standard_information_size_is_24() { + let info = FileStandardInformation { + allocation_size: 4096, + end_of_file: 1024, + number_of_links: 1, + delete_pending: 0, + directory: 0, + reserved: 0, + }; + assert_eq!(info.smb_byte_size(), 24); + } + + #[test] + fn file_standard_information_round_trip() { + let info = FileStandardInformation { + allocation_size: 8192, + end_of_file: 2048, + number_of_links: 3, + delete_pending: 1, + directory: 0, + reserved: 0, + }; + let bytes = info.smb_to_bytes(); + assert_eq!(bytes.len(), 24); + let (_, parsed) = FileStandardInformation::smb_from_bytes(&bytes).unwrap(); + assert_eq!(info, parsed); + } + + #[test] + fn file_internal_information_round_trip() { + let info = FileInternalInformation { index_number: 42 }; + let bytes = info.smb_to_bytes(); + assert_eq!(bytes.len(), 8); + let (_, parsed) = FileInternalInformation::smb_from_bytes(&bytes).unwrap(); + assert_eq!(info, parsed); + } + + #[test] + fn file_ea_information_round_trip() { + let info = FileEaInformation { ea_size: 0 }; + let bytes = info.smb_to_bytes(); + assert_eq!(bytes.len(), 4); + let (_, parsed) = FileEaInformation::smb_from_bytes(&bytes).unwrap(); + assert_eq!(info, parsed); + } + + #[test] + fn file_access_information_round_trip() { + let info = FileAccessInformation { access_flags: 0x001f01ff }; + let bytes = info.smb_to_bytes(); + assert_eq!(bytes.len(), 4); + let (_, parsed) = FileAccessInformation::smb_from_bytes(&bytes).unwrap(); + assert_eq!(info, parsed); + } + + #[test] + fn file_position_information_round_trip() { + let info = FilePositionInformation { current_byte_offset: 512 }; + let bytes = info.smb_to_bytes(); + assert_eq!(bytes.len(), 8); + let (_, parsed) = FilePositionInformation::smb_from_bytes(&bytes).unwrap(); + assert_eq!(info, parsed); + } + + #[test] + fn file_mode_information_round_trip() { + let info = FileModeInformation { mode: 0 }; + let bytes = info.smb_to_bytes(); + assert_eq!(bytes.len(), 4); + let (_, parsed) = FileModeInformation::smb_from_bytes(&bytes).unwrap(); + assert_eq!(info, parsed); + } + + #[test] + fn file_alignment_information_round_trip() { + let info = FileAlignmentInformation { alignment_requirement: 0 }; + let bytes = info.smb_to_bytes(); + assert_eq!(bytes.len(), 4); + let (_, parsed) = FileAlignmentInformation::smb_from_bytes(&bytes).unwrap(); + assert_eq!(info, parsed); + } + + #[test] + fn file_network_open_information_size_is_56() { + let info = FileNetworkOpenInformation { + creation_time: FileTime::zero(), + last_access_time: FileTime::zero(), + last_write_time: FileTime::zero(), + change_time: FileTime::zero(), + allocation_size: 0, + end_of_file: 0, + file_attributes: SMBFileAttributes::NORMAL, + reserved: 0, + }; + assert_eq!(info.smb_byte_size(), 56); + } + + #[test] + fn file_network_open_information_round_trip() { + let info = FileNetworkOpenInformation { + creation_time: FileTime::now(), + last_access_time: FileTime::now(), + last_write_time: FileTime::now(), + change_time: FileTime::now(), + allocation_size: 4096, + end_of_file: 1024, + file_attributes: SMBFileAttributes::ARCHIVE, + reserved: 0, + }; + let bytes = info.smb_to_bytes(); + assert_eq!(bytes.len(), 56); + let (_, parsed) = FileNetworkOpenInformation::smb_from_bytes(&bytes).unwrap(); + assert_eq!(info, parsed); + } + + #[test] + fn file_all_information_contains_all_sub_structs() { + let all = FileAllInformation { + basic: FileBasicInformation { + creation_time: FileTime::zero(), + last_access_time: FileTime::zero(), + last_write_time: FileTime::zero(), + change_time: FileTime::zero(), + file_attributes: SMBFileAttributes::NORMAL, + reserved: 0, + }, + standard: FileStandardInformation { + allocation_size: 4096, + end_of_file: 21, + number_of_links: 1, + delete_pending: 0, + directory: 0, + reserved: 0, + }, + internal: FileInternalInformation { index_number: 0 }, + ea: FileEaInformation { ea_size: 0 }, + access: FileAccessInformation { access_flags: 0x001f01ff }, + position: FilePositionInformation { current_byte_offset: 0 }, + mode: FileModeInformation { mode: 0 }, + alignment: FileAlignmentInformation { alignment_requirement: 0 }, + name: FileNameInformation { + file_name_length: 24, + file_name: "testfile.txt".into(), + }, + }; + let bytes = all.to_bytes(); + // 40 + 24 + 8 + 4 + 4 + 8 + 4 + 4 + (4 + 24) = 124 + assert_eq!(bytes.len(), 124); + } + + #[test] + fn file_all_information_basic_segment_matches_standalone() { + let basic = FileBasicInformation { + creation_time: FileTime::now(), + last_access_time: FileTime::now(), + last_write_time: FileTime::now(), + change_time: FileTime::now(), + file_attributes: SMBFileAttributes::ARCHIVE, + reserved: 0, + }; + let all = FileAllInformation { + basic: basic.clone(), + standard: FileStandardInformation { + allocation_size: 0, end_of_file: 0, number_of_links: 1, + delete_pending: 0, directory: 0, reserved: 0, + }, + internal: FileInternalInformation { index_number: 0 }, + ea: FileEaInformation { ea_size: 0 }, + access: FileAccessInformation { access_flags: 0 }, + position: FilePositionInformation { current_byte_offset: 0 }, + mode: FileModeInformation { mode: 0 }, + alignment: FileAlignmentInformation { alignment_requirement: 0 }, + name: FileNameInformation { file_name_length: 0, file_name: String::new() }, + }; + let all_bytes = all.to_bytes(); + let basic_bytes = basic.smb_to_bytes(); + assert_eq!(&all_bytes[..40], &basic_bytes[..]); + } +} diff --git a/smb/src/protocol/body/file_info/mode.rs b/smb/src/protocol/body/file_info/mode.rs new file mode 100644 index 0000000..af299dd --- /dev/null +++ b/smb/src/protocol/body/file_info/mode.rs @@ -0,0 +1,10 @@ +use serde::{Deserialize, Serialize}; + +use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; + +/// FILE_MODE_INFORMATION (MS-FSCC 2.4.26) — 4 bytes +#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, SMBByteSize, SMBFromBytes, SMBToBytes)] +pub struct FileModeInformation { + #[smb_direct(start(fixed = 0))] + pub mode: u32, +} diff --git a/smb/src/protocol/body/file_info/name.rs b/smb/src/protocol/body/file_info/name.rs new file mode 100644 index 0000000..b58bb3e --- /dev/null +++ b/smb/src/protocol/body/file_info/name.rs @@ -0,0 +1,12 @@ +use serde::{Deserialize, Serialize}; + +use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; + +/// FILE_NAME_INFORMATION (MS-FSCC 2.4.28) — variable length +#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, SMBByteSize, SMBFromBytes, SMBToBytes)] +pub struct FileNameInformation { + #[smb_direct(start(fixed = 0))] + pub file_name_length: u32, + #[smb_string(order = 0, start(fixed = 4), length(inner(start = 0, num_type = "u32")), underlying = "u16")] + pub file_name: String, +} diff --git a/smb/src/protocol/body/file_info/network_open.rs b/smb/src/protocol/body/file_info/network_open.rs new file mode 100644 index 0000000..a52c537 --- /dev/null +++ b/smb/src/protocol/body/file_info/network_open.rs @@ -0,0 +1,27 @@ +use serde::{Deserialize, Serialize}; + +use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; + +use crate::protocol::body::create::file_attributes::SMBFileAttributes; +use crate::protocol::body::filetime::FileTime; + +/// FILE_NETWORK_OPEN_INFORMATION (MS-FSCC 2.4.29) — 56 bytes +#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, SMBByteSize, SMBFromBytes, SMBToBytes)] +pub struct FileNetworkOpenInformation { + #[smb_direct(start(fixed = 0))] + pub creation_time: FileTime, + #[smb_direct(start(fixed = 8))] + pub last_access_time: FileTime, + #[smb_direct(start(fixed = 16))] + pub last_write_time: FileTime, + #[smb_direct(start(fixed = 24))] + pub change_time: FileTime, + #[smb_direct(start(fixed = 32))] + pub allocation_size: u64, + #[smb_direct(start(fixed = 40))] + pub end_of_file: u64, + #[smb_direct(start(fixed = 48))] + pub file_attributes: SMBFileAttributes, + #[smb_direct(start(fixed = 52))] + pub reserved: u32, +} diff --git a/smb/src/protocol/body/file_info/position.rs b/smb/src/protocol/body/file_info/position.rs new file mode 100644 index 0000000..aba5efa --- /dev/null +++ b/smb/src/protocol/body/file_info/position.rs @@ -0,0 +1,10 @@ +use serde::{Deserialize, Serialize}; + +use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; + +/// FILE_POSITION_INFORMATION (MS-FSCC 2.4.35) — 8 bytes +#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, SMBByteSize, SMBFromBytes, SMBToBytes)] +pub struct FilePositionInformation { + #[smb_direct(start(fixed = 0))] + pub current_byte_offset: u64, +} diff --git a/smb/src/protocol/body/file_info/standard.rs b/smb/src/protocol/body/file_info/standard.rs new file mode 100644 index 0000000..e03aeaa --- /dev/null +++ b/smb/src/protocol/body/file_info/standard.rs @@ -0,0 +1,20 @@ +use serde::{Deserialize, Serialize}; + +use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; + +/// FILE_STANDARD_INFORMATION (MS-FSCC 2.4.41) — 24 bytes +#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, SMBByteSize, SMBFromBytes, SMBToBytes)] +pub struct FileStandardInformation { + #[smb_direct(start(fixed = 0))] + pub allocation_size: u64, + #[smb_direct(start(fixed = 8))] + pub end_of_file: u64, + #[smb_direct(start(fixed = 16))] + pub number_of_links: u32, + #[smb_direct(start(fixed = 20))] + pub delete_pending: u8, + #[smb_direct(start(fixed = 21))] + pub directory: u8, + #[smb_direct(start(fixed = 22))] + pub reserved: u16, +} diff --git a/smb/src/protocol/body/mod.rs b/smb/src/protocol/body/mod.rs index ab0021e..577dfbe 100644 --- a/smb/src/protocol/body/mod.rs +++ b/smb/src/protocol/body/mod.rs @@ -68,6 +68,7 @@ pub mod query_info; pub mod ioctl; pub mod set_info; pub mod oplock_break; +pub mod file_info; pub trait Body: SMBEnumFromBytes + SMBToBytes { fn parse_with_cc(bytes: &[u8], command_code: S::CommandCode) -> SMBParseResult<&[u8], Self> where Self: Sized; diff --git a/smb/src/protocol/body/query_info/info_type.rs b/smb/src/protocol/body/query_info/info_type.rs index 3550893..9c4b6d5 100644 --- a/smb/src/protocol/body/query_info/info_type.rs +++ b/smb/src/protocol/body/query_info/info_type.rs @@ -6,8 +6,8 @@ use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; #[repr(u8)] #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, TryFromPrimitive, SMBToBytes, SMBFromBytes, SMBByteSize, Serialize, Deserialize)] pub enum SMBInfoType { - File, - Filesystem, - Security, - Quota, + File = 1, + Filesystem = 2, + Security = 3, + Quota = 4, } \ No newline at end of file diff --git a/smb/src/protocol/body/query_info/mod.rs b/smb/src/protocol/body/query_info/mod.rs index b4dd070..7474a8c 100644 --- a/smb/src/protocol/body/query_info/mod.rs +++ b/smb/src/protocol/body/query_info/mod.rs @@ -10,7 +10,7 @@ use crate::protocol::body::query_info::info_type::SMBInfoType; use crate::protocol::body::query_info::security_information::SMBSecurityInformation; mod flags; -mod info_type; +pub mod info_type; mod security_information; #[derive( @@ -44,6 +44,24 @@ pub struct SMBQueryInfoRequest { buffer: Vec, } +impl SMBQueryInfoRequest { + pub fn info_type(&self) -> SMBInfoType { + self.info_type + } + + pub fn file_info_class(&self) -> u8 { + self.file_info_class + } + + pub fn output_buffer_length(&self) -> u32 { + self.output_buffer_length + } + + pub fn file_id(&self) -> &SMBFileId { + &self.file_id + } +} + #[derive( Debug, PartialEq, @@ -55,11 +73,85 @@ pub struct SMBQueryInfoRequest { Deserialize, Clone )] -#[smb_byte_tag(value = 17)] +#[smb_byte_tag(value = 9)] pub struct SMBQueryInfoResponse { #[smb_skip(start = 2, length = 6)] reserved: PhantomData>, // TODO make this a struct: https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-smb2/3b1b3598-a898-44ca-bfac-2dcae065247f #[smb_buffer(order = 0, offset(inner(start = 2, num_type = "u16", subtract = 64)), length(inner(start = 4, num_type = "u32")))] data: Vec, +} + +impl SMBQueryInfoResponse { + pub fn new(data: Vec) -> Self { + Self { + reserved: PhantomData, + data, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use smb_core::{SMBByteSize, SMBToBytes, SMBFromBytes}; + + #[test] + fn query_info_response_new_sets_data() { + let data = vec![1, 2, 3, 4, 5, 6, 7, 8]; + let resp = SMBQueryInfoResponse::new(data.clone()); + assert_eq!(resp.data, data); + } + + #[test] + fn query_info_response_serialization_round_trip() { + let resp = SMBQueryInfoResponse::new(vec![0xAA; 40]); + let bytes = resp.smb_to_bytes(); + assert_eq!(bytes.len(), resp.smb_byte_size()); + let (_, parsed) = SMBQueryInfoResponse::smb_from_bytes(&bytes).unwrap(); + assert_eq!(resp, parsed); + } + + #[test] + fn query_info_response_empty_data_round_trip() { + let resp = SMBQueryInfoResponse::new(vec![]); + let bytes = resp.smb_to_bytes(); + let (_, parsed) = SMBQueryInfoResponse::smb_from_bytes(&bytes).unwrap(); + assert_eq!(resp, parsed); + } + + #[test] + fn query_info_request_accessors() { + let bytes = { + let mut buf = Vec::new(); + // struct_size (u16) = 41 + buf.extend_from_slice(&41u16.to_le_bytes()); + // info_type (u8) = 1 (File) per MS-SMB2 + buf.push(1); + // file_info_class (u8) = 4 (FileBasicInformation) + buf.push(4); + // output_buffer_length (u32) = 4096 + buf.extend_from_slice(&4096u32.to_le_bytes()); + // input_buffer_offset (u16) = 0 + buf.extend_from_slice(&0u16.to_le_bytes()); + // reserved (u16) = 0 + buf.extend_from_slice(&0u16.to_le_bytes()); + // input_buffer_length (u32) = 0 + buf.extend_from_slice(&0u32.to_le_bytes()); + // additional_information (u32) = 0 + buf.extend_from_slice(&0u32.to_le_bytes()); + // flags (u32) = 0 + buf.extend_from_slice(&0u32.to_le_bytes()); + // file_id: persistent (u64) + volatile (u64) + buf.extend_from_slice(&55u64.to_le_bytes()); + buf.extend_from_slice(&77u64.to_le_bytes()); + buf + }; + let (_, req) = SMBQueryInfoRequest::smb_from_bytes(&bytes).unwrap(); + assert_eq!(req.info_type(), SMBInfoType::File); + assert_eq!(req.file_info_class(), 4); + assert_eq!(req.output_buffer_length(), 4096); + assert_eq!(req.file_id().persistent, 55); + assert_eq!(req.file_id().volatile, 77); + } } \ No newline at end of file diff --git a/smb/src/protocol/body/read/mod.rs b/smb/src/protocol/body/read/mod.rs index 2e9a64d..2a1cca5 100644 --- a/smb/src/protocol/body/read/mod.rs +++ b/smb/src/protocol/body/read/mod.rs @@ -42,6 +42,24 @@ pub struct SMBReadRequest { channel_information: Vec, } +impl SMBReadRequest { + pub fn file_id(&self) -> &SMBFileId { + &self.file_id + } + + pub fn read_length(&self) -> u32 { + self.read_length + } + + pub fn read_offset(&self) -> u64 { + self.read_offset + } + + pub fn minimum_count(&self) -> u32 { + self.minimum_count + } +} + #[derive( Debug, PartialEq, @@ -63,4 +81,83 @@ pub struct SMBReadResponse { flags: SMBReadResponseFlags, #[smb_buffer(order = 0, offset(inner(start = 2, num_type = "u8", subtract = 64)), length(inner(start = 4, num_type = "u32")))] data: Vec, +} + +impl SMBReadResponse { + pub fn new(data: Vec, data_remaining: u32) -> Self { + Self { + reserved: PhantomData, + data_remaining, + flags: SMBReadResponseFlags::None, + data, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use smb_core::{SMBByteSize, SMBToBytes, SMBFromBytes}; + + #[test] + fn read_response_new_sets_fields() { + let data = vec![0xDE, 0xAD, 0xBE, 0xEF]; + let resp = SMBReadResponse::new(data.clone(), 100); + assert_eq!(resp.data, data); + assert_eq!(resp.data_remaining, 100); + assert_eq!(resp.flags, SMBReadResponseFlags::None); + } + + #[test] + fn read_response_serialization_round_trip() { + let resp = SMBReadResponse::new(vec![1, 2, 3, 4, 5], 0); + let bytes = resp.smb_to_bytes(); + assert_eq!(bytes.len(), resp.smb_byte_size()); + let (_, parsed) = SMBReadResponse::smb_from_bytes(&bytes).unwrap(); + assert_eq!(resp, parsed); + } + + #[test] + fn read_response_empty_data() { + let resp = SMBReadResponse::new(vec![], 0); + let bytes = resp.smb_to_bytes(); + let (_, parsed) = SMBReadResponse::smb_from_bytes(&bytes).unwrap(); + assert_eq!(resp, parsed); + } + + #[test] + fn read_request_accessors() { + let bytes = { + let mut buf = Vec::new(); + // struct_size (u16) = 49 + buf.extend_from_slice(&49u16.to_le_bytes()); + // padding (u8) + buf.push(0); + // flags (u8) = 0 + buf.push(0); + // read_length (u32) = 1024 + buf.extend_from_slice(&1024u32.to_le_bytes()); + // read_offset (u64) = 512 + buf.extend_from_slice(&512u64.to_le_bytes()); + // file_id: persistent (u64) + volatile (u64) + buf.extend_from_slice(&10u64.to_le_bytes()); + buf.extend_from_slice(&20u64.to_le_bytes()); + // minimum_count (u32) = 256 + buf.extend_from_slice(&256u32.to_le_bytes()); + // channel (u32) = 0 + buf.extend_from_slice(&0u32.to_le_bytes()); + // remaining_bytes (u32) = 0 + buf.extend_from_slice(&0u32.to_le_bytes()); + // channel_info_offset (u16) = 0, channel_info_length (u16) = 0 + buf.extend_from_slice(&0u16.to_le_bytes()); + buf.extend_from_slice(&0u16.to_le_bytes()); + buf + }; + let (_, req) = SMBReadRequest::smb_from_bytes(&bytes).unwrap(); + assert_eq!(req.read_length(), 1024); + assert_eq!(req.read_offset(), 512); + assert_eq!(req.minimum_count(), 256); + assert_eq!(req.file_id().persistent, 10); + assert_eq!(req.file_id().volatile, 20); + } } \ No newline at end of file diff --git a/smb/src/server/mod.rs b/smb/src/server/mod.rs index a76e1a8..79d8b27 100644 --- a/smb/src/server/mod.rs +++ b/smb/src/server/mod.rs @@ -52,6 +52,7 @@ pub trait Server: Send + Sync { fn shares(&self) -> &HashMap>; fn opens(&self) -> &HashMap>>; fn add_open(&mut self, open: Arc>) -> impl Future; + fn remove_open(&mut self, global_id: u32); fn sessions(&self) -> &HashMap>>; fn sessions_mut(&mut self) -> &mut HashMap>>; fn guid(&self) -> Uuid; @@ -198,6 +199,10 @@ impl, Auth: AuthProvider, Share: 0 } + fn remove_open(&mut self, global_id: u32) { + self.open_table.remove(&global_id); + } + fn sessions(&self) -> &HashMap>> { &self.session_table } diff --git a/smb/src/server/open.rs b/smb/src/server/open.rs index cf402aa..63e5209 100644 --- a/smb/src/server/open.rs +++ b/smb/src/server/open.rs @@ -1,5 +1,4 @@ -use std::fmt::{Debug, Formatter, Pointer}; -use std::future::Future; +use std::fmt::{Debug, Formatter}; use std::sync::Arc; use uuid::Uuid; @@ -28,6 +27,7 @@ pub trait Open: Send + Sync { fn file_attributes(&self) -> SMBFileAttributes; fn file_id(&self) -> SMBFileId; fn file_metadata(&self) -> SMBResult; + fn read_data(&mut self, offset: u64, length: u32) -> SMBResult>; } pub struct SMBOpen { @@ -143,7 +143,11 @@ impl Open for SMBOpen { } fn file_metadata(&self) -> SMBResult { - return self.underlying.metadata() + self.underlying.metadata() + } + + fn read_data(&mut self, offset: u64, length: u32) -> SMBResult> { + self.underlying.read_data(offset, length) } } // TODO: From MS-FSCC section 2.6 @@ -210,8 +214,8 @@ impl Debug for SMBOpen where S: Debug, S::Session: Debug, S::Handl impl SMBLockedMessageHandlerBase for Arc> { type Inner = (); - async fn inner(&self, message: &SMBMessageType) -> Option { - todo!() + async fn inner(&self, _message: &SMBMessageType) -> Option { + None } } diff --git a/smb/src/server/session.rs b/smb/src/server/session.rs index 1ea8b02..be371fe 100644 --- a/smb/src/server/session.rs +++ b/smb/src/server/session.rs @@ -57,6 +57,7 @@ pub trait Session: Send + Sync { fn provider(&self) -> &Arc; fn encrypt_data(&self) -> bool; fn open_table(&self) -> &HashMap>>; + fn open_table_mut(&mut self) -> &mut HashMap>>; fn add_open(&mut self, open: Arc>) -> impl Future; fn set_previous_file_id(&mut self, file_id: SMBFileId); fn signing_key(&self) -> &[u8]; @@ -242,7 +243,7 @@ impl>> SMBLockedMessageHandlerBase for Arc::get_next_map_id(&self_rd.tree_connect_table); let tree_connect = SMBTreeConnect::init(tree_id, Arc::downgrade(self), share.clone(), response.access_mask().clone()); - let header = SMBSyncHeader::create_response_header(&header, 0, self_rd.id(), 1); + let header = SMBSyncHeader::create_response_header(&header, 0, self_rd.id(), tree_id); drop(self_rd); let mut self_wr = self.write().await; self_wr.tree_connect_table.insert(tree_id, Arc::new(tree_connect)); @@ -342,6 +343,10 @@ impl> Session f &self.open_table } + fn open_table_mut(&mut self) -> &mut HashMap>> { + &mut self.open_table + } + async fn add_open(&mut self, open: Arc>) { let id = Self::get_next_map_id(&self.open_table); let mut open_wr = open.write().await; diff --git a/smb/src/server/share/file_system.rs b/smb/src/server/share/file_system.rs index 25d4417..29caf7a 100644 --- a/smb/src/server/share/file_system.rs +++ b/smb/src/server/share/file_system.rs @@ -2,6 +2,7 @@ use std::any::Any; use std::fmt::{Debug, Formatter}; use std::fs; use std::fs::{File, OpenOptions, ReadDir}; +use std::io::{Read, Seek, SeekFrom}; use std::marker::PhantomData; use std::time::{SystemTime, UNIX_EPOCH}; @@ -86,6 +87,21 @@ impl ResourceHandle for SMBFileSystemHandle { actual_size: metadata.len(), }) } + + fn read_data(&mut self, offset: u64, length: u32) -> SMBResult> { + match &mut self.resource { + SMBFileSystemResourceHandle::File(file) => { + file.seek(SeekFrom::Start(offset)).map_err(SMBError::io_error)?; + let mut buf = vec![0u8; length as usize]; + let bytes_read = file.read(&mut buf).map_err(SMBError::io_error)?; + buf.truncate(bytes_read); + Ok(buf) + } + SMBFileSystemResourceHandle::Directory(_) => { + Err(SMBError::response_error(smb_core::nt_status::NTStatus::InvalidDeviceRequest)) + } + } + } } impl SMBFileSystemResourceHandle { @@ -166,7 +182,10 @@ impl + ResourceHandle + } fn handle_create(&self, path: &str, disposition: SMBCreateDisposition, directory: bool) -> SMBResult { - let path = format!("{}/{}", self.local_path, path); + // Sanitize: strip NUL terminators from UTF-16LE wire encoding, + // convert Windows backslashes to forward slashes + let sanitized = path.trim_end_matches('\0').replace('\\', "/"); + let path = format!("{}/{}", self.local_path, sanitized); let resource = match directory { true => SMBFileSystemResourceHandle::directory(&path), false => SMBFileSystemResourceHandle::file(&path, disposition) diff --git a/smb/src/server/share/ipc.rs b/smb/src/server/share/ipc.rs index b7ef3cc..77a9fbf 100644 --- a/smb/src/server/share/ipc.rs +++ b/smb/src/server/share/ipc.rs @@ -44,6 +44,10 @@ impl ResourceHandle for SMBIPCHandle { actual_size: 0, }) } + + fn read_data(&mut self, _offset: u64, _length: u32) -> SMBResult> { + Err(SMBError::response_error(smb_core::nt_status::NTStatus::InvalidDeviceRequest)) + } } impl From for Box { diff --git a/smb/src/server/share/mod.rs b/smb/src/server/share/mod.rs index 8856db6..0ce87ab 100644 --- a/smb/src/server/share/mod.rs +++ b/smb/src/server/share/mod.rs @@ -24,6 +24,7 @@ pub trait ResourceHandle: Send + Sync { fn is_directory(&self) -> bool; fn path(&self) -> &str; fn metadata(&self) -> SMBResult; + fn read_data(&mut self, offset: u64, length: u32) -> SMBResult>; } pub struct SMBFileMetadata { @@ -55,6 +56,10 @@ impl ResourceHandle for Box { fn metadata(&self) -> SMBResult { H::metadata(self) } + + fn read_data(&mut self, offset: u64, length: u32) -> SMBResult> { + H::read_data(self, offset, length) + } } pub trait SharedResource: Send + Sync { diff --git a/smb/src/server/tree_connect.rs b/smb/src/server/tree_connect.rs index 46bb0bd..232a353 100644 --- a/smb/src/server/tree_connect.rs +++ b/smb/src/server/tree_connect.rs @@ -4,13 +4,25 @@ use std::sync::{Arc, Weak}; use tokio::sync::RwLock; -use smb_core::{SMBByteSize, SMBResult}; +use smb_core::{SMBByteSize, SMBResult, SMBToBytes}; use smb_core::error::SMBError; use smb_core::logging::{debug, trace}; +use smb_core::nt_status::NTStatus; +use crate::protocol::body::close::{SMBCloseRequest, SMBCloseResponse}; use crate::protocol::body::create::{SMBCreateRequest, SMBCreateResponse}; +use crate::protocol::body::create::file_attributes::SMBFileAttributes; use crate::protocol::body::create::file_id::SMBFileId; +use crate::protocol::body::file_info::{ + FileBasicInformation, FileStandardInformation, FileNetworkOpenInformation, + FileAllInformation, FileInternalInformation, FileEaInformation, + FileAccessInformation, FilePositionInformation, FileModeInformation, + FileAlignmentInformation, FileNameInformation, +}; use crate::protocol::body::filetime::FileTime; +use crate::protocol::body::query_info::{SMBQueryInfoRequest, SMBQueryInfoResponse}; +use crate::protocol::body::query_info::info_type::SMBInfoType; +use crate::protocol::body::read::{SMBReadRequest, SMBReadResponse}; use crate::protocol::body::SMBBody; use crate::protocol::body::tree_connect::access_mask::SMBAccessMask; use crate::protocol::header::SMBSyncHeader; @@ -48,10 +60,84 @@ impl SMBTreeConnect { } } +impl SMBTreeConnect { + fn get_session(&self) -> SMBResult>> { + self.session.upgrade() + .ok_or(SMBError::server_error("No Session Found")) + } + + async fn find_open(&self, file_id: &SMBFileId) -> SMBResult>> { + let session = self.get_session()?; + let session_rd = session.read().await; + session_rd.open_table() + .get(&file_id.volatile) + .cloned() + .ok_or(SMBError::response_error(NTStatus::FileClosed)) + } + + fn build_basic_info(open: &S::Open) -> SMBResult { + let metadata = open.file_metadata()?; + Ok(FileBasicInformation { + creation_time: metadata.creation_time, + last_access_time: metadata.last_access_time, + last_write_time: metadata.last_write_time, + change_time: metadata.last_modification_time, + file_attributes: open.file_attributes(), + reserved: 0, + }) + } + + fn build_standard_info(open: &S::Open) -> SMBResult { + let metadata = open.file_metadata()?; + let is_dir = open.file_attributes().contains(SMBFileAttributes::DIRECTORY); + Ok(FileStandardInformation { + allocation_size: metadata.allocated_size, + end_of_file: metadata.actual_size, + number_of_links: 1, + delete_pending: 0, + directory: if is_dir { 1 } else { 0 }, + reserved: 0, + }) + } + + fn build_network_open_info(open: &S::Open) -> SMBResult { + let metadata = open.file_metadata()?; + Ok(FileNetworkOpenInformation { + creation_time: metadata.creation_time, + last_access_time: metadata.last_access_time, + last_write_time: metadata.last_write_time, + change_time: metadata.last_modification_time, + allocation_size: metadata.allocated_size, + end_of_file: metadata.actual_size, + file_attributes: open.file_attributes(), + reserved: 0, + }) + } + + fn build_all_info(open: &S::Open) -> SMBResult { + let name = open.file_name(); + let name_byte_len = (name.encode_utf16().count() * 2) as u32; + Ok(FileAllInformation { + basic: Self::build_basic_info(open)?, + standard: Self::build_standard_info(open)?, + internal: FileInternalInformation { index_number: 0 }, + ea: FileEaInformation { ea_size: 0 }, + access: FileAccessInformation { access_flags: 0x001f01ff }, + position: FilePositionInformation { current_byte_offset: 0 }, + mode: FileModeInformation { mode: 0 }, + alignment: FileAlignmentInformation { alignment_requirement: 0 }, + name: FileNameInformation { + file_name_length: name_byte_len, + file_name: name.into(), + }, + }) + } +} + impl SMBLockedMessageHandlerBase for Arc> { type Inner = Arc>; - async fn inner(&self, message: &SMBMessageType) -> Option { + async fn inner(&self, _message: &SMBMessageType) -> Option { None } @@ -61,14 +147,14 @@ impl SMBLockedMessageHandlerBase for Arc> { let open_raw = Open::init(handle, message); let response = SMBBody::CreateResponse(SMBCreateResponse::for_open::(&open_raw)?); let open = Arc::new(RwLock::new(open_raw)); - let session = self.session.upgrade() - .ok_or(SMBError::server_error("No Session Found"))?; - session.write().await.add_open(open.clone()).await; + let session = self.get_session()?; + // Register with server first (outermost), then session (inner) let server = session.upper().await? .upper().await?; { server.write().await.add_open(open.clone()).await; } + session.write().await.add_open(open.clone()).await; { let file_id = open.read().await.file_id(); session.write().await.set_previous_file_id(file_id); @@ -78,6 +164,96 @@ impl SMBLockedMessageHandlerBase for Arc> { trace!(response_size = response.smb_byte_size(), "create response built"); Ok(SMBHandlerState::Finished(SMBMessage::new(header, response))) } + + async fn handle_close(&mut self, header: &SMBSyncHeader, message: &SMBCloseRequest) -> SMBResult> { + debug!(file_id = ?message.file_id(), "handling close request"); + + // Phase 1: Read open data (session_rd → open_rd, outer before inner) + let session = self.get_session()?; + let open = { + let session_rd = session.read().await; + session_rd.open_table() + .get(&message.file_id().volatile) + .cloned() + .ok_or(SMBError::response_error(NTStatus::FileClosed))? + }; + let (response, file_id) = { + let open_rd = open.read().await; + let response = if message.flags().contains(crate::protocol::body::close::flags::SMBCloseFlags::POSTQUERY_ATTRIB) { + let metadata = open_rd.file_metadata()?; + SMBCloseResponse::from_metadata(&metadata, open_rd.file_attributes()) + } else { + SMBCloseResponse::empty() + }; + (response, open_rd.file_id()) + }; + + // Phase 2: Cleanup — acquire locks outer to inner (server_wr, then session_wr) + // Server write first (outermost) + if let Ok(conn) = session.upper().await { + if let Ok(server) = conn.upper().await { + server.write().await.remove_open(file_id.volatile as u32); + } + } + // Session write second (inner relative to server) + { + let mut session_wr = session.write().await; + session_wr.open_table_mut().remove(&file_id.volatile); + } + + debug!(file_id = ?file_id, "close completed"); + let header = header.create_response_header(0, header.session_id, header.tree_id); + Ok(SMBHandlerState::Finished(SMBMessage::new(header, SMBBody::CloseResponse(response)))) + } + + async fn handle_read(&mut self, header: &SMBSyncHeader, message: &SMBReadRequest) -> SMBResult> { + debug!(file_id = ?message.file_id(), offset = message.read_offset(), length = message.read_length(), "handling read request"); + let open = self.find_open(message.file_id()).await?; + let mut open_wr = open.write().await; + let data = open_wr.read_data(message.read_offset(), message.read_length())?; + drop(open_wr); + + if data.len() < message.minimum_count() as usize { + return Err(SMBError::response_error(NTStatus::EndOfFile)); + } + + debug!(bytes_read = data.len(), "read completed"); + trace!(data_len = data.len(), "read response data"); + let response = SMBReadResponse::new(data, 0); + let header = header.create_response_header(0, header.session_id, header.tree_id); + Ok(SMBHandlerState::Finished(SMBMessage::new(header, SMBBody::ReadResponse(response)))) + } + + async fn handle_query_info(&mut self, header: &SMBSyncHeader, message: &SMBQueryInfoRequest) -> SMBResult> { + debug!(file_id = ?message.file_id(), info_type = ?message.info_type(), class = message.file_info_class(), "handling query_info request"); + let open = self.find_open(message.file_id()).await?; + let open_rd = open.read().await; + + let data = match message.info_type() { + SMBInfoType::File => { + // MS-FSCC file information classes + match message.file_info_class() { + 4 => SMBTreeConnect::::build_basic_info(&*open_rd)?.smb_to_bytes(), + 5 => SMBTreeConnect::::build_standard_info(&*open_rd)?.smb_to_bytes(), + 18 => SMBTreeConnect::::build_all_info(&*open_rd)?.to_bytes(), + 34 => SMBTreeConnect::::build_network_open_info(&*open_rd)?.smb_to_bytes(), + _ => { + debug!(class = message.file_info_class(), "unsupported file info class"); + return Err(SMBError::response_error(NTStatus::InvalidInfoClass)); + } + } + } + _ => { + debug!(info_type = ?message.info_type(), "unsupported info type"); + return Err(SMBError::response_error(NTStatus::InvalidInfoClass)); + } + }; + + debug!(data_len = data.len(), "query_info completed"); + let response = SMBQueryInfoResponse::new(data); + let header = header.create_response_header(0, header.session_id, header.tree_id); + Ok(SMBHandlerState::Finished(SMBMessage::new(header, SMBBody::QueryInfoResponse(response)))) + } } impl SMBLockedMessageHandler for Arc> {} \ No newline at end of file diff --git a/smb/tests/smbclient.rs b/smb/tests/smbclient.rs index e55b3e7..7a07094 100644 --- a/smb/tests/smbclient.rs +++ b/smb/tests/smbclient.rs @@ -250,6 +250,194 @@ fn tree_connect_nonexistent_share() { server.kill().ok(); } +// --------------------------------------------------------------------------- +// File Read Tests +// --------------------------------------------------------------------------- + +/// Verify that smbclient can read a file from the share. +/// +/// Expected: The server handles Create, Read, QueryInfo, and Close +/// without crashing. smbclient should be able to retrieve file contents. +#[test] +#[ignore] +fn file_read_does_not_crash_server() { + use std::io::Write; + + let port = free_port(); + + // Create a temp file in the server's working directory for the share to serve + let tmp_dir = std::env::temp_dir().join(format!("smb_test_{}", port)); + std::fs::create_dir_all(&tmp_dir).expect("Failed to create temp dir"); + let test_file = tmp_dir.join("testfile.txt"); + { + let mut f = std::fs::File::create(&test_file).expect("Failed to create test file"); + f.write_all(b"hello from smb server").expect("Failed to write test file"); + } + + // Start server with the share path pointing to our temp dir + let server_bin = env!("CARGO_BIN_EXE_spin_server_up"); + let mut server = std::process::Command::new(server_bin) + .env("SMB_PORT", port.to_string()) + .env("SMB_SHARE_PATH", tmp_dir.to_str().unwrap()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .expect("Failed to spawn SMB server binary"); + + // Wait for server to start + let addr = format!("127.0.0.1:{}", port); + for _ in 0..50 { + if std::net::TcpStream::connect(&addr).is_ok() { + break; + } + std::thread::sleep(Duration::from_millis(100)); + } + + let download_path = tmp_dir.join("downloaded.txt"); + let port_str = port.to_string(); + let download_str = download_path.to_str().unwrap().to_string(); + let get_cmd = format!("get testfile.txt {}", download_str); + let (success, stdout, stderr) = run_smbclient(&[ + &format!("//127.0.0.1/test"), + "-p", &port_str, + "-U", "tejasmehta%password", + "-m", "SMB2", + "-c", &get_cmd, + ]); + + // Server should not crash + std::thread::sleep(Duration::from_millis(200)); + let status = server.try_wait().expect("Failed to check server status"); + assert!( + status.is_none(), + "Server should still be running after file read. stdout: {} stderr: {}", + stdout, stderr + ); + + // Verify the file was downloaded and contents match + assert!(success, "smbclient get should succeed. stdout: {} stderr: {}", stdout, stderr); + let downloaded = std::fs::read(&download_path) + .expect("Downloaded file should exist"); + assert_eq!( + downloaded, + b"hello from smb server", + "Downloaded file contents should match the original" + ); + + server.kill().ok(); + let _ = std::fs::remove_dir_all(&tmp_dir); +} + +/// Verify that smbclient can list files (which triggers QueryInfo). +#[test] +#[ignore] +fn directory_listing_does_not_crash_server() { + use std::io::Write; + + let port = free_port(); + + let tmp_dir = std::env::temp_dir().join(format!("smb_test_ls_{}", port)); + std::fs::create_dir_all(&tmp_dir).expect("Failed to create temp dir"); + let test_file = tmp_dir.join("listing_test.txt"); + { + let mut f = std::fs::File::create(&test_file).expect("Failed to create test file"); + f.write_all(b"test content").expect("Failed to write test file"); + } + + let server_bin = env!("CARGO_BIN_EXE_spin_server_up"); + let mut server = std::process::Command::new(server_bin) + .env("SMB_PORT", port.to_string()) + .env("SMB_SHARE_PATH", tmp_dir.to_str().unwrap()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .expect("Failed to spawn SMB server binary"); + + let addr = format!("127.0.0.1:{}", port); + for _ in 0..50 { + if std::net::TcpStream::connect(&addr).is_ok() { + break; + } + std::thread::sleep(Duration::from_millis(100)); + } + + let port_str = port.to_string(); + let (_success, _stdout, stderr) = run_smbclient(&[ + &format!("//127.0.0.1/test"), + "-p", &port_str, + "-U", "tejasmehta%password", + "-m", "SMB2", + "-c", "ls", + ]); + + // Server should not crash + std::thread::sleep(Duration::from_millis(200)); + let status = server.try_wait().expect("Failed to check server status"); + assert!( + status.is_none(), + "Server should still be running after directory listing. stderr: {}", + stderr + ); + + server.kill().ok(); + let _ = std::fs::remove_dir_all(&tmp_dir); +} + +/// Verify that reading a nonexistent file returns an error without crashing. +#[test] +#[ignore] +fn read_nonexistent_file_returns_error() { + let port = free_port(); + + let tmp_dir = std::env::temp_dir().join(format!("smb_test_nofile_{}", port)); + std::fs::create_dir_all(&tmp_dir).expect("Failed to create temp dir"); + + let server_bin = env!("CARGO_BIN_EXE_spin_server_up"); + let mut server = std::process::Command::new(server_bin) + .env("SMB_PORT", port.to_string()) + .env("SMB_SHARE_PATH", tmp_dir.to_str().unwrap()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .expect("Failed to spawn SMB server binary"); + + let addr = format!("127.0.0.1:{}", port); + for _ in 0..50 { + if std::net::TcpStream::connect(&addr).is_ok() { + break; + } + std::thread::sleep(Duration::from_millis(100)); + } + + let port_str = port.to_string(); + let (success, stdout, stderr) = run_smbclient(&[ + &format!("//127.0.0.1/test"), + "-p", &port_str, + "-U", "tejasmehta%password", + "-m", "SMB2", + "-c", "get nonexistent_file.txt /dev/null", + ]); + + // Should fail (file doesn't exist) + assert!( + !success || stdout.contains("NT_STATUS_") || stderr.contains("NT_STATUS_"), + "Reading nonexistent file should fail. stdout: {} stderr: {}", + stdout, stderr + ); + + // Server should not crash + std::thread::sleep(Duration::from_millis(200)); + let status = server.try_wait().expect("Failed to check server status"); + assert!( + status.is_none(), + "Server should still be running after failed file read. stderr: {}", + stderr + ); + + server.kill().ok(); + let _ = std::fs::remove_dir_all(&tmp_dir); +} + // --------------------------------------------------------------------------- // Echo Tests // ---------------------------------------------------------------------------