diff --git a/Sources/Core/Services/CaptureService+Internals.swift b/Sources/Core/Services/CaptureService+Internals.swift index 7636117..2e2542b 100644 --- a/Sources/Core/Services/CaptureService+Internals.swift +++ b/Sources/Core/Services/CaptureService+Internals.swift @@ -66,7 +66,7 @@ extension CaptureService { } func recordingOutputDidFail(_ error: Error) { - let (continuations, fallbackHandler) = withStateLock { () -> (CaptureContinuations, ((Error) -> Void)?) in + let (continuations, fallbackHandler) = withStateLock { () -> (Continuations, ((Error) -> Void)?) in let continuations = drainContinuationsLocked() let handler = continuations.hasActive ? nil : errorHandler phase = .idle @@ -87,7 +87,7 @@ extension CaptureService { } func handleRecordingOutputDidFinish() { - let action: CaptureCompletionAction = withStateLock { + let action: CompletionAction = withStateLock { finalizeActiveSegmentLocked() if let pauseContinuation { @@ -122,7 +122,7 @@ extension CaptureService { } func handleStreamDidStop(with error: Error) { - let (continuations, fallbackHandler) = withStateLock { () -> (CaptureContinuations, ((Error) -> Void)?) in + let (continuations, fallbackHandler) = withStateLock { () -> (Continuations, ((Error) -> Void)?) in let continuations = drainContinuationsLocked() let handler = continuations.hasActive ? nil : errorHandler stream = nil @@ -200,8 +200,8 @@ extension CaptureService { activeSegmentURL = nil } - func drainContinuationsLocked() -> CaptureContinuations { - let continuations = CaptureContinuations( + func drainContinuationsLocked() -> Continuations { + let continuations = Continuations( pause: pauseContinuation, finish: finishContinuation, discard: discardContinuation @@ -212,7 +212,7 @@ extension CaptureService { return continuations } - func resume(continuations: CaptureContinuations, throwing error: Error) { + func resume(continuations: Continuations, throwing error: Error) { continuations.pause?.resume(throwing: error) continuations.finish?.resume(throwing: error) continuations.discard?.resume(throwing: error) diff --git a/Sources/Core/Services/CaptureService+Recording.swift b/Sources/Core/Services/CaptureService+Recording.swift index f8885f9..f7808e3 100644 --- a/Sources/Core/Services/CaptureService+Recording.swift +++ b/Sources/Core/Services/CaptureService+Recording.swift @@ -121,6 +121,11 @@ extension CaptureService: CaptureServicing { } public func stopRecording() async throws -> URL { + guard await waitForOutputTransitionToSettle() else { + try? await stopAndReset(clearSegmentState: false) + throw AppError.captureFailed + } + let mode = try resolveStopMode() try await finalizeStopIfNeeded(mode) @@ -140,6 +145,11 @@ extension CaptureService: CaptureServicing { } public func discardRecording() async { + guard await waitForOutputTransitionToSettle() else { + try? await stopAndReset() + return + } + let mode = resolveDiscardMode() guard case .noOp = mode else { await finalizeDiscardIfNeeded(mode) @@ -194,7 +204,26 @@ extension CaptureService: CaptureServicing { } extension CaptureService { - func resolveStopMode() throws -> CaptureStopMode { + func waitForOutputTransitionToSettle( + maxPollCount: Int = 100, + pollIntervalNanoseconds: UInt64 = 10_000_000 + ) async -> Bool { + for _ in 0 ..< maxPollCount { + let inTransition = withStateLock { + phase == .pausing || phase == .resuming + } + if !inTransition { + return true + } + try? await Task.sleep(nanoseconds: pollIntervalNanoseconds) + } + + return withStateLock { + phase != .pausing && phase != .resuming + } + } + + func resolveStopMode() throws -> StopMode { try withStateLock { switch phase { case .recording: @@ -217,7 +246,7 @@ extension CaptureService { } } - func finalizeStopIfNeeded(_ mode: CaptureStopMode) async throws { + func finalizeStopIfNeeded(_ mode: StopMode) async throws { guard case let .finalize(activeStream, activeOutput) = mode else { return } @@ -239,7 +268,7 @@ extension CaptureService { } } - func stopStitchInput() throws -> CaptureStitchInput { + func stopStitchInput() throws -> StitchInput { try withStateLock { guard let baseOutputURL, let outputFileType = resolvedOutputFileType @@ -247,7 +276,7 @@ extension CaptureService { throw AppError.captureFailed } - return CaptureStitchInput( + return StitchInput( baseOutputURL: baseOutputURL, outputFileType: outputFileType, segments: segmentURLs @@ -255,7 +284,7 @@ extension CaptureService { } } - func resolveDiscardMode() -> CaptureDiscardMode { + func resolveDiscardMode() -> DiscardMode { withStateLock { switch phase { case .recording: @@ -278,7 +307,7 @@ extension CaptureService { } } - func finalizeDiscardIfNeeded(_ mode: CaptureDiscardMode) async { + func finalizeDiscardIfNeeded(_ mode: DiscardMode) async { guard case let .finalize(activeStream, activeOutput) = mode else { return } diff --git a/Sources/Core/Services/CaptureService+TestSupport.swift b/Sources/Core/Services/CaptureService+TestSupport.swift index 112a2ef..ac6c7c0 100644 --- a/Sources/Core/Services/CaptureService+TestSupport.swift +++ b/Sources/Core/Services/CaptureService+TestSupport.swift @@ -11,6 +11,7 @@ import Foundation videoCodec: AVVideoCodecType = .h264, segmentURLs: [URL] = [], segmentIndex: Int = 0, + activeSegmentURL: URL? = nil, paused: Bool ) { withStateLock { @@ -22,7 +23,7 @@ import Foundation self.segmentURLs = segmentURLs self.segmentIndex = segmentIndex phase = paused ? .paused : .recording - activeSegmentURL = nil + self.activeSegmentURL = activeSegmentURL } } diff --git a/Sources/Core/Services/CaptureService.swift b/Sources/Core/Services/CaptureService.swift index c60395f..ae3d37e 100644 --- a/Sources/Core/Services/CaptureService.swift +++ b/Sources/Core/Services/CaptureService.swift @@ -2,51 +2,51 @@ import AVFoundation import Foundation import ScreenCaptureKit -enum CaptureSessionPhase: String { - case idle - case recording - case pausing - case paused - case resuming - case stopping - case discarding -} +public final class CaptureService: NSObject { + enum SessionPhase: String { + case idle + case recording + case pausing + case paused + case resuming + case stopping + case discarding + } -struct CaptureContinuations { - var pause: CheckedContinuation? - var finish: CheckedContinuation? - var discard: CheckedContinuation? + struct Continuations { + var pause: CheckedContinuation? + var finish: CheckedContinuation? + var discard: CheckedContinuation? - var hasActive: Bool { - pause != nil || finish != nil || discard != nil + var hasActive: Bool { + pause != nil || finish != nil || discard != nil + } } -} -enum CaptureCompletionAction { - case pause(CheckedContinuation) - case stop(CheckedContinuation) - case discard(CheckedContinuation) - case none -} + enum CompletionAction { + case pause(CheckedContinuation) + case stop(CheckedContinuation) + case discard(CheckedContinuation) + case none + } -enum CaptureStopMode { - case finalize(any CaptureStreamControlling, any CaptureRecordingOutputControlling) - case stitchOnly -} + enum StopMode { + case finalize(any CaptureStreamControlling, any CaptureRecordingOutputControlling) + case stitchOnly + } -enum CaptureDiscardMode { - case finalize(any CaptureStreamControlling, any CaptureRecordingOutputControlling) - case stopOnly - case noOp -} + enum DiscardMode { + case finalize(any CaptureStreamControlling, any CaptureRecordingOutputControlling) + case stopOnly + case noOp + } -struct CaptureStitchInput { - let baseOutputURL: URL - let outputFileType: AVFileType - let segments: [URL] -} + struct StitchInput { + let baseOutputURL: URL + let outputFileType: AVFileType + let segments: [URL] + } -public final class CaptureService: NSObject { var stream: (any CaptureStreamControlling)? var recordingOutput: (any CaptureRecordingOutputControlling)? var activeSegmentURL: URL? @@ -58,7 +58,7 @@ public final class CaptureService: NSObject { var resolvedOutputFileType: AVFileType? var resolvedVideoCodec: AVVideoCodecType? - var phase: CaptureSessionPhase = .idle + var phase: SessionPhase = .idle var finishContinuation: CheckedContinuation? var pauseContinuation: CheckedContinuation? var discardContinuation: CheckedContinuation? diff --git a/Tests/CoreTests/CaptureServiceRefactorTests.swift b/Tests/CoreTests/CaptureServiceRefactorTests.swift index 861e4e5..be1c8ff 100644 --- a/Tests/CoreTests/CaptureServiceRefactorTests.swift +++ b/Tests/CoreTests/CaptureServiceRefactorTests.swift @@ -39,6 +39,88 @@ final class CaptureServiceRefactorTests: XCTestCase { XCTAssertEqual(service.phaseForTesting(), "paused") } + func testStopRecordingWaitsForPauseTransitionThenStops() async throws { + let stream = MockCaptureStream() + let service = makeService(stream: stream) + let folder = makeTemporaryDirectory() + let finalURL = folder.appendingPathComponent("final-stop.mp4") + let activeSegmentURL = folder.appendingPathComponent("segment-stop.mp4") + try Data(repeating: 7, count: 16).write(to: activeSegmentURL) + + service.installTestSession( + stream: stream, + recordingOutput: MockRecordingOutput(), + baseOutputURL: finalURL, + outputFileType: .mp4, + activeSegmentURL: activeSegmentURL, + paused: false + ) + + let pauseTask = Task { + try await service.pauseRecording() + } + + await waitUntil { stream.removeRecordingOutputCallCount == 1 } + + let stopTask = Task { + try await service.stopRecording() + } + + try await Task.sleep(nanoseconds: 50_000_000) + XCTAssertEqual(stream.stopCaptureCallCount, 0) + XCTAssertEqual(service.phaseForTesting(), "pausing") + + service.handleRecordingOutputDidFinish() + + try await pauseTask.value + let stoppedURL = try await stopTask.value + XCTAssertEqual(stoppedURL.standardizedFileURL, finalURL.standardizedFileURL) + XCTAssertEqual(stream.stopCaptureCallCount, 1) + XCTAssertEqual(service.phaseForTesting(), "idle") + XCTAssertTrue(FileManager.default.fileExists(atPath: finalURL.path)) + } + + func testDiscardRecordingWaitsForPauseTransitionThenTearsDown() async throws { + let stream = MockCaptureStream() + let service = makeService(stream: stream) + let folder = makeTemporaryDirectory() + let finalURL = folder.appendingPathComponent("final-discard.mp4") + let activeSegmentURL = folder.appendingPathComponent("segment-discard.mp4") + try Data(repeating: 3, count: 32).write(to: activeSegmentURL) + + service.installTestSession( + stream: stream, + recordingOutput: MockRecordingOutput(), + baseOutputURL: finalURL, + outputFileType: .mp4, + activeSegmentURL: activeSegmentURL, + paused: false + ) + + let pauseTask = Task { + try await service.pauseRecording() + } + + await waitUntil { stream.removeRecordingOutputCallCount == 1 } + + let discardTask = Task { + await service.discardRecording() + } + + try await Task.sleep(nanoseconds: 50_000_000) + XCTAssertEqual(stream.stopCaptureCallCount, 0) + XCTAssertEqual(service.phaseForTesting(), "pausing") + + service.handleRecordingOutputDidFinish() + + try await pauseTask.value + await discardTask.value + + XCTAssertEqual(stream.stopCaptureCallCount, 1) + XCTAssertEqual(service.phaseForTesting(), "idle") + XCTAssertFalse(FileManager.default.fileExists(atPath: activeSegmentURL.path)) + } + func testResumeRecordingFailureRollsBackPhaseAndSegmentIndex() { let stream = MockCaptureStream() stream.addRecordingOutputError = AppError.captureFailed