From 3dcd3258ea4281ba8ad1a04f9d95dafa1700a143 Mon Sep 17 00:00:00 2001 From: Luke Howard Date: Wed, 28 Jan 2026 15:40:39 +1100 Subject: [PATCH] simplify Swift concurrency model - Applications are now actors - Participants are @unchecked Sendable final classes; any functions that modify mutable state first assert they are running on the application's actor - Because NonisolatedNonsendingByDefault is set, nonisolated Participant methods will run on application actor (except for Timer callbacks, where the isolation is explicitly passed) --- .../Applications/MMRP/MMRPApplication.swift | 57 ++-- .../Applications/MSRP/MSRPApplication.swift | 270 +++++++++--------- .../MRP/Applications/MSRP/MSRPHandler.swift | 76 ++--- .../Applications/MVRP/MVRPApplication.swift | 67 ++--- Sources/MRP/Base/MRPController.swift | 24 +- Sources/MRP/Base/Utility.swift | 25 +- Sources/MRP/Model/Application.swift | 74 +++-- Sources/MRP/Model/BaseApplication.swift | 78 +++-- Sources/MRP/Model/Participant.swift | 258 ++++++++++------- Sources/MRP/RestApi/RestApiHandler.swift | 6 +- 10 files changed, 474 insertions(+), 461 deletions(-) diff --git a/Sources/MRP/Applications/MMRP/MMRPApplication.swift b/Sources/MRP/Applications/MMRP/MMRPApplication.swift index 97287b40..4c0c4abc 100644 --- a/Sources/MRP/Applications/MMRP/MMRPApplication.swift +++ b/Sources/MRP/Applications/MMRP/MMRPApplication.swift @@ -38,34 +38,32 @@ protocol MMRPAwareBridge

: Bridge where P: Port { ) async throws } -public final class MMRPApplication: BaseApplication, BaseApplicationEventObserver, - CustomStringConvertible, - Sendable where P == P +public actor MMRPApplication: BaseApplication, BaseApplicationEventObserver, Sendable, + CustomStringConvertible where P == P { // for now, we only operate in the Base Spanning Tree Context - public var nonBaseContextsSupported: Bool { false } + public nonisolated var nonBaseContextsSupported: Bool { false } - public var validAttributeTypes: ClosedRange { + public nonisolated var validAttributeTypes: ClosedRange { MMRPAttributeType.validAttributeTypes } // 10.12.1.3 MMRP application address - public var groupAddress: EUI48 { CustomerBridgeMRPGroupAddress } + public nonisolated var groupAddress: EUI48 { CustomerBridgeMRPGroupAddress } // 10.12.1.4 MMRP application EtherType - public var etherType: UInt16 { MMRPEtherType } + public nonisolated var etherType: UInt16 { MMRPEtherType } // 10.12.1.5 MMRP ProtocolVersion - public var protocolVersion: ProtocolVersion { 0 } + public nonisolated var protocolVersion: ProtocolVersion { 0 } - public var hasAttributeListLength: Bool { false } + public nonisolated var hasAttributeListLength: Bool { false } let _controller: Weak> - public var controller: MRPController

? { _controller.object } + public nonisolated var controller: MRPController

? { _controller.object } - let _participants = - Mutex<[MAPContextIdentifier: Set>>]>([:]) + var _participants: [MAPContextIdentifier: Set>>] = [:] let _logger: Logger public init(controller: MRPController

) async throws { @@ -74,14 +72,13 @@ public final class MMRPApplication: BaseApplication, BaseApplicationEve try await controller.register(application: self) } - public var description: String { - let participants: String = _participants.withLock { String(describing: $0) } - return "MMRPApplication(controller: \(controller!), participants: \(participants))" + public nonisolated var description: String { + "MMRPApplication(controller: \(controller!))" } - public var name: String { "MMRP" } + public nonisolated var name: String { "MMRP" } - public func deserialize( + public nonisolated func deserialize( attributeOfType attributeType: AttributeType, from input: inout ParserSpan ) throws -> any Value { @@ -95,7 +92,7 @@ public final class MMRPApplication: BaseApplication, BaseApplicationEve } } - public func makeNullValue(for attributeType: AttributeType) throws -> any Value { + public nonisolated func makeNullValue(for attributeType: AttributeType) throws -> any Value { guard let attributeType = MMRPAttributeType(rawValue: attributeType) else { throw MRPError.unknownAttributeType } switch attributeType { @@ -106,18 +103,18 @@ public final class MMRPApplication: BaseApplication, BaseApplicationEve } } - public func hasAttributeSubtype(for: AttributeType) -> Bool { + public nonisolated func hasAttributeSubtype(for: AttributeType) -> Bool { false } - public func administrativeControl(for attributeType: AttributeType) throws + public nonisolated func administrativeControl(for attributeType: AttributeType) throws -> AdministrativeControl { .normalParticipant } - public func register(macAddress: EUI48) async throws { - try await join( + public func register(macAddress: EUI48) throws { + try join( attributeType: MMRPAttributeType.mac.rawValue, attributeValue: MMRPMACValue(macAddress: macAddress), isNew: false, @@ -125,8 +122,8 @@ public final class MMRPApplication: BaseApplication, BaseApplicationEve ) } - public func deregister(macAddress: EUI48) async throws { - try await leave( + public func deregister(macAddress: EUI48) throws { + try leave( attributeType: MMRPAttributeType.mac.rawValue, attributeValue: MMRPMACValue(macAddress: macAddress), for: MAPBaseSpanningTreeContext @@ -135,8 +132,8 @@ public final class MMRPApplication: BaseApplication, BaseApplicationEve public func register( serviceRequirement requirementSpecification: MMRPServiceRequirementValue - ) async throws { - try await join( + ) throws { + try join( attributeType: MMRPAttributeType.serviceRequirement.rawValue, attributeValue: requirementSpecification, isNew: false, @@ -146,8 +143,8 @@ public final class MMRPApplication: BaseApplication, BaseApplicationEve public func deregister( serviceRequirement requirementSpecification: MMRPServiceRequirementValue - ) async throws { - try await leave( + ) throws { + try leave( attributeType: MMRPAttributeType.serviceRequirement.rawValue, attributeValue: requirementSpecification, for: MAPBaseSpanningTreeContext @@ -155,8 +152,8 @@ public final class MMRPApplication: BaseApplication, BaseApplicationEve } public func periodic(for contextIdentifier: MAPContextIdentifier? = nil) async throws { - try await apply(for: contextIdentifier) { participant in - try await participant.periodic() + try apply(for: contextIdentifier) { participant in + try participant.periodic() } } } diff --git a/Sources/MRP/Applications/MSRP/MSRPApplication.swift b/Sources/MRP/Applications/MSRP/MSRPApplication.swift index b7263865..2460c576 100644 --- a/Sources/MRP/Applications/MSRP/MSRPApplication.swift +++ b/Sources/MRP/Applications/MSRP/MSRPApplication.swift @@ -137,32 +137,31 @@ struct MSRPPortState: Sendable { } } -public final class MSRPApplication: BaseApplication, BaseApplicationEventObserver, - BaseApplicationContextObserver, CustomStringConvertible, @unchecked Sendable where P == P +public actor MSRPApplication: BaseApplication, BaseApplicationEventObserver, Sendable, + BaseApplicationContextObserver, CustomStringConvertible where P == P { private typealias TalkerRegistration = (Participant, any MSRPTalkerValue) // for now, we only operate in the Base Spanning Tree Context - public var nonBaseContextsSupported: Bool { false } + public nonisolated var nonBaseContextsSupported: Bool { false } - public var validAttributeTypes: ClosedRange { + public nonisolated var validAttributeTypes: ClosedRange { MSRPAttributeType.validAttributeTypes } - public var groupAddress: EUI48 { IndividualLANScopeGroupAddress } + public nonisolated var groupAddress: EUI48 { IndividualLANScopeGroupAddress } - public var etherType: UInt16 { MSRPEtherType } + public nonisolated var etherType: UInt16 { MSRPEtherType } - public var protocolVersion: ProtocolVersion { MSRPProtocolVersion.v0.rawValue } + public nonisolated var protocolVersion: ProtocolVersion { MSRPProtocolVersion.v0.rawValue } - public var hasAttributeListLength: Bool { true } + public nonisolated var hasAttributeListLength: Bool { true } let _controller: Weak> - public var controller: MRPController

? { _controller.object } + public nonisolated var controller: MRPController

? { _controller.object } - let _participants = - Mutex<[MAPContextIdentifier: Set>>]>([:]) + var _participants: [MAPContextIdentifier: Set>>] = [:] let _logger: Logger let _latencyMaxFrameSize: UInt16 let _queues: [SRclassID: UInt] @@ -174,15 +173,15 @@ public final class MSRPApplication: BaseApplication, BaseApplication fileprivate let _maxFanInPorts: Int fileprivate let _maxSRClass: SRclassID - fileprivate let _portStates = Mutex<[P.ID: MSRPPortState

]>([:]) + fileprivate var _portStates: [P.ID: MSRPPortState

] = [:] fileprivate let _mmrp: MMRPApplication

? fileprivate var _priorityMapNotificationTask: Task<(), Error>? // Convenience accessors for flags - fileprivate var _forceAvbCapable: Bool { _flags.contains(.forceAvbCapable) } - fileprivate var _configureQueues: Bool { _flags.contains(.configureQueues) } - var _ignoreAsCapable: Bool { _flags.contains(.ignoreAsCapable) } - fileprivate var _talkerPruning: Bool { _flags.contains(.talkerPruning) } + fileprivate nonisolated var _forceAvbCapable: Bool { _flags.contains(.forceAvbCapable) } + fileprivate nonisolated var _configureQueues: Bool { _flags.contains(.configureQueues) } + nonisolated var _ignoreAsCapable: Bool { _flags.contains(.ignoreAsCapable) } + fileprivate nonisolated var _talkerPruning: Bool { _flags.contains(.talkerPruning) } public init( controller: MRPController

, @@ -228,12 +227,10 @@ public final class MSRPApplication: BaseApplication, BaseApplication port: P, _ body: (_: inout MSRPPortState

) throws -> T ) throws -> T { - try _portStates.withLock { - if let index = $0.index(forKey: port.id) { - return try body(&$0.values[index]) - } else { - throw MRPError.portNotFound - } + if let index = _portStates.index(forKey: port.id) { + return try body(&_portStates.values[index]) + } else { + throw MRPError.portNotFound } } @@ -276,85 +273,73 @@ public final class MSRPApplication: BaseApplication, BaseApplication } } - try _portStates.withLock { - for port in context { - var portState = try MSRPPortState(msrp: self, port: port) - if let srClassPriorityMap = srClassPriorityMap[port.id] { - portState.srClassPriorityMap = srClassPriorityMap - } - $0[port.id] = portState + for port in context { + var portState = try MSRPPortState(msrp: self, port: port) + if let srClassPriorityMap = srClassPriorityMap[port.id] { + portState.srClassPriorityMap = srClassPriorityMap } + _portStates[port.id] = portState } for port in context { _logger.debug("MSRP: declaring domains for port \(port)") - try await _declareDomains(port: port) + try _declareDomains(port: port) } } func onContextUpdated( contextIdentifier: MAPContextIdentifier, with context: MAPContext

- ) throws { + ) async throws { guard contextIdentifier == MAPBaseSpanningTreeContext else { return } if !_forceAvbCapable { - _portStates.withLock { - for port in context { - guard let index = $0.index(forKey: port.id) else { continue } - if $0.values[index].msrpPortEnabledStatus != port.isAvbCapable { - _logger.info("MSRP: port \(port) changed isAvbCapable, now \(port.isAvbCapable)") - } - $0.values[index].msrpPortEnabledStatus = port.isAvbCapable + for port in context { + guard let index = _portStates.index(forKey: port.id) else { continue } + if _portStates.values[index].msrpPortEnabledStatus != port.isAvbCapable { + _logger.info("MSRP: port \(port) changed isAvbCapable, now \(port.isAvbCapable)") } + _portStates.values[index].msrpPortEnabledStatus = port.isAvbCapable } } - Task { - for port in context { - _logger.debug("MSRP: re-declaring domains for port \(port)") - try await _declareDomains(port: port) - } + for port in context { + _logger.debug("MSRP: re-declaring domains for port \(port)") + try _declareDomains(port: port) } } func onContextRemoved( contextIdentifier: MAPContextIdentifier, with context: MAPContext

- ) throws { + ) async throws { guard contextIdentifier == MAPBaseSpanningTreeContext else { return } if _configureQueues { - Task { - for port in context { - guard port.isAvbCapable, - let bridge = (controller?.bridge as? any MSRPAwareBridge

) else { continue } - do { - try await bridge.unconfigureQueues(port: port) - } catch { - _logger.error("MSRP: failed to unconfigure queues for port \(port): \(error)") - } + for port in context { + guard port.isAvbCapable, + let bridge = (controller?.bridge as? any MSRPAwareBridge

) else { continue } + do { + try await bridge.unconfigureQueues(port: port) + } catch { + _logger.error("MSRP: failed to unconfigure queues for port \(port): \(error)") } } } - _portStates.withLock { - for port in context { - _logger.debug("MSRP: port \(port) disappeared, removing") - $0.removeValue(forKey: port.id) - } + for port in context { + _logger.debug("MSRP: port \(port) disappeared, removing") + _portStates.removeValue(forKey: port.id) } } - public var description: String { - let participants: String = _participants.withLock { String(describing: $0) } - let portStates: String = _portStates.withLock { String(describing: $0) } - return "MSRPApplication(controller: \(controller?.description ?? ""), participants: \(participants), portStates: \(portStates)" + public nonisolated var description: String { + "MSRPApplication(controller: \(controller!))" } - public var name: String { "MSRP" } + public nonisolated var name: String { "MSRP" } - public func deserialize( + public nonisolated func deserialize( attributeOfType attributeType: AttributeType, from input: inout ParserSpan ) throws -> any Value { @@ -372,7 +357,7 @@ public final class MSRPApplication: BaseApplication, BaseApplication } } - public func makeNullValue(for attributeType: AttributeType) throws -> any Value { + public nonisolated func makeNullValue(for attributeType: AttributeType) throws -> any Value { guard let attributeType = MSRPAttributeType(rawValue: attributeType) else { throw MRPError.unknownAttributeType } switch attributeType { @@ -387,11 +372,11 @@ public final class MSRPApplication: BaseApplication, BaseApplication } } - public func hasAttributeSubtype(for attributeType: AttributeType) -> Bool { + public nonisolated func hasAttributeSubtype(for attributeType: AttributeType) -> Bool { attributeType == MSRPAttributeType.listener.rawValue } - public func administrativeControl(for attributeType: AttributeType) throws + public nonisolated func administrativeControl(for attributeType: AttributeType) throws -> AdministrativeControl { .normalParticipant @@ -411,7 +396,7 @@ public final class MSRPApplication: BaseApplication, BaseApplication priorityAndRank: MSRPPriorityAndRank, accumulatedLatency: UInt32, failureInformation: MSRPFailure? = nil - ) async throws { + ) throws { let attributeValue: any Value switch declarationType { @@ -447,7 +432,7 @@ public final class MSRPApplication: BaseApplication, BaseApplication throw MRPError.invalidMSRPDeclarationType } - try await join( + try join( attributeType: ( failureInformation != nil ? MSRPAttributeType.talkerFailed : MSRPAttributeType .talkerAdvertise @@ -465,8 +450,8 @@ public final class MSRPApplication: BaseApplication, BaseApplication // were in the associated REGISTER_STREAM.request primitive. public func deregisterStream( streamID: MSRPStreamID - ) async throws { - guard let talkerRegistration = await _findTalkerRegistration(for: streamID) else { + ) throws { + guard let talkerRegistration = _findTalkerRegistration(for: streamID) else { throw MRPError.participantNotFound } let declarationType: MSRPDeclarationType = if talkerRegistration.1 is MSRPTalkerAdvertiseValue { @@ -474,7 +459,7 @@ public final class MSRPApplication: BaseApplication, BaseApplication } else { .talkerFailed } - try await leave( + try leave( attributeType: declarationType.attributeType.rawValue, attributeValue: MSRPListenerValue(streamID: streamID), for: MAPBaseSpanningTreeContext @@ -490,10 +475,10 @@ public final class MSRPApplication: BaseApplication, BaseApplication streamID: MSRPStreamID, declarationType: MSRPDeclarationType, on port: P? = nil - ) async throws { - try await apply { participant in + ) throws { + try apply { participant in if let port, port != participant.port { return } - try await join( + try join( attributeType: MSRPAttributeType.listener.rawValue, attributeSubtype: declarationType.attributeSubtype?.rawValue, attributeValue: MSRPListenerValue(streamID: streamID), @@ -511,16 +496,16 @@ public final class MSRPApplication: BaseApplication, BaseApplication public func deregisterAttach( streamID: MSRPStreamID, on port: P? = nil - ) async throws { - try await apply { participant in + ) throws { + try apply { participant in if let port, port != participant.port { return } - guard let listenerRegistration = await _findListenerRegistration( + guard let listenerRegistration = _findListenerRegistration( for: streamID, participant: participant ) else { return } - try await leave( + try leave( attributeType: MSRPAttributeType.listener.rawValue, attributeSubtype: listenerRegistration.1.rawValue, attributeValue: listenerRegistration.0, @@ -542,24 +527,24 @@ extension MSRPApplication { declarationType: MSRPDeclarationType, streamID: MSRPStreamID, eventSource: EventSource - ) async throws { + ) throws { let oppositeType: MSRPAttributeType = declarationType == .talkerAdvertise ? .talkerFailed : .talkerAdvertise - let oppositeAttributes = await participant.findAttributes( + let oppositeAttributes = participant.findAttributes( attributeType: oppositeType.rawValue, matching: .matchAnyIndex(streamID.index) ) for (_, attributeValue) in oppositeAttributes { if eventSource == .map { - try? await participant.leave( + try? participant.leave( attributeType: oppositeType.rawValue, attributeValue: attributeValue, eventSource: eventSource ) } else { - try? await participant.deregister( + try? participant.deregister( attributeType: oppositeType.rawValue, attributeValue: attributeValue, eventSource: eventSource @@ -582,8 +567,8 @@ extension MSRPApplication { guard let portState = try? withPortState(port: port, { $0 }) else { return true } if _talkerPruning || portState.talkerPruning { - if let mmrpParticipant = try? _mmrp?.findParticipant(port: port), - await mmrpParticipant.findAttribute( + if let mmrpParticipant = try? await _mmrp?.findParticipant(port: port), + mmrpParticipant.findAttribute( attributeType: MMRPAttributeType.mac.rawValue, matching: .matchEqual(MMRPMACValue(macAddress: dataFrameParameters.destinationAddress)) ) == nil @@ -599,7 +584,7 @@ extension MSRPApplication { return false } - private func _isFanInPortLimitReached() async -> Bool { + private func _isFanInPortLimitReached() -> Bool { if _maxFanInPorts == 0 { return false } @@ -607,8 +592,8 @@ extension MSRPApplication { var fanInCount = 0 // calculate total number of ports with inbound reservations - await apply { participant in - if await participant.findAttribute( + apply { participant in + if participant.findAttribute( attributeType: MSRPAttributeType.listener.rawValue, matching: .matchAny ) != nil { @@ -707,11 +692,11 @@ extension MSRPApplication { participant: Participant, portState: MSRPPortState

, provisionalTalker: MSRPTalkerAdvertiseValue? = nil - ) async throws -> [SRclassID: Int] { + ) throws -> [SRclassID: Int] { var bandwidthUsed = [SRclassID: Int]() // Find all active talkers (those with listeners in ready or readyFailed state) - var talkers = await _findActiveTalkers(participant: participant) + var talkers = _findActiveTalkers(participant: participant) // Add provisional talker if provided (for bandwidth admission control check) if let provisionalTalker { talkers.insert(provisionalTalker) } @@ -744,7 +729,7 @@ extension MSRPApplication { dataFrameParameters: MSRPDataFrameParameters, tSpec: MSRPTSpec, priorityAndRank: MSRPPriorityAndRank - ) async throws -> Bool { + ) throws -> Bool { let port = participant.port let provisionalTalker = MSRPTalkerAdvertiseValue( streamID: streamID, @@ -754,7 +739,7 @@ extension MSRPApplication { accumulatedLatency: 0 // or this ) - let bandwidthUsed = try await _calculateBandwidthUsed( + let bandwidthUsed = try _calculateBandwidthUsed( participant: participant, portState: portState, provisionalTalker: provisionalTalker @@ -789,7 +774,7 @@ extension MSRPApplication { accumulatedLatency: UInt32, isNew: Bool, eventSource: EventSource - ) async throws { + ) throws { let port = participant.port do { @@ -802,7 +787,7 @@ extension MSRPApplication { throw MSRPFailure(systemID: port.systemID, failureCode: .egressPortIsNotAvbCapable) } - if let existingTalkerRegistration = await _findTalkerRegistration( + if let existingTalkerRegistration = _findTalkerRegistration( for: streamID, participant: participant ), existingTalkerRegistration.dataFrameParameters != dataFrameParameters { @@ -822,7 +807,7 @@ extension MSRPApplication { throw MSRPFailure(systemID: port.systemID, failureCode: .egressPortIsNotAvbCapable) } - guard await !_isFanInPortLimitReached() else { + guard !_isFanInPortLimitReached() else { _logger.error("MSRP: fan in port limit reached") throw MSRPFailure(systemID: port.systemID, failureCode: .fanInPortLimitReached) } @@ -839,7 +824,7 @@ extension MSRPApplication { throw MSRPFailure(systemID: port.systemID, failureCode: .maxFrameSizeTooLargeForMedia) } - guard try await _checkAvailableBandwidth( + guard try _checkAvailableBandwidth( participant: participant, portState: portState, streamID: streamID, @@ -890,7 +875,7 @@ extension MSRPApplication { // Deregister the opposite talker type from the peer to ensure mutual exclusion if eventSource == .peer { let sourceParticipant = try findParticipant(for: contextIdentifier, port: port) - try await _enforceTalkerMutualExclusion( + try _enforceTalkerMutualExclusion( participant: sourceParticipant, declarationType: declarationType, streamID: talkerValue.streamID, @@ -938,7 +923,7 @@ extension MSRPApplication { // Leave the opposite talker declaration type to ensure mutual exclusion // (per spec, only one talker declaration type should exist per stream) - try await _enforceTalkerMutualExclusion( + try _enforceTalkerMutualExclusion( participant: participant, declarationType: declarationType, streamID: talkerValue.streamID, @@ -947,7 +932,7 @@ extension MSRPApplication { if declarationType == .talkerAdvertise { do { - try await _canBridgeTalker( + try _canBridgeTalker( participant: participant, port: port, streamID: talkerValue.streamID, @@ -970,7 +955,7 @@ extension MSRPApplication { .debug( "MSRP: propagating talker advertise \(talkerAdvertise) to port \(port)" ) - try await participant.join( + try participant.join( attributeType: MSRPAttributeType.talkerAdvertise.rawValue, attributeValue: talkerAdvertise, isNew: false, @@ -990,7 +975,7 @@ extension MSRPApplication { .debug( "MSRP: propagating talker failed \(talkerFailed) on port \(port), error \(error)" ) - try await participant.join( + try participant.join( attributeType: MSRPAttributeType.talkerFailed.rawValue, attributeValue: talkerFailed, isNew: true, @@ -1012,7 +997,7 @@ extension MSRPApplication { .debug( "MSRP: propagating talker failed \(talkerFailed) to port \(port), transitive" ) - try await participant.join( + try participant.join( attributeType: MSRPAttributeType.talkerFailed.rawValue, attributeValue: talkerFailed, isNew: false, @@ -1042,7 +1027,7 @@ extension MSRPApplication { // _propagateListenerDeclarationToTalker() will examine all listeners and // return the merged declaration type, so there is no need to do this // within the apply() loop - guard let mergedDeclarationType = try? await _propagateListenerDeclarationToTalker( + guard let mergedDeclarationType = try? _propagateListenerDeclarationToTalker( contextIdentifier: contextIdentifier, listenerPort: nil, declarationType: nil, @@ -1056,7 +1041,7 @@ extension MSRPApplication { // that didn't exist previously, we do need to update port parameters on // each talker port that matches the listener stream ID await apply(for: contextIdentifier) { participant in - guard let listenerRegistration = await _findListenerRegistration( + guard let listenerRegistration = _findListenerRegistration( for: talkerValue.streamID, participant: participant ) else { @@ -1064,7 +1049,7 @@ extension MSRPApplication { } // verify talker still exists (guard against race with talker departure) - guard let currentTalker = await _findTalkerRegistration( + guard let currentTalker = _findTalkerRegistration( for: talkerValue.streamID, participant: talkerParticipant ), currentTalker.streamID == talkerValue.streamID else { @@ -1135,8 +1120,8 @@ extension MSRPApplication { private func _findListenerRegistration( for streamID: MSRPStreamID, participant: Participant - ) async -> (MSRPListenerValue, MSRPAttributeSubtype)? { - guard let listenerAttribute = await participant.findAttribute( + ) -> (MSRPListenerValue, MSRPAttributeSubtype)? { + guard let listenerAttribute = participant.findAttribute( attributeType: MSRPAttributeType.listener.rawValue, matching: .matchAnyIndex(streamID.index) ) else { return nil } @@ -1154,14 +1139,14 @@ extension MSRPApplication { private func _findTalkerRegistration( for streamID: MSRPStreamID, participant: Participant - ) async -> (any MSRPTalkerValue)? { + ) -> (any MSRPTalkerValue)? { // TalkerFailed takes precedence over TalkerAdvertise per spec - if let value = await participant.findAttribute( + if let value = participant.findAttribute( attributeType: MSRPAttributeType.talkerFailed.rawValue, matching: .matchAnyIndex(streamID.index) ) { value.1 as? (any MSRPTalkerValue) - } else if let value = await participant.findAttribute( + } else if let value = participant.findAttribute( attributeType: MSRPAttributeType.talkerAdvertise.rawValue, matching: .matchAnyIndex(streamID.index) ) { @@ -1173,11 +1158,11 @@ extension MSRPApplication { private func _findTalkerRegistration( for streamID: MSRPStreamID - ) async -> TalkerRegistration? { + ) -> TalkerRegistration? { var talkerRegistration: TalkerRegistration? - await apply { participant in - guard let participantTalker = await _findTalkerRegistration( + apply { participant in + guard let participantTalker = _findTalkerRegistration( for: streamID, participant: participant ), talkerRegistration == nil else { @@ -1195,22 +1180,23 @@ extension MSRPApplication { declarationType: MSRPDeclarationType?, talkerRegistration: TalkerRegistration, isJoin: Bool - ) async throws -> MSRPDeclarationType? { + ) throws -> MSRPDeclarationType? { var mergedDeclarationType = isJoin ? declarationType : nil let streamID = talkerRegistration.1.streamID var listenerCount = mergedDeclarationType != nil ? 1 : 0 // collect listener declarations from all other ports and merge declaration type - await apply(for: contextIdentifier) { participant in + apply(for: contextIdentifier) { participant in // exclude registering or leaving port guard participant.port != port else { return } // exclude talker port guard participant.port != talkerRegistration.0.port else { return } - for listenerAttribute in await participant.findAllAttributes( + for listenerAttribute in participant.findAllAttributesUnchecked( attributeType: MSRPAttributeType.listener.rawValue, - matching: .matchAnyIndex(streamID.id) + matching: .matchAnyIndex(streamID.id), + isolation: self ) { guard let declarationType = try? MSRPDeclarationType(attributeSubtype: listenerAttribute .attributeSubtype) @@ -1275,19 +1261,19 @@ extension MSRPApplication { private func _findActiveTalkers( participant: Participant> - ) async -> Set { + ) -> Set { // Find all active talkers by querying listeners on this port and finding their corresponding // talkers - await Set(participant.findAttributes( + Set(participant.findAttributes( attributeType: MSRPAttributeType.listener.rawValue, matching: .matchAny - ).asyncCompactMap { + ).compactMap { guard let attributeSubtype = $0.0, let attributeSubtype = MSRPAttributeSubtype(rawValue: attributeSubtype), attributeSubtype == .ready || attributeSubtype == .readyFailed else { return nil } let listener = $0.1 as! MSRPListenerValue - guard let talkerRegistration = await _findTalkerRegistration(for: listener.streamID), + guard let talkerRegistration = _findTalkerRegistration(for: listener.streamID), let talkerAdvertise = talkerRegistration.1 as? MSRPTalkerAdvertiseValue else { return nil } return talkerAdvertise @@ -1310,7 +1296,7 @@ extension MSRPApplication { return } - var talkers = await _findActiveTalkers(participant: participant) + var talkers = _findActiveTalkers(participant: participant) // Remove the specific talker stream that is the subject of this // registration or deregistration; we will add it back conditionally @@ -1417,7 +1403,7 @@ extension MSRPApplication { isNew: Bool, eventSource: EventSource, talkerRegistration: TalkerRegistration - ) async throws -> MSRPDeclarationType? { + ) throws -> MSRPDeclarationType? { // point-to-point talker registrations should not come from the same port as the listener if let port, port.isPointToPoint, talkerRegistration.0.port == port { _logger @@ -1428,7 +1414,7 @@ extension MSRPApplication { } // TL;DR: propagate merged Listener declarations to _talker_ port - guard let mergedDeclarationType = try await _mergeListenerDeclarations( + guard let mergedDeclarationType = try _mergeListenerDeclarations( contextIdentifier: contextIdentifier, port: port, declarationType: declarationType, @@ -1443,7 +1429,7 @@ extension MSRPApplication { "MSRP: propagating listener declaration streamID \(streamID) declarationType \(declarationType != nil ? String(describing: declarationType!) : "") -> \(mergedDeclarationType) to participant \(talkerRegistration.0)" ) - try await talkerRegistration.0.join( + try talkerRegistration.0.join( attributeType: MSRPAttributeType.listener.rawValue, attributeSubtype: mergedDeclarationType.attributeSubtype!.rawValue, attributeValue: MSRPListenerValue(streamID: streamID), @@ -1467,7 +1453,7 @@ extension MSRPApplication { isNew: Bool, eventSource: EventSource ) async throws { - guard let talkerRegistration = await _findTalkerRegistration(for: streamID) else { + guard let talkerRegistration = _findTalkerRegistration(for: streamID) else { // no listener attribute propagation if no talker (35.2.4.4.1) // this is an expected race condition - listener arrives before talker // when talker arrives, _updateExistingListeners() will process it @@ -1478,7 +1464,7 @@ extension MSRPApplication { return } - guard let mergedDeclarationType = try await _propagateListenerDeclarationToTalker( + guard let mergedDeclarationType = try _propagateListenerDeclarationToTalker( contextIdentifier: contextIdentifier, listenerPort: port, declarationType: declarationType, @@ -1606,11 +1592,11 @@ extension MSRPApplication { guard participant.port != port else { return } // don't propagate to source port // If this participant has active listeners, propagate a leave back to the talker - if let listenerRegistration = await _findListenerRegistration( + if let listenerRegistration = _findListenerRegistration( for: streamID, participant: participant ) { - try await talkerParticipant.leave( + try talkerParticipant.leave( attributeType: MSRPAttributeType.listener.rawValue, attributeSubtype: listenerRegistration.1.rawValue, attributeValue: listenerRegistration.0, @@ -1628,7 +1614,7 @@ extension MSRPApplication { // 35.2.4.3: If no Talker attributes are registered for a StreamID then // no Talker attributes for that StreamID will be declared on any other // port of the Bridge. i.e. implement as ordinary attribute propagation - try await participant.leave( + try participant.leave( attributeType: talkerValue.declarationType!.attributeType.rawValue, attributeSubtype: nil, attributeValue: talkerValue, @@ -1651,12 +1637,12 @@ extension MSRPApplication { // StreamID of the Declaration matches a Stream that the Talker is // transmitting, then the Talker shall stop the transmission for this // Stream, if it is transmitting. - guard let talkerRegistration = await _findTalkerRegistration(for: streamID) else { + guard let talkerRegistration = _findTalkerRegistration(for: streamID) else { return } // TL;DR: propagate merged Listener declarations to _talker_ port - let mergedDeclarationType = try await _mergeListenerDeclarations( + let mergedDeclarationType = try _mergeListenerDeclarations( contextIdentifier: contextIdentifier, port: port, declarationType: declarationType, @@ -1670,7 +1656,7 @@ extension MSRPApplication { ) if let mergedDeclarationType { - try await talkerRegistration.0.join( + try talkerRegistration.0.join( attributeType: MSRPAttributeType.listener.rawValue, attributeSubtype: mergedDeclarationType.attributeSubtype!.rawValue, attributeValue: MSRPListenerValue(streamID: streamID), @@ -1678,7 +1664,7 @@ extension MSRPApplication { eventSource: .map ) } else { - try await talkerRegistration.0.leave( + try talkerRegistration.0.leave( attributeType: MSRPAttributeType.listener.rawValue, attributeValue: MSRPListenerValue(streamID: streamID), eventSource: .map @@ -1756,7 +1742,7 @@ extension MSRPApplication { private func _declareDomain( srClassID: SRclassID, on participant: Participant - ) async throws { + ) throws { var domain: MSRPDomainValue? domain = try withPortState(port: participant.port) { portState in @@ -1765,7 +1751,7 @@ extension MSRPApplication { if let domain { _logger.info("MSRP: declaring domain \(domain)") - try await participant.join( + try participant.join( attributeType: MSRPAttributeType.domain.rawValue, attributeValue: domain, isNew: true, @@ -1779,14 +1765,14 @@ extension MSRPApplication { } } - fileprivate var _allSRClassIDs: [SRclassID] { + fileprivate nonisolated var _allSRClassIDs: [SRclassID] { Array((_maxSRClass.rawValue...SRclassID.A.rawValue).map { SRclassID(rawValue: $0)! }) } - private func _declareDomains(port: P) async throws { + private func _declareDomains(port: P) throws { let participant = try findParticipant(port: port) for srClassID in _allSRClassIDs { - try await _declareDomain(srClassID: srClassID, on: participant) + try _declareDomain(srClassID: srClassID, on: participant) } } @@ -1794,13 +1780,13 @@ extension MSRPApplication { get async { var numberOfTalkerAttributes = 0 - await apply { participant in - numberOfTalkerAttributes += await participant.findAttributes( + apply { participant in + numberOfTalkerAttributes += participant.findAttributes( attributeType: MSRPAttributeType.talkerAdvertise.rawValue, matching: .matchAny ).count - numberOfTalkerAttributes += await participant.findAttributes( + numberOfTalkerAttributes += participant.findAttributes( attributeType: MSRPAttributeType.talkerFailed.rawValue, matching: .matchAny ).count diff --git a/Sources/MRP/Applications/MSRP/MSRPHandler.swift b/Sources/MRP/Applications/MSRP/MSRPHandler.swift index c214217e..0a81f14b 100644 --- a/Sources/MRP/Applications/MSRP/MSRPHandler.swift +++ b/Sources/MRP/Applications/MSRP/MSRPHandler.swift @@ -133,9 +133,9 @@ struct MSRPHandler: Sendable, RestApiApplicationHandler { participant: Participant, srClassID: SRclassID, streams: [Stream] - ) { + ) async { deltaBandwidth = application._deltaBandwidths[srClassID] ?? 0 - guard let portState = try? application.withPortState(port: participant.port, { $0 }) + guard let portState = try? await application.withPortState(port: participant.port, { $0 }) else { return nil } guard let domain = portState.getDomain(for: srClassID, defaultSRPVid: application._srPVid) else { return nil } @@ -176,7 +176,10 @@ struct MSRPHandler: Sendable, RestApiApplicationHandler { let type: String let streamAge: UInt32 - fileprivate init(participant: Participant, attributeValue: AttributeValue) { + fileprivate init( + participant: Participant, + attributeValue: AttributeValue + ) async { portNumber = participant.port.id portName = participant.port.name @@ -200,7 +203,7 @@ struct MSRPHandler: Sendable, RestApiApplicationHandler { let streamID = (attributeValue.attributeValue as! MSRPListenerValue).streamID streamAge = if let application = participant.application { - (try? application + await (try? application .withPortState(port: participant.port) { $0.getStreamAge(for: streamID) }) ?? 0 } else { 0 @@ -223,13 +226,13 @@ struct MSRPHandler: Sendable, RestApiApplicationHandler { attributeValue: AttributeValue ) async throws { let talker = attributeValue.attributeValue as! any MSRPTalkerValue - let portState = try application.withPortState(port: participant.port) { $0 } + let portState = try await application.withPortState(port: participant.port) { $0 } streamID = talker.streamID.streamIDString vid = talker.dataFrameParameters.vlanIdentifier.vid priority = talker.priorityAndRank.dataFramePriority.rawValue if let talker = talker as? MSRPTalkerAdvertiseValue { - bandwidth = try application._calculateBandwidthUsed( + bandwidth = try await application._calculateBandwidthUsed( portState: portState, talker: talker, nominalBandwidth: false @@ -262,15 +265,15 @@ struct MSRPHandler: Sendable, RestApiApplicationHandler { let port = participant.port let streams = await application._getStreams() - enabled = (try? application + enabled = await (try? application .withPortState(port: participant.port) { $0.msrpPortEnabledStatus }) ?? false listener = await participant._getListeners() talker = await participant._getTalkers() talkerFailed = await participant._getTalkersFailed() let activeStreamIDs = Set(listener.map(\.streamID)) let activeStreams = streams.filter { activeStreamIDs.contains($0.streamID) } - srClass = SRclassID.allCases.reduce(into: [:]) { dict, classID in - if let srClassInstance = SRClass( + srClass = await SRclassID.allCases.asyncReduce(into: [:]) { dict, classID in + if let srClassInstance = await SRClass( application: application, participant: participant, srClassID: classID, @@ -364,7 +367,7 @@ struct MSRPHandler: Sendable, RestApiApplicationHandler { func getListenerByStreamID(_ request: HTTPRequest) async throws -> Listener { guard let application, let (port, streamID) = await application._getPortAndStream(request), - let participant = try? application.findParticipant(port: port) + let participant = try? await application.findParticipant(port: port) else { throw HTTPUnhandledError() } @@ -531,7 +534,7 @@ struct MSRPHandler: Sendable, RestApiApplicationHandler { func getTalker(_ request: HTTPRequest) async throws -> Array { guard let application, let port = await controller?.getPort(request), - let participant = try? application.findParticipant(port: port.0) + let participant = try? await application.findParticipant(port: port.0) else { throw HTTPUnhandledError() } @@ -546,7 +549,7 @@ struct MSRPHandler: Sendable, RestApiApplicationHandler { func getTalkerByStreamID(_ request: HTTPRequest) async throws -> Talker { guard let application, let (port, streamID) = await application._getPortAndStream(request), - let participant = try? application.findParticipant(port: port) + let participant = try? await application.findParticipant(port: port) else { throw HTTPUnhandledError() } @@ -565,7 +568,7 @@ struct MSRPHandler: Sendable, RestApiApplicationHandler { func getTalkerFailed(_ request: HTTPRequest) async throws -> Array { guard let application, let port = await controller?.getPort(request), - let participant = try? application.findParticipant(port: port.0) + let participant = try? await application.findParticipant(port: port.0) else { throw HTTPUnhandledError() } @@ -580,7 +583,7 @@ struct MSRPHandler: Sendable, RestApiApplicationHandler { func getTalkerFailedByStreamID(_ request: HTTPRequest) async throws -> TalkerFailed { guard let application, let (port, streamID) = await application._getPortAndStream(request), - let participant = try? application.findParticipant(port: port) + let participant = try? await application.findParticipant(port: port) else { throw HTTPUnhandledError() } @@ -816,7 +819,7 @@ struct MSRPHandler: Sendable, RestApiApplicationHandler { func getStreamListenerByStreamID(_ request: HTTPRequest) async throws -> StreamListener { guard let application, let (streamID, port) = await application._getStreamAndPort(request), - let participant = try? application.findParticipant(port: port) + let participant = try? await application.findParticipant(port: port) else { throw HTTPUnhandledError() } @@ -830,9 +833,9 @@ struct MSRPHandler: Sendable, RestApiApplicationHandler { } fileprivate extension MSRPApplication { - func _getTransmitRate(for participant: Participant) async throws -> Int { + func _getTransmitRate(for participant: Participant) throws -> Int { let portState = try withPortState(port: participant.port) { $0 } - let bandwidthUsed = try await _calculateBandwidthUsed( + let bandwidthUsed = try _calculateBandwidthUsed( participant: participant, portState: portState ) @@ -878,15 +881,15 @@ fileprivate extension MSRPApplication { } fileprivate extension Participant where A.P: AVBPort { - func _getListeners() -> [MSRPHandler.Listener] { - findAllAttributes( + func _getListeners() async -> [MSRPHandler.Listener] { + await findAllAttributes( attributeType: MSRPAttributeType.listener.rawValue, matching: .matchAny ).map { MSRPHandler.Listener(attributeValue: $0) } } - func _getListener(streamID: MSRPStreamID) -> MSRPHandler.Listener? { - findAllAttributes( + func _getListener(streamID: MSRPStreamID) async -> MSRPHandler.Listener? { + await findAllAttributes( attributeType: MSRPAttributeType.listener.rawValue, matching: .matchIndex(streamID) ).map { MSRPHandler.Listener(attributeValue: $0) } @@ -895,15 +898,15 @@ fileprivate extension Participant where A.P: AVBPort { } fileprivate extension Participant where A.P: AVBPort { - func _getTalkers() -> [MSRPHandler.Talker] { - findAllAttributes( + func _getTalkers() async -> [MSRPHandler.Talker] { + await findAllAttributes( attributeType: MSRPAttributeType.talkerAdvertise.rawValue, matching: .matchAny ).map { MSRPHandler.Talker(attributeValue: $0) } } - func _getTalker(streamID: MSRPStreamID) -> MSRPHandler.Talker? { - findAllAttributes( + func _getTalker(streamID: MSRPStreamID) async -> MSRPHandler.Talker? { + await findAllAttributes( attributeType: MSRPAttributeType.talkerAdvertise.rawValue, matching: .matchIndex(streamID) ) @@ -913,15 +916,15 @@ fileprivate extension Participant where A.P: AVBPort { } fileprivate extension Participant where A.P: AVBPort { - func _getTalkersFailed() -> [MSRPHandler.TalkerFailed] { - findAllAttributes( + func _getTalkersFailed() async -> [MSRPHandler.TalkerFailed] { + await findAllAttributes( attributeType: MSRPAttributeType.talkerFailed.rawValue, matching: .matchAny ).map { MSRPHandler.TalkerFailed(attributeValue: $0) } } - func _getTalkerFailed(streamID: MSRPStreamID) -> MSRPHandler.TalkerFailed? { - findAllAttributes( + func _getTalkerFailed(streamID: MSRPStreamID) async -> MSRPHandler.TalkerFailed? { + await findAllAttributes( attributeType: MSRPAttributeType.talkerFailed.rawValue, matching: .matchIndex(streamID) ) @@ -931,16 +934,15 @@ fileprivate extension Participant where A.P: AVBPort { } fileprivate extension Participant where A.P: AVBPort { - func _getStreamListener(streamID: MSRPStreamID) -> MSRPHandler.Stream.Listener? { - findAllAttributes( + func _getStreamListener(streamID: MSRPStreamID) async -> MSRPHandler.Stream.Listener? { + await findAllAttributes( attributeType: MSRPAttributeType.listener.rawValue, matching: .matchIndex(streamID) - ) - .filter(\.isRegistered) - .map { MSRPHandler.Stream.Listener( - participant: self as! Participant>, - attributeValue: $0 - ) }.first + ).filter(\.isRegistered) + .asyncMap { await MSRPHandler.Stream.Listener( + participant: self as! Participant>, + attributeValue: $0 + ) }.first } } diff --git a/Sources/MRP/Applications/MVRP/MVRPApplication.swift b/Sources/MRP/Applications/MVRP/MVRPApplication.swift index 78cb74c5..e2dd9e0b 100644 --- a/Sources/MRP/Applications/MVRP/MVRPApplication.swift +++ b/Sources/MRP/Applications/MVRP/MVRPApplication.swift @@ -32,34 +32,32 @@ protocol MVRPAwareBridge

: Bridge where P: Port { func deregister(vlan: VLAN, from port: P) async throws } -public final class MVRPApplication: BaseApplication, BaseApplicationEventObserver, - BaseApplicationContextObserver, CustomStringConvertible, - Sendable where P == P +public actor MVRPApplication: BaseApplication, BaseApplicationEventObserver, Sendable, + BaseApplicationContextObserver, CustomStringConvertible where P == P { // for now, we only operate in the Base Spanning Tree Context - public var nonBaseContextsSupported: Bool { false } + public nonisolated var nonBaseContextsSupported: Bool { false } - public var validAttributeTypes: ClosedRange { + public nonisolated var validAttributeTypes: ClosedRange { MVRPAttributeType.validAttributeTypes } // 10.12.1.3 MVRP application address - public var groupAddress: EUI48 { CustomerBridgeMRPGroupAddress } + public nonisolated var groupAddress: EUI48 { CustomerBridgeMRPGroupAddress } // 10.12.1.4 MVRP application EtherType - public var etherType: UInt16 { MVRPEtherType } + public nonisolated var etherType: UInt16 { MVRPEtherType } // 10.12.1.5 MVRP ProtocolVersion - public var protocolVersion: ProtocolVersion { 0 } + public nonisolated var protocolVersion: ProtocolVersion { 0 } - public var hasAttributeListLength: Bool { false } + public nonisolated var hasAttributeListLength: Bool { false } let _controller: Weak> - public var controller: MRPController

? { _controller.object } + public nonisolated var controller: MRPController

? { _controller.object } - let _participants = - Mutex<[MAPContextIdentifier: Set>>]>([:]) + var _participants: [MAPContextIdentifier: Set>>] = [:] let _logger: Logger let _vlanExclusions: Set @@ -70,14 +68,13 @@ public final class MVRPApplication: BaseApplication, BaseApplicationEve try await controller.register(application: self) } - public var description: String { - let participants: String = _participants.withLock { String(describing: $0) } - return "MVRPApplication(controller: \(controller!), vlanExclusions: \(_vlanExclusions), participants: \(participants))" + public nonisolated var description: String { + "MVRPApplication(controller: \(controller!))" } - public var name: String { "MVRP" } + public nonisolated var name: String { "MVRP" } - public func deserialize( + public nonisolated func deserialize( attributeOfType attributeType: AttributeType, from input: inout ParserSpan ) throws -> any Value { @@ -89,7 +86,7 @@ public final class MVRPApplication: BaseApplication, BaseApplicationEve } } - public func makeNullValue(for attributeType: AttributeType) throws -> any Value { + public nonisolated func makeNullValue(for attributeType: AttributeType) throws -> any Value { guard let attributeType = MVRPAttributeType(rawValue: attributeType) else { throw MRPError.unknownAttributeType } switch attributeType { @@ -98,11 +95,11 @@ public final class MVRPApplication: BaseApplication, BaseApplicationEve } } - public func hasAttributeSubtype(for: AttributeType) -> Bool { + public nonisolated func hasAttributeSubtype(for: AttributeType) -> Bool { false } - public func administrativeControl(for attributeType: AttributeType) throws + public nonisolated func administrativeControl(for attributeType: AttributeType) throws -> AdministrativeControl { .normalParticipant @@ -114,8 +111,8 @@ public final class MVRPApplication: BaseApplication, BaseApplicationEve // Vector Attribute Type (11.2.3.1.6) and the attribute_value parameter // carries the value of the VID parameter carried in the // ES_REGISTER_VLAN_MEMBER primitive. - public func register(vlanMember: VLAN) async throws { - try await join( + public func register(vlanMember: VLAN) throws { + try join( attributeType: MVRPAttributeType.vid.rawValue, attributeValue: vlanMember, isNew: false, @@ -129,8 +126,8 @@ public final class MVRPApplication: BaseApplication, BaseApplicationEve // Vector Attribute Type (11.2.3.1.6) and the attribute_value parameter // carries the value of the VID parameter carried in the // ES_DEREGISTER_VLAN_MEMBER primitive. - public func deregister(vlanMember: VLAN) async throws { - try await leave( + public func deregister(vlanMember: VLAN) throws { + try leave( attributeType: MVRPAttributeType.vid.rawValue, attributeValue: vlanMember, for: MAPBaseSpanningTreeContext @@ -140,8 +137,8 @@ public final class MVRPApplication: BaseApplication, BaseApplicationEve public func periodic(for contextIdentifier: MAPContextIdentifier?) async throws { // 5.4.4 the Periodic Transmission state machine (10.7.10) is specifically // excluded from MSRP - try await apply(for: contextIdentifier) { participant in - try await participant.periodic() + try apply(for: contextIdentifier) { participant in + try participant.periodic() } } } @@ -221,7 +218,7 @@ extension MVRPApplication { ) async throws { guard let bridge = controller?.bridge as? any MVRPAwareBridge

, !bridge.hasLocalMVRPApplicant else { return } - try await join( + try join( attributeType: MVRPAttributeType.vid.rawValue, attributeValue: VLAN(contextIdentifier: contextIdentifier), isNew: true, @@ -232,21 +229,19 @@ extension MVRPApplication { func onContextUpdated( contextIdentifier: MAPContextIdentifier, with context: MAPContext

- ) throws {} + ) async throws {} func onContextRemoved( contextIdentifier: MAPContextIdentifier, with context: MAPContext

- ) throws { + ) async throws { guard let bridge = controller?.bridge as? any MVRPAwareBridge

, !bridge.hasLocalMVRPApplicant else { return } - Task { - try await leave( - attributeType: MVRPAttributeType.vid.rawValue, - attributeValue: VLAN(contextIdentifier: contextIdentifier), - for: MAPBaseSpanningTreeContext - ) - } + try leave( + attributeType: MVRPAttributeType.vid.rawValue, + attributeValue: VLAN(contextIdentifier: contextIdentifier), + for: MAPBaseSpanningTreeContext + ) } } diff --git a/Sources/MRP/Base/MRPController.swift b/Sources/MRP/Base/MRPController.swift index 595b2ae2..60b5fb06 100644 --- a/Sources/MRP/Base/MRPController.swift +++ b/Sources/MRP/Base/MRPController.swift @@ -116,7 +116,7 @@ public actor MRPController: Service, CustomStringConvertible, Sendable #endif try? await bridge.shutdown(controller: self) for port in ports { - try? _didRemove(port: port) + try? await _didRemove(port: port) } } @@ -206,13 +206,13 @@ public actor MRPController: Service, CustomStringConvertible, Sendable ) for contextIdentifier in removedContextIdentifiers { - try _didRemove(contextIdentifier: contextIdentifier, with: [port]) + try await _didRemove(contextIdentifier: contextIdentifier, with: [port]) } for contextIdentifier in updatedContextIdentifiers .union(isNewPort ? [] : [MAPBaseSpanningTreeContext]) { - try _didUpdate(contextIdentifier: contextIdentifier, with: [port]) + try await _didUpdate(contextIdentifier: contextIdentifier, with: [port]) } for contextIdentifier in addedContextIdentifiers @@ -222,7 +222,7 @@ public actor MRPController: Service, CustomStringConvertible, Sendable } } - private func _applyContextIdentifierChanges(beforeRemoving port: P) throws { + private func _applyContextIdentifierChanges(beforeRemoving port: P) async throws { let removedContextIdentifiers: Set guard let existingPort = ports.first(where: { $0.id == port.id }) else { return } @@ -234,7 +234,7 @@ public actor MRPController: Service, CustomStringConvertible, Sendable ) for contextIdentifier in [MAPBaseSpanningTreeContext] + removedContextIdentifiers { - try _didRemove(contextIdentifier: contextIdentifier, with: [port]) + try await _didRemove(contextIdentifier: contextIdentifier, with: [port]) } } @@ -247,10 +247,10 @@ public actor MRPController: Service, CustomStringConvertible, Sendable _ports[port.id] = port } - private func _didRemove(port: P) throws { + private func _didRemove(port: P) async throws { logger.debug("removed port \(port.id): \(port)") - try _applyContextIdentifierChanges(beforeRemoving: port) + try await _applyContextIdentifierChanges(beforeRemoving: port) _ports[port.id] = nil if timerConfiguration.periodicTime != .zero { _stopPeriodicTimer() } @@ -271,7 +271,7 @@ public actor MRPController: Service, CustomStringConvertible, Sendable case let .added(port): try await ports.contains(port) ? _didUpdate(port: port) : _didAdd(port: port) case let .removed(port): - try _didRemove(port: port) + try await _didRemove(port: port) case let .changed(port): try await _didUpdate(port: port) } @@ -394,18 +394,18 @@ public actor MRPController: Service, CustomStringConvertible, Sendable private func _didUpdate( contextIdentifier: MAPContextIdentifier, with context: MAPContext

- ) throws { + ) async throws { for application in _applications.values { - try application.didUpdate(contextIdentifier: contextIdentifier, with: context) + try await application.didUpdate(contextIdentifier: contextIdentifier, with: context) } } private func _didRemove( contextIdentifier: MAPContextIdentifier, with context: MAPContext

- ) throws { + ) async throws { for application in _applications.values { - try application.didRemove(contextIdentifier: contextIdentifier, with: context) + try await application.didRemove(contextIdentifier: contextIdentifier, with: context) } } diff --git a/Sources/MRP/Base/Utility.swift b/Sources/MRP/Base/Utility.swift index 64f58780..7c7b67f4 100644 --- a/Sources/MRP/Base/Utility.swift +++ b/Sources/MRP/Base/Utility.swift @@ -30,14 +30,6 @@ extension Weak: Equatable where T: Equatable { } } -// https://stackoverflow.com/questions/25329186/safe-bounds-checked-array-lookup-in-swift-through-optional-bindings -extension Collection { - /// Returns the element at the specified index if it is within bounds, otherwise nil. - subscript(safe index: Index) -> Element? { - indices.contains(index) ? self[index] : nil - } -} - extension Array { /// Creates an array from a collection, padding to a multiple of the specified value. init(_ collection: some Collection, multiple: Int, with element: Element) { @@ -61,18 +53,15 @@ public extension Sequence { return values } - func asyncCompactMap( - _ transform: (Element) async throws -> T? - ) async rethrows -> [T] { - var values = [T]() - + func asyncReduce( + into initialResult: Result, + _ updateAccumulatingResult: (inout Result, Element) async throws -> () + ) async rethrows -> Result { + var result = initialResult for element in self { - if let transformed = try await transform(element) { - values.append(transformed) - } + try await updateAccumulatingResult(&result, element) } - - return values + return result } } diff --git a/Sources/MRP/Model/Application.swift b/Sources/MRP/Model/Application.swift index e1866896..b674cb55 100644 --- a/Sources/MRP/Model/Application.swift +++ b/Sources/MRP/Model/Application.swift @@ -1,5 +1,5 @@ // -// Copyright (c) 2024 PADL Software Pty Ltd +// Copyright (c) 2024-2026 PADL Software Pty Ltd // // Licensed under the Apache License, Version 2.0 (the License); // you may not use this file except in compliance with the License. @@ -26,27 +26,27 @@ public enum AdministrativeControl { public typealias AttributeSubtype = UInt8 -public protocol Application

: AnyObject, Equatable, Hashable, Sendable { +public protocol Application

: Actor, Equatable, Hashable, Sendable { associatedtype P: Port typealias ApplyFunction = (Participant) throws -> T typealias AsyncApplyFunction = (Participant) async throws -> T - var validAttributeTypes: ClosedRange { get } - var groupAddress: EUI48 { get } - var etherType: UInt16 { get } - var protocolVersion: ProtocolVersion { get } - var hasAttributeListLength: Bool { get } - var name: String { get } + nonisolated var validAttributeTypes: ClosedRange { get } + nonisolated var groupAddress: EUI48 { get } + nonisolated var etherType: UInt16 { get } + nonisolated var protocolVersion: ProtocolVersion { get } + nonisolated var hasAttributeListLength: Bool { get } + nonisolated var name: String { get } - var controller: MRPController

? { get } + nonisolated var controller: MRPController

? { get } // notifications from controller when a port is added, didUpdated or removed // if contextIdentifier is MAPBaseSpanningTreeContext, the ports are physical // ports on the bridge; otherwise, they are virtual ports managed by MVRP. func didAdd(contextIdentifier: MAPContextIdentifier, with context: MAPContext

) async throws - func didUpdate(contextIdentifier: MAPContextIdentifier, with context: MAPContext

) throws - func didRemove(contextIdentifier: MAPContextIdentifier, with context: MAPContext

) throws + func didUpdate(contextIdentifier: MAPContextIdentifier, with context: MAPContext

) async throws + func didRemove(contextIdentifier: MAPContextIdentifier, with context: MAPContext

) async throws // apply for all participants. if contextIdentifier is nil, then all participants are called // regardless of contextIdentifier. @@ -62,12 +62,12 @@ public protocol Application

: AnyObject, Equatable, Hashable, Sendable { _ block: AsyncApplyFunction ) async rethrows -> [T] - func hasAttributeSubtype(for: AttributeType) -> Bool - func administrativeControl(for: AttributeType) throws -> AdministrativeControl - var nonBaseContextsSupported: Bool { get } + nonisolated func hasAttributeSubtype(for: AttributeType) -> Bool + nonisolated func administrativeControl(for: AttributeType) throws -> AdministrativeControl + nonisolated var nonBaseContextsSupported: Bool { get } - func makeNullValue(for attributeType: AttributeType) throws -> any Value - func deserialize( + nonisolated func makeNullValue(for attributeType: AttributeType) throws -> any Value + nonisolated func deserialize( attributeOfType attributeType: AttributeType, from input: inout ParserSpan ) throws -> any Value @@ -95,11 +95,11 @@ public protocol Application

: AnyObject, Equatable, Hashable, Sendable { } public extension Application { - func hash(into hasher: inout Hasher) { + nonisolated func hash(into hasher: inout Hasher) { etherType.hash(into: &hasher) } - static func == (lhs: Self, rhs: Self) -> Bool { + nonisolated static func == (lhs: Self, rhs: Self) -> Bool { lhs.etherType == rhs.etherType } } @@ -118,16 +118,6 @@ extension Application { } } - private func apply( - for contextIdentifier: MAPContextIdentifier, - with arg: T, - _ block: AsyncParticipantSpecificApplyFunction - ) async throws { - try await apply(for: contextIdentifier) { participant in - try await block(participant)(arg) - } - } - func findParticipants(for contextIdentifier: MAPContextIdentifier? = nil) -> [Participant] { @@ -152,9 +142,9 @@ extension Application { attributeValue: some Value, isNew: Bool, for contextIdentifier: MAPContextIdentifier - ) async throws { - try await apply(for: contextIdentifier) { participant in - try await participant.join( + ) throws { + try apply(for: contextIdentifier) { participant in + try participant.join( attributeType: attributeType, attributeSubtype: attributeSubtype, attributeValue: attributeValue, @@ -169,9 +159,9 @@ extension Application { attributeSubtype: AttributeSubtype? = nil, attributeValue: some Value, for contextIdentifier: MAPContextIdentifier - ) async throws { - try await apply(for: contextIdentifier) { participant in - try await participant.leave( + ) throws { + try apply(for: contextIdentifier) { participant in + try participant.leave( attributeType: attributeType, attributeSubtype: attributeSubtype, attributeValue: attributeValue, @@ -180,11 +170,11 @@ extension Application { } } - func rx(packet: IEEE802Packet, from port: P) async throws { + func rx(packet: IEEE802Packet, from port: P) throws { let pdu = try packet.payload.withParserSpan { input in try MRPDU(parsing: &input, application: self) } - try await rx( + try rx( pdu: pdu, for: MAPContextIdentifier(packet: packet), from: port, @@ -197,16 +187,16 @@ extension Application { for contextIdentifier: MAPContextIdentifier, from port: P, sourceMacAddress: EUI48 - ) async throws { + ) throws { let participant = try findParticipant(for: contextIdentifier, port: port) - try await participant.rx(pdu: pdu, sourceMacAddress: sourceMacAddress) + try participant.rx(pdu: pdu, sourceMacAddress: sourceMacAddress) } - func flush(for contextIdentifier: MAPContextIdentifier) async throws { - try await apply(for: contextIdentifier) { try await $0.flush() } + func flush(for contextIdentifier: MAPContextIdentifier) throws { + try apply(for: contextIdentifier) { try $0.flush() } } - func redeclare(for contextIdentifier: MAPContextIdentifier) async throws { - try await apply(for: contextIdentifier) { try await $0.redeclare() } + func redeclare(for contextIdentifier: MAPContextIdentifier) throws { + try apply(for: contextIdentifier) { try $0.redeclare() } } } diff --git a/Sources/MRP/Model/BaseApplication.swift b/Sources/MRP/Model/BaseApplication.swift index c4ffb290..60c17a83 100644 --- a/Sources/MRP/Model/BaseApplication.swift +++ b/Sources/MRP/Model/BaseApplication.swift @@ -21,7 +21,7 @@ protocol BaseApplication: Application where P == P { typealias MAPParticipantDictionary = [MAPContextIdentifier: Set>] var _controller: Weak> { get } - var _participants: Mutex { get } + var _participants: MAPParticipantDictionary { get set } } protocol BaseApplicationContextObserver

: BaseApplication { @@ -29,8 +29,14 @@ protocol BaseApplicationContextObserver

: BaseApplication { contextIdentifier: MAPContextIdentifier, with context: MAPContext

) async throws - func onContextUpdated(contextIdentifier: MAPContextIdentifier, with context: MAPContext

) throws - func onContextRemoved(contextIdentifier: MAPContextIdentifier, with context: MAPContext

) throws + func onContextUpdated( + contextIdentifier: MAPContextIdentifier, + with context: MAPContext

+ ) async throws + func onContextRemoved( + contextIdentifier: MAPContextIdentifier, + with context: MAPContext

+ ) async throws } protocol BaseApplicationEventObserver

: BaseApplication { @@ -61,12 +67,10 @@ extension BaseApplication { nonBaseContextsSupported || participant .contextIdentifier == MAPBaseSpanningTreeContext ) - _participants.withLock { - if let index = $0.index(forKey: participant.contextIdentifier) { - $0.values[index].insert(participant) - } else { - $0[participant.contextIdentifier] = Set([participant]) - } + if let index = _participants.index(forKey: participant.contextIdentifier) { + _participants.values[index].insert(participant) + } else { + _participants[participant.contextIdentifier] = Set([participant]) } } @@ -77,9 +81,7 @@ extension BaseApplication { nonBaseContextsSupported || participant .contextIdentifier == MAPBaseSpanningTreeContext ) - _ = _participants.withLock { - $0[participant.contextIdentifier]?.remove(participant) - } + _participants[participant.contextIdentifier]?.remove(participant) } @discardableResult @@ -87,13 +89,10 @@ extension BaseApplication { for contextIdentifier: MAPContextIdentifier? = nil, _ block: AsyncApplyFunction ) async rethrows -> [T] { - var participants: Set>? - _participants.withLock { - if let contextIdentifier { - participants = $0[contextIdentifier] - } else { - participants = Set($0.flatMap { Array($1) }) - } + let participants: Set>? = if let contextIdentifier { + _participants[contextIdentifier] + } else { + Set(_participants.flatMap { Array($1) }) } var ret = [T]() if let participants { @@ -109,13 +108,10 @@ extension BaseApplication { for contextIdentifier: MAPContextIdentifier? = nil, _ block: ApplyFunction ) rethrows -> [T] { - var participants: Set>? - _participants.withLock { - if let contextIdentifier { - participants = $0[contextIdentifier] - } else { - participants = Set($0.flatMap { Array($1) }) - } + let participants: Set>? = if let contextIdentifier { + _participants[contextIdentifier] + } else { + Set(_participants.flatMap { Array($1) }) } var ret = [T]() if let participants { @@ -143,7 +139,7 @@ extension BaseApplication { throw MRPError.portAlreadyExists } guard let controller else { throw MRPError.internalError } - let participant = await Participant( + let participant = Participant( controller: controller, application: self, port: port, @@ -164,32 +160,32 @@ extension BaseApplication { public func didUpdate( contextIdentifier: MAPContextIdentifier, with context: MAPContext

- ) throws { + ) async throws { if _isParticipantValid(contextIdentifier: contextIdentifier) { for port in context { let participant = try findParticipant( for: contextIdentifier, port: port ) - Task { try await participant.redeclare() } + try participant.redeclare() } } // also call this regardless of the value of nonBaseContextsSupported, so that // MVRP can be advised of VLAN changes on a port if let observer = self as? any BaseApplicationContextObserver

{ - try observer.onContextUpdated(contextIdentifier: contextIdentifier, with: context) + try await observer.onContextUpdated(contextIdentifier: contextIdentifier, with: context) } } public func didRemove( contextIdentifier: MAPContextIdentifier, with context: MAPContext

- ) throws { + ) async throws { // call observer _before_ removing participants so it can do any other cleanup // also call this regardless of the value of nonBaseContextsSupported, so that // MVRP can be advised of VLAN changes on a port if let observer = self as? any BaseApplicationContextObserver

{ - try observer.onContextRemoved(contextIdentifier: contextIdentifier, with: context) + try await observer.onContextRemoved(contextIdentifier: contextIdentifier, with: context) } if _isParticipantValid(contextIdentifier: contextIdentifier) { for port in context { @@ -197,7 +193,7 @@ extension BaseApplication { for: contextIdentifier, port: port ) - Task { try await participant.flush() } + try participant.flush() try remove(participant: participant) } } @@ -236,11 +232,11 @@ extension BaseApplication { attributeValue: some Value, isNew: Bool, eventSource: EventSource - ) async throws { + ) throws { guard shouldPropagate(eventSource: eventSource) else { return } - try await apply(for: contextIdentifier) { participant in + try apply(for: contextIdentifier) { participant in guard participant.port != port else { return } - try await participant.join( + try participant.join( attributeType: attributeType, attributeSubtype: attributeSubtype, attributeValue: attributeValue, @@ -275,7 +271,7 @@ extension BaseApplication { } catch MRPError.doNotPropagateAttribute { return } - try await _propagateJoinIndicated( + try _propagateJoinIndicated( contextIdentifier: contextIdentifier, port: port, attributeType: attributeType, @@ -293,11 +289,11 @@ extension BaseApplication { attributeSubtype: AttributeSubtype?, attributeValue: some Value, eventSource: EventSource - ) async throws { + ) throws { guard shouldPropagate(eventSource: eventSource) else { return } - try await apply(for: contextIdentifier) { participant in + try apply(for: contextIdentifier) { participant in guard participant.port != port else { return } - try await participant.leave( + try participant.leave( attributeType: attributeType, attributeSubtype: attributeSubtype, attributeValue: attributeValue, @@ -329,7 +325,7 @@ extension BaseApplication { } catch MRPError.doNotPropagateAttribute { return } - try await _propagateLeaveIndicated( + try _propagateLeaveIndicated( contextIdentifier: contextIdentifier, port: port, attributeType: attributeType, diff --git a/Sources/MRP/Model/Participant.swift b/Sources/MRP/Model/Participant.swift index 8c51b8eb..1297b5db 100644 --- a/Sources/MRP/Model/Participant.swift +++ b/Sources/MRP/Model/Participant.swift @@ -100,7 +100,9 @@ private enum EnqueuedEvent: Equatable, CustomStringConvertible { } } -public final actor Participant: Equatable, Hashable, CustomStringConvertible { +public final class Participant: Equatable, Hashable, CustomStringConvertible, + @unchecked Sendable +{ public static func == (lhs: Participant, rhs: Participant) -> Bool { lhs.application == rhs.application && lhs.port == rhs.port && lhs.contextIdentifier == rhs .contextIdentifier @@ -116,10 +118,11 @@ public final actor Participant: Equatable, Hashable, CustomStrin private var _attributes = [AttributeType: Set<_AttributeValue>]() private var _enqueuedEvents = EnqueuedEvents() - private var _leaveAll: LeaveAll! - private var _jointimer: Timer! + private nonisolated(unsafe) var _leaveAll: LeaveAll! + private nonisolated(unsafe) var _jointimer: Timer! private var _rxInProgress = false private var _transmissionOpportunityTimestamps: [ContinuousClock.Instant] = [] + private nonisolated let _controller: Weak> private nonisolated let _application: Weak @@ -137,7 +140,7 @@ public final actor Participant: Equatable, Hashable, CustomStrin port: A.P, contextIdentifier: MAPContextIdentifier, type: ParticipantType? = nil - ) async { + ) { _controller = Weak(controller) _application = Weak(application) self.contextIdentifier = contextIdentifier @@ -157,6 +160,10 @@ public final actor Participant: Equatable, Hashable, CustomStrin _logger.trace("\(self): initialized participant type \(_type)") } + private func _assertIsolatedToApplication() { + application?.assertIsolated("MRP Participant must be called from Application isolation") + } + public nonisolated var description: String { "\(application!.name)@\(port.name)" } @@ -167,7 +174,10 @@ public final actor Participant: Equatable, Hashable, CustomStrin // instance of this timer is required on a per-Port, per-MRP Participant // basis. The value of JoinTime used to initialize this timer is determined // in accordance with 10.7.11. - _jointimer = Timer(label: "jointimer", onExpiry: _onTxOpportunity) + _jointimer = Timer(label: "jointimer") { @Sendable [self] in + guard let application else { return } + try await _onTxOpportunity(isolation: application) + } // The Leave All Period Timer, leavealltimer, controls the frequency with // which the LeaveAll state machine generates LeaveAll PDUs. The timer is @@ -176,10 +186,12 @@ public final actor Participant: Equatable, Hashable, CustomStrin // All Period Timer is set to a random value, T, in the range LeaveAllTime // < T < 1.5 × LeaveAllTime when it is started. LeaveAllTime is defined in // Table 10-7. - _leaveAll = LeaveAll( - interval: controller!.timerConfiguration.leaveAllTime, - onLeaveAllTimerExpired: _onLeaveAllTimerExpired - ) + _leaveAll = LeaveAll(interval: controller!.timerConfiguration + .leaveAllTime) + { @Sendable [self] in + guard let application else { return } + try await _onLeaveAllTimerExpired(isolation: application) + } } deinit { @@ -187,20 +199,21 @@ public final actor Participant: Equatable, Hashable, CustomStrin _leaveAll?.stopLeaveAllTimer() } - @Sendable - private func _onLeaveAllTimerExpired() async throws { - try await _handleLeaveAll(protocolEvent: .leavealltimer, eventSource: .leaveAllTimer) + private func _onLeaveAllTimerExpired(isolation: isolated A) throws { + try _handleLeaveAll(protocolEvent: .leavealltimer, eventSource: .leaveAllTimer) } private func _apply( attributeType: AttributeType? = nil, matching filter: AttributeValueFilter = .matchAny, - _ block: AsyncParticipantApplyFunction - ) async rethrows { + _ block: ParticipantApplyFunction + ) rethrows { + _assertIsolatedToApplication() + for attribute in _attributes { for attributeValue in attribute.value { if !attributeValue.matches(attributeType: attributeType, matching: filter) { continue } - try await block(attributeValue) + try block(attributeValue) } } } @@ -209,10 +222,10 @@ public final actor Participant: Equatable, Hashable, CustomStrin attributeType: AttributeType? = nil, protocolEvent event: ProtocolEvent, eventSource: EventSource - ) async throws { + ) throws { _logger.trace("\(self): apply protocolEvent \(event), eventSource: \(eventSource)") - try await _apply(attributeType: attributeType) { attributeValue in - try await _handleAttributeValue( + try _apply(attributeType: attributeType) { attributeValue in + try _handleAttributeValue( attributeValue, protocolEvent: event, eventSource: eventSource @@ -259,8 +272,9 @@ public final actor Participant: Equatable, Hashable, CustomStrin _jointimer.start(interval: interval) } - @Sendable - private func _onTxOpportunity() async throws { + private func _onTxOpportunity(isolation: isolated A) async throws { + _assertIsolatedToApplication() + let eventSource = EventSource.joinTimer // Suppress TX opportunities while RX is processing to prevent interleaving @@ -295,11 +309,11 @@ public final actor Participant: Equatable, Hashable, CustomStrin switch _leaveAll.state { case .Active: // encode attributes first with current registrar states, then process LeaveAll - try await _apply(protocolEvent: .txLA, eventSource: eventSource) + try _apply(protocolEvent: .txLA, eventSource: eventSource) // sets LeaveAll to passive and emits sLA action - try await _handleLeaveAll(protocolEvent: .tx, eventSource: eventSource) + try _handleLeaveAll(protocolEvent: .tx, eventSource: eventSource) case .Passive: - try await _apply(protocolEvent: .tx, eventSource: eventSource) + try _apply(protocolEvent: .tx, eventSource: eventSource) } let didTransmit = try await _tx() @@ -326,6 +340,8 @@ public final actor Participant: Equatable, Hashable, CustomStrin matching filter: AttributeValueFilter, createIfMissing: Bool ) throws -> _AttributeValue { + _assertIsolatedToApplication() + if let attributeValue = _attributes[attributeType]? .first(where: { $0.matches(attributeType: attributeType, matching: filter) }) { @@ -358,13 +374,17 @@ public final actor Participant: Equatable, Hashable, CustomStrin attributeType: AttributeType, matching filter: AttributeValueFilter = .matchAny ) -> [_AttributeValue] { - (_attributes[attributeType] ?? []) + _assertIsolatedToApplication() + + return (_attributes[attributeType] ?? []) .filter { $0.matches(attributeType: attributeType, matching: filter) && $0.isRegistered } } fileprivate func _gcAttributeValue(_ attributeValue: _AttributeValue) { + _assertIsolatedToApplication() + if let index = _attributes.index(forKey: attributeValue.attributeType) { _attributes.values[index].remove(attributeValue) if _attributes.values[index].isEmpty { @@ -377,8 +397,8 @@ public final actor Participant: Equatable, Hashable, CustomStrin _ attributeValue: _AttributeValue, protocolEvent: ProtocolEvent, eventSource: EventSource - ) async throws { - try await attributeValue.handle( + ) throws { + try attributeValue.handle( protocolEvent: protocolEvent, eventSource: eventSource ) @@ -480,6 +500,8 @@ public final actor Participant: Equatable, Hashable, CustomStrin } private func _txEnqueue(_ event: EnqueuedEvent, eventSource: EventSource) { + _assertIsolatedToApplication() + if let index = _enqueuedEvents.index(forKey: event.attributeType) { if let eventIndex = _enqueuedEvents.values[index] .firstIndex(where: { $0 == event }) @@ -518,7 +540,7 @@ public final actor Participant: Equatable, Hashable, CustomStrin private func _handleLeaveAll( protocolEvent event: ProtocolEvent, eventSource: EventSource - ) async throws { + ) throws { let (action, txOpportunity) = _leaveAll.action(for: event) // may update state if txOpportunity { @@ -536,7 +558,7 @@ public final actor Participant: Equatable, Hashable, CustomStrin // the rLA! event is responsible for starting the leave timer on // registered attributes (Table 10-4), as well as requesting the // applicant to redeclare attributes (Table 10-3). - try await _apply(protocolEvent: .rLA, eventSource: eventSource) + try _apply(protocolEvent: .rLA, eventSource: eventSource) try _txEnqueueLeaveAllEvents(eventSource: eventSource) default: break @@ -544,7 +566,10 @@ public final actor Participant: Equatable, Hashable, CustomStrin } private func _txDequeue() throws -> MRPDU? { + _assertIsolatedToApplication() + guard let application else { throw MRPError.internalError } + let enqueuedMessages = try _packMessages(with: _enqueuedEvents) _enqueuedEvents.removeAll() @@ -558,11 +583,11 @@ public final actor Participant: Equatable, Hashable, CustomStrin return pdu } - private func rx(message: Message, eventSource: EventSource, leaveAll: inout Bool) async throws { + private func rx(message: Message, eventSource: EventSource, leaveAll: inout Bool) throws { for vectorAttribute in message.attributeList { // 10.6 Protocol operation: process LeaveAll first. if vectorAttribute.leaveAllEvent == .LeaveAll { - try await _apply( + try _apply( attributeType: message.attributeType, protocolEvent: .rLA, eventSource: eventSource @@ -595,11 +620,11 @@ public final actor Participant: Equatable, Hashable, CustomStrin .debug( "\(self): \(eventSource) declared attribute \(attribute) with new subtype \(attributeSubtype); replacing" ) - try? await attribute.rLvNow(eventSource: eventSource, suppressGC: true) + try? attribute.rLvNow(eventSource: eventSource, suppressGC: true) attribute.attributeSubtype = attributeSubtype } - try await _handleAttributeValue( + try _handleAttributeValue( attribute, protocolEvent: attributeEvent.protocolEvent, eventSource: eventSource @@ -608,8 +633,11 @@ public final actor Participant: Equatable, Hashable, CustomStrin } } - func rx(pdu: MRPDU, sourceMacAddress: EUI48) async throws { + func rx(pdu: MRPDU, sourceMacAddress: EUI48) throws { + _assertIsolatedToApplication() + _debugLogPdu(pdu, direction: .rx) + _rxInProgress = true defer { _rxInProgress = false } @@ -619,10 +647,10 @@ public final actor Participant: Equatable, Hashable, CustomStrin port.macAddress ) ? .local : .peer for message in pdu.messages { - try await rx(message: message, eventSource: eventSource, leaveAll: &leaveAll) + try rx(message: message, eventSource: eventSource, leaveAll: &leaveAll) } if leaveAll { - try await _handleLeaveAll(protocolEvent: .rLA, eventSource: eventSource) + try _handleLeaveAll(protocolEvent: .rLA, eventSource: eventSource) } } @@ -731,9 +759,9 @@ public extension Participant { // associated with a given Port and spanning tree instance, this event is // generated when the Port Role changes from either Root Port or Alternate // Port to Designated Port. - func flush() async throws { - try await _apply(protocolEvent: .Flush, eventSource: .internal) - try await _handleLeaveAll(protocolEvent: .Flush, eventSource: .internal) + func flush() throws { + try _apply(protocolEvent: .Flush, eventSource: .internal) + try _handleLeaveAll(protocolEvent: .Flush, eventSource: .internal) } // A Re-declare! event signals to the Applicant and Registrar state machines @@ -745,8 +773,8 @@ public extension Participant { // Registrar state machines associated with a given Port and spanning tree // instance, this event is generated when the Port Role changes from // Designated Port to either Root Port or Alternate Port. - func redeclare() async throws { - try await _apply(protocolEvent: .ReDeclare, eventSource: .internal) + func redeclare() throws { + try _apply(protocolEvent: .ReDeclare, eventSource: .internal) } func join( @@ -755,7 +783,7 @@ public extension Participant { attributeValue: some Value, isNew: Bool, eventSource: EventSource - ) async throws { + ) throws { let attribute = try _findOrCreateAttribute( attributeType: attributeType, attributeSubtype: attributeSubtype, @@ -768,11 +796,11 @@ public extension Participant { "\(self): \(eventSource) declared attribute \(attribute) with new subtype \(attributeSubtype); replacing" ) - try? await attribute.handle(protocolEvent: .Lv, eventSource: eventSource, suppressGC: true) + try? attribute.handle(protocolEvent: .Lv, eventSource: eventSource, suppressGC: true) attribute.attributeSubtype = attributeSubtype } - try await _handleAttributeValue( + try _handleAttributeValue( attribute, protocolEvent: isNew ? .New : .Join, eventSource: eventSource @@ -784,7 +812,7 @@ public extension Participant { attributeSubtype: AttributeSubtype? = nil, attributeValue: some Value, eventSource: EventSource - ) async throws { + ) throws { let attribute = try _findOrCreateAttribute( attributeType: attributeType, attributeSubtype: attributeSubtype, @@ -792,7 +820,7 @@ public extension Participant { createIfMissing: false ) - try await _handleAttributeValue( + try _handleAttributeValue( attribute, protocolEvent: .Lv, eventSource: eventSource @@ -803,7 +831,7 @@ public extension Participant { attributeType: AttributeType, attributeValue: some Value, eventSource: EventSource - ) async throws { + ) throws { let attribute = try _findOrCreateAttribute( attributeType: attributeType, attributeSubtype: nil, @@ -811,16 +839,16 @@ public extension Participant { createIfMissing: false ) - try await _handleAttributeValue( + try _handleAttributeValue( attribute, protocolEvent: .rLvNow, eventSource: eventSource ) } - func periodic() async throws { + func periodic() throws { _logger.trace("\(self): running periodic") - try await _apply(protocolEvent: .periodic, eventSource: .periodicTimer) + try _apply(protocolEvent: .periodic, eventSource: .periodicTimer) // timer is restarted by the caller } } @@ -828,10 +856,13 @@ public extension Participant { // MARK: - for use by REST APIs extension Participant { - func findAllAttributes( - matching filter: AttributeValueFilter = .matchAny + private func findAllAttributesUnchecked( + matching filter: AttributeValueFilter, + isolation: isolated A ) -> [AttributeValue] { - _attributes.values.flatMap { $0 } + _assertIsolatedToApplication() + + return _attributes.values.flatMap { $0 } .map { AttributeValue( attributeType: $0.attributeType, attributeSubtype: $0.attributeSubtype, @@ -842,10 +873,20 @@ extension Participant { } func findAllAttributes( - attributeType: AttributeType, matching filter: AttributeValueFilter = .matchAny + ) async -> [AttributeValue] { + guard let application else { return [] } + return await findAllAttributesUnchecked(matching: filter, isolation: application) + } + + func findAllAttributesUnchecked( + attributeType: AttributeType, + matching filter: AttributeValueFilter, + isolation: isolated A ) -> [AttributeValue] { - (_attributes[attributeType] ?? []) + _assertIsolatedToApplication() + + return (_attributes[attributeType] ?? []) .filter { $0.matches(attributeType: attributeType, matching: filter) } .map { AttributeValue( attributeType: attributeType, @@ -855,10 +896,19 @@ extension Participant { registrarState: $0.registrarState ) } } -} -private typealias AsyncParticipantApplyFunction = - @Sendable (_AttributeValue) async throws -> () + func findAllAttributes( + attributeType: AttributeType, + matching filter: AttributeValueFilter = .matchAny + ) async -> [AttributeValue] { + guard let application else { return [] } + return await findAllAttributesUnchecked( + attributeType: attributeType, + matching: filter, + isolation: application + ) + } +} private typealias ParticipantApplyFunction = @Sendable (_AttributeValue) throws -> () @@ -958,10 +1008,12 @@ private final class _AttributeValue: Sendable, Hashable, Equatab _attributeSubtype = .init(subtype) self.value = AnyValue(value) if participant._type != .applicantOnly { - registrar = Registrar( - leaveTime: participant.controller!.timerConfiguration.leaveTime, - onLeaveTimerExpired: _onLeaveTimerExpired - ) + registrar = Registrar(leaveTime: participant.controller!.timerConfiguration + .leaveTime) + { @Sendable [self] in + guard let application = participant.application else { return } + try await _onLeaveTimerExpired(isolation: application) + } } } @@ -969,9 +1021,8 @@ private final class _AttributeValue: Sendable, Hashable, Equatab registrar?.stopLeaveTimer() } - @Sendable - private func _onLeaveTimerExpired() async throws { - try await handle( + private func _onLeaveTimerExpired(isolation: isolated A) throws { + try handle( protocolEvent: .leavetimer, eventSource: .leaveTimer ) @@ -1011,7 +1062,7 @@ private final class _AttributeValue: Sendable, Hashable, Equatab private func _getEventContext( for event: ProtocolEvent, eventSource: EventSource, - isolation participant: isolated P + participant: P ) throws -> EventContext { try EventContext( participant: participant, @@ -1030,13 +1081,13 @@ private final class _AttributeValue: Sendable, Hashable, Equatab protocolEvent event: ProtocolEvent, eventSource: EventSource, suppressGC: Bool = false - ) async throws { + ) throws { guard let participant else { throw MRPError.internalError } - try await _handle( + try _handle( protocolEvent: event, eventSource: eventSource, suppressGC: suppressGC, - isolation: participant + participant: participant ) } @@ -1044,12 +1095,16 @@ private final class _AttributeValue: Sendable, Hashable, Equatab protocolEvent event: ProtocolEvent, eventSource: EventSource, suppressGC: Bool = false, - isolation participant: isolated P - ) async throws { - let context = try _getEventContext(for: event, eventSource: eventSource, isolation: participant) + participant: P + ) throws { + let context = try _getEventContext( + for: event, + eventSource: eventSource, + participant: participant + ) - try await _handleRegistrar(context: context, isolation: context.participant) - try await _handleApplicant(context: context, isolation: context.participant) + try _handleRegistrar(context: context, participant: context.participant) + try _handleApplicant(context: context, participant: context.participant) // remove attribute entirely if it is no longer declared or registered if !suppressGC, canGC { participant._gcAttributeValue(self) } @@ -1057,7 +1112,7 @@ private final class _AttributeValue: Sendable, Hashable, Equatab private func _handleApplicant( context: EventContext, - isolation participant: isolated P + participant: P ) throws { participant._logger.trace("\(context.participant): handling applicant \(context)") @@ -1116,8 +1171,8 @@ private final class _AttributeValue: Sendable, Hashable, Equatab private func _handleRegistrar( context: EventContext, - isolation participant: isolated P - ) async throws { + participant: P + ) throws { context.participant._logger.trace("\(context.participant): handling registrar \(context)") guard let registrarAction = context.registrar? @@ -1132,36 +1187,39 @@ private final class _AttributeValue: Sendable, Hashable, Equatab ) guard let application = context.participant.application else { throw MRPError.internalError } - switch registrarAction { - case .New: - fallthrough - case .Join: - try await application.joinIndicated( - contextIdentifier: context.participant.contextIdentifier, - port: context.participant.port, - attributeType: context.attributeType, - attributeSubtype: context.attributeSubtype, - attributeValue: context.attributeValue, - isNew: registrarAction == .New, - eventSource: context.eventSource - ) - case .Lv: - try await application.leaveIndicated( - contextIdentifier: context.participant.contextIdentifier, - port: context.participant.port, - attributeType: context.attributeType, - attributeSubtype: context.attributeSubtype, - attributeValue: context.attributeValue, - eventSource: context.eventSource - ) + + Task { @Sendable in + switch registrarAction { + case .New: + fallthrough + case .Join: + try await application.joinIndicated( + contextIdentifier: context.participant.contextIdentifier, + port: context.participant.port, + attributeType: context.attributeType, + attributeSubtype: context.attributeSubtype, + attributeValue: context.attributeValue, + isNew: registrarAction == .New, + eventSource: context.eventSource + ) + case .Lv: + try await application.leaveIndicated( + contextIdentifier: context.participant.contextIdentifier, + port: context.participant.port, + attributeType: context.attributeType, + attributeSubtype: context.attributeSubtype, + attributeValue: context.attributeValue, + eventSource: context.eventSource + ) + } } } fileprivate func rLvNow( eventSource: EventSource, suppressGC: Bool = false - ) async throws { - try await handle( + ) throws { + try handle( protocolEvent: .rLvNow, eventSource: eventSource, suppressGC: suppressGC diff --git a/Sources/MRP/RestApi/RestApiHandler.swift b/Sources/MRP/RestApi/RestApiHandler.swift index 2372d224..cc581233 100644 --- a/Sources/MRP/RestApi/RestApiHandler.swift +++ b/Sources/MRP/RestApi/RestApiHandler.swift @@ -159,9 +159,9 @@ extension RestApiApplicationHandler { } /// Validates application and finds participant for port - func getParticipant(for port: P) throws -> Participant { + func getParticipant(for port: P) async throws -> Participant { let application = try requireApplication() - return try application.findParticipant(port: port) + return try await application.findParticipant(port: port) } /// Validates and extracts application, port, and participant from request @@ -170,7 +170,7 @@ extension RestApiApplicationHandler { { let application = try requireApplication() let (port, _) = try await getPort(from: request) - let participant = try application.findParticipant(port: port) + let participant = try await application.findParticipant(port: port) return (application, port, participant) } }