diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index 3caa60b..a5d8412 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -30,3 +30,14 @@ jobs: - name: clippy run: cargo clippy --workspace --features server -- -D warnings + + fmt: + name: Rustfmt + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@nightly + with: + components: rustfmt + - name: rustfmt + run: cargo fmt --all --check diff --git a/.gitignore b/.gitignore index f40642a..2252e6d 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ **/.vscode/ IMPLEMENTATION_PLAN.md AUDIT.md +.claude/ diff --git a/smb-core/src/error.rs b/smb-core/src/error.rs index 3598d8e..e9e161c 100644 --- a/smb-core/src/error.rs +++ b/smb-core/src/error.rs @@ -53,7 +53,7 @@ pub struct SMBParseError { impl>> From for SMBParseError { fn from(value: T) -> Self { Self { - error: value.into() + error: value.into(), } } } @@ -72,12 +72,11 @@ pub struct SMBCryptoError { impl>> From for SMBCryptoError { fn from(value: T) -> Self { Self { - message: value.into() + message: value.into(), } } } - impl Display for SMBCryptoError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "Crypto operation failed with error: {}", self.message) @@ -92,14 +91,18 @@ pub struct SMBPreconditionFailedError { impl> From for SMBPreconditionFailedError { fn from(value: T) -> Self { Self { - message: value.into() + message: value.into(), } } } impl Display for SMBPreconditionFailedError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "Operation failed with unmet precondition: {}", self.message) + write!( + f, + "Operation failed with unmet precondition: {}", + self.message + ) } } @@ -111,7 +114,7 @@ pub struct SMBIOError { impl> From for SMBIOError { fn from(value: T) -> Self { Self { - error: value.into() + error: value.into(), } } } @@ -136,7 +139,7 @@ impl SMBResponseError { impl> From for SMBResponseError { fn from(value: T) -> Self { Self { - status: value.into() + status: value.into(), } } } @@ -164,7 +167,11 @@ impl, U: Into> From<(T, U)> for SMBPayloadTooSmallError { impl Display for SMBPayloadTooSmallError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "Expected {} bytes, was actually {} bytes", self.expected, self.actual) + write!( + f, + "Expected {} bytes, was actually {} bytes", + self.expected, self.actual + ) } } @@ -176,7 +183,7 @@ pub struct SMBServerError { impl>> From for SMBServerError { fn from(value: T) -> Self { Self { - error: value.into() + error: value.into(), } } } @@ -196,7 +203,7 @@ impl Display for SMBError { Self::IOError(x) => write!(f, "{}", x), Self::ResponseError(x) => write!(f, "{}", x), Self::PayloadTooSmall(x) => write!(f, "{}", x), - Self::ServerError(x) => write!(f, "{}", x) + Self::ServerError(x) => write!(f, "{}", x), } } } @@ -211,7 +218,11 @@ mod tests { fn server_error_display_says_server() { let err = SMBError::server_error("something broke"); let msg = format!("{}", err); - assert!(msg.contains("Server operation failed"), "ServerError Display should say 'Server operation failed', got: {}", msg); + assert!( + msg.contains("Server operation failed"), + "ServerError Display should say 'Server operation failed', got: {}", + msg + ); assert!(msg.contains("something broke")); } @@ -219,14 +230,22 @@ mod tests { fn parse_error_display_says_parse() { let err = SMBError::parse_error("bad bytes"); let msg = format!("{}", err); - assert!(msg.contains("Parse failed"), "ParseError Display should say 'Parse failed', got: {}", msg); + assert!( + msg.contains("Parse failed"), + "ParseError Display should say 'Parse failed', got: {}", + msg + ); } #[test] fn crypto_error_display_says_crypto() { let err = SMBError::crypto_error("bad key"); let msg = format!("{}", err); - assert!(msg.contains("Crypto operation failed"), "CryptoError Display should say 'Crypto operation failed', got: {}", msg); + assert!( + msg.contains("Crypto operation failed"), + "CryptoError Display should say 'Crypto operation failed', got: {}", + msg + ); } #[test] @@ -241,6 +260,9 @@ mod tests { fn response_error_display() { let err = SMBError::response_error(NTStatus::AccessDenied); let msg = format!("{}", err); - assert!(msg.contains("AccessDenied"), "should mention the NTStatus variant"); + assert!( + msg.contains("AccessDenied"), + "should mention the NTStatus variant" + ); } -} \ No newline at end of file +} diff --git a/smb-core/src/lib.rs b/smb-core/src/lib.rs index d415362..badaec3 100644 --- a/smb-core/src/lib.rs +++ b/smb-core/src/lib.rs @@ -17,7 +17,9 @@ pub trait SMBByteSize { } pub trait SMBFromBytes: SMBByteSize { - fn smb_from_bytes(input: &[u8]) -> SMBParseResult<&[u8], Self> where Self: Sized; + fn smb_from_bytes(input: &[u8]) -> SMBParseResult<&[u8], Self> + where + Self: Sized; } pub trait SMBToBytes: SMBByteSize { @@ -44,7 +46,10 @@ impl SMBVecByteSize for Vec { } impl SMBFromBytes for PhantomData { - fn smb_from_bytes(input: &[u8]) -> SMBParseResult<&[u8], Self> where Self: Sized { + fn smb_from_bytes(input: &[u8]) -> SMBParseResult<&[u8], Self> + where + Self: Sized, + { let (remaining, _) = T::smb_from_bytes(input)?; Ok((remaining, PhantomData)) } @@ -63,10 +68,17 @@ impl SMBByteSize for PhantomData { } impl SMBVecFromBytesCnt for String { - fn smb_from_bytes_vec_cnt(input: &[u8], align: usize, count: usize) -> SMBParseResult<&[u8], Self> where Self: Sized { + fn smb_from_bytes_vec_cnt( + input: &[u8], + align: usize, + count: usize, + ) -> SMBParseResult<&[u8], Self> + where + Self: Sized, + { let (remaining, vec) = >::smb_from_bytes_vec_cnt(input, align, count)?; - let str = String::from_utf8(vec) - .map_err(|_e| SMBError::parse_error("Invalid byte slice"))?; + let str = + String::from_utf8(vec).map_err(|_e| SMBError::parse_error("Invalid byte slice"))?; Ok((remaining, str)) } } @@ -78,19 +90,40 @@ impl SMBVecByteSize for String { } pub trait SMBVecFromBytesCnt { - fn smb_from_bytes_vec_cnt(input: &[u8], align: usize, count: usize) -> SMBParseResult<&[u8], Self> where Self: Sized; + fn smb_from_bytes_vec_cnt( + input: &[u8], + align: usize, + count: usize, + ) -> SMBParseResult<&[u8], Self> + where + Self: Sized; } pub trait SMBVecFromBytesLen { - fn smb_from_bytes_vec_len(input: &[u8], align: usize, len: usize) -> SMBParseResult<&[u8], Self> where Self: Sized; + fn smb_from_bytes_vec_len( + input: &[u8], + align: usize, + len: usize, + ) -> SMBParseResult<&[u8], Self> + where + Self: Sized; } pub trait SMBEnumFromBytes { - fn smb_enum_from_bytes(input: &[u8], discriminator: u64) -> SMBParseResult<&[u8], Self> where Self: Sized; + fn smb_enum_from_bytes(input: &[u8], discriminator: u64) -> SMBParseResult<&[u8], Self> + where + Self: Sized; } impl SMBVecFromBytesCnt for Vec { - fn smb_from_bytes_vec_cnt(input: &[u8], align: usize, count: usize) -> SMBParseResult<&[u8], Self> where Self: Sized { + fn smb_from_bytes_vec_cnt( + input: &[u8], + align: usize, + count: usize, + ) -> SMBParseResult<&[u8], Self> + where + Self: Sized, + { let mut remaining = input; let mut done_cnt = 0; let mut msg_vec = Vec::::new(); @@ -115,7 +148,10 @@ impl SMBVecFromBytesCnt for Vec { } impl SMBVecFromBytesLen for Vec { - fn smb_from_bytes_vec_len(input: &[u8], align: usize, len: usize) -> SMBParseResult<&[u8], Self> where Self: Sized { + fn smb_from_bytes_vec_len(input: &[u8], align: usize, len: usize) -> SMBParseResult<&[u8], Self> + where + Self: Sized, + { let mut remaining = input; let mut msg_vec = Vec::::new(); let mut pos = 0; @@ -124,7 +160,7 @@ impl SMBVecFromBytesLen for Vec { remaining = &remaining[extra..]; let (_, val) = T::smb_from_bytes(remaining)?; let size = T::smb_byte_size(&val); - pos += size; + pos += size; extra = if align > 0 && !pos.is_multiple_of(align) { align - (pos % align) } else { @@ -139,9 +175,12 @@ impl SMBVecFromBytesLen for Vec { } impl SMBFromBytes for Uuid { - fn smb_from_bytes(input: &[u8]) -> SMBParseResult<&[u8], Self> where Self: Sized { + fn smb_from_bytes(input: &[u8]) -> SMBParseResult<&[u8], Self> + where + Self: Sized, + { if 16 > input.len() { - return Err(SMBError::payload_too_small(16usize, input.len())) + return Err(SMBError::payload_too_small(16usize, input.len())); } let uuid = Uuid::from_slice(&input[0..16]) .map_err(|_e| SMBError::parse_error("Invalid byte slice"))?; @@ -170,7 +209,7 @@ macro_rules! impl_parse_fixed_slice { let res = <[u8; $size]>::try_from(&$input[0..$size]) .map_err(|_e| SMBError::parse_error("Invalid byte slice"))?; Ok((&$input[$size..], res)) - }} + }}; } macro_rules! impl_smb_byte_size_for_slice {( @@ -278,4 +317,4 @@ impl_smb_from_bytes_unsigned_type! { impl_smb_to_bytes_unsigned_type! { u8 u16 u32 u64 u128 -} \ No newline at end of file +} diff --git a/smb-core/src/logging.rs b/smb-core/src/logging.rs index 71c49ed..b2d4813 100644 --- a/smb-core/src/logging.rs +++ b/smb-core/src/logging.rs @@ -4,55 +4,71 @@ /// macros from the `tracing` crate. When disabled, they compile to no-ops. #[cfg(feature = "tracing")] -pub use tracing::{trace, debug, info, warn, error, info_span, debug_span, trace_span}; +pub use tracing::{debug, debug_span, error, info, info_span, trace, trace_span, warn}; #[cfg(not(feature = "tracing"))] #[macro_export] macro_rules! trace { - ($($t:tt)*) => {()}; + ($($t:tt)*) => { + () + }; } #[cfg(not(feature = "tracing"))] #[macro_export] macro_rules! debug { - ($($t:tt)*) => {()}; + ($($t:tt)*) => { + () + }; } #[cfg(not(feature = "tracing"))] #[macro_export] macro_rules! info { - ($($t:tt)*) => {()}; + ($($t:tt)*) => { + () + }; } #[cfg(not(feature = "tracing"))] #[macro_export] macro_rules! warn { - ($($t:tt)*) => {()}; + ($($t:tt)*) => { + () + }; } #[cfg(not(feature = "tracing"))] #[macro_export] macro_rules! error { - ($($t:tt)*) => {()}; + ($($t:tt)*) => { + () + }; } #[cfg(not(feature = "tracing"))] #[macro_export] macro_rules! info_span { - ($($t:tt)*) => {()}; + ($($t:tt)*) => { + () + }; } #[cfg(not(feature = "tracing"))] #[macro_export] macro_rules! debug_span { - ($($t:tt)*) => {()}; + ($($t:tt)*) => { + () + }; } #[cfg(not(feature = "tracing"))] #[macro_export] macro_rules! trace_span { - ($($t:tt)*) => {()}; + ($($t:tt)*) => { + () + }; } #[cfg(not(feature = "tracing"))] -pub use crate::{trace, debug, info, warn, error, info_span, debug_span, trace_span}; +pub use crate::{debug, debug_span, error, info, info_span, trace, trace_span, warn}; diff --git a/smb-core/src/nt_status.rs b/smb-core/src/nt_status.rs index 6eb985e..9c6e60d 100644 --- a/smb-core/src/nt_status.rs +++ b/smb-core/src/nt_status.rs @@ -1,8 +1,8 @@ use num_enum::TryFromPrimitive; use serde::{Deserialize, Serialize}; -use crate::{SMBByteSize, SMBFromBytes, SMBParseResult, SMBToBytes}; use crate::error::SMBError; +use crate::{SMBByteSize, SMBFromBytes, SMBParseResult, SMBToBytes}; #[repr(u32)] #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, TryFromPrimitive, Copy)] @@ -29,13 +29,14 @@ impl SMBByteSize for NTStatus { } impl SMBFromBytes for NTStatus { - fn smb_from_bytes(input: &[u8]) -> SMBParseResult<&[u8], Self> where Self: Sized { - u32::smb_from_bytes(input) - .map(|(remaining, underlying)| { - let res = Self::try_from_primitive(underlying) - .map_err(SMBError::parse_error)?; - Ok((remaining, res)) - })? + fn smb_from_bytes(input: &[u8]) -> SMBParseResult<&[u8], Self> + where + Self: Sized, + { + u32::smb_from_bytes(input).map(|(remaining, underlying)| { + let res = Self::try_from_primitive(underlying).map_err(SMBError::parse_error)?; + Ok((remaining, res)) + })? } } @@ -43,4 +44,4 @@ impl SMBToBytes for NTStatus { fn smb_to_bytes(&self) -> Vec { (*self as u32).smb_to_bytes() } -} \ No newline at end of file +} diff --git a/smb-derive/src/attrs.rs b/smb-derive/src/attrs.rs index 1e654ec..87217cf 100644 --- a/smb-derive/src/attrs.rs +++ b/smb-derive/src/attrs.rs @@ -1,12 +1,12 @@ use std::default::Default; -use darling::{FromAttributes, FromDeriveInput, FromField, FromMeta}; use darling::ast::NestedMeta; +use darling::{FromAttributes, FromDeriveInput, FromField, FromMeta}; use proc_macro2::{Ident, TokenStream}; use quote::{format_ident, quote, quote_spanned}; -use syn::{Attribute, DeriveInput, Expr, Lit, Meta, Path, Token, Type, TypePath}; use syn::punctuated::Punctuated; use syn::spanned::Spanned; +use syn::{Attribute, DeriveInput, Expr, Lit, Meta, Path, Token, Type, TypePath}; /// Construct a [`syn::Type`] from a primitive type name string (e.g. `"u16"`, /// `"usize"`), using the span of `spanned` for error reporting. @@ -71,7 +71,12 @@ impl DirectInner { /// Generate a token stream that serializes a value back into the output /// buffer at `start`, adding back the `subtract` offset and clamping to /// `min_val`. - fn smb_to_bytes(&self, name: &str, spanned: &T, name_val: Option) -> TokenStream { + fn smb_to_bytes( + &self, + name: &str, + spanned: &T, + name_val: Option, + ) -> TokenStream { let start = self.start; let subtract = self.subtract; let name = format_ident!("{}", name); @@ -129,7 +134,8 @@ impl DirectInner { pub enum AttributeInfo { Fixed(usize), Inner(DirectInner), - #[default] CurrentPos, + #[default] + CurrentPos, NullTerminated(String), } @@ -141,24 +147,30 @@ impl FromMeta for AttributeInfo { && let Expr::Lit(lit) = &meta.value && let Lit::Int(int) = &lit.lit { - return Ok(AttributeInfo::Fixed(int.base10_parse::()?)) + return Ok(AttributeInfo::Fixed(int.base10_parse::()?)); } } else if let NestedMeta::Meta(Meta::List(list)) = item { if list.path.is_ident("inner") { - return Ok(AttributeInfo::Inner(DirectInner::from_nested_meta(item)?)) + return Ok(AttributeInfo::Inner(DirectInner::from_nested_meta(item)?)); } else if list.path.is_ident("null_terminated") { - return Ok(AttributeInfo::NullTerminated(String::from_nested_meta(item)?)) + return Ok(AttributeInfo::NullTerminated(String::from_nested_meta( + item, + )?)); } } } - Err(darling::Error::missing_field("fixed | current_pos | inner | null_terminated")) + Err(darling::Error::missing_field( + "fixed | current_pos | inner | null_terminated", + )) } fn from_string(value: &str) -> darling::Result { match value.to_lowercase().trim().replace([' ', '_'], "").as_str() { "currentpos" => Ok(AttributeInfo::CurrentPos), "nullterminated" => Ok(AttributeInfo::NullTerminated("u8".into())), - _ => Err(darling::Error::missing_field("fixed | current_pos | inner | null_terminated")) + _ => Err(darling::Error::missing_field( + "fixed | current_pos | inner | null_terminated", + )), } } } @@ -192,7 +204,12 @@ impl AttributeInfo { } } - pub(crate) fn smb_to_bytes(&self, spanned: &T, name: &str, name_val: Option) -> TokenStream { + pub(crate) fn smb_to_bytes( + &self, + spanned: &T, + name: &str, + name_val: Option, + ) -> TokenStream { let name_ident = format_ident!("{}", name); match self { Self::CurrentPos => quote! { let #name_ident = current_pos; }, @@ -206,7 +223,7 @@ impl AttributeInfo { match self { Self::CurrentPos | Self::NullTerminated(_) => 0, Self::Fixed(pos) => *pos, - Self::Inner(inner) => inner.start + Self::Inner(inner) => inner.start, } } @@ -251,7 +268,12 @@ pub struct Direct { } impl Direct { - pub(crate) fn smb_from_bytes(&self, spanned: &T, name: &Ident, ty: &Type) -> TokenStream { + pub(crate) fn smb_from_bytes( + &self, + spanned: &T, + name: &Ident, + ty: &Type, + ) -> TokenStream { let start = self.start.smb_from_bytes(spanned, "item_start"); quote_spanned! { spanned.span() => #start @@ -274,7 +296,9 @@ impl Direct { } } - pub(crate) fn attr_byte_size(&self) -> usize { 0 } + pub(crate) fn attr_byte_size(&self) -> usize { + 0 + } } /// `#[smb_buffer(offset(…), length(…))]` — a variable-length byte buffer. @@ -314,9 +338,13 @@ impl Buffer { pub(crate) fn smb_to_bytes(&self, spanned: &T, token: &TokenStream) -> TokenStream { let offset_info = self.offset.smb_to_bytes(spanned, "offset", None); - let length_info = self.length.smb_to_bytes(spanned, "length", Some(quote! { - bytes.len() - })); + let length_info = self.length.smb_to_bytes( + spanned, + "length", + Some(quote! { + bytes.len() + }), + ); quote_spanned! {spanned.span()=> let bytes = #token; @@ -330,7 +358,9 @@ impl Buffer { } } - pub(crate) fn attr_byte_size(&self) -> usize { 0 } + pub(crate) fn attr_byte_size(&self) -> usize { + 0 + } } /// `#[smb_vector(count(…) | length(…), offset(…), align = N)]` — a vector of @@ -361,13 +391,22 @@ impl Vector { pub(crate) fn validate_attrs(self) -> darling::Result { let default = AttributeInfo::default(); if self.count == default && self.length == default { - return Err(darling::Error::custom("count or length must be specified for smb_vector types")); + return Err(darling::Error::custom( + "count or length must be specified for smb_vector types", + )); } else if self.count != default && self.length != default { - return Err(darling::Error::custom("only one of count or length can be specified for smb_vector types")); + return Err(darling::Error::custom( + "only one of count or length can be specified for smb_vector types", + )); } Ok(self) } - pub(crate) fn smb_from_bytes(&self, spanned: &T, name: &Ident, ty: &Type) -> TokenStream { + pub(crate) fn smb_from_bytes( + &self, + spanned: &T, + name: &Ident, + ty: &Type, + ) -> TokenStream { let vec_count_or_len = if self.count == AttributeInfo::default() { self.length.smb_from_bytes(spanned, "item_length") } else { @@ -401,20 +440,32 @@ impl Vector { } } - pub(crate) fn smb_to_bytes(&self, spanned: &T, raw_token: &TokenStream) -> TokenStream { + pub(crate) fn smb_to_bytes( + &self, + spanned: &T, + raw_token: &TokenStream, + ) -> TokenStream { let count_info = if self.count == AttributeInfo::default() { quote! {} } else { - self.count.smb_to_bytes(spanned, "item_count", Some(quote! { - #raw_token.len() - })) + self.count.smb_to_bytes( + spanned, + "item_count", + Some(quote! { + #raw_token.len() + }), + ) }; let len_info = if self.length == AttributeInfo::default() { quote! {} } else { - self.length.smb_to_bytes(spanned, "item_length", Some(quote! { - byte_size - })) + self.length.smb_to_bytes( + spanned, + "item_length", + Some(quote! { + byte_size + }), + ) }; let offset_info = self.offset.smb_to_bytes(spanned, "item_offset", None); let align = self.align; @@ -442,7 +493,9 @@ impl Vector { } } - pub(crate) fn attr_byte_size(&self) -> usize { 0 } + pub(crate) fn attr_byte_size(&self) -> usize { + 0 + } } /// `#[smb_string(length(…), underlying = "u16", …)]` — a UTF-8 or UTF-16LE @@ -472,7 +525,7 @@ impl SMBString { order, start, mut length, - underlying + underlying, } = self; if let AttributeInfo::NullTerminated(_) = &length { length = AttributeInfo::NullTerminated(underlying.clone()) @@ -495,7 +548,7 @@ impl SMBString { "u16" => quote! { let #name = String::from_utf16(&#vec_name).map_err(|e| ::smb_core::error::SMBError::parse_error("Invalid UTF-16 string"))?; }, - _ => quote! {} + _ => quote! {}, }; let num_type = get_type(&self.underlying, spanned); @@ -513,7 +566,11 @@ impl SMBString { } } - pub(crate) fn smb_to_bytes(&self, spanned: &T, raw_token: &TokenStream) -> TokenStream { + pub(crate) fn smb_to_bytes( + &self, + spanned: &T, + raw_token: &TokenStream, + ) -> TokenStream { let (count_expr, string_to_bytes) = match self.underlying.as_str() { "u8" => ( quote! { #raw_token.len() }, @@ -525,7 +582,9 @@ impl SMBString { ), _ => (quote! { 0 }, quote! {}), }; - let count_info = self.length.smb_to_bytes(spanned, "item_count", Some(count_expr)); + let count_info = self + .length + .smb_to_bytes(spanned, "item_count", Some(count_expr)); let offset_info = self.start.smb_to_bytes(spanned, "item_offset", None); quote_spanned! { spanned.span()=> #count_info @@ -540,7 +599,9 @@ impl SMBString { } } - pub(crate) fn attr_byte_size(&self) -> usize { 0 } + pub(crate) fn attr_byte_size(&self) -> usize { + 0 + } } /// `#[smb_discriminator(value = 0x…)]` — marks a discriminated enum variant @@ -565,7 +626,8 @@ pub struct Discriminator { /// access mask type from a combined flags field). #[derive(Debug, Default, PartialEq, Eq, FromMeta)] pub enum SMBAttributeModifier { - #[default] None, + #[default] + None, And(u64), Or(u64), RightShift(u64), @@ -573,7 +635,12 @@ pub enum SMBAttributeModifier { } impl SMBAttributeModifier { - pub(crate) fn smb_from_bytes(&self, spanned: &T, name: &Ident, name_ty: &Type) -> TokenStream { + pub(crate) fn smb_from_bytes( + &self, + spanned: &T, + name: &Ident, + name_ty: &Type, + ) -> TokenStream { match self { SMBAttributeModifier::None => quote! {}, SMBAttributeModifier::And(value) => quote_spanned! {spanned.span()=> @@ -612,7 +679,7 @@ pub struct SMBEnum { #[darling(multiple, default, rename = "modifier")] pub modifiers: Vec, #[darling(default = "SMBEnum::default_should_write")] - pub should_write: bool + pub should_write: bool, } impl SMBEnum { @@ -620,14 +687,18 @@ impl SMBEnum { true } pub(crate) fn smb_from_bytes(&self, spanned: &T, name: &Ident) -> TokenStream { - let discriminator_info = self.discriminator.smb_from_bytes(spanned, "item_discriminator"); + let discriminator_info = self + .discriminator + .smb_from_bytes(spanned, "item_discriminator"); let start_info = self.start.smb_from_bytes(spanned, "item_start"); let discrim_type = match &self.discriminator { AttributeInfo::Inner(inner) => get_type(&inner.num_type, spanned), - _ => get_type("usize", spanned) + _ => get_type("usize", spanned), }; let discrim_ident = format_ident!("item_discriminator"); - let all_modifier_ops: Vec = self.modifiers.iter() + let all_modifier_ops: Vec = self + .modifiers + .iter() .map(|modifier| modifier.smb_from_bytes(spanned, &discrim_ident, &discrim_type)) .collect(); let modifier_info = quote_spanned! {spanned.span()=> @@ -660,7 +731,9 @@ impl SMBEnum { } } - pub(crate) fn attr_byte_size(&self) -> usize { 0 } + pub(crate) fn attr_byte_size(&self) -> usize { + 0 + } } /// `#[smb_byte_tag(value = 0xNN)]` — a single-byte sentinel/tag. @@ -700,7 +773,9 @@ impl ByteTag { } } - pub(crate) fn attr_byte_size(&self) -> usize { 1 } + pub(crate) fn attr_byte_size(&self) -> usize { + 1 + } } /// `#[smb_string_tag(value = "SMB")]` — a multi-byte string sentinel/tag. @@ -750,7 +825,9 @@ impl StringTag { } } - pub(crate) fn attr_byte_size(&self) -> usize { self.value.len() } + pub(crate) fn attr_byte_size(&self) -> usize { + self.value.len() + } } /// Extracts the `#[repr(uN)]` type from an enum's attributes. @@ -779,9 +856,18 @@ pub struct Skip { impl Skip { pub(crate) fn new(start: usize, length: usize) -> Self { - Self { start, length, value: Vec::new() } + Self { + start, + length, + value: Vec::new(), + } } - pub(crate) fn smb_from_bytes(&self, spanned: &T, name: &Ident, ty: &Type) -> TokenStream { + pub(crate) fn smb_from_bytes( + &self, + spanned: &T, + name: &Ident, + ty: &Type, + ) -> TokenStream { let start = self.start; let length = self.length; @@ -808,7 +894,9 @@ impl Skip { } } - pub(crate) fn attr_byte_size(&self) -> usize { 0 } + pub(crate) fn attr_byte_size(&self) -> usize { + 0 + } } impl FromDeriveInput for Repr { @@ -821,14 +909,15 @@ impl FromAttributes for Repr { fn from_attributes(attrs: &[Attribute]) -> darling::Result { for attr in attrs.iter() { if attr.path().is_ident("repr") { - let nested = attr.parse_args_with(Punctuated::::parse_terminated)?; + let nested = + attr.parse_args_with(Punctuated::::parse_terminated)?; for meta in nested { if let Meta::Path(p) = meta && let Some(ident) = p.get_ident() { return Ok(Self { - ident: ident.clone() - }) + ident: ident.clone(), + }); } } } @@ -865,7 +954,13 @@ mod tests { #[derive(Debug)] }; let struct_buffer: AttrsTestStruct = syn::parse2(struct_stream).unwrap(); - assert_eq!(Repr::from_attributes(&struct_buffer.attrs).unwrap().ident.to_string(), "u8"); + assert_eq!( + Repr::from_attributes(&struct_buffer.attrs) + .unwrap() + .ident + .to_string(), + "u8" + ); } #[test] @@ -874,4 +969,3 @@ mod tests { // let skip_to_bytes= skip.smb_to_bytes(); } } - diff --git a/smb-derive/src/field.rs b/smb-derive/src/field.rs index 60b43df..76e42d2 100644 --- a/smb-derive/src/field.rs +++ b/smb-derive/src/field.rs @@ -1,14 +1,16 @@ -use std::cmp::{max, Ordering}; +use std::cmp::{Ordering, max}; use std::fmt::Debug; use darling::FromAttributes; use proc_macro2::{Delimiter, Group, Ident, TokenStream, TokenTree}; use quote::{format_ident, quote, quote_spanned}; -use syn::{Attribute, Field, Type}; use syn::spanned::Spanned; +use syn::{Attribute, Field, Type}; -use crate::attrs::{AttributeInfo, Buffer, ByteTag, Direct, Skip, SMBEnum, SMBString, StringTag, Vector}; use crate::SMBDeriveError; +use crate::attrs::{ + AttributeInfo, Buffer, ByteTag, Direct, SMBEnum, SMBString, Skip, StringTag, Vector, +}; /// A single field within an SMB struct or enum variant, together with its /// parsed attribute metadata. @@ -64,7 +66,10 @@ impl<'a, T: Spanned> SMBField<'a, T> { let field = self.spanned; let ty = &self.ty; let _name_str = name.to_string(); - let all_bytes = self.val_type.iter().map(|field_ty| field_ty.smb_from_bytes(name, field, ty)); + let all_bytes = self + .val_type + .iter() + .map(|field_ty| field_ty.smb_from_bytes(name, field, ty)); quote! { #(#all_bytes)* } @@ -84,7 +89,10 @@ impl<'a, T: Spanned> SMBField<'a, T> { }; let field = self.spanned; let _ty = &self.ty; - let all_bytes = self.val_type.iter().map(|field_ty| field_ty.smb_to_bytes(&name_token_adj, &raw_token, field)); + let all_bytes = self + .val_type + .iter() + .map(|field_ty| field_ty.smb_to_bytes(&name_token_adj, &raw_token, field)); quote! { #(#all_bytes)* } @@ -98,7 +106,10 @@ impl<'a, T: Spanned> SMBField<'a, T> { &#group }; let field = self.spanned; - let all_bytes = self.val_type.iter().map(|field_ty| field_ty.smb_to_bytes(&token_adj, &raw_token, field)); + let all_bytes = self + .val_type + .iter() + .map(|field_ty| field_ty.smb_to_bytes(&token_adj, &raw_token, field)); quote! { #(#all_bytes)* } @@ -132,12 +143,14 @@ impl SMBField<'_, T> { } pub(crate) fn get_named_token(&self) -> TokenStream { - format!("&self.{}", &self.name.to_string()).parse() + format!("&self.{}", &self.name.to_string()) + .parse() .unwrap_or_else(|_e| Self::error(self.spanned)) } pub(crate) fn get_unnamed_token(&self, idx: usize) -> TokenStream { - format!("&self.{}", idx).parse() + format!("&self.{}", idx) + .parse() .unwrap_or_else(|_e| Self::error(self.spanned)) } @@ -149,7 +162,9 @@ impl SMBField<'_, T> { } pub(crate) fn get_disc_enum_token(&self) -> TokenStream { - format!("Self::{}", &self.name.to_string()).parse().unwrap_or_else(|_e| Self::error(self.spanned)) + format!("Self::{}", &self.name.to_string()) + .parse() + .unwrap_or_else(|_e| Self::error(self.spanned)) } pub(crate) fn get_smb_message_size(&self, size_tokens: TokenStream) -> TokenStream { @@ -170,11 +185,7 @@ impl SMBField<'_, T> { let align = if let SMBFieldType::Vector(vec) = ty { if vec.align > 0 { vec.align } else { 1 } } else if let SMBFieldType::String(str) = ty { - if str.underlying == "u8" { - 1 - } else { - 2 - } + if str.underlying == "u8" { 1 } else { 2 } } else { 1 }; @@ -201,16 +212,18 @@ impl SMBField<'_, T> { } else { (l.get_pos(), l.get_type(&self.spanned.span())) } - }, + } (Some(o), None) => (o.get_pos(), o.get_type(&self.spanned.span())), (None, Some(l)) => (l.get_pos(), l.get_type(&self.spanned.span())), - _ => (0, None) + _ => (0, None), }; let buffer_min_pos = offset.map(AttributeInfo::get_min_val).unwrap_or(0); let attr_start_ty = match attr_ty { - Some(ty) => quote! { ::std::cmp::max(#buffer_min_pos, #attr_start + std::mem::size_of::<#ty>())}, + Some(ty) => { + quote! { ::std::cmp::max(#buffer_min_pos, #attr_start + std::mem::size_of::<#ty>())} + } None => quote! { ::std::cmp::max(#attr_start, #buffer_min_pos) }, }; @@ -227,21 +240,25 @@ impl SMBField<'_, T> { } impl<'a> SMBField<'a, Field> { - pub(crate) fn from_iter>(fields: U) -> Result, SMBDeriveError> { - fields.enumerate().map(|(idx, field)| { - let val_types = field.attrs.iter().map(|attr| get_field_types(field, std::slice::from_ref(attr))).collect::, SMBDeriveError>>()?; - let name = if let Some(x) = &field.ident { - x.clone() - } else { - format_ident!("val_{}", idx) - }; - Ok(SMBField::new( - field, - name, - field.ty.clone(), - val_types, - )) - }).collect::, SMBDeriveError>>>() + pub(crate) fn from_iter>( + fields: U, + ) -> Result, SMBDeriveError> { + fields + .enumerate() + .map(|(idx, field)| { + let val_types = field + .attrs + .iter() + .map(|attr| get_field_types(field, std::slice::from_ref(attr))) + .collect::, SMBDeriveError>>()?; + let name = if let Some(x) = &field.ident { + x.clone() + } else { + format_ident!("val_{}", idx) + }; + Ok(SMBField::new(field, name, field.ty.clone(), val_types)) + }) + .collect::, SMBDeriveError>>>() .into_iter() .collect::>, SMBDeriveError>>() } @@ -260,7 +277,12 @@ impl SMBFieldType { SMBFieldType::StringTag(string_tag) => string_tag.smb_from_bytes(field), } } - fn smb_to_bytes(&self, token: &TokenStream, raw_token: &TokenStream, field: &T) -> TokenStream { + fn smb_to_bytes( + &self, + token: &TokenStream, + raw_token: &TokenStream, + field: &T, + ) -> TokenStream { match self { SMBFieldType::Direct(direct) => direct.smb_to_bytes(field, token), SMBFieldType::Buffer(buffer) => buffer.smb_to_bytes(field, token), @@ -287,11 +309,15 @@ impl SMBFieldType { } impl PartialOrd for SMBFieldType { - fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } } impl PartialOrd for SMBField<'_, T> { - fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } } impl Ord for SMBField<'_, T> { @@ -315,16 +341,16 @@ impl SMBFieldType { match self { Self::Direct(x) => match x.start { AttributeInfo::Fixed(idx) => idx, - AttributeInfo::CurrentPos | - AttributeInfo::Inner(_) | - AttributeInfo::NullTerminated(_) => x.order + AttributeInfo::CurrentPos + | AttributeInfo::Inner(_) + | AttributeInfo::NullTerminated(_) => x.order, }, Self::Enum(x) => match x.start { AttributeInfo::Fixed(idx) => idx, - AttributeInfo::CurrentPos | - AttributeInfo::Inner(_) | - AttributeInfo::NullTerminated(_) => x.order - } + AttributeInfo::CurrentPos + | AttributeInfo::Inner(_) + | AttributeInfo::NullTerminated(_) => x.order, + }, Self::Buffer(x) => x.order, Self::Vector(x) => x.order, Self::String(x) => x.order, @@ -366,7 +392,10 @@ impl FromAttributes for SMBFieldType { } } -fn get_field_types(field: &Field, attrs: &[Attribute]) -> Result> { +fn get_field_types( + field: &Field, + attrs: &[Attribute], +) -> Result> { SMBFieldType::from_attributes(attrs) .map_err(|_e| SMBDeriveError::TypeError(Box::new(field.clone()))) -} \ No newline at end of file +} diff --git a/smb-derive/src/field_mapping.rs b/smb-derive/src/field_mapping.rs index ba9bd16..865b5be 100644 --- a/smb-derive/src/field_mapping.rs +++ b/smb-derive/src/field_mapping.rs @@ -3,14 +3,17 @@ use std::fmt::Debug; use darling::FromAttributes; use proc_macro2::Ident; use quote::{format_ident, quote, quote_spanned}; -use syn::{AngleBracketedGenericArguments, Attribute, DataEnum, DeriveInput, Field, Fields, GenericArgument, Path, PathArguments, PathSegment, Token, Type, TypePath}; use syn::punctuated::Punctuated; use syn::spanned::Spanned; use syn::token::PathSep; +use syn::{ + AngleBracketedGenericArguments, Attribute, DataEnum, DeriveInput, Field, Fields, + GenericArgument, Path, PathArguments, PathSegment, Token, Type, TypePath, +}; +use crate::SMBDeriveError; use crate::attrs::{AttributeInfo, Direct, Discriminator, Repr}; use crate::field::{SMBField, SMBFieldType}; -use crate::SMBDeriveError; /// Maps a single struct or enum variant to its parent-level attributes and /// ordered child fields. @@ -24,7 +27,7 @@ pub struct SMBFieldMapping<'a, T: Spanned + PartialEq + Eq, U: Spanned + Partial fields: Vec>, mapping_type: SMBFieldMappingType, discriminators: Vec, - variant_ident: Option + variant_ident: Option, } /// Classifies the shape of the type being derived so that code generation can @@ -45,43 +48,64 @@ impl SMBFieldM let parent_size = self.parent.attr_byte_size(); let variant = self.variant_ident.is_some(); let size = match &self.mapping_type { - SMBFieldMappingType::NamedStruct => self.fields.iter().map(|f| { - let token = match variant { - true => f.get_name(), - false => f.get_named_token(), - }; - f.get_smb_message_size(token.clone()) - }).collect(), - SMBFieldMappingType::UnnamedStruct => self.fields.iter().enumerate().map(|(idx, f)| { - let token = match variant { - true => f.get_name(), - false => f.get_unnamed_token(idx), - }; - f.get_smb_message_size(token) - }).collect(), - SMBFieldMappingType::NumEnum => self.fields.iter().map(|f| { - let token = match variant { - true => f.get_name(), - false => f.get_num_enum_token(), - }; - f.get_smb_message_size(token) - }).collect(), - SMBFieldMappingType::DiscriminatedEnum => self.fields.iter().map(|f| { - let token = match variant { - true => f.get_name(), - false => f.get_disc_enum_token(), - }; - f.get_smb_message_size(token) - }).collect(), - SMBFieldMappingType::Unit => vec![quote! { - - }] + SMBFieldMappingType::NamedStruct => self + .fields + .iter() + .map(|f| { + let token = match variant { + true => f.get_name(), + false => f.get_named_token(), + }; + f.get_smb_message_size(token.clone()) + }) + .collect(), + SMBFieldMappingType::UnnamedStruct => self + .fields + .iter() + .enumerate() + .map(|(idx, f)| { + let token = match variant { + true => f.get_name(), + false => f.get_unnamed_token(idx), + }; + f.get_smb_message_size(token) + }) + .collect(), + SMBFieldMappingType::NumEnum => self + .fields + .iter() + .map(|f| { + let token = match variant { + true => f.get_name(), + false => f.get_num_enum_token(), + }; + f.get_smb_message_size(token) + }) + .collect(), + SMBFieldMappingType::DiscriminatedEnum => self + .fields + .iter() + .map(|f| { + let token = match variant { + true => f.get_name(), + false => f.get_disc_enum_token(), + }; + f.get_smb_message_size(token) + }) + .collect(), + SMBFieldMappingType::Unit => vec![quote! {}], }; let names = self.fields.iter().map(|field| field.get_name()); - let key = self.variant_ident.clone().map(|variant| quote! { - Self::#variant(#(#names,)*) - }).unwrap_or(quote! {_}); + let key = self + .variant_ident + .clone() + .map(|variant| { + quote! { + Self::#variant(#(#names,)*) + } + }) + .unwrap_or(quote! {_}); quote! { #key => { @@ -104,7 +128,11 @@ pub(crate) fn enum_repr_type(attrs: &[Attribute]) -> darling::Result { /// /// The entire enum is treated as a single `Direct` field at offset 0 with the /// repr type. Parsing reads the raw integer and converts via `TryFrom`. -pub(crate) fn get_num_enum_mapping(input: &DeriveInput, parent_attrs: Vec, repr_type: Repr) -> Result, SMBDeriveError> { +pub(crate) fn get_num_enum_mapping( + input: &DeriveInput, + parent_attrs: Vec, + repr_type: Repr, +) -> Result, SMBDeriveError> { let identity = &repr_type.ident; let ty = Type::Path(TypePath { qself: None, @@ -133,12 +161,23 @@ pub(crate) fn get_num_enum_mapping(input: &DeriveInput, parent_attrs: Vec Result>, SMBDeriveError> { - info.variants.iter().map(|variant| { - let discriminators = Discriminator::from_attributes(&variant.attrs).map(|d| d.values.iter().map(|val| val | d.flag).collect()) - .map_err(|_e| SMBDeriveError::MissingField)?; - get_struct_field_mapping(&variant.fields, vec![SMBFieldType::from_attributes(&variant.attrs).unwrap()], discriminators, Some(variant.ident.clone())) - }).collect() +pub(crate) fn get_desc_enum_mapping( + info: &DataEnum, +) -> Result>, SMBDeriveError> { + info.variants + .iter() + .map(|variant| { + let discriminators = Discriminator::from_attributes(&variant.attrs) + .map(|d| d.values.iter().map(|val| val | d.flag).collect()) + .map_err(|_e| SMBDeriveError::MissingField)?; + get_struct_field_mapping( + &variant.fields, + vec![SMBFieldType::from_attributes(&variant.attrs).unwrap()], + discriminators, + Some(variant.ident.clone()), + ) + }) + .collect() } /// Build the field mapping for a struct (or a single enum variant's fields). @@ -147,17 +186,27 @@ pub(crate) fn get_desc_enum_mapping(info: &DataEnum) -> Result, discriminators: Vec, variant_ident: Option) -> Result, SMBDeriveError> { +pub(crate) fn get_struct_field_mapping( + struct_fields: &Fields, + parent_attrs: Vec, + discriminators: Vec, + variant_ident: Option, +) -> Result, SMBDeriveError> { if struct_fields.len() == 1 { - let field = struct_fields.iter().next() + let field = struct_fields + .iter() + .next() .ok_or(SMBDeriveError::InvalidType)?; let (field, val_types) = if !parent_attrs.is_empty() { (field, parent_attrs) } else { - (field, vec![SMBFieldType::Direct(Direct { - start: AttributeInfo::Fixed(0), - order: 0, - })]) + ( + field, + vec![SMBFieldType::Direct(Direct { + start: AttributeInfo::Fixed(0), + order: 0, + })], + ) }; let name = if let Some(x) = &field.ident { @@ -166,10 +215,14 @@ pub(crate) fn get_struct_field_mapping(struct_fields: &Fields, parent_attrs: Vec format_ident!("val_0") }; - let fields = vec![SMBField::new(field, name, field.ty.clone(), val_types)]; - let parent = SMBField::new(struct_fields, format_ident!("single_base"), field.ty.clone(), vec![]); + let parent = SMBField::new( + struct_fields, + format_ident!("single_base"), + field.ty.clone(), + vec![], + ); return if field.ident.is_some() { Ok(SMBFieldMapping { @@ -177,7 +230,7 @@ pub(crate) fn get_struct_field_mapping(struct_fields: &Fields, parent_attrs: Vec fields, mapping_type: SMBFieldMappingType::NamedStruct, discriminators, - variant_ident + variant_ident, }) } else { Ok(SMBFieldMapping { @@ -185,7 +238,7 @@ pub(crate) fn get_struct_field_mapping(struct_fields: &Fields, parent_attrs: Vec fields, mapping_type: SMBFieldMappingType::UnnamedStruct, discriminators, - variant_ident + variant_ident, }) }; } @@ -200,7 +253,7 @@ pub(crate) fn get_struct_field_mapping(struct_fields: &Fields, parent_attrs: Vec let mapping_type = match struct_fields { Fields::Named(_) => SMBFieldMappingType::NamedStruct, Fields::Unnamed(_) => SMBFieldMappingType::UnnamedStruct, - Fields::Unit => SMBFieldMappingType::Unit + Fields::Unit => SMBFieldMappingType::Unit, }; let spanned_field = struct_fields; @@ -240,24 +293,30 @@ pub(crate) fn get_struct_field_mapping(struct_fields: &Fields, parent_attrs: Vec path: phantom_data_path, }); - let parent = SMBField::new(spanned_field, format_ident!("structure_base"), phantom_ty, parent_attrs); + let parent = SMBField::new( + spanned_field, + format_ident!("structure_base"), + phantom_ty, + parent_attrs, + ); Ok(SMBFieldMapping { parent, fields: mapped_fields, mapping_type, discriminators, - variant_ident + variant_ident, }) } - /// Generate the body of `SMBFromBytes::smb_from_bytes` for a single mapping. /// /// Emits code that initializes `current_pos = 0`, processes parent attributes /// (tags), then parses each field in order and constructs the final /// `Ok((remaining, Self { … }))` return value. -pub(crate) fn smb_from_bytes(mapping: &SMBFieldMapping) -> proc_macro2::TokenStream { +pub(crate) fn smb_from_bytes( + mapping: &SMBFieldMapping, +) -> proc_macro2::TokenStream { let vector = &mapping.fields; let recurse = vector.iter().map(SMBField::smb_from_bytes); let parent = mapping.parent.smb_from_bytes(); @@ -287,7 +346,7 @@ pub(crate) fn smb_from_bytes { quote! {} } @@ -311,7 +370,9 @@ pub(crate) fn smb_from_bytes(mapping: &SMBFieldMapping) -> proc_macro2::TokenStream { +pub(crate) fn smb_enum_from_bytes( + mapping: &SMBFieldMapping, +) -> proc_macro2::TokenStream { let vector = &mapping.fields; let recurse = vector.iter().map(SMBField::smb_from_bytes); let parent = mapping.parent.smb_from_bytes(); @@ -319,7 +380,7 @@ pub(crate) fn smb_enum_from_bytes std::compile_error!("No variant identifier provided for enum field") - } + }; } let variant_ident = &mapping.variant_ident.clone().unwrap(); @@ -331,7 +392,7 @@ pub(crate) fn smb_enum_from_bytes { quote! { #(#recurse)* @@ -339,8 +400,10 @@ pub(crate) fn smb_enum_from_bytes panic!("Only enums with associated types can be used to derive SMBEnumFromBytes, please use SMBFromBytes for other types") + } + _ => panic!( + "Only enums with associated types can be used to derive SMBEnumFromBytes, please use SMBFromBytes for other types" + ), }; let tokens = quote! { @@ -352,8 +415,10 @@ pub(crate) fn smb_enum_from_bytes #tokens, + let recursive_mapping = mapping.discriminators.iter().map(|discriminator| { + quote! { + #discriminator => #tokens, + } }); quote! { @@ -365,24 +430,38 @@ pub(crate) fn smb_enum_from_bytes` of the correct size, writes parent attributes /// (tags), then serializes each field into its wire position. -pub(crate) fn smb_to_bytes(mapping: &SMBFieldMapping) -> proc_macro2::TokenStream { +pub(crate) fn smb_to_bytes( + mapping: &SMBFieldMapping, +) -> proc_macro2::TokenStream { let vector = &mapping.fields; let variant = mapping.variant_ident.is_some(); let parent = match mapping.mapping_type { SMBFieldMappingType::NumEnum => mapping.parent.smb_to_bytes_enum(), - _ => mapping.parent.smb_to_bytes_struct(variant) + _ => mapping.parent.smb_to_bytes_struct(variant), }; let recurse = match mapping.mapping_type { - SMBFieldMappingType::NumEnum => vector.iter().map(SMBField::smb_to_bytes_enum).collect::>(), - _ => vector.iter().map(|field| field.smb_to_bytes_struct(variant)).collect() + SMBFieldMappingType::NumEnum => vector + .iter() + .map(SMBField::smb_to_bytes_enum) + .collect::>(), + _ => vector + .iter() + .map(|field| field.smb_to_bytes_struct(variant)) + .collect(), }; let names = mapping.fields.iter().map(|field| field.get_name()); - let key = mapping.variant_ident.clone().map(|variant| quote! { - Self::#variant(#(#names,)*) - }).unwrap_or(quote! {_}); + let key = mapping + .variant_ident + .clone() + .map(|variant| { + quote! { + Self::#variant(#(#names,)*) + } + }) + .unwrap_or(quote! {_}); quote! { #key => { let mut current_pos = 0; @@ -392,4 +471,4 @@ pub(crate) fn smb_to_bytes TokenStream { let input: DeriveInput = parse_macro_input!(input); @@ -122,7 +136,20 @@ pub fn smb_from_bytes(input: TokenStream) -> TokenStream { /// /// The generated `smb_enum_from_bytes(input, discriminator)` matches the /// discriminator and delegates to the per-variant parser. -#[proc_macro_derive(SMBEnumFromBytes, attributes(smb_direct, smb_buffer, smb_vector, smb_string, smb_enum, smb_skip, smb_byte_tag, smb_string_tag, smb_discriminator))] +#[proc_macro_derive( + SMBEnumFromBytes, + attributes( + smb_direct, + smb_buffer, + smb_vector, + smb_string, + smb_enum, + smb_skip, + smb_byte_tag, + smb_string_tag, + smb_discriminator + ) +)] pub fn smb_enum_from_bytes(input: TokenStream) -> TokenStream { let input: DeriveInput = parse_macro_input!(input); @@ -137,7 +164,19 @@ pub fn smb_enum_from_bytes(input: TokenStream) -> TokenStream { /// Allocates a `Vec` of the correct size (via `SMBByteSize`) and writes /// each field into its wire-format position. Field ordering and placement is /// controlled by the same `smb_*` attributes used for parsing. -#[proc_macro_derive(SMBToBytes, attributes(smb_direct, smb_buffer, smb_vector, smb_string, smb_enum, smb_skip, smb_byte_tag, smb_string_tag))] +#[proc_macro_derive( + SMBToBytes, + attributes( + smb_direct, + smb_buffer, + smb_vector, + smb_string, + smb_enum, + smb_skip, + smb_byte_tag, + smb_string_tag + ) +)] pub fn smb_to_bytes(input: TokenStream) -> TokenStream { let input: DeriveInput = parse_macro_input!(input); @@ -152,7 +191,19 @@ pub fn smb_to_bytes(input: TokenStream) -> TokenStream { /// Computes the total on-wire byte size by summing fixed-field sizes, skip /// regions, tag bytes, and the dynamic sizes of any buffer/vector/string /// fields. -#[proc_macro_derive(SMBByteSize, attributes(smb_direct, smb_buffer, smb_vector, smb_string, smb_enum, smb_skip, smb_byte_tag, smb_string_tag))] +#[proc_macro_derive( + SMBByteSize, + attributes( + smb_direct, + smb_buffer, + smb_vector, + smb_string, + smb_enum, + smb_skip, + smb_byte_tag, + smb_string_tag + ) +)] pub fn smb_byte_size(input: TokenStream) -> TokenStream { let input: DeriveInput = parse_macro_input!(input); @@ -161,7 +212,6 @@ pub fn smb_byte_size(input: TokenStream) -> TokenStream { parse_token.into() } - /// Central dispatch that maps a [`DeriveInput`] (struct or enum) into the /// appropriate [`SMBFieldMapping`] and then delegates to the supplied /// [`CreatorFn`] to produce the final trait implementation. @@ -187,39 +237,38 @@ fn derive_impl_creator(input: DeriveInput, creator: impl CreatorFn) -> proc_macr SMBDeriveError::TypeError(f) => quote_spanned! {f.span()=>::std::compile_error!("Invalid field for SMB message parsing")}, _ => invalid_token }) - }, - Data::Enum(enum_info) => { - match enum_repr_type(&input.attrs) { - Ok(repr) => { - let mapping = get_num_enum_mapping(&input, parent_attrs, repr) - .map(|r| vec![r]); - creator.call(mapping, name) - .unwrap_or_else(|_e| quote_spanned! {input.span()=> - ::std::compile_error!("Invalid enum for SMB message parsing") - }) - }, - Err(_) => { - let mapping = get_desc_enum_mapping(enum_info); - creator.call(mapping, name) + } + Data::Enum(enum_info) => match enum_repr_type(&input.attrs) { + Ok(repr) => { + let mapping = get_num_enum_mapping(&input, parent_attrs, repr).map(|r| vec![r]); + creator.call(mapping, name).unwrap_or_else(|_e| { + quote_spanned! {input.span()=> + ::std::compile_error!("Invalid enum for SMB message parsing") + } + }) + } + Err(_) => { + let mapping = get_desc_enum_mapping(enum_info); + creator.call(mapping, name) .unwrap_or_else(|e| match e { SMBDeriveError::TypeError(f) => quote_spanned! {f.span()=>::std::compile_error!("Invalid field for SMB message parsing")}, _ => invalid_token }) - } } }, - _ => invalid_token + _ => invalid_token, } } - /// Extracts any struct-level / enum-level `smb_*` attributes (e.g. /// `#[smb_byte_tag(…)]`, `#[smb_string_tag(…)]`) from the top-level /// `DeriveInput` and returns them as a sorted list of [`SMBFieldType`]s. fn parent_attrs(input: &DeriveInput) -> Vec { - input.attrs.iter().filter_map(|attr| { - SMBFieldType::from_attributes(std::slice::from_ref(attr)).ok() - }).collect() + input + .attrs + .iter() + .filter_map(|attr| SMBFieldType::from_attributes(std::slice::from_ref(attr)).ok()) + .collect() } /// Trait object interface for the four code-generation backends. @@ -228,7 +277,11 @@ fn parent_attrs(input: &DeriveInput) -> Vec { /// [`EnumFromBytesCreator`]) implements this trait so that /// [`derive_impl_creator`] can dispatch generically. trait CreatorFn { - fn call(self, mapping: Result>, SMBDeriveError>, name: &Ident) -> Result>; + fn call( + self, + mapping: Result>, SMBDeriveError>, + name: &Ident, + ) -> Result>; } /// Errors that can occur during derive-macro expansion. @@ -242,11 +295,15 @@ enum SMBDeriveError { impl Display for SMBDeriveError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { - Self::TypeError(span) => write!(f, "No type annotation for spannable ${:?} (must be buffer or direct)", span), + Self::TypeError(span) => write!( + f, + "No type annotation for spannable ${:?} (must be buffer or direct)", + span + ), Self::MissingField => write!(f, "Needed attribute for field missing"), Self::InvalidType => write!(f, "Unsupported or invalid type"), } } } -impl std::error::Error for SMBDeriveError {} \ No newline at end of file +impl std::error::Error for SMBDeriveError {} diff --git a/smb-derive/src/smb_byte_size.rs b/smb-derive/src/smb_byte_size.rs index 2632404..3ebaca8 100644 --- a/smb-derive/src/smb_byte_size.rs +++ b/smb-derive/src/smb_byte_size.rs @@ -4,8 +4,8 @@ use proc_macro2::Ident; use quote::quote; use syn::spanned::Spanned; -use crate::{CreatorFn, SMBDeriveError}; use crate::field_mapping::SMBFieldMapping; +use crate::{CreatorFn, SMBDeriveError}; /// Code-generation backend for [`SMBByteSize`]. /// @@ -14,12 +14,19 @@ use crate::field_mapping::SMBFieldMapping; pub(crate) struct ByteSizeCreator {} impl CreatorFn for ByteSizeCreator { - fn call(self, mappings: Result>, SMBDeriveError>, name: &Ident) -> Result> { + fn call( + self, + mappings: Result>, SMBDeriveError>, + name: &Ident, + ) -> Result> { create_byte_size_impl(mappings, name) } } -fn create_byte_size_impl(mappings: Result>, SMBDeriveError>, name: &Ident) -> Result> { +fn create_byte_size_impl( + mappings: Result>, SMBDeriveError>, + name: &Ident, +) -> Result> { let mappings = mappings?; let size = mappings.iter().map(|mapping| smb_byte_size_impl(mapping)); Ok(quote! { @@ -34,6 +41,8 @@ fn create_byte_size_impl(mapping: &SMBFieldMapping) -> proc_macro2::TokenStream { +fn smb_byte_size_impl( + mapping: &SMBFieldMapping, +) -> proc_macro2::TokenStream { mapping.get_mapping_size() -} \ No newline at end of file +} diff --git a/smb-derive/src/smb_enum_from_bytes.rs b/smb-derive/src/smb_enum_from_bytes.rs index 01ef0e2..2ece9ac 100644 --- a/smb-derive/src/smb_enum_from_bytes.rs +++ b/smb-derive/src/smb_enum_from_bytes.rs @@ -4,8 +4,8 @@ use proc_macro2::Ident; use quote::quote; use syn::spanned::Spanned; +use crate::field_mapping::{SMBFieldMapping, smb_enum_from_bytes}; use crate::{CreatorFn, SMBDeriveError}; -use crate::field_mapping::{smb_enum_from_bytes, SMBFieldMapping}; /// Code-generation backend for [`SMBEnumFromBytes`]. /// @@ -14,12 +14,19 @@ use crate::field_mapping::{smb_enum_from_bytes, SMBFieldMapping}; pub(crate) struct EnumFromBytesCreator {} impl CreatorFn for EnumFromBytesCreator { - fn call(self, mapping: Result>, SMBDeriveError>, name: &Ident) -> Result> { + fn call( + self, + mapping: Result>, SMBDeriveError>, + name: &Ident, + ) -> Result> { enum_from_bytes_parser_impl(mapping, name) } } -fn enum_from_bytes_parser_impl(mappings: Result>, SMBDeriveError>, name: &Ident) -> Result> { +fn enum_from_bytes_parser_impl( + mappings: Result>, SMBDeriveError>, + name: &Ident, +) -> Result> { let mappings = mappings?; let parser = mappings.iter().map(|mapping| smb_enum_from_bytes(mapping)); Ok(quote! { @@ -33,4 +40,4 @@ fn enum_from_bytes_parser_impl(self, mappings: Result>, SMBDeriveError>, name: &Ident) -> Result> { + fn call( + self, + mappings: Result>, SMBDeriveError>, + name: &Ident, + ) -> Result> { create_parser_impl(mappings, name) } } -fn create_parser_impl(mapping: Result>, SMBDeriveError>, name: &Ident) -> Result> { +fn create_parser_impl( + mapping: Result>, SMBDeriveError>, + name: &Ident, +) -> Result> { let mapping = mapping?; let parser = smb_from_bytes(&mapping[0]); @@ -32,4 +39,4 @@ fn create_parser_impl(self, mappings: Result>, SMBDeriveError>, name: &Ident) -> Result> { + fn call( + self, + mappings: Result>, SMBDeriveError>, + name: &Ident, + ) -> Result> { to_bytes_parser_impl(mappings, name) } } -fn to_bytes_parser_impl(mappings: Result>, SMBDeriveError>, name: &Ident) -> Result> { +fn to_bytes_parser_impl( + mappings: Result>, SMBDeriveError>, + name: &Ident, +) -> Result> { let mappings = mappings?; let to_bytes = mappings.iter().map(|mapping| smb_to_bytes(mapping)); @@ -34,4 +41,4 @@ fn to_bytes_parser_impl = vec![0x01, 0x00, 0x02, 0x00, 0x00, 0x00, 0xFF, 0xFF]; let (remaining, parsed) = TwoFields::smb_from_bytes(&bytes).unwrap(); - assert_eq!(parsed, TwoFields { field_a: 1, field_b: 2 }); + assert_eq!( + parsed, + TwoFields { + field_a: 1, + field_b: 2 + } + ); assert_eq!(remaining, &[0xFF, 0xFF]); } @@ -126,7 +138,10 @@ struct WithByteTag { #[test] fn byte_tag_to_bytes() { - let val = WithByteTag { flags: 0x0001, extra: 0 }; + let val = WithByteTag { + flags: 0x0001, + extra: 0, + }; let bytes = val.smb_to_bytes(); // First byte should be the tag value (9) assert_eq!(bytes[0], 9); @@ -187,8 +202,14 @@ fn buffer_from_bytes() { // --------------------------------------------------------------------------- #[derive( - Debug, PartialEq, Eq, Clone, Copy, - SMBFromBytes, SMBToBytes, SMBByteSize, + Debug, + PartialEq, + Eq, + Clone, + Copy, + SMBFromBytes, + SMBToBytes, + SMBByteSize, num_enum::TryFromPrimitive, )] #[repr(u16)] @@ -324,7 +345,10 @@ struct Gapped { #[test] fn gapped_roundtrip() { - let original = Gapped { first: 0xAA, second: 0x11223344 }; + let original = Gapped { + first: 0xAA, + second: 0x11223344, + }; let bytes = original.smb_to_bytes(); // Byte 0 = 0xAA, bytes 1-3 = 0 (gap), bytes 4-7 = LE 0x11223344 assert_eq!(bytes[0], 0xAA); @@ -368,8 +392,14 @@ fn wrapper_roundtrip() { // --------------------------------------------------------------------------- #[derive( - Debug, PartialEq, Eq, Clone, Copy, - SMBFromBytes, SMBToBytes, SMBByteSize, + Debug, + PartialEq, + Eq, + Clone, + Copy, + SMBFromBytes, + SMBToBytes, + SMBByteSize, num_enum::TryFromPrimitive, )] #[repr(u8)] @@ -381,11 +411,7 @@ enum SmallEnum { #[test] fn small_enum_roundtrip() { - for (variant, expected_byte) in [ - (SmallEnum::A, 0u8), - (SmallEnum::B, 1), - (SmallEnum::C, 255), - ] { + for (variant, expected_byte) in [(SmallEnum::A, 0u8), (SmallEnum::B, 1), (SmallEnum::C, 255)] { let bytes = variant.smb_to_bytes(); assert_eq!(bytes, vec![expected_byte]); let (_rem, parsed) = SmallEnum::smb_from_bytes(&bytes).unwrap(); @@ -409,7 +435,10 @@ struct HeaderLike { #[test] fn header_like_to_bytes() { - let val = HeaderLike { value: 0x0040, extra: 0 }; + let val = HeaderLike { + value: 0x0040, + extra: 0, + }; let bytes = val.smb_to_bytes(); assert_eq!(bytes[0], 0xFE); assert_eq!(&bytes[1..4], b"SMB"); @@ -468,4 +497,4 @@ fn inner_offset_from_bytes() { let (_remaining, parsed) = WithInnerOffset::smb_from_bytes(&bytes).unwrap(); assert_eq!(parsed.flags, 1); assert_eq!(parsed.buffer, vec![0xAA, 0xBB, 0xCC]); -} \ No newline at end of file +} diff --git a/smb/src/byte_helper.rs b/smb/src/byte_helper.rs index d30eaea..a2c7ad1 100644 --- a/smb/src/byte_helper.rs +++ b/smb/src/byte_helper.rs @@ -3,10 +3,10 @@ pub(crate) fn u16_to_bytes(num: u16) -> [u8; 2] { } pub(crate) fn bytes_to_u32(bytes: &[u8]) -> u32 { - (bytes[0] as u32) | - ((bytes[1] as u32) << 8) | - ((bytes[2] as u32) << 16) | - ((bytes[3] as u32) << 24) + (bytes[0] as u32) + | ((bytes[1] as u32) << 8) + | ((bytes[2] as u32) << 16) + | ((bytes[3] as u32) << 24) } pub(crate) fn u32_to_bytes(num: u32) -> [u8; 4] { @@ -19,14 +19,14 @@ pub(crate) fn u32_to_bytes(num: u32) -> [u8; 4] { } pub(crate) fn bytes_to_u64(bytes: &[u8]) -> u64 { - (bytes[0] as u64) | - ((bytes[1] as u64) << 8) | - ((bytes[2] as u64) << 16) | - ((bytes[3] as u64) << 24) | - ((bytes[4] as u64) << 32) | - ((bytes[5] as u64) << 40) | - ((bytes[6] as u64) << 48) | - ((bytes[7] as u64) << 56) + (bytes[0] as u64) + | ((bytes[1] as u64) << 8) + | ((bytes[2] as u64) << 16) + | ((bytes[3] as u64) << 24) + | ((bytes[4] as u64) << 32) + | ((bytes[5] as u64) << 40) + | ((bytes[6] as u64) << 48) + | ((bytes[7] as u64) << 56) } pub(crate) fn u64_to_bytes(num: u64) -> [u8; 8] { @@ -76,7 +76,11 @@ mod tests { fn u64_max_value_round_trip() { let val: u64 = u64::MAX; let bytes = u64_to_bytes(val); - assert_eq!(bytes_to_u64(&bytes), val, "u64::MAX should round-trip correctly"); + assert_eq!( + bytes_to_u64(&bytes), + val, + "u64::MAX should round-trip correctly" + ); } #[test] @@ -84,6 +88,10 @@ mod tests { let val: u64 = 0xFF00_0000_0000_0000; let bytes = u64_to_bytes(val); assert_eq!(bytes[7], 0xFF, "High byte should be 0xFF"); - assert_eq!(bytes_to_u64(&bytes), val, "High-byte-only u64 should round-trip"); + assert_eq!( + bytes_to_u64(&bytes), + val, + "High-byte-only u64 should round-trip" + ); } -} \ No newline at end of file +} diff --git a/smb/src/lib.rs b/smb/src/lib.rs index ad019a8..baf9d7b 100644 --- a/smb/src/lib.rs +++ b/smb/src/lib.rs @@ -41,12 +41,12 @@ extern crate core; +mod byte_helper; /// SMB2/3 wire-format protocol types: headers, bodies, and message framing. pub mod protocol; -/// Utility modules: authentication, cryptography, byte helpers, and flag macros. -pub mod util; /// SMB server implementation: connection, session, tree-connect, and open management. pub mod server; /// Socket abstractions for SMB message transport (TCP listener, read/write streams). pub mod socket; -mod byte_helper; +/// Utility modules: authentication, cryptography, byte helpers, and flag macros. +pub mod util; diff --git a/smb/src/main.rs b/smb/src/main.rs index 63f35ae..0a2b652 100644 --- a/smb/src/main.rs +++ b/smb/src/main.rs @@ -6,10 +6,12 @@ use tokio::net::TcpListener; use smb_core::SMBResult; use smb_core::logging::info; -use smb_reader::protocol::body::tree_connect::access_mask::{SMBAccessMask, SMBDirectoryAccessMask}; +use smb_reader::protocol::body::tree_connect::access_mask::{ + SMBAccessMask, SMBDirectoryAccessMask, +}; use smb_reader::server::{DefaultShare, SMBServerBuilder, StartSMBServer}; -use smb_reader::util::auth::ntlm::NTLMAuthProvider; use smb_reader::util::auth::User; +use smb_reader::util::auth::ntlm::NTLMAuthProvider; const NTLM_ID: [u8; 10] = [0x2b, 0x06, 0x01, 0x04, 0x01, 0x82, 0x37, 0x02, 0x02, 0x0a]; const SPNEGO_ID: [u8; 6] = [0x2b, 0x06, 0x01, 0x05, 0x05, 0x02]; @@ -24,8 +26,7 @@ async fn main() -> SMBResult<()> { use tracing_subscriber::EnvFilter; tracing_subscriber::fmt() .with_env_filter( - EnvFilter::try_from_default_env() - .unwrap_or_else(|_| EnvFilter::new("info")), + EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")), ) .init(); } @@ -35,18 +36,28 @@ async fn main() -> SMBResult<()> { .and_then(|p| p.parse().ok()) .unwrap_or(50122); let addr = std::net::SocketAddr::from(([127, 0, 0, 1], port)); - let builder = SMBServerBuilder::<_, TcpListener, NTLMAuthProvider, DefaultShare, _>::default() - .anonymous_access(true) - .unencrypted_access(true) - .require_message_signing(false) - .encrypt_data(false) - .add_fs_share("test".into(), "".into(), file_allowed, get_file_perms) - .add_ipc_share() - .auth_provider(NTLMAuthProvider::new(vec![ + let builder = SMBServerBuilder::< + _, + TcpListener, + NTLMAuthProvider, + DefaultShare, + _, + >::default() + .anonymous_access(true) + .unencrypted_access(true) + .require_message_signing(false) + .encrypt_data(false) + .add_fs_share("test".into(), "".into(), file_allowed, get_file_perms) + .add_ipc_share() + .auth_provider(NTLMAuthProvider::new( + vec![ User::new("tejasmehta", "password"), User::new("tejas2", "password"), - ], false)) - .listener_address(addr).await?; + ], + false, + )) + .listener_address(addr) + .await?; let server = builder.build()?; info!(port, "SMB server starting"); server.start().await @@ -71,4 +82,4 @@ fn file_allowed(test: &String) -> bool { fn get_file_perms(test: &String) -> SMBAccessMask { SMBAccessMask::Directory(SMBDirectoryAccessMask::GENERIC_ALL) -} \ No newline at end of file +} diff --git a/smb/src/protocol/body/cancel/mod.rs b/smb/src/protocol/body/cancel/mod.rs index e5a4399..a83e688 100644 --- a/smb/src/protocol/body/cancel/mod.rs +++ b/smb/src/protocol/body/cancel/mod.rs @@ -1,3 +1,3 @@ use crate::protocol::body::empty::SMBEmpty; -pub type SMBCancelRequest = SMBEmpty; \ No newline at end of file +pub type SMBCancelRequest = SMBEmpty; diff --git a/smb/src/protocol/body/capabilities.rs b/smb/src/protocol/body/capabilities.rs index 5167c6e..c0f9390 100644 --- a/smb/src/protocol/body/capabilities.rs +++ b/smb/src/protocol/body/capabilities.rs @@ -1,7 +1,9 @@ use bitflags::bitflags; use serde::{Deserialize, Serialize}; -use crate::util::flags_helper::{impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag}; +use crate::util::flags_helper::{ + impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag, +}; bitflags! { #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)] @@ -45,4 +47,4 @@ mod tests { let (_, parsed) = Capabilities::smb_from_bytes(&bytes).unwrap(); assert_eq!(parsed, caps); } -} \ No newline at end of file +} diff --git a/smb/src/protocol/body/change_notify/completion_filter.rs b/smb/src/protocol/body/change_notify/completion_filter.rs index bc4d6aa..03ac98c 100644 --- a/smb/src/protocol/body/change_notify/completion_filter.rs +++ b/smb/src/protocol/body/change_notify/completion_filter.rs @@ -1,7 +1,9 @@ use bitflags::bitflags; use serde::{Deserialize, Serialize}; -use crate::util::flags_helper::{impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag}; +use crate::util::flags_helper::{ + impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag, +}; bitflags! { #[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone, Copy)] diff --git a/smb/src/protocol/body/change_notify/flags.rs b/smb/src/protocol/body/change_notify/flags.rs index 12c1749..e3d7f3e 100644 --- a/smb/src/protocol/body/change_notify/flags.rs +++ b/smb/src/protocol/body/change_notify/flags.rs @@ -1,7 +1,9 @@ use bitflags::bitflags; use serde::{Deserialize, Serialize}; -use crate::util::flags_helper::{impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag}; +use crate::util::flags_helper::{ + impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag, +}; bitflags! { #[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone, Copy)] diff --git a/smb/src/protocol/body/change_notify/mod.rs b/smb/src/protocol/body/change_notify/mod.rs index 103216c..d403ee5 100644 --- a/smb/src/protocol/body/change_notify/mod.rs +++ b/smb/src/protocol/body/change_notify/mod.rs @@ -8,19 +8,11 @@ use crate::protocol::body::change_notify::completion_filter::SMBCompletionFilter use crate::protocol::body::change_notify::flags::SMBChangeNotifyFlags; use crate::protocol::body::create::file_id::SMBFileId; -mod flags; mod completion_filter; +mod flags; #[derive( - Debug, - PartialEq, - Eq, - SMBByteSize, - SMBToBytes, - SMBFromBytes, - Serialize, - Deserialize, - Clone + Debug, PartialEq, Eq, SMBByteSize, SMBToBytes, SMBFromBytes, Serialize, Deserialize, Clone, )] #[smb_byte_tag(value = 32)] pub struct SMBChangeNotifyRequest { @@ -37,21 +29,17 @@ pub struct SMBChangeNotifyRequest { } #[derive( - Debug, - PartialEq, - Eq, - SMBByteSize, - SMBToBytes, - SMBFromBytes, - Serialize, - Deserialize, - Clone + Debug, PartialEq, Eq, SMBByteSize, SMBToBytes, SMBFromBytes, Serialize, Deserialize, Clone, )] #[smb_byte_tag(value = 17)] pub struct SMBChangeNotifyResponse { #[smb_skip(start = 2, length = 6)] reserved: PhantomData>, // TODO make this into a vector of FILE_NOTIFY_INFO structs: https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-smb2/14f9d050-27b2-49df-b009-54e08e8bf7b5 - #[smb_buffer(order = 0, offset(inner(start = 2, num_type = "u16", subtract = 64)), length(inner(start = 4, num_type = "u32")))] + #[smb_buffer( + order = 0, + offset(inner(start = 2, num_type = "u16", subtract = 64)), + length(inner(start = 4, num_type = "u32")) + )] data: Vec, -} \ No newline at end of file +} diff --git a/smb/src/protocol/body/close/flags.rs b/smb/src/protocol/body/close/flags.rs index fa20c2b..26beb9f 100644 --- a/smb/src/protocol/body/close/flags.rs +++ b/smb/src/protocol/body/close/flags.rs @@ -1,7 +1,9 @@ use bitflags::bitflags; use serde::{Deserialize, Serialize}; -use crate::util::flags_helper::{impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag}; +use crate::util::flags_helper::{ + impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag, +}; bitflags! { #[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone, Copy)] @@ -12,4 +14,4 @@ bitflags! { impl_smb_byte_size_for_bitflag! { SMBCloseFlags } impl_smb_to_bytes_for_bitflag! { SMBCloseFlags } -impl_smb_from_bytes_for_bitflag! { SMBCloseFlags } \ No newline at end of file +impl_smb_from_bytes_for_bitflag! { SMBCloseFlags } diff --git a/smb/src/protocol/body/close/mod.rs b/smb/src/protocol/body/close/mod.rs index 9f278f6..0b08bb6 100644 --- a/smb/src/protocol/body/close/mod.rs +++ b/smb/src/protocol/body/close/mod.rs @@ -12,15 +12,7 @@ use crate::protocol::body::filetime::FileTime; mod flags; #[derive( - Debug, - PartialEq, - Eq, - SMBByteSize, - SMBToBytes, - SMBFromBytes, - Serialize, - Deserialize, - Clone + Debug, PartialEq, Eq, SMBByteSize, SMBToBytes, SMBFromBytes, Serialize, Deserialize, Clone, )] #[smb_byte_tag(value = 24)] pub struct SMBCloseRequest { @@ -33,15 +25,7 @@ pub struct SMBCloseRequest { } #[derive( - Debug, - PartialEq, - Eq, - SMBByteSize, - SMBToBytes, - SMBFromBytes, - Serialize, - Deserialize, - Clone + Debug, PartialEq, Eq, SMBByteSize, SMBToBytes, SMBFromBytes, Serialize, Deserialize, Clone, )] #[smb_byte_tag(value = 60)] pub struct SMBCloseResponse { @@ -63,4 +47,4 @@ pub struct SMBCloseResponse { end_of_file: u64, #[smb_direct(start(fixed = 56))] file_attributes: SMBFileAttributes, -} \ No newline at end of file +} diff --git a/smb/src/protocol/body/create/action.rs b/smb/src/protocol/body/create/action.rs index ecee160..904ca52 100644 --- a/smb/src/protocol/body/create/action.rs +++ b/smb/src/protocol/body/create/action.rs @@ -4,10 +4,22 @@ use serde::{Deserialize, Serialize}; use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; #[repr(u32)] -#[derive(Debug, PartialEq, Eq, Serialize, Deserialize, SMBFromBytes, SMBToBytes, SMBByteSize, TryFromPrimitive, Copy, Clone)] +#[derive( + Debug, + PartialEq, + Eq, + Serialize, + Deserialize, + SMBFromBytes, + SMBToBytes, + SMBByteSize, + TryFromPrimitive, + Copy, + Clone, +)] pub enum SMBCreateAction { Superseded = 0x0, Opened = 0x1, Created = 0x2, Overwritten = 0x3, -} \ No newline at end of file +} diff --git a/smb/src/protocol/body/create/context_helper.rs b/smb/src/protocol/body/create/context_helper.rs index 55fe252..a9d8135 100644 --- a/smb/src/protocol/body/create/context_helper.rs +++ b/smb/src/protocol/body/create/context_helper.rs @@ -28,39 +28,45 @@ macro_rules! create_ctx_smb_byte_size { } macro_rules! create_ctx_smb_from_bytes { - ($enumType: expr, $bodyType: expr, $data: expr) => { - { - let (_, body) = $bodyType($data)?; - Ok($enumType(body)) - } - }; + ($enumType: expr, $bodyType: expr, $data: expr) => {{ + let (_, body) = $bodyType($data)?; + Ok($enumType(body)) + }}; } macro_rules! create_ctx_smb_to_bytes { - ($body: expr, $tag: expr) => { - { - let bytes = $body.smb_to_bytes(); - let wrapper = CreateContextWrapper{ - data: bytes, - reserved: PhantomData, - name: $tag.to_vec(), - }; - wrapper.smb_to_bytes() - } - }; + ($body: expr, $tag: expr) => {{ + let bytes = $body.smb_to_bytes(); + let wrapper = CreateContextWrapper { + data: bytes, + reserved: PhantomData, + name: $tag.to_vec(), + }; + wrapper.smb_to_bytes() + }}; } -#[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes)] +#[derive( + Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes, +)] pub struct CreateContextWrapper { #[smb_skip(start = 8, length = 4)] pub reserved: PhantomData>, - #[smb_buffer(offset(inner(start = 4, num_type = "u16")), length(inner(start = 6, num_type = "u16")), order = 0)] + #[smb_buffer( + offset(inner(start = 4, num_type = "u16")), + length(inner(start = 6, num_type = "u16")), + order = 0 + )] pub name: Vec, - #[smb_buffer(offset(inner(start = 10, num_type = "u16")), length(inner(start = 12, num_type = "u32")), order = 1)] + #[smb_buffer( + offset(inner(start = 10, num_type = "u16")), + length(inner(start = 12, num_type = "u32")), + order = 1 + )] pub data: Vec, } -pub(crate) use create_ctx_smb_to_bytes; -pub(crate) use create_ctx_smb_from_bytes; pub(crate) use create_ctx_smb_byte_size; -pub(crate) use impl_tag_for_ctx; \ No newline at end of file +pub(crate) use create_ctx_smb_from_bytes; +pub(crate) use create_ctx_smb_to_bytes; +pub(crate) use impl_tag_for_ctx; diff --git a/smb/src/protocol/body/create/disposition.rs b/smb/src/protocol/body/create/disposition.rs index c648965..8e1323e 100644 --- a/smb/src/protocol/body/create/disposition.rs +++ b/smb/src/protocol/body/create/disposition.rs @@ -4,7 +4,21 @@ use serde::{Deserialize, Serialize}; use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; #[repr(u8)] -#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, SMBFromBytes, SMBToBytes, SMBByteSize, TryFromPrimitive, Serialize, Deserialize)] +#[derive( + Debug, + Copy, + Clone, + PartialEq, + Eq, + PartialOrd, + Ord, + SMBFromBytes, + SMBToBytes, + SMBByteSize, + TryFromPrimitive, + Serialize, + Deserialize, +)] pub enum SMBCreateDisposition { Supersede = 0x0, Open = 0x1, @@ -21,4 +35,4 @@ impl SMBCreateDisposition { Self::Open | Self::Create | Self::OpenIf => true, } } -} \ No newline at end of file +} diff --git a/smb/src/protocol/body/create/file_attributes.rs b/smb/src/protocol/body/create/file_attributes.rs index 83b718f..74ac2a9 100644 --- a/smb/src/protocol/body/create/file_attributes.rs +++ b/smb/src/protocol/body/create/file_attributes.rs @@ -1,7 +1,9 @@ use bitflags::bitflags; use serde::{Deserialize, Serialize}; -use crate::util::flags_helper::{impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag}; +use crate::util::flags_helper::{ + impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag, +}; bitflags! { #[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Copy, Clone)] diff --git a/smb/src/protocol/body/create/file_id.rs b/smb/src/protocol/body/create/file_id.rs index 4329d04..85237b2 100644 --- a/smb/src/protocol/body/create/file_id.rs +++ b/smb/src/protocol/body/create/file_id.rs @@ -2,10 +2,12 @@ use serde::{Deserialize, Serialize}; use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; -#[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes)] +#[derive( + Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes, +)] pub struct SMBFileId { #[smb_direct(start(fixed = 0))] pub persistent: u64, #[smb_direct(start(fixed = 8))] pub volatile: u64, -} \ No newline at end of file +} diff --git a/smb/src/protocol/body/create/flags.rs b/smb/src/protocol/body/create/flags.rs index 01ca0f4..4e9596e 100644 --- a/smb/src/protocol/body/create/flags.rs +++ b/smb/src/protocol/body/create/flags.rs @@ -1,7 +1,9 @@ use bitflags::bitflags; use serde::{Deserialize, Serialize}; -use crate::util::flags_helper::{impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag}; +use crate::util::flags_helper::{ + impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag, +}; bitflags! { #[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone, Copy)] @@ -12,4 +14,4 @@ bitflags! { impl_smb_byte_size_for_bitflag! { SMBCreateFlags } impl_smb_to_bytes_for_bitflag! { SMBCreateFlags } -impl_smb_from_bytes_for_bitflag! { SMBCreateFlags } \ No newline at end of file +impl_smb_from_bytes_for_bitflag! { SMBCreateFlags } diff --git a/smb/src/protocol/body/create/impersonation_level.rs b/smb/src/protocol/body/create/impersonation_level.rs index ebe0f16..47d491d 100644 --- a/smb/src/protocol/body/create/impersonation_level.rs +++ b/smb/src/protocol/body/create/impersonation_level.rs @@ -4,10 +4,24 @@ use serde::{Deserialize, Serialize}; use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; #[repr(u8)] -#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, SMBFromBytes, SMBToBytes, SMBByteSize, TryFromPrimitive, Serialize, Deserialize)] +#[derive( + Debug, + Copy, + Clone, + PartialEq, + Eq, + PartialOrd, + Ord, + SMBFromBytes, + SMBToBytes, + SMBByteSize, + TryFromPrimitive, + Serialize, + Deserialize, +)] pub enum SMBImpersonationLevel { Anonymous = 0x0, Identification = 0x1, Impersonation = 0x2, Delegate = 0x3, -} \ No newline at end of file +} diff --git a/smb/src/protocol/body/create/mod.rs b/smb/src/protocol/body/create/mod.rs index 6f8b8c8..a7c79ef 100644 --- a/smb/src/protocol/body/create/mod.rs +++ b/smb/src/protocol/body/create/mod.rs @@ -2,9 +2,9 @@ use std::marker::PhantomData; use serde::{Deserialize, Serialize}; +use smb_core::SMBResult; use smb_core::error::SMBError; use smb_core::nt_status::NTStatus; -use smb_core::SMBResult; use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; use crate::protocol::body::create::action::SMBCreateAction; @@ -20,35 +20,27 @@ use crate::protocol::body::create::response_context::CreateResponseContext; use crate::protocol::body::create::share_access::SMBShareAccess; use crate::protocol::body::filetime::FileTime; use crate::protocol::body::tree_connect::access_mask::SMBAccessMask; -use crate::server::open::Open; use crate::server::Server; +use crate::server::open::Open; use crate::server::share::{ResourceType, SharedResource}; -pub mod options; -pub mod oplock; -pub mod impersonation_level; -pub mod file_attributes; -pub mod share_access; +mod action; pub mod disposition; -pub mod request_context; +pub mod file_attributes; pub mod file_id; mod flags; -mod action; +pub mod impersonation_level; +pub mod oplock; +pub mod options; +pub mod request_context; mod response_context; +pub mod share_access; #[macro_use] pub(crate) mod context_helper; #[derive( - Debug, - PartialEq, - Eq, - SMBByteSize, - SMBToBytes, - SMBFromBytes, - Serialize, - Deserialize, - Clone + Debug, PartialEq, Eq, SMBByteSize, SMBToBytes, SMBFromBytes, Serialize, Deserialize, Clone, )] #[smb_byte_tag(value = 57)] pub struct SMBCreateRequest { @@ -56,7 +48,12 @@ pub struct SMBCreateRequest { oplock_level: SMBOplockLevel, #[smb_direct(start(fixed = 4))] impersonation_level: SMBImpersonationLevel, - #[smb_enum(start(fixed = 24), discriminator(inner(start = 28, num_type = "u32")), modifier(and = 0x10), modifier(right_shift = 4))] + #[smb_enum( + start(fixed = 24), + discriminator(inner(start = 28, num_type = "u32")), + modifier(and = 0x10), + modifier(right_shift = 4) + )] desired_access: SMBAccessMask, #[smb_direct(start(fixed = 28))] attributes: SMBFileAttributes, @@ -66,9 +63,19 @@ pub struct SMBCreateRequest { create_disposition: SMBCreateDisposition, #[smb_direct(start(fixed = 40))] create_options: SMBCreateOptions, - #[smb_string(order = 0, start(inner(start = 44, num_type = "u16", subtract = 68)), length(inner(start = 46, num_type = "u16")), underlying = "u16")] + #[smb_string( + order = 0, + start(inner(start = 44, num_type = "u16", subtract = 68)), + length(inner(start = 46, num_type = "u16")), + underlying = "u16" + )] file_name: String, - #[smb_vector(order = 1, align = 8, length(inner(start = 52, num_type = "u32")), offset(inner(start = 48, num_type = "u32", subtract = 64)))] + #[smb_vector( + order = 1, + align = 8, + length(inner(start = 52, num_type = "u32")), + offset(inner(start = 48, num_type = "u32", subtract = 64)) + )] contexts: Vec, } @@ -86,9 +93,9 @@ impl SMBCreateRequest { } fn validate_print(&self) -> bool { - !self.attributes.contains(SMBFileAttributes::DIRECTORY) && - self.desired_access.validate_print() && - self.create_disposition == SMBCreateDisposition::Create + !self.attributes.contains(SMBFileAttributes::DIRECTORY) + && self.desired_access.validate_print() + && self.create_disposition == SMBCreateDisposition::Create } pub fn desired_access(&self) -> &SMBAccessMask { @@ -103,29 +110,32 @@ impl SMBCreateRequest { self.attributes } - pub fn validate(&self, resource: &R) -> SMBResult<(&str, SMBCreateDisposition, bool)> { + pub fn validate( + &self, + resource: &R, + ) -> SMBResult<(&str, SMBCreateDisposition, bool)> { if resource.resource_type() == ResourceType::PRINT_QUEUE && !self.validate_print() { - return Err(SMBError::response_error(NTStatus::NotSupported)) + return Err(SMBError::response_error(NTStatus::NotSupported)); } - if self.create_options.contains(SMBCreateOptions::DIRECTORY_FILE) && - !self.validate_directory() { + if self + .create_options + .contains(SMBCreateOptions::DIRECTORY_FILE) + && !self.validate_directory() + { // TODO make this the right error code return Err(SMBError::response_error(NTStatus::NotSupported)); } - Ok((self.file_name(), self.disposition(), self.create_options.contains(SMBCreateOptions::DIRECTORY_FILE))) + Ok(( + self.file_name(), + self.disposition(), + self.create_options + .contains(SMBCreateOptions::DIRECTORY_FILE), + )) } } #[derive( - Debug, - PartialEq, - Eq, - SMBByteSize, - SMBToBytes, - SMBFromBytes, - Serialize, - Deserialize, - Clone + Debug, PartialEq, Eq, SMBByteSize, SMBToBytes, SMBFromBytes, Serialize, Deserialize, Clone, )] #[smb_byte_tag(value = 89)] pub struct SMBCreateResponse { @@ -181,4 +191,4 @@ impl SMBCreateResponse { contexts: vec![], }) } -} \ No newline at end of file +} diff --git a/smb/src/protocol/body/create/oplock.rs b/smb/src/protocol/body/create/oplock.rs index 8bb2809..93f1119 100644 --- a/smb/src/protocol/body/create/oplock.rs +++ b/smb/src/protocol/body/create/oplock.rs @@ -4,11 +4,25 @@ use serde::{Deserialize, Serialize}; use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; #[repr(u8)] -#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, SMBFromBytes, SMBToBytes, SMBByteSize, TryFromPrimitive, Serialize, Deserialize)] +#[derive( + Debug, + Copy, + Clone, + PartialEq, + Eq, + PartialOrd, + Ord, + SMBFromBytes, + SMBToBytes, + SMBByteSize, + TryFromPrimitive, + Serialize, + Deserialize, +)] pub enum SMBOplockLevel { None = 0x0, II = 0x1, Exclusive = 0x8, Batch = 0x9, Lease = 0xFF, -} \ No newline at end of file +} diff --git a/smb/src/protocol/body/create/options.rs b/smb/src/protocol/body/create/options.rs index c09ace4..2af3f3a 100644 --- a/smb/src/protocol/body/create/options.rs +++ b/smb/src/protocol/body/create/options.rs @@ -1,7 +1,9 @@ use bitflags::bitflags; use serde::{Deserialize, Serialize}; -use crate::util::flags_helper::{impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag}; +use crate::util::flags_helper::{ + impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag, +}; bitflags! { #[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Copy, Clone)] @@ -32,15 +34,15 @@ bitflags! { impl SMBCreateOptions { pub fn validate_directory(&self) -> bool { - let flag = SMBCreateOptions::WRITE_THROUGH | - SMBCreateOptions::OPEN_FOR_BACKUP_INTENT | - SMBCreateOptions::DELETE_ON_CLOSE | - SMBCreateOptions::OPEN_REPARSE_POINT | - SMBCreateOptions::DIRECTORY_FILE; + let flag = SMBCreateOptions::WRITE_THROUGH + | SMBCreateOptions::OPEN_FOR_BACKUP_INTENT + | SMBCreateOptions::DELETE_ON_CLOSE + | SMBCreateOptions::OPEN_REPARSE_POINT + | SMBCreateOptions::DIRECTORY_FILE; (flag.complement() & *self) == SMBCreateOptions::empty() } } impl_smb_byte_size_for_bitflag! { SMBCreateOptions } impl_smb_to_bytes_for_bitflag! { SMBCreateOptions } -impl_smb_from_bytes_for_bitflag! { SMBCreateOptions } \ No newline at end of file +impl_smb_from_bytes_for_bitflag! { SMBCreateOptions } diff --git a/smb/src/protocol/body/create/request_context.rs b/smb/src/protocol/body/create/request_context.rs index 4637074..4964cd5 100644 --- a/smb/src/protocol/body/create/request_context.rs +++ b/smb/src/protocol/body/create/request_context.rs @@ -5,15 +5,20 @@ use num_enum::TryFromPrimitive; use serde::{Deserialize, Serialize}; use uuid::Uuid; -use smb_core::{SMBByteSize, SMBFromBytes, SMBParseResult, SMBToBytes}; use smb_core::error::SMBError; use smb_core::logging::trace; +use smb_core::{SMBByteSize, SMBFromBytes, SMBParseResult, SMBToBytes}; use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; -use crate::protocol::body::create::context_helper::{create_ctx_smb_byte_size, create_ctx_smb_from_bytes, create_ctx_smb_to_bytes, CreateContextWrapper, impl_tag_for_ctx}; +use crate::protocol::body::create::context_helper::{ + CreateContextWrapper, create_ctx_smb_byte_size, create_ctx_smb_from_bytes, + create_ctx_smb_to_bytes, impl_tag_for_ctx, +}; use crate::protocol::body::create::file_id::SMBFileId; use crate::protocol::body::filetime::FileTime; -use crate::util::flags_helper::{impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag}; +use crate::util::flags_helper::{ + impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag, +}; pub const EA_BUFFER_TAG: &[u8] = "ExtA".as_bytes(); pub const SD_BUFFER_TAG: &[u8] = "SecD".as_bytes(); @@ -28,16 +33,16 @@ pub const REQUEST_LEASE_TAG: &[u8] = "RqLs".as_bytes(); pub const DURABLE_HANDLE_REQUEST_V2_TAG: &[u8] = "DH2Q".as_bytes(); pub const DURABLE_HANDLE_RECONNECT_V2_TAG: &[u8] = "DH2C".as_bytes(); pub const APP_INSTANCE_ID_TAG: &[u8] = &[ - 0x45, 0xBC, 0xA6, 0x6A, 0xEF, 0xA7, 0xF7, 0x4A, 0x90, 0x08, 0xFA, 0x46, 0x2E, 0x14, 0x4D, 0x74 + 0x45, 0xBC, 0xA6, 0x6A, 0xEF, 0xA7, 0xF7, 0x4A, 0x90, 0x08, 0xFA, 0x46, 0x2E, 0x14, 0x4D, 0x74, ]; pub const APP_INSTANCE_VERSION_TAG: &[u8] = &[ - 0xB9, 0x82, 0xD0, 0xB7, 0x3B, 0x56, 0x07, 0x4F, 0xA0, 0x7B, 0x52, 0x4A, 0x81, 0x16, 0xA0, 0x10 + 0xB9, 0x82, 0xD0, 0xB7, 0x3B, 0x56, 0x07, 0x4F, 0xA0, 0x7B, 0x52, 0x4A, 0x81, 0x16, 0xA0, 0x10, ]; pub const SVHDX_OPEN_DEVICE_CONTEXT_TAG: &[u8] = &[ - 0x9C, 0xCB, 0xCF, 0x9E, 0x04, 0xC1, 0xE6, 0x43, 0x98, 0x0E, 0x15, 0x8D, 0xA1, 0xF6, 0xEC, 0x83 + 0x9C, 0xCB, 0xCF, 0x9E, 0x04, 0xC1, 0xE6, 0x43, 0x98, 0x0E, 0x15, 0x8D, 0xA1, 0xF6, 0xEC, 0x83, ]; pub const RESERVED: &[u8] = &[ - 0x93, 0xAD, 0x25, 0x50, 0x9C, 0xB4, 0x11, 0xE7, 0xB4, 0x23, 0x83, 0xDE, 0x96, 0x8B, 0xCD, 0x7C + 0x93, 0xAD, 0x25, 0x50, 0x9C, 0xB4, 0x11, 0xE7, 0xB4, 0x23, 0x83, 0xDE, 0x96, 0x8B, 0xCD, 0x7C, ]; #[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Clone)] @@ -82,7 +87,10 @@ impl SMBByteSize for CreateRequestContext { } impl SMBFromBytes for CreateRequestContext { - fn smb_from_bytes(input: &[u8]) -> SMBParseResult<&[u8], Self> where Self: Sized { + fn smb_from_bytes(input: &[u8]) -> SMBParseResult<&[u8], Self> + where + Self: Sized, + { trace!("parsing create request context wrapper"); let (remaining, wrapper) = CreateContextWrapper::smb_from_bytes(input)?; trace!(?wrapper, "parsed create request context wrapper"); @@ -113,17 +121,19 @@ impl SMBFromBytes for CreateRequestContext { AllocationSize::smb_from_bytes, wrapper.data.as_slice() ), - QUERY_MAXIMAL_ACCESS_REQUEST_TAG if !wrapper.data.is_empty() => create_ctx_smb_from_bytes!( - Self::QueryMaximalAccessRequest, - QueryMaximalAccessRequest::smb_from_bytes, - wrapper.data.as_slice() - ), + QUERY_MAXIMAL_ACCESS_REQUEST_TAG if !wrapper.data.is_empty() => { + create_ctx_smb_from_bytes!( + Self::QueryMaximalAccessRequest, + QueryMaximalAccessRequest::smb_from_bytes, + wrapper.data.as_slice() + ) + } // TODO investigate -- MacOS seems to send an empty payload here... - QUERY_MAXIMAL_ACCESS_REQUEST_TAG if wrapper.data.is_empty() => Ok(Self::QueryMaximalAccessRequest( - QueryMaximalAccessRequest { - timestamp: FileTime::zero() - } - )), + QUERY_MAXIMAL_ACCESS_REQUEST_TAG if wrapper.data.is_empty() => { + Ok(Self::QueryMaximalAccessRequest(QueryMaximalAccessRequest { + timestamp: FileTime::zero(), + })) + } TIMEWARP_TOKEN_TAG => create_ctx_smb_from_bytes!( Self::TimewarpToken, TimewarpToken::smb_from_bytes, @@ -169,7 +179,7 @@ impl SMBFromBytes for CreateRequestContext { SVHDXOpenDeviceContext::smb_from_bytes, wrapper.data.as_slice() ), - _ => Err(SMBError::parse_error("Invalid context tag")) + _ => Err(SMBError::parse_error("Invalid context tag")), }?; Ok((remaining, context)) @@ -181,28 +191,58 @@ impl SMBToBytes for CreateRequestContext { match self { CreateRequestContext::EABuffer(x) => create_ctx_smb_to_bytes!(x, EA_BUFFER_TAG), CreateRequestContext::SDBuffer(x) => create_ctx_smb_to_bytes!(x, SD_BUFFER_TAG), - CreateRequestContext::DurableHandleRequest(x) => create_ctx_smb_to_bytes!(x, DURABLE_HANDLE_REQUEST_TAG), - CreateRequestContext::DurableHandleReconnect(x) => create_ctx_smb_to_bytes!(x, DURABLE_HANDLE_RECONNECT_TAG), - CreateRequestContext::AllocationSize(x) => create_ctx_smb_to_bytes!(x, ALLOCATION_SIZE_TAG), - CreateRequestContext::QueryMaximalAccessRequest(x) => create_ctx_smb_to_bytes!(x, QUERY_MAXIMAL_ACCESS_REQUEST_TAG), - CreateRequestContext::TimewarpToken(x) => create_ctx_smb_to_bytes!(x, TIMEWARP_TOKEN_TAG), - CreateRequestContext::QueryOnDiskID(x) => create_ctx_smb_to_bytes!(x, QUERY_ON_DISK_ID_TAG), + CreateRequestContext::DurableHandleRequest(x) => { + create_ctx_smb_to_bytes!(x, DURABLE_HANDLE_REQUEST_TAG) + } + CreateRequestContext::DurableHandleReconnect(x) => { + create_ctx_smb_to_bytes!(x, DURABLE_HANDLE_RECONNECT_TAG) + } + CreateRequestContext::AllocationSize(x) => { + create_ctx_smb_to_bytes!(x, ALLOCATION_SIZE_TAG) + } + CreateRequestContext::QueryMaximalAccessRequest(x) => { + create_ctx_smb_to_bytes!(x, QUERY_MAXIMAL_ACCESS_REQUEST_TAG) + } + CreateRequestContext::TimewarpToken(x) => { + create_ctx_smb_to_bytes!(x, TIMEWARP_TOKEN_TAG) + } + CreateRequestContext::QueryOnDiskID(x) => { + create_ctx_smb_to_bytes!(x, QUERY_ON_DISK_ID_TAG) + } CreateRequestContext::RequestLease(x) => create_ctx_smb_to_bytes!(x, REQUEST_LEASE_TAG), - CreateRequestContext::RequestLeaseV2(x) => create_ctx_smb_to_bytes!(x, REQUEST_LEASE_TAG), - CreateRequestContext::DurableHandleRequestV2(x) => create_ctx_smb_to_bytes!(x, DURABLE_HANDLE_REQUEST_V2_TAG), - CreateRequestContext::DurableHandleReconnectV2(x) => create_ctx_smb_to_bytes!(x, DURABLE_HANDLE_RECONNECT_V2_TAG), - CreateRequestContext::AppInstanceID(x) => create_ctx_smb_to_bytes!(x, APP_INSTANCE_ID_TAG), - CreateRequestContext::AppInstanceVersion(x) => create_ctx_smb_to_bytes!(x, APP_INSTANCE_VERSION_TAG), - CreateRequestContext::SVHDXOpenDeviceContext(x) => create_ctx_smb_to_bytes!(x, SVHDX_OPEN_DEVICE_CONTEXT_TAG), + CreateRequestContext::RequestLeaseV2(x) => { + create_ctx_smb_to_bytes!(x, REQUEST_LEASE_TAG) + } + CreateRequestContext::DurableHandleRequestV2(x) => { + create_ctx_smb_to_bytes!(x, DURABLE_HANDLE_REQUEST_V2_TAG) + } + CreateRequestContext::DurableHandleReconnectV2(x) => { + create_ctx_smb_to_bytes!(x, DURABLE_HANDLE_RECONNECT_V2_TAG) + } + CreateRequestContext::AppInstanceID(x) => { + create_ctx_smb_to_bytes!(x, APP_INSTANCE_ID_TAG) + } + CreateRequestContext::AppInstanceVersion(x) => { + create_ctx_smb_to_bytes!(x, APP_INSTANCE_VERSION_TAG) + } + CreateRequestContext::SVHDXOpenDeviceContext(x) => { + create_ctx_smb_to_bytes!(x, SVHDX_OPEN_DEVICE_CONTEXT_TAG) + } } } } -#[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes)] +#[derive( + Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes, +)] pub struct EABuffer { #[smb_direct(start(fixed = 4))] flags: EABufferFlags, - #[smb_string(order = 0, length(inner(start = 5, num_type = "u8")), underlying = "u8")] + #[smb_string( + order = 0, + length(inner(start = 5, num_type = "u8")), + underlying = "u8" + )] name: String, #[smb_buffer(order = 1, length(inner(start = 6, num_type = "u16")))] value: Vec, @@ -210,7 +250,19 @@ pub struct EABuffer { #[repr(u8)] #[derive( -Debug, Eq, PartialEq, TryFromPrimitive, Serialize, Deserialize, Clone, Ord, PartialOrd, Copy, SMBFromBytes, SMBByteSize, SMBToBytes + Debug, + Eq, + PartialEq, + TryFromPrimitive, + Serialize, + Deserialize, + Clone, + Ord, + PartialOrd, + Copy, + SMBFromBytes, + SMBByteSize, + SMBToBytes, )] pub enum EABufferFlags { None = 0x0, @@ -218,14 +270,18 @@ pub enum EABufferFlags { } // TODO -#[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes)] +#[derive( + Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes, +)] pub struct SDBuffer { // revision: u8, // sbz1: u8, // control: u16, } -#[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes)] +#[derive( + Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes, +)] pub struct DurableHandleRequest { #[smb_skip(start = 0, length = 16)] reserved_1: PhantomData>, @@ -233,31 +289,41 @@ pub struct DurableHandleRequest { reserved_2: PhantomData>, } -#[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes)] +#[derive( + Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes, +)] pub struct DurableHandleReconnect { #[smb_direct(start(fixed = 0))] file_id: SMBFileId, } -#[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes)] +#[derive( + Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes, +)] pub struct QueryMaximalAccessRequest { #[smb_direct(start(fixed = 0))] timestamp: FileTime, } -#[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes)] +#[derive( + Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes, +)] pub struct AllocationSize { #[smb_direct(start(fixed = 0))] size: u64, } -#[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes)] +#[derive( + Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes, +)] pub struct TimewarpToken { #[smb_direct(start(fixed = 0))] timestamp: FileTime, } -#[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes)] +#[derive( + Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes, +)] pub struct RequestLease { #[smb_direct(start(fixed = 0))] lease_key: [u8; 16], @@ -279,10 +345,14 @@ bitflags! { } } -#[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes)] +#[derive( + Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes, +)] pub struct QueryOnDiskID {} -#[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes)] +#[derive( + Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes, +)] pub struct RequestLeaseV2 { #[smb_direct(start(fixed = 0))] lease_key: [u8; 16], @@ -307,7 +377,9 @@ bitflags! { } } -#[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes)] +#[derive( + Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes, +)] pub struct DurableHandleRequestV2 { #[smb_direct(start(fixed = 0))] timeout: u32, @@ -326,7 +398,9 @@ bitflags! { } } -#[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes)] +#[derive( + Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes, +)] pub struct DurableHandleReconnectV2 { #[smb_direct(start(fixed = 0))] file_id: SMBFileId, @@ -336,7 +410,9 @@ pub struct DurableHandleReconnectV2 { flags: DurableHandleV2Flags, } -#[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes)] +#[derive( + Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes, +)] #[smb_byte_tag(value = 20)] pub struct AppInstanceID { #[smb_skip(start = 0, length = 4)] @@ -345,7 +421,9 @@ pub struct AppInstanceID { app_instance_id: [u8; 16], } -#[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes)] +#[derive( + Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes, +)] #[smb_byte_tag(20)] pub struct AppInstanceVersion { #[smb_skip(start = 0, length = 4)] @@ -358,7 +436,9 @@ pub struct AppInstanceVersion { app_instance_version_low: u64, } -#[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes)] +#[derive( + Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes, +)] pub struct SVHDXOpenDeviceContext {} impl_tag_for_ctx!(EABuffer, EA_BUFFER_TAG); @@ -380,4 +460,3 @@ impl_tag_for_ctx!(SVHDXOpenDeviceContext, SVHDX_OPEN_DEVICE_CONTEXT_TAG); impl_smb_from_bytes_for_bitflag! {RequestLeaseState RequestLeaseFlags DurableHandleV2Flags} impl_smb_to_bytes_for_bitflag! {RequestLeaseState RequestLeaseFlags DurableHandleV2Flags} impl_smb_byte_size_for_bitflag! {RequestLeaseState RequestLeaseFlags DurableHandleV2Flags} - diff --git a/smb/src/protocol/body/create/response_context.rs b/smb/src/protocol/body/create/response_context.rs index f15a145..1972762 100644 --- a/smb/src/protocol/body/create/response_context.rs +++ b/smb/src/protocol/body/create/response_context.rs @@ -3,16 +3,25 @@ use std::marker::PhantomData; use bitflags::bitflags; use serde::{Deserialize, Serialize}; -use smb_core::{SMBByteSize, SMBFromBytes, SMBParseResult, SMBToBytes}; use smb_core::error::SMBError; use smb_core::logging::trace; use smb_core::nt_status::NTStatus; +use smb_core::{SMBByteSize, SMBFromBytes, SMBParseResult, SMBToBytes}; use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; -use crate::protocol::body::create::context_helper::{create_ctx_smb_byte_size, create_ctx_smb_from_bytes, create_ctx_smb_to_bytes, CreateContextWrapper, impl_tag_for_ctx}; -use crate::protocol::body::create::request_context::{DURABLE_HANDLE_REQUEST_TAG, DURABLE_HANDLE_REQUEST_V2_TAG, DurableHandleV2Flags, QUERY_MAXIMAL_ACCESS_REQUEST_TAG, QUERY_ON_DISK_ID_TAG, REQUEST_LEASE_TAG, RequestLeaseState, SVHDX_OPEN_DEVICE_CONTEXT_TAG}; +use crate::protocol::body::create::context_helper::{ + CreateContextWrapper, create_ctx_smb_byte_size, create_ctx_smb_from_bytes, + create_ctx_smb_to_bytes, impl_tag_for_ctx, +}; +use crate::protocol::body::create::request_context::{ + DURABLE_HANDLE_REQUEST_TAG, DURABLE_HANDLE_REQUEST_V2_TAG, DurableHandleV2Flags, + QUERY_MAXIMAL_ACCESS_REQUEST_TAG, QUERY_ON_DISK_ID_TAG, REQUEST_LEASE_TAG, RequestLeaseState, + SVHDX_OPEN_DEVICE_CONTEXT_TAG, +}; use crate::protocol::body::tree_connect::access_mask::SMBFilePipePrinterAccessMask; -use crate::util::flags_helper::{impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag}; +use crate::util::flags_helper::{ + impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag, +}; const DURABLE_HANDLE_RESPONSE_TAG: &[u8] = DURABLE_HANDLE_REQUEST_TAG; const QUERY_MAXIMAL_ACCESS_RESPONSE_TAG: &[u8] = QUERY_MAXIMAL_ACCESS_REQUEST_TAG; @@ -49,7 +58,10 @@ impl SMBByteSize for CreateResponseContext { } impl SMBFromBytes for CreateResponseContext { - fn smb_from_bytes(input: &[u8]) -> SMBParseResult<&[u8], Self> where Self: Sized { + fn smb_from_bytes(input: &[u8]) -> SMBParseResult<&[u8], Self> + where + Self: Sized, + { trace!("parsing create response context wrapper"); let (remaining, wrapper) = CreateContextWrapper::smb_from_bytes(input)?; @@ -90,7 +102,7 @@ impl SMBFromBytes for CreateResponseContext { SVHDXOpenDeviceContext::smb_from_bytes, wrapper.data.as_slice() ), - _ => Err(SMBError::parse_error("Invalid context tag")) + _ => Err(SMBError::parse_error("Invalid context tag")), }?; Ok((remaining, context)) @@ -100,18 +112,34 @@ impl SMBFromBytes for CreateResponseContext { impl SMBToBytes for CreateResponseContext { fn smb_to_bytes(&self) -> Vec { match self { - CreateResponseContext::DurableHandleResponse(x) => create_ctx_smb_to_bytes!(x, DURABLE_HANDLE_REQUEST_TAG), - CreateResponseContext::QueryMaximalAccessResponse(x) => create_ctx_smb_to_bytes!(x, QUERY_MAXIMAL_ACCESS_REQUEST_TAG), - CreateResponseContext::QueryOnDiskIDResponse(x) => create_ctx_smb_to_bytes!(x, QUERY_ON_DISK_ID_RESPONSE_TAG), - CreateResponseContext::ResponseLease(x) => create_ctx_smb_to_bytes!(x, REQUEST_LEASE_TAG), - CreateResponseContext::ResponseLeaseV2(x) => create_ctx_smb_to_bytes!(x, REQUEST_LEASE_TAG), - CreateResponseContext::DurableHandleResponseV2(x) => create_ctx_smb_to_bytes!(x, DURABLE_HANDLE_REQUEST_V2_TAG), - CreateResponseContext::SVHDXOpenDeviceContext(x) => create_ctx_smb_to_bytes!(x, SVHDX_OPEN_DEVICE_CONTEXT_TAG), + CreateResponseContext::DurableHandleResponse(x) => { + create_ctx_smb_to_bytes!(x, DURABLE_HANDLE_REQUEST_TAG) + } + CreateResponseContext::QueryMaximalAccessResponse(x) => { + create_ctx_smb_to_bytes!(x, QUERY_MAXIMAL_ACCESS_REQUEST_TAG) + } + CreateResponseContext::QueryOnDiskIDResponse(x) => { + create_ctx_smb_to_bytes!(x, QUERY_ON_DISK_ID_RESPONSE_TAG) + } + CreateResponseContext::ResponseLease(x) => { + create_ctx_smb_to_bytes!(x, REQUEST_LEASE_TAG) + } + CreateResponseContext::ResponseLeaseV2(x) => { + create_ctx_smb_to_bytes!(x, REQUEST_LEASE_TAG) + } + CreateResponseContext::DurableHandleResponseV2(x) => { + create_ctx_smb_to_bytes!(x, DURABLE_HANDLE_REQUEST_V2_TAG) + } + CreateResponseContext::SVHDXOpenDeviceContext(x) => { + create_ctx_smb_to_bytes!(x, SVHDX_OPEN_DEVICE_CONTEXT_TAG) + } } } } -#[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes)] +#[derive( + Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes, +)] pub struct DurableHandleResponse { #[smb_skip(start = 0, length = 8)] reserved: PhantomData>, @@ -119,7 +147,9 @@ pub struct DurableHandleResponse { reserved2: PhantomData>, } -#[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes)] +#[derive( + Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes, +)] pub struct QueryMaximalAccessResponse { #[smb_direct(start(fixed = 0))] status: NTStatus, @@ -127,7 +157,9 @@ pub struct QueryMaximalAccessResponse { maximal_access: SMBFilePipePrinterAccessMask, } -#[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes)] +#[derive( + Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes, +)] pub struct QueryOnDiskIDResponse { #[smb_direct(start(fixed = 0))] disk_file_id: u64, @@ -137,7 +169,9 @@ pub struct QueryOnDiskIDResponse { reserved: PhantomData>, } -#[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes)] +#[derive( + Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes, +)] pub struct ResponseLease { #[smb_direct(start(fixed = 0))] lease_key: [u8; 16], @@ -151,7 +185,9 @@ pub struct ResponseLease { pub type ResponseLeaseState = RequestLeaseState; -#[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes)] +#[derive( + Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes, +)] pub struct ResponseLeaseV2 { #[smb_direct(start(fixed = 0))] lease_key: [u8; 16], @@ -176,7 +212,9 @@ bitflags! { const PARENT_LEASE_KEY_SET = 0x4; } } -#[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes)] +#[derive( + Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes, +)] pub struct DurableHandleResponseV2 { #[smb_direct(start(fixed = 0))] timeout: u32, @@ -184,7 +222,9 @@ pub struct DurableHandleResponseV2 { flags: DurableHandleV2Flags, } -#[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes)] +#[derive( + Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes, +)] pub struct SVHDXOpenDeviceContext {} impl_smb_byte_size_for_bitflag!(ResponseLeaseFlags); @@ -192,9 +232,15 @@ impl_smb_to_bytes_for_bitflag!(ResponseLeaseFlags); impl_smb_from_bytes_for_bitflag!(ResponseLeaseFlags); impl_tag_for_ctx!(DurableHandleResponse, DURABLE_HANDLE_RESPONSE_TAG); -impl_tag_for_ctx!(QueryMaximalAccessResponse, QUERY_MAXIMAL_ACCESS_RESPONSE_TAG); +impl_tag_for_ctx!( + QueryMaximalAccessResponse, + QUERY_MAXIMAL_ACCESS_RESPONSE_TAG +); impl_tag_for_ctx!(QueryOnDiskIDResponse, QUERY_ON_DISK_ID_RESPONSE_TAG); impl_tag_for_ctx!(ResponseLease, RESPONSE_LEASE_TAG); impl_tag_for_ctx!(ResponseLeaseV2, RESPONSE_LEASE_TAG); impl_tag_for_ctx!(DurableHandleResponseV2, DURABLE_HANDLE_RESPONSE_V2_TAG); -impl_tag_for_ctx!(SVHDXOpenDeviceContext, SVHDX_OPEN_DEVICE_CONTEXT_RESPONSE_TAG); +impl_tag_for_ctx!( + SVHDXOpenDeviceContext, + SVHDX_OPEN_DEVICE_CONTEXT_RESPONSE_TAG +); diff --git a/smb/src/protocol/body/create/share_access.rs b/smb/src/protocol/body/create/share_access.rs index 8e79cfe..b8358a6 100644 --- a/smb/src/protocol/body/create/share_access.rs +++ b/smb/src/protocol/body/create/share_access.rs @@ -1,7 +1,9 @@ use bitflags::bitflags; use serde::{Deserialize, Serialize}; -use crate::util::flags_helper::{impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag}; +use crate::util::flags_helper::{ + impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag, +}; bitflags! { #[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone, Copy)] @@ -14,4 +16,4 @@ bitflags! { impl_smb_byte_size_for_bitflag! { SMBShareAccess } impl_smb_to_bytes_for_bitflag! { SMBShareAccess } -impl_smb_from_bytes_for_bitflag! { SMBShareAccess } \ No newline at end of file +impl_smb_from_bytes_for_bitflag! { SMBShareAccess } diff --git a/smb/src/protocol/body/dialect.rs b/smb/src/protocol/body/dialect.rs index e2ffdd5..e633902 100644 --- a/smb/src/protocol/body/dialect.rs +++ b/smb/src/protocol/body/dialect.rs @@ -4,7 +4,22 @@ use serde::{Deserialize, Serialize}; use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; #[repr(u16)] -#[derive(Debug, Eq, PartialEq, TryFromPrimitive, Serialize, Deserialize, Copy, Clone, Ord, PartialOrd, SMBFromBytes, SMBByteSize, SMBToBytes, Default)] +#[derive( + Debug, + Eq, + PartialEq, + TryFromPrimitive, + Serialize, + Deserialize, + Copy, + Clone, + Ord, + PartialOrd, + SMBFromBytes, + SMBByteSize, + SMBToBytes, + Default, +)] #[allow(non_camel_case_types)] pub enum SMBDialect { V2_0_2 = 0x202, @@ -13,7 +28,7 @@ pub enum SMBDialect { V3_0_2 = 0x302, V3_1_1 = 0x311, #[default] - V2_X_X = 0x2FF + V2_X_X = 0x2FF, } impl SMBDialect { @@ -63,4 +78,4 @@ mod tests { let (_, parsed) = SMBDialect::smb_from_bytes(&bytes).unwrap(); assert_eq!(parsed, dialect); } -} \ No newline at end of file +} diff --git a/smb/src/protocol/body/echo/mod.rs b/smb/src/protocol/body/echo/mod.rs index 87ca0b4..d2f780a 100644 --- a/smb/src/protocol/body/echo/mod.rs +++ b/smb/src/protocol/body/echo/mod.rs @@ -1,4 +1,4 @@ use crate::protocol::body::empty::SMBEmpty; pub type SMBEchoRequest = SMBEmpty; -pub type SMBEchoResponse = SMBEmpty; \ No newline at end of file +pub type SMBEchoResponse = SMBEmpty; diff --git a/smb/src/protocol/body/empty.rs b/smb/src/protocol/body/empty.rs index 37624e1..7c10527 100644 --- a/smb/src/protocol/body/empty.rs +++ b/smb/src/protocol/body/empty.rs @@ -3,15 +3,7 @@ use serde::{Deserialize, Serialize}; use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; #[derive( - Serialize, - Deserialize, - PartialEq, - Eq, - Debug, - SMBFromBytes, - SMBToBytes, - SMBByteSize, - Clone + Serialize, Deserialize, PartialEq, Eq, Debug, SMBFromBytes, SMBToBytes, SMBByteSize, Clone, )] #[smb_byte_tag(value = 4)] #[smb_skip(start = 0, length = 4)] @@ -46,4 +38,4 @@ mod tests { assert!(remaining.is_empty()); assert_eq!(parsed, empty); } -} \ No newline at end of file +} diff --git a/smb/src/protocol/body/error/mod.rs b/smb/src/protocol/body/error/mod.rs index 7e9a02b..c1e8c95 100644 --- a/smb/src/protocol/body/error/mod.rs +++ b/smb/src/protocol/body/error/mod.rs @@ -14,15 +14,7 @@ use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; /// NTStatus is carried in the SMB2 header (channel_sequence field), /// not in the error body. #[derive( - Debug, - PartialEq, - Eq, - SMBByteSize, - SMBToBytes, - SMBFromBytes, - Serialize, - Deserialize, - Clone + Debug, PartialEq, Eq, SMBByteSize, SMBToBytes, SMBFromBytes, Serialize, Deserialize, Clone, )] #[smb_byte_tag(value = 9)] pub struct SMBErrorResponse { diff --git a/smb/src/protocol/body/filetime.rs b/smb/src/protocol/body/filetime.rs index ac51b51..9ec9e4f 100644 --- a/smb/src/protocol/body/filetime.rs +++ b/smb/src/protocol/body/filetime.rs @@ -6,7 +6,18 @@ use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; use crate::byte_helper::{bytes_to_u32, bytes_to_u64, u32_to_bytes, u64_to_bytes}; -#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone, SMBFromBytes, SMBToBytes, SMBByteSize, Default)] +#[derive( + Serialize, + Deserialize, + PartialEq, + Eq, + Debug, + Clone, + SMBFromBytes, + SMBToBytes, + SMBByteSize, + Default, +)] pub struct FileTime { #[smb_direct(start(fixed = 0))] low_date_time: u32, @@ -82,7 +93,8 @@ mod tests { assert!( (back as i64 - unix_ts as i64).abs() < 2, "Unix timestamp should round-trip: got {} expected {}", - back, unix_ts + back, + unix_ts ); } -} \ No newline at end of file +} diff --git a/smb/src/protocol/body/flush/mod.rs b/smb/src/protocol/body/flush/mod.rs index 01b946d..d2c982e 100644 --- a/smb/src/protocol/body/flush/mod.rs +++ b/smb/src/protocol/body/flush/mod.rs @@ -8,15 +8,7 @@ use crate::protocol::body::create::file_id::SMBFileId; use crate::protocol::body::empty::SMBEmpty; #[derive( - Debug, - PartialEq, - Eq, - SMBByteSize, - SMBToBytes, - SMBFromBytes, - Serialize, - Deserialize, - Clone + Debug, PartialEq, Eq, SMBByteSize, SMBToBytes, SMBFromBytes, Serialize, Deserialize, Clone, )] #[smb_byte_tag(value = 24)] pub struct SMBFlushRequest { @@ -28,4 +20,4 @@ pub struct SMBFlushRequest { file_id: SMBFileId, } -pub type SMBFlushResponse = SMBEmpty; \ No newline at end of file +pub type SMBFlushResponse = SMBEmpty; diff --git a/smb/src/protocol/body/ioctl/flags.rs b/smb/src/protocol/body/ioctl/flags.rs index 26c0fcd..9ee0a08 100644 --- a/smb/src/protocol/body/ioctl/flags.rs +++ b/smb/src/protocol/body/ioctl/flags.rs @@ -4,9 +4,21 @@ use serde::{Deserialize, Serialize}; use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; #[repr(u32)] -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, TryFromPrimitive, SMBFromBytes, SMBByteSize, SMBToBytes)] +#[derive( + Debug, + Clone, + Copy, + PartialEq, + Eq, + Serialize, + Deserialize, + TryFromPrimitive, + SMBFromBytes, + SMBByteSize, + SMBToBytes, +)] #[allow(clippy::upper_case_acronyms)] pub enum SMBIoCtlRequestFlags { IOCTL = 0x0, FSCTL = 0x1, -} \ No newline at end of file +} diff --git a/smb/src/protocol/body/ioctl/method.rs b/smb/src/protocol/body/ioctl/method.rs index 6db0ba0..22838af 100644 --- a/smb/src/protocol/body/ioctl/method.rs +++ b/smb/src/protocol/body/ioctl/method.rs @@ -2,7 +2,18 @@ use serde::{Deserialize, Serialize}; use smb_derive::{SMBByteSize, SMBEnumFromBytes, SMBFromBytes, SMBToBytes}; -#[derive(SMBEnumFromBytes, Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize, SMBByteSize, SMBToBytes)] +#[derive( + SMBEnumFromBytes, + Debug, + PartialEq, + Eq, + Clone, + Copy, + Serialize, + Deserialize, + SMBByteSize, + SMBToBytes, +)] pub enum SMBIoCtlMethod { #[smb_discriminator(value = 0x00060194)] #[smb_direct(start(fixed = 0))] @@ -51,48 +62,77 @@ pub enum SMBIoCtlMethod { ValidateNegotiateInfo(ValidateNegotiateInfo), } -#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize, SMBByteSize, SMBToBytes, SMBFromBytes)] +#[derive( + Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize, SMBByteSize, SMBToBytes, SMBFromBytes, +)] pub struct DfsGetReferrals {} -#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize, SMBByteSize, SMBToBytes, SMBFromBytes)] +#[derive( + Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize, SMBByteSize, SMBToBytes, SMBFromBytes, +)] pub struct PipePeek {} -#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize, SMBByteSize, SMBToBytes, SMBFromBytes)] +#[derive( + Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize, SMBByteSize, SMBToBytes, SMBFromBytes, +)] pub struct PipeWait {} -#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize, SMBByteSize, SMBToBytes, SMBFromBytes)] +#[derive( + Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize, SMBByteSize, SMBToBytes, SMBFromBytes, +)] pub struct PipeTransceive {} -#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize, SMBByteSize, SMBToBytes, SMBFromBytes)] +#[derive( + Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize, SMBByteSize, SMBToBytes, SMBFromBytes, +)] pub struct SrvCopyChunk {} -#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize, SMBByteSize, SMBToBytes, SMBFromBytes)] +#[derive( + Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize, SMBByteSize, SMBToBytes, SMBFromBytes, +)] pub struct SrvEnumerateSnapshots {} -#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize, SMBByteSize, SMBToBytes, SMBFromBytes)] +#[derive( + Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize, SMBByteSize, SMBToBytes, SMBFromBytes, +)] pub struct SrvRequestResumeKey {} -#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize, SMBByteSize, SMBToBytes, SMBFromBytes)] +#[derive( + Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize, SMBByteSize, SMBToBytes, SMBFromBytes, +)] pub struct SrvReadHash {} -#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize, SMBByteSize, SMBToBytes, SMBFromBytes)] +#[derive( + Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize, SMBByteSize, SMBToBytes, SMBFromBytes, +)] pub struct SrvCopyChunkWrite {} -#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize, SMBByteSize, SMBToBytes, SMBFromBytes)] +#[derive( + Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize, SMBByteSize, SMBToBytes, SMBFromBytes, +)] pub struct LmrRequestResiliency {} -#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize, SMBByteSize, SMBToBytes, SMBFromBytes)] +#[derive( + Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize, SMBByteSize, SMBToBytes, SMBFromBytes, +)] pub struct NetworkInterfaceInfo {} -#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize, SMBByteSize, SMBToBytes, SMBFromBytes)] +#[derive( + Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize, SMBByteSize, SMBToBytes, SMBFromBytes, +)] pub struct SetReparsePoint {} -#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize, SMBByteSize, SMBToBytes, SMBFromBytes)] +#[derive( + Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize, SMBByteSize, SMBToBytes, SMBFromBytes, +)] pub struct DfsGetReferralsEx {} -#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize, SMBByteSize, SMBToBytes, SMBFromBytes)] +#[derive( + Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize, SMBByteSize, SMBToBytes, SMBFromBytes, +)] pub struct FileLevelTrip {} -#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize, SMBByteSize, SMBToBytes, SMBFromBytes)] +#[derive( + Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize, SMBByteSize, SMBToBytes, SMBFromBytes, +)] pub struct ValidateNegotiateInfo {} - diff --git a/smb/src/protocol/body/ioctl/mod.rs b/smb/src/protocol/body/ioctl/mod.rs index bdd2c6f..e147769 100644 --- a/smb/src/protocol/body/ioctl/mod.rs +++ b/smb/src/protocol/body/ioctl/mod.rs @@ -12,15 +12,7 @@ mod flags; mod method; #[derive( - Debug, - PartialEq, - Eq, - SMBByteSize, - SMBToBytes, - SMBFromBytes, - Serialize, - Deserialize, - Clone + Debug, PartialEq, Eq, SMBByteSize, SMBToBytes, SMBFromBytes, Serialize, Deserialize, Clone, )] #[smb_byte_tag(value = 57)] pub struct SMBIoCtlRequest { @@ -38,20 +30,15 @@ pub struct SMBIoCtlRequest { flags: SMBIoCtlRequestFlags, #[smb_skip(start = 52, length = 4)] reserved2: PhantomData>, - #[smb_enum(start(inner(start = 24, num_type = "u32")), discriminator(inner(start = 4, num_type = "u32")))] + #[smb_enum( + start(inner(start = 24, num_type = "u32")), + discriminator(inner(start = 4, num_type = "u32")) + )] input_method: SMBIoCtlMethod, } #[derive( - Debug, - PartialEq, - Eq, - SMBByteSize, - SMBToBytes, - SMBFromBytes, - Serialize, - Deserialize, - Clone + Debug, PartialEq, Eq, SMBByteSize, SMBToBytes, SMBFromBytes, Serialize, Deserialize, Clone, )] #[smb_byte_tag(value = 49)] pub struct SMBIoCtlResponse { @@ -65,6 +52,9 @@ pub struct SMBIoCtlResponse { flags: PhantomData>, #[smb_skip(start = 44, length = 4)] reserved2: PhantomData>, - #[smb_enum(start(inner(start = 30, num_type = "u32")), discriminator(inner(start = 4, num_type = "u32")))] + #[smb_enum( + start(inner(start = 30, num_type = "u32")), + discriminator(inner(start = 4, num_type = "u32")) + )] input_method: SMBIoCtlMethod, -} \ No newline at end of file +} diff --git a/smb/src/protocol/body/lock/flags.rs b/smb/src/protocol/body/lock/flags.rs index 8462312..eb0b402 100644 --- a/smb/src/protocol/body/lock/flags.rs +++ b/smb/src/protocol/body/lock/flags.rs @@ -1,7 +1,9 @@ use bitflags::bitflags; use serde::{Deserialize, Serialize}; -use crate::util::flags_helper::{impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag}; +use crate::util::flags_helper::{ + impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag, +}; bitflags! { #[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone, Copy)] @@ -15,4 +17,4 @@ bitflags! { impl_smb_from_bytes_for_bitflag!(SMBLockFlags); impl_smb_to_bytes_for_bitflag!(SMBLockFlags); -impl_smb_byte_size_for_bitflag!(SMBLockFlags); \ No newline at end of file +impl_smb_byte_size_for_bitflag!(SMBLockFlags); diff --git a/smb/src/protocol/body/lock/info.rs b/smb/src/protocol/body/lock/info.rs index e9cac3b..50b53ed 100644 --- a/smb/src/protocol/body/lock/info.rs +++ b/smb/src/protocol/body/lock/info.rs @@ -7,15 +7,7 @@ use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; use crate::protocol::body::lock::flags::SMBLockFlags; #[derive( - Debug, - PartialEq, - Eq, - SMBByteSize, - SMBToBytes, - SMBFromBytes, - Serialize, - Deserialize, - Clone + Debug, PartialEq, Eq, SMBByteSize, SMBToBytes, SMBFromBytes, Serialize, Deserialize, Clone, )] pub struct SMBLockInfo { #[smb_direct(start(fixed = 0))] @@ -26,4 +18,4 @@ pub struct SMBLockInfo { flags: SMBLockFlags, #[smb_skip(start = 20, length = 4)] reserved: PhantomData>, -} \ No newline at end of file +} diff --git a/smb/src/protocol/body/lock/mod.rs b/smb/src/protocol/body/lock/mod.rs index 8a0b32a..806c27f 100644 --- a/smb/src/protocol/body/lock/mod.rs +++ b/smb/src/protocol/body/lock/mod.rs @@ -6,19 +6,11 @@ use crate::protocol::body::create::file_id::SMBFileId; use crate::protocol::body::empty::SMBEmpty; use crate::protocol::body::lock::info::SMBLockInfo; -mod info; mod flags; +mod info; #[derive( - Debug, - PartialEq, - Eq, - SMBByteSize, - SMBToBytes, - SMBFromBytes, - Serialize, - Deserialize, - Clone + Debug, PartialEq, Eq, SMBByteSize, SMBToBytes, SMBFromBytes, Serialize, Deserialize, Clone, )] #[smb_byte_tag(value = 48)] pub struct SMBLockRequest { @@ -30,4 +22,4 @@ pub struct SMBLockRequest { locks: Vec, } -pub type SMBLockResponse = SMBEmpty; \ No newline at end of file +pub type SMBLockResponse = SMBEmpty; diff --git a/smb/src/protocol/body/logoff/mod.rs b/smb/src/protocol/body/logoff/mod.rs index 33e5fc3..71b7d8a 100644 --- a/smb/src/protocol/body/logoff/mod.rs +++ b/smb/src/protocol/body/logoff/mod.rs @@ -1,4 +1,4 @@ use crate::protocol::body::empty::SMBEmpty; pub type SMBLogoffRequest = SMBEmpty; -pub type SMBLogoffResponse = SMBEmpty; \ No newline at end of file +pub type SMBLogoffResponse = SMBEmpty; diff --git a/smb/src/protocol/body/mod.rs b/smb/src/protocol/body/mod.rs index f49f401..7db66f5 100644 --- a/smb/src/protocol/body/mod.rs +++ b/smb/src/protocol/body/mod.rs @@ -14,16 +14,16 @@ use nom::multi::many1; use nom::number::complete::le_u8; use serde::{Deserialize, Serialize}; -use smb_core::{SMBByteSize, SMBEnumFromBytes, SMBParseResult, SMBToBytes}; use smb_core::error::SMBError; +use smb_core::{SMBByteSize, SMBEnumFromBytes, SMBParseResult, SMBToBytes}; use smb_derive::{SMBByteSize, SMBEnumFromBytes, SMBToBytes}; use crate::protocol::body::cancel::SMBCancelRequest; -use crate::protocol::body::error::SMBErrorResponse; use crate::protocol::body::change_notify::{SMBChangeNotifyRequest, SMBChangeNotifyResponse}; use crate::protocol::body::close::{SMBCloseRequest, SMBCloseResponse}; use crate::protocol::body::create::{SMBCreateRequest, SMBCreateResponse}; use crate::protocol::body::echo::{SMBEchoRequest, SMBEchoResponse}; +use crate::protocol::body::error::SMBErrorResponse; use crate::protocol::body::flush::{SMBFlushRequest, SMBFlushResponse}; use crate::protocol::body::ioctl::{SMBIoCtlRequest, SMBIoCtlResponse}; use crate::protocol::body::lock::{SMBLockRequest, SMBLockResponse}; @@ -38,10 +38,10 @@ use crate::protocol::body::set_info::{SMBSetInfoRequest, SMBSetInfoResponse}; use crate::protocol::body::tree_connect::{SMBTreeConnectRequest, SMBTreeConnectResponse}; use crate::protocol::body::tree_disconnect::{SMBTreeDisconnectRequest, SMBTreeDisconnectResponse}; use crate::protocol::body::write::{SMBWriteRequest, SMBWriteResponse}; -use crate::protocol::header::command_code::{LegacySMBCommandCode, SMBCommandCode}; use crate::protocol::header::Header; use crate::protocol::header::LegacySMBHeader; use crate::protocol::header::SMBSyncHeader; +use crate::protocol::header::command_code::{LegacySMBCommandCode, SMBCommandCode}; pub mod capabilities; pub mod dialect; @@ -49,41 +49,35 @@ pub mod filetime; pub mod negotiate; pub mod session_setup; -pub mod logoff; -pub mod tree_connect; -pub mod tree_disconnect; -pub mod empty; +pub mod cancel; +pub mod change_notify; +pub mod close; pub mod create; +pub mod echo; +pub mod empty; pub mod error; -pub mod close; pub mod flush; -pub mod read; -pub mod write; +pub mod ioctl; pub mod lock; -pub mod echo; -pub mod cancel; +pub mod logoff; +pub mod oplock_break; pub mod query_directory; -pub mod change_notify; pub mod query_info; -pub mod ioctl; +pub mod read; pub mod set_info; -pub mod oplock_break; +pub mod tree_connect; +pub mod tree_disconnect; +pub mod write; pub trait Body: SMBEnumFromBytes + SMBToBytes { - fn parse_with_cc(bytes: &[u8], command_code: S::CommandCode) -> SMBParseResult<&[u8], Self> where Self: Sized; + fn parse_with_cc(bytes: &[u8], command_code: S::CommandCode) -> SMBParseResult<&[u8], Self> + where + Self: Sized; fn as_bytes(&self) -> Vec; } #[derive( - Serialize, - Deserialize, - PartialEq, - Eq, - Debug, - SMBEnumFromBytes, - SMBToBytes, - SMBByteSize, - Clone + Serialize, Deserialize, PartialEq, Eq, Debug, SMBEnumFromBytes, SMBToBytes, SMBByteSize, Clone, )] pub enum SMBBody { #[smb_discriminator(value = 0x0)] @@ -241,26 +235,43 @@ pub enum LegacySMBBody { } impl smb_core::SMBEnumFromBytes for LegacySMBBody { - fn smb_enum_from_bytes(input: &[u8], discriminator: u64) -> SMBParseResult<&[u8], Self> where Self: Sized { - match LegacySMBCommandCode::try_from(discriminator as u8).map(|x| x == LegacySMBCommandCode::Negotiate) { + fn smb_enum_from_bytes(input: &[u8], discriminator: u64) -> SMBParseResult<&[u8], Self> + where + Self: Sized, + { + match LegacySMBCommandCode::try_from(discriminator as u8) + .map(|x| x == LegacySMBCommandCode::Negotiate) + { Ok(true) => { - let (remaining, cnt) = le_u8(input) - .map_err(|_: nom::Err>| SMBError::parse_error("Invalid count"))?; - let (_, protocol_vecs) = many1(take_till(|n: u8| n == 0x02))(remaining) - .map_err(|_: nom::Err>| SMBError::parse_error("No valid payload"))?; + let (remaining, cnt) = + le_u8(input).map_err(|_: nom::Err>| { + SMBError::parse_error("Invalid count") + })?; + let (_, protocol_vecs) = many1(take_till(|n: u8| n == 0x02))(remaining).map_err( + |_: nom::Err>| { + SMBError::parse_error("No valid payload") + }, + )?; let mut protocol_strs = Vec::new(); for slice in protocol_vecs { let mut vec = slice.to_vec(); vec.retain(|x| *x != 0); - protocol_strs.push(String::from_utf8(vec).map_err( - |_| SMBError::parse_error("Could not map protocol to string"))? + protocol_strs.push( + String::from_utf8(vec).map_err(|_| { + SMBError::parse_error("Could not map protocol to string") + })?, ); } - let (remaining, _) = take(cnt as usize)(input) - .map_err(|_: nom::Err>| SMBError::parse_error("Size too small for parse length"))?; + let (remaining, _) = take(cnt as usize)(input).map_err( + |_: nom::Err>| { + SMBError::parse_error("Size too small for parse length") + }, + )?; Ok((remaining, LegacySMBBody::Negotiate(protocol_strs))) - }, - _ => Err(SMBError::parse_error("Unknown parse error for LegacySMBBody")), + } + _ => Err(SMBError::parse_error( + "Unknown parse error for LegacySMBBody", + )), } } } @@ -275,17 +286,23 @@ impl SMBByteSize for LegacySMBBody { fn smb_byte_size(&self) -> usize { match self { LegacySMBBody::None => 0, - LegacySMBBody::Negotiate(x) => x.len() * 2 + LegacySMBBody::Negotiate(x) => x.len() * 2, } } } impl Body for LegacySMBBody { - fn parse_with_cc(bytes: &[u8], command_code: LegacySMBCommandCode) -> SMBParseResult<&[u8], Self> where Self: Sized { + fn parse_with_cc( + bytes: &[u8], + command_code: LegacySMBCommandCode, + ) -> SMBParseResult<&[u8], Self> + where + Self: Sized, + { LegacySMBBody::smb_enum_from_bytes(bytes, command_code as u64) } fn as_bytes(&self) -> Vec { self.smb_to_bytes() } -} \ No newline at end of file +} diff --git a/smb/src/protocol/body/negotiate/context.rs b/smb/src/protocol/body/negotiate/context.rs index a780194..133dac7 100644 --- a/smb/src/protocol/body/negotiate/context.rs +++ b/smb/src/protocol/body/negotiate/context.rs @@ -6,17 +6,19 @@ use num_enum::TryFromPrimitive; use rand::RngCore; use serde::{Deserialize, Serialize}; -use smb_core::{SMBByteSize, SMBFromBytes, SMBParseResult, SMBResult, SMBToBytes}; use smb_core::error::SMBError; use smb_core::logging::trace; use smb_core::nt_status::NTStatus; +use smb_core::{SMBByteSize, SMBFromBytes, SMBParseResult, SMBResult, SMBToBytes}; use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; use crate::byte_helper::u16_to_bytes; -use crate::server::connection::{Connection, SMBConnection, SMBConnectionUpdate}; use crate::server::Server; +use crate::server::connection::{Connection, SMBConnection, SMBConnectionUpdate}; use crate::socket::message_stream::{SMBReadStream, SMBWriteStream}; -use crate::util::flags_helper::{impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag}; +use crate::util::flags_helper::{ + impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag, +}; const PRE_AUTH_INTEGRITY_CAPABILITIES_TAG: u16 = 0x01; const ENCRYPTION_CAPABILITIES_TAG: u16 = 0x02; @@ -60,7 +62,7 @@ pub enum NegotiateContext { TransportCapabilities(TransportCapabilities), RDMATransformCapabilities(RDMATransformCapabilities), SigningCapabilities(SigningCapabilities), - PosixExtensions(PosixExtensions) + PosixExtensions(PosixExtensions), } impl SMBByteSize for NegotiateContext { @@ -73,19 +75,24 @@ impl SMBByteSize for NegotiateContext { NegotiateContext::TransportCapabilities(x) => x.smb_byte_size(), NegotiateContext::RDMATransformCapabilities(x) => x.smb_byte_size(), NegotiateContext::SigningCapabilities(x) => x.smb_byte_size(), - NegotiateContext::PosixExtensions(x) => x.smb_byte_size() + NegotiateContext::PosixExtensions(x) => x.smb_byte_size(), }) + 2 } } impl SMBFromBytes for NegotiateContext { - fn smb_from_bytes(input: &[u8]) -> SMBParseResult<&[u8], Self> where Self: Sized { - if input.len() < 4 { return Err(SMBError::parse_error("Input too small")) } + fn smb_from_bytes(input: &[u8]) -> SMBParseResult<&[u8], Self> + where + Self: Sized, + { + if input.len() < 4 { + return Err(SMBError::parse_error("Input too small")); + } let (remaining, ctx_type) = u16::smb_from_bytes(input)?; let (_, ctx_len) = u16::smb_from_bytes(remaining)?; trace!(ctx_type, ctx_len, "parsing negotiate context"); - + match ctx_type { PRE_AUTH_INTEGRITY_CAPABILITIES_TAG => ctx_smb_from_bytes_enumify!( Self::PreAuthIntegrityCapabilities, @@ -135,7 +142,7 @@ impl SMBFromBytes for NegotiateContext { remaining, ctx_len ), - _ => Err(SMBError::parse_error("Invalid negotiate context type")) + _ => Err(SMBError::parse_error("Invalid negotiate context type")), } } } @@ -168,47 +175,76 @@ impl NegotiateContext { NegotiateContext::PosixExtensions(x) => x.byte_code(), } } - pub fn from_connection_state(connection: &SMBConnection, request_contexts: HashSet) -> Vec { + pub fn from_connection_state( + connection: &SMBConnection, + request_contexts: HashSet, + ) -> Vec { let mut response_contexts = Vec::new(); if request_contexts.contains(&PRE_AUTH_INTEGRITY_CAPABILITIES_TAG) { - response_contexts.push(Self::PreAuthIntegrityCapabilities(PreAuthIntegrityCapabilities::from_connection_state(connection))); + response_contexts.push(Self::PreAuthIntegrityCapabilities( + PreAuthIntegrityCapabilities::from_connection_state(connection), + )); } if request_contexts.contains(&ENCRYPTION_CAPABILITIES_TAG) { - response_contexts.push(Self::EncryptionCapabilities(EncryptionCapabilities::from_connection_state(connection))); + response_contexts.push(Self::EncryptionCapabilities( + EncryptionCapabilities::from_connection_state(connection), + )); } if request_contexts.contains(&COMPRESSION_CAPABILITIES_TAG) { - response_contexts.push(Self::CompressionCapabilities(CompressionCapabilities::from_connection_state(connection))); + response_contexts.push(Self::CompressionCapabilities( + CompressionCapabilities::from_connection_state(connection), + )); } if request_contexts.contains(&RDMA_TRANSFORM_CAPABILITIES_TAG) { - response_contexts.push(Self::RDMATransformCapabilities(RDMATransformCapabilities::from_connection_state(connection))); + response_contexts.push(Self::RDMATransformCapabilities( + RDMATransformCapabilities::from_connection_state(connection), + )); } if request_contexts.contains(&SIGNING_CAPABILITIES_TAG) { - response_contexts.push(Self::SigningCapabilities(SigningCapabilities::from_connection_state(connection))); + response_contexts.push(Self::SigningCapabilities( + SigningCapabilities::from_connection_state(connection), + )); } if request_contexts.contains(&TRANSPORT_CAPABILITIES_TAG) { - response_contexts.push(Self::TransportCapabilities(TransportCapabilities::from_connection_state(connection))); + response_contexts.push(Self::TransportCapabilities( + TransportCapabilities::from_connection_state(connection), + )); } if request_contexts.contains(&POSIX_EXTENSIONS_TAG) { - response_contexts.push(Self::PosixExtensions(PosixExtensions::from_connection_state(connection))); + response_contexts.push(Self::PosixExtensions( + PosixExtensions::from_connection_state(connection), + )); } response_contexts } - pub fn validate_and_set_state(&self, connection: SMBConnectionUpdate, server: &S) -> SMBResult<(SMBConnectionUpdate, bool)> { + pub fn validate_and_set_state( + &self, + connection: SMBConnectionUpdate, + server: &S, + ) -> SMBResult<(SMBConnectionUpdate, bool)> { match self { - NegotiateContext::PreAuthIntegrityCapabilities(x) => x.validate_and_set_state(connection), + NegotiateContext::PreAuthIntegrityCapabilities(x) => { + x.validate_and_set_state(connection) + } NegotiateContext::EncryptionCapabilities(x) => x.validate_and_set_state(connection), - NegotiateContext::CompressionCapabilities(x) => x.validate_and_set_state(connection, server), + NegotiateContext::CompressionCapabilities(x) => { + x.validate_and_set_state(connection, server) + } NegotiateContext::NetnameNegotiateContextID(_x) => Ok((connection, false)), NegotiateContext::TransportCapabilities(x) => x.validate_and_set_state(connection), - NegotiateContext::RDMATransformCapabilities(x) => x.validate_and_set_state(connection, server), + NegotiateContext::RDMATransformCapabilities(x) => { + x.validate_and_set_state(connection, server) + } NegotiateContext::SigningCapabilities(x) => x.validate_and_set_state(connection), NegotiateContext::PosixExtensions(x) => x.validate_and_set_state(connection), } } } -#[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes)] +#[derive( + Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes, +)] pub struct PreAuthIntegrityCapabilities { #[smb_skip(start = 0, length = 10)] reserved: PhantomData>, @@ -220,7 +256,19 @@ pub struct PreAuthIntegrityCapabilities { #[repr(u16)] #[derive( -Debug, Eq, PartialEq, TryFromPrimitive, Serialize, Deserialize, Copy, Clone, Ord, PartialOrd, SMBFromBytes, SMBByteSize, SMBToBytes + Debug, + Eq, + PartialEq, + TryFromPrimitive, + Serialize, + Deserialize, + Copy, + Clone, + Ord, + PartialOrd, + SMBFromBytes, + SMBByteSize, + SMBToBytes, )] pub enum HashAlgorithm { SHA512 = 0x01, @@ -231,7 +279,9 @@ impl PreAuthIntegrityCapabilities { PRE_AUTH_INTEGRITY_CAPABILITIES_TAG } - fn from_connection_state(connection: &SMBConnection) -> Self { + fn from_connection_state( + connection: &SMBConnection, + ) -> Self { let mut salt = vec![0_u8; 32]; rand::rngs::ThreadRng::default().fill_bytes(&mut salt); Self { @@ -241,18 +291,26 @@ impl PreAuthIntegrityCapabilities { } } - pub fn validate_and_set_state(&self, connection: SMBConnectionUpdate) -> SMBResult<(SMBConnectionUpdate, bool)> { + pub fn validate_and_set_state( + &self, + connection: SMBConnectionUpdate, + ) -> SMBResult<(SMBConnectionUpdate, bool)> { if let Some(algorithm) = self.hash_algorithms.first() { - Ok((connection + Ok(( + connection .preauth_integrity_hash_id(*algorithm) - .preauth_integrity_hash_value(Vec::new()), true)) + .preauth_integrity_hash_value(Vec::new()), + true, + )) } else { Err(SMBError::response_error(NTStatus::InvalidParameter)) } } } -#[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes)] +#[derive( + Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes, +)] pub struct EncryptionCapabilities { #[smb_skip(start = 0, length = 8)] reserved: PhantomData>, @@ -262,7 +320,19 @@ pub struct EncryptionCapabilities { #[repr(u16)] #[derive( -Debug, Eq, PartialEq, TryFromPrimitive, Serialize, Deserialize, Clone, Ord, PartialOrd, Copy, SMBFromBytes, SMBByteSize, SMBToBytes + Debug, + Eq, + PartialEq, + TryFromPrimitive, + Serialize, + Deserialize, + Clone, + Ord, + PartialOrd, + Copy, + SMBFromBytes, + SMBByteSize, + SMBToBytes, )] pub enum EncryptionCipher { None = 0x0, @@ -277,13 +347,18 @@ impl EncryptionCapabilities { ENCRYPTION_CAPABILITIES_TAG } - fn from_connection_state(connection: &SMBConnection) -> Self { + fn from_connection_state( + connection: &SMBConnection, + ) -> Self { Self { reserved: Default::default(), ciphers: vec![connection.cipher_id()], } } - pub fn validate_and_set_state(&self, connection: SMBConnectionUpdate) -> SMBResult<(SMBConnectionUpdate, bool)> { + pub fn validate_and_set_state( + &self, + connection: SMBConnectionUpdate, + ) -> SMBResult<(SMBConnectionUpdate, bool)> { let mut ciphers = self.ciphers.clone(); ciphers.sort(); ciphers.reverse(); @@ -295,7 +370,9 @@ impl EncryptionCapabilities { } } -#[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes)] +#[derive( + Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes, +)] pub struct CompressionCapabilities { #[smb_direct(start(fixed = 10))] pub(crate) flags: CompressionCapabilitiesFlags, @@ -305,7 +382,19 @@ pub struct CompressionCapabilities { #[repr(u32)] #[derive( -Debug, Eq, PartialEq, TryFromPrimitive, Serialize, Deserialize, Clone, Ord, PartialOrd, Copy, SMBFromBytes, SMBByteSize, SMBToBytes + Debug, + Eq, + PartialEq, + TryFromPrimitive, + Serialize, + Deserialize, + Clone, + Ord, + PartialOrd, + Copy, + SMBFromBytes, + SMBByteSize, + SMBToBytes, )] pub enum CompressionCapabilitiesFlags { None = 0x0, @@ -314,7 +403,19 @@ pub enum CompressionCapabilitiesFlags { #[repr(u16)] #[derive( -Debug, Eq, PartialEq, TryFromPrimitive, Serialize, Deserialize, Clone, Ord, PartialOrd, Copy, SMBFromBytes, SMBByteSize, SMBToBytes + Debug, + Eq, + PartialEq, + TryFromPrimitive, + Serialize, + Deserialize, + Clone, + Ord, + PartialOrd, + Copy, + SMBFromBytes, + SMBByteSize, + SMBToBytes, )] pub enum CompressionAlgorithm { None = 0x0, @@ -329,7 +430,9 @@ impl CompressionCapabilities { COMPRESSION_CAPABILITIES_TAG } - fn from_connection_state(connection: &SMBConnection) -> Self { + fn from_connection_state( + connection: &SMBConnection, + ) -> Self { let flags = if connection.supports_chained_compression() { CompressionCapabilitiesFlags::Chained } else { @@ -340,22 +443,35 @@ impl CompressionCapabilities { compression_algorithms: connection.compression_ids().clone(), } } - pub fn validate_and_set_state(&self, connection: SMBConnectionUpdate, server: &S) -> SMBResult<(SMBConnectionUpdate, bool)> { + pub fn validate_and_set_state( + &self, + connection: SMBConnectionUpdate, + server: &S, + ) -> SMBResult<(SMBConnectionUpdate, bool)> { if !server.compression_supported() { - return Ok((connection, false)) + return Ok((connection, false)); } if self.compression_algorithms.is_empty() { return Err(SMBError::response_error(NTStatus::InvalidParameter)); } - Ok((connection.compression_ids(self.compression_algorithms.clone()), true)) + Ok(( + connection.compression_ids(self.compression_algorithms.clone()), + true, + )) } } -#[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes)] +#[derive( + Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes, +)] pub struct NetnameNegotiateContextID { #[smb_skip(start = 0, length = 6)] reserved: PhantomData>, - #[smb_string(order = 1, length(inner(start = 0, num_type = "u16")), underlying = "u16")] + #[smb_string( + order = 1, + length(inner(start = 0, num_type = "u16")), + underlying = "u16" + )] pub(crate) netname: String, } @@ -365,7 +481,21 @@ impl NetnameNegotiateContextID { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize, SMBFromBytes, SMBByteSize, SMBToBytes)] +#[derive( + Debug, + Clone, + Copy, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + Deserialize, + Serialize, + SMBFromBytes, + SMBByteSize, + SMBToBytes, +)] pub struct TransportCapabilities { #[smb_direct(start = 0, length = 4)] pub(crate) flags: TransportCapabilitiesFlags, @@ -387,18 +517,24 @@ impl TransportCapabilities { TRANSPORT_CAPABILITIES_TAG } - fn from_connection_state(connection: &SMBConnection) -> Self { + fn from_connection_state( + connection: &SMBConnection, + ) -> Self { let flags = if connection.supports_chained_compression() { TransportCapabilitiesFlags::ACCEPT_TRANSPORT_LEVEL_SECURITY } else { TransportCapabilitiesFlags::empty() }; - Self { - flags, - } + Self { flags } } - pub fn validate_and_set_state(&self, connection: SMBConnectionUpdate) -> SMBResult<(SMBConnectionUpdate, bool)> { - if self.flags.contains(TransportCapabilitiesFlags::ACCEPT_TRANSPORT_LEVEL_SECURITY) { + pub fn validate_and_set_state( + &self, + connection: SMBConnectionUpdate, + ) -> SMBResult<(SMBConnectionUpdate, bool)> { + if self + .flags + .contains(TransportCapabilitiesFlags::ACCEPT_TRANSPORT_LEVEL_SECURITY) + { Ok((connection.accept_transport_security(true), true)) } else { Ok((connection, true)) @@ -406,7 +542,9 @@ impl TransportCapabilities { } } -#[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes)] +#[derive( + Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes, +)] pub struct RDMATransformCapabilities { #[smb_skip(start = 0, length = 14)] reserved: PhantomData>, @@ -415,7 +553,19 @@ pub struct RDMATransformCapabilities { } #[repr(u16)] -#[derive(Debug, Eq, PartialEq, TryFromPrimitive, Serialize, Deserialize, Clone, Copy, SMBFromBytes, SMBByteSize, SMBToBytes)] +#[derive( + Debug, + Eq, + PartialEq, + TryFromPrimitive, + Serialize, + Deserialize, + Clone, + Copy, + SMBFromBytes, + SMBByteSize, + SMBToBytes, +)] pub enum RDMATransformID { None = 0x0, Encryption, @@ -426,7 +576,9 @@ impl RDMATransformCapabilities { fn byte_code(&self) -> u16 { RDMA_TRANSFORM_CAPABILITIES_TAG } - fn from_connection_state(connection: &SMBConnection) -> Self { + fn from_connection_state( + connection: &SMBConnection, + ) -> Self { let transform_ids = if connection.rdma_transform_ids().is_empty() { vec![RDMATransformID::None] } else { @@ -437,18 +589,27 @@ impl RDMATransformCapabilities { transform_ids, } } - pub fn validate_and_set_state(&self, connection: SMBConnectionUpdate, server: &S) -> SMBResult<(SMBConnectionUpdate, bool)> { + pub fn validate_and_set_state( + &self, + connection: SMBConnectionUpdate, + server: &S, + ) -> SMBResult<(SMBConnectionUpdate, bool)> { if !server.rdma_transform_supported() { - return Ok((connection, false)) + return Ok((connection, false)); } if self.transform_ids.is_empty() { return Err(SMBError::response_error(NTStatus::InvalidParameter)); } - Ok((connection.rdma_transform_ids(self.transform_ids.clone()), true)) + Ok(( + connection.rdma_transform_ids(self.transform_ids.clone()), + true, + )) } } -#[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes)] +#[derive( + Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes, +)] pub struct SigningCapabilities { #[smb_skip(start = 0, length = 8)] reserved: PhantomData>, @@ -458,7 +619,19 @@ pub struct SigningCapabilities { #[repr(u16)] #[derive( -Debug, Eq, PartialEq, TryFromPrimitive, Serialize, Deserialize, Clone, Ord, PartialOrd, Copy, SMBFromBytes, SMBByteSize, SMBToBytes + Debug, + Eq, + PartialEq, + TryFromPrimitive, + Serialize, + Deserialize, + Clone, + Ord, + PartialOrd, + Copy, + SMBFromBytes, + SMBByteSize, + SMBToBytes, )] pub enum SigningAlgorithm { HmacSha256 = 0x0, @@ -471,23 +644,33 @@ impl SigningCapabilities { SIGNING_CAPABILITIES_TAG } - fn from_connection_state(connection: &SMBConnection) -> Self { + fn from_connection_state( + connection: &SMBConnection, + ) -> Self { Self { reserved: PhantomData, signing_algorithms: vec![connection.signing_algorithm_id()], } } - pub fn validate_and_set_state(&self, connection: SMBConnectionUpdate) -> SMBResult<(SMBConnectionUpdate, bool)> { + pub fn validate_and_set_state( + &self, + connection: SMBConnectionUpdate, + ) -> SMBResult<(SMBConnectionUpdate, bool)> { if self.signing_algorithms.is_empty() { return Err(SMBError::response_error(NTStatus::InvalidParameter)); } let mut algorithms = self.signing_algorithms.clone(); algorithms.sort(); - Ok((connection.signing_algorithm_id(*algorithms.first().unwrap()), true)) + Ok(( + connection.signing_algorithm_id(*algorithms.first().unwrap()), + true, + )) } } -#[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBToBytes, SMBByteSize)] +#[derive( + Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBToBytes, SMBByteSize, +)] pub struct PosixExtensions { #[smb_skip(start = 0, length = 6)] reserved: PhantomData>, @@ -496,19 +679,26 @@ pub struct PosixExtensions { } impl PosixExtensions { - fn byte_code(&self) -> u16 { POSIX_EXTENSIONS_TAG } + fn byte_code(&self) -> u16 { + POSIX_EXTENSIONS_TAG + } - pub fn validate_and_set_state(&self, connection: SMBConnectionUpdate) -> SMBResult<(SMBConnectionUpdate, bool)> { + pub fn validate_and_set_state( + &self, + connection: SMBConnectionUpdate, + ) -> SMBResult<(SMBConnectionUpdate, bool)> { if self.posix_reserved.is_empty() { return Err(SMBError::response_error(NTStatus::InvalidParameter)); } let connection = connection.posix_extension_payload(self.posix_reserved.clone()); Ok((connection, true)) } - fn from_connection_state(connection: &SMBConnection) -> Self { + fn from_connection_state( + connection: &SMBConnection, + ) -> Self { Self { reserved: PhantomData, posix_reserved: connection.posix_extension_payload().to_vec(), } } -} \ No newline at end of file +} diff --git a/smb/src/protocol/body/negotiate/mod.rs b/smb/src/protocol/body/negotiate/mod.rs index e0ad207..c1ceef8 100644 --- a/smb/src/protocol/body/negotiate/mod.rs +++ b/smb/src/protocol/body/negotiate/mod.rs @@ -6,9 +6,9 @@ use serde::{Deserialize, Serialize}; use sha2::Sha512; use uuid::Uuid; -use smb_core::{SMBResult, SMBToBytes}; use smb_core::error::SMBError; use smb_core::nt_status::NTStatus; +use smb_core::{SMBResult, SMBToBytes}; use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; use crate::protocol::body::capabilities::Capabilities; @@ -16,8 +16,8 @@ use crate::protocol::body::dialect::SMBDialect; use crate::protocol::body::filetime::FileTime; use crate::protocol::body::negotiate::context::NegotiateContext; use crate::protocol::body::negotiate::security_mode::NegotiateSecurityMode; -use crate::server::connection::{Connection, SMBConnection, SMBConnectionUpdate}; use crate::server::Server; +use crate::server::connection::{Connection, SMBConnection, SMBConnectionUpdate}; use crate::socket::message_stream::{SMBReadStream, SMBWriteStream}; use crate::util::auth::AuthProvider; use crate::util::auth::spnego::{SPNEGOToken, SPNEGOTokenInitBody}; @@ -26,15 +26,7 @@ pub mod context; pub mod security_mode; #[derive( - Serialize, - Deserialize, - PartialEq, - Eq, - Debug, - SMBFromBytes, - SMBByteSize, - SMBToBytes, - Clone + Serialize, Deserialize, PartialEq, Eq, Debug, SMBFromBytes, SMBByteSize, SMBToBytes, Clone, )] #[smb_byte_tag(value = 36)] pub struct SMBNegotiateRequest { @@ -48,12 +40,21 @@ pub struct SMBNegotiateRequest { reserved: PhantomData>, #[smb_vector(order = 1, count(inner(start = 2, num_type = "u16")))] pub(crate) dialects: Vec, - #[smb_vector(order = 2, align = 8, count(inner(start = 32, num_type = "u16")), offset(inner(start = 28, num_type = "u32", subtract = 64)))] + #[smb_vector( + order = 2, + align = 8, + count(inner(start = 32, num_type = "u16")), + offset(inner(start = 28, num_type = "u32", subtract = 64)) + )] negotiate_contexts: Vec, } impl SMBNegotiateRequest { - pub fn validate_and_set_state(&self, connection: &SMBConnection, server: &S) -> SMBResult<(SMBConnectionUpdate, HashSet)> { + pub fn validate_and_set_state( + &self, + connection: &SMBConnection, + server: &S, + ) -> SMBResult<(SMBConnectionUpdate, HashSet)> { if connection.negotiate_dialect() != SMBDialect::default() { return Err(SMBError::response_error(NTStatus::AccessDenied)); } @@ -94,7 +95,10 @@ impl SMBNegotiateRequest { if self.capabilities.contains(Capabilities::PERSISTENT_HANDLES) { capabilities |= Capabilities::PERSISTENT_HANDLES; } - if connection.dialect() != SMBDialect::V3_1_1 && server.encryption_supported() && capabilities.contains(Capabilities::ENCRYPTION) { + if connection.dialect() != SMBDialect::V3_1_1 + && server.encryption_supported() + && capabilities.contains(Capabilities::ENCRYPTION) + { capabilities |= Capabilities::ENCRYPTION; } } @@ -114,7 +118,10 @@ impl SMBNegotiateRequest { .client_dialects(dialects) .client_capabilities(self.capabilities) .client_guid(self.client_uuid) - .should_sign(self.security_mode.contains(NegotiateSecurityMode::NEGOTIATE_SIGNING_REQUIRED)) + .should_sign( + self.security_mode + .contains(NegotiateSecurityMode::NEGOTIATE_SIGNING_REQUIRED), + ) .server_capabilites(capabilities) .max_read_size(8388608) .max_write_size(8388608) @@ -126,15 +133,7 @@ impl SMBNegotiateRequest { } #[derive( - Serialize, - Deserialize, - PartialEq, - Eq, - Debug, - SMBToBytes, - SMBByteSize, - SMBFromBytes, - Clone + Serialize, Deserialize, PartialEq, Eq, Debug, SMBToBytes, SMBByteSize, SMBFromBytes, Clone, )] #[smb_byte_tag(value = 65)] pub struct SMBNegotiateResponse { @@ -159,10 +158,15 @@ pub struct SMBNegotiateResponse { #[smb_buffer( offset(inner(start = 56, num_type = "u16", subtract = 64, min_val = 128)), length(inner(start = 58, num_type = "u16")), - order = 1) - ] + order = 1 + )] buffer: Vec, - #[smb_vector(order = 2, align = 8, count(inner(start = 6, num_type = "u16")), offset(inner(start = 60, num_type = "u32", subtract = 64)))] + #[smb_vector( + order = 2, + align = 8, + count(inner(start = 6, num_type = "u16")), + offset(inner(start = 60, num_type = "u32", subtract = 64)) + )] negotiate_contexts: Vec, } @@ -183,9 +187,19 @@ impl SMBNegotiateResponse { } } - pub fn from_connection_state(connection: &SMBConnection, server: &S, negotiate_contexts: HashSet) -> Self { + pub fn from_connection_state< + A: AuthProvider, + R: SMBReadStream, + W: SMBWriteStream, + S: Server, + >( + connection: &SMBConnection, + server: &S, + negotiate_contexts: HashSet, + ) -> Self { let buffer = SPNEGOToken::Init(SPNEGOTokenInitBody::::new()).as_bytes(true); - let negotiate_contexts = NegotiateContext::from_connection_state(connection, negotiate_contexts); + let negotiate_contexts = + NegotiateContext::from_connection_state(connection, negotiate_contexts); Self { security_mode: connection.server_security_mode(), dialect: connection.dialect(), @@ -201,4 +215,4 @@ impl SMBNegotiateResponse { negotiate_contexts, } } -} \ No newline at end of file +} diff --git a/smb/src/protocol/body/negotiate/security_mode.rs b/smb/src/protocol/body/negotiate/security_mode.rs index cdd4919..435c2ad 100644 --- a/smb/src/protocol/body/negotiate/security_mode.rs +++ b/smb/src/protocol/body/negotiate/security_mode.rs @@ -1,7 +1,9 @@ use bitflags::bitflags; use serde::{Deserialize, Serialize}; -use crate::util::flags_helper::{impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag}; +use crate::util::flags_helper::{ + impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag, +}; bitflags! { #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)] @@ -22,8 +24,14 @@ mod tests { #[test] fn security_mode_values() { - assert_eq!(NegotiateSecurityMode::NEGOTIATE_SIGNING_ENABLED.bits(), 0x0001); - assert_eq!(NegotiateSecurityMode::NEGOTIATE_SIGNING_REQUIRED.bits(), 0x0002); + assert_eq!( + NegotiateSecurityMode::NEGOTIATE_SIGNING_ENABLED.bits(), + 0x0001 + ); + assert_eq!( + NegotiateSecurityMode::NEGOTIATE_SIGNING_REQUIRED.bits(), + 0x0002 + ); } #[test] @@ -35,4 +43,4 @@ mod tests { let (_, parsed) = NegotiateSecurityMode::smb_from_bytes(&bytes).unwrap(); assert_eq!(parsed, mode); } -} \ No newline at end of file +} diff --git a/smb/src/protocol/body/oplock_break/mod.rs b/smb/src/protocol/body/oplock_break/mod.rs index f6f4bb9..ee30efb 100644 --- a/smb/src/protocol/body/oplock_break/mod.rs +++ b/smb/src/protocol/body/oplock_break/mod.rs @@ -10,15 +10,7 @@ use crate::protocol::body::oplock_break::oplock_level::SMBOplockLevel; mod oplock_level; #[derive( - Debug, - PartialEq, - Eq, - SMBByteSize, - SMBToBytes, - SMBFromBytes, - Serialize, - Deserialize, - Clone + Debug, PartialEq, Eq, SMBByteSize, SMBToBytes, SMBFromBytes, Serialize, Deserialize, Clone, )] #[smb_byte_tag(value = 24)] pub struct SMBOplockBreakContent { @@ -32,4 +24,4 @@ pub struct SMBOplockBreakContent { file_id: SMBFileId, } -pub type SMBOplockBreakAcknowledgement = SMBOplockBreakContent; \ No newline at end of file +pub type SMBOplockBreakAcknowledgement = SMBOplockBreakContent; diff --git a/smb/src/protocol/body/oplock_break/oplock_level.rs b/smb/src/protocol/body/oplock_break/oplock_level.rs index 965765d..9974730 100644 --- a/smb/src/protocol/body/oplock_break/oplock_level.rs +++ b/smb/src/protocol/body/oplock_break/oplock_level.rs @@ -4,9 +4,21 @@ use serde::{Deserialize, Serialize}; use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; #[repr(u8)] -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, TryFromPrimitive, SMBFromBytes, SMBByteSize, SMBToBytes)] +#[derive( + Debug, + Clone, + Copy, + PartialEq, + Eq, + Serialize, + Deserialize, + TryFromPrimitive, + SMBFromBytes, + SMBByteSize, + SMBToBytes, +)] pub enum SMBOplockLevel { None = 0x0, II = 0x1, Exclusive = 0x8, -} \ No newline at end of file +} diff --git a/smb/src/protocol/body/query_directory/flags.rs b/smb/src/protocol/body/query_directory/flags.rs index 7d62481..b8011e3 100644 --- a/smb/src/protocol/body/query_directory/flags.rs +++ b/smb/src/protocol/body/query_directory/flags.rs @@ -1,7 +1,9 @@ use bitflags::bitflags; use serde::{Deserialize, Serialize}; -use crate::util::flags_helper::{impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag}; +use crate::util::flags_helper::{ + impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag, +}; bitflags! { #[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone, Copy)] @@ -15,4 +17,4 @@ bitflags! { impl_smb_from_bytes_for_bitflag!(SMBQueryDirectoryFlags); impl_smb_to_bytes_for_bitflag!(SMBQueryDirectoryFlags); -impl_smb_byte_size_for_bitflag!(SMBQueryDirectoryFlags); \ No newline at end of file +impl_smb_byte_size_for_bitflag!(SMBQueryDirectoryFlags); diff --git a/smb/src/protocol/body/query_directory/information_class.rs b/smb/src/protocol/body/query_directory/information_class.rs index 2805817..c11ded4 100644 --- a/smb/src/protocol/body/query_directory/information_class.rs +++ b/smb/src/protocol/body/query_directory/information_class.rs @@ -5,7 +5,21 @@ use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; // TODO this needs to be a discrim for an enum based type for here and for QueryInfo #[repr(u8)] -#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, SMBFromBytes, SMBToBytes, SMBByteSize, TryFromPrimitive, Serialize, Deserialize)] +#[derive( + Debug, + Copy, + Clone, + PartialEq, + Eq, + PartialOrd, + Ord, + SMBFromBytes, + SMBToBytes, + SMBByteSize, + TryFromPrimitive, + Serialize, + Deserialize, +)] pub enum SMBInformationClass { FileDirectoryInformation = 0x1, FullFileificateInformation = 0x2, @@ -16,4 +30,4 @@ pub enum SMBInformationClass { FileIdExtdDirectoryInformation = 0x3C, // Must never be used and ignored on receipt FileInformationClassReserved = 0x64, -} \ No newline at end of file +} diff --git a/smb/src/protocol/body/query_directory/mod.rs b/smb/src/protocol/body/query_directory/mod.rs index db60085..9281aac 100644 --- a/smb/src/protocol/body/query_directory/mod.rs +++ b/smb/src/protocol/body/query_directory/mod.rs @@ -8,19 +8,11 @@ use crate::protocol::body::create::file_id::SMBFileId; use crate::protocol::body::query_directory::flags::SMBQueryDirectoryFlags; use crate::protocol::body::query_directory::information_class::SMBInformationClass; -mod information_class; mod flags; +mod information_class; #[derive( - Debug, - PartialEq, - Eq, - SMBByteSize, - SMBToBytes, - SMBFromBytes, - Serialize, - Deserialize, - Clone + Debug, PartialEq, Eq, SMBByteSize, SMBToBytes, SMBFromBytes, Serialize, Deserialize, Clone, )] #[smb_byte_tag(value = 33)] pub struct SMBQueryDirectoryRequest { @@ -34,26 +26,26 @@ pub struct SMBQueryDirectoryRequest { file_id: SMBFileId, #[smb_direct(start(fixed = 28))] max_output_len: u32, - #[smb_string(order = 0, start(inner(start = 24, num_type = "u16", subtract = 64)), length(inner(start = 26, num_type = "u16")), underlying = "u16")] + #[smb_string( + order = 0, + start(inner(start = 24, num_type = "u16", subtract = 64)), + length(inner(start = 26, num_type = "u16")), + underlying = "u16" + )] search_pattern: String, } #[derive( - Debug, - PartialEq, - Eq, - SMBByteSize, - SMBToBytes, - SMBFromBytes, - Serialize, - Deserialize, - Clone + Debug, PartialEq, Eq, SMBByteSize, SMBToBytes, SMBFromBytes, Serialize, Deserialize, Clone, )] #[smb_byte_tag(value = 9)] pub struct SMBQueryDirectoryResponse { #[smb_skip(start = 0, length = 8)] output_info: PhantomData>, // TODO make this a file directory class https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-smb2/4f75351b-048c-4a0c-9ea3-addd55a71956 - #[smb_buffer(offset(inner(start = 2, num_type = "u16", subtract = 64)), length(inner(start = 4, num_type = "u32")))] + #[smb_buffer( + offset(inner(start = 2, num_type = "u16", subtract = 64)), + length(inner(start = 4, num_type = "u32")) + )] buffer: Vec, -} \ No newline at end of file +} diff --git a/smb/src/protocol/body/query_info/flags.rs b/smb/src/protocol/body/query_info/flags.rs index e74beb6..f8effec 100644 --- a/smb/src/protocol/body/query_info/flags.rs +++ b/smb/src/protocol/body/query_info/flags.rs @@ -1,7 +1,9 @@ use bitflags::bitflags; use serde::{Deserialize, Serialize}; -use crate::util::flags_helper::{impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag}; +use crate::util::flags_helper::{ + impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag, +}; bitflags! { #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] diff --git a/smb/src/protocol/body/query_info/info_type.rs b/smb/src/protocol/body/query_info/info_type.rs index 3550893..a23d556 100644 --- a/smb/src/protocol/body/query_info/info_type.rs +++ b/smb/src/protocol/body/query_info/info_type.rs @@ -4,10 +4,24 @@ use serde::{Deserialize, Serialize}; use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; #[repr(u8)] -#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, TryFromPrimitive, SMBToBytes, SMBFromBytes, SMBByteSize, Serialize, Deserialize)] +#[derive( + Debug, + Copy, + Clone, + PartialEq, + Eq, + PartialOrd, + Ord, + TryFromPrimitive, + SMBToBytes, + SMBFromBytes, + SMBByteSize, + Serialize, + Deserialize, +)] pub enum SMBInfoType { File, Filesystem, Security, Quota, -} \ No newline at end of file +} diff --git a/smb/src/protocol/body/query_info/mod.rs b/smb/src/protocol/body/query_info/mod.rs index b4dd070..b66b2af 100644 --- a/smb/src/protocol/body/query_info/mod.rs +++ b/smb/src/protocol/body/query_info/mod.rs @@ -14,15 +14,7 @@ mod info_type; mod security_information; #[derive( - Debug, - PartialEq, - Eq, - SMBByteSize, - SMBToBytes, - SMBFromBytes, - Serialize, - Deserialize, - Clone + Debug, PartialEq, Eq, SMBByteSize, SMBToBytes, SMBFromBytes, Serialize, Deserialize, Clone, )] #[smb_byte_tag(value = 41)] pub struct SMBQueryInfoRequest { @@ -40,26 +32,25 @@ pub struct SMBQueryInfoRequest { flags: SMBQueryInfoFlags, #[smb_direct(start(fixed = 24))] file_id: SMBFileId, - #[smb_buffer(offset(inner(start = 8, num_type = "u16", subtract = 64)), length(inner(start = 12, num_type = "u32")))] + #[smb_buffer( + offset(inner(start = 8, num_type = "u16", subtract = 64)), + length(inner(start = 12, num_type = "u32")) + )] buffer: Vec, } #[derive( - Debug, - PartialEq, - Eq, - SMBByteSize, - SMBToBytes, - SMBFromBytes, - Serialize, - Deserialize, - Clone + Debug, PartialEq, Eq, SMBByteSize, SMBToBytes, SMBFromBytes, Serialize, Deserialize, Clone, )] #[smb_byte_tag(value = 17)] pub struct SMBQueryInfoResponse { #[smb_skip(start = 2, length = 6)] reserved: PhantomData>, // TODO make this a struct: https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-smb2/3b1b3598-a898-44ca-bfac-2dcae065247f - #[smb_buffer(order = 0, offset(inner(start = 2, num_type = "u16", subtract = 64)), length(inner(start = 4, num_type = "u32")))] + #[smb_buffer( + order = 0, + offset(inner(start = 2, num_type = "u16", subtract = 64)), + length(inner(start = 4, num_type = "u32")) + )] data: Vec, -} \ No newline at end of file +} diff --git a/smb/src/protocol/body/query_info/security_information.rs b/smb/src/protocol/body/query_info/security_information.rs index 1f2a26e..a1b92a4 100644 --- a/smb/src/protocol/body/query_info/security_information.rs +++ b/smb/src/protocol/body/query_info/security_information.rs @@ -1,7 +1,9 @@ use bitflags::bitflags; use serde::{Deserialize, Serialize}; -use crate::util::flags_helper::{impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag}; +use crate::util::flags_helper::{ + impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag, +}; bitflags! { #[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone, Copy)] diff --git a/smb/src/protocol/body/read/channel.rs b/smb/src/protocol/body/read/channel.rs index 12939ac..1ac000f 100644 --- a/smb/src/protocol/body/read/channel.rs +++ b/smb/src/protocol/body/read/channel.rs @@ -4,9 +4,21 @@ use serde::{Deserialize, Serialize}; use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; #[repr(u32)] -#[derive(Debug, PartialEq, Eq, SMBByteSize, SMBToBytes, SMBFromBytes, Serialize, Deserialize, TryFromPrimitive, Copy, Clone)] +#[derive( + Debug, + PartialEq, + Eq, + SMBByteSize, + SMBToBytes, + SMBFromBytes, + Serialize, + Deserialize, + TryFromPrimitive, + Copy, + Clone, +)] pub enum SMBRWChannel { None = 0x0, RdmaV1 = 0x1, RdmaV1Invalidate = 0x2, -} \ No newline at end of file +} diff --git a/smb/src/protocol/body/read/flags.rs b/smb/src/protocol/body/read/flags.rs index 16a994a..d2c1c25 100644 --- a/smb/src/protocol/body/read/flags.rs +++ b/smb/src/protocol/body/read/flags.rs @@ -4,7 +4,9 @@ use serde::{Deserialize, Serialize}; use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; -use crate::util::flags_helper::{impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag}; +use crate::util::flags_helper::{ + impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag, +}; bitflags! { #[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone, Copy)] @@ -15,7 +17,19 @@ bitflags! { } #[repr(u32)] -#[derive(Debug, PartialEq, Eq, Serialize, Deserialize, SMBToBytes, SMBFromBytes, SMBByteSize, TryFromPrimitive, Copy, Clone)] +#[derive( + Debug, + PartialEq, + Eq, + Serialize, + Deserialize, + SMBToBytes, + SMBFromBytes, + SMBByteSize, + TryFromPrimitive, + Copy, + Clone, +)] pub enum SMBReadResponseFlags { None = 0x0, RdmaTransform = 0x01, diff --git a/smb/src/protocol/body/read/mod.rs b/smb/src/protocol/body/read/mod.rs index 2e9a64d..14ea887 100644 --- a/smb/src/protocol/body/read/mod.rs +++ b/smb/src/protocol/body/read/mod.rs @@ -8,19 +8,11 @@ use crate::protocol::body::create::file_id::SMBFileId; use crate::protocol::body::read::channel::SMBRWChannel; use crate::protocol::body::read::flags::{SMBReadRequestFlags, SMBReadResponseFlags}; -mod flags; pub mod channel; +mod flags; #[derive( - Debug, - PartialEq, - Eq, - SMBByteSize, - SMBToBytes, - SMBFromBytes, - Serialize, - Deserialize, - Clone + Debug, PartialEq, Eq, SMBByteSize, SMBToBytes, SMBFromBytes, Serialize, Deserialize, Clone, )] #[smb_byte_tag(value = 49)] pub struct SMBReadRequest { @@ -38,20 +30,15 @@ pub struct SMBReadRequest { channel: SMBRWChannel, #[smb_direct(start(fixed = 40))] remaining_bytes: u32, - #[smb_buffer(offset(inner(start = 44, num_type = "u16", subtract = 64)), length(inner(start = 46, num_type = "u16")))] + #[smb_buffer( + offset(inner(start = 44, num_type = "u16", subtract = 64)), + length(inner(start = 46, num_type = "u16")) + )] channel_information: Vec, } #[derive( - Debug, - PartialEq, - Eq, - SMBByteSize, - SMBToBytes, - SMBFromBytes, - Serialize, - Deserialize, - Clone + Debug, PartialEq, Eq, SMBByteSize, SMBToBytes, SMBFromBytes, Serialize, Deserialize, Clone, )] #[smb_byte_tag(value = 17)] pub struct SMBReadResponse { @@ -61,6 +48,10 @@ pub struct SMBReadResponse { data_remaining: u32, #[smb_direct(start(fixed = 12))] flags: SMBReadResponseFlags, - #[smb_buffer(order = 0, offset(inner(start = 2, num_type = "u8", subtract = 64)), length(inner(start = 4, num_type = "u32")))] + #[smb_buffer( + order = 0, + offset(inner(start = 2, num_type = "u8", subtract = 64)), + length(inner(start = 4, num_type = "u32")) + )] data: Vec, -} \ No newline at end of file +} diff --git a/smb/src/protocol/body/session_setup/flags.rs b/smb/src/protocol/body/session_setup/flags.rs index 8aa084b..fd49b56 100644 --- a/smb/src/protocol/body/session_setup/flags.rs +++ b/smb/src/protocol/body/session_setup/flags.rs @@ -1,7 +1,9 @@ use bitflags::bitflags; use serde::{Deserialize, Serialize}; -use crate::util::flags_helper::{impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag}; +use crate::util::flags_helper::{ + impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag, +}; bitflags! { #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)] @@ -21,4 +23,4 @@ bitflags! { impl_smb_byte_size_for_bitflag! {SMBSessionSetupFlags SMBSessionFlags} impl_smb_from_bytes_for_bitflag! {SMBSessionSetupFlags SMBSessionFlags} -impl_smb_to_bytes_for_bitflag! {SMBSessionSetupFlags SMBSessionFlags} \ No newline at end of file +impl_smb_to_bytes_for_bitflag! {SMBSessionSetupFlags SMBSessionFlags} diff --git a/smb/src/protocol/body/session_setup/mod.rs b/smb/src/protocol/body/session_setup/mod.rs index 253858b..1832107 100644 --- a/smb/src/protocol/body/session_setup/mod.rs +++ b/smb/src/protocol/body/session_setup/mod.rs @@ -4,36 +4,28 @@ use digest::Digest; use serde::{Deserialize, Serialize}; use sha2::Sha512; -use smb_core::{SMBResult, SMBToBytes}; use smb_core::error::SMBError; use smb_core::nt_status::NTStatus; +use smb_core::{SMBResult, SMBToBytes}; use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; use crate::protocol::body::capabilities::Capabilities; use crate::protocol::body::dialect::SMBDialect; use crate::protocol::body::session_setup::flags::{SMBSessionFlags, SMBSessionSetupFlags}; use crate::protocol::body::session_setup::security_mode::SessionSetupSecurityMode; -use crate::protocol::header::flags::SMBFlags; use crate::protocol::header::SMBSyncHeader; +use crate::protocol::header::flags::SMBFlags; +use crate::server::Server; use crate::server::connection::{Connection, SMBConnection, SMBConnectionUpdate}; use crate::server::preauth_session::SMBPreauthSession; -use crate::server::Server; use crate::server::session::{Session, SessionState}; use crate::socket::message_stream::{SMBReadStream, SMBWriteStream}; -pub mod security_mode; pub mod flags; +pub mod security_mode; #[derive( - Serialize, - Deserialize, - PartialEq, - Eq, - Debug, - SMBFromBytes, - SMBByteSize, - SMBToBytes, - Clone + Serialize, Deserialize, PartialEq, Eq, Debug, SMBFromBytes, SMBByteSize, SMBToBytes, Clone, )] #[smb_byte_tag(value = 25)] pub struct SMBSessionSetupRequest { @@ -45,12 +37,21 @@ pub struct SMBSessionSetupRequest { capabilities: Capabilities, #[smb_direct(start(fixed = 16))] previous_session_id: u64, - #[smb_buffer(offset(inner(start = 12, num_type = "u16", subtract = 64)), length(inner(start = 14, num_type = "u16")))] + #[smb_buffer( + offset(inner(start = 12, num_type = "u16", subtract = 64)), + length(inner(start = 14, num_type = "u16")) + )] buffer: Vec, } impl SMBSessionSetupRequest { - pub fn new(flags: SMBSessionSetupFlags, security_mode: SessionSetupSecurityMode, capabilities: Capabilities, previous_session_id: u64, buffer: Vec) -> Self { + pub fn new( + flags: SMBSessionSetupFlags, + security_mode: SessionSetupSecurityMode, + capabilities: Capabilities, + previous_session_id: u64, + buffer: Vec, + ) -> Self { Self { flags, security_mode, @@ -65,19 +66,37 @@ impl SMBSessionSetupRequest { pub fn flags(&self) -> SMBSessionSetupFlags { self.flags } - pub async fn validate_and_set_state>>(&self, connection: &SMBConnection, server: &S, session: &S::Session, header: &SMBSyncHeader) -> SMBResult> { + pub async fn validate_and_set_state< + R: SMBReadStream, + W: SMBWriteStream, + S: Server>, + >( + &self, + connection: &SMBConnection, + server: &S, + session: &S::Session, + header: &SMBSyncHeader, + ) -> SMBResult> { let mut update = SMBConnectionUpdate::default(); - if server.encrypt_data() && (!server.unencrypted_access() - && (connection.dialect().is_smb3() - || !connection.client_capabilities().contains(Capabilities::ENCRYPTION))) { + if server.encrypt_data() + && (!server.unencrypted_access() + && (connection.dialect().is_smb3() + || !connection + .client_capabilities() + .contains(Capabilities::ENCRYPTION))) + { return Err(SMBError::response_error(NTStatus::AccessDenied)); } - if connection.dialect().is_smb3() && server.multi_channel_capable() && self.flags.contains(SMBSessionSetupFlags::BINDING) { + if connection.dialect().is_smb3() + && server.multi_channel_capable() + && self.flags.contains(SMBSessionSetupFlags::BINDING) + { let locked_conn = session.connection_res()?; let session_conn = locked_conn.read().await; - if session_conn.dialect() != connection.dialect() || - header.flags.contains(SMBFlags::SIGNED) { + if session_conn.dialect() != connection.dialect() + || header.flags.contains(SMBFlags::SIGNED) + { return Err(SMBError::response_error(NTStatus::InvalidParameter)); } if session_conn.client_guid() != connection.client_guid() { @@ -88,19 +107,22 @@ impl SMBSessionSetupRequest { } if session.state() == SessionState::Expired { - return Err(SMBError::response_error(NTStatus::NetworkSessionExpired)) + return Err(SMBError::response_error(NTStatus::NetworkSessionExpired)); } if session.anonymous() || session.guest() { return Err(SMBError::response_error(NTStatus::NotSupported)); } - if connection.dialect() == SMBDialect::V3_1_1 && !connection.preauth_sessions().contains_key(&session.id()) { + if connection.dialect() == SMBDialect::V3_1_1 + && !connection.preauth_sessions().contains_key(&session.id()) + { let mut sha = Sha512::default(); sha.update(connection.preauth_integtiry_hash_value()); sha.update(self.smb_to_bytes()); let bytes = sha.finalize().to_vec(); let preauth_session = SMBPreauthSession::new(session.id(), bytes); - update = update.preauth_session_table(HashMap::from([(session.id(), preauth_session)])); + update = + update.preauth_session_table(HashMap::from([(session.id(), preauth_session)])); } } Ok(update) @@ -108,21 +130,16 @@ impl SMBSessionSetupRequest { } #[derive( - Serialize, - Deserialize, - PartialEq, - Eq, - Debug, - SMBToBytes, - SMBFromBytes, - SMBByteSize, - Clone + Serialize, Deserialize, PartialEq, Eq, Debug, SMBToBytes, SMBFromBytes, SMBByteSize, Clone, )] #[smb_byte_tag(value = 9)] pub struct SMBSessionSetupResponse { #[smb_direct(start(fixed = 2))] session_flags: SMBSessionFlags, - #[smb_buffer(offset(inner(start = 4, num_type = "u16", subtract = 64, min_val = 72)), length(inner(start = 6, num_type = "u16")))] + #[smb_buffer( + offset(inner(start = 4, num_type = "u16", subtract = 64, min_val = 72)), + length(inner(start = 6, num_type = "u16")) + )] buffer: Vec, } @@ -157,4 +174,4 @@ impl SMBSessionSetupResponse { buffer: token, }) } -} \ No newline at end of file +} diff --git a/smb/src/protocol/body/session_setup/security_mode.rs b/smb/src/protocol/body/session_setup/security_mode.rs index 8b265ca..1558bd9 100644 --- a/smb/src/protocol/body/session_setup/security_mode.rs +++ b/smb/src/protocol/body/session_setup/security_mode.rs @@ -1,7 +1,9 @@ use bitflags::bitflags; use serde::{Deserialize, Serialize}; -use crate::util::flags_helper::{impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag}; +use crate::util::flags_helper::{ + impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag, +}; bitflags! { #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)] diff --git a/smb/src/protocol/body/set_info/info_type.rs b/smb/src/protocol/body/set_info/info_type.rs index f51ce79..79738e7 100644 --- a/smb/src/protocol/body/set_info/info_type.rs +++ b/smb/src/protocol/body/set_info/info_type.rs @@ -4,10 +4,22 @@ use serde::{Deserialize, Serialize}; use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; #[repr(u8)] -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, TryFromPrimitive, SMBFromBytes, SMBByteSize, SMBToBytes)] +#[derive( + Debug, + Clone, + Copy, + PartialEq, + Eq, + Serialize, + Deserialize, + TryFromPrimitive, + SMBFromBytes, + SMBByteSize, + SMBToBytes, +)] pub enum SMBInfoType { File = 0x1, FileSystem = 0x2, Security = 0x3, Quota = 0x4, -} \ No newline at end of file +} diff --git a/smb/src/protocol/body/set_info/mod.rs b/smb/src/protocol/body/set_info/mod.rs index 7b09a3f..5d6f147 100644 --- a/smb/src/protocol/body/set_info/mod.rs +++ b/smb/src/protocol/body/set_info/mod.rs @@ -10,15 +10,7 @@ use crate::protocol::body::set_info::info_type::SMBInfoType; mod info_type; #[derive( - Debug, - PartialEq, - Eq, - SMBByteSize, - SMBToBytes, - SMBFromBytes, - Serialize, - Deserialize, - Clone + Debug, PartialEq, Eq, SMBByteSize, SMBToBytes, SMBFromBytes, Serialize, Deserialize, Clone, )] #[smb_byte_tag(value = 33)] pub struct SMBSetInfoRequest { @@ -30,20 +22,15 @@ pub struct SMBSetInfoRequest { additional_information: u32, #[smb_direct(start(fixed = 16))] file_id: SMBFileId, - #[smb_buffer(offset(inner(start = 8, num_type = "u16", subtract = 64)), length(inner(start = 4, num_type = "u32")))] + #[smb_buffer( + offset(inner(start = 8, num_type = "u16", subtract = 64)), + length(inner(start = 4, num_type = "u32")) + )] buffer: Vec, } #[derive( - Debug, - PartialEq, - Eq, - SMBByteSize, - SMBToBytes, - SMBFromBytes, - Serialize, - Deserialize, - Clone + Debug, PartialEq, Eq, SMBByteSize, SMBToBytes, SMBFromBytes, Serialize, Deserialize, Clone, )] #[smb_byte_tag(value = 2)] pub struct SMBSetInfoResponse { @@ -51,4 +38,4 @@ pub struct SMBSetInfoResponse { reserved: PhantomData>, #[smb_skip(start = 1, length = 1)] reserved2: PhantomData>, -} \ No newline at end of file +} diff --git a/smb/src/protocol/body/tree_connect/access_mask.rs b/smb/src/protocol/body/tree_connect/access_mask.rs index 1c09d99..b1a5b2a 100644 --- a/smb/src/protocol/body/tree_connect/access_mask.rs +++ b/smb/src/protocol/body/tree_connect/access_mask.rs @@ -3,7 +3,9 @@ use serde::{Deserialize, Serialize}; use smb_derive::{SMBByteSize, SMBEnumFromBytes, SMBToBytes}; -#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, SMBEnumFromBytes, SMBByteSize, SMBToBytes, Clone)] +#[derive( + Serialize, Deserialize, PartialEq, Eq, Debug, SMBEnumFromBytes, SMBByteSize, SMBToBytes, Clone, +)] pub enum SMBAccessMask { #[smb_discriminator(value = 0x2, value = 0x3, value = 0x0)] #[smb_direct(start(fixed = 0))] @@ -17,7 +19,7 @@ impl SMBAccessMask { pub fn raw(&self) -> u32 { match self { SMBAccessMask::FilePipePrinter(x) => x.bits(), - SMBAccessMask::Directory(x) => x.bits() + SMBAccessMask::Directory(x) => x.bits(), } } pub fn validate_print(&self) -> bool { @@ -29,21 +31,29 @@ impl SMBAccessMask { pub fn includes_maximum_allowed(&self) -> bool { match self { - SMBAccessMask::FilePipePrinter(x) => x.contains(SMBFilePipePrinterAccessMask::MAXIMUM_ALLOWED), - SMBAccessMask::Directory(x) => x.contains(SMBDirectoryAccessMask::MAXIMUM_ALLOWED) + SMBAccessMask::FilePipePrinter(x) => { + x.contains(SMBFilePipePrinterAccessMask::MAXIMUM_ALLOWED) + } + SMBAccessMask::Directory(x) => x.contains(SMBDirectoryAccessMask::MAXIMUM_ALLOWED), } } pub fn includes_access_system_security(&self) -> bool { match self { - SMBAccessMask::FilePipePrinter(x) => x.contains(SMBFilePipePrinterAccessMask::ACCESS_SYSTEM_SECURITY), - SMBAccessMask::Directory(x) => x.contains(SMBDirectoryAccessMask::ACCESS_SYSTEM_SECURITY) + SMBAccessMask::FilePipePrinter(x) => { + x.contains(SMBFilePipePrinterAccessMask::ACCESS_SYSTEM_SECURITY) + } + SMBAccessMask::Directory(x) => { + x.contains(SMBDirectoryAccessMask::ACCESS_SYSTEM_SECURITY) + } } } pub fn access_no_connect_security(is_directory: bool) -> Self { match is_directory { - true => Self::FilePipePrinter(SMBFilePipePrinterAccessMask::access_no_connect_security()), + true => { + Self::FilePipePrinter(SMBFilePipePrinterAccessMask::access_no_connect_security()) + } false => Self::Directory(SMBDirectoryAccessMask::access_no_connect_security()), } } @@ -52,15 +62,21 @@ impl SMBAccessMask { let mask = desired.clone(); if mask.includes_maximum_allowed() { match mask { - SMBAccessMask::FilePipePrinter(mut x) => x |= SMBFilePipePrinterAccessMask::GENERIC_ALL, - SMBAccessMask::Directory(mut x) => x |= SMBDirectoryAccessMask::GENERIC_ALL + SMBAccessMask::FilePipePrinter(mut x) => { + x |= SMBFilePipePrinterAccessMask::GENERIC_ALL + } + SMBAccessMask::Directory(mut x) => x |= SMBDirectoryAccessMask::GENERIC_ALL, }; } if mask.includes_access_system_security() { match mask { - SMBAccessMask::FilePipePrinter(mut x) => x |= SMBFilePipePrinterAccessMask::ACCESS_SYSTEM_SECURITY, - SMBAccessMask::Directory(mut x) => x |= SMBDirectoryAccessMask::ACCESS_SYSTEM_SECURITY + SMBAccessMask::FilePipePrinter(mut x) => { + x |= SMBFilePipePrinterAccessMask::ACCESS_SYSTEM_SECURITY + } + SMBAccessMask::Directory(mut x) => { + x |= SMBDirectoryAccessMask::ACCESS_SYSTEM_SECURITY + } }; } mask @@ -101,9 +117,19 @@ impl SMBFilePipePrinterAccessMask { } pub fn access_no_connect_security() -> Self { - Self::FILE_READ_DATA | Self::FILE_WRITE_DATA | Self::FILE_APPEND_DATA | Self::FILE_READ_EA - | Self::FILE_WRITE_EA | Self::FILE_DELETE_CHILD | Self::FILE_EXECUTE | Self::FILE_READ_ATTRIBUTES - | Self::FILE_WRITE_ATTRIBUTES | Self::DELETE | Self::READ_CONTROL | Self::WRITE_DAC | Self::WRITE_OWNER + Self::FILE_READ_DATA + | Self::FILE_WRITE_DATA + | Self::FILE_APPEND_DATA + | Self::FILE_READ_EA + | Self::FILE_WRITE_EA + | Self::FILE_DELETE_CHILD + | Self::FILE_EXECUTE + | Self::FILE_READ_ATTRIBUTES + | Self::FILE_WRITE_ATTRIBUTES + | Self::DELETE + | Self::READ_CONTROL + | Self::WRITE_DAC + | Self::WRITE_OWNER | Self::SYNCHRONIZE } } @@ -136,9 +162,19 @@ bitflags! { impl SMBDirectoryAccessMask { pub fn access_no_connect_security() -> Self { - Self::FILE_LIST_DIRECTORY | Self::FILE_ADD_FILE | Self::FILE_ADD_SUBDIRECTORY | Self::FILE_READ_EA - | Self::FILE_WRITE_EA | Self::FILE_DELETE_CHILD | Self::FILE_TRAVERSE | Self::FILE_READ_ATTRIBUTES - | Self::FILE_WRITE_ATTRIBUTES | Self::DELETE | Self::READ_CONTROL | Self::WRITE_DAC | Self::WRITE_OWNER + Self::FILE_LIST_DIRECTORY + | Self::FILE_ADD_FILE + | Self::FILE_ADD_SUBDIRECTORY + | Self::FILE_READ_EA + | Self::FILE_WRITE_EA + | Self::FILE_DELETE_CHILD + | Self::FILE_TRAVERSE + | Self::FILE_READ_ATTRIBUTES + | Self::FILE_WRITE_ATTRIBUTES + | Self::DELETE + | Self::READ_CONTROL + | Self::WRITE_DAC + | Self::WRITE_OWNER | Self::SYNCHRONIZE } -} \ No newline at end of file +} diff --git a/smb/src/protocol/body/tree_connect/buffer.rs b/smb/src/protocol/body/tree_connect/buffer.rs index 3357868..2c96cb6 100644 --- a/smb/src/protocol/body/tree_connect/buffer.rs +++ b/smb/src/protocol/body/tree_connect/buffer.rs @@ -8,19 +8,16 @@ use smb_derive::{SMBByteSize, SMBEnumFromBytes, SMBFromBytes, SMBToBytes}; use crate::protocol::body::tree_connect::context::SMBTreeConnectContext; #[derive( - Serialize, - Deserialize, - PartialEq, - Eq, - Debug, - SMBEnumFromBytes, - SMBByteSize, - SMBToBytes, - Clone + Serialize, Deserialize, PartialEq, Eq, Debug, SMBEnumFromBytes, SMBByteSize, SMBToBytes, Clone, )] pub enum SMBTreeConnectBuffer { #[smb_discriminator(value = 0x0)] - #[smb_string(order = 0, start(inner(start = 0, num_type = "u16", subtract = 68)), length(inner(start = 2, num_type = "u16")), underlying = "u16")] + #[smb_string( + order = 0, + start(inner(start = 0, num_type = "u16", subtract = 68)), + length(inner(start = 2, num_type = "u16")), + underlying = "u16" + )] Path(String), #[smb_direct(start(fixed = 0))] #[smb_discriminator(value = 0x1)] @@ -31,7 +28,7 @@ impl SMBTreeConnectBuffer { pub fn share(&self) -> &str { let path_str = match self { SMBTreeConnectBuffer::Path(x) => x, - SMBTreeConnectBuffer::Extension(x) => &x.path_name + SMBTreeConnectBuffer::Extension(x) => &x.path_name, }; let idx = path_str.rfind('\\'); trace!(?idx, "parsing share name from path"); @@ -44,21 +41,22 @@ impl SMBTreeConnectBuffer { } #[derive( - Serialize, - Deserialize, - PartialEq, - Eq, - Debug, - SMBByteSize, - SMBFromBytes, - SMBToBytes, - Clone + Serialize, Deserialize, PartialEq, Eq, Debug, SMBByteSize, SMBFromBytes, SMBToBytes, Clone, )] pub struct SMBTreeConnectExtension { #[smb_skip(start = 12, length = 2)] reserved: PhantomData>, - #[smb_string(order = 1, start(inner(start = 2, num_type = "u16", subtract = 64)), length = "null_terminated", underlying = "u16")] + #[smb_string( + order = 1, + start(inner(start = 2, num_type = "u16", subtract = 64)), + length = "null_terminated", + underlying = "u16" + )] path_name: String, - #[smb_vector(order = 2, count(inner(start = 10, num_type = "u16")), offset(inner(start = 6, num_type = "u32", subtract = 64)))] + #[smb_vector( + order = 2, + count(inner(start = 10, num_type = "u16")), + offset(inner(start = 6, num_type = "u32", subtract = 64)) + )] tree_connect_contexts: Vec, -} \ No newline at end of file +} diff --git a/smb/src/protocol/body/tree_connect/capabilities.rs b/smb/src/protocol/body/tree_connect/capabilities.rs index 95e55f9..68ccf1e 100644 --- a/smb/src/protocol/body/tree_connect/capabilities.rs +++ b/smb/src/protocol/body/tree_connect/capabilities.rs @@ -11,4 +11,4 @@ bitflags! { const ASYMMETRIC = 0x080; const REDIRECT_TO_OWNER = 0x100; } -} \ No newline at end of file +} diff --git a/smb/src/protocol/body/tree_connect/context.rs b/smb/src/protocol/body/tree_connect/context.rs index e1f5881..deaeb54 100644 --- a/smb/src/protocol/body/tree_connect/context.rs +++ b/smb/src/protocol/body/tree_connect/context.rs @@ -3,8 +3,8 @@ use std::marker::PhantomData; use bitflags::bitflags; use serde::{Deserialize, Serialize}; -use smb_core::{SMBByteSize, SMBFromBytes, SMBParseResult, SMBToBytes}; use smb_core::error::SMBError; +use smb_core::{SMBByteSize, SMBFromBytes, SMBParseResult, SMBToBytes}; use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; #[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Clone)] @@ -15,20 +15,25 @@ pub enum SMBTreeConnectContext { impl SMBByteSize for SMBTreeConnectContext { fn smb_byte_size(&self) -> usize { match self { - Self::RemotedIdentity(identity) => identity.smb_byte_size() + Self::RemotedIdentity(identity) => identity.smb_byte_size(), } } } impl SMBFromBytes for SMBTreeConnectContext { - fn smb_from_bytes(input: &[u8]) -> SMBParseResult<&[u8], Self> where Self: Sized { + fn smb_from_bytes(input: &[u8]) -> SMBParseResult<&[u8], Self> + where + Self: Sized, + { let (_remaining, ctx_type) = u16::smb_from_bytes(input)?; match ctx_type { 0x01 => { let (remaining, identity) = RemotedIdentity::smb_from_bytes(input)?; Ok((remaining, Self::RemotedIdentity(identity))) - }, - _ => Err(SMBError::parse_error("Invalid context type for tree connect context")) + } + _ => Err(SMBError::parse_error( + "Invalid context type for tree connect context", + )), } } } @@ -39,14 +44,13 @@ impl SMBToBytes for SMBTreeConnectContext { Self::RemotedIdentity(x) => (0x01_u16, x.smb_to_bytes()), }; let ctx_bytes = ctx_type.smb_to_bytes(); - [ - ctx_bytes, - bytes - ].concat() + [ctx_bytes, bytes].concat() } } -#[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBToBytes, SMBByteSize)] +#[derive( + Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBToBytes, SMBByteSize, +)] pub struct RemotedIdentity { #[smb_direct(start(inner(start = 4, num_type = "u16")))] user: SidAttrData, @@ -74,7 +78,9 @@ pub struct RemotedIdentity { device_claims: BlobData, } -#[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes)] +#[derive( + Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes, +)] pub struct BlobData { #[smb_skip(start = 0, length = 2)] reserved: PhantomData>, @@ -82,7 +88,9 @@ pub struct BlobData { data: Vec, } -#[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes)] +#[derive( + Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes, +)] pub struct SidAttrData { #[smb_direct(start(fixed = 0))] sid_data: BlobData, @@ -90,7 +98,9 @@ pub struct SidAttrData { attr: SidAttr, } -#[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes)] +#[derive( + Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes, +)] pub struct SidArrayData { #[smb_skip(start = 0, length = 2)] reserved: PhantomData>, @@ -98,7 +108,9 @@ pub struct SidArrayData { array: Vec, } -#[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes)] +#[derive( + Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes, +)] pub struct LuidAttrData { #[smb_direct(start(fixed = 0))] luid: [u8; 8], @@ -108,7 +120,9 @@ pub struct LuidAttrData { pub type PrivilegeData = BlobData; -#[derive(Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes)] +#[derive( + Debug, Eq, PartialEq, Serialize, Deserialize, Clone, SMBFromBytes, SMBByteSize, SMBToBytes, +)] pub struct PrivilegeArrayData { #[smb_skip(start = 0, length = 2)] reserved: PhantomData>, @@ -137,4 +151,4 @@ bitflags! { const E = 0x1E; const D = 0x1F; } -} \ No newline at end of file +} diff --git a/smb/src/protocol/body/tree_connect/flags.rs b/smb/src/protocol/body/tree_connect/flags.rs index b6d77cd..215a949 100644 --- a/smb/src/protocol/body/tree_connect/flags.rs +++ b/smb/src/protocol/body/tree_connect/flags.rs @@ -38,4 +38,4 @@ impl Default for SMBShareFlags { fn default() -> Self { Self::MANUAL_CACHING } -} \ No newline at end of file +} diff --git a/smb/src/protocol/body/tree_connect/mod.rs b/smb/src/protocol/body/tree_connect/mod.rs index 752b01d..6b9e9fb 100644 --- a/smb/src/protocol/body/tree_connect/mod.rs +++ b/smb/src/protocol/body/tree_connect/mod.rs @@ -5,30 +5,26 @@ use serde::{Deserialize, Serialize}; use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; -use crate::protocol::body::tree_connect::access_mask::{SMBAccessMask, SMBDirectoryAccessMask, SMBFilePipePrinterAccessMask}; +use crate::protocol::body::tree_connect::access_mask::{ + SMBAccessMask, SMBDirectoryAccessMask, SMBFilePipePrinterAccessMask, +}; use crate::protocol::body::tree_connect::buffer::SMBTreeConnectBuffer; use crate::protocol::body::tree_connect::capabilities::SMBTreeConnectCapabilities; use crate::protocol::body::tree_connect::context::{LuidAttr, SidAttr}; use crate::protocol::body::tree_connect::flags::{SMBShareFlags, SMBTreeConnectFlags}; use crate::server::share::{ResourceType, SharedResource}; -use crate::util::flags_helper::{impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag}; +use crate::util::flags_helper::{ + impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag, +}; -pub mod context; -pub mod buffer; pub mod access_mask; -pub mod flags; +pub mod buffer; pub mod capabilities; +pub mod context; +pub mod flags; #[derive( - Serialize, - Deserialize, - PartialEq, - Eq, - Debug, - SMBByteSize, - SMBFromBytes, - SMBToBytes, - Clone + Serialize, Deserialize, PartialEq, Eq, Debug, SMBByteSize, SMBFromBytes, SMBToBytes, Clone, )] #[smb_byte_tag(value = 09)] pub struct SMBTreeConnectRequest { @@ -45,15 +41,7 @@ impl SMBTreeConnectRequest { } #[derive( - Serialize, - Deserialize, - PartialEq, - Eq, - Debug, - SMBByteSize, - SMBFromBytes, - SMBToBytes, - Clone + Serialize, Deserialize, PartialEq, Eq, Debug, SMBByteSize, SMBFromBytes, SMBToBytes, Clone, )] #[smb_byte_tag(value = 16)] pub struct SMBTreeConnectResponse { @@ -84,7 +72,9 @@ impl Default for SMBTreeConnectResponse { impl SMBTreeConnectResponse { pub fn ipc() -> Self { Self { - maximal_access: SMBAccessMask::FilePipePrinter(SMBFilePipePrinterAccessMask::from_bits_truncate(2032127)), + maximal_access: SMBAccessMask::FilePipePrinter( + SMBFilePipePrinterAccessMask::from_bits_truncate(2032127), + ), share_type: SMBShareType::Pipe, reserved: PhantomData, share_flags: SMBShareFlags::NO_CACHING, @@ -103,7 +93,9 @@ impl SMBTreeConnectResponse { reserved: Default::default(), share_flags, capabilities: SMBTreeConnectCapabilities::empty(), - maximal_access: SMBAccessMask::FilePipePrinter(SMBFilePipePrinterAccessMask::from_bits_truncate(0x001f01ff)), + maximal_access: SMBAccessMask::FilePipePrinter( + SMBFilePipePrinterAccessMask::from_bits_truncate(0x001f01ff), + ), } } @@ -113,7 +105,20 @@ impl SMBTreeConnectResponse { } #[repr(u8)] -#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Copy, Clone, SMBByteSize, SMBFromBytes, SMBToBytes, TryFromPrimitive, Default)] +#[derive( + Serialize, + Deserialize, + PartialEq, + Eq, + Debug, + Copy, + Clone, + SMBByteSize, + SMBFromBytes, + SMBToBytes, + TryFromPrimitive, + Default, +)] pub enum SMBShareType { #[default] Disk = 0x01, diff --git a/smb/src/protocol/body/write/flags.rs b/smb/src/protocol/body/write/flags.rs index c70ae18..ae99e04 100644 --- a/smb/src/protocol/body/write/flags.rs +++ b/smb/src/protocol/body/write/flags.rs @@ -1,7 +1,9 @@ use bitflags::bitflags; use serde::{Deserialize, Serialize}; -use crate::util::flags_helper::{impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag}; +use crate::util::flags_helper::{ + impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag, +}; bitflags! { #[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone, Copy)] diff --git a/smb/src/protocol/body/write/mod.rs b/smb/src/protocol/body/write/mod.rs index b280e3d..f4ff936 100644 --- a/smb/src/protocol/body/write/mod.rs +++ b/smb/src/protocol/body/write/mod.rs @@ -11,15 +11,7 @@ use crate::protocol::body::write::flags::SMBWriteFlags; mod flags; #[derive( - Debug, - PartialEq, - Eq, - SMBByteSize, - SMBToBytes, - SMBFromBytes, - Serialize, - Deserialize, - Clone + Debug, PartialEq, Eq, SMBByteSize, SMBToBytes, SMBFromBytes, Serialize, Deserialize, Clone, )] #[smb_byte_tag(value = 49)] pub struct SMBWriteRequest { @@ -35,22 +27,20 @@ pub struct SMBWriteRequest { remaining_bytes: u32, #[smb_direct(start(fixed = 44))] flags: SMBWriteFlags, - #[smb_buffer(offset(inner(start = 40, num_type = "u16", subtract = 64)), length(inner(start = 42, num_type = "u16")))] + #[smb_buffer( + offset(inner(start = 40, num_type = "u16", subtract = 64)), + length(inner(start = 42, num_type = "u16")) + )] channel_information: Vec, - #[smb_buffer(offset(inner(start = 2, num_type = "u16", subtract = 64)), length(inner(start = 4, num_type = "u32")))] + #[smb_buffer( + offset(inner(start = 2, num_type = "u16", subtract = 64)), + length(inner(start = 4, num_type = "u32")) + )] data_to_write: Vec, } #[derive( - Debug, - PartialEq, - Eq, - SMBByteSize, - SMBToBytes, - SMBFromBytes, - Serialize, - Deserialize, - Clone + Debug, PartialEq, Eq, SMBByteSize, SMBToBytes, SMBFromBytes, Serialize, Deserialize, Clone, )] #[smb_byte_tag(value = 17)] pub struct SMBWriteResponse { @@ -64,4 +54,4 @@ pub struct SMBWriteResponse { write_channel_info_offset: PhantomData>, #[smb_skip(start = 14, length = 2)] write_channel_info_len: PhantomData>, -} \ No newline at end of file +} diff --git a/smb/src/protocol/header/command_code.rs b/smb/src/protocol/header/command_code.rs index 8eb68e8..c06ae19 100644 --- a/smb/src/protocol/header/command_code.rs +++ b/smb/src/protocol/header/command_code.rs @@ -4,7 +4,20 @@ use serde::{Deserialize, Serialize}; use smb_derive::{SMBByteSize, SMBFromBytes, SMBToBytes}; #[repr(u16)] -#[derive(Debug, Eq, PartialEq, TryFromPrimitive, IntoPrimitive, Serialize, Deserialize, Clone, Copy, SMBFromBytes, SMBByteSize, SMBToBytes)] +#[derive( + Debug, + Eq, + PartialEq, + TryFromPrimitive, + IntoPrimitive, + Serialize, + Deserialize, + Clone, + Copy, + SMBFromBytes, + SMBByteSize, + SMBToBytes, +)] pub enum SMBCommandCode { Negotiate = 0x0, SessionSetup, @@ -25,7 +38,7 @@ pub enum SMBCommandCode { QueryInfo, SetInfo, OplockBreak, - LegacyNegotiate + LegacyNegotiate, } impl From for u64 { @@ -35,7 +48,20 @@ impl From for u64 { } #[repr(u8)] -#[derive(Debug, Eq, PartialEq, TryFromPrimitive, IntoPrimitive, Serialize, Deserialize, Clone, Copy, SMBFromBytes, SMBByteSize, SMBToBytes)] +#[derive( + Debug, + Eq, + PartialEq, + TryFromPrimitive, + IntoPrimitive, + Serialize, + Deserialize, + Clone, + Copy, + SMBFromBytes, + SMBByteSize, + SMBToBytes, +)] pub enum LegacySMBCommandCode { CreateDirectory, DeleteDirectory, @@ -107,7 +133,7 @@ pub enum LegacySMBCommandCode { ClosePrintFile, GetPrintQueue, ReadBulk = 0xD9, - WriteBulkData + WriteBulkData, } impl From for u64 { @@ -143,4 +169,4 @@ mod tests { assert_eq!(SMBCommandCode::SetInfo as u16, 0x0011); assert_eq!(SMBCommandCode::OplockBreak as u16, 0x0012); } -} \ No newline at end of file +} diff --git a/smb/src/protocol/header/extra.rs b/smb/src/protocol/header/extra.rs index 5467905..6947747 100644 --- a/smb/src/protocol/header/extra.rs +++ b/smb/src/protocol/header/extra.rs @@ -1,6 +1,6 @@ +use nom::IResult; use nom::bytes::complete::take; use nom::combinator::map; -use nom::IResult; use nom::number::complete::{le_u16, le_u64}; use nom::sequence::tuple; use serde::{Deserialize, Serialize}; @@ -38,4 +38,3 @@ impl SMBExtra { .concat() } } - diff --git a/smb/src/protocol/header/flags.rs b/smb/src/protocol/header/flags.rs index 61534a5..1b7eda0 100644 --- a/smb/src/protocol/header/flags.rs +++ b/smb/src/protocol/header/flags.rs @@ -1,7 +1,9 @@ use bitflags::bitflags; use serde::{Deserialize, Serialize}; -use crate::util::flags_helper::{impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag}; +use crate::util::flags_helper::{ + impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag, +}; bitflags! { #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)] @@ -96,4 +98,4 @@ mod tests { let (_, parsed) = SMBFlags::smb_from_bytes(&bytes).unwrap(); assert_eq!(parsed, flags); } -} \ No newline at end of file +} diff --git a/smb/src/protocol/header/flags2.rs b/smb/src/protocol/header/flags2.rs index 83458cb..609ada1 100644 --- a/smb/src/protocol/header/flags2.rs +++ b/smb/src/protocol/header/flags2.rs @@ -1,7 +1,9 @@ use bitflags::bitflags; use serde::{Deserialize, Serialize}; -use crate::util::flags_helper::{impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag}; +use crate::util::flags_helper::{ + impl_smb_byte_size_for_bitflag, impl_smb_from_bytes_for_bitflag, impl_smb_to_bytes_for_bitflag, +}; bitflags! { #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Deserialize, Serialize)] @@ -26,4 +28,4 @@ bitflags! { impl_smb_byte_size_for_bitflag! { LegacySMBFlags2 } impl_smb_from_bytes_for_bitflag! { LegacySMBFlags2 } -impl_smb_to_bytes_for_bitflag! { LegacySMBFlags2 } \ No newline at end of file +impl_smb_to_bytes_for_bitflag! { LegacySMBFlags2 } diff --git a/smb/src/protocol/header/mod.rs b/smb/src/protocol/header/mod.rs index eab268e..6a2e71e 100644 --- a/smb/src/protocol/header/mod.rs +++ b/smb/src/protocol/header/mod.rs @@ -10,8 +10,8 @@ use std::cmp::min; use std::marker::PhantomData; -use nom::error::ErrorKind; use nom::IResult; +use nom::error::ErrorKind; use serde::{Deserialize, Serialize}; use smb_core::{SMBFromBytes, SMBToBytes}; @@ -25,14 +25,14 @@ use crate::protocol::header::status::SMBStatus; /// SMB2 command codes ([\[MS-SMB2\] 2.2.1]). pub mod command_code; -/// NT Status codes and legacy DOS error codes. -pub mod status; +/// Legacy SMB1 header extra fields. +pub mod extra; /// SMB2 header flags ([\[MS-SMB2\] 2.2.1]: `Flags` field). pub mod flags; /// Legacy SMB1 Flags2 field. pub mod flags2; -/// Legacy SMB1 header extra fields. -pub mod extra; +/// NT Status codes and legacy DOS error codes. +pub mod status; /// Indicates the direction of an SMB message. /// @@ -57,9 +57,16 @@ pub trait Header: SMBFromBytes + SMBToBytes { fn command_code(&self) -> Self::CommandCode; /// Parse a header from raw bytes, returning the header and its command code. - fn parse(bytes: &[u8]) -> IResult<&[u8], (Self, Self::CommandCode)> where Self: Sized + SMBFromBytes { - let (remaining, message) = Self::smb_from_bytes(bytes) - .map_err(|_e| nom::Err::Error(nom::error::ParseError::from_error_kind(bytes, ErrorKind::MapRes)))?; + fn parse(bytes: &[u8]) -> IResult<&[u8], (Self, Self::CommandCode)> + where + Self: Sized + SMBFromBytes, + { + let (remaining, message) = Self::smb_from_bytes(bytes).map_err(|_e| { + nom::Err::Error(nom::error::ParseError::from_error_kind( + bytes, + ErrorKind::MapRes, + )) + })?; let command = message.command_code(); // .map_err(|_e| ); Ok((remaining, (message, command))) @@ -93,15 +100,7 @@ pub trait Header: SMBFromBytes + SMBToBytes { /// | 40 | 8 | SessionId | /// | 48 | 16 | Signature | #[derive( - Serialize, - Deserialize, - PartialEq, - Eq, - Debug, - SMBFromBytes, - SMBToBytes, - SMBByteSize, - Clone + Serialize, Deserialize, PartialEq, Eq, Debug, SMBFromBytes, SMBToBytes, SMBByteSize, Clone, )] #[allow(clippy::duplicated_attributes)] #[smb_byte_tag(value = 0xFE, order = 0)] @@ -240,7 +239,12 @@ impl SMBSyncHeader { /// /// Sets `SMB2_FLAGS_SERVER_TO_REDIR`, copies the command and message ID, /// and zeroes the signature (to be filled in later if signing is required). - pub fn create_response_header(&self, channel_sequence: u32, session_id: u64, tree_id: u32) -> Self { + pub fn create_response_header( + &self, + channel_sequence: u32, + session_id: u64, + tree_id: u32, + ) -> Self { Self { command: self.command, flags: SMBFlags::SERVER_TO_REDIR, @@ -273,7 +277,13 @@ mod tests { #[test] fn sync_header_protocol_id_and_structure_size() { let header = SMBSyncHeader::new( - SMBCommandCode::Negotiate, SMBFlags::empty(), 0, 0, 0, 0, [0; 16], + SMBCommandCode::Negotiate, + SMBFlags::empty(), + 0, + 0, + 0, + 0, + [0; 16], ); let bytes = header.smb_to_bytes(); assert_eq!(bytes[0], 0xFE); @@ -286,16 +296,21 @@ mod tests { #[test] fn sync_header_is_64_bytes() { - let header = SMBSyncHeader::new( - SMBCommandCode::Echo, SMBFlags::empty(), 0, 0, 0, 0, [0; 16], - ); + let header = + SMBSyncHeader::new(SMBCommandCode::Echo, SMBFlags::empty(), 0, 0, 0, 0, [0; 16]); assert_eq!(header.smb_to_bytes().len(), 64); } #[test] fn sync_header_command_field_offset() { let header = SMBSyncHeader::new( - SMBCommandCode::SessionSetup, SMBFlags::empty(), 0, 0, 0, 0, [0; 16], + SMBCommandCode::SessionSetup, + SMBFlags::empty(), + 0, + 0, + 0, + 0, + [0; 16], ); let bytes = header.smb_to_bytes(); let cmd = u16::from_le_bytes([bytes[12], bytes[13]]); @@ -307,7 +322,11 @@ mod tests { let header = SMBSyncHeader::new( SMBCommandCode::Negotiate, SMBFlags::SERVER_TO_REDIR | SMBFlags::SIGNED, - 0, 0, 0, 0, [0; 16], + 0, + 0, + 0, + 0, + [0; 16], ); let bytes = header.smb_to_bytes(); let flags = u32::from_le_bytes([bytes[16], bytes[17], bytes[18], bytes[19]]); @@ -318,12 +337,17 @@ mod tests { #[test] fn sync_header_message_id_offset() { let header = SMBSyncHeader::new( - SMBCommandCode::Echo, SMBFlags::empty(), 0, 42, 0, 0, [0; 16], + SMBCommandCode::Echo, + SMBFlags::empty(), + 0, + 42, + 0, + 0, + [0; 16], ); let bytes = header.smb_to_bytes(); let msg_id = u64::from_le_bytes([ - bytes[24], bytes[25], bytes[26], bytes[27], - bytes[28], bytes[29], bytes[30], bytes[31], + bytes[24], bytes[25], bytes[26], bytes[27], bytes[28], bytes[29], bytes[30], bytes[31], ]); assert_eq!(msg_id, 42); } @@ -331,13 +355,18 @@ mod tests { #[test] fn sync_header_tree_id_and_session_id() { let header = SMBSyncHeader::new( - SMBCommandCode::Create, SMBFlags::empty(), 0, 0, 0x1234, 0xABCD, [0; 16], + SMBCommandCode::Create, + SMBFlags::empty(), + 0, + 0, + 0x1234, + 0xABCD, + [0; 16], ); let bytes = header.smb_to_bytes(); let tree_id = u32::from_le_bytes([bytes[36], bytes[37], bytes[38], bytes[39]]); let session_id = u64::from_le_bytes([ - bytes[40], bytes[41], bytes[42], bytes[43], - bytes[44], bytes[45], bytes[46], bytes[47], + bytes[40], bytes[41], bytes[42], bytes[43], bytes[44], bytes[45], bytes[46], bytes[47], ]); assert_eq!(tree_id, 0x1234); assert_eq!(session_id, 0xABCD); @@ -346,9 +375,7 @@ mod tests { #[test] fn sync_header_signature_offset() { let sig = [1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; - let header = SMBSyncHeader::new( - SMBCommandCode::Echo, SMBFlags::empty(), 0, 0, 0, 0, sig, - ); + let header = SMBSyncHeader::new(SMBCommandCode::Echo, SMBFlags::empty(), 0, 0, 0, 0, sig); let bytes = header.smb_to_bytes(); assert_eq!(&bytes[48..64], &sig); } @@ -356,7 +383,13 @@ mod tests { #[test] fn sync_header_round_trip() { let header = SMBSyncHeader::new( - SMBCommandCode::TreeConnect, SMBFlags::SERVER_TO_REDIR, 0, 7, 3, 99, [0xAA; 16], + SMBCommandCode::TreeConnect, + SMBFlags::SERVER_TO_REDIR, + 0, + 7, + 3, + 99, + [0xAA; 16], ); let bytes = header.smb_to_bytes(); let (remaining, parsed) = SMBSyncHeader::smb_from_bytes(&bytes).unwrap(); @@ -372,7 +405,13 @@ mod tests { #[test] fn create_response_header_sets_server_flag() { let request = SMBSyncHeader::new( - SMBCommandCode::Negotiate, SMBFlags::empty(), 0, 1, 0, 0, [0; 16], + SMBCommandCode::Negotiate, + SMBFlags::empty(), + 0, + 1, + 0, + 0, + [0; 16], ); let response = request.create_response_header(0, 0, 0); assert!(response.flags.contains(SMBFlags::SERVER_TO_REDIR)); @@ -382,9 +421,8 @@ mod tests { #[test] fn set_signature_enables_signed_flag() { - let mut header = SMBSyncHeader::new( - SMBCommandCode::Echo, SMBFlags::empty(), 0, 0, 0, 0, [0; 16], - ); + let mut header = + SMBSyncHeader::new(SMBCommandCode::Echo, SMBFlags::empty(), 0, 0, 0, 0, [0; 16]); assert!(!header.flags.contains(SMBFlags::SIGNED)); let sig = [0xDE; 16]; header.set_signature(&sig); diff --git a/smb/src/protocol/header/status.rs b/smb/src/protocol/header/status.rs index 3db90c6..a587615 100644 --- a/smb/src/protocol/header/status.rs +++ b/smb/src/protocol/header/status.rs @@ -1,16 +1,16 @@ -use nom::{bits, IResult}; +use nom::Err::Error; use nom::bits::streaming::take; use nom::combinator::map; -use nom::Err::Error; use nom::error::ErrorKind; use nom::number::complete::le_u8; use nom::number::streaming::le_u16; use nom::sequence::tuple; +use nom::{IResult, bits}; use num_enum::TryFromPrimitive; use serde::{Deserialize, Serialize}; -use smb_core::{SMBByteSize, SMBFromBytes, SMBParseResult, SMBToBytes}; use smb_core::error::SMBError; +use smb_core::{SMBByteSize, SMBFromBytes, SMBParseResult, SMBToBytes}; use crate::byte_helper::u16_to_bytes; @@ -38,24 +38,33 @@ pub enum NTStatusLevel { impl SMBStatus { pub(crate) fn parse(bytes: &[u8]) -> IResult<&[u8], Self> { - bits::<_, _, nom::error::Error<(&[u8], usize)>, _, _>(tuple((take(4_usize), take(4_usize))))(bytes) - .and_then(|(_, (_, nibble)): (&[u8], (u8, u8))| { - if nibble == 0x0 || nibble == 0x4 || nibble == 0x8 || nibble == 0xC { - let level = NTStatusLevel::try_from(nibble >> 2).map_err(|_e| Error(nom::error::Error::new(bytes, ErrorKind::Fail)))?; - let (remaining, facility) = map(nom::bytes::complete::take(2_usize), |s: &[u8]| [s[0] << 4, s[1]])(bytes)?; - let (remaining, error_code) = le_u16(remaining)?; - Ok((remaining, Self::NTStatus(NTStatusCode { + bits::<_, _, nom::error::Error<(&[u8], usize)>, _, _>(tuple(( + take(4_usize), + take(4_usize), + )))(bytes) + .and_then(|(_, (_, nibble)): (&[u8], (u8, u8))| { + if nibble == 0x0 || nibble == 0x4 || nibble == 0x8 || nibble == 0xC { + let level = NTStatusLevel::try_from(nibble >> 2) + .map_err(|_e| Error(nom::error::Error::new(bytes, ErrorKind::Fail)))?; + let (remaining, facility) = + map(nom::bytes::complete::take(2_usize), |s: &[u8]| { + [s[0] << 4, s[1]] + })(bytes)?; + let (remaining, error_code) = le_u16(remaining)?; + Ok(( + remaining, + Self::NTStatus(NTStatusCode { level, facility, - error_code - }))) - } else { - map( - tuple((le_u8, le_u8, le_u16)), - |(first, second, third)| Self::DosError(first.into(), second.into(), third), - )(bytes) - } - })?; + error_code, + }), + )) + } else { + map(tuple((le_u8, le_u8, le_u16)), |(first, second, third)| { + Self::DosError(first.into(), second.into(), third) + })(bytes) + } + })?; todo!() } } @@ -67,7 +76,10 @@ impl SMBByteSize for SMBStatus { } impl SMBFromBytes for SMBStatus { - fn smb_from_bytes(input: &[u8]) -> SMBParseResult<&[u8], Self> where Self: Sized { + fn smb_from_bytes(input: &[u8]) -> SMBParseResult<&[u8], Self> + where + Self: Sized, + { Self::parse(input).map_err(|_e| SMBError::parse_error("Invalid format")) } } @@ -86,7 +98,7 @@ impl SMBStatus { &[x.facility[1]][0..], &u16_to_bytes(x.error_code)[0..], ] - .concat(), + .concat(), SMBStatus::DosError(c1, c2, code) => [ &[*c1 as u8][0..], &[*c2 as u8][0..], @@ -96,4 +108,3 @@ impl SMBStatus { } } } - diff --git a/smb/src/protocol/message.rs b/smb/src/protocol/message.rs index 8ee92ba..a259434 100644 --- a/smb/src/protocol/message.rs +++ b/smb/src/protocol/message.rs @@ -18,13 +18,13 @@ use hmac::Hmac; use serde::{Deserialize, Serialize}; use sha2::Sha256; -use smb_core::{SMBParseResult, SMBResult}; use smb_core::error::SMBError; use smb_core::logging::trace; +use smb_core::{SMBParseResult, SMBResult}; use crate::byte_helper::u16_to_bytes; -use crate::protocol::body::{Body, LegacySMBBody, SMBBody}; use crate::protocol::body::negotiate::context::SigningAlgorithm; +use crate::protocol::body::{Body, LegacySMBBody, SMBBody}; use crate::protocol::header::{Header, LegacySMBHeader, SMBSyncHeader}; /// Convenience alias for a synchronous SMB2/3 message. @@ -46,10 +46,7 @@ pub struct SMBMessage> { impl> SMBMessage { pub fn new(header: S, body: T) -> Self { - SMBMessage { - header, - body - } + SMBMessage { header, body } } } @@ -66,11 +63,18 @@ pub trait Message { fn as_bytes(&self) -> Vec; /// Parse a message from raw bytes (starting at the SMB2 ProtocolId, **without** /// the 4-byte NetBIOS header). - fn parse(bytes: &[u8]) -> SMBParseResult<&[u8], Self> where Self: Sized; + fn parse(bytes: &[u8]) -> SMBParseResult<&[u8], Self> + where + Self: Sized; /// Compute the cryptographic signature for this message using the given /// signing key and algorithm ([\[MS-SMB2\] 3.1.5.1]). - fn signature(&self, nonce: &[u8], key: &[u8], algorithm: SigningAlgorithm) -> SMBResult>; + fn signature( + &self, + nonce: &[u8], + key: &[u8], + algorithm: SigningAlgorithm, + ) -> SMBResult>; } impl SMBMessage { @@ -95,29 +99,34 @@ impl> Message for SMBMessage { fn parse(bytes: &[u8]) -> SMBParseResult<&[u8], Self> { let (remaining, header) = S::smb_from_bytes(bytes)?; - trace!(?header, remaining_len = remaining.len(), "parsed message header"); + trace!( + ?header, + remaining_len = remaining.len(), + "parsed message header" + ); let discriminator_code = (header.command_code().into()) | ((header.sender() as u64) << 16); let (remaining, body) = T::smb_enum_from_bytes(remaining, discriminator_code)?; Ok((remaining, Self { header, body })) } - fn signature(&self, _nonce: &[u8], key: &[u8], algorithm: SigningAlgorithm) -> SMBResult> { + fn signature( + &self, + _nonce: &[u8], + key: &[u8], + algorithm: SigningAlgorithm, + ) -> SMBResult> { let res = match algorithm { SigningAlgorithm::HmacSha256 => { - let mut hmac = Hmac::::new_from_slice(key) - .map_err(SMBError::crypto_error)?; + let mut hmac = + Hmac::::new_from_slice(key).map_err(SMBError::crypto_error)?; hmac.update(&self.as_bytes()); - hmac.finalize() - .into_bytes() - .to_vec() + hmac.finalize().into_bytes().to_vec() } SigningAlgorithm::AesCmac => { - let mut cmac = Cmac::::new_from_slice(key) - .map_err(SMBError::crypto_error)?; + let mut cmac = + Cmac::::new_from_slice(key).map_err(SMBError::crypto_error)?; cmac.update(&self.as_bytes()); - cmac.finalize() - .into_bytes() - .to_vec() + cmac.finalize().into_bytes().to_vec() } SigningAlgorithm::AesGmac => { todo!(); @@ -140,11 +149,8 @@ mod tests { use crate::protocol::header::flags::SMBFlags; fn echo_request_message() -> SMBSyncMessage { - let header = SMBSyncHeader::new( - SMBCommandCode::Echo, - SMBFlags::empty(), - 0, 1, 0, 0, [0; 16], - ); + let header = + SMBSyncHeader::new(SMBCommandCode::Echo, SMBFlags::empty(), 0, 1, 0, 0, [0; 16]); let body = SMBBody::EchoRequest(SMBEmpty); SMBMessage::new(header, body) } @@ -196,7 +202,9 @@ mod tests { fn hmac_sha256_signature_is_nonempty() { let msg = echo_request_message(); let key = [0xAB; 16]; - let sig = msg.signature(&[], &key, SigningAlgorithm::HmacSha256).unwrap(); + let sig = msg + .signature(&[], &key, SigningAlgorithm::HmacSha256) + .unwrap(); assert!(!sig.is_empty(), "HMAC-SHA256 signature should not be empty"); assert_eq!(sig.len(), 32, "HMAC-SHA256 produces 32 bytes"); } @@ -209,4 +217,4 @@ mod tests { let sig = msg.signature(&[], &key, SigningAlgorithm::AesCmac).unwrap(); assert_eq!(sig.len(), 16, "AES-CMAC produces 16 bytes"); } -} \ No newline at end of file +} diff --git a/smb/src/protocol/mod.rs b/smb/src/protocol/mod.rs index d7dfb48..0b4f963 100644 --- a/smb/src/protocol/mod.rs +++ b/smb/src/protocol/mod.rs @@ -10,4 +10,4 @@ pub mod body; pub mod header; -pub mod message; \ No newline at end of file +pub mod message; diff --git a/smb/src/server/channel.rs b/smb/src/server/channel.rs index ad0d0c9..58d9c25 100644 --- a/smb/src/server/channel.rs +++ b/smb/src/server/channel.rs @@ -1,8 +1,7 @@ -use crate::server::{Server, SMBConnection}; +use crate::server::{SMBConnection, Server}; use crate::socket::message_stream::{SMBReadStream, SMBWriteStream}; pub struct SMBChannel { signing_key: [u8; 16], - connection: SMBConnection + connection: SMBConnection, } - diff --git a/smb/src/server/client.rs b/smb/src/server/client.rs index 303a7fb..af1d48f 100644 --- a/smb/src/server/client.rs +++ b/smb/src/server/client.rs @@ -5,5 +5,5 @@ use crate::protocol::body::dialect::SMBDialect; #[derive(Debug)] pub struct SMBClient { client_guid: Uuid, - dialect: SMBDialect -} \ No newline at end of file + dialect: SMBDialect, +} diff --git a/smb/src/server/connection.rs b/smb/src/server/connection.rs index 6d7616c..42c1797 100644 --- a/smb/src/server/connection.rs +++ b/smb/src/server/connection.rs @@ -5,36 +5,41 @@ use std::sync::{Arc, Weak}; use derive_builder::Builder; use digest::Digest; use sha2::Sha512; -use tokio::sync::{Mutex, RwLock}; use tokio::sync::mpsc::Sender; +use tokio::sync::{Mutex, RwLock}; use tokio_stream::StreamExt; use uuid::Uuid; -use smb_core::{SMBResult, SMBToBytes}; use smb_core::error::SMBError; -use smb_core::logging::{trace, debug, info, warn, error}; +use smb_core::logging::{debug, error, info, trace, warn}; use smb_core::nt_status::NTStatus; +use smb_core::{SMBResult, SMBToBytes}; +use crate::protocol::body::SMBBody; use crate::protocol::body::capabilities::Capabilities; use crate::protocol::body::create::SMBCreateRequest; use crate::protocol::body::dialect::SMBDialect; +use crate::protocol::body::error::SMBErrorResponse; use crate::protocol::body::filetime::FileTime; -use crate::protocol::body::negotiate::{SMBNegotiateRequest, SMBNegotiateResponse}; -use crate::protocol::body::negotiate::context::{CompressionAlgorithm, EncryptionCipher, HashAlgorithm, RDMATransformID, SigningAlgorithm}; +use crate::protocol::body::negotiate::context::{ + CompressionAlgorithm, EncryptionCipher, HashAlgorithm, RDMATransformID, SigningAlgorithm, +}; use crate::protocol::body::negotiate::security_mode::NegotiateSecurityMode; -use crate::protocol::body::session_setup::flags::SMBSessionSetupFlags; +use crate::protocol::body::negotiate::{SMBNegotiateRequest, SMBNegotiateResponse}; use crate::protocol::body::session_setup::SMBSessionSetupRequest; -use crate::protocol::body::error::SMBErrorResponse; -use crate::protocol::body::SMBBody; +use crate::protocol::body::session_setup::flags::SMBSessionSetupFlags; use crate::protocol::header::SMBSyncHeader; use crate::protocol::message::{Message, SMBMessage}; -use crate::server::{Server, SMBServerDiagnosticsUpdate}; -use crate::server::message_handler::{NonEndingHandler, SMBHandlerState, SMBLockedMessageHandler, SMBLockedMessageHandlerBase, SMBMessageType}; +use crate::server::message_handler::{ + NonEndingHandler, SMBHandlerState, SMBLockedMessageHandler, SMBLockedMessageHandlerBase, + SMBMessageType, +}; use crate::server::open::Open; use crate::server::preauth_session::SMBPreauthSession; use crate::server::request::Request; use crate::server::safe_locked_getter::{InnerGetter, SafeLockedGetter}; use crate::server::session::Session; +use crate::server::{SMBServerDiagnosticsUpdate, Server}; use crate::socket::message_stream::{SMBReadStream, SMBSocketConnection, SMBWriteStream}; use crate::util::auth::AuthProvider; @@ -118,7 +123,7 @@ pub struct SMBConnection { signing_algorithm_id: SigningAlgorithm, accept_transport_security: bool, underlying_stream: Arc>>, - server: Weak> + server: Weak>, } // Getters @@ -218,9 +223,15 @@ impl Connection for SMBConnectio } } -impl> SMBConnection - where Arc>: SMBLockedMessageHandler { - pub async fn start_message_handler(stream: &mut SMBSocketConnection, mut connection: Arc>>, update_channel: Sender) -> SMBResult<()> { +impl> SMBConnection +where + Arc>: SMBLockedMessageHandler, +{ + pub async fn start_message_handler( + stream: &mut SMBSocketConnection, + mut connection: Arc>>, + update_channel: Sender, + ) -> SMBResult<()> { let (read, write) = stream.streams(); info!("starting message handler"); let mut messages = read.messages(); @@ -241,7 +252,11 @@ impl> SMBConnect if let Some(session_lock) = session { let session_rd = session_lock.read().await; let key = session_rd.signing_key().to_vec(); - if !key.is_empty() { Some((key, conn_rd.dialect)) } else { None } + if !key.is_empty() { + Some((key, conn_rd.dialect)) + } else { + None + } } else { None } @@ -254,7 +269,9 @@ impl> SMBConnect debug!(command = ?response.header.command, status = ?status, "sending error response"); trace!(?response, "full error response"); let sent = write.write_message(&response).await?; - let _ = update_channel.send(SMBServerDiagnosticsUpdate::default().bytes_sent(sent as u64)).await; + let _ = update_channel + .send(SMBServerDiagnosticsUpdate::default().bytes_sent(sent as u64)) + .await; continue; } let result = connection.handle_message(&incoming).await; @@ -279,7 +296,11 @@ impl> SMBConnect if let Some(session_lock) = session { let session_rd = session_lock.read().await; let key = session_rd.signing_key().to_vec(); - if !key.is_empty() { Some((key, conn_rd.dialect)) } else { None } + if !key.is_empty() { + Some((key, conn_rd.dialect)) + } else { + None + } } else { None } @@ -292,7 +313,9 @@ impl> SMBConnect debug!(command = ?response.header.command, mid = response.header.message_id, "sending response"); trace!(?response, "full outgoing response"); let sent = write.write_message(&response).await?; - let _ = update_channel.send(SMBServerDiagnosticsUpdate::default().bytes_sent(sent as u64)).await; + let _ = update_channel + .send(SMBServerDiagnosticsUpdate::default().bytes_sent(sent as u64)) + .await; } // Close streams on message parse finish (logoff) @@ -320,9 +343,18 @@ impl> SMBConnect // The SMB2 message starts at offset 4 (after the 4-byte NetBIOS header) let smb2_offset = 4; let smb2_len = bytes.len() - smb2_offset; - trace!(?dialect, key_len = signing_key.len(), smb2_len, "signing message"); + trace!( + ?dialect, + key_len = signing_key.len(), + smb2_len, + "signing message" + ); if let Ok(sig) = crate::util::crypto::smb2::calculate_signature( - signing_key, dialect, &bytes, smb2_offset, smb2_len + signing_key, + dialect, + &bytes, + smb2_offset, + smb2_len, ) { trace!(signature = ?&sig[..std::cmp::min(16, sig.len())], "computed signature"); let len = std::cmp::min(16, sig.len()); @@ -331,7 +363,7 @@ impl> SMBConnect } } -impl> SMBConnection { +impl> SMBConnection { pub fn underlying_socket(&self) -> Arc>> { self.underlying_stream.clone() } @@ -438,7 +470,9 @@ impl> SMBConnect type LockedSMBConnection = Arc>>; pub type WeakLockedSMBConnection = Weak>>; -impl> InnerGetter for SMBConnection { +impl> InnerGetter + for SMBConnection +{ type Upper = S; fn upper(&self) -> Option>> { @@ -446,47 +480,87 @@ impl> InnerGette } } - -impl> SMBConnection { - fn handle_negotiate(&mut self, server: &S, header: &SMBSyncHeader, request: &SMBNegotiateRequest) -> SMBResult { +impl> SMBConnection { + fn handle_negotiate( + &mut self, + server: &S, + header: &SMBSyncHeader, + request: &SMBNegotiateRequest, + ) -> SMBResult { let (update, contexts) = request.validate_and_set_state(self, server)?; self.apply_update(update); let resp_header = header.create_response_header(0x0, 0, 0); - let resp_body = SMBNegotiateResponse::from_connection_state::(self, server, contexts); - Ok(SMBMessage::new(resp_header, SMBBody::NegotiateResponse(resp_body))) - } - - async fn handle_session_setup Arc>>(&mut self, server: &S, header: &SMBSyncHeader, request: &SMBSessionSetupRequest, get_locked: F) -> SMBResult>> { + let resp_body = + SMBNegotiateResponse::from_connection_state::(self, server, contexts); + Ok(SMBMessage::new( + resp_header, + SMBBody::NegotiateResponse(resp_body), + )) + } + + async fn handle_session_setup Arc>>( + &mut self, + server: &S, + header: &SMBSyncHeader, + request: &SMBSessionSetupRequest, + get_locked: F, + ) -> SMBResult>> { let locked_conn = get_locked(); let mut sha = Sha512::default(); sha.update(self.preauth_integtiry_hash_value()); sha.update(request.smb_to_bytes()); let preauth_val = sha.finalize().to_vec(); - let session = S::Session::init(1, server.encrypt_data(), preauth_val, Arc::downgrade(&locked_conn), server.auth_provider().clone()); + let session = S::Session::init( + 1, + server.encrypt_data(), + preauth_val, + Arc::downgrade(&locked_conn), + server.auth_provider().clone(), + ); let id = session.id(); let wrapped_session = Arc::new(RwLock::new(session)); self.session_table.insert(id, wrapped_session.clone()); let unlocked = wrapped_session.read().await; - let update = request.validate_and_set_state(self, server, &unlocked, header).await?; + let update = request + .validate_and_set_state(self, server, &unlocked, header) + .await?; drop(unlocked); self.apply_update(update); Ok(wrapped_session) } - fn get_session(&self, server: &S, header: &SMBSyncHeader, flags: SMBSessionSetupFlags) -> SMBResult>> { - if self.dialect.is_smb3() && server.multi_channel_capable() && flags.contains(SMBSessionSetupFlags::BINDING) { + fn get_session( + &self, + server: &S, + header: &SMBSyncHeader, + flags: SMBSessionSetupFlags, + ) -> SMBResult>> { + if self.dialect.is_smb3() + && server.multi_channel_capable() + && flags.contains(SMBSessionSetupFlags::BINDING) + { server.sessions().get(&header.session_id) - } else if !self.dialect.is_smb3() && !server.multi_channel_capable() && flags.contains(SMBSessionSetupFlags::BINDING) { + } else if !self.dialect.is_smb3() + && !server.multi_channel_capable() + && flags.contains(SMBSessionSetupFlags::BINDING) + { None } else { self.sessions().get(&header.session_id) - }.map(Arc::clone).ok_or(SMBError::response_error(NTStatus::UserSessionDeleted)) + } + .map(Arc::clone) + .ok_or(SMBError::response_error(NTStatus::UserSessionDeleted)) } } -impl>> NonEndingHandler for LockedSMBConnection {} +impl>> + NonEndingHandler for LockedSMBConnection +{ +} -impl>> SMBLockedMessageHandlerBase for LockedSMBConnection { +impl>> + SMBLockedMessageHandlerBase for LockedSMBConnection +{ type Inner = Arc>; async fn inner(&self, message: &SMBMessageType) -> Option>> { @@ -500,34 +574,53 @@ impl SMBResult> { + async fn handle_negotiate( + &mut self, + header: &SMBSyncHeader, + message: &SMBNegotiateRequest, + ) -> SMBResult> { let server = self.upper().await?; let unlocked = server.read().await; - let message = self.write().await.handle_negotiate::(&unlocked, header, message)?; + let message = self + .write() + .await + .handle_negotiate::(&unlocked, header, message)?; Ok(SMBHandlerState::Finished(message)) } - async fn handle_session_setup(&mut self, header: &SMBSyncHeader, message: &SMBSessionSetupRequest) -> SMBResult>>> { + async fn handle_session_setup( + &mut self, + header: &SMBSyncHeader, + message: &SMBSessionSetupRequest, + ) -> SMBResult>>> { let server = self.upper().await?; let unlocked = server.read().await; let cloned_arc = self.clone(); - let get_locked = || { - cloned_arc - }; + let get_locked = || cloned_arc; if header.session_id == 0 { - let session = self.write().await.handle_session_setup(&unlocked, header, message, get_locked).await?; + let session = self + .write() + .await + .handle_session_setup(&unlocked, header, message, get_locked) + .await?; Ok(SMBHandlerState::Next(Some(session))) } else { Ok(SMBHandlerState::Next(None)) } } - async fn handle_create(&mut self, _header: &SMBSyncHeader, message: &SMBCreateRequest) -> SMBResult> { + async fn handle_create( + &mut self, + _header: &SMBSyncHeader, + message: &SMBCreateRequest, + ) -> SMBResult> { let server = self.upper().await?; let server_rd = server.read().await; let conn = self.read().await; @@ -544,7 +637,9 @@ impl TryFrom<(SMBSocketConnection, Weak>)> for SMBConnection { +impl + TryFrom<(SMBSocketConnection, Weak>)> for SMBConnection +{ type Error = SMBError; fn try_from(value: (SMBSocketConnection, Weak>)) -> Result { @@ -582,7 +677,7 @@ impl TryFrom<(SMBSocketConnectio signing_algorithm_id: SigningAlgorithm::HmacSha256, accept_transport_security: false, underlying_stream: Arc::new(Mutex::new(value.0)), - server: value.1 + server: value.1, }) } -} \ No newline at end of file +} diff --git a/smb/src/server/lease.rs b/smb/src/server/lease.rs index c06538d..43a7b4a 100644 --- a/smb/src/server/lease.rs +++ b/smb/src/server/lease.rs @@ -4,16 +4,15 @@ use std::fmt::{Debug, Formatter}; use bitflags::bitflags; use uuid::Uuid; -use crate::server::open::SMBOpen; use crate::server::Server; +use crate::server::open::SMBOpen; pub trait Lease: Send + Sync {} - #[derive(Debug)] pub struct SMBLeaseTable { client_guid: Uuid, - lease_list: HashMap + lease_list: HashMap, } pub struct SMBLease { @@ -30,10 +29,16 @@ pub struct SMBLease { file_delete_on_close: bool, epoch: u64, parent_lease_key: u128, - version: u8 + version: u8, } -impl Debug for SMBLease where S::Handle: Debug, S: Debug, S::Session: Debug, S::Share: Debug { +impl Debug for SMBLease +where + S::Handle: Debug, + S: Debug, + S::Session: Debug, + S::Share: Debug, +{ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.debug_struct("SMBLease") .field("lease_key", &self.lease_key) @@ -80,4 +85,4 @@ bitflags! { pub struct SMBLeaseBreakNotificationFlags: u32 { const NOTIFY_BREAK_LEASE_FLAG_ACK_REQUIRED = 0x01; } -} \ No newline at end of file +} diff --git a/smb/src/server/message_handler.rs b/smb/src/server/message_handler.rs index 50f0a2b..e73bbb0 100644 --- a/smb/src/server/message_handler.rs +++ b/smb/src/server/message_handler.rs @@ -1,10 +1,11 @@ use std::future::Future; +use smb_core::SMBResult; use smb_core::error::SMBError; -use smb_core::logging::{trace, debug}; +use smb_core::logging::{debug, trace}; use smb_core::nt_status::NTStatus; -use smb_core::SMBResult; +use crate::protocol::body::SMBBody; use crate::protocol::body::cancel::SMBCancelRequest; use crate::protocol::body::change_notify::SMBChangeNotifyRequest; use crate::protocol::body::close::SMBCloseRequest; @@ -21,7 +22,6 @@ use crate::protocol::body::query_info::SMBQueryInfoRequest; use crate::protocol::body::read::SMBReadRequest; use crate::protocol::body::session_setup::SMBSessionSetupRequest; use crate::protocol::body::set_info::SMBSetInfoRequest; -use crate::protocol::body::SMBBody; use crate::protocol::body::tree_connect::SMBTreeConnectRequest; use crate::protocol::body::tree_disconnect::SMBTreeDisconnectRequest; use crate::protocol::body::write::SMBWriteRequest; @@ -48,16 +48,25 @@ impl SMBHandlerState { pub trait SMBLockedMessageHandlerBase { type Inner; - fn inner(&self, message: &SMBMessageType) -> impl Future>; - fn handle_message_inner(&mut self, message: &SMBMessageType) -> impl Future>> { + fn inner(&self, message: &SMBMessageType) -> impl Future>; + fn handle_message_inner( + &mut self, + message: &SMBMessageType, + ) -> impl Future>> { debug!(command = ?message.header.command, mid = message.header.message_id, "dispatching to handler"); async { match &message.body { SMBBody::NegotiateRequest(req) => self.handle_negotiate(&message.header, req).await, - SMBBody::SessionSetupRequest(req) => self.handle_session_setup(&message.header, req).await, + SMBBody::SessionSetupRequest(req) => { + self.handle_session_setup(&message.header, req).await + } SMBBody::LogoffRequest(req) => self.handle_logoff(&message.header, req).await, - SMBBody::TreeConnectRequest(req) => self.handle_tree_connect(&message.header, req).await, - SMBBody::TreeDisconnectRequest(req) => self.handle_tree_disconnect(&message.header, req).await, + SMBBody::TreeConnectRequest(req) => { + self.handle_tree_connect(&message.header, req).await + } + SMBBody::TreeDisconnectRequest(req) => { + self.handle_tree_disconnect(&message.header, req).await + } SMBBody::CreateRequest(req) => self.handle_create(&message.header, req).await, SMBBody::CloseRequest(req) => self.handle_close(&message.header, req).await, SMBBody::FlushRequest(req) => self.handle_flush(&message.header, req).await, @@ -67,109 +76,196 @@ pub trait SMBLockedMessageHandlerBase { SMBBody::IoCtlRequest(req) => self.handle_ioctl(&message.header, req).await, SMBBody::CancelRequest(req) => self.handle_cancel(&message.header, req).await, SMBBody::EchoRequest(req) => self.handle_echo(&message.header, req).await, - SMBBody::QueryDirectoryRequest(req) => self.handle_query_directory(&message.header, req).await, - SMBBody::ChangeNotifyRequest(req) => self.handle_change_notify(&message.header, req).await, - SMBBody::QueryInfoRequest(req) => self.handle_query_info(&message.header, req).await, + SMBBody::QueryDirectoryRequest(req) => { + self.handle_query_directory(&message.header, req).await + } + SMBBody::ChangeNotifyRequest(req) => { + self.handle_change_notify(&message.header, req).await + } + SMBBody::QueryInfoRequest(req) => { + self.handle_query_info(&message.header, req).await + } SMBBody::SetInfoRequest(req) => self.handle_set_info(&message.header, req).await, - SMBBody::OplockBreakAcknowledgement(req) => self.handle_oplock_break(&message.header, req).await, + SMBBody::OplockBreakAcknowledgement(req) => { + self.handle_oplock_break(&message.header, req).await + } SMBBody::ErrorResponse(_) => { let status = NTStatus::try_from(message.header.channel_sequence) .unwrap_or(NTStatus::NotSupported); Err(SMBError::response_error(status)) - }, + } _ => Err(SMBError::server_error("Command not implemented")), } } } - fn handle_negotiate(&mut self, _header: &SMBSyncHeader, _message: &SMBNegotiateRequest) -> impl Future>> { + fn handle_negotiate( + &mut self, + _header: &SMBSyncHeader, + _message: &SMBNegotiateRequest, + ) -> impl Future>> { async { Ok(SMBHandlerState::Next(None)) } } - fn handle_session_setup(&mut self, _header: &SMBSyncHeader, _message: &SMBSessionSetupRequest) -> impl Future>> { + fn handle_session_setup( + &mut self, + _header: &SMBSyncHeader, + _message: &SMBSessionSetupRequest, + ) -> impl Future>> { async { Ok(SMBHandlerState::Next(None)) } } - fn handle_logoff(&mut self, _header: &SMBSyncHeader, _message: &SMBLogoffRequest) -> impl Future>> { + fn handle_logoff( + &mut self, + _header: &SMBSyncHeader, + _message: &SMBLogoffRequest, + ) -> impl Future>> { async { Ok(SMBHandlerState::Next(None)) } } - fn handle_tree_connect(&mut self, _header: &SMBSyncHeader, _message: &SMBTreeConnectRequest) -> impl Future>> { + fn handle_tree_connect( + &mut self, + _header: &SMBSyncHeader, + _message: &SMBTreeConnectRequest, + ) -> impl Future>> { async { Ok(SMBHandlerState::Next(None)) } } - fn handle_tree_disconnect(&mut self, _header: &SMBSyncHeader, _message: &SMBTreeDisconnectRequest) -> impl Future>> { + fn handle_tree_disconnect( + &mut self, + _header: &SMBSyncHeader, + _message: &SMBTreeDisconnectRequest, + ) -> impl Future>> { async { Ok(SMBHandlerState::Next(None)) } } - fn handle_create(&mut self, _header: &SMBSyncHeader, _message: &SMBCreateRequest) -> impl Future>> { + fn handle_create( + &mut self, + _header: &SMBSyncHeader, + _message: &SMBCreateRequest, + ) -> impl Future>> { debug!("create request, passing to next handler"); async { Ok(SMBHandlerState::Next(None)) } } - fn handle_close(&mut self, _header: &SMBSyncHeader, _message: &SMBCloseRequest) -> impl Future>> { + fn handle_close( + &mut self, + _header: &SMBSyncHeader, + _message: &SMBCloseRequest, + ) -> impl Future>> { async { Ok(SMBHandlerState::Next(None)) } } - fn handle_flush(&mut self, _header: &SMBSyncHeader, _message: &SMBFlushRequest) -> impl Future>> { + fn handle_flush( + &mut self, + _header: &SMBSyncHeader, + _message: &SMBFlushRequest, + ) -> impl Future>> { async { Ok(SMBHandlerState::Next(None)) } } - fn handle_read(&mut self, _header: &SMBSyncHeader, _message: &SMBReadRequest) -> impl Future>> { + fn handle_read( + &mut self, + _header: &SMBSyncHeader, + _message: &SMBReadRequest, + ) -> impl Future>> { async { Ok(SMBHandlerState::Next(None)) } } - fn handle_write(&mut self, _header: &SMBSyncHeader, _message: &SMBWriteRequest) -> impl Future>> { + fn handle_write( + &mut self, + _header: &SMBSyncHeader, + _message: &SMBWriteRequest, + ) -> impl Future>> { async { Ok(SMBHandlerState::Next(None)) } } - fn handle_lock(&mut self, _header: &SMBSyncHeader, _message: &SMBLockRequest) -> impl Future>> { + fn handle_lock( + &mut self, + _header: &SMBSyncHeader, + _message: &SMBLockRequest, + ) -> impl Future>> { async { Ok(SMBHandlerState::Next(None)) } } - fn handle_ioctl(&mut self, _header: &SMBSyncHeader, _message: &SMBIoCtlRequest) -> impl Future>> { + fn handle_ioctl( + &mut self, + _header: &SMBSyncHeader, + _message: &SMBIoCtlRequest, + ) -> impl Future>> { async { Ok(SMBHandlerState::Next(None)) } } - fn handle_cancel(&mut self, _header: &SMBSyncHeader, _message: &SMBCancelRequest) -> impl Future>> { + fn handle_cancel( + &mut self, + _header: &SMBSyncHeader, + _message: &SMBCancelRequest, + ) -> impl Future>> { async { Ok(SMBHandlerState::Next(None)) } } - fn handle_echo(&mut self, _header: &SMBSyncHeader, _message: &SMBEchoRequest) -> impl Future>> { + fn handle_echo( + &mut self, + _header: &SMBSyncHeader, + _message: &SMBEchoRequest, + ) -> impl Future>> { async { Ok(SMBHandlerState::Next(None)) } } - fn handle_query_directory(&mut self, _header: &SMBSyncHeader, _message: &SMBQueryDirectoryRequest) -> impl Future>> { + fn handle_query_directory( + &mut self, + _header: &SMBSyncHeader, + _message: &SMBQueryDirectoryRequest, + ) -> impl Future>> { async { Ok(SMBHandlerState::Next(None)) } } - fn handle_change_notify(&mut self, _header: &SMBSyncHeader, _message: &SMBChangeNotifyRequest) -> impl Future>> { + fn handle_change_notify( + &mut self, + _header: &SMBSyncHeader, + _message: &SMBChangeNotifyRequest, + ) -> impl Future>> { async { Ok(SMBHandlerState::Next(None)) } } - fn handle_query_info(&mut self, _header: &SMBSyncHeader, _message: &SMBQueryInfoRequest) -> impl Future>> { + fn handle_query_info( + &mut self, + _header: &SMBSyncHeader, + _message: &SMBQueryInfoRequest, + ) -> impl Future>> { async { Ok(SMBHandlerState::Next(None)) } } - fn handle_set_info(&mut self, _header: &SMBSyncHeader, _message: &SMBSetInfoRequest) -> impl Future>> { + fn handle_set_info( + &mut self, + _header: &SMBSyncHeader, + _message: &SMBSetInfoRequest, + ) -> impl Future>> { async { Ok(SMBHandlerState::Next(None)) } } - fn handle_oplock_break(&mut self, _header: &SMBSyncHeader, _message: &SMBOplockBreakAcknowledgement) -> impl Future>> { + fn handle_oplock_break( + &mut self, + _header: &SMBSyncHeader, + _message: &SMBOplockBreakAcknowledgement, + ) -> impl Future>> { async { Ok(SMBHandlerState::Next(None)) } } } pub trait SMBLockedMessageHandler: SMBLockedMessageHandlerBase { - fn handle_message(&mut self, message: &SMBMessageType) -> impl Future> { - async { - self.handle_message_inner(message).await? - .get_message() - } + fn handle_message( + &mut self, + message: &SMBMessageType, + ) -> impl Future> { + async { self.handle_message_inner(message).await?.get_message() } } } -impl SMBLockedMessageHandler for H where H::Inner: SMBLockedMessageHandler { +impl SMBLockedMessageHandler for H +where + H::Inner: SMBLockedMessageHandler, +{ async fn handle_message(&mut self, message: &SMBMessageType) -> SMBResult { debug!(command = ?message.header.command, mid = message.header.message_id, "handling message in chain"); let state = self.handle_message_inner(message).await?; @@ -178,20 +274,21 @@ impl SMBLockedMessageHandler SMBHandlerState::Finished(msg) => { trace!("handler produced final response"); Ok(msg) - }, + } SMBHandlerState::Next(Some(mut handler)) => { trace!("delegating to explicit next handler"); handler.handle_message(message).await - }, + } SMBHandlerState::Next(None) => { trace!("delegating to inner handler"); - self.inner(message).await + self.inner(message) + .await .ok_or(SMBError::server_error("Invalid handler defined"))? .handle_message(message) .await - }, + } } } } -pub trait NonEndingHandler {} \ No newline at end of file +pub trait NonEndingHandler {} diff --git a/smb/src/server/mod.rs b/smb/src/server/mod.rs index 0dd240a..18a3d08 100644 --- a/smb/src/server/mod.rs +++ b/smb/src/server/mod.rs @@ -1,18 +1,18 @@ -use std::collections::hash_map::Entry; use std::collections::HashMap; +use std::collections::hash_map::Entry; use std::fmt::Debug; use std::future::Future; use std::sync::Arc; use derive_builder::Builder; use tokio::net::TcpListener; -use tokio::sync::{mpsc, Mutex, RwLock}; +use tokio::sync::{Mutex, RwLock, mpsc}; use tokio_stream::StreamExt; use uuid::Uuid; -use smb_core::error::SMBError; -use smb_core::logging::{info, debug, warn}; use smb_core::SMBResult; +use smb_core::error::SMBError; +use smb_core::logging::{debug, info, warn}; use crate::protocol::body::dialect::SMBDialect; use crate::protocol::body::filetime::FileTime; @@ -21,38 +21,42 @@ use crate::server::connection::{Connection, SMBConnection, WeakLockedSMBConnecti use crate::server::lease::{Lease, SMBLease, SMBLeaseTable}; use crate::server::open::{LockedSMBOpen, Open, SMBOpen}; use crate::server::safe_locked_getter::InnerGetter; -use crate::server::session::{LockedSMBSession, Session, SMBSession}; -use crate::server::share::{ConnectAllowed, FilePerms, ResourceHandle, SharedResource}; +use crate::server::session::{LockedSMBSession, SMBSession, Session}; use crate::server::share::file_system::{SMBFileSystemHandle, SMBFileSystemShare}; use crate::server::share::ipc::{SMBIPCHandle, SMBIPCShare}; +use crate::server::share::{ConnectAllowed, FilePerms, ResourceHandle, SharedResource}; use crate::socket::listener::{SMBListener, SMBSocket}; -use crate::util::auth::{AuthContext, AuthProvider}; use crate::util::auth::ntlm::NTLMAuthProvider; +use crate::util::auth::{AuthContext, AuthProvider}; -pub mod client; pub mod channel; +pub mod client; pub mod connection; pub mod lease; +mod message_handler; pub mod open; pub mod preauth_session; pub mod request; +mod safe_locked_getter; pub mod session; pub mod share; pub mod tree_connect; -mod message_handler; -mod safe_locked_getter; pub trait Server: Send + Sync { - type Connection: Connection + InnerGetter; - type Session: Session + InnerGetter; - type Share: SharedResource::Context as AuthContext>::UserName, Handle=Self::Handle>; - type Open: Open; + type Connection: Connection + InnerGetter; + type Session: Session + + InnerGetter; + type Share: SharedResource< + UserName = <::Context as AuthContext>::UserName, + Handle = Self::Handle, + >; + type Open: Open; type Lease: Lease; type AuthProvider: AuthProvider; type Handle: ResourceHandle; fn shares(&self) -> &HashMap>; fn opens(&self) -> &HashMap>>; - fn add_open(&mut self, open: Arc>) -> impl Future; + fn add_open(&mut self, open: Arc>) -> impl Future; fn sessions(&self) -> &HashMap>>; fn sessions_mut(&mut self) -> &mut HashMap>>; fn guid(&self) -> Uuid; @@ -78,17 +82,32 @@ pub trait Server: Send + Sync { } pub trait StartSMBServer { - fn start(&self) -> impl Future> + Send; + fn start(&self) -> impl Future> + Send; } -type SMBConnectionType = SMBConnection<>::ReadStream, >::WriteStream, SMBServer>; +type SMBConnectionType = SMBConnection< + >::ReadStream, + >::WriteStream, + SMBServer, +>; type UserName = <::Context as AuthContext>::UserName; -pub type DefaultShare = Box::Context as AuthContext>::UserName, Handle=DefaultHandle>>; +pub type DefaultShare = Box< + dyn SharedResource< + UserName = <::Context as AuthContext>::UserName, + Handle = DefaultHandle, + >, +>; type DefaultHandle = Box; #[derive(Debug, Builder)] #[builder(pattern = "owned")] #[builder(build_fn(name = "build_inner", private))] -pub struct SMBServer = TcpListener, Auth: AuthProvider = NTLMAuthProvider, Share: SharedResource, Handle=Handle> = DefaultShare, Handle: ResourceHandle = DefaultHandle> { +pub struct SMBServer< + Addrs: Send + Sync, + Listener: SMBSocket = TcpListener, + Auth: AuthProvider = NTLMAuthProvider, + Share: SharedResource, Handle = Handle> = DefaultShare, + Handle: ResourceHandle = DefaultHandle, +> { #[builder(default = "Default::default()")] statistics: Arc>, #[builder(default = "false")] @@ -106,7 +125,14 @@ pub struct SMBServer = TcpListene #[builder(field( type = "HashMap>::ReadStream, >::WriteStream, SMBServer>>" ))] - connection_list: HashMap>::ReadStream, >::WriteStream, SMBServer>>, + connection_list: HashMap< + String, + WeakLockedSMBConnection< + >::ReadStream, + >::WriteStream, + SMBServer, + >, + >, #[builder(default = "Uuid::new_v4()")] guid: Uuid, #[builder(default = "FileTime::default()")] @@ -124,7 +150,8 @@ pub struct SMBServer = TcpListene #[builder(field( type = "HashMap>>>" ))] - lease_table_list: HashMap>>>, + lease_table_list: + HashMap>>>, #[builder(default = "5000")] max_resiliency_timeout: u64, #[builder(default = "5000")] @@ -164,14 +191,21 @@ pub struct SMBServer = TcpListene auth_provider: Arc, } -impl, Auth: AuthProvider, Share: SharedResource, Handle=Handle>, Handle: ResourceHandle> Server for SMBServer { +impl< + Addrs: Send + Sync, + Listener: SMBSocket, + Auth: AuthProvider, + Share: SharedResource, Handle = Handle>, + Handle: ResourceHandle, +> Server for SMBServer +{ type Connection = SMBConnectionType; type Session = SMBSession; type Share = Share; type Open = SMBOpen; type Lease = SMBLease; type AuthProvider = Auth; - type Handle = Handle; + type Handle = Handle; fn shares(&self) -> &HashMap> { &self.share_list @@ -283,7 +317,14 @@ impl, Auth: AuthProvider, Share: } } -impl, Auth: AuthProvider, Share: SharedResource, Handle=Handle>, Handle: ResourceHandle> SMBServerBuilder { +impl< + Addrs: Send + Sync, + Listener: SMBSocket, + Auth: AuthProvider, + Share: SharedResource, Handle = Handle>, + Handle: ResourceHandle, +> SMBServerBuilder +{ #[cfg(not(feature = "async"))] pub fn listener_address(self, addr: Addrs) -> SMBResult { Ok(self.local_listener(SMBListener::new(addr)?)) @@ -318,7 +359,14 @@ pub enum HashLevel { EnableShare, } -impl, Auth: AuthProvider + 'static, Share: SharedResource, Handle=Handle>, Handle: ResourceHandle + 'static> SMBServer { +impl< + Addrs: Send + Sync, + Listener: SMBSocket, + Auth: AuthProvider + 'static, + Share: SharedResource, Handle = Handle>, + Handle: ResourceHandle + 'static, +> SMBServer +{ pub fn initialize(&mut self) { self.statistics = Default::default(); self.guid = Uuid::new_v4(); @@ -338,10 +386,18 @@ impl< Addrs: Send + Sync, Listener: SMBSocket, Auth: AuthProvider + 'static, - Share: SharedResource, Handle=Handle> + From, Handle>>, - Handle: ResourceHandle + 'static + From + TryInto -> SMBServerBuilder { - pub fn add_fs_share(self, name: String, path: String, connect_allowed: ConnectAllowed>, file_perms: FilePerms>) -> Self { + Share: SharedResource, Handle = Handle> + + From, Handle>>, + Handle: ResourceHandle + 'static + From + TryInto, +> SMBServerBuilder +{ + pub fn add_fs_share( + self, + name: String, + path: String, + connect_allowed: ConnectAllowed>, + file_perms: FilePerms>, + ) -> Self { let share = SMBFileSystemShare::path(name.clone(), path, connect_allowed, file_perms); self.add_share(name, share.into()) } @@ -351,29 +407,34 @@ impl< Addrs: Send + Sync, Listener: SMBSocket, Auth: AuthProvider + 'static, - Share: SharedResource, Handle=Handle> + From, Handle>>, - Handle: ResourceHandle + 'static + From -> SMBServerBuilder { + Share: SharedResource, Handle = Handle> + + From, Handle>>, + Handle: ResourceHandle + 'static + From, +> SMBServerBuilder +{ pub fn add_ipc_share(self) -> Self { let share: SMBIPCShare, Handle> = SMBIPCShare::new(); self.add_share("ipc$", share.into()) } } -impl + 'static, Auth: AuthProvider + 'static, Share: SharedResource, Handle=Handle> + 'static, Handle: ResourceHandle + 'static> StartSMBServer for Arc>> { +impl< + Addrs: Send + Sync + 'static, + Listener: SMBSocket + 'static, + Auth: AuthProvider + 'static, + Share: SharedResource, Handle = Handle> + 'static, + Handle: ResourceHandle + 'static, +> StartSMBServer for Arc>> +{ async fn start(&self) -> SMBResult<()> { let (rx, mut tx) = mpsc::channel(10); - let diagnostics = { - self.read().await.statistics.clone() - }; + let diagnostics = { self.read().await.statistics.clone() }; tokio::spawn(async move { while let Some(update) = tx.recv().await { diagnostics.write().await.update(update); } }); - let listener = { - self.read().await.local_listener.clone() - }; + let listener = { self.read().await.local_listener.clone() }; info!("SMB server accepting connections"); while let Some(connection) = listener.lock().await.connections().next().await { let smb_connection = SMBConnection::try_from((connection, Arc::downgrade(self)))?; @@ -382,13 +443,22 @@ impl + 'static, Auth: A let socket = smb_connection.underlying_socket(); let wrapped_connection = Arc::new(RwLock::new(smb_connection)); { - self.write().await.connection_list.insert(name.clone(), Arc::downgrade(&wrapped_connection)); + self.write() + .await + .connection_list + .insert(name.clone(), Arc::downgrade(&wrapped_connection)); } let update_channel = rx.clone(); tokio::spawn(async move { debug!(client = %name, "starting message handler"); let mut stream = socket.lock().await; - match SMBConnection::start_message_handler::(&mut stream, wrapped_connection, update_channel).await { + match SMBConnection::start_message_handler::( + &mut stream, + wrapped_connection, + update_channel, + ) + .await + { Ok(()) => debug!("message handler completed"), Err(_e) => warn!(?e, "message handler exited with error"), } @@ -479,4 +549,4 @@ impl SMBServerDiagnostics { self.big_buffer_need += big_buffer_need; } } -} \ No newline at end of file +} diff --git a/smb/src/server/open.rs b/smb/src/server/open.rs index c75e766..4b32141 100644 --- a/smb/src/server/open.rs +++ b/smb/src/server/open.rs @@ -6,15 +6,17 @@ use uuid::Uuid; use smb_core::SMBResult; +use crate::protocol::body::create::SMBCreateRequest; use crate::protocol::body::create::file_attributes::SMBFileAttributes; use crate::protocol::body::create::file_id::SMBFileId; use crate::protocol::body::create::oplock::SMBOplockLevel; use crate::protocol::body::create::options::SMBCreateOptions; -use crate::protocol::body::create::SMBCreateRequest; use crate::protocol::body::tree_connect::access_mask::SMBAccessMask; -use crate::server::lease::SMBLease; -use crate::server::message_handler::{SMBLockedMessageHandler, SMBLockedMessageHandlerBase, SMBMessageType}; use crate::server::Server; +use crate::server::lease::SMBLease; +use crate::server::message_handler::{ + SMBLockedMessageHandler, SMBLockedMessageHandlerBase, SMBMessageType, +}; use crate::server::share::{ResourceHandle, SMBFileMetadata}; use crate::server::tree_connect::SMBTreeConnect; @@ -157,7 +159,7 @@ struct FileAttributes; pub enum SMBOplockState { Held, Breaking, - None + None, } #[derive(Debug)] @@ -166,7 +168,13 @@ pub struct LockSequence { valid: bool, } -impl Debug for SMBOpen where S: Debug, S::Session: Debug, S::Handle: Debug, S::Share: Debug { +impl Debug for SMBOpen +where + S: Debug, + S::Session: Debug, + S::Handle: Debug, + S::Share: Debug, +{ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.debug_struct("SMBOpen") .field("file_share_id", &self.file_share_id) @@ -180,7 +188,10 @@ impl Debug for SMBOpen where S: Debug, S::Session: Debug, S::Handl .field("oplock_timeout", &self.oplock_timeout) .field("is_durable", &self.is_durable) .field("durable_open_timeout", &self.durable_open_timeout) - .field("durable_open_scavenger_timeout", &self.durable_open_scavenger_timeout) + .field( + "durable_open_scavenger_timeout", + &self.durable_open_scavenger_timeout, + ) .field("durable_owner", &self.durable_owner) .field("underlying", &self.underlying) .field("current_ea_index", &self.current_ea_index) @@ -202,10 +213,19 @@ impl Debug for SMBOpen where S: Debug, S::Session: Debug, S::Handl .field("is_persistent", &self.is_persistent) .field("channel_sequence", &self.channel_sequence) .field("outstanding_request_count", &self.outstanding_request_count) - .field("outstanding_pre_request_count", &self.outstanding_pre_request_count) + .field( + "outstanding_pre_request_count", + &self.outstanding_pre_request_count, + ) .field("is_shared_vhdx", &self.is_shared_vhdx) - .field("application_instance_version_high", &self.application_instance_version_high) - .field("application_instance_version_low", &self.application_instance_version_low) + .field( + "application_instance_version_high", + &self.application_instance_version_high, + ) + .field( + "application_instance_version_low", + &self.application_instance_version_low, + ) .finish() } } diff --git a/smb/src/server/preauth_session.rs b/smb/src/server/preauth_session.rs index c9f1aea..0ebaadc 100644 --- a/smb/src/server/preauth_session.rs +++ b/smb/src/server/preauth_session.rs @@ -1,7 +1,7 @@ #[derive(Debug, Clone)] pub struct SMBPreauthSession { session_id: u64, - preauth_integrity_hash_value: Vec + preauth_integrity_hash_value: Vec, } impl SMBPreauthSession { @@ -11,4 +11,4 @@ impl SMBPreauthSession { preauth_integrity_hash_value, } } -} \ No newline at end of file +} diff --git a/smb/src/server/request.rs b/smb/src/server/request.rs index f73a70d..13ef3d3 100644 --- a/smb/src/server/request.rs +++ b/smb/src/server/request.rs @@ -1,6 +1,5 @@ - -use crate::server::open::SMBOpen; use crate::server::Server; +use crate::server::open::SMBOpen; pub trait Request: Send + Sync {} @@ -11,5 +10,5 @@ pub struct SMBRequest { open: SMBOpen, is_encrypted: bool, transform_session_id: u64, - compress_reply: bool -} \ No newline at end of file + compress_reply: bool, +} diff --git a/smb/src/server/safe_locked_getter.rs b/smb/src/server/safe_locked_getter.rs index 531294c..b653f19 100644 --- a/smb/src/server/safe_locked_getter.rs +++ b/smb/src/server/safe_locked_getter.rs @@ -3,12 +3,12 @@ use std::sync::Arc; use tokio::sync::RwLock; -use smb_core::error::SMBError; use smb_core::SMBResult; +use smb_core::error::SMBError; pub trait SafeLockedGetter { type Upper; - fn upper(&self) -> impl Future>>>; + fn upper(&self) -> impl Future>>>; } pub trait InnerGetter { @@ -20,7 +20,9 @@ impl SafeLockedGetter for Arc> { type Upper = Inner::Upper; async fn upper(&self) -> SMBResult>> { - self.read().await.upper() + self.read() + .await + .upper() .ok_or(SMBError::server_error("No server available")) } -} \ No newline at end of file +} diff --git a/smb/src/server/session.rs b/smb/src/server/session.rs index 719d90c..5bf7865 100644 --- a/smb/src/server/session.rs +++ b/smb/src/server/session.rs @@ -11,28 +11,30 @@ use hmac::Hmac; use sha2::Sha256; use tokio::sync::RwLock; +use crate::protocol::body::create::file_id::SMBFileId; +use smb_core::SMBResult; use smb_core::error::SMBError; -use smb_core::logging::{trace, debug, info}; +use smb_core::logging::{debug, info, trace}; use smb_core::nt_status::NTStatus; -use smb_core::SMBResult; -use crate::protocol::body::create::file_id::SMBFileId; +use crate::protocol::body::SMBBody; use crate::protocol::body::dialect::SMBDialect; use crate::protocol::body::negotiate::context::EncryptionCipher; use crate::protocol::body::negotiate::context::EncryptionCipher::AES256CCM; use crate::protocol::body::session_setup::{SMBSessionSetupRequest, SMBSessionSetupResponse}; -use crate::protocol::body::SMBBody; use crate::protocol::body::tree_connect::{SMBTreeConnectRequest, SMBTreeConnectResponse}; use crate::protocol::header::SMBSyncHeader; use crate::protocol::message::SMBMessage; +use crate::server::Server; use crate::server::connection::Connection; -use crate::server::message_handler::{NonEndingHandler, SMBHandlerState, SMBLockedMessageHandlerBase}; +use crate::server::message_handler::{ + NonEndingHandler, SMBHandlerState, SMBLockedMessageHandlerBase, +}; use crate::server::open::Open; use crate::server::safe_locked_getter::InnerGetter; -use crate::server::Server; use crate::server::tree_connect::SMBTreeConnect; -use crate::util::auth::{AuthContext, AuthProvider}; use crate::util::auth::spnego::{SPNEGOToken, SPNEGOTokenResponseBody}; +use crate::util::auth::{AuthContext, AuthProvider}; use crate::util::crypto::sp800_108::derive_key; use crate::util::num_limits::{MaxVal, MinVal, One, Zero}; @@ -42,9 +44,14 @@ type SMBMessageType = SMBMessage; const _OUTPUT_SIZE_128: usize = 128; const _OUTPUT_SIZE_256: usize = 256; - pub trait Session: Send + Sync { - fn init(id: u64, encrypt_data: bool, preauth_integrity_hash_value: Vec, conn: Weak>, provider: Arc) -> Self; + fn init( + id: u64, + encrypt_data: bool, + preauth_integrity_hash_value: Vec, + conn: Weak>, + provider: Arc, + ) -> Self; fn id(&self) -> u64; fn connection(&self) -> Weak>; fn connection_res(&self) -> SMBResult>>; @@ -56,7 +63,7 @@ pub trait Session: Send + Sync { fn provider(&self) -> &Arc; fn encrypt_data(&self) -> bool; fn open_table(&self) -> &HashMap>>; - fn add_open(&mut self, open: Arc>) -> impl Future; + fn add_open(&mut self, open: Arc>) -> impl Future; fn set_previous_file_id(&mut self, file_id: SMBFileId); fn signing_key(&self) -> &[u8]; } @@ -89,7 +96,7 @@ pub struct SMBSession { application_key: Vec, preauth_integrity_hash_value: Vec, full_session_key: Vec, - prev_command_id: Option + prev_command_id: Option, } // impl InnerGetter for SMBSession { @@ -102,13 +109,14 @@ pub struct SMBSession { impl Debug for SMBSession { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "SMBSession {{}}", ) + write!(f, "SMBSession {{}}",) } } impl SMBSession { fn get_connection(&self) -> SMBResult>> { - self.connection.upgrade() + self.connection + .upgrade() .ok_or(SMBError::server_error("Connection not found for session")) } async fn handle_successful_setup(&mut self, session_key: Vec) -> SMBResult<()> { @@ -133,22 +141,24 @@ impl SMBSession { debug!(?dialect, ?cipher_id, "generating session keys"); let key_length = match cipher_id { EncryptionCipher::AES256GCM | AES256CCM => 32, - _ => 16 + _ => 16, }; let signing_key_len = 16; - let smb_sign_bytes = [ - "SmbSign".as_bytes(), - &[0], - ].concat(); + let smb_sign_bytes = ["SmbSign".as_bytes(), &[0]].concat(); let (signing_key_label, signing_key_context): (&str, &[u8]) = match dialect { SMBDialect::V3_1_1 => ("SMBSigningKey", &self.preauth_integrity_hash_value), _ => ("SMB2AESCMAC", &smb_sign_bytes), }; self.signing_key = match dialect { - SMBDialect::V3_0_0 | SMBDialect::V3_0_2 | SMBDialect::V3_1_1 => generate_key(&self.session_key, signing_key_label, signing_key_context, signing_key_len), + SMBDialect::V3_0_0 | SMBDialect::V3_0_2 | SMBDialect::V3_1_1 => generate_key( + &self.session_key, + signing_key_label, + signing_key_context, + signing_key_len, + ), _ => self.session_key.clone().to_vec(), }; @@ -156,18 +166,32 @@ impl SMBSession { let (application_key_label, application_key_context): (&str, &[u8]) = match dialect { SMBDialect::V3_1_1 => ("SMBAppKey", &self.preauth_integrity_hash_value), - _ => ("SMB2APP", "SmbRpc".as_bytes()) + _ => ("SMB2APP", "SmbRpc".as_bytes()), }; - self.application_key = generate_key(&self.session_key, application_key_label, application_key_context, key_length); - trace!(signing_key_len = self.signing_key.len(), application_key_len = self.application_key.len(), "key generation complete"); + self.application_key = generate_key( + &self.session_key, + application_key_label, + application_key_context, + key_length, + ); + trace!( + signing_key_len = self.signing_key.len(), + application_key_len = self.application_key.len(), + "key generation complete" + ); } - fn get_next_map_id(map: &HashMap) -> K { + fn get_next_map_id< + K: MaxVal + MinVal + One + Zero + Eq + PartialEq + AddAssign + PartialOrd + std::hash::Hash, + V, + >( + map: &HashMap, + ) -> K { let mut i = K::min_val(); while i < K::max_val() { if map.get(&i).is_none() { return i; } - + i += K::one(); } K::zero() @@ -175,28 +199,40 @@ impl SMBSession { } fn generate_key(secure_key: &[u8], label: &str, context: &[u8], output_len: usize) -> Vec { - trace!(key_len = secure_key.len(), label, context_len = context.len(), output_len, "deriving key via KDF"); + trace!( + key_len = secure_key.len(), + label, + context_len = context.len(), + output_len, + "deriving key via KDF" + ); let mac = >::new_from_slice(secure_key) - .map_err(|_| SMBError::crypto_error("Invalid Key Length")).unwrap(); - let label_bytes = [ - label.as_bytes(), - &[0] - ].concat(); + .map_err(|_| SMBError::crypto_error("Invalid Key Length")) + .unwrap(); + let label_bytes = [label.as_bytes(), &[0]].concat(); derive_key(mac, &label_bytes, context, (output_len * 8) as u32) } -impl>> NonEndingHandler for Arc>> {} +impl>> NonEndingHandler for Arc>> {} -impl>> SMBLockedMessageHandlerBase for Arc>> { +impl>> SMBLockedMessageHandlerBase + for Arc>> +{ type Inner = Arc>; async fn inner(&self, message: &SMBMessageType) -> Option { let write = self.write().await; debug!(tid = message.header.tree_id, command = ?message.header.command, "looking up tree connect"); - write.tree_connect_table.get(&message.header.tree_id) + write + .tree_connect_table + .get(&message.header.tree_id) .map(Arc::clone) } - async fn handle_session_setup(&mut self, header: &SMBSyncHeader, request: &SMBSessionSetupRequest) -> SMBResult> { + async fn handle_session_setup( + &mut self, + header: &SMBSyncHeader, + request: &SMBSessionSetupRequest, + ) -> SMBResult> { let buffer = request.buffer(); let (_, token) = SPNEGOToken::::parse(buffer)?; let mut session_write = self.write().await; @@ -213,7 +249,10 @@ impl>> SMBLockedMessageHandlerBase for Arc::new(status, msg); let (id, session_setup) = { let session_read = self.read().await; - let resp = SMBSessionSetupResponse::from_session_state::(&session_read, response.as_bytes()); + let resp = SMBSessionSetupResponse::from_session_state::( + &session_read, + response.as_bytes(), + ); (session_read.id(), resp) }; let header = header.create_response_header(status as u32, id, 0); @@ -221,7 +260,11 @@ impl>> SMBLockedMessageHandlerBase for Arc SMBResult> { + async fn handle_tree_connect( + &mut self, + header: &SMBSyncHeader, + request: &SMBTreeConnectRequest, + ) -> SMBResult> { let self_rd = self.read().await; let conn = self_rd.get_connection()?; drop(self_rd); @@ -229,22 +272,29 @@ impl>> SMBLockedMessageHandlerBase for Arc::get_next_map_id(&self_rd.tree_connect_table); - let tree_connect = SMBTreeConnect::init(tree_id, Arc::downgrade(self), share.clone(), response.access_mask().clone()); + let tree_connect = SMBTreeConnect::init( + tree_id, + Arc::downgrade(self), + share.clone(), + response.access_mask().clone(), + ); let header = SMBSyncHeader::create_response_header(header, 0, self_rd.id(), 1); drop(self_rd); let mut self_wr = self.write().await; - self_wr.tree_connect_table.insert(tree_id, Arc::new(tree_connect)); + self_wr + .tree_connect_table + .insert(tree_id, Arc::new(tree_connect)); let message = SMBMessage::new(header, SMBBody::TreeConnectResponse(response)); Ok(SMBHandlerState::Finished(message)) } @@ -262,12 +312,17 @@ impl InnerGetter for SMBSession { pub enum SessionState { InProgress, Valid, - Expired + Expired, } -impl> Session for SMBSession { - fn init(id: u64, encrypt_data: bool, preauth_integrity_hash_value: Vec, conn: Weak>, provider: Arc) -> Self { - +impl> Session for SMBSession { + fn init( + id: u64, + encrypt_data: bool, + preauth_integrity_hash_value: Vec, + conn: Weak>, + provider: Arc, + ) -> Self { Self { session_id: id, state: SessionState::InProgress, @@ -305,7 +360,8 @@ impl> Session f } fn connection_res(&self) -> SMBResult>> { - self.connection.upgrade() + self.connection + .upgrade() .ok_or(SMBError::server_error("Connection not found for session")) } @@ -350,7 +406,7 @@ impl> Session f } fn set_previous_file_id(&mut self, file_id: SMBFileId) { - self.prev_command_id = Some(file_id); + self.prev_command_id = Some(file_id); } fn signing_key(&self) -> &[u8] { @@ -373,7 +429,11 @@ mod tests { fn generate_key_is_not_raw_session_key() { let session_key = [0xBB; 16]; let key = generate_key(&session_key, "SMB2AESCMAC", b"SmbSign\0", 16); - assert_ne!(key, session_key.to_vec(), "KDF output must differ from raw session key"); + assert_ne!( + key, + session_key.to_vec(), + "KDF output must differ from raw session key" + ); } /// Regression test for B3: V3_0_2 must take the KDF branch, not the raw @@ -392,8 +452,15 @@ mod tests { let key_v300 = generate_key(&session_key, label, context, 16); // Both 3.0 and 3.0.2 use the same KDF path, so keys must match - assert_eq!(key_v302, key_v300, "V3_0_2 and V3_0_0 should derive identical signing keys"); + assert_eq!( + key_v302, key_v300, + "V3_0_2 and V3_0_0 should derive identical signing keys" + ); // And neither should be the raw session key - assert_ne!(key_v302, session_key.to_vec(), "V3_0_2 signing key must not be the raw session key"); + assert_ne!( + key_v302, + session_key.to_vec(), + "V3_0_2 signing key must not be the raw session key" + ); } -} \ No newline at end of file +} diff --git a/smb/src/server/share/file_system.rs b/smb/src/server/share/file_system.rs index 9441b49..ab0a9a1 100644 --- a/smb/src/server/share/file_system.rs +++ b/smb/src/server/share/file_system.rs @@ -5,15 +5,17 @@ use std::fs::{File, OpenOptions, ReadDir}; use std::marker::PhantomData; use std::time::{SystemTime, UNIX_EPOCH}; +use smb_core::SMBResult; use smb_core::error::SMBError; use smb_core::logging::debug; -use smb_core::SMBResult; use crate::protocol::body::create::disposition::SMBCreateDisposition; use crate::protocol::body::filetime::FileTime; use crate::protocol::body::tree_connect::access_mask::SMBAccessMask; use crate::protocol::body::tree_connect::flags::SMBShareFlags; -use crate::server::share::{ConnectAllowed, FilePerms, ResourceHandle, ResourceType, SharedResource, SMBFileMetadata}; +use crate::server::share::{ + ConnectAllowed, FilePerms, ResourceHandle, ResourceType, SMBFileMetadata, SharedResource, +}; #[derive(Debug)] pub struct SMBFileSystemHandle { @@ -24,7 +26,7 @@ pub struct SMBFileSystemHandle { #[derive(Debug)] pub enum SMBFileSystemResourceHandle { File(File), - Directory(ReadDir) + Directory(ReadDir), } impl From for Box { @@ -37,13 +39,21 @@ impl TryFrom> for SMBFileSystemHandle { type Error = SMBError; fn try_from(value: Box) -> Result { - value.into_any().downcast::() - .ok().ok_or(SMBError::server_error("Invalid resource handle")) + value + .into_any() + .downcast::() + .ok() + .ok_or(SMBError::server_error("Invalid resource handle")) .map(|val| *val) } } -impl + TryInto + ResourceHandle + 'static> From> for Box> { +impl< + UserName: Send + Sync + 'static, + Handle: From + TryInto + ResourceHandle + 'static, +> From> + for Box> +{ fn from(value: SMBFileSystemShare) -> Self { Box::new(value) } @@ -61,7 +71,7 @@ impl ResourceHandle for SMBFileSystemHandle { fn is_directory(&self) -> bool { match &self.resource { SMBFileSystemResourceHandle::File(_) => false, - SMBFileSystemResourceHandle::Directory(_) => true + SMBFileSystemResourceHandle::Directory(_) => true, } } @@ -70,18 +80,25 @@ impl ResourceHandle for SMBFileSystemHandle { } fn metadata(&self) -> SMBResult { - let metadata = fs::metadata(self.path()) - .map_err(|err| SMBError::server_error(format!("Failed to get metadata for path: {}, error: {}", self.path(), err)))?; - let time_transform = |time: SystemTime| { - time.duration_since(UNIX_EPOCH) - .unwrap() - .as_secs() - }; + let metadata = fs::metadata(self.path()).map_err(|err| { + SMBError::server_error(format!( + "Failed to get metadata for path: {}, error: {}", + self.path(), + err + )) + })?; + let time_transform = |time: SystemTime| time.duration_since(UNIX_EPOCH).unwrap().as_secs(); Ok(SMBFileMetadata { creation_time: FileTime::from_unix(metadata.created().map(time_transform).unwrap_or(0)), - last_access_time: FileTime::from_unix(metadata.accessed().map(time_transform).unwrap_or(0)), - last_write_time: FileTime::from_unix(metadata.modified().map(time_transform).unwrap_or(0)), - last_modification_time: FileTime::from_unix(metadata.modified().map(time_transform).unwrap_or(0)), + last_access_time: FileTime::from_unix( + metadata.accessed().map(time_transform).unwrap_or(0), + ), + last_write_time: FileTime::from_unix( + metadata.modified().map(time_transform).unwrap_or(0), + ), + last_modification_time: FileTime::from_unix( + metadata.modified().map(time_transform).unwrap_or(0), + ), allocated_size: metadata.len(), actual_size: metadata.len(), }) @@ -91,33 +108,21 @@ impl ResourceHandle for SMBFileSystemHandle { impl SMBFileSystemResourceHandle { fn file(path: &str, disposition: SMBCreateDisposition) -> SMBResult { let mut options = OpenOptions::new(); - options.read(true) - .write(true); + options.read(true).write(true); match disposition { - SMBCreateDisposition::Supersede => options - .truncate(true) - .create(true), - SMBCreateDisposition::Open => options - .create(false), - SMBCreateDisposition::Create => options - .create_new(true), - SMBCreateDisposition::OpenIf => options - .truncate(false) - .create(true), - SMBCreateDisposition::Overwrite => options - .truncate(true) - .create(false), - SMBCreateDisposition::OverwriteIf => options - .truncate(false) - .create(true) + SMBCreateDisposition::Supersede => options.truncate(true).create(true), + SMBCreateDisposition::Open => options.create(false), + SMBCreateDisposition::Create => options.create_new(true), + SMBCreateDisposition::OpenIf => options.truncate(false).create(true), + SMBCreateDisposition::Overwrite => options.truncate(true).create(false), + SMBCreateDisposition::OverwriteIf => options.truncate(false).create(true), }; let file = options.open(path).map_err(SMBError::io_error)?; Ok(Self::File(file)) } fn directory(path: &str) -> SMBResult { - let res = std::fs::read_dir(path) - .map_err(SMBError::io_error)?; + let res = std::fs::read_dir(path).map_err(SMBError::io_error)?; Ok(Self::Directory(res)) } } @@ -149,7 +154,11 @@ pub struct SMBFileSystemShare, } -impl + ResourceHandle + TryInto> SharedResource for SMBFileSystemShare { +impl< + UserName: Send + Sync, + Handle: From + ResourceHandle + TryInto, +> SharedResource for SMBFileSystemShare +{ type UserName = UserName; type Handle = Handle; @@ -165,16 +174,18 @@ impl + ResourceHandle + self.csc_flags } - fn handle_create(&self, path: &str, disposition: SMBCreateDisposition, directory: bool) -> SMBResult { + fn handle_create( + &self, + path: &str, + disposition: SMBCreateDisposition, + directory: bool, + ) -> SMBResult { let path = format!("{}/{}", self.local_path, path); let resource = match directory { true => SMBFileSystemResourceHandle::directory(&path), - false => SMBFileSystemResourceHandle::file(&path, disposition) + false => SMBFileSystemResourceHandle::file(&path, disposition), }?; - let handle = SMBFileSystemHandle { - resource, - path, - }; + let handle = SMBFileSystemHandle { resource, path }; debug!(?handle, "created filesystem handle"); Ok(handle.into()) } @@ -188,11 +199,22 @@ impl + ResourceHandle + } } -impl> SMBFileSystemShare { - pub fn root(name: String, connect_security: ConnectAllowed, file_security: FilePerms) -> Self { +impl> + SMBFileSystemShare +{ + pub fn root( + name: String, + connect_security: ConnectAllowed, + file_security: FilePerms, + ) -> Self { Self::path(name, "".into(), connect_security, file_security) } - pub fn path(name: String, path: String, connect_security: ConnectAllowed, file_security: FilePerms) -> Self { + pub fn path( + name: String, + path: String, + connect_security: ConnectAllowed, + file_security: FilePerms, + ) -> Self { Self { name, server_name: "localhost".into(), @@ -217,12 +239,14 @@ impl> SMBFileSystemS supports_identity_remoting: true, compress_data: false, user_name_type: PhantomData, - handle_phantom: PhantomData + handle_phantom: PhantomData, } } } -impl> Debug for SMBFileSystemShare { +impl> Debug + for SMBFileSystemShare +{ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.debug_struct("SMBServer") .field("name", &self.name) @@ -230,10 +254,16 @@ impl> Debug for SMBF .field("local_path", &self.local_path) .field("csc_flags", &self.csc_flags) .field("dfs_enabled", &self.dfs_enabled) - .field("do_access_based_directory_enumeration", &self.do_access_based_directory_enumeration) + .field( + "do_access_based_directory_enumeration", + &self.do_access_based_directory_enumeration, + ) .field("allow_namespace_caching", &self.allow_namespace_caching) .field("force_shared_delete", &self.force_shared_delete) - .field("restrict_exclusive_options", &self.restrict_exclusive_options) + .field( + "restrict_exclusive_options", + &self.restrict_exclusive_options, + ) .field("remark", &self.remark) .field("max_uses", &self.max_uses) .field("current_uses", &self.current_uses) @@ -243,8 +273,11 @@ impl> Debug for SMBF .field("ca_timeout", &self.ca_timeout) .field("continuously_available", &self.continuously_available) .field("encrypt_data", &self.encrypt_data) - .field("supports_identity_remoting", &self.supports_identity_remoting) + .field( + "supports_identity_remoting", + &self.supports_identity_remoting, + ) .field("compress_data", &self.compress_data) .finish() } -} \ No newline at end of file +} diff --git a/smb/src/server/share/ipc.rs b/smb/src/server/share/ipc.rs index 04e6152..9579e8e 100644 --- a/smb/src/server/share/ipc.rs +++ b/smb/src/server/share/ipc.rs @@ -6,9 +6,11 @@ use smb_core::SMBResult; use crate::protocol::body::create::disposition::SMBCreateDisposition; use crate::protocol::body::filetime::FileTime; -use crate::protocol::body::tree_connect::access_mask::{SMBAccessMask, SMBFilePipePrinterAccessMask}; +use crate::protocol::body::tree_connect::access_mask::{ + SMBAccessMask, SMBFilePipePrinterAccessMask, +}; use crate::protocol::body::tree_connect::flags::SMBShareFlags; -use crate::server::share::{ResourceHandle, ResourceType, SharedResource, SMBFileMetadata}; +use crate::server::share::{ResourceHandle, ResourceType, SMBFileMetadata, SharedResource}; /// A minimal IPC$ named pipe share handle #[derive(Debug)] @@ -57,7 +59,9 @@ pub struct SMBIPCShare + Resou _handle: PhantomData, } -impl + ResourceHandle> Default for SMBIPCShare { +impl + ResourceHandle> Default + for SMBIPCShare +{ fn default() -> Self { Self { _user_name: PhantomData, @@ -66,25 +70,34 @@ impl + ResourceHandle> Default } } -impl + ResourceHandle> SMBIPCShare { +impl + ResourceHandle> + SMBIPCShare +{ pub fn new() -> Self { Self::default() } } -impl + ResourceHandle> Debug for SMBIPCShare { +impl + ResourceHandle> Debug + for SMBIPCShare +{ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.debug_struct("SMBIPCShare").finish() } } -impl + ResourceHandle + 'static> From> for Box> { +impl + ResourceHandle + 'static> + From> + for Box> +{ fn from(value: SMBIPCShare) -> Self { Box::new(value) } } -impl + ResourceHandle> SharedResource for SMBIPCShare { +impl + ResourceHandle> SharedResource + for SMBIPCShare +{ type UserName = UserName; type Handle = Handle; @@ -100,7 +113,12 @@ impl + ResourceHandle> SharedR SMBShareFlags::default() } - fn handle_create(&self, path: &str, _disposition: SMBCreateDisposition, _directory: bool) -> SMBResult { + fn handle_create( + &self, + path: &str, + _disposition: SMBCreateDisposition, + _directory: bool, + ) -> SMBResult { let handle = SMBIPCHandle { path: path.to_string(), }; diff --git a/smb/src/server/share/mod.rs b/smb/src/server/share/mod.rs index 8856db6..17d7af4 100644 --- a/smb/src/server/share/mod.rs +++ b/smb/src/server/share/mod.rs @@ -8,9 +8,9 @@ use smb_core::SMBResult; use crate::protocol::body::create::disposition::SMBCreateDisposition; use crate::protocol::body::filetime::FileTime; +use crate::protocol::body::tree_connect::SMBShareType; use crate::protocol::body::tree_connect::access_mask::SMBAccessMask; use crate::protocol::body::tree_connect::flags::SMBShareFlags; -use crate::protocol::body::tree_connect::SMBShareType; pub mod file_system; pub mod ipc; @@ -63,7 +63,12 @@ pub trait SharedResource: Send + Sync { fn name(&self) -> &str; fn resource_type(&self) -> ResourceType; fn flags(&self) -> SMBShareFlags; - fn handle_create(&self, path: &str, disposition: SMBCreateDisposition, directory: bool) -> SMBResult; + fn handle_create( + &self, + path: &str, + disposition: SMBCreateDisposition, + directory: bool, + ) -> SMBResult; fn close(&self, handle: Self::Handle) -> SMBResult<()> { Box::new(handle).close() } @@ -88,7 +93,12 @@ impl SharedResource for Box { T::flags(self) } - fn handle_create(&self, path: &str, disposition: SMBCreateDisposition, directory: bool) -> SMBResult { + fn handle_create( + &self, + path: &str, + disposition: SMBCreateDisposition, + directory: bool, + ) -> SMBResult { T::handle_create(self, path, disposition, directory) } @@ -125,7 +135,7 @@ impl From for ResourceType { match value { SMBShareType::Disk => ResourceType::DISK, SMBShareType::Pipe => ResourceType::IPC, - SMBShareType::Print => ResourceType::PRINT_QUEUE + SMBShareType::Print => ResourceType::PRINT_QUEUE, } } -} \ No newline at end of file +} diff --git a/smb/src/server/tree_connect.rs b/smb/src/server/tree_connect.rs index ae9c769..992e352 100644 --- a/smb/src/server/tree_connect.rs +++ b/smb/src/server/tree_connect.rs @@ -8,16 +8,18 @@ use smb_core::SMBResult; use smb_core::error::SMBError; use smb_core::logging::{debug, trace}; +use crate::protocol::body::SMBBody; use crate::protocol::body::create::{SMBCreateRequest, SMBCreateResponse}; use crate::protocol::body::filetime::FileTime; -use crate::protocol::body::SMBBody; use crate::protocol::body::tree_connect::access_mask::SMBAccessMask; use crate::protocol::header::SMBSyncHeader; use crate::protocol::message::SMBMessage; -use crate::server::message_handler::{SMBHandlerState, SMBLockedMessageHandler, SMBLockedMessageHandlerBase, SMBMessageType}; +use crate::server::Server; +use crate::server::message_handler::{ + SMBHandlerState, SMBLockedMessageHandler, SMBLockedMessageHandlerBase, SMBMessageType, +}; use crate::server::open::{Open, SMBOpen}; use crate::server::safe_locked_getter::SafeLockedGetter; -use crate::server::Server; use crate::server::session::Session; use crate::server::share::SharedResource; @@ -34,7 +36,12 @@ pub struct SMBTreeConnect { } impl SMBTreeConnect { - pub fn init(tree_id: u32, session: Weak>, share: Arc, maximal_access: SMBAccessMask) -> SMBTreeConnect { + pub fn init( + tree_id: u32, + session: Weak>, + share: Arc, + maximal_access: SMBAccessMask, + ) -> SMBTreeConnect { Self { tree_id, session, @@ -54,17 +61,22 @@ impl SMBLockedMessageHandlerBase for Arc> { None } - async fn handle_create(&mut self, header: &SMBSyncHeader, message: &SMBCreateRequest) -> SMBResult> { + async fn handle_create( + &mut self, + header: &SMBSyncHeader, + message: &SMBCreateRequest, + ) -> SMBResult> { let (path, disposition, directory) = message.validate(self.share.deref())?; let handle = self.share.handle_create(path, disposition, directory)?; let open_raw = Open::init(handle, message); let response = SMBBody::CreateResponse(SMBCreateResponse::for_open::(&open_raw)?); let open = Arc::new(RwLock::new(open_raw)); - let session = self.session.upgrade() + let session = self + .session + .upgrade() .ok_or(SMBError::server_error("No Session Found"))?; session.write().await.add_open(open.clone()).await; - let server = session.upper().await? - .upper().await?; + let server = session.upper().await?.upper().await?; { server.write().await.add_open(open.clone()).await; } @@ -73,10 +85,17 @@ impl SMBLockedMessageHandlerBase for Arc> { session.write().await.set_previous_file_id(file_id); } debug!("tree connect create handled"); - let header = header.create_response_header(header.channel_sequence, header.session_id, header.tree_id); - trace!(response_size = response.smb_byte_size(), "create response built"); + let header = header.create_response_header( + header.channel_sequence, + header.session_id, + header.tree_id, + ); + trace!( + response_size = response.smb_byte_size(), + "create response built" + ); Ok(SMBHandlerState::Finished(SMBMessage::new(header, response))) } } -impl SMBLockedMessageHandler for Arc> {} \ No newline at end of file +impl SMBLockedMessageHandler for Arc> {} diff --git a/smb/src/socket/listener/listener_async.rs b/smb/src/socket/listener/listener_async.rs index e542210..18168b0 100644 --- a/smb/src/socket/listener/listener_async.rs +++ b/smb/src/socket/listener/listener_async.rs @@ -5,39 +5,50 @@ use std::task::{Context, Poll, ready}; use tokio::io; use tokio::io::AsyncWriteExt; -use tokio::net::{TcpListener, ToSocketAddrs}; use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; +use tokio::net::{TcpListener, ToSocketAddrs}; use tokio_stream::Stream; use tokio_util::sync::ReusableBoxFuture; -use smb_core::error::SMBError; use smb_core::SMBResult; +use smb_core::error::SMBError; use crate::socket::listener::{SMBConnectionIterator, SMBListener, SMBSocket}; use crate::socket::message_stream::{SMBSocketConnection, SMBStream}; -impl SMBSocket for TcpListener where T: ToSocketAddrs + Send + Sync { +impl SMBSocket for TcpListener +where + T: ToSocketAddrs + Send + Sync, +{ type ReadStream = OwnedReadHalf; type WriteStream = OwnedWriteHalf; - async fn new_connection(&self) -> SMBResult> { + async fn new_connection( + &self, + ) -> SMBResult> { match self.accept().await { Ok((stream, addr)) => { let (read, write) = stream.into_split(); Ok(SMBSocketConnection::new(addr.to_string(), read, write)) } - Err(e) => Err(SMBError::io_error(e)) + Err(e) => Err(SMBError::io_error(e)), } } - async fn new_socket(addr: T) -> SMBResult where Self: Sized { + async fn new_socket(addr: T) -> SMBResult + where + Self: Sized, + { Self::bind(addr).await.map_err(SMBError::io_error) } } impl SMBStream for OwnedReadHalf { async fn close_stream(&mut self) -> SMBResult<()> { - Err(SMBError::io_error(io::Error::new(ErrorKind::Unsupported, "Invalid operation"))) + Err(SMBError::io_error(io::Error::new( + ErrorKind::Unsupported, + "Invalid operation", + ))) } } @@ -49,13 +60,21 @@ impl SMBStream for OwnedWriteHalf { type SMBConnectionResult = SMBResult>; -type SMBConnectionStreamResult<'a, Addrs, Socket> = (SMBConnectionResult<>::ReadStream, >::WriteStream>, SMBConnectionIterator<'a, Addrs, Socket>); +type SMBConnectionStreamResult<'a, Addrs, Socket> = ( + SMBConnectionResult< + >::ReadStream, + >::WriteStream, + >, + SMBConnectionIterator<'a, Addrs, Socket>, +); pub struct SMBConnectionStream<'a, Addrs: Send + Sync, Socket: SMBSocket> { inner: ReusableBoxFuture<'a, SMBConnectionStreamResult<'a, Addrs, Socket>>, } -async fn make_future<'a, Addrs: Send + Sync, Socket: SMBSocket>(iterator: SMBConnectionIterator<'a, Addrs, Socket>) -> SMBConnectionStreamResult<'a, Addrs, Socket> { +async fn make_future<'a, Addrs: Send + Sync, Socket: SMBSocket>( + iterator: SMBConnectionIterator<'a, Addrs, Socket>, +) -> SMBConnectionStreamResult<'a, Addrs, Socket> { let res = iterator.server.new_connection().await; (res, iterator) } @@ -64,13 +83,13 @@ impl<'a, Addrs: Send + Sync, Socket: SMBSocket> SMBConnectionStream<'a, A pub fn new(listener: &'a SMBListener) -> Self { let iterator = SMBConnectionIterator::new(listener); let inner = ReusableBoxFuture::new(make_future(iterator)); - Self { - inner - } + Self { inner } } } -impl> Stream for SMBConnectionStream<'_, Addrs, Socket> { +impl> Stream + for SMBConnectionStream<'_, Addrs, Socket> +{ type Item = SMBSocketConnection; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -86,7 +105,10 @@ impl> Stream for SMBConnectionStrea impl> SMBListener { pub async fn new(addr: Addrs) -> SMBResult { let socket = Socket::new_socket(addr).await?; - Ok(SMBListener { socket, addrs_phantom: PhantomData }) + Ok(SMBListener { + socket, + addrs_phantom: PhantomData, + }) } } @@ -94,4 +116,4 @@ impl> SMBListener { pub fn connections(&self) -> SMBConnectionStream<'_, Addrs, Socket> { SMBConnectionStream::new(self) } -} \ No newline at end of file +} diff --git a/smb/src/socket/listener/listener_sync.rs b/smb/src/socket/listener/listener_sync.rs index 1f72685..bdfe234 100644 --- a/smb/src/socket/listener/listener_sync.rs +++ b/smb/src/socket/listener/listener_sync.rs @@ -1,8 +1,8 @@ use std::marker::PhantomData; use std::net::{TcpListener, TcpStream, ToSocketAddrs}; -use smb_core::error::SMBError; use smb_core::SMBResult; +use smb_core::error::SMBError; use crate::socket::listener::{SMBConnectionIterator, SMBListener, SMBSocket}; use crate::socket::message_stream::SMBSocketConnection; @@ -10,29 +10,42 @@ use crate::socket::message_stream::SMBSocketConnection; impl> SMBListener { pub fn new(addr: Addrs) -> SMBResult { let socket = Socket::new_socket(addr)?; - Ok(SMBListener { socket, addrs_phantom: PhantomData }) + Ok(SMBListener { + socket, + addrs_phantom: PhantomData, + }) } } -impl SMBSocket for TcpListener where T: ToSocketAddrs + Send + Sync { +impl SMBSocket for TcpListener +where + T: ToSocketAddrs + Send + Sync, +{ type ReadStream = TcpStream; type WriteStream = TcpStream; - fn new_connection(&self) -> SMBResult> { + fn new_connection( + &self, + ) -> SMBResult> { match self.accept() { Ok((read, addr)) => { let write = read.try_clone().map_err(SMBError::io_error)?; Ok(SMBSocketConnection::new(addr.to_string(), read, write)) } - Err(e) => Err(SMBError::io_error(e)) + Err(e) => Err(SMBError::io_error(e)), } } - fn new_socket(addr: T) -> SMBResult where Self: Sized { + fn new_socket(addr: T) -> SMBResult + where + Self: Sized, + { Self::bind(addr).map_err(SMBError::io_error) } } -impl> Iterator for SMBConnectionIterator<'_, Addrs, Socket> { +impl> Iterator + for SMBConnectionIterator<'_, Addrs, Socket> +{ type Item = SMBSocketConnection; fn next(&mut self) -> Option { @@ -44,4 +57,4 @@ impl> SMBListener { pub fn connections(&self) -> SMBConnectionIterator { SMBConnectionIterator { server: self } } -} \ No newline at end of file +} diff --git a/smb/src/socket/listener/mod.rs b/smb/src/socket/listener/mod.rs index ce1512b..e6d9487 100644 --- a/smb/src/socket/listener/mod.rs +++ b/smb/src/socket/listener/mod.rs @@ -3,34 +3,41 @@ use std::future::Future; use std::marker::PhantomData; use std::ops::{Deref, DerefMut}; -use smb_core::error::SMBError; use smb_core::SMBResult; +use smb_core::error::SMBError; use crate::socket::message_stream::{SMBReadStream, SMBSocketConnection, SMBWriteStream}; -#[cfg(not(feature = "async"))] -mod listener_sync; #[cfg(feature = "async")] mod listener_async; +#[cfg(not(feature = "async"))] +mod listener_sync; pub trait SMBSocket: Send + Sync { type ReadStream: SMBReadStream + Send + Sync + Debug + 'static; type WriteStream: SMBWriteStream + Send + Sync + Debug + 'static; #[cfg(not(feature = "async"))] - fn new_connection(&self) -> SMBResult>; + fn new_connection(&self) + -> SMBResult>; #[cfg(feature = "async")] - fn new_connection(&self) -> impl Future>> + Send; + fn new_connection( + &self, + ) -> impl Future>> + Send; #[cfg(not(feature = "async"))] - fn new_socket(addr: T) -> SMBResult where Self: Sized { + fn new_socket(addr: T) -> SMBResult + where + Self: Sized, + { Err(SMBError::precondition_failed("Invalid socket address type")) } #[cfg(feature = "async")] - fn new_socket(_addr: T) -> impl Future> + Send where Self: Sized { - async { - Err(SMBError::precondition_failed("Invalid socket address type")) - } + fn new_socket(_addr: T) -> impl Future> + Send + where + Self: Sized, + { + async { Err(SMBError::precondition_failed("Invalid socket address type")) } } } @@ -62,4 +69,4 @@ impl> DerefMut for SMBListener &mut Self::Target { &mut self.socket } -} \ No newline at end of file +} diff --git a/smb/src/socket/message_stream/mod.rs b/smb/src/socket/message_stream/mod.rs index da4149d..dc6a5b6 100644 --- a/smb/src/socket/message_stream/mod.rs +++ b/smb/src/socket/message_stream/mod.rs @@ -1,66 +1,91 @@ use tokio_util::sync::ReusableBoxFuture; -use smb_core::{SMBFromBytes, SMBParseResult, SMBResult}; use smb_core::error::SMBError; use smb_core::logging::{trace, warn}; +use smb_core::{SMBFromBytes, SMBParseResult, SMBResult}; -use crate::protocol::body::{LegacySMBBody, SMBBody}; use crate::protocol::body::error::SMBErrorResponse; +use crate::protocol::body::{LegacySMBBody, SMBBody}; use crate::protocol::header::{LegacySMBHeader, SMBSyncHeader}; use crate::protocol::message::{Message, SMBMessage}; -#[cfg(not(feature = "async"))] -mod stream_sync; #[cfg(feature = "async")] mod stream_async; +#[cfg(not(feature = "async"))] +mod stream_sync; pub trait SMBReadStream: SMBStream { #[cfg(feature = "async")] - fn read_message<'a>(&'a mut self, existing: &'a mut Vec) -> impl Future>> + Send; + fn read_message<'a>( + &'a mut self, + existing: &'a mut Vec, + ) -> impl Future>> + Send; #[cfg(not(feature = "async"))] - fn read_message<'a>(&'a mut self, existing: &'a mut Vec) -> SMBParseResult<&[u8], SMBMessage>; + fn read_message<'a>( + &'a mut self, + existing: &'a mut Vec, + ) -> SMBParseResult<&[u8], SMBMessage>; #[cfg(not(feature = "async"))] - fn messages(&mut self) -> SMBMessageIterator where Self: Sized; + fn messages(&mut self) -> SMBMessageIterator + where + Self: Sized; #[cfg(feature = "async")] - fn messages(&mut self) -> SMBMessageStream<'_, Self> where Self: Sized; - fn read_message_inner(buffer: &[u8]) -> SMBParseResult<&[u8], SMBMessage> { + fn messages(&mut self) -> SMBMessageStream<'_, Self> + where + Self: Sized; + fn read_message_inner( + buffer: &[u8], + ) -> SMBParseResult<&[u8], SMBMessage> { trace!(buf_len = buffer.len(), "parsing message from buffer"); if let Some(pos) = buffer.iter().position(|x| *x == b'S') && buffer[pos..].starts_with(b"SMB") { - trace!(smb_offset = pos - 1, "found SMB header"); - let smb_start = pos - 1; - let result = SMBMessage::::parse(&buffer[smb_start..]); - return match result { - Ok(r) => Ok(r), - Err(_) => { - // Try legacy parse first - if let Ok((remaining, legacy_msg)) = SMBMessage::::parse(&buffer[smb_start..]) { - return Ok((remaining, SMBMessage::::from_legacy(legacy_msg) - .ok_or(SMBError::parse_error("Invalid legacy body"))?)); - } - // Body parse failed — try header-only parse and return an ErrorResponse - // so the connection handler can send a proper error back to the client - if let Ok((_remaining, mut header)) = SMBSyncHeader::smb_from_bytes(&buffer[smb_start..]) { - warn!(command = ?header.command, "body parse failed, returning ErrorResponse"); - header.channel_sequence = smb_core::nt_status::NTStatus::NotSupported as u32; - let body = SMBBody::ErrorResponse(SMBErrorResponse::new()); - Ok((&buffer[buffer.len()..], SMBMessage::new(header, body))) - } else { - Err(SMBError::parse_error("Failed to parse header")) - } + trace!(smb_offset = pos - 1, "found SMB header"); + let smb_start = pos - 1; + let result = SMBMessage::::parse(&buffer[smb_start..]); + return match result { + Ok(r) => Ok(r), + Err(_) => { + // Try legacy parse first + if let Ok((remaining, legacy_msg)) = + SMBMessage::::parse(&buffer[smb_start..]) + { + return Ok(( + remaining, + SMBMessage::::from_legacy(legacy_msg) + .ok_or(SMBError::parse_error("Invalid legacy body"))?, + )); } - }; + // Body parse failed — try header-only parse and return an ErrorResponse + // so the connection handler can send a proper error back to the client + if let Ok((_remaining, mut header)) = + SMBSyncHeader::smb_from_bytes(&buffer[smb_start..]) + { + warn!(command = ?header.command, "body parse failed, returning ErrorResponse"); + header.channel_sequence = + smb_core::nt_status::NTStatus::NotSupported as u32; + let body = SMBBody::ErrorResponse(SMBErrorResponse::new()); + Ok((&buffer[buffer.len()..], SMBMessage::new(header, body))) + } else { + Err(SMBError::parse_error("Failed to parse header")) + } + } + }; } - Err(SMBError::parse_error("Unknown error occurred while parsing message")) + Err(SMBError::parse_error( + "Unknown error occurred while parsing message", + )) } } pub trait SMBWriteStream: SMBStream { #[cfg(feature = "async")] - fn write_message(&mut self, message: &T) -> impl Future> + Send; + fn write_message( + &mut self, + message: &T, + ) -> impl Future> + Send; #[cfg(not(feature = "async"))] fn write_message(&mut self, message: &T) -> SMBResult; @@ -68,7 +93,7 @@ pub trait SMBWriteStream: SMBStream { pub trait SMBStream: Send + Sync { #[cfg(feature = "async")] - fn close_stream(&mut self) -> impl Future> + Send; + fn close_stream(&mut self) -> impl Future> + Send; #[cfg(not(feature = "async"))] fn close_stream(&mut self) -> SMBResult<()>; } @@ -92,7 +117,10 @@ impl<'a, R: SMBReadStream> SMBMessageIterator<'a, R> { } #[cfg(feature = "async")] -type SMBMessageStreamResult<'a, T> = (SMBResult>, SMBMessageIterator<'a, T>); +type SMBMessageStreamResult<'a, T> = ( + SMBResult>, + SMBMessageIterator<'a, T>, +); #[cfg(feature = "async")] pub struct SMBMessageStream<'a, T: SMBReadStream> { @@ -138,4 +166,4 @@ impl SMBSocketConnection { pub fn into_streams(self) -> (R, W) { (self.read_stream, self.write_stream) } -} \ No newline at end of file +} diff --git a/smb/src/socket/message_stream/stream_async.rs b/smb/src/socket/message_stream/stream_async.rs index d269067..a092791 100644 --- a/smb/src/socket/message_stream/stream_async.rs +++ b/smb/src/socket/message_stream/stream_async.rs @@ -5,21 +5,32 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio_stream::Stream; use tokio_util::sync::ReusableBoxFuture; -use smb_core::{SMBParseResult, SMBResult}; use smb_core::error::SMBError; -use smb_core::logging::{trace, debug, warn}; +use smb_core::logging::{debug, trace, warn}; +use smb_core::{SMBParseResult, SMBResult}; use crate::protocol::body::SMBBody; use crate::protocol::header::SMBSyncHeader; use crate::protocol::message::{Message, SMBMessage}; -use crate::socket::message_stream::{SMBMessageIterator, SMBMessageStream, SMBReadStream, SMBSocketConnection, SMBStream, SMBWriteStream}; +use crate::socket::message_stream::{ + SMBMessageIterator, SMBMessageStream, SMBReadStream, SMBSocketConnection, SMBStream, + SMBWriteStream, +}; -async fn make_future(mut iterator: SMBMessageIterator<'_, T>) -> (SMBResult>, SMBMessageIterator<'_, T>) { +async fn make_future( + mut iterator: SMBMessageIterator<'_, T>, +) -> ( + SMBResult>, + SMBMessageIterator<'_, T>, +) { let res = loop { match iterator.reader.read_message(&mut iterator.buffer).await { Ok(msg) => break Ok(msg), Err(SMBError::PayloadTooSmall(_x)) => { - trace!(buf_len = iterator.buffer.len(), "buffer too small, reading more data"); + trace!( + buf_len = iterator.buffer.len(), + "buffer too small, reading more data" + ); } Err(e) => { warn!(?e, "message read error"); @@ -33,7 +44,11 @@ async fn make_future(mut iterator: SMBMessageIterator<'_, T>) } else { Err(res.err().unwrap()) }; - debug!(ok = msg_res.is_ok(), remaining = iterator.buffer.len(), "message read complete"); + debug!( + ok = msg_res.is_ok(), + remaining = iterator.buffer.len(), + "message read complete" + ); trace!(?msg_res, "parsed message result"); (msg_res, iterator) } @@ -42,13 +57,14 @@ impl<'a, T: SMBReadStream> SMBMessageStream<'a, T> { pub fn new(reader: &'a mut T) -> Self { let iterator = SMBMessageIterator::new(reader); let future = ReusableBoxFuture::new(make_future(iterator)); - Self { - inner: future, - } + Self { inner: future } } } -impl SMBWriteStream for Writer where Writer: AsyncWriteExt + Unpin + Send + Sync + SMBStream { +impl SMBWriteStream for Writer +where + Writer: AsyncWriteExt + Unpin + Send + Sync + SMBStream, +{ async fn write_message(&mut self, message: &T) -> SMBResult { let bytes = message.as_bytes(); self.write_all(&bytes).await.map_err(SMBError::io_error)?; @@ -56,8 +72,14 @@ impl SMBWriteStream for Writer where Writer: AsyncWriteExt + Unpin + Sen } } -impl SMBReadStream for Reader where Reader: AsyncReadExt + Unpin + Send + Sync + SMBStream { - async fn read_message<'a>(&'a mut self, existing: &'a mut Vec) -> SMBParseResult<&'a [u8], SMBMessage> { +impl SMBReadStream for Reader +where + Reader: AsyncReadExt + Unpin + Send + Sync + SMBStream, +{ + async fn read_message<'a>( + &'a mut self, + existing: &'a mut Vec, + ) -> SMBParseResult<&'a [u8], SMBMessage> { trace!(buf_len = existing.len(), "read_message called"); if let Ok((remaining, res)) = Self::read_message_inner(existing) { return Ok((&existing[(existing.len() - remaining.len())..], res)); @@ -69,7 +91,10 @@ impl SMBReadStream for Reader where Reader: AsyncReadExt + Unpin + Send Self::read_message_inner(existing) } - fn messages(&mut self) -> SMBMessageStream<'_, Self> where Self: Sized { + fn messages(&mut self) -> SMBMessageStream<'_, Self> + where + Self: Sized, + { SMBMessageStream::new(self) } } @@ -91,4 +116,4 @@ impl<'a, R: SMBReadStream> Stream for SMBMessageStream<'a, R> { Err(_) => Poll::Ready(None), } } -} \ No newline at end of file +} diff --git a/smb/src/socket/message_stream/stream_sync.rs b/smb/src/socket/message_stream/stream_sync.rs index 01dbf8e..cc2cc22 100644 --- a/smb/src/socket/message_stream/stream_sync.rs +++ b/smb/src/socket/message_stream/stream_sync.rs @@ -1,15 +1,23 @@ use std::io::{Read, Write}; -use smb_core::{SMBParseResult, SMBResult}; use smb_core::error::SMBError; +use smb_core::{SMBParseResult, SMBResult}; use crate::protocol::body::SMBBody; use crate::protocol::header::SMBSyncHeader; use crate::protocol::message::{Message, SMBMessage}; -use crate::socket::message_stream::{SMBMessageIterator, SMBReadStream, SMBSocketConnection, SMBWriteStream}; - -impl SMBReadStream for Reader where Reader: Read + Send + Sync { - fn read_message<'a>(&'a mut self, existing: &'a mut Vec) -> SMBParseResult<&[u8], SMBMessage> { +use crate::socket::message_stream::{ + SMBMessageIterator, SMBReadStream, SMBSocketConnection, SMBWriteStream, +}; + +impl SMBReadStream for Reader +where + Reader: Read + Send + Sync, +{ + fn read_message<'a>( + &'a mut self, + existing: &'a mut Vec, + ) -> SMBParseResult<&[u8], SMBMessage> { let mut buffer = [0_u8; 512]; if let Ok(read) = self.read(&mut buffer) { @@ -19,12 +27,18 @@ impl SMBReadStream for Reader where Reader: Read + Send + Sync { Self::read_message_inner(existing) } - fn messages(&mut self) -> SMBMessageIterator where Self: Sized { + fn messages(&mut self) -> SMBMessageIterator + where + Self: Sized, + { SMBMessageIterator::new(self) } } -impl SMBWriteStream for Writer where Writer: Write { +impl SMBWriteStream for Writer +where + Writer: Write, +{ fn write_message(&mut self, message: &T) -> SMBResult { let bytes = message.as_bytes(); self.write_all(&bytes).map_err(SMBError::io_error)?; @@ -46,4 +60,4 @@ impl Iterator for SMBMessageIterator<'_, R> { self.buffer = remaining.to_vec(); Some(message) } -} \ No newline at end of file +} diff --git a/smb/src/socket/mod.rs b/smb/src/socket/mod.rs index a51f2b9..f74c32d 100644 --- a/smb/src/socket/mod.rs +++ b/smb/src/socket/mod.rs @@ -1,2 +1,2 @@ +pub mod listener; pub mod message_stream; -pub mod listener; \ No newline at end of file diff --git a/smb/src/util/as_bytes.rs b/smb/src/util/as_bytes.rs index a54c146..def3515 100644 --- a/smb/src/util/as_bytes.rs +++ b/smb/src/util/as_bytes.rs @@ -1,3 +1,3 @@ pub trait AsByteVec { fn as_byte_vec(&self) -> Vec; -} \ No newline at end of file +} diff --git a/smb/src/util/auth/auth_context.rs b/smb/src/util/auth/auth_context.rs index ffe8977..98e2ab0 100644 --- a/smb/src/util/auth/auth_context.rs +++ b/smb/src/util/auth/auth_context.rs @@ -5,4 +5,4 @@ pub struct GenericAuthContext { work_station: String, version: String, guest: bool, -} \ No newline at end of file +} diff --git a/smb/src/util/auth/mod.rs b/smb/src/util/auth/mod.rs index a8aefdd..345373a 100644 --- a/smb/src/util/auth/mod.rs +++ b/smb/src/util/auth/mod.rs @@ -1,12 +1,11 @@ - pub use auth_context::*; -use smb_core::{SMBParseResult, SMBResult}; use smb_core::nt_status::NTStatus; +use smb_core::{SMBParseResult, SMBResult}; pub use user::*; +mod auth_context; pub mod ntlm; pub mod spnego; -mod auth_context; mod user; pub trait AuthProvider: Send + Sync { type Message: AuthMessage + Send + Sync + 'static; @@ -14,11 +13,17 @@ pub trait AuthProvider: Send + Sync { fn get_oid() -> Vec; - fn accept_security_context(&self, input_token: &Self::Message, context: &mut Self::Context) -> (NTStatus, Self::Message); + fn accept_security_context( + &self, + input_token: &Self::Message, + context: &mut Self::Context, + ) -> (NTStatus, Self::Message); } pub trait AuthMessage { - fn parse(data: &[u8]) -> SMBParseResult<&[u8], Self> where Self: Sized; + fn parse(data: &[u8]) -> SMBParseResult<&[u8], Self> + where + Self: Sized; fn as_bytes(&self) -> Vec; @@ -31,4 +36,3 @@ pub trait AuthContext { fn session_key(&self) -> &[u8]; fn user_name(&self) -> SMBResult<&Self::UserName>; } - diff --git a/smb/src/util/auth/ntlm/mod.rs b/smb/src/util/auth/ntlm/mod.rs index b696845..ded322d 100644 --- a/smb/src/util/auth/ntlm/mod.rs +++ b/smb/src/util/auth/ntlm/mod.rs @@ -5,8 +5,7 @@ pub use ntlm_message::*; pub use ntlm_negotiate_message::*; mod ntlm_auth_provider; +mod ntlm_authenticate_message; +mod ntlm_challenge_message; mod ntlm_message; mod ntlm_negotiate_message; -mod ntlm_challenge_message; -mod ntlm_authenticate_message; - diff --git a/smb/src/util/auth/ntlm/ntlm_auth_provider.rs b/smb/src/util/auth/ntlm/ntlm_auth_provider.rs index c41e588..5702d66 100644 --- a/smb/src/util/auth/ntlm/ntlm_auth_provider.rs +++ b/smb/src/util/auth/ntlm/ntlm_auth_provider.rs @@ -1,24 +1,24 @@ use serde::{Deserialize, Serialize}; +use smb_core::SMBResult; use smb_core::error::SMBError; use smb_core::nt_status::NTStatus; -use smb_core::SMBResult; -use crate::util::auth::{AuthContext, AuthProvider}; use crate::util::auth::ntlm::ntlm_message::NTLMMessage; use crate::util::auth::user::User; +use crate::util::auth::{AuthContext, AuthProvider}; #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] pub struct NTLMAuthProvider { accepted_users: Vec, - guest_supported: bool + guest_supported: bool, } impl NTLMAuthProvider { pub fn new(accepted_users: Vec, guest_supported: bool) -> Self { Self { accepted_users, - guest_supported + guest_supported, } } } @@ -31,27 +31,28 @@ impl AuthProvider for NTLMAuthProvider { vec![0x2b, 0x06, 0x01, 0x04, 0x01, 0x82, 0x37, 0x02, 0x02, 0x0a] } - fn accept_security_context(&self, input_message: &NTLMMessage, context: &mut NTLMAuthContext) -> (NTStatus, NTLMMessage) { + fn accept_security_context( + &self, + input_message: &NTLMMessage, + context: &mut NTLMAuthContext, + ) -> (NTStatus, NTLMMessage) { match input_message { NTLMMessage::Negotiate(x) => { let (status, challenge) = x.get_challenge_response(); context.server_challenge = (*challenge.server_challenge()).into(); (status, NTLMMessage::Challenge(challenge)) - }, - NTLMMessage::Challenge(_x) => { - (NTStatus::StatusSuccess, NTLMMessage::Dummy) - }, + } + NTLMMessage::Challenge(_x) => (NTStatus::StatusSuccess, NTLMMessage::Dummy), NTLMMessage::Authenticate(x) => { - let auth_status = x.authenticate(context, &self.accepted_users, self.guest_supported); + let auth_status = + x.authenticate(context, &self.accepted_users, self.guest_supported); if auth_status == 0 { (NTStatus::StatusSuccess, NTLMMessage::Dummy) } else { (NTStatus::LogonFailure, NTLMMessage::Dummy) } - }, - NTLMMessage::Dummy => { - (NTStatus::StatusSuccess, NTLMMessage::Dummy) } + NTLMMessage::Dummy => (NTStatus::StatusSuccess, NTLMMessage::Dummy), } } } @@ -99,6 +100,8 @@ impl AuthContext for NTLMAuthContext { } fn user_name(&self) -> SMBResult<&Self::UserName> { - self.user_name.as_ref().ok_or(SMBError::server_error("No user name")) + self.user_name + .as_ref() + .ok_or(SMBError::server_error("No user name")) } -} \ No newline at end of file +} diff --git a/smb/src/util/auth/ntlm/ntlm_authenticate_message.rs b/smb/src/util/auth/ntlm/ntlm_authenticate_message.rs index 9b041c5..e337e13 100644 --- a/smb/src/util/auth/ntlm/ntlm_authenticate_message.rs +++ b/smb/src/util/auth/ntlm/ntlm_authenticate_message.rs @@ -1,11 +1,11 @@ use des::cipher::KeyInit; +use nom::IResult; use nom::bytes::complete::take; use nom::combinator::{map, map_res}; -use nom::IResult; use nom::number::complete::le_u32; use nom::sequence::tuple; -use rc4::{Key, Rc4, StreamCipher}; use rc4::consts::U16; +use rc4::{Key, Rc4, StreamCipher}; use serde::{Deserialize, Serialize}; use smb_core::logging::trace; @@ -58,25 +58,34 @@ impl NTLMAuthenticateMessageBody { encrypted_session_key_info, negotiate_flags, _, - mic + mic, ), )| { let (_, lm_challenge_response) = get_buffer(lm_challenge_info.0, lm_challenge_info.1, bytes)?; let (_, nt_challenge_response) = get_buffer(nt_challenge_info.0, nt_challenge_info.1, bytes)?; - let (_, domain_name) = map(map_res( - |bytes| get_buffer(domain_name_into.0, domain_name_into.1, bytes), - String::from_utf8, - ), |s| s.replace('\0', ""))(bytes)?; - let (_, user_name) = map(map_res( - |bytes| get_buffer(user_name_info.0, user_name_info.1, bytes), - String::from_utf8, - ), |s| s.replace('\0', ""))(bytes)?; - let (_, work_station) = map(map_res( - |bytes| get_buffer(work_station_info.0, work_station_info.1, bytes), - String::from_utf8, - ), |s| s.replace('\0', ""))(bytes)?; + let (_, domain_name) = map( + map_res( + |bytes| get_buffer(domain_name_into.0, domain_name_into.1, bytes), + String::from_utf8, + ), + |s| s.replace('\0', ""), + )(bytes)?; + let (_, user_name) = map( + map_res( + |bytes| get_buffer(user_name_info.0, user_name_info.1, bytes), + String::from_utf8, + ), + |s| s.replace('\0', ""), + )(bytes)?; + let (_, work_station) = map( + map_res( + |bytes| get_buffer(work_station_info.0, work_station_info.1, bytes), + String::from_utf8, + ), + |s| s.replace('\0', ""), + )(bytes)?; let (remaining, encrypted_session_key) = get_buffer( encrypted_session_key_info.0, encrypted_session_key_info.1, @@ -124,7 +133,7 @@ impl NTLMAuthenticateMessageBody { 0 } else { 1 // TODO failure - } + }; } // TODO check if remaining attempts are allowed @@ -147,39 +156,59 @@ impl NTLMAuthenticateMessageBody { .negotiate_flags .contains(NTLMNegotiateFlags::EXTENDED_SESSION_SECURITY) { - if self.lm_challenge_response.len() == 24 && self.lm_challenge_response[0..8] != [0; 8] { + if self.lm_challenge_response.len() == 24 && self.lm_challenge_response[0..8] != [0; 8] + { // ntlm v1 extended // TODO: Use authenticate_v1_extended result to validate credentials and derive session base key - let _response = authenticate_v1_extended(&matched_user.password, server_challenge, &self.lm_challenge_response, &self.nt_challenge_response); + let _response = authenticate_v1_extended( + &matched_user.password, + server_challenge, + &self.lm_challenge_response, + &self.nt_challenge_response, + ); Vec::new() } else { // ntlm v2 - let (_, session_base_key) = authenticate_v2(&self.domain_name, &self.user_name, &matched_user.password, server_challenge, &self.lm_challenge_response, &self.nt_challenge_response).unwrap(); + let (_, session_base_key) = authenticate_v2( + &self.domain_name, + &self.user_name, + &matched_user.password, + server_challenge, + &self.lm_challenge_response, + &self.nt_challenge_response, + ) + .unwrap(); session_base_key } } else { Vec::new() }; - if !response_key.is_empty() && self.negotiate_flags.contains(NTLMNegotiateFlags::KEY_EXCHANGE) { - let session_key = if self.negotiate_flags.contains(NTLMNegotiateFlags::SEAL) || self.negotiate_flags.contains(NTLMNegotiateFlags::SIGN) { + if !response_key.is_empty() + && self + .negotiate_flags + .contains(NTLMNegotiateFlags::KEY_EXCHANGE) + { + let session_key = if self.negotiate_flags.contains(NTLMNegotiateFlags::SEAL) + || self.negotiate_flags.contains(NTLMNegotiateFlags::SIGN) + { let mut rc4 = Rc4::new(Key::::from_slice(&response_key)); let mut output = vec![0; self.encrypted_session_key.len()]; - rc4.apply_keystream_b2b(&self.encrypted_session_key, &mut output).unwrap(); - output + rc4.apply_keystream_b2b(&self.encrypted_session_key, &mut output) + .unwrap(); + output } else { response_key }; context.session_key = session_key; 0 - } else { 1 } - + } else { + 1 + } } } - fn get_buffer(length: u16, offset: u32, buffer: &[u8]) -> IResult<&[u8], Vec> { let (remaining, slice) = take(offset as usize)(buffer) .and_then(|(remaining, _)| take(length as usize)(remaining))?; Ok((remaining, slice.to_vec())) } - diff --git a/smb/src/util/auth/ntlm/ntlm_challenge_message.rs b/smb/src/util/auth/ntlm/ntlm_challenge_message.rs index bb84171..3782ef5 100644 --- a/smb/src/util/auth/ntlm/ntlm_challenge_message.rs +++ b/smb/src/util/auth/ntlm/ntlm_challenge_message.rs @@ -40,17 +40,19 @@ impl NTLMChallengeMessageBody { } [ self.signature.as_bytes(), // 0 - 8 - &u32_to_bytes(0x02), // 8 - 12 - &u16_to_bytes(20), &u16_to_bytes(20), // 12 - 16 - &u32_to_bytes(56), // 16 - 20 + &u32_to_bytes(0x02), // 8 - 12 + &u16_to_bytes(20), + &u16_to_bytes(20), // 12 - 16 + &u32_to_bytes(56), // 16 - 20 &u32_to_bytes(self.negotiate_flags.bits()), // 20 - 24 - &self.server_challenge, // 24 - 32 - &[0; 8], // 32 - 40 - &u16_to_bytes(52), &u16_to_bytes(52), // 40-44 - &u32_to_bytes(76), // 44 - 48 - &[6, 1], // NTLM major minor + &self.server_challenge, // 24 - 32 + &[0; 8], // 32 - 40 + &u16_to_bytes(52), + &u16_to_bytes(52), // 40-44 + &u32_to_bytes(76), // 44 - 48 + &[6, 1], // NTLM major minor &u16_to_bytes(7600), // NTLM build - &[0, 0, 0, 15], // NTLM current revision + &[0, 0, 0, 15], // NTLM current revision &name, &u16_to_bytes(1), &u16_to_bytes(20), @@ -59,7 +61,8 @@ impl NTLMChallengeMessageBody { &u16_to_bytes(20), &name, &[0; 4], - ].concat() + ] + .concat() } } @@ -75,4 +78,4 @@ impl NTLMChallengeMessageBody { pub fn server_challenge(&self) -> &[u8; 8] { &self.server_challenge } -} \ No newline at end of file +} diff --git a/smb/src/util/auth/ntlm/ntlm_message.rs b/smb/src/util/auth/ntlm/ntlm_message.rs index 4e01a41..254e5b1 100644 --- a/smb/src/util/auth/ntlm/ntlm_message.rs +++ b/smb/src/util/auth/ntlm/ntlm_message.rs @@ -1,11 +1,11 @@ use bitflags::bitflags; -use nom::bytes::complete::take; use nom::IResult; +use nom::bytes::complete::take; use nom::number::complete::{le_u16, le_u32}; use serde::{Deserialize, Serialize}; -use smb_core::error::SMBError; use smb_core::SMBParseResult; +use smb_core::error::SMBError; use crate::util::auth::AuthMessage; use crate::util::auth::ntlm::ntlm_authenticate_message::NTLMAuthenticateMessageBody; @@ -17,39 +17,31 @@ pub enum NTLMMessage { Negotiate(NTLMNegotiateMessageBody), Challenge(NTLMChallengeMessageBody), Authenticate(NTLMAuthenticateMessageBody), - Dummy + Dummy, } impl AuthMessage for NTLMMessage { fn parse(bytes: &[u8]) -> SMBParseResult<&[u8], Self> { let (_, msg_type) = take::>(8_usize)(bytes) .and_then(|(remaining, _)| le_u32(remaining)) - .map_err(|e| { - SMBError::parse_error(e.to_owned()) - })?; + .map_err(|e| SMBError::parse_error(e.to_owned()))?; match msg_type { 0x01 => { let (remaining, body) = NTLMNegotiateMessageBody::parse(bytes) - .map_err(|e| { - SMBError::parse_error(e.to_owned()) - })?; + .map_err(|e| SMBError::parse_error(e.to_owned()))?; Ok((remaining, NTLMMessage::Negotiate(body))) - }, + } 0x02 => { let (remaining, body) = NTLMChallengeMessageBody::parse(bytes) - .map_err(|e| { - SMBError::parse_error(e.to_owned()) - })?; + .map_err(|e| SMBError::parse_error(e.to_owned()))?; Ok((remaining, NTLMMessage::Challenge(body))) - }, + } 0x03 => { let (remaining, body) = NTLMAuthenticateMessageBody::parse(bytes) - .map_err(|e| { - SMBError::parse_error(e.to_owned()) - })?; + .map_err(|e| SMBError::parse_error(e.to_owned()))?; Ok((remaining, NTLMMessage::Authenticate(body))) - }, - _ => Err(SMBError::parse_error("Invalid message type")) + } + _ => Err(SMBError::parse_error("Invalid message type")), } } @@ -97,6 +89,7 @@ bitflags! { pub(crate) fn parse_ntlm_buffer_fields(bytes: &[u8]) -> IResult<&[u8], (u16, u32)> { let (remaining, length) = le_u16(bytes)?; - let (remaining, buffer_offset) = take(2_usize)(remaining).and_then(|(remaining, _)| le_u32(remaining))?; + let (remaining, buffer_offset) = + take(2_usize)(remaining).and_then(|(remaining, _)| le_u32(remaining))?; Ok((remaining, (length, buffer_offset))) } diff --git a/smb/src/util/auth/ntlm/ntlm_negotiate_message.rs b/smb/src/util/auth/ntlm/ntlm_negotiate_message.rs index 462ceef..9c6c641 100644 --- a/smb/src/util/auth/ntlm/ntlm_negotiate_message.rs +++ b/smb/src/util/auth/ntlm/ntlm_negotiate_message.rs @@ -1,6 +1,6 @@ +use nom::IResult; use nom::bytes::complete::take; use nom::combinator::{map, map_res}; -use nom::IResult; use nom::number::complete::le_u32; use nom::sequence::tuple; use serde::{Deserialize, Serialize}; @@ -123,7 +123,9 @@ impl NTLMNegotiateMessageBody { let target_name = "fakeserver"; - (NTStatus::MoreProcessingRequired, NTLMChallengeMessageBody::new(target_name.into(), negotiate_flags)) + ( + NTStatus::MoreProcessingRequired, + NTLMChallengeMessageBody::new(target_name.into(), negotiate_flags), + ) } } - diff --git a/smb/src/util/auth/spnego/der_utils.rs b/smb/src/util/auth/spnego/der_utils.rs index 56bab8a..dadd51a 100644 --- a/smb/src/util/auth/spnego/der_utils.rs +++ b/smb/src/util/auth/spnego/der_utils.rs @@ -1,8 +1,8 @@ +use nom::Err::Error; +use nom::IResult; use nom::bytes::complete::take; use nom::combinator::map; -use nom::Err::Error; use nom::error::ErrorKind; -use nom::IResult; use nom::multi::fold_many_m_n; use nom::number::complete::le_u8; @@ -40,21 +40,27 @@ impl AsDerBytes for Vec { impl AsDerBytes for Vec> { fn der_bytes(&self, item_tag: u8) -> Vec { - self.iter().flat_map(|inner_arr| { - [ - &[item_tag][0..], - &*get_length(inner_arr.len()), - inner_arr - ].concat() - }).collect::>() + self.iter() + .flat_map(|inner_arr| { + [&[item_tag][0..], &*get_length(inner_arr.len()), inner_arr].concat() + }) + .collect::>() } } pub fn parse_length(buffer: &[u8]) -> IResult<&[u8], usize> { let (remaining, len) = le_u8(buffer)?; - if len < 0x80 { return Ok((remaining, len as usize)); } + if len < 0x80 { + return Ok((remaining, len as usize)); + } let field_size = (len & 0x7f) as usize; - fold_many_m_n(field_size, field_size, le_u8, || 0_usize, |len, item| len * 256 + item as usize)(remaining) + fold_many_m_n( + field_size, + field_size, + le_u8, + || 0_usize, + |len, item| len * 256 + item as usize, + )(remaining) } pub fn parse_field_with_len(buffer: &[u8]) -> IResult<&[u8], &[u8]> { @@ -72,14 +78,20 @@ pub fn get_array_field_len(array: &T) -> usize { 1 + bytes_construction_len_field_size + bytes_construction_len } -pub fn encode_der_bytes(bytes: &T, type_tag: u8, encoding_tag: u8, item_tag: u8) -> Vec { +pub fn encode_der_bytes( + bytes: &T, + type_tag: u8, + encoding_tag: u8, + item_tag: u8, +) -> Vec { [ &[type_tag][0..], &*get_length(1 + get_field_size(bytes.der_length()) + bytes.der_length()), &[encoding_tag], &*get_length(bytes.der_length()), - &*bytes.der_bytes(item_tag) - ].concat() + &*bytes.der_bytes(item_tag), + ] + .concat() } pub fn get_field_size(len: usize) -> usize { @@ -122,7 +134,9 @@ pub fn parse_der_byte_array(buffer: &[u8]) -> IResult<&[u8], Vec> { pub fn parse_der_multibyte(buffer: &[u8], tag: u8) -> IResult<&[u8], Vec> { let (remaining, b_tag) = le_u8(buffer)?; - if tag != b_tag { return Err(Error(nom::error::Error::new(remaining, ErrorKind::Fail))); } + if tag != b_tag { + return Err(Error(nom::error::Error::new(remaining, ErrorKind::Fail))); + } map(parse_field_with_len, |buf| buf.to_vec())(remaining) } @@ -144,4 +158,4 @@ impl WithDerLength for Vec> { prev + entry_len }) } -} \ No newline at end of file +} diff --git a/smb/src/util/auth/spnego/mod.rs b/smb/src/util/auth/spnego/mod.rs index b5ea4dc..93781d2 100644 --- a/smb/src/util/auth/spnego/mod.rs +++ b/smb/src/util/auth/spnego/mod.rs @@ -8,4 +8,4 @@ pub(crate) mod der_utils; pub type SPNEGOToken = spnego_token::SPNEGOToken; pub type SPNEGOTokenInitBody = spnego_token_init::SPNEGOTokenInitBody; pub type SPNEGOTokenInit2Body = spnego_token_init_2::SPNEGOTokenInit2Body; -pub type SPNEGOTokenResponseBody = spnego_token_response::SPNEGOTokenResponseBody; \ No newline at end of file +pub type SPNEGOTokenResponseBody = spnego_token_response::SPNEGOTokenResponseBody; diff --git a/smb/src/util/auth/spnego/spnego_token.rs b/smb/src/util/auth/spnego/spnego_token.rs index a357a88..86959f3 100644 --- a/smb/src/util/auth/spnego/spnego_token.rs +++ b/smb/src/util/auth/spnego/spnego_token.rs @@ -1,18 +1,23 @@ -use nom::bytes::complete::take; use nom::Err::Error; -use nom::error::ErrorKind; use nom::IResult; +use nom::bytes::complete::take; +use nom::error::ErrorKind; use nom::number::complete::le_u8; use serde::{Deserialize, Serialize}; -use smb_core::{SMBParseResult, SMBResult}; use smb_core::error::SMBError; use smb_core::logging::trace; use smb_core::nt_status::NTStatus; +use smb_core::{SMBParseResult, SMBResult}; +use crate::util::auth::spnego::der_utils::{ + APPLICATION_TAG, DER_ENCODING_OID_TAG, NEG_TOKEN_INIT_TAG, NEG_TOKEN_RESP_TAG, SPNEGO_ID, + get_field_size, get_length, parse_field_with_len, +}; +use crate::util::auth::spnego::{ + SPNEGOTokenInit2Body, SPNEGOTokenInitBody, SPNEGOTokenResponseBody, +}; use crate::util::auth::{AuthMessage, AuthProvider}; -use crate::util::auth::spnego::{SPNEGOTokenInit2Body, SPNEGOTokenInitBody, SPNEGOTokenResponseBody}; -use crate::util::auth::spnego::der_utils::{APPLICATION_TAG, DER_ENCODING_OID_TAG, get_field_size, get_length, NEG_TOKEN_INIT_TAG, NEG_TOKEN_RESP_TAG, parse_field_with_len, SPNEGO_ID}; #[derive(Debug, Deserialize, Serialize)] pub enum SPNEGOToken { @@ -22,21 +27,33 @@ pub enum SPNEGOToken { } impl SPNEGOToken { - pub fn get_message(&self, auth_provider: &A, ctx: &mut A::Context) -> SMBResult<(NTStatus, A::Message)> { + pub fn get_message( + &self, + auth_provider: &A, + ctx: &mut A::Context, + ) -> SMBResult<(NTStatus, A::Message)> { let result = match self { SPNEGOToken::Init(init_msg) => { - let mech_token = init_msg.mech_token.as_ref().ok_or(SMBError::parse_error("Parse failure"))?; - let ntlm_msg = - A::Message::parse(mech_token).map_err(|_e| SMBError::parse_error("Parse failure"))?.1; + let mech_token = init_msg + .mech_token + .as_ref() + .ok_or(SMBError::parse_error("Parse failure"))?; + let ntlm_msg = A::Message::parse(mech_token) + .map_err(|_e| SMBError::parse_error("Parse failure"))? + .1; auth_provider.accept_security_context(&ntlm_msg, ctx) } SPNEGOToken::Response(resp_msg) => { - let response_token = resp_msg.response_token.as_ref().ok_or(SMBError::parse_error("Parse failure"))?; - let ntlm_msg = - A::Message::parse(response_token).map_err(|_e| SMBError::parse_error("Parse failure"))?.1; + let response_token = resp_msg + .response_token + .as_ref() + .ok_or(SMBError::parse_error("Parse failure"))?; + let ntlm_msg = A::Message::parse(response_token) + .map_err(|_e| SMBError::parse_error("Parse failure"))? + .1; auth_provider.accept_security_context(&ntlm_msg, ctx) } - _ => { (NTStatus::StatusSuccess, A::Message::empty()) } + _ => (NTStatus::StatusSuccess, A::Message::empty()), }; Ok(result) @@ -48,37 +65,34 @@ impl SPNEGOToken { trace!(buf_len = bytes.len(), "parsing SPNEGO token"); let (remaining, tag) = le_u8(bytes)?; match tag { - APPLICATION_TAG => { - take(1_usize)(remaining) - .and_then(|(remaining, _)| { - let (remaining, tag) = le_u8(remaining)?; - if tag != DER_ENCODING_OID_TAG { - return Err(Error(nom::error::Error::new(remaining, ErrorKind::Fail))); - } - let (remaining, oid) = parse_field_with_len(remaining)?; - if oid.len() != SPNEGO_ID.len() || *oid != SPNEGO_ID { - return Err(Error(nom::error::Error::new(remaining, ErrorKind::Fail))); - } - let (remaining, tag) = le_u8(remaining)?; - trace!(tag, "SPNEGO inner tag"); - match tag { - NEG_TOKEN_INIT_TAG => { - let (remaining, body) = SPNEGOTokenInitBody::parse(remaining)?; - Ok((remaining, SPNEGOToken::Init(body))) - }, - NEG_TOKEN_RESP_TAG => { - let (remaining, body) = SPNEGOTokenResponseBody::parse(remaining)?; - Ok((remaining, SPNEGOToken::Response(body))) - }, - _ => Err(Error(nom::error::Error::new(remaining, ErrorKind::Fail))) - } - }) - }, + APPLICATION_TAG => take(1_usize)(remaining).and_then(|(remaining, _)| { + let (remaining, tag) = le_u8(remaining)?; + if tag != DER_ENCODING_OID_TAG { + return Err(Error(nom::error::Error::new(remaining, ErrorKind::Fail))); + } + let (remaining, oid) = parse_field_with_len(remaining)?; + if oid.len() != SPNEGO_ID.len() || *oid != SPNEGO_ID { + return Err(Error(nom::error::Error::new(remaining, ErrorKind::Fail))); + } + let (remaining, tag) = le_u8(remaining)?; + trace!(tag, "SPNEGO inner tag"); + match tag { + NEG_TOKEN_INIT_TAG => { + let (remaining, body) = SPNEGOTokenInitBody::parse(remaining)?; + Ok((remaining, SPNEGOToken::Init(body))) + } + NEG_TOKEN_RESP_TAG => { + let (remaining, body) = SPNEGOTokenResponseBody::parse(remaining)?; + Ok((remaining, SPNEGOToken::Response(body))) + } + _ => Err(Error(nom::error::Error::new(remaining, ErrorKind::Fail))), + } + }), NEG_TOKEN_RESP_TAG => { let (remaining, body) = SPNEGOTokenResponseBody::parse(remaining)?; Ok((remaining, SPNEGOToken::Response(body))) - }, - _ => Err(Error(nom::error::Error::new(remaining, ErrorKind::Fail))) + } + _ => Err(Error(nom::error::Error::new(remaining, ErrorKind::Fail))), } } @@ -97,10 +111,11 @@ impl SPNEGOToken { &[DER_ENCODING_OID_TAG], &get_length(SPNEGO_ID.len()), &SPNEGO_ID, - &bytes - ].concat() + &bytes, + ] + .concat() } else { bytes.to_vec() } } -} \ No newline at end of file +} diff --git a/smb/src/util/auth/spnego/spnego_token_init.rs b/smb/src/util/auth/spnego/spnego_token_init.rs index 9444967..80f714b 100644 --- a/smb/src/util/auth/spnego/spnego_token_init.rs +++ b/smb/src/util/auth/spnego/spnego_token_init.rs @@ -1,12 +1,17 @@ use nom::Err::Error; -use nom::error::ErrorKind; use nom::IResult; +use nom::error::ErrorKind; use nom::multi::many0; use nom::number::complete::le_u8; use serde::{Deserialize, Serialize}; use crate::util::auth::AuthProvider; -use crate::util::auth::spnego::der_utils::{DER_ENCODING_BYTE_ARRAY_TAG, DER_ENCODING_OID_TAG, DER_ENCODING_SEQUENCE_TAG, encode_der_bytes, get_array_field_len, get_field_size, get_length, MECH_LIST_MIC_TAG, MECH_TOKEN_TAG, MECH_TYPE_LIST_TAG, NEG_TOKEN_INIT_TAG, parse_der_byte_array, parse_der_multibyte, parse_field_with_len, parse_length}; +use crate::util::auth::spnego::der_utils::{ + DER_ENCODING_BYTE_ARRAY_TAG, DER_ENCODING_OID_TAG, DER_ENCODING_SEQUENCE_TAG, + MECH_LIST_MIC_TAG, MECH_TOKEN_TAG, MECH_TYPE_LIST_TAG, NEG_TOKEN_INIT_TAG, encode_der_bytes, + get_array_field_len, get_field_size, get_length, parse_der_byte_array, parse_der_multibyte, + parse_field_with_len, parse_length, +}; #[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)] pub struct SPNEGOTokenInitBody { @@ -36,7 +41,9 @@ impl SPNEGOTokenInitBody { pub fn parse(bytes: &[u8]) -> IResult<&[u8], Self> { let (remaining, _) = parse_length(bytes)?; let (remaining, mut tag) = le_u8(remaining)?; - if tag != DER_ENCODING_SEQUENCE_TAG { return Err(Error(nom::error::Error::new(remaining, ErrorKind::Fail))) } + if tag != DER_ENCODING_SEQUENCE_TAG { + return Err(Error(nom::error::Error::new(remaining, ErrorKind::Fail))); + } let (remaining, mut sequence) = parse_field_with_len(remaining)?; let mut mech_type_list = None; let mut mech_token = None; @@ -48,21 +55,29 @@ impl SPNEGOTokenInitBody { let (s, list) = Self::parse_mech_type_list(sequence)?; sequence = s; mech_type_list = Some(list); - }, + } MECH_TOKEN_TAG => { let (s, token) = Self::parse_mech_token(sequence)?; sequence = s; mech_token = Some(token); - }, + } MECH_LIST_MIC_TAG => { let (s, mic) = Self::parse_mech_list_mic(sequence)?; sequence = s; mech_list_mic = Some(mic); - }, + } _ => return Err(Error(nom::error::Error::new(remaining, ErrorKind::Fail))), } } - Ok((remaining, SPNEGOTokenInitBody { mechanism: None, mech_type_list, mech_token, mech_list_mic })) + Ok(( + remaining, + SPNEGOTokenInitBody { + mechanism: None, + mech_type_list, + mech_token, + mech_list_mic, + }, + )) } pub fn as_bytes(&self) -> Vec { @@ -80,15 +95,30 @@ impl SPNEGOTokenInitBody { // Write mechanism type list if it's not null if let Some(mech_type_list) = &self.mech_type_list { - bytes.append(&mut encode_der_bytes(mech_type_list, MECH_TYPE_LIST_TAG, DER_ENCODING_SEQUENCE_TAG, DER_ENCODING_OID_TAG)); + bytes.append(&mut encode_der_bytes( + mech_type_list, + MECH_TYPE_LIST_TAG, + DER_ENCODING_SEQUENCE_TAG, + DER_ENCODING_OID_TAG, + )); } // Write mechanism token if it's not null if let Some(mech_token) = &self.mech_token { - bytes.append(&mut encode_der_bytes(mech_token, MECH_TOKEN_TAG, DER_ENCODING_BYTE_ARRAY_TAG, 0)); + bytes.append(&mut encode_der_bytes( + mech_token, + MECH_TOKEN_TAG, + DER_ENCODING_BYTE_ARRAY_TAG, + 0, + )); } // Write mechanism list mic if it's not null if let Some(mech_list_mic) = &self.mech_list_mic { - bytes.append(&mut encode_der_bytes(mech_list_mic, MECH_LIST_MIC_TAG, DER_ENCODING_BYTE_ARRAY_TAG, 0)); + bytes.append(&mut encode_der_bytes( + mech_list_mic, + MECH_LIST_MIC_TAG, + DER_ENCODING_BYTE_ARRAY_TAG, + 0, + )); } bytes } @@ -99,7 +129,9 @@ impl SPNEGOTokenInitBody { fn parse_mech_type_list(buffer: &[u8]) -> IResult<&[u8], Vec>> { let (remaining, _) = parse_length(buffer)?; let (remaining, tag) = le_u8(remaining)?; - if tag != DER_ENCODING_SEQUENCE_TAG { return Err(Error(nom::error::Error::new(remaining, ErrorKind::Fail))) } + if tag != DER_ENCODING_SEQUENCE_TAG { + return Err(Error(nom::error::Error::new(remaining, ErrorKind::Fail))); + } let (remaining, sequence) = parse_field_with_len(remaining)?; let (_, list) = many0(|buf| parse_der_multibyte(buf, DER_ENCODING_OID_TAG))(sequence)?; Ok((remaining, list)) diff --git a/smb/src/util/auth/spnego/spnego_token_init_2.rs b/smb/src/util/auth/spnego/spnego_token_init_2.rs index 30d58b4..4711766 100644 --- a/smb/src/util/auth/spnego/spnego_token_init_2.rs +++ b/smb/src/util/auth/spnego/spnego_token_init_2.rs @@ -3,9 +3,11 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Deserialize, Serialize)] pub struct SPNEGOTokenInit2Body { - mechanism: Option + mechanism: Option, } impl SPNEGOTokenInit2Body { - pub(crate) fn as_bytes(&self) -> Vec {Vec::new()} + pub(crate) fn as_bytes(&self) -> Vec { + Vec::new() + } } diff --git a/smb/src/util/auth/spnego/spnego_token_response.rs b/smb/src/util/auth/spnego/spnego_token_response.rs index d942241..5e1828b 100644 --- a/smb/src/util/auth/spnego/spnego_token_response.rs +++ b/smb/src/util/auth/spnego/spnego_token_response.rs @@ -1,7 +1,7 @@ -use nom::combinator::map_res; use nom::Err::Error; -use nom::error::ErrorKind; use nom::IResult; +use nom::combinator::map_res; +use nom::error::ErrorKind; use nom::number::complete::le_u8; use num_enum::TryFromPrimitive; use serde::{Deserialize, Serialize}; @@ -9,8 +9,13 @@ use serde::{Deserialize, Serialize}; use smb_core::logging::trace; use smb_core::nt_status::NTStatus; +use crate::util::auth::spnego::der_utils::{ + DER_ENCODING_BYTE_ARRAY_TAG, DER_ENCODING_ENUM_TAG, DER_ENCODING_OID_TAG, + DER_ENCODING_SEQUENCE_TAG, MECH_LIST_MIC_TAG, NEG_STATE_TAG, NEG_TOKEN_RESP_TAG, + RESPONSE_TOKEN_TAG, SUPPORTED_MECH_TAG, encode_der_bytes, get_array_field_len, get_field_size, + get_length, parse_der_byte_array, parse_der_oid, parse_field_with_len, parse_length, +}; use crate::util::auth::{AuthMessage, AuthProvider}; -use crate::util::auth::spnego::der_utils::{DER_ENCODING_BYTE_ARRAY_TAG, DER_ENCODING_ENUM_TAG, DER_ENCODING_OID_TAG, DER_ENCODING_SEQUENCE_TAG, encode_der_bytes, get_array_field_len, get_field_size, get_length, MECH_LIST_MIC_TAG, NEG_STATE_TAG, NEG_TOKEN_RESP_TAG, parse_der_byte_array, parse_der_oid, parse_field_with_len, parse_length, RESPONSE_TOKEN_TAG, SUPPORTED_MECH_TAG}; #[repr(u8)] #[derive(Copy, Clone, Debug, PartialEq, Eq, TryFromPrimitive, Deserialize, Serialize)] @@ -35,7 +40,7 @@ impl SPNEGOTokenResponseBody { let state = Some(match status { NTStatus::StatusSuccess => NegotiateState::AcceptCompleted, NTStatus::MoreProcessingRequired => NegotiateState::AcceptIncomplete, - _ => NegotiateState::Reject + _ => NegotiateState::Reject, }); let (response_token, supported_mech) = if token_content.as_bytes().is_empty() { (None, None) @@ -70,15 +75,30 @@ impl SPNEGOTokenResponseBody { } if let Some(supported_mech) = &self.supported_mech { - bytes.append(&mut encode_der_bytes(supported_mech, SUPPORTED_MECH_TAG, DER_ENCODING_OID_TAG, 0)); + bytes.append(&mut encode_der_bytes( + supported_mech, + SUPPORTED_MECH_TAG, + DER_ENCODING_OID_TAG, + 0, + )); } if let Some(response_token) = &self.response_token { - bytes.append(&mut encode_der_bytes(response_token, RESPONSE_TOKEN_TAG, DER_ENCODING_BYTE_ARRAY_TAG, 0)); + bytes.append(&mut encode_der_bytes( + response_token, + RESPONSE_TOKEN_TAG, + DER_ENCODING_BYTE_ARRAY_TAG, + 0, + )); } if let Some(mech_list_mic) = &self.mech_list_mic { - bytes.append(&mut encode_der_bytes(mech_list_mic, MECH_LIST_MIC_TAG, DER_ENCODING_BYTE_ARRAY_TAG, 0)); + bytes.append(&mut encode_der_bytes( + mech_list_mic, + MECH_LIST_MIC_TAG, + DER_ENCODING_BYTE_ARRAY_TAG, + 0, + )); } bytes @@ -87,7 +107,9 @@ impl SPNEGOTokenResponseBody { pub fn parse(bytes: &[u8]) -> IResult<&[u8], Self> { let (remaining, _) = parse_length(bytes)?; let (remaining, mut tag) = le_u8(remaining)?; - if tag != DER_ENCODING_SEQUENCE_TAG { return Err(Error(nom::error::Error::new(remaining, ErrorKind::Fail))) } + if tag != DER_ENCODING_SEQUENCE_TAG { + return Err(Error(nom::error::Error::new(remaining, ErrorKind::Fail))); + } let (remaining, mut sequence) = parse_field_with_len(remaining)?; let mut state = None; let mut supported_mech = None; @@ -95,34 +117,46 @@ impl SPNEGOTokenResponseBody { let mut mech_list_mic = None; while !sequence.is_empty() { - trace!(remaining_len = sequence.len(), "parsing SPNEGO response sequence field"); + trace!( + remaining_len = sequence.len(), + "parsing SPNEGO response sequence field" + ); (sequence, tag) = le_u8(sequence)?; match tag { NEG_STATE_TAG => { let (s, neg_state) = Self::parse_negotiate_state(sequence)?; sequence = s; state = Some(neg_state); - }, + } SUPPORTED_MECH_TAG => { let (s, mech) = Self::parse_supported_mech(sequence)?; sequence = s; supported_mech = Some(mech); - }, + } RESPONSE_TOKEN_TAG => { let (s, resp) = Self::parse_response_token(sequence)?; sequence = s; response_token = Some(resp); - }, + } MECH_LIST_MIC_TAG => { let (s, mic) = Self::parse_mech_list_mic(sequence)?; sequence = s; mech_list_mic = Some(mic); - }, + } _ => return Err(Error(nom::error::Error::new(remaining, ErrorKind::Fail))), } } - Ok((remaining, SPNEGOTokenResponseBody { mechanism: None, state, supported_mech, response_token, mech_list_mic })) + Ok(( + remaining, + SPNEGOTokenResponseBody { + mechanism: None, + state, + supported_mech, + response_token, + mech_list_mic, + }, + )) } } @@ -131,7 +165,9 @@ impl SPNEGOTokenResponseBody { fn parse_negotiate_state(buffer: &[u8]) -> IResult<&[u8], NegotiateState> { let (remaining, _) = parse_length(buffer)?; let (remaining, tag) = le_u8(remaining)?; - if tag != DER_ENCODING_ENUM_TAG { return Err(Error(nom::error::Error::new(remaining, ErrorKind::Fail))) } + if tag != DER_ENCODING_ENUM_TAG { + return Err(Error(nom::error::Error::new(remaining, ErrorKind::Fail))); + } map_res(le_u8, NegotiateState::try_from)(remaining) } @@ -156,8 +192,9 @@ impl SPNEGOTokenResponseBody { &*get_length(3), &[DER_ENCODING_ENUM_TAG], &*get_length(1), - &[*state as u8] - ].concat() + &[*state as u8], + ] + .concat() } fn supported_mech_bytes(&self, supported_mech: &[u8]) -> Vec { @@ -167,8 +204,9 @@ impl SPNEGOTokenResponseBody { &*get_length(construction_len), &[DER_ENCODING_OID_TAG], &*get_length(supported_mech.len()), - supported_mech - ].concat() + supported_mech, + ] + .concat() } fn response_token_bytes(&self, response_token: &[u8]) -> Vec { @@ -178,11 +216,11 @@ impl SPNEGOTokenResponseBody { &*get_length(construction_len), &[DER_ENCODING_BYTE_ARRAY_TAG], &*get_length(response_token.len()), - response_token - ].concat() + response_token, + ] + .concat() } - fn token_fields_len(&self) -> usize { let mut len = 0; if self.state.is_some() { @@ -200,4 +238,4 @@ impl SPNEGOTokenResponseBody { } len } -} \ No newline at end of file +} diff --git a/smb/src/util/auth/user.rs b/smb/src/util/auth/user.rs index 5f68cf2..35495ec 100644 --- a/smb/src/util/auth/user.rs +++ b/smb/src/util/auth/user.rs @@ -8,6 +8,9 @@ pub struct User { impl User { pub fn new, P: Into>(username: U, password: P) -> Self { - Self { username: username.into(), password: password.into() } + Self { + username: username.into(), + password: password.into(), + } } -} \ No newline at end of file +} diff --git a/smb/src/util/crypto/des.rs b/smb/src/util/crypto/des.rs index 005f2f8..4a7af60 100644 --- a/smb/src/util/crypto/des.rs +++ b/smb/src/util/crypto/des.rs @@ -1,12 +1,14 @@ -use des::cipher::BlockEncrypt; use des::Des; +use des::cipher::BlockEncrypt; use digest::KeyInit; -use smb_core::error::SMBError; use smb_core::SMBResult; +use smb_core::error::SMBError; pub fn des_long_encrypt(key: &[u8], plaintext: &[u8]) -> SMBResult> { - if key.len() != 16 || plaintext.len() != 8 { return Err(SMBError::crypto_error("Invalid key length")); } + if key.len() != 16 || plaintext.len() != 8 { + return Err(SMBError::crypto_error("Invalid key length")); + } let padded = [key, &*vec![0; 21 - key.len()]].concat(); let k1 = &padded[0..7]; @@ -40,8 +42,7 @@ fn extend_des_key(key: &[u8]) -> Vec { } fn des_encrypt(key: &[u8], plaintext: &[u8]) -> SMBResult> { - let des = Des::new_from_slice(key) - .map_err(|_| SMBError::crypto_error("Invalid key length"))?; + let des = Des::new_from_slice(key).map_err(|_| SMBError::crypto_error("Invalid key length"))?; let mut result = vec![0_u8; plaintext.len()]; des.encrypt_block_b2b(plaintext.into(), (&mut *result).into()); Ok(result) diff --git a/smb/src/util/crypto/mod.rs b/smb/src/util/crypto/mod.rs index c454fa4..da4bd67 100644 --- a/smb/src/util/crypto/mod.rs +++ b/smb/src/util/crypto/mod.rs @@ -2,4 +2,4 @@ pub mod des; pub mod ntlm_v1_extended; pub mod ntlm_v2; pub mod smb2; -pub mod sp800_108; \ No newline at end of file +pub mod sp800_108; diff --git a/smb/src/util/crypto/ntlm_v1_extended.rs b/smb/src/util/crypto/ntlm_v1_extended.rs index 05ed4d8..771e6c1 100644 --- a/smb/src/util/crypto/ntlm_v1_extended.rs +++ b/smb/src/util/crypto/ntlm_v1_extended.rs @@ -6,14 +6,24 @@ use smb_core::SMBResult; use crate::byte_helper::u16_to_bytes; use crate::util::crypto::des::des_long_encrypt; -pub fn authenticate_v1_extended(password: &str, server_challenge: &[u8], lm_response: &[u8], nt_respobse: &[u8]) -> SMBResult { +pub fn authenticate_v1_extended( + password: &str, + server_challenge: &[u8], + lm_response: &[u8], + nt_respobse: &[u8], +) -> SMBResult { let client_challenge = &lm_response[0..8]; - let expected_v1_response = compute_ntlmv1_extended_response(server_challenge, client_challenge, password)?; + let expected_v1_response = + compute_ntlmv1_extended_response(server_challenge, client_challenge, password)?; Ok(nt_respobse == expected_v1_response) } -fn compute_ntlmv1_extended_response(server_challenge: &[u8], client_challenge: &[u8], password: &str) -> SMBResult> { +fn compute_ntlmv1_extended_response( + server_challenge: &[u8], + client_challenge: &[u8], + password: &str, +) -> SMBResult> { let challenge_hash = Md4::new() .chain_update(server_challenge) .chain_update(client_challenge) @@ -24,7 +34,11 @@ fn compute_ntlmv1_extended_response(server_challenge: &[u8], client_challenge: & } fn ntowf_v1(password: &str) -> Vec { - let password = password.encode_utf16().map(u16_to_bytes).collect::>().concat(); + let password = password + .encode_utf16() + .map(u16_to_bytes) + .collect::>() + .concat(); let mut pass_hash = Md4::new(); pass_hash.update(password); pass_hash.finalize().as_slice().into() diff --git a/smb/src/util/crypto/ntlm_v2.rs b/smb/src/util/crypto/ntlm_v2.rs index 2ba3e7a..6b35a82 100644 --- a/smb/src/util/crypto/ntlm_v2.rs +++ b/smb/src/util/crypto/ntlm_v2.rs @@ -3,30 +3,60 @@ use hmac::{Hmac, Mac}; use md4::Md4; use md5::Md5; -use smb_core::error::SMBError; use smb_core::SMBResult; +use smb_core::error::SMBError; use crate::byte_helper::u16_to_bytes; -pub fn authenticate_v2(domain: &str, account: &str, password: &str, server_challenge: &[u8], lm_response: &[u8], nt_response: &[u8]) -> SMBResult<(bool, Vec)> { +pub fn authenticate_v2( + domain: &str, + account: &str, + password: &str, + server_challenge: &[u8], + lm_response: &[u8], + nt_response: &[u8], +) -> SMBResult<(bool, Vec)> { // AV-pairs structure let server_name = &nt_response[44..(nt_response.len() - 4)]; - let (nt_exp, lm_exp, _nt_proof) = compute_ntlm_v2_response(server_challenge, &nt_response[16..], server_name, password, account, domain)?; + let (nt_exp, lm_exp, _nt_proof) = compute_ntlm_v2_response( + server_challenge, + &nt_response[16..], + server_name, + password, + account, + domain, + )?; let resp = nt_exp == nt_response || lm_exp == lm_response; let resp = if !resp { let lm_client_challenge = &lm_response[16..24]; - let expected_resp = compute_lmv2_response(server_challenge, lm_client_challenge, password, account, domain)?; + let expected_resp = compute_lmv2_response( + server_challenge, + lm_client_challenge, + password, + account, + domain, + )?; expected_resp == lm_response - } else { resp }; + } else { + resp + }; let resp = if !resp && nt_response.len() >= 16 { let client_nt_proof = &nt_response[0..16]; let client_structure_padded = &nt_response[16..]; - let expected_nt_proof = compute_ntlmv2_proof(server_challenge, client_structure_padded, password, account, domain)?; + let expected_nt_proof = compute_ntlmv2_proof( + server_challenge, + client_structure_padded, + password, + account, + domain, + )?; client_nt_proof == expected_nt_proof - } else { resp }; + } else { + resp + }; if resp { let response_key_nt = ntowf_v2(password, account, domain)?; @@ -34,14 +64,25 @@ pub fn authenticate_v2(domain: &str, account: &str, password: &str, server_chall // for session_base_key, since our reconstructed nt_proof may differ let client_nt_proof = &nt_response[0..16]; let session_base_key = new_hmac_from_slice(&response_key_nt)? - .chain_update(client_nt_proof).finalize().into_bytes().as_slice().into(); + .chain_update(client_nt_proof) + .finalize() + .into_bytes() + .as_slice() + .into(); Ok((resp, session_base_key)) } else { Ok((resp, Vec::new())) } } -fn compute_ntlm_v2_response(server_challenge: &[u8], client_challenge: &[u8], server_name: &[u8], password: &str, account: &str, domain: &str) -> SMBResult<(Vec, Vec, Vec)> { +fn compute_ntlm_v2_response( + server_challenge: &[u8], + client_challenge: &[u8], + server_name: &[u8], + password: &str, + account: &str, + domain: &str, +) -> SMBResult<(Vec, Vec, Vec)> { let time = &client_challenge[8..16]; let client_challenge = &client_challenge[16..24]; let temp = [ @@ -54,30 +95,33 @@ fn compute_ntlm_v2_response(server_challenge: &[u8], client_challenge: &[u8], se &[0; 4], server_name, &[0; 4], - ].concat(); + ] + .concat(); let key = lmowf_v2(password, account, domain)?; let proof_hmac = new_hmac_from_slice(&key)? .chain_update(server_challenge) .chain_update(&temp); - let nt_proof_str = hmac::Mac::finalize(proof_hmac) - .into_bytes(); - let nt_challenge_response = [ - nt_proof_str.as_slice(), - &temp, - ].concat(); + let nt_proof_str = hmac::Mac::finalize(proof_hmac).into_bytes(); + let nt_challenge_response = [nt_proof_str.as_slice(), &temp].concat(); let lm_challenge_hmac = new_hmac_from_slice(&key)? .chain_update(server_challenge) .chain_update(client_challenge); - let lm_challenge_response_1 = hmac::Mac::finalize(lm_challenge_hmac) - .into_bytes(); - let lm_challenge_response = [ - lm_challenge_response_1.as_slice(), - client_challenge, - ].concat(); - Ok((nt_challenge_response.to_vec(), lm_challenge_response.to_vec(), nt_proof_str.to_vec())) + let lm_challenge_response_1 = hmac::Mac::finalize(lm_challenge_hmac).into_bytes(); + let lm_challenge_response = [lm_challenge_response_1.as_slice(), client_challenge].concat(); + Ok(( + nt_challenge_response.to_vec(), + lm_challenge_response.to_vec(), + nt_proof_str.to_vec(), + )) } -fn compute_lmv2_response(server_challenge: &[u8], lm_client_challenge: &[u8], password: &str, account: &str, domain: &str) -> SMBResult> { +fn compute_lmv2_response( + server_challenge: &[u8], + lm_client_challenge: &[u8], + password: &str, + account: &str, + domain: &str, +) -> SMBResult> { let key = lmowf_v2(password, account, domain)?; let bytes_hmac = new_hmac_from_slice(&key)?; let bytes_hmac = bytes_hmac @@ -87,13 +131,17 @@ fn compute_lmv2_response(server_challenge: &[u8], lm_client_challenge: &[u8], pa Ok([result.into_bytes().as_slice(), lm_client_challenge].concat()) } -fn compute_ntlmv2_proof(server_challenge: &[u8], client_structure_padded: &[u8], password: &str, account: &str, domain: &str) -> SMBResult> { +fn compute_ntlmv2_proof( + server_challenge: &[u8], + client_structure_padded: &[u8], + password: &str, + account: &str, + domain: &str, +) -> SMBResult> { let key = ntowf_v2(password, account, domain)?; let temp = client_structure_padded; let bytes_hmac = new_hmac_from_slice(&key)?; - let bytes_hmac = bytes_hmac - .chain_update(server_challenge) - .chain_update(temp); + let bytes_hmac = bytes_hmac.chain_update(server_challenge).chain_update(temp); let result = hmac::Mac::finalize(bytes_hmac); Ok(result.into_bytes().as_slice().into()) } @@ -103,10 +151,18 @@ fn lmowf_v2(password: &str, user: &str, domain: &str) -> SMBResult> { } fn ntowf_v2(password: &str, user: &str, domain: &str) -> SMBResult> { - let password = password.encode_utf16().map(u16_to_bytes).collect::>().concat(); + let password = password + .encode_utf16() + .map(u16_to_bytes) + .collect::>() + .concat(); let password_hash = Md4::digest(password); let text = user.to_uppercase() + domain; - let bytes = text.encode_utf16().map(u16_to_bytes).collect::>().concat(); + let bytes = text + .encode_utf16() + .map(u16_to_bytes) + .collect::>() + .concat(); let mut hmac_md5 = new_hmac_from_slice(password_hash.as_slice())?; hmac_md5.update(&bytes); let result = hmac::Mac::finalize(hmac_md5); @@ -115,4 +171,4 @@ fn ntowf_v2(password: &str, user: &str, domain: &str) -> SMBResult> { fn new_hmac_from_slice(slice: &[u8]) -> SMBResult> { >::new_from_slice(slice).map_err(|_| SMBError::crypto_error("Invalid length for key")) -} \ No newline at end of file +} diff --git a/smb/src/util/crypto/smb2.rs b/smb/src/util/crypto/smb2.rs index 49ca084..57e2da1 100644 --- a/smb/src/util/crypto/smb2.rs +++ b/smb/src/util/crypto/smb2.rs @@ -3,13 +3,19 @@ use cmac::Cmac; use hmac::{Hmac, Mac}; use sha2::Sha256; -use smb_core::error::SMBError; use smb_core::SMBResult; +use smb_core::error::SMBError; use crate::protocol::body::dialect::SMBDialect; use crate::util::crypto::sp800_108; -pub fn calculate_signature(signing_key: &[u8], dialect: SMBDialect, buffer: &[u8], offset: usize, padded_len: usize) -> SMBResult> { +pub fn calculate_signature( + signing_key: &[u8], + dialect: SMBDialect, + buffer: &[u8], + offset: usize, + padded_len: usize, +) -> SMBResult> { let buffer = &buffer[offset..(offset + padded_len)]; let output = if dialect == SMBDialect::V2_0_2 || dialect == SMBDialect::V2_1_0 { new_sha256_from_slice(signing_key)? @@ -28,13 +34,19 @@ pub fn calculate_signature(signing_key: &[u8], dialect: SMBDialect, buffer: &[u8 Ok(output) } -pub fn generate_signing_key(session_key: &[u8], dialect: SMBDialect, preauth_integrity_hash_value: &[u8]) -> SMBResult> { +pub fn generate_signing_key( + session_key: &[u8], + dialect: SMBDialect, + preauth_integrity_hash_value: &[u8], +) -> SMBResult> { if dialect == SMBDialect::V2_0_2 || dialect == SMBDialect::V2_1_0 { return Ok(session_key.into()); } if dialect == SMBDialect::V3_1_1 && preauth_integrity_hash_value.is_empty() { - return Err(SMBError::PreconditionFailed("No preauth_integrity_hash_value with SMB 3.1.1".into())); + return Err(SMBError::PreconditionFailed( + "No preauth_integrity_hash_value with SMB 3.1.1".into(), + )); } let label: &[u8] = if dialect == SMBDialect::V3_1_1 { @@ -53,6 +65,5 @@ pub fn generate_signing_key(session_key: &[u8], dialect: SMBDialect, preauth_int } fn new_sha256_from_slice(slice: &[u8]) -> SMBResult> { - >::new_from_slice(slice) - .map_err(|_| SMBError::crypto_error("Invalid Key Length")) -} \ No newline at end of file + >::new_from_slice(slice).map_err(|_| SMBError::crypto_error("Invalid Key Length")) +} diff --git a/smb/src/util/crypto/sp800_108.rs b/smb/src/util/crypto/sp800_108.rs index c29c16f..4d324cf 100644 --- a/smb/src/util/crypto/sp800_108.rs +++ b/smb/src/util/crypto/sp800_108.rs @@ -2,7 +2,12 @@ use std::cmp::min; use digest::Mac; -pub fn derive_key(mac: T, label: &[u8], context: &[u8], key_len_bits: u32) -> Vec { +pub fn derive_key( + mac: T, + label: &[u8], + context: &[u8], + key_len_bits: u32, +) -> Vec { let mut buffer = vec![0_u8; 4 + label.len() + 1 + context.len() + 4]; buffer[4..(label.len() + 4)].copy_from_slice(label); @@ -25,10 +30,7 @@ pub fn derive_key(mac: T, label: &[u8], context: &[u8], key_len_ buffer[..bytes.len()].copy_from_slice(&bytes[..]); - let k_i = mac.clone() - .chain_update(&*buffer) - .finalize() - .into_bytes(); + let k_i = mac.clone().chain_update(&*buffer).finalize().into_bytes(); let num_to_copy = min(num_remaining, k_i.len() as u32); output[(num_written as usize)..(num_written + num_to_copy) as usize] @@ -40,4 +42,4 @@ pub fn derive_key(mac: T, label: &[u8], context: &[u8], key_len_ } output -} \ No newline at end of file +} diff --git a/smb/src/util/flags_helper.rs b/smb/src/util/flags_helper.rs index f5f4745..ab3dfba 100644 --- a/smb/src/util/flags_helper.rs +++ b/smb/src/util/flags_helper.rs @@ -43,4 +43,4 @@ macro_rules! impl_smb_to_bytes_for_bitflag {( pub(crate) use impl_smb_byte_size_for_bitflag; pub(crate) use impl_smb_from_bytes_for_bitflag; -pub(crate) use impl_smb_to_bytes_for_bitflag; \ No newline at end of file +pub(crate) use impl_smb_to_bytes_for_bitflag; diff --git a/smb/src/util/mod.rs b/smb/src/util/mod.rs index 6a49ee1..2923265 100644 --- a/smb/src/util/mod.rs +++ b/smb/src/util/mod.rs @@ -1,5 +1,5 @@ -pub mod auth; pub(crate) mod as_bytes; +pub mod auth; pub(crate) mod crypto; pub(crate) mod flags_helper; -pub(crate) mod num_limits; \ No newline at end of file +pub(crate) mod num_limits; diff --git a/smb/src/util/num_limits.rs b/smb/src/util/num_limits.rs index 215eadb..f7be516 100644 --- a/smb/src/util/num_limits.rs +++ b/smb/src/util/num_limits.rs @@ -72,21 +72,18 @@ impl One for u16 { fn one() -> Self { 1 } - } impl One for u32 { fn one() -> Self { 1 } - } impl One for u64 { fn one() -> Self { 1 } - } impl Zero for u8 { @@ -111,4 +108,4 @@ impl Zero for u64 { fn zero() -> Self { 0 } -} \ No newline at end of file +} diff --git a/smb/tests/message.rs b/smb/tests/message.rs index 0d1cd13..76fbc68 100644 --- a/smb/tests/message.rs +++ b/smb/tests/message.rs @@ -1,5 +1,5 @@ -use smb_reader::protocol::body::echo::SMBEchoRequest; use smb_reader::protocol::body::SMBBody; +use smb_reader::protocol::body::echo::SMBEchoRequest; use smb_reader::protocol::header::SMBSyncHeader; use smb_reader::protocol::message::{Message, SMBMessage}; @@ -8,16 +8,24 @@ fn test_as_bytes_empty_body() { use smb_reader::protocol::header::command_code::SMBCommandCode; use smb_reader::protocol::header::flags::SMBFlags; - let header = SMBSyncHeader::new(SMBCommandCode::Negotiate, SMBFlags::empty(), 0, 0, 0, 0, [0; 16]); + let header = SMBSyncHeader::new( + SMBCommandCode::Negotiate, + SMBFlags::empty(), + 0, + 0, + 0, + 0, + [0; 16], + ); let body = SMBBody::EchoRequest(SMBEchoRequest {}); let message = SMBMessage::new(header, body); let bytes = message.as_bytes(); assert!(!bytes.is_empty()); - assert_eq!(bytes[0..4], [0, 0, 0, 68]); // First 4 bytes should be [0, 0, (len..)] --> + assert_eq!(bytes[0..4], [0, 0, 0, 68]); // First 4 bytes should be [0, 0, (len..)] --> let expected_len = u16::from_be_bytes([bytes[2], bytes[3]]); - assert_eq!(bytes.len(), expected_len as usize + 4); // Total length should match + assert_eq!(bytes.len(), expected_len as usize + 4); // Total length should match } #[test] @@ -25,7 +33,15 @@ fn test_as_bytes_consistency() { use smb_reader::protocol::header::command_code::SMBCommandCode; use smb_reader::protocol::header::flags::SMBFlags; - let header = SMBSyncHeader::new(SMBCommandCode::Negotiate, SMBFlags::empty(), 0, 0, 0, 0, [0; 16]); + let header = SMBSyncHeader::new( + SMBCommandCode::Negotiate, + SMBFlags::empty(), + 0, + 0, + 0, + 0, + [0; 16], + ); let body = SMBBody::EchoRequest(SMBEchoRequest {}); let message1 = SMBMessage::new(header.clone(), body.clone()); let message2 = SMBMessage::new(header, body); @@ -33,8 +49,14 @@ fn test_as_bytes_consistency() { let bytes1 = message1.as_bytes(); let bytes2 = message2.as_bytes(); - assert_eq!(bytes1, bytes2, "Byte representations should be identical for the same message content"); - assert!(!bytes1.is_empty(), "Byte representation should not be empty"); + assert_eq!( + bytes1, bytes2, + "Byte representations should be identical for the same message content" + ); + assert!( + !bytes1.is_empty(), + "Byte representation should not be empty" + ); } #[test] @@ -47,7 +69,8 @@ fn test_as_bytes_serialization_deserialization() { let original_message = SMBMessage::new(header, body); let serialized = original_message.as_bytes(); - let (_, deserialized_message) = SMBMessage::::parse(&serialized[4..]).unwrap(); + let (_, deserialized_message) = + SMBMessage::::parse(&serialized[4..]).unwrap(); assert_eq!(original_message, deserialized_message); -} \ No newline at end of file +} diff --git a/smb/tests/smbclient.rs b/smb/tests/smbclient.rs index e55b3e7..0f70bfb 100644 --- a/smb/tests/smbclient.rs +++ b/smb/tests/smbclient.rs @@ -42,7 +42,10 @@ fn spawn_server(port: u16) -> Child { } std::thread::sleep(Duration::from_millis(100)); } - panic!("Server did not start listening on {} within 5 seconds", addr); + panic!( + "Server did not start listening on {} within 5 seconds", + addr + ); } /// Run an smbclient command and return (exit_status, stdout, stderr). @@ -75,10 +78,13 @@ fn negotiate_completes() { let port_str = port.to_string(); let (_success, _stdout, stderr) = run_smbclient(&[ "//127.0.0.1/share", - "-p", &port_str, - "-N", // no password - "-m", "SMB2", - "-c", "exit", + "-p", + &port_str, + "-N", // no password + "-m", + "SMB2", + "-c", + "exit", ]); // smbclient may fail auth but should get past negotiate. @@ -106,10 +112,13 @@ fn server_does_not_crash_on_smb1_only() { let port_str = port.to_string(); let (_success, _stdout, _stderr) = run_smbclient(&[ "//127.0.0.1/share", - "-p", &port_str, + "-p", + &port_str, "-N", - "-m", "NT1", - "-c", "exit", + "-m", + "NT1", + "-c", + "exit", ]); // Server should still be running (not crashed) @@ -142,10 +151,14 @@ fn session_setup_with_credentials() { let port_str = port.to_string(); let (_success, _stdout, stderr) = run_smbclient(&[ "//127.0.0.1/share", - "-p", &port_str, - "-U", "testuser%testpass", - "-m", "SMB2", - "-c", "exit", + "-p", + &port_str, + "-U", + "testuser%testpass", + "-m", + "SMB2", + "-c", + "exit", ]); // Server should not crash @@ -170,10 +183,13 @@ fn session_setup_anonymous() { let port_str = port.to_string(); let (_success, _stdout, stderr) = run_smbclient(&[ "//127.0.0.1/share", - "-p", &port_str, + "-p", + &port_str, "-N", - "-m", "SMB2", - "-c", "exit", + "-m", + "SMB2", + "-c", + "exit", ]); // Server should not crash @@ -206,10 +222,14 @@ fn tree_connect_to_share() { let port_str = port.to_string(); let (_success, _stdout, stderr) = run_smbclient(&[ "//127.0.0.1/share", - "-p", &port_str, - "-U", "testuser%testpass", - "-m", "SMB2", - "-c", "ls", + "-p", + &port_str, + "-U", + "testuser%testpass", + "-m", + "SMB2", + "-c", + "ls", ]); // Server should not crash @@ -234,10 +254,14 @@ fn tree_connect_nonexistent_share() { let port_str = port.to_string(); let (success, _stdout, stderr) = run_smbclient(&[ "//127.0.0.1/nonexistent_share_xyz", - "-p", &port_str, - "-U", "testuser%testpass", - "-m", "SMB2", - "-c", "ls", + "-p", + &port_str, + "-U", + "testuser%testpass", + "-m", + "SMB2", + "-c", + "ls", ]); // Should fail (share doesn't exist) @@ -269,10 +293,13 @@ fn server_survives_multiple_connections() { let port_str = port.to_string(); let (_success, _stdout, _stderr) = run_smbclient(&[ "//127.0.0.1/share", - "-p", &port_str, + "-p", + &port_str, "-N", - "-m", "SMB2", - "-c", "exit", + "-m", + "SMB2", + "-c", + "exit", ]); }