diff --git a/Cargo.toml b/Cargo.toml index 04984d7..4f39ae9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -49,7 +49,6 @@ http-body-util = "0.1.2" tower-reqwest = "0.4.0" color-eyre = "0.6" tracing-error = "0.2.1" - asn1_codecs_derive = "0.7.0" asn1-codecs = "0.7.0" smart-default = "0.7.1" @@ -68,14 +67,13 @@ tokio-sctp = "0.2.0" valuable = "0.1.0" statig = {version = "0.3.0", features = ["async"]} non-empty-string = "0.2.6" - -# Depenedency of tokio-sctp -socket2 = "0.4" derive-new = "0.7" faster-hex = "0.10.0" nonempty = { version = "0.8.1", features = ["serialize"] } bitvec = "1.0.1" ascii = "1.1.0" +atomic_enum = "0.3.0" +socket2 = "0.4" # Depenedency of tokio-sctp oasbi = { git = "https://github.com/UnifyAir/open-api.git/", package = "oasbi", branch = "master" } openapi-smf = { git = "https://github.com/UnifyAir/open-api.git/", package = "openapi-smf", features = [ @@ -96,4 +94,3 @@ openapi-nrf = { git = "https://github.com/UnifyAir/open-api.git/", package = "op ngap-models = { git = "https://github.com/UnifyAir/asn-models.git/", package = "ngap", branch = "master" } asn1-per = { git = "https://github.com/UnifyAir/asn-models.git/", package = "asn1-per", branch = "master" } nas-models = { git = "https://github.com/UnifyAir/nas-models.git/", package = "nas-models", branch = "master" } - diff --git a/lightning-nf/omnipath/app/Cargo.toml b/lightning-nf/omnipath/app/Cargo.toml index 923fd2e..433d865 100644 --- a/lightning-nf/omnipath/app/Cargo.toml +++ b/lightning-nf/omnipath/app/Cargo.toml @@ -57,8 +57,10 @@ valuable.workspace = true ascii.workspace = true non-empty-string.workspace = true statig.workspace = true +atomic_enum.workspace = true counter = { path = "../../../utils/counter" } client = { path = "../../../utils/client" } nf-base = { path = "../../../utils/nf-base" } +atomic-handle = {path = "../../../utils/atomic-handle"} diff --git a/lightning-nf/omnipath/app/src/context/gnb_context.rs b/lightning-nf/omnipath/app/src/context/gnb_context.rs index 9d6adde..1f923fa 100644 --- a/lightning-nf/omnipath/app/src/context/gnb_context.rs +++ b/lightning-nf/omnipath/app/src/context/gnb_context.rs @@ -2,13 +2,16 @@ use std::sync::Arc; use counter::CounterU64; use derive_new::new; -use ngap_models::{GlobalRanNodeId, PagingDrx}; +use ngap_models::{FiveGSTmsi, GlobalRanNodeId, PagingDrx, RanUeNgapId}; use nonempty::NonEmpty; use oasbi::common::{Snssai, Tai}; +use scc::hash_map::HashMap as SccHashMap; use tokio_util::sync::CancellationToken; +use atomic_handle::AtomicHandle; + use crate::{ - context::ue_context::UeContext, + context::ue_context::{GmmStateField, UeContext}, ngap::{manager::ContextManager, network::TnlaAssociation}, }; @@ -22,6 +25,13 @@ pub struct GnbContext { #[new(value = "ContextManager::new()")] pub ue_context_manager: ContextManager, + // List of registered ues who might be paged later + #[new(default)] + pub idle_ues: SccHashMap, + + #[new(default)] + pub ue_states: SccHashMap>, + #[new(default)] pub name: String, diff --git a/lightning-nf/omnipath/app/src/context/mod.rs b/lightning-nf/omnipath/app/src/context/mod.rs index f08ba20..cb43c32 100644 --- a/lightning-nf/omnipath/app/src/context/mod.rs +++ b/lightning-nf/omnipath/app/src/context/mod.rs @@ -1,9 +1,13 @@ pub mod app_context; mod gnb_context; mod ngap_context; +mod state; mod ue_context; +mod nas_context; pub use app_context::AppContext; pub use gnb_context::{GnbContext, SupportedTai}; pub use ngap_context::NgapContext; +pub use state::{AtomicGmmState, GmmState}; pub use ue_context::UeContext; +pub use nas_context::NasContext; diff --git a/lightning-nf/omnipath/app/src/context/nas_context.rs b/lightning-nf/omnipath/app/src/context/nas_context.rs new file mode 100644 index 0000000..b8ae8c0 --- /dev/null +++ b/lightning-nf/omnipath/app/src/context/nas_context.rs @@ -0,0 +1,25 @@ +use std::{cell::UnsafeCell, num::NonZeroU32}; + +use derive_new::new; +use ngap_models::RrcEstablishmentCause; +use non_empty_string::NonEmptyString; + +use crate::utils::models::FiveGSTmsi; + +#[derive(new)] +pub struct NasContext { + pub rrc_establishment_cause: RrcEstablishmentCause, + pub five_g_s_tmsi: Option, + #[new(default)] + pub tmsi: Option, + #[new(default)] + pub guti: Option, + #[new(default)] + pub suci: Option, + #[new(default)] + pub pei: Option, + #[new(default)] + pub mac_addr: Option, + #[new(default)] + pub plmn_id: Option, +} diff --git a/lightning-nf/omnipath/app/src/context/state/atomic_gmm_state.rs b/lightning-nf/omnipath/app/src/context/state/atomic_gmm_state.rs new file mode 100644 index 0000000..dccbf8e --- /dev/null +++ b/lightning-nf/omnipath/app/src/context/state/atomic_gmm_state.rs @@ -0,0 +1,218 @@ +use std::{ + fmt, ops::Deref, sync::atomic::{AtomicUsize, Ordering} +}; + +use super::GmmState; + +/// A custom atomic wrapper for GmmState that performs operations directly on +/// the enum's discriminant to avoid the potential overhead of `match` +/// statements. +/// +/// The safety of the `unsafe` transmutation relies on the invariant that this +/// struct only ever stores values that are valid `GmmState` discriminants. +#[derive(Debug)] +pub struct AtomicGmmState { + state: AtomicUsize, +} + +impl AtomicGmmState { + /// Creates a new AtomicGmmState with the given initial state. + pub const fn new(initial_state: GmmState) -> Self { + Self { + state: AtomicUsize::new(initial_state as usize), + } + } + + /// Loads the current state using `std::mem::transmute`. + #[inline] + pub fn load( + &self, + ordering: Ordering, + ) -> GmmState { + let discriminant = self.state.load(ordering); + // This assertion ensures that, in debug builds, we panic if the state + // ever holds an invalid discriminant, which would be undefined behavior. + debug_assert!( + discriminant <= GmmState::MAX_DISCRIMINANT as usize, + "Invalid GmmState discriminant: {}", + discriminant + ); + // Safety: The `debug_assert` and disciplined use of the `store` methods + // ensure that `state` only contains valid discriminants for `GmmState`. + // `GmmState` is `#[repr(u8)]`, so transmutation from its discriminant is sound. + unsafe { std::mem::transmute(discriminant as u8) } + } + + /// Stores a new state. + #[inline] + pub fn store( + &self, + new_state: GmmState, + ordering: Ordering, + ) { + self.state.store(new_state as usize, ordering); + } + + /// Atomically swaps the state and returns the previous state. + #[inline] + pub fn swap( + &self, + new_state: GmmState, + ordering: Ordering, + ) -> GmmState { + let old_discriminant = self.state.swap(new_state as usize, ordering); + debug_assert!( + old_discriminant <= GmmState::MAX_DISCRIMINANT as usize, + "Invalid GmmState discriminant: {}", + old_discriminant + ); + // Safety: Same justification as `load`. + unsafe { std::mem::transmute(old_discriminant as u8) } + } + + /// A helper function to reduce code duplication in `compare_exchange` + /// methods. + #[inline] + fn transmute_result(result: Result) -> Result { + match result { + Ok(prev) => { + debug_assert!(prev <= GmmState::MAX_DISCRIMINANT as usize); + Ok(unsafe { std::mem::transmute(prev as u8) }) + } + Err(actual) => { + debug_assert!(actual <= GmmState::MAX_DISCRIMINANT as usize); + Err(unsafe { std::mem::transmute(actual as u8) }) + } + } + } + + /// Atomically compares the current state with `current` and, if they match, + /// replaces it with `new`. + #[inline] + pub fn compare_exchange( + &self, + current: GmmState, + new: GmmState, + success: Ordering, + failure: Ordering, + ) -> Result { + let result = self + .state + .compare_exchange(current as usize, new as usize, success, failure); + Self::transmute_result(result) + } + + /// Performs a weak compare-and-exchange operation. + #[inline] + pub fn compare_exchange_weak( + &self, + current: GmmState, + new: GmmState, + success: Ordering, + failure: Ordering, + ) -> Result { + let result = + self.state + .compare_exchange_weak(current as usize, new as usize, success, failure); + Self::transmute_result(result) + } + + /// Atomically modifies the state with a given function. + #[inline] + pub fn fetch_update( + &self, + set_order: Ordering, + fetch_order: Ordering, + mut f: F, + ) -> Result + where + F: FnMut(GmmState) -> Option, + { + let result = self + .state + .fetch_update(set_order, fetch_order, |discriminant| { + debug_assert!(discriminant <= GmmState::MAX_DISCRIMINANT as usize); + // Safety: Same justification as `load`. + let current_state = unsafe { std::mem::transmute(discriminant as u8) }; + f(current_state).map(|new_state| new_state as usize) + }); + Self::transmute_result(result) + } + + // --- Convenience Methods --- + + /// Gets the current state using `Acquire` ordering. + #[inline] + pub fn get(&self) -> GmmState { + self.load(Ordering::Acquire) + } + + /// Sets the current state using `Release` ordering. + #[inline] + pub fn set( + &self, + state: GmmState, + ) { + self.store(state, Ordering::Release); + } + + /// Checks if the current state matches the given state. + #[inline] + pub fn is( + &self, + expected: GmmState, + ) -> bool { + self.get() == expected + } + + /// Checks if the UE can access services in the current state. + #[inline] + pub fn can_access_services(&self) -> bool { + self.get().can_access_services() + } + + /// Checks if the current state is transitional. + #[inline] + pub fn is_transitional(&self) -> bool { + self.get().is_transitional() + } + +} + +impl Default for AtomicGmmState { + fn default() -> Self { + Self::new(GmmState::default()) + } +} + +impl Deref for AtomicGmmState { + type Target = AtomicUsize; + + fn deref(&self) -> &Self::Target { + &self.state + } +} + +impl Clone for AtomicGmmState { + /// Clones the `AtomicGmmState` by creating a new atomic variable + /// initialized with the current state's value. + fn clone(&self) -> Self { + Self::new(self.get()) + } +} + +impl From for AtomicGmmState { + fn from(state: GmmState) -> Self { + Self::new(state) + } +} + +impl fmt::Display for AtomicGmmState { + fn fmt( + &self, + f: &mut fmt::Formatter<'_>, + ) -> fmt::Result { + // Delegate formatting to the underlying GmmState's Debug or Display impl. + write!(f, "{:?}", self.get()) + } +} diff --git a/lightning-nf/omnipath/app/src/context/state/gmm_state.rs b/lightning-nf/omnipath/app/src/context/state/gmm_state.rs new file mode 100644 index 0000000..9ced94f --- /dev/null +++ b/lightning-nf/omnipath/app/src/context/state/gmm_state.rs @@ -0,0 +1,194 @@ +use std::{ + fmt::{Debug, Display}, +}; + +/// 5G Mobility Management (GMM) State Machine - Top Level States Only +#[repr(u8)] +#[derive(Eq, PartialEq, Hash)] +pub enum GmmState { + /// UE is not registered with the network + Deregistered = 0, + + /// Deregistration procedure has been initiated + DeregistrationInitiated = 1, + + /// Registration procedure has been initiated + RegistrationInitiated = 2, + + /// UE is in authentication phase + Unauthenticated = 3, + + /// UE has been successfully authenticated + Authenticated = 4, + + /// Security mode procedures have been completed + SecurityModeDone = 5, + + /// UE is fully registered and can access services + Registered = 6, + + /// Common procedure initiated (service mode, other procedures) + CommonProcedureInitiated = 7, + + /// An irrecoverable error has occurred, requiring cleanup. + Irrecoverable = 8, +} + +impl GmmState { + pub const MAX_DISCRIMINANT: u8 = GmmState::Irrecoverable as u8; + + /// Convert to integer value + pub const fn to_u8(self) -> u8 { + self as u8 + } + + /// Convert from integer value + pub const fn from_u8(value: u8) -> Option { + match value { + 0 => Some(Self::Deregistered), + 1 => Some(Self::DeregistrationInitiated), + 2 => Some(Self::RegistrationInitiated), + 3 => Some(Self::Unauthenticated), + 4 => Some(Self::Authenticated), + 5 => Some(Self::SecurityModeDone), + 6 => Some(Self::Registered), + 7 => Some(Self::CommonProcedureInitiated), + 8 => Some(Self::Irrecoverable), + _ => None, + } + } + + /// Get a human-readable description of the state + pub fn description(&self) -> &'static str { + match self { + GmmState::Deregistered => "UE is deregistered from the 5G network", + GmmState::DeregistrationInitiated => "Deregistration procedure initiated", + GmmState::RegistrationInitiated => "Registration procedure initiated", + GmmState::Unauthenticated => "UE is undergoing authentication procedures", + GmmState::Authenticated => "UE has been successfully authenticated", + GmmState::SecurityModeDone => "Security mode procedures completed", + GmmState::Registered => "UE is registered and can access 5G services", + GmmState::CommonProcedureInitiated => "Common procedure (service mode, etc.) initiated", + GmmState::Irrecoverable => "An irrecoverable error occurred; cleanup required", + } + } + + /// Check if the UE can initiate services in this state + pub fn can_access_services(&self) -> bool { + matches!(self, GmmState::Registered) + } + + /// Check if this is a transitional state (temporary during procedures) + pub fn is_transitional(&self) -> bool { + matches!( + self, + GmmState::DeregistrationInitiated + | GmmState::RegistrationInitiated + | GmmState::Unauthenticated + | GmmState::Authenticated + | GmmState::SecurityModeDone + | GmmState::CommonProcedureInitiated + ) + } + + /// Get the next expected state in the registration flow + pub fn next_state(&self) -> Option { + match self { + GmmState::Deregistered => Some(GmmState::RegistrationInitiated), + GmmState::DeregistrationInitiated => Some(GmmState::Deregistered), + GmmState::RegistrationInitiated => Some(GmmState::Unauthenticated), + GmmState::Unauthenticated => Some(GmmState::Authenticated), + GmmState::Authenticated => Some(GmmState::SecurityModeDone), + GmmState::SecurityModeDone => Some(GmmState::Registered), + GmmState::Registered => None, // Final state or can go to CommonProcedureInitiated + GmmState::CommonProcedureInitiated => Some(GmmState::Registered), // Usually returns to Registered */ + GmmState::Irrecoverable => None, + } + } + + /// Check if transition to target state is valid + pub fn can_transition_to( + &self, + target: Self, + ) -> bool { + match (self, target) { + // Any state can transition to Irrecoverable + (_, GmmState::Irrecoverable) => true, + // Cannot transition from Irrecoverable + (GmmState::Irrecoverable, _) => false, + + // Forward registration transitions + (GmmState::Deregistered, GmmState::RegistrationInitiated) => true, + (GmmState::RegistrationInitiated, GmmState::Unauthenticated) => true, + (GmmState::Unauthenticated, GmmState::Authenticated) => true, + (GmmState::Authenticated, GmmState::SecurityModeDone) => true, + (GmmState::SecurityModeDone, GmmState::Registered) => true, + + // Deregistration flow + (GmmState::Registered, GmmState::DeregistrationInitiated) => true, + (GmmState::DeregistrationInitiated, GmmState::Deregistered) => true, + + // Common procedure transitions + (GmmState::Registered, GmmState::CommonProcedureInitiated) => true, + (GmmState::CommonProcedureInitiated, GmmState::Registered) => true, + + // Failure transitions - can always go back to deregistered + (_, GmmState::Deregistered) => true, + + // Stay in same state + (state, target) if state == &target => true, + + // Invalid transitions + _ => false, + } + } +} + +impl Default for GmmState { + fn default() -> Self { + GmmState::Deregistered + } +} + +impl From for u8 { + fn from(state: GmmState) -> Self { + state.to_u8() + } +} + +impl TryFrom for GmmState { + type Error = (); + + fn try_from(value: u8) -> Result { + Self::from_u8(value).ok_or(()) + } +} + +impl Display for GmmState { + fn fmt( + &self, + f: &mut std::fmt::Formatter<'_>, + ) -> std::fmt::Result { + let name = match self { + GmmState::Deregistered => "GMM-DEREGISTERED", + GmmState::DeregistrationInitiated => "GMM-DEREGISTRATION-INITIATED", + GmmState::RegistrationInitiated => "GMM-REGISTRATION-INITIATED", + GmmState::Unauthenticated => "GMM-UNAUTHENTICATED", + GmmState::Authenticated => "GMM-AUTHENTICATED", + GmmState::SecurityModeDone => "GMM-SECURITY-MODE-DONE", + GmmState::Registered => "GMM-REGISTERED", + GmmState::CommonProcedureInitiated => "GMM-COMMON-PROCEDURE-INITIATED", + GmmState::Irrecoverable => "GMM-IRRECOVERABLE", + }; + write!(f, "{}", name) + } +} + +impl Debug for GmmState { + fn fmt( + &self, + f: &mut std::fmt::Formatter<'_>, + ) -> std::fmt::Result { + Display::fmt(self, f) + } +} diff --git a/lightning-nf/omnipath/app/src/context/state/mod.rs b/lightning-nf/omnipath/app/src/context/state/mod.rs new file mode 100644 index 0000000..7011dd1 --- /dev/null +++ b/lightning-nf/omnipath/app/src/context/state/mod.rs @@ -0,0 +1,4 @@ +mod gmm_state; +mod atomic_gmm_state; +pub use gmm_state::GmmState; +pub use atomic_gmm_state::AtomicGmmState; \ No newline at end of file diff --git a/lightning-nf/omnipath/app/src/context/ue_context.rs b/lightning-nf/omnipath/app/src/context/ue_context.rs index 6b256f1..71fbf33 100644 --- a/lightning-nf/omnipath/app/src/context/ue_context.rs +++ b/lightning-nf/omnipath/app/src/context/ue_context.rs @@ -1,34 +1,29 @@ -use std::{num::NonZeroU32, sync::Arc}; +use std::sync::{Arc, atomic::AtomicUsize}; +use atomic_handle::AtomicOperation; use derive_new::new; -use ngap_models::{AmfUeNgapId, RanUeNgapId, RrcEstablishmentCause}; -use non_empty_string::NonEmptyString; -use statig::awaitable::StateMachine; +use ngap_models::{AmfUeNgapId, RanUeNgapId}; -use super::GnbContext; -use crate::{nas::nas_context::NasContext, ngap::manager::Identifiable, utils::models::FiveGSTmsi}; +use crate::{ + context::{AtomicGmmState, GnbContext, NasContext}, + ngap::manager::Identifiable, +}; #[derive(new)] pub struct UeContext { pub ran_ue_ngap_id: RanUeNgapId, pub amf_ue_ngap_id: AmfUeNgapId, - pub rrc_establishment_cause: RrcEstablishmentCause, pub gnb_context: Arc, - pub five_g_s_tmsi: Option, - - pub gmm: Arc>, - #[new(default)] - pub tmsi: Option, - #[new(default)] - pub guti: Option, - #[new(default)] - pub suci: Option, - #[new(default)] - pub pei: Option, - #[new(default)] - pub mac_addr: Option, - #[new(default)] - pub plmn_id: Option, + pub state: AtomicGmmState, + pub nas_context: NasContext, +} + +pub struct GmmStateField; + +impl AtomicOperation for UeContext { + fn get_atomic(&self) -> &AtomicUsize { + &self.state + } } impl std::fmt::Debug for UeContext { @@ -39,15 +34,7 @@ impl std::fmt::Debug for UeContext { f.debug_struct("UeContext") .field("ran_ue_ngap_id", &self.ran_ue_ngap_id) .field("amf_ue_ngap_id", &self.amf_ue_ngap_id) - .field("rrc_establishment_cause", &self.rrc_establishment_cause) - .field("gnb_context", &self.gnb_context) - .field("five_g_s_tmsi", &self.five_g_s_tmsi) - .field("tmsi", &self.tmsi) - .field("guti", &self.guti) - .field("suci", &self.suci) - .field("pei", &self.pei) - .field("mac_addr", &self.mac_addr) - .field("plmn_id", &self.plmn_id) + .field("gmm", &self.state.get()) .finish() } } diff --git a/lightning-nf/omnipath/app/src/lib.rs b/lightning-nf/omnipath/app/src/lib.rs index 371b898..3998061 100644 --- a/lightning-nf/omnipath/app/src/lib.rs +++ b/lightning-nf/omnipath/app/src/lib.rs @@ -1,8 +1,11 @@ #![feature(error_generic_member_access)] +#![feature(adt_const_params)] +#![feature(unsized_const_params)] pub mod builder; pub(crate) mod config; pub(crate) mod context; +pub mod nas_old; pub mod nas; pub mod ngap; pub mod utils; diff --git a/lightning-nf/omnipath/app/src/nas/interface.rs b/lightning-nf/omnipath/app/src/nas/interface.rs new file mode 100644 index 0000000..fe773b0 --- /dev/null +++ b/lightning-nf/omnipath/app/src/nas/interface.rs @@ -0,0 +1,227 @@ +use std::{fmt::Debug, future::Future}; + +use bytes::Bytes; +use nas_models::{TlvDecode, TlvError}; +use ngap_models::NgapPdu; +use thiserror::Error; + +use crate::context::GmmState; + +// TODO: Convert this State to associated type with ConstParamTy_. +// Rust Lang issue: https://github.com/rust-lang/rust/issues/98210 +type State = GmmState; + +/// A new, more detailed error for invalid state transitions. +/// It is generic over the request type to provide better context. +#[derive(Error, Debug)] +#[error("Invalid state transition from {state:?} with request: {request:?}")] +pub struct InvalidStateTransition { + pub state: State, + pub request: R, +} + +impl InvalidStateTransition { + pub fn new( + state: State, + request: R, + ) -> Self { + Self { state, request } + } +} + +/// Defines the contract for handling a NAS message. +/// This version manually specifies `-> impl Future` to avoid the `async-trait` +/// macro and ensures all returned futures are Send and Sync. Their lifetime is +/// correctly tied to the `&mut self` borrow. +pub trait NasMessageHandle: Send + Sync { + /// The request message type, which must be decodable, debuggable, and safe + /// to send and share across threads. + type Request: TlvDecode + Debug + Send + Sync + 'static; + + /// The handler's error type. + type Error: From> + From + Debug; + + /// Decodes raw bytes into the request message. + fn decode(bytes: Bytes) -> Result { + let mut bytes = bytes; + TlvDecode::decode(bytes.len(), &mut bytes) + } + + /// An optional hook to run before the main state transition logic. + fn pre_comp_state_change( + &mut self, + state: State, + _req: &mut Self::Request, + ) -> Result { + Ok(state) + } + + /// The main state transition dispatcher. + fn state_transition( + &mut self, + from_state: State, + req: Self::Request, + ) -> impl Future), Self::Error>> + Send + Sync { + async move { + let mut req = req; + let from_state = self.pre_comp_state_change(from_state, &mut req)?; + match from_state { + GmmState::Deregistered => self.state_transition_deregistered(req).await, + GmmState::DeregistrationInitiated => { + self.state_transition_deregistration_initiated(req).await + } + GmmState::RegistrationInitiated => { + self.state_transition_registration_initiated(req).await + } + GmmState::Unauthenticated => self.state_transition_unauthenticated(req).await, + GmmState::Authenticated => self.state_transition_authenticated(req).await, + GmmState::SecurityModeDone => self.state_transition_security_mode_done(req).await, + GmmState::Registered => self.state_transition_registered(req).await, + GmmState::CommonProcedureInitiated => { + self.state_transition_common_procedure_initiated(req).await + } + GmmState::Irrecoverable => self.state_transition_irrecoverable(req).await, + } + } + } + + // --- Default Implementations for each state --- + + fn state_transition_deregistered( + &mut self, + req: Self::Request, + ) -> impl Future), Self::Error>> + Send + Sync { + async move { Err(InvalidStateTransition::new(GmmState::Deregistered, req).into()) } + } + + fn state_transition_deregistration_initiated( + &mut self, + req: Self::Request, + ) -> impl Future), Self::Error>> + Send + Sync { + async move { Err(InvalidStateTransition::new(GmmState::DeregistrationInitiated, req).into()) } + } + + fn state_transition_registration_initiated( + &mut self, + req: Self::Request, + ) -> impl Future), Self::Error>> + Send + Sync { + async move { Err(InvalidStateTransition::new(GmmState::RegistrationInitiated, req).into()) } + } + + fn state_transition_unauthenticated( + &mut self, + req: Self::Request, + ) -> impl Future), Self::Error>> + Send + Sync { + async move { Err(InvalidStateTransition::new(GmmState::Unauthenticated, req).into()) } + } + + fn state_transition_authenticated( + &mut self, + req: Self::Request, + ) -> impl Future), Self::Error>> + Send + Sync { + async move { Err(InvalidStateTransition::new(GmmState::Authenticated, req).into()) } + } + + fn state_transition_security_mode_done( + &mut self, + req: Self::Request, + ) -> impl Future), Self::Error>> + Send + Sync { + async move { Err(InvalidStateTransition::new(GmmState::SecurityModeDone, req).into()) } + } + fn state_transition_registered( + &mut self, + req: Self::Request, + ) -> impl Future), Self::Error>> + Send + Sync { + async move { Err(InvalidStateTransition::new(GmmState::Registered, req).into()) } + } + + fn state_transition_common_procedure_initiated( + &mut self, + req: Self::Request, + ) -> impl Future), Self::Error>> + Send + Sync { + async move { Err(InvalidStateTransition::new(GmmState::CommonProcedureInitiated, req).into()) } + } + + fn state_transition_irrecoverable( + &mut self, + req: Self::Request, + ) -> impl Future), Self::Error>> + Send + Sync { + async move { Err(InvalidStateTransition::new(GmmState::CommonProcedureInitiated, req).into()) } + } +} + +// --- The Tokio Test Case --- +#[cfg(test)] +mod tests { + use std::{future::Future, sync::Arc}; + + use ngap_models::NgapPdu; + + use super::{GmmState, InvalidStateTransition, NasMessageHandle}; + use crate::ngap::engine::EmptyResponse; + + // --- Minimal Handler Implementation --- + + #[derive(Debug, Clone)] + struct DummyRequest; + impl nas_models::TlvDecode for DummyRequest { + fn decode( + _len: usize, + _bytes: &mut bytes::Bytes, + ) -> Result { + Ok(DummyRequest) + } + } + + #[derive(Debug, thiserror::Error)] + enum DummyError { + #[error(transparent)] + InvalidState(#[from] InvalidStateTransition), + #[error(transparent)] + Tlv(#[from] nas_models::TlvError), + } + + #[derive(Debug)] + struct DummyHandler; + + impl NasMessageHandle for DummyHandler { + type Request = DummyRequest; + type Error = DummyError; + + // We only need to override one method for the test. + fn state_transition_registered( + &mut self, + _req: Self::Request, + ) -> impl std::future::Future), Self::Error>> + + Send + + Sync { + async move { Ok((GmmState::CommonProcedureInitiated, None)) } + } + } + + /// Helper to enforce `tokio::spawn` bounds at compile time. + fn require_spawn_bounds(_future: F) + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + // This function's body is intentionally empty. + // Successful compilation is the test. + } + + #[tokio::test] + async fn handler_future_satisfies_spawn_bounds() { + // Setup the handler and request. + let mut handler = DummyHandler; + let request = DummyRequest; + let from_state = GmmState::Registered; + + // Create the future that will be tested. + // It is 'static because it owns all its data (the Arc clone and the request). + let future_to_test = async move { handler.state_transition(from_state, request).await }; + + // This line is the entire test. Successful compilation proves the bounds are + // met. + require_spawn_bounds(future_to_test); + } +} diff --git a/lightning-nf/omnipath/app/src/nas/mod.rs b/lightning-nf/omnipath/app/src/nas/mod.rs index 6a5e243..a1b0d17 100644 --- a/lightning-nf/omnipath/app/src/nas/mod.rs +++ b/lightning-nf/omnipath/app/src/nas/mod.rs @@ -1,33 +1,29 @@ -use std::{future::Future, sync::Arc}; - +mod interface; +mod registration_req; +pub use interface::{InvalidStateTransition, NasMessageHandle}; use crate::context::UeContext; -use nas_context::NasContext; -use error::NasHandlerError; - - -pub mod nas_context; -mod handlers; -mod error; -mod gmm; -mod builders; -mod ue_actions; -pub trait NasHandler { - fn handle( - &self, - nas_context: &mut NasContext, - ue_context: &mut UeContext, - ) -> impl Future> + Send; +impl UeContext { + pub async fn handle_nas( + &mut self, + nas_pdu: Vec, + ) { + + // * Need some thought here about how to handle this + + // // Todo: fix this to have a single Bytes for Ngap and Nas + // let mut bytes = Bytes::from(nas_pdu); + + // let mut gmm = self.gmm.clone(); + // // Safety: unwrap over Arc::get_mut will succeed because + // // no one will get a mutable reference to the NasContext + // // and that will only be mutated through the StateMachine + // // Todo:: make nas_context internal field private by mod __private + // if let Ok(gmm_message) = GmmMessage::try_from(&mut bytes) { + // Arc::get_mut(&mut gmm).unwrap().handle_with_context(&gmm_message, + // self); } else { + // trace!("Invalid NAS PDU: {:?}", bytes); + // } + } } - - -pub trait NasBuilder: Sized { - fn build( - nas_context: &NasContext, - ue_context: &UeContext, - ) -> Result; -} - - - diff --git a/lightning-nf/omnipath/app/src/nas/registration_req/error.rs b/lightning-nf/omnipath/app/src/nas/registration_req/error.rs new file mode 100644 index 0000000..4849378 --- /dev/null +++ b/lightning-nf/omnipath/app/src/nas/registration_req/error.rs @@ -0,0 +1,14 @@ +use thiserror::Error; +use nas_models::TlvError; +use nas_models::message::NasRegistrationRequest; +use super::super::InvalidStateTransition; + + +#[derive(Error, Debug)] +pub enum RegistrationReqError { + #[error(transparent)] + InvalidState(#[from] InvalidStateTransition), + #[error(transparent)] + Tlv(#[from] TlvError), + +} diff --git a/lightning-nf/omnipath/app/src/nas/registration_req/mod.rs b/lightning-nf/omnipath/app/src/nas/registration_req/mod.rs new file mode 100644 index 0000000..9422021 --- /dev/null +++ b/lightning-nf/omnipath/app/src/nas/registration_req/mod.rs @@ -0,0 +1,64 @@ +mod error; + +use std::num::NonZeroU32; + +pub use error::RegistrationReqError; +use nas_models::{message::NasRegistrationRequest, types::MobileIdentity}; +use ngap_models::NgapPdu; +use non_empty_string::NonEmptyString; + +use super::{InvalidStateTransition, NasMessageHandle}; +use crate::context::{GmmState, UeContext}; + +impl NasMessageHandle for UeContext { + type Request = NasRegistrationRequest; + type Error = RegistrationReqError; + + fn pre_comp_state_change( + &mut self, + state: GmmState, + _req: &mut Self::Request, + ) -> Result { + + Ok(GmmState::RegistrationInitiated) + } + + async fn state_transition_deregistered( + &mut self, + req: Self::Request, + ) -> Result<(GmmState, Option), Self::Error> { + + match req + .nas_5gs_mobile_identity + .get_mobile_identity() + { + MobileIdentity::NoIdentity(_no_identity) => { + // Todo push some logging here + } + MobileIdentity::Suci(suci) => { + self.nas_context.suci = NonEmptyString::new(suci.to_string()).ok(); + } + MobileIdentity::FiveGGuti(five_gguti) => { + self.nas_context.guti = NonEmptyString::new(five_gguti.get_guti_string()).ok(); + } + MobileIdentity::Imei(imei_or_imei_sv) => { + self.nas_context.pei = NonEmptyString::new(imei_or_imei_sv.to_string()).ok(); + } + MobileIdentity::FiveGSTmsi(five_gtmsi) => { + self.nas_context.tmsi = NonZeroU32::new(five_gtmsi.get_5g_tmsi()); + } + MobileIdentity::Imeisv(imei_or_imei_sv) => { + self.nas_context.pei = NonEmptyString::new(imei_or_imei_sv.to_string()).ok(); + } + MobileIdentity::MacAddress(mac_address) => { + self.nas_context.mac_addr = NonEmptyString::new(mac_address.to_string()).ok(); + } + MobileIdentity::Eui64(eui64) => todo!(), + } + + + + + Ok((GmmState::Authenticated, None)) + } +} diff --git a/lightning-nf/omnipath/app/src/nas/ue_actions.rs b/lightning-nf/omnipath/app/src/nas/ue_actions.rs deleted file mode 100644 index 866d99c..0000000 --- a/lightning-nf/omnipath/app/src/nas/ue_actions.rs +++ /dev/null @@ -1,25 +0,0 @@ -use crate::context::UeContext; - -impl UeContext { - pub async fn handle_nas( - &mut self, - nas_pdu: Vec, - ) { - - // * Need some thought here about how to handle this - - // // Todo: fix this to have a single Bytes for Ngap and Nas - // let mut bytes = Bytes::from(nas_pdu); - - // let mut gmm = self.gmm.clone(); - // // Safety: unwrap over Arc::get_mut will succeed because - // // no one will get a mutable reference to the NasContext - // // and that will only be mutated through the StateMachine - // // Todo:: make nas_context internal field private by mod __private - // if let Ok(gmm_message) = GmmMessage::try_from(&mut bytes) { - // Arc::get_mut(&mut gmm).unwrap().handle_with_context(&gmm_message, - // self); } else { - // trace!("Invalid NAS PDU: {:?}", bytes); - // } - } -} diff --git a/lightning-nf/omnipath/app/src/nas/builders/authentication_request.rs b/lightning-nf/omnipath/app/src/nas_old/builders/authentication_request.rs similarity index 72% rename from lightning-nf/omnipath/app/src/nas/builders/authentication_request.rs rename to lightning-nf/omnipath/app/src/nas_old/builders/authentication_request.rs index 7b109be..603a6ff 100644 --- a/lightning-nf/omnipath/app/src/nas/builders/authentication_request.rs +++ b/lightning-nf/omnipath/app/src/nas_old/builders/authentication_request.rs @@ -1,6 +1,6 @@ use nas_models::message::*; -use crate::nas::{NasContext, NasBuilder, UeContext, NasHandlerError}; +use crate::nas_old::{NasContext, NasBuilder, UeContext, NasHandlerError}; diff --git a/lightning-nf/omnipath/app/src/nas/builders/gmm_status.rs b/lightning-nf/omnipath/app/src/nas_old/builders/gmm_status.rs similarity index 100% rename from lightning-nf/omnipath/app/src/nas/builders/gmm_status.rs rename to lightning-nf/omnipath/app/src/nas_old/builders/gmm_status.rs diff --git a/lightning-nf/omnipath/app/src/nas/builders/mod.rs b/lightning-nf/omnipath/app/src/nas_old/builders/mod.rs similarity index 100% rename from lightning-nf/omnipath/app/src/nas/builders/mod.rs rename to lightning-nf/omnipath/app/src/nas_old/builders/mod.rs diff --git a/lightning-nf/omnipath/app/src/nas/builders/registration_response.rs b/lightning-nf/omnipath/app/src/nas_old/builders/registration_response.rs similarity index 100% rename from lightning-nf/omnipath/app/src/nas/builders/registration_response.rs rename to lightning-nf/omnipath/app/src/nas_old/builders/registration_response.rs diff --git a/lightning-nf/omnipath/app/src/nas/error.rs b/lightning-nf/omnipath/app/src/nas_old/error.rs similarity index 100% rename from lightning-nf/omnipath/app/src/nas/error.rs rename to lightning-nf/omnipath/app/src/nas_old/error.rs diff --git a/lightning-nf/omnipath/app/src/nas/gmm.rs b/lightning-nf/omnipath/app/src/nas_old/gmm.rs similarity index 100% rename from lightning-nf/omnipath/app/src/nas/gmm.rs rename to lightning-nf/omnipath/app/src/nas_old/gmm.rs diff --git a/lightning-nf/omnipath/app/src/nas/gmm_state_machine.png b/lightning-nf/omnipath/app/src/nas_old/gmm_state_machine.png similarity index 100% rename from lightning-nf/omnipath/app/src/nas/gmm_state_machine.png rename to lightning-nf/omnipath/app/src/nas_old/gmm_state_machine.png diff --git a/lightning-nf/omnipath/app/src/nas/handlers/authentication_response.rs b/lightning-nf/omnipath/app/src/nas_old/handlers/authentication_response.rs similarity index 79% rename from lightning-nf/omnipath/app/src/nas/handlers/authentication_response.rs rename to lightning-nf/omnipath/app/src/nas_old/handlers/authentication_response.rs index d205fcd..f02e960 100644 --- a/lightning-nf/omnipath/app/src/nas/handlers/authentication_response.rs +++ b/lightning-nf/omnipath/app/src/nas_old/handlers/authentication_response.rs @@ -1,6 +1,6 @@ use nas_models::message as nas_message; -use crate::nas::{NasContext, NasHandler, UeContext, NasHandlerError}; +use crate::nas_old::{NasContext, NasHandler, UeContext, NasHandlerError}; impl NasHandler for nas_message::NasAuthenticationResponse { diff --git a/lightning-nf/omnipath/app/src/nas/handlers/gmm_status.rs b/lightning-nf/omnipath/app/src/nas_old/handlers/gmm_status.rs similarity index 100% rename from lightning-nf/omnipath/app/src/nas/handlers/gmm_status.rs rename to lightning-nf/omnipath/app/src/nas_old/handlers/gmm_status.rs diff --git a/lightning-nf/omnipath/app/src/nas/handlers/mod.rs b/lightning-nf/omnipath/app/src/nas_old/handlers/mod.rs similarity index 100% rename from lightning-nf/omnipath/app/src/nas/handlers/mod.rs rename to lightning-nf/omnipath/app/src/nas_old/handlers/mod.rs diff --git a/lightning-nf/omnipath/app/src/nas/handlers/registration_request.rs b/lightning-nf/omnipath/app/src/nas_old/handlers/registration_request.rs similarity index 95% rename from lightning-nf/omnipath/app/src/nas/handlers/registration_request.rs rename to lightning-nf/omnipath/app/src/nas_old/handlers/registration_request.rs index 2a61c9d..fbed7ae 100644 --- a/lightning-nf/omnipath/app/src/nas/handlers/registration_request.rs +++ b/lightning-nf/omnipath/app/src/nas_old/handlers/registration_request.rs @@ -4,10 +4,10 @@ use nas_models::message as nas_message; use nas_models::types as nas_types; use non_empty_string::NonEmptyString; -use crate::nas::error::NasHandlerError; -use crate::nas::NasContext; -use crate::nas::UeContext; -use crate::nas::NasHandler; +use crate::nas_old::error::NasHandlerError; +use crate::nas_old::NasContext; +use crate::nas_old::UeContext; +use crate::nas_old::NasHandler; fn initial_registration_handler(nas_registration_request: &nas_message::NasRegistrationRequest, nas_context:&mut NasContext, ue_context: &mut UeContext) -> Result<(), NasHandlerError>{ diff --git a/lightning-nf/omnipath/app/src/nas_old/mod.rs b/lightning-nf/omnipath/app/src/nas_old/mod.rs new file mode 100644 index 0000000..6a5e243 --- /dev/null +++ b/lightning-nf/omnipath/app/src/nas_old/mod.rs @@ -0,0 +1,33 @@ +use std::{future::Future, sync::Arc}; + +use crate::context::UeContext; +use nas_context::NasContext; +use error::NasHandlerError; + + +pub mod nas_context; +mod handlers; +mod error; +mod gmm; +mod builders; +mod ue_actions; + + +pub trait NasHandler { + fn handle( + &self, + nas_context: &mut NasContext, + ue_context: &mut UeContext, + ) -> impl Future> + Send; +} + + +pub trait NasBuilder: Sized { + fn build( + nas_context: &NasContext, + ue_context: &UeContext, + ) -> Result; +} + + + diff --git a/lightning-nf/omnipath/app/src/nas/nas_context.rs b/lightning-nf/omnipath/app/src/nas_old/nas_context.rs similarity index 100% rename from lightning-nf/omnipath/app/src/nas/nas_context.rs rename to lightning-nf/omnipath/app/src/nas_old/nas_context.rs diff --git a/lightning-nf/omnipath/app/src/nas_old/ue_actions.rs b/lightning-nf/omnipath/app/src/nas_old/ue_actions.rs new file mode 100644 index 0000000..e019269 --- /dev/null +++ b/lightning-nf/omnipath/app/src/nas_old/ue_actions.rs @@ -0,0 +1,25 @@ +use crate::context::UeContext; + +// impl UeContext { +// pub async fn handle_nas( +// &mut self, +// nas_pdu: Vec, +// ) { + +// // * Need some thought here about how to handle this + +// // // Todo: fix this to have a single Bytes for Ngap and Nas +// // let mut bytes = Bytes::from(nas_pdu); + +// // let mut gmm = self.gmm.clone(); +// // // Safety: unwrap over Arc::get_mut will succeed because +// // // no one will get a mutable reference to the NasContext +// // // and that will only be mutated through the StateMachine +// // // Todo:: make nas_context internal field private by mod __private +// // if let Ok(gmm_message) = GmmMessage::try_from(&mut bytes) { +// // Arc::get_mut(&mut gmm).unwrap().handle_with_context(&gmm_message, +// // self); } else { +// // trace!("Invalid NAS PDU: {:?}", bytes); +// // } +// } +// } diff --git a/lightning-nf/omnipath/app/src/ngap/core/initial_ue_message.rs b/lightning-nf/omnipath/app/src/ngap/core/initial_ue_message.rs index 8974654..6a90b2a 100644 --- a/lightning-nf/omnipath/app/src/ngap/core/initial_ue_message.rs +++ b/lightning-nf/omnipath/app/src/ngap/core/initial_ue_message.rs @@ -1,20 +1,32 @@ use std::sync::Arc; use ngap_models::{AmfUeNgapId, InitialUeMessage, RanUeNgapId}; -use statig::awaitable::IntoStateMachineExt; use thiserror::Error; -use tokio::sync::OwnedRwLockWriteGuard; +use tokio::sync::{OwnedRwLockWriteGuard, RwLock}; use crate::{ - context::{GnbContext, NgapContext, UeContext}, - nas::nas_context::NasContext, + context::{AtomicGmmState, GmmState, GnbContext, NasContext, NgapContext, UeContext}, ngap::{ engine::{EmptyResponse, NgapRequestHandler, NgapResponseError}, manager::{ContextError, PinnedSendSyncFuture}, }, - utils::models::FiveGSTmsi, + utils::{SeqLock, models::FiveGSTmsi}, }; +async fn create_new_context( + ue_context: UeContext, + gnb: &GnbContext, +) -> Result<(), UeContext> { + match gnb.ue_context_manager.add_context(ue_context).await { + Err(ContextError::ContextAlreadyExists(_, inner)) => { + return Err(inner); + } + Err(_) => unreachable!(), + Ok(_) => (), + }; + Ok(()) +} + impl NgapRequestHandler> for NgapContext { type Success = EmptyResponse; type Failure = EmptyResponse; @@ -46,29 +58,39 @@ impl NgapRequestHandler> for NgapContext { .. } = request; - let ue_context = UeContext::new( - ran_ue_ngap_id, - AmfUeNgapId(state.amf_ue_id_generator.increment()), - rrc_establishment_cause, - state.clone(), - five_g_s_tmsi.map(FiveGSTmsi::from), - Arc::new(NasContext::new().state_machine()), - ); + // Add appropriate context to context manager + match five_g_s_tmsi { + Some(tmsi) => { + // Already registered user. Service Request + // TODO: Fetch the ue context from the idle ue and move it into + // context manager. + } + None => { + let nas_context = + NasContext::new(rrc_establishment_cause, five_g_s_tmsi.map(FiveGSTmsi::from)); + let ue_context = UeContext::new( + ran_ue_ngap_id, + AmfUeNgapId(state.amf_ue_id_generator.increment()), + state.clone(), + AtomicGmmState::new(GmmState::Deregistered), + nas_context + ); - match state.ue_context_manager.add_context(ue_context).await { - Err(ContextError::ContextAlreadyExists(_, inner)) => { - return Err(NgapResponseError::new_empty_failure_error( - UeContextAlreadyExistsError::UeContext(inner), - )); + match create_new_context(ue_context, &state).await { + // Another registration has started, so schedule this event onto that. + // TODO: Add handling of different RRC Establishment Causes + Err(_) => (), + Ok(()) => (), + }; } - Err(_) => unreachable!(), - Ok(_) => (), }; - let future_closure = move |mut ue_context: OwnedRwLockWriteGuard| { + let future_closure = move |mut ue_context: Arc>| { let nas_pdu = nas_pdu.0; Box::pin(async move { - ue_context.handle_nas(nas_pdu).await; + // SAFETY: Context Manager Ensures that futures are executed one by one. + let ue_context_mut = unsafe { ue_context.get_mut() }; + ue_context_mut.handle_nas(nas_pdu).await; }) as PinnedSendSyncFuture<()> }; diff --git a/lightning-nf/omnipath/app/src/ngap/engine/controller.rs b/lightning-nf/omnipath/app/src/ngap/engine/controller.rs index b6ab244..0253c83 100644 --- a/lightning-nf/omnipath/app/src/ngap/engine/controller.rs +++ b/lightning-nf/omnipath/app/src/ngap/engine/controller.rs @@ -9,7 +9,7 @@ use valuable::Valuable; use super::{ decode_ngap_pdu, - interfaces::{NgapRequestHandler, NgapResponseError}, + interface::{NgapRequestHandler, NgapResponseError}, utils::codec_to_bytes, }; use crate::{ diff --git a/lightning-nf/omnipath/app/src/ngap/engine/interfaces.rs b/lightning-nf/omnipath/app/src/ngap/engine/interface.rs similarity index 100% rename from lightning-nf/omnipath/app/src/ngap/engine/interfaces.rs rename to lightning-nf/omnipath/app/src/ngap/engine/interface.rs diff --git a/lightning-nf/omnipath/app/src/ngap/engine/mod.rs b/lightning-nf/omnipath/app/src/ngap/engine/mod.rs index b04314c..fd64b23 100644 --- a/lightning-nf/omnipath/app/src/ngap/engine/mod.rs +++ b/lightning-nf/omnipath/app/src/ngap/engine/mod.rs @@ -1,7 +1,7 @@ pub mod controller; -mod interfaces; +mod interface; mod ue_actions; mod utils; -pub use interfaces::*; +pub use interface::*; pub use utils::decode_ngap_pdu; diff --git a/lightning-nf/omnipath/app/src/ngap/manager/context_manager.rs b/lightning-nf/omnipath/app/src/ngap/manager/context_manager.rs index a6697ea..ecd2e46 100644 --- a/lightning-nf/omnipath/app/src/ngap/manager/context_manager.rs +++ b/lightning-nf/omnipath/app/src/ngap/manager/context_manager.rs @@ -6,6 +6,7 @@ use thiserror::Error; use tokio::sync::OwnedRwLockWriteGuard; use super::context_queue::ContextQueue; +use crate::utils::SeqLock; /// A trait for types that can be identified by a unique ID. /// @@ -192,10 +193,7 @@ impl ContextManager { closure: F, ) -> Result> where - F: FnOnce(OwnedRwLockWriteGuard) -> PinnedSendSyncFuture - + Send - + Sync - + 'static, + F: FnOnce(Arc>) -> PinnedSendSyncFuture + Send + Sync + 'static, O: Send + Sync + 'static, { let element = self.queues.read_async(&id, |_, queue| queue.clone()).await; diff --git a/lightning-nf/omnipath/app/src/ngap/manager/context_queue.rs b/lightning-nf/omnipath/app/src/ngap/manager/context_queue.rs index f4b2d61..53722f9 100644 --- a/lightning-nf/omnipath/app/src/ngap/manager/context_queue.rs +++ b/lightning-nf/omnipath/app/src/ngap/manager/context_queue.rs @@ -8,10 +8,11 @@ use std::{ }, }; -use tokio::sync::{Mutex, OwnedRwLockWriteGuard, RwLock, oneshot}; +use tokio::sync::{Mutex, oneshot}; use tracing::Instrument; use super::context_manager::Identifiable; +use crate::utils::SeqLock; /// `ContextQueue` is designed to manage and execute asynchronous operations on /// a shared context (`T`) in a sequential manner. It combines an internal, @@ -33,7 +34,7 @@ use super::context_manager::Identifiable; /// empty. pub(crate) struct ContextQueue { - inner: Arc>, + inner: Arc>, queue: Arc + Send + 'static>>>>>, processor_active: AtomicBool, } @@ -41,7 +42,7 @@ pub(crate) struct ContextQueue { impl ContextQueue { pub fn new(context: T) -> Self { ContextQueue { - inner: Arc::new(RwLock::new(context)), + inner: Arc::new(SeqLock::new(context)), queue: Arc::new(Mutex::new(VecDeque::new())), processor_active: AtomicBool::new(false), } @@ -52,7 +53,7 @@ impl ContextQueue { /// operations. pub unsafe fn into_inner(self) -> Option { let t = Arc::into_inner(self.inner); - t.map(|t| t.into_inner()) + t.map(|t| unsafe { t.into_inner() }) } } @@ -132,9 +133,7 @@ where tx: oneshot::Sender, ) -> Pin + Send + Sync + 'static>> where - F: FnOnce( - OwnedRwLockWriteGuard, - ) -> Pin + Send + Sync + 'static>> + F: FnOnce(Arc>) -> Pin + Send + Sync + 'static>> + Send + Sync + 'static, @@ -142,8 +141,10 @@ where { let context = self.inner.clone(); Box::pin(async move { - let context = context.write_owned().await; - let id = *context.id(); + let context = context; + // SAFETY: Context Manager Ensures that the future is executed one by one. + let id = *unsafe { context.get() }.id(); + let future = closure(context).instrument(tracing::info_span!("ContextQueue", id = ?id)); let output = future.await; // The receiver will be waiting for this result, so we ignore the @@ -188,9 +189,7 @@ where closure: F, ) -> O where - F: FnOnce( - OwnedRwLockWriteGuard, - ) -> Pin + Send + Sync + 'static>> + F: FnOnce(Arc>) -> Pin + Send + Sync + 'static>> + Send + Sync + 'static, diff --git a/lightning-nf/omnipath/app/src/utils/mod.rs b/lightning-nf/omnipath/app/src/utils/mod.rs index bc23349..47505a2 100644 --- a/lightning-nf/omnipath/app/src/utils/mod.rs +++ b/lightning-nf/omnipath/app/src/utils/mod.rs @@ -1,4 +1,6 @@ mod convert; +mod seq_lock; pub use convert::{convert, try_convert}; -pub mod models; \ No newline at end of file +pub mod models; +pub use seq_lock::SeqLock; \ No newline at end of file diff --git a/lightning-nf/omnipath/app/src/utils/seq_lock.rs b/lightning-nf/omnipath/app/src/utils/seq_lock.rs new file mode 100644 index 0000000..74f76a3 --- /dev/null +++ b/lightning-nf/omnipath/app/src/utils/seq_lock.rs @@ -0,0 +1,32 @@ +use std::{cell::UnsafeCell, ops::Deref, sync::Arc}; + +pub struct SeqLock { + data: UnsafeCell, +} + +unsafe impl Send for SeqLock {} +unsafe impl Sync for SeqLock {} + +impl SeqLock { + pub fn new(data: T) -> Self { + Self { + data: UnsafeCell::new(data), + } + } + + pub unsafe fn get_mut(&self) -> &mut T { + // SAFETY: Context manager ensures only one future accesses this at a time + unsafe { &mut *self.data.get() } + } + + pub unsafe fn get(&self) -> &T { + // SAFETY: Same as above + unsafe { &*self.data.get() } + } + + pub unsafe fn into_inner(self) -> T { + self.data.into_inner() + } + +} + diff --git a/utils/atomic-handle/Cargo.toml b/utils/atomic-handle/Cargo.toml new file mode 100644 index 0000000..79d8f53 --- /dev/null +++ b/utils/atomic-handle/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "atomic-handle" +version.workspace = true +edition.workspace = true +authors.workspace = true +repository.workspace = true +homepage.workspace = true +description.workspace = true +publish.workspace = true +readme.workspace = true + +[dependencies] diff --git a/utils/atomic-handle/src/lib.rs b/utils/atomic-handle/src/lib.rs new file mode 100644 index 0000000..21ff8dd --- /dev/null +++ b/utils/atomic-handle/src/lib.rs @@ -0,0 +1,273 @@ +use std::{ + marker::PhantomData, ops::Deref, sync::{ + atomic::{AtomicUsize, Ordering}, Arc + } +}; + + +/// Trait that types must implement to be used with AtomicHandle. +/// The generic parameter `Field` allows the same type to implement +/// this trait multiple times for different atomic fields. +pub trait AtomicOperation { + /// Returns a reference to the atomic field identified by the Field type + /// parameter + fn get_atomic(&self) -> &AtomicUsize; +} + +/// A handle that provides safe shared access to an AtomicUsize field within an +/// Arc-managed struct. +/// +/// The `Field` type parameter is used to distinguish between different atomic +/// fields in the same struct using zero-sized marker types. +/// +/// # Type Parameters +/// * `S` - The struct type that contains the atomic field +/// * `Field` - A marker type to identify which atomic field this handle +/// accesses +/// +/// # Safety +/// This struct maintains a raw pointer to the atomic field but keeps the +/// original Arc alive, ensuring the memory remains valid for the lifetime of +/// the handle. +pub struct AtomicHandle +where + S: AtomicOperation + Send + Sync , +{ + /// Keeps the original Arc alive to prevent memory deallocation + _owner: Arc, + /// Raw pointer to the specific atomic field within the struct + atomic_ptr: *const AtomicUsize, + /// Zero-sized marker to track which field this handle represents + _phantom: PhantomData, +} + +// Safety: AtomicHandle is Send because: +// - Arc is Send when S: Send + Sync +// - The raw pointer points to memory kept alive by the Arc +// - AtomicUsize operations are thread-safe +unsafe impl Send for AtomicHandle where S: AtomicOperation + Send + Sync {} + +// Safety: AtomicHandle is Sync because: +// - Arc is Sync when S: Send + Sync +// - The atomic operations are inherently thread-safe +// - Multiple threads can safely share references to this handle +unsafe impl Sync for AtomicHandle where S: AtomicOperation + Send + Sync {} + +impl AtomicHandle +where + S: AtomicOperation + Send + Sync, +{ + /// Creates a new AtomicHandle for the specified field in the given Arc. + /// + /// # Arguments + /// * `owner` - The Arc containing the atomic field + /// + /// # Returns + /// A new AtomicHandle that provides access to the atomic field + /// + /// # Example + /// ``` + /// # use std::sync::{Arc, atomic::AtomicUsize}; + /// # use std::marker::PhantomData; + /// # use std::ops::Deref; + /// # pub trait AtomicOperation { + /// # fn get_atomic(&self) -> &AtomicUsize; + /// # } + /// # pub struct AtomicHandle where S: AtomicOperation + Send + Sync { + /// # _owner: Arc, + /// # atomic_ptr: *const AtomicUsize, + /// # _phantom: PhantomData, + /// # } + /// # impl AtomicHandle where S: AtomicOperation + Send + Sync { + /// # pub fn new(owner: Arc) -> Self { + /// # let atomic_ptr = owner.get_atomic() as *const AtomicUsize; + /// # Self { _owner: owner, atomic_ptr, _phantom: PhantomData } + /// # } + /// # } + /// # struct Counter; + /// # struct MyStruct { counter: AtomicUsize } + /// # impl AtomicOperation for MyStruct { + /// # fn get_atomic(&self) -> &AtomicUsize { &self.counter } + /// # } + /// let my_struct = Arc::new(MyStruct { counter: AtomicUsize::new(0) }); + /// let handle: AtomicHandle = AtomicHandle::new(my_struct); + /// ``` + pub fn new(owner: Arc) -> Self { + let atomic_ptr = owner.get_atomic() as *const AtomicUsize; + Self { + _owner: owner, + atomic_ptr, + _phantom: PhantomData, + } + } + + /// Returns a reference to the underlying AtomicUsize. + /// + /// This is safe because the Arc is kept alive for the lifetime of this + /// handle. + pub fn get(&self) -> &AtomicUsize { + unsafe { &*self.atomic_ptr } + } + +} + +impl Deref for AtomicHandle +where + S: AtomicOperation + Send + Sync, +{ + type Target = AtomicUsize; + + fn deref(&self) -> &Self::Target { + self.get() + } +} + +// Manual implementation of Clone. +impl Clone for AtomicHandle +where + S: AtomicOperation + Send + Sync, +{ + fn clone(&self) -> Self { + Self { + _owner: Arc::clone(&self._owner), + atomic_ptr: self.atomic_ptr, + _phantom: PhantomData, + } + } +} + +impl std::fmt::Debug for AtomicHandle +where + S: AtomicOperation + Send + Sync, +{ + fn fmt( + &self, + f: &mut std::fmt::Formatter<'_>, + ) -> std::fmt::Result { + f.debug_struct("AtomicHandle") + .field("current_value", &self.load(Ordering::SeqCst)) + .field("field_type", &std::any::type_name::()) + .field("struct_type", &std::any::type_name::()) + .finish() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Type alias for the default (unnamed) atomic field handle + pub type DefaultAtomicHandle = AtomicHandle; + + // Test marker types + struct Counter; + struct Requests; + struct Errors; + + // Test struct with multiple atomics + struct TestStruct { + name: String, + counter: AtomicUsize, + requests: AtomicUsize, + errors: AtomicUsize, + } + + // Implement AtomicOperation for each field + impl AtomicOperation for TestStruct { + fn get_atomic(&self) -> &AtomicUsize { + &self.counter + } + } + + impl AtomicOperation for TestStruct { + fn get_atomic(&self) -> &AtomicUsize { + &self.requests + } + } + + impl AtomicOperation for TestStruct { + fn get_atomic(&self) -> &AtomicUsize { + &self.errors + } + } + + // Default implementation (could be any field you choose as default) + impl AtomicOperation for TestStruct { + fn get_atomic(&self) -> &AtomicUsize { + &self.counter // Default to counter field + } + } + + type CounterHandle = AtomicHandle; + type RequestsHandle = AtomicHandle; + type ErrorsHandle = AtomicHandle; + + #[test] + fn test_multiple_field_handles() { + let test_struct = Arc::new(TestStruct { + name: "test".to_string(), + counter: AtomicUsize::new(10), + requests: AtomicUsize::new(20), + errors: AtomicUsize::new(30), + }); + + let counter_handle: CounterHandle = AtomicHandle::new(Arc::clone(&test_struct)); + let requests_handle: RequestsHandle = AtomicHandle::new(Arc::clone(&test_struct)); + let errors_handle: ErrorsHandle = AtomicHandle::new(Arc::clone(&test_struct)); + + // Test initial values + assert_eq!(counter_handle.load(Ordering::SeqCst), 10); + assert_eq!(requests_handle.load(Ordering::SeqCst), 20); + assert_eq!(errors_handle.load(Ordering::SeqCst), 30); + + // Test modifications + counter_handle.store(100, Ordering::SeqCst); + requests_handle.fetch_add(5, Ordering::SeqCst); + errors_handle.fetch_sub(10, Ordering::SeqCst); + + // Verify changes + assert_eq!(counter_handle.load(Ordering::SeqCst), 100); + assert_eq!(requests_handle.load(Ordering::SeqCst), 25); + assert_eq!(errors_handle.load(Ordering::SeqCst), 20); + + // Verify original struct is also modified + assert_eq!(test_struct.counter.load(Ordering::SeqCst), 100); + assert_eq!(test_struct.requests.load(Ordering::SeqCst), 25); + assert_eq!(test_struct.errors.load(Ordering::SeqCst), 20); + } + + #[test] + fn test_handle_cloning() { + let test_struct = Arc::new(TestStruct { + name: "test".to_string(), + counter: AtomicUsize::new(0), + requests: AtomicUsize::new(0), + errors: AtomicUsize::new(0), + }); + + let handle1: CounterHandle = AtomicHandle::new(test_struct); + let handle2 = handle1.clone(); + + handle1.fetch_add(10, Ordering::SeqCst); + handle2.fetch_add(5, Ordering::SeqCst); + + assert_eq!(handle1.load(Ordering::SeqCst), 15); + assert_eq!(handle2.load(Ordering::SeqCst), 15); + } + + #[test] + fn test_atomic_address() { + let test_struct = Arc::new(TestStruct { + name: "test".to_string(), + counter: AtomicUsize::new(42), + requests: AtomicUsize::new(0), + errors: AtomicUsize::new(0), + }); + + let handle: CounterHandle = AtomicHandle::new(test_struct.clone()); + let counter_ref = &test_struct.counter as *const AtomicUsize; + let counter_handle_ref = handle.get() as *const AtomicUsize; + + assert_eq!(counter_ref, counter_handle_ref); + } +} \ No newline at end of file diff --git a/utils/client/src/lib.rs b/utils/client/src/lib.rs index 0d335f5..1085aa3 100644 --- a/utils/client/src/lib.rs +++ b/utils/client/src/lib.rs @@ -1,5 +1,6 @@ #![feature(error_generic_member_access)] #![feature(adt_const_params)] +#![feature(generic_const_exprs)] #![feature(async_closure)] use std::{backtrace::Backtrace, error::Error, fmt::Debug}; diff --git a/utils/client/src/nf_clients/mod.rs b/utils/client/src/nf_clients/mod.rs index aea3506..0c0c7f8 100644 --- a/utils/client/src/nf_clients/mod.rs +++ b/utils/client/src/nf_clients/mod.rs @@ -1,4 +1,11 @@ -use std::{convert::Infallible, iter::once, net::SocketAddr, ops::AsyncFn, sync::Arc}; +use std::{ + convert::Infallible, + iter::once, + marker::PhantomData, + net::SocketAddr, + ops::AsyncFn, + sync::Arc, +}; use bytes::Bytes; use http::{ @@ -11,11 +18,13 @@ use http::{ request::Builder as HttpReqBuilder, }; use http_body_util::BodyExt; +use hyper_util::service; use oasbi::{DeserResponse, common::NfType, nrf::types::NfProfile}; use openapi_nrf::models::{ SearchNfInstancesHeaderParams, SearchNfInstancesQueryParams, SearchResult, + ServiceName, }; use reqwest::{Body, Client, ClientBuilder, Request, Response}; use serde::Serialize; @@ -42,9 +51,14 @@ use tower_reqwest::{HttpClientLayer, HttpClientService}; use url::Url; pub mod amf; +mod oauth_service; +mod service_discovery; +use oauth_service::OAuthTokenLayer; +use service_discovery::ServiceDiscoveryLayer; use crate::{ GenericClientError, + nf_clients::{oauth_service::OAuthTokenService, service_discovery::ServiceDiscovery}, nrf_client::{NrfClient, NrfDiscoveryError}, to_headers, }; @@ -59,6 +73,7 @@ pub trait NfClientController { &self, search_result: SearchResult, ) -> NfProfile; + fn get_search_params( &self, requester_nf_type: NfType, @@ -71,39 +86,67 @@ pub trait NfClientController { } } -type TowerReqwestClient = SetSensitiveRequestHeaders< - Trace, SharedClassifier>, +type TowerReqwestClient = ServiceDiscovery< + OAuthTokenService< + SetSensitiveRequestHeaders< + Trace, +SharedClassifier>, >, + TARGET_TYPE, + >, + T, + TARGET_TYPE, >; -pub struct NFClient { - nrf_client: Arc, - controller: T, - nf_profile: NfProfile, - req_client: TowerReqwestClient, +// type TowerReqwestClient = ServiceDiscovery< +// SetSensitiveRequestHeaders< +// Trace, SharedClassifier>, +// >, +// T, +// TARGET_TYPE, +// >; + +// type TowerReqwestClient = OAuthTokenService< +// SetSensitiveRequestHeaders< +// Trace, SharedClassifier>, +// >, +// TARGET_TYPE, +// >; + +pub struct NFClient +where + T: NfClientController + Send + Sync + 'static, +{ + req_client: TowerReqwestClient, + controller: PhantomData, } -impl NFClient +impl NFClient where - T: NfClientController + ApiBaseUrl, + T: NfClientController + Send + Sync + 'static, { pub async fn new( nrf_client: Arc, controller: T, + services: Vec, ) -> Result { // let url = controller.base_url(); - let search_params = controller.get_search_params(APP_TYPE); - let header_params = SearchNfInstancesHeaderParams { - ..Default::default() - }; - let search_result = nrf_client - .search_nf_instance(search_params, header_params) - .await?; - let nf_profile = controller.profile_selection(search_result); let builder = ClientBuilder::new(); let client = builder.build()?; + // let oauth_layer = OAuthTokenLayer::new(nrf_client.clone(), services.clone()); + let arc_services = services.into(); + let arc_controller = Arc::new(controller); + + let service_discovery_layer = ServiceDiscoveryLayer::::new( + nrf_client, + arc_controller, + T::CLIENT_TYPE, + arc_services, + ); let service = ServiceBuilder::new() // Mark the `Authorization` request header as sensitive so it doesn't show in logs + .layer(service_discovery_layer) + .layer(oauth_layer) .layer(SetSensitiveRequestHeadersLayer::new(once(AUTHORIZATION))) // High level logging of requests and responses .layer(TraceLayer::new_for_http()) @@ -111,9 +154,7 @@ where .service(client); Ok(NFClient { - nrf_client, - controller, - nf_profile, + controller: PhantomData, req_client: service, }) } diff --git a/utils/client/src/nf_clients/oauth_service.rs b/utils/client/src/nf_clients/oauth_service.rs new file mode 100644 index 0000000..487190b --- /dev/null +++ b/utils/client/src/nf_clients/oauth_service.rs @@ -0,0 +1,180 @@ +use std::{ + future::Future, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; + +use http::{Request as HttpRequest, Response as HttpResponse}; +use oasbi::common::NfType; +use openapi_nrf::models::ServiceName; +use reqwest::Body; +use tower::{BoxError, Layer, Service}; + +use crate::nrf_client::NrfClient; + +/// OAuth Token Service that automatically adds authentication headers to +/// requests +#[derive(Clone)] +pub struct OAuthTokenService { + inner: S, + nrf_client: Arc, + service_names: Vec, +} + +impl OAuthTokenService { + /// Create a new OAuth Token Service + pub fn new( + inner: S, + nrf_client: Arc, + service_names: Vec, + ) -> Self { + Self { + inner, + nrf_client, + service_names, + } + } + + /// Get a reference to the inner service + pub fn inner(&self) -> &S { + &self.inner + } + + /// Get a mutable reference to the inner service + pub fn inner_mut(&mut self) -> &mut S { + &mut self.inner + } + + /// Consume this service and return the inner service + pub fn into_inner(self) -> S { + self.inner + } + + /// Get the service names this service will request tokens for + pub fn service_names(&self) -> &[ServiceName] { + &self.service_names + } + + /// Get a reference to the NRF client + pub fn nrf_client(&self) -> &Arc { + &self.nrf_client + } +} + +impl Service> + for OAuthTokenService +where + S: Service, Response = HttpResponse, Error = BoxError> + + Clone + + Send + + 'static, + S::Future: Send + 'static, +{ + type Response = S::Response; + type Error = BoxError; + type Future = Pin> + Send>>; + + fn poll_ready( + &mut self, + cx: &mut Context<'_>, + ) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call( + &mut self, + mut req: HttpRequest, + ) -> Self::Future { + let mut inner = self.inner.clone(); + let nrf_client = self.nrf_client.clone(); + let service_names = self.service_names.clone(); + + Box::pin(async move { + // Try to add OAuth token if available and OAuth is enabled + if let Err(e) = nrf_client + .set_auth_token::(req.headers_mut(), service_names) + .await + { + // Convert NrfAuthorizationError to BoxError + return Err(Box::new(e) as BoxError); + } + + // Token added successfully (or OAuth not enabled), proceed with request + inner.call(req).await + }) + } +} + +/// Layer for creating the OAuth Token Service +#[derive(Clone)] +pub struct OAuthTokenLayer { + nrf_client: Arc, + service_names: Vec, +} + +impl OAuthTokenLayer { + /// Create a new OAuth Token Layer + pub fn new( + nrf_client: Arc, + service_names: Vec, + ) -> Self { + Self { + nrf_client, + service_names, + } + } + + /// Create a new OAuth Token Layer with a single service name + pub fn with_service( + nrf_client: Arc, + service_name: ServiceName, + ) -> Self { + Self { + nrf_client, + service_names: vec![service_name], + } + } + + /// Get the service names this layer will request tokens for + pub fn service_names(&self) -> &[ServiceName] { + &self.service_names + } + + /// Get a reference to the NRF client + pub fn nrf_client(&self) -> &Arc { + &self.nrf_client + } +} + +impl Layer for OAuthTokenLayer { + type Service = OAuthTokenService; + + fn layer( + &self, + inner: S, + ) -> Self::Service { + OAuthTokenService::new(inner, self.nrf_client.clone(), self.service_names.clone()) + } +} + +#[cfg(test)] +mod tests { + use std::convert::Infallible; + + use tower::service_fn; + + use super::*; + + // Mock service for testing + fn mock_service() + -> impl Service, Response = HttpResponse, Error = BoxError> + Clone { + service_fn(|_req| async { Ok(HttpResponse::new(Body::from("test response"))) }) + } + + #[tokio::test] + async fn test_oauth_layer_creation() { + // This test would need a mock NrfClient to run properly + // Left as a structure example + } +} diff --git a/utils/client/src/nf_clients/service_discovery/future.rs b/utils/client/src/nf_clients/service_discovery/future.rs new file mode 100644 index 0000000..095a80a --- /dev/null +++ b/utils/client/src/nf_clients/service_discovery/future.rs @@ -0,0 +1,68 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; +use http::Response as HttpResponse; +use tower::BoxError; +use pin_project_lite::pin_project; +use reqwest::Body; + + +pin_project! { + /// Future for the `ServiceDiscovery` service. + pub struct ResponseFuture { + #[pin] + state: State, + } +} + +pin_project! { + #[project = StateProj] + enum State { + // The future returned by the inner service. + Inner { #[pin] fut: F }, + // An error occurred before the inner service was called. + // The Option is used to allow moving the error out on the first poll. + Error { error: Option }, + } +} + +impl ResponseFuture { + /// Creates a new `ResponseFuture` in the `Inner` state. + pub(crate) fn new(fut: F) -> Self { + Self { + state: State::Inner { fut }, + } + } + + /// Creates a new `ResponseFuture` in the `Error` state. + pub(crate) fn error(error: BoxError) -> Self { + Self { + state: State::Error { error: Some(error) }, + } + } +} + +impl Future for ResponseFuture +where + F: Future, E>>, + E: Into, +{ + type Output = Result, BoxError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.project().state.project() { + StateProj::Inner { fut } => { + // Poll the inner future and map its error type + fut.poll(cx).map(|res| res.map_err(Into::into)) + } + StateProj::Error { error } => { + // The error is ready immediately. + // We take it from the Option to ensure it's only returned once. + let e = error.take().expect("ResponseFuture polled after completion"); + Poll::Ready(Err(e)) + } + } + } +} \ No newline at end of file diff --git a/utils/client/src/nf_clients/service_discovery/mod.rs b/utils/client/src/nf_clients/service_discovery/mod.rs new file mode 100644 index 0000000..53c6b19 --- /dev/null +++ b/utils/client/src/nf_clients/service_discovery/mod.rs @@ -0,0 +1,416 @@ +mod future; +use std::{ + collections::HashMap, + future::Future, + pin::Pin, + sync::Arc, + task::{Context, Poll}, + time::{Duration, SystemTime, UNIX_EPOCH}, +}; + +use arc_swap::ArcSwap; +use future::ResponseFuture; +use http::{Request as HttpRequest, Response as HttpResponse, Uri, uri::InvalidUri}; +use oasbi::{ + common::NfType, + nrf::types::{NfProfile, ServiceName}, +}; +use openapi_nrf::models::SearchNfInstancesHeaderParams; +use reqwest::Body; +use tower::{BoxError, Layer, Service}; +use url::Url; + +use super::NfClientController; +use crate::nrf_client::{NrfClient, NrfDiscoveryError}; + +/// Create an empty cached profile +fn empty_cached_profile() -> CachedNfProfile { + CachedNfProfile { + nf_profile: None, + service_urls: HashMap::new(), + validity_period: None, + cached_at: UNIX_EPOCH, + } +} + +/// Cached NF profile with validity period +#[derive(Debug, Clone)] +pub struct CachedNfProfile { + pub nf_profile: Option, + pub service_urls: HashMap>, + pub validity_period: Option, + pub cached_at: SystemTime, +} + +impl CachedNfProfile { + pub fn new( + nf_profile: Option, + validity_period: Option, + ) -> Self { + let service_urls = nf_profile + .as_ref() + .map(|profile| { + let mut map = HashMap::new(); + for service in &profile.nf_services { + let urls: Vec = service + .ip_end_points + .iter() + .filter_map(|endpoint| { + let scheme = service.scheme.to_string(); + let host = endpoint.ipv4_address.as_ref()?; + let port = endpoint.port?; + + let url_str = format!("{}://{:?}:{}", scheme, host, port); + Url::parse(&url_str).ok() + }) + .collect(); + + if !urls.is_empty() { + map.insert(service.service_name.clone(), urls); + } + } + map + }) + .unwrap_or_default(); + + Self { + nf_profile, + service_urls, + validity_period, + cached_at: SystemTime::now(), + } + } + + pub fn is_valid(&self) -> bool { + if let Some(validity_seconds) = self.validity_period { + SystemTime::now() + .duration_since(self.cached_at) + .unwrap_or(Duration::MAX) + .as_secs() < validity_seconds as u64 + } else { + self.nf_profile.is_some() // Valid if we have a profile and no expiry + } + } + + pub fn get_base_url(&self) -> Option<&Url> { + // Return the first URL from the first service + self.service_urls.values().flatten().next() + } + + pub fn get_service_urls( + &self, + service_name: &ServiceName, + ) -> Option<&Vec> { + self.service_urls.get(service_name) + } + + pub fn get_first_service_url( + &self, + service_name: &ServiceName, + ) -> Option<&Url> { + self.service_urls + .get(service_name) + .and_then(|urls| urls.first()) + } +} + +/// Service discovery layer with atomic cache +pub struct ServiceDiscoveryLayer { + // Use ArcSwap for lock-free reads, only one writer at a time + cached_profile: Arc>, + nrf_client: Arc, + controller: Arc, + service_names: Arc<[ServiceName]>, + app_type: NfType, + // Prevent multiple concurrent discoveries + discovery_semaphore: Arc, +} + +impl Clone for ServiceDiscoveryLayer { + fn clone(&self) -> Self { + Self { + cached_profile: self.cached_profile.clone(), + nrf_client: self.nrf_client.clone(), + controller: self.controller.clone(), + app_type: self.app_type, + discovery_semaphore: self.discovery_semaphore.clone(), + service_names: self.service_names.clone(), + } + } +} + +impl ServiceDiscoveryLayer +where + T: NfClientController, +{ + pub fn new( + nrf_client: Arc, + controller: Arc, + app_type: NfType, + service_names: Arc<[ServiceName]>, + ) -> Self { + Self { + cached_profile: Arc::new(ArcSwap::from_pointee(empty_cached_profile())), + nrf_client, + controller, + app_type, + discovery_semaphore: Arc::new(tokio::sync::Semaphore::new(1)), + service_names, + } + } + + /// Check if the cache is valid (fast path) + pub fn is_cache_valid(&self) -> bool { + self.cached_profile.load().is_valid() + } + + /// Get the cached profile if valid + pub fn get_cached_profile(&self) -> Option> { + let cached = self.cached_profile.load(); + if cached.is_valid() { + Some(cached.clone()) + } else { + None + } + } + + /// Private helper to handle the "slow path" where the cache is stale. + /// This function manages the semaphore to ensure only one discovery happens + /// at a time. + pub async fn discover_and_update_cache( + self: Self + ) -> Result, ServiceDiscoveryError> { + // Acquire a permit. This will pause if another task is already updating. + let _permit = self.discovery_semaphore.acquire().await.unwrap(); + + // Perform the crucial "double-check". The cache might have been + // updated by the task we were waiting for. + let cached_after_wait = self.cached_profile.load(); + if cached_after_wait.is_valid() { + return Ok(cached_after_wait.clone()); + } + + // --- We are the designated "worker" --- + // The cache is still stale, so we must perform the discovery. + let search_params = self.controller.get_search_params(self.app_type); + let header_params = SearchNfInstancesHeaderParams { + ..Default::default() + }; + + match self + .nrf_client + .search_nf_instance(search_params, header_params) + .await + { + Ok(search_result) => { + let validity_period = search_result.validity_period.clone(); + let nf_profile = self.controller.profile_selection(search_result); + let new_cached_arc = + Arc::new(CachedNfProfile::new(Some(nf_profile), validity_period)); + self.cached_profile.store(new_cached_arc.clone()); + Ok(new_cached_arc) + } + Err(e) => { + // On failure, clear the cache to ensure the next request can retry. + self.cached_profile.store(Arc::new(empty_cached_profile())); + Err(e.into()) + } + } + } + + pub fn invalidate_cache(&self) { + // Atomic invalidation using empty profile + self.cached_profile.store(Arc::new(empty_cached_profile())); + } +} + +impl Layer for ServiceDiscoveryLayer +where + T: NfClientController, +{ + type Service = ServiceDiscovery; + + fn layer( + &self, + inner: S, + ) -> Self::Service { + ServiceDiscovery { + inner, + layer: self.clone(), + discovery_future: None, + } + } +} + +/// Applies service discovery to requests. +pub struct ServiceDiscovery { + inner: S, + layer: ServiceDiscoveryLayer, + // Store the discovery future in the service to poll it in poll_ready + discovery_future: Option< + Pin, ServiceDiscoveryError>> + Send>>, + >, +} + +impl Clone for ServiceDiscovery +where + S: Clone, +{ + fn clone(&self) -> Self { + ServiceDiscovery { + inner: self.inner.clone(), + layer: self.layer.clone(), + discovery_future: None, // Don't clone the discovery future + } + } +} + +impl ServiceDiscovery { + /// Creates a new [`ServiceDiscovery`] + pub fn new( + inner: S, + layer: ServiceDiscoveryLayer, + ) -> Self { + ServiceDiscovery { + inner, + layer, + discovery_future: None, + } + } + + /// Get a reference to the inner service + pub fn get_ref(&self) -> &S { + &self.inner + } + + /// Get a mutable reference to the inner service + pub fn get_mut(&mut self) -> &mut S { + &mut self.inner + } + + /// Consume `self`, returning the inner service + pub fn into_inner(self) -> S { + self.inner + } +} + +impl Service> + for ServiceDiscovery +where + S: Service, Response = HttpResponse>, + S::Error: Into, + T: NfClientController + Send + Sync + 'static, +{ + type Response = S::Response; + type Error = BoxError; + type Future = ResponseFuture; + + fn poll_ready( + &mut self, + cx: &mut Context<'_>, + ) -> Poll> { + // First check if we have a valid cache (fast path) + if self.layer.is_cache_valid() { + // Cache is valid, just check if inner service is ready + return self.inner.poll_ready(cx).map_err(Into::into); + } + + // Cache is stale, we need to discover + loop { + if let Some(ref mut discovery_fut) = self.discovery_future { + // We have an ongoing discovery, poll it + match discovery_fut.as_mut().poll(cx) { + Poll::Ready(Ok(_cached_profile)) => { + // Discovery completed successfully + self.discovery_future = None; + // Now check if inner service is ready + return self.inner.poll_ready(cx).map_err(Into::into); + } + Poll::Ready(Err(e)) => { + // Discovery failed + self.discovery_future = None; + return Poll::Ready(Err(Box::new(e))); + } + Poll::Pending => { + // Discovery is still in progress + return Poll::Pending; + } + } + } else { + // No ongoing discovery, start one + self.discovery_future = Some(Box::pin(ServiceDiscoveryLayer::discover_and_update_cache(self.layer.clone()))); + // Continue the loop to poll the new future + } + } + } + + fn call( + &mut self, + mut request: HttpRequest, + ) -> Self::Future { + // At this point, poll_ready has ensured we have a valid cache + match self.layer.get_cached_profile() { + Some(cached_profile) => { + // Update the request URL with the discovered service URL + match Self::update_request_url(&mut request, &cached_profile) { + Ok(()) => { + // URL updated successfully, create the success future + let inner_future = self.inner.call(request); + ResponseFuture::new(inner_future) + } + Err(e) => { + // URL update failed, create the error future + ResponseFuture::error(e.into()) + } + } + } + None => { + // This should not happen if poll_ready worked correctly, + // but handle it gracefully by creating an error future. + let error = Box::new(ServiceDiscoveryError::NoServiceUrl) as BoxError; + ResponseFuture::error(error) + } + } + } +} + +impl ServiceDiscovery { + fn update_request_url( + request: &mut HttpRequest, + cached_profile: &CachedNfProfile, + ) -> Result<(), ServiceDiscoveryError> { + match cached_profile.get_base_url() { + Some(base_url) => { + let original_path = request + .uri() + .path_and_query() + .map(|pq| pq.as_str()) + .unwrap_or("/"); + + let new_url = base_url.join(original_path)?; + let new_uri: Uri = new_url.as_str().parse()?; + *request.uri_mut() = new_uri; + Ok(()) + } + None => Err(ServiceDiscoveryError::NoServiceUrl), + } + } +} + +#[derive(thiserror::Error, Debug)] +pub enum ServiceDiscoveryError { + #[error("NRF discovery failed: {0}")] + NrfError(#[from] NrfDiscoveryError), + + #[error("No service URL found")] + NoServiceUrl, + + #[error(transparent)] + UrlParseError(#[from] url::ParseError), + + #[error(transparent)] + InvalidUri(#[from] InvalidUri), + + #[error(transparent)] + BoxError(#[from] BoxError), +} diff --git a/utils/client/src/nrf_client.rs b/utils/client/src/nrf_client.rs index 14386f0..17a0062 100644 --- a/utils/client/src/nrf_client.rs +++ b/utils/client/src/nrf_client.rs @@ -1,8 +1,17 @@ -use std::{backtrace::Backtrace, str::FromStr, sync::Arc}; +use std::{ + backtrace::Backtrace, + str::FromStr, + sync::Arc, + time::{SystemTime, UNIX_EPOCH}, +}; use arc_swap::ArcSwap; use formatx::formatx; -use http::header::{self, AUTHORIZATION}; +use http::{ + HeaderMap, + Request as HttpRequest, + header::{self, AUTHORIZATION}, +}; use oasbi::{ DeserResponse, common::{ @@ -286,7 +295,7 @@ impl NrfClient { Option::<&TraitSatisfier>::None, ContentType::AppJson, )?; - self.set_auth_token::<{ NfType::Nrf }>(&mut request, vec![ServiceName::NnrfNfm]) + self.set_auth_token::<{ NfType::Nrf }>(request.headers_mut(), vec![ServiceName::NnrfNfm]) .await?; let response = self .client @@ -393,33 +402,32 @@ impl NrfClient { ) -> Result, NrfAuthorizationError> { let token_entry = self.nf_token_store.get(&target_service_name).await?; match token_entry { - Some(entry) => Ok(entry), - None => { - let resp = self - .nf_token_store - .set( - target_service_name.clone(), - self.authenticaion_request( - self.nf_config.load().nf_instance_id, - self.init_config.source, - T, - target_service_name, - ), - ) - .await?; - Ok(resp) - } - } + Some(entry) if !is_token_expired(entry.get()) => return Ok(entry), + _ => (), + }; + let resp = self + .nf_token_store + .set( + target_service_name.clone(), + self.authenticaion_request( + self.nf_config.load().nf_instance_id, + self.init_config.source, + T, + target_service_name, + ), + ) + .await?; + Ok(resp) } pub async fn set_auth_token( &self, - req: &mut Request, + headers_mut: &mut HeaderMap, service_name: Vec, ) -> Result<(), NrfAuthorizationError> { if self.nf_config.load().oauth_enabled { let token_entry = self.get_token::(service_name).await?; - set_auth_token(req, token_entry)?; + set_auth_token(headers_mut, token_entry)?; } Ok(()) } @@ -486,24 +494,54 @@ pub enum NrfAuthorizationError { TokenParsingError(#[from] header::InvalidHeaderValue), } -pub(crate) fn set_auth_token( - req: &mut Request, - token_entry: TokenEntry, -) -> Result<(), header::InvalidHeaderValue> { +pub(crate) fn create_token_from_access_token(token_entry: TokenEntry) -> String { let token: &str = &token_entry.get().access_token; let token_type = token_entry.get().token_type; - let token: String = match token_type { + match token_type { AccessTokenRspTokenType::Bearer => { - let mut string = "Bearer ".to_owned(); + let mut string = "bearer ".to_owned(); string.push_str(token); string } - }; - let headers_mut = req.headers_mut(); + } +} + +pub(crate) fn set_auth_token( + headers_mut: &mut HeaderMap, + token_entry: TokenEntry, +) -> Result<(), header::InvalidHeaderValue> { + let token = create_token_from_access_token(token_entry); headers_mut.insert(AUTHORIZATION, token.try_into()?); Ok(()) } +/// Check if a token is expired based on its expiry time +/// +/// Returns `true` if the token is expired or will expire within the buffer time +fn is_token_expired(token: &AccessTokenRsp) -> bool { + // Get current time + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + // SAFETY: As now system time should not be before UNIX_EPOCH + .unwrap_or_default() + .as_secs(); + // Check if token has expires_in field + if let Some(expires_in) = token.expires_in { + // Calculate expiry time (assuming the token was issued recently) + // In a real implementation, you might want to store the issue time + // For now, we assume the token was just issued + let buffer_seconds = 30; // 30 second buffer before expiry + let expires_at = now + expires_in as u64; + let expires_with_buffer = expires_at.saturating_sub(buffer_seconds); + + now >= expires_with_buffer + } else { + // If no expiry time is provided, assume token is still valid + // This is a conservative approach + false + } +} + #[derive(Debug)] pub struct Scope(String);