diff --git a/RemoteViewer.slnx b/RemoteViewer.slnx index 028f70a..b9551d8 100644 --- a/RemoteViewer.slnx +++ b/RemoteViewer.slnx @@ -27,5 +27,6 @@ + diff --git a/src/RemoteViewer.Client/RemoteViewer.Client.csproj b/src/RemoteViewer.Client/RemoteViewer.Client.csproj index 0655023..ddf6fd3 100644 --- a/src/RemoteViewer.Client/RemoteViewer.Client.csproj +++ b/src/RemoteViewer.Client/RemoteViewer.Client.csproj @@ -15,7 +15,6 @@ - @@ -54,4 +53,5 @@ runtime; build; native; contentfiles; analyzers; buildtransitive + diff --git a/src/RemoteViewer.Client/Services/HubClient/Connection.cs b/src/RemoteViewer.Client/Services/HubClient/Connection.cs index b67ff02..9f2f445 100644 --- a/src/RemoteViewer.Client/Services/HubClient/Connection.cs +++ b/src/RemoteViewer.Client/Services/HubClient/Connection.cs @@ -1,4 +1,4 @@ -using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using RemoteViewer.Client.Common; using RemoteViewer.Client.Services.FileTransfer; @@ -463,6 +463,10 @@ async void IConnectionImpl.OnMessageReceived(string senderClientId, string messa { var message = ProtocolSerializer.Deserialize(data); ((IViewerServiceImpl)this.ViewerService!).HandleFrame(message.DisplayId, message.FrameNumber, message.Codec, message.Regions); + + if (this.Owner.Options.SuppressAutoFrameAck is false) + await this.Owner.SendAckFrameAsync(this.ConnectionId); + break; } diff --git a/src/RemoteViewer.Client/Services/HubClient/ConnectionHubClient.cs b/src/RemoteViewer.Client/Services/HubClient/ConnectionHubClient.cs index 6f6dc34..ce1027c 100644 --- a/src/RemoteViewer.Client/Services/HubClient/ConnectionHubClient.cs +++ b/src/RemoteViewer.Client/Services/HubClient/ConnectionHubClient.cs @@ -1,4 +1,4 @@ -using System.Collections.Concurrent; +using System.Collections.Concurrent; using Microsoft.AspNetCore.SignalR.Client; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; @@ -23,13 +23,15 @@ public ConnectionHubClient( { this._logger = logger; this._serviceProvider = serviceProvider; + this.Options = options.Value; this._connection = new HubConnectionBuilder() - .WithUrl($"{options.Value.BaseUrl}/connection", httpOptions => + .WithUrl($"{this.Options.BaseUrl}/connection", httpOptions => { httpOptions.Headers.Add("X-Client-Version", ThisAssembly.AssemblyInformationalVersion); httpOptions.Headers.Add("X-Display-Name", this.DisplayName); }) + .WithAutomaticReconnect() .AddMessagePackProtocol(Witness.GeneratedTypeShapeProvider) .Build(); @@ -174,6 +176,8 @@ private void CloseAllConnections() this._connections.Clear(); } + public ConnectionHubClientOptions Options { get; } + public string? ClientId { get; private set; } public string? Username { get; private set; } public string? Password { get; private set; } @@ -260,9 +264,9 @@ public async Task ConnectToHub() return error; } - catch (Exception ex) when (!this.IsConnected) + catch (Exception ex) { - this._logger.LogWarning(ex, "Failed to connect to device - hub disconnected"); + this._logger.LogWarning(ex, "Failed to connect to device"); return null; } } @@ -278,9 +282,9 @@ public async Task GenerateNewPassword() await this._connection.InvokeAsync("GenerateNewPassword"); this._logger.LogInformation("New password generated"); } - catch (Exception ex) when (!this.IsConnected) + catch (Exception ex) { - this._logger.LogWarning(ex, "Failed to generate new password - hub disconnected"); + this._logger.LogWarning(ex, "Failed to generate new password"); } } @@ -297,9 +301,9 @@ public async Task SetDisplayName(string displayName) await this._connection.InvokeAsync("SetDisplayName", displayName); this._logger.LogInformation("Display name set successfully"); } - catch (Exception ex) when (!this.IsConnected) + catch (Exception ex) { - this._logger.LogWarning(ex, "Failed to set display name - hub disconnected"); + this._logger.LogWarning(ex, "Failed to set display name"); } } @@ -321,9 +325,9 @@ internal async Task SendMessageAsync(string connectionId, string messageType, Re await this._connection.SendAsync("SendMessage", connectionId, messageType, data, destination, targetClientIds); this._logger.LogDebug("Message sent successfully"); } - catch (Exception ex) when (!this.IsConnected) + catch (Exception ex) { - this._logger.LogWarning(ex, "Failed to send message - hub disconnected"); + this._logger.LogWarning(ex, "Failed to send message"); } } @@ -338,9 +342,9 @@ internal async Task DisconnectAsync(string connectionId) await this._connection.InvokeAsync("Disconnect", connectionId); this._logger.LogInformation("Disconnected from connection: {ConnectionId}", connectionId); } - catch (Exception ex) when (!this.IsConnected) + catch (Exception ex) { - this._logger.LogWarning(ex, "Failed to disconnect - hub disconnected"); + this._logger.LogWarning(ex, "Failed to disconnect"); } } @@ -354,9 +358,9 @@ internal async Task SetConnectionPropertiesAsync(string connectionId, Connection this._logger.LogDebug("Setting connection properties - ConnectionId: {ConnectionId}", connectionId); await this._connection.InvokeAsync("SetConnectionProperties", connectionId, properties); } - catch (Exception ex) when (!this.IsConnected) + catch (Exception ex) { - this._logger.LogWarning(ex, "Failed to set connection properties - hub disconnected"); + this._logger.LogWarning(ex, "Failed to set connection properties"); } } @@ -370,12 +374,28 @@ internal async Task SetConnectionPropertiesAsync(string connectionId, Connection this._logger.LogDebug("Generating IPC auth token for connection: {ConnectionId}", connectionId); return await this._connection.InvokeAsync("GenerateIpcAuthToken", connectionId); } - catch (Exception ex) when (!this.IsConnected) + catch (Exception ex) { - this._logger.LogWarning(ex, "Failed to generate IPC auth token - hub disconnected"); + this._logger.LogWarning(ex, "Failed to generate IPC auth token"); return null; } } + + internal async Task SendAckFrameAsync(string connectionId) + { + if (!this.IsConnected || this.IsReconnecting) + return; + + try + { + await this._connection.SendAsync("AckFrame", connectionId); + } + catch (Exception ex) + { + this._logger.LogWarning(ex, "Failed to send frame ack"); + } + } + } #region EventArgs Classes diff --git a/src/RemoteViewer.Client/Services/HubClient/ConnectionHubClientOptions.cs b/src/RemoteViewer.Client/Services/HubClient/ConnectionHubClientOptions.cs index 2eef82c..bb50be5 100644 --- a/src/RemoteViewer.Client/Services/HubClient/ConnectionHubClientOptions.cs +++ b/src/RemoteViewer.Client/Services/HubClient/ConnectionHubClientOptions.cs @@ -1,6 +1,5 @@ -using Microsoft.AspNetCore.Http.Connections; +namespace RemoteViewer.Client.Services.HubClient; -namespace RemoteViewer.Client.Services.HubClient; public class ConnectionHubClientOptions { @@ -9,4 +8,6 @@ public class ConnectionHubClientOptions #else public string BaseUrl { get; set; } = "https://rdp.xemio.net"; #endif + + public bool SuppressAutoFrameAck { get; set; } } diff --git a/src/RemoteViewer.Server/Hubs/ConnectionHub.cs b/src/RemoteViewer.Server/Hubs/ConnectionHub.cs index 19c0e15..772e0e5 100644 --- a/src/RemoteViewer.Server/Hubs/ConnectionHub.cs +++ b/src/RemoteViewer.Server/Hubs/ConnectionHub.cs @@ -79,6 +79,12 @@ public async Task SendMessage(string connectionId, string messageType, byte[] da await clientsService.SendMessage(this.Context.ConnectionId, connectionId, messageType, data, destination, targetClientIds); } + public Task AckFrame(string connectionId) + { + return clientsService.AckFrame(this.Context.ConnectionId, connectionId); + } + + public async Task SetConnectionProperties(string connectionId, ConnectionProperties properties) { await clientsService.SetConnectionProperties(this.Context.ConnectionId, connectionId, properties); diff --git a/src/RemoteViewer.Server/Orleans/Grains/ClientGrain.cs b/src/RemoteViewer.Server/Orleans/Grains/ClientGrain.cs index cea8276..05a54a0 100644 --- a/src/RemoteViewer.Server/Orleans/Grains/ClientGrain.cs +++ b/src/RemoteViewer.Server/Orleans/Grains/ClientGrain.cs @@ -35,6 +35,8 @@ public sealed partial class ClientGrain(ILogger logger, IHubContext private string _displayName = string.Empty; + private IClientSendGrain? _sendGrain; + private IConnectionGrain? _presenterConnectionGrain; private readonly List _viewerConnectionGrains = []; @@ -67,6 +69,8 @@ public async Task Initialize(string? displayName) this.LogUsernameCollision(attempts); } + this._sendGrain = this.GrainFactory.GetGrain(this.GetPrimaryKeyString()); + this.LogClientInitialized(this._clientId, this._usernameGrain.GetPrimaryKeyString()); await hubContext.Clients @@ -93,6 +97,7 @@ public async Task Deactivate() await connection.Internal_RemoveClient(this.AsReference()); } + await this._sendGrain.Disconnect(); await this._usernameGrain.ReleaseAsync(this.GetPrimaryKeyString()); this.DeactivateOnIdle(); @@ -208,14 +213,15 @@ private static string FormatUsername(string username) } return sb.ToString(); } - [MemberNotNull(nameof(_clientId), nameof(_usernameGrain), nameof(_password))] + [MemberNotNull(nameof(_clientId), nameof(_usernameGrain), nameof(_sendGrain), nameof(_password))] private void EnsureInitialized() { - if (this._clientId is null || this._usernameGrain is null || this._password is null) + if (this._clientId is null || this._usernameGrain is null || this._sendGrain is null || this._password is null) { throw new InvalidOperationException( $"ClientGrain not initialized: clientId={(this._clientId is null ? "null" : "set")}, " + $"usernameGrain={(this._usernameGrain is null ? "null" : "set")}, " + + $"sendGrain={(this._sendGrain is null ? "null" : "set")}, " + $"password={(this._password is null ? "null" : "set")}"); } } diff --git a/src/RemoteViewer.Server/Orleans/Grains/ClientSendGrain.cs b/src/RemoteViewer.Server/Orleans/Grains/ClientSendGrain.cs new file mode 100644 index 0000000..6f4f25d --- /dev/null +++ b/src/RemoteViewer.Server/Orleans/Grains/ClientSendGrain.cs @@ -0,0 +1,246 @@ +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Threading; +using Microsoft.AspNetCore.SignalR; +using Orleans.Concurrency; +using RemoteViewer.Server.Hubs; +using RemoteViewer.Shared.Protocol; +using System.Threading.Channels; + + +namespace RemoteViewer.Server.Orleans.Grains; + +public interface IClientSendGrain : IGrainWithStringKey +{ + Task Enqueue(string connectionId, string senderClientId, string messageType, byte[] data); + Task AckFrame(string connectionId); + Task Disconnect(); +} + +[Reentrant] +[SuppressMessage("IDisposableAnalyzers", "CA1001", Justification = "Orleans grains don't implement IDisposable; cleanup is in OnDeactivateAsync")] +public sealed partial class ClientSendGrain(ILogger logger, IHubContext hubContext) + : Grain, IClientSendGrain +{ + private readonly Channel _nonFrameChannel = Channel.CreateUnbounded(new() { SingleReader = true, SingleWriter = false }); + private readonly CancellationTokenSource _shutdownCts = new(); + + private readonly ConcurrentDictionary _frameStates = new(StringComparer.Ordinal); + private Task? _processingTask; + + + + public override Task OnActivateAsync(CancellationToken cancellationToken) + { + this._processingTask = Task.Run(() => this.ProcessNonFrameMessagesAsync(this._shutdownCts.Token), cancellationToken); + + return Task.CompletedTask; + } + + public override async Task OnDeactivateAsync(DeactivationReason reason, CancellationToken cancellationToken) + { + this._shutdownCts.Cancel(); + this._nonFrameChannel.Writer.TryComplete(); + + if (this._processingTask is not null) + { + try + { + await this._processingTask; + } + catch (OperationCanceledException) + { + } + } + + this._shutdownCts.Dispose(); + } + + public Task Enqueue(string connectionId, string senderClientId, string messageType, byte[] data) + { + var message = new QueuedMessage(connectionId, senderClientId, messageType, data); + + if (messageType == MessageTypes.Screen.Frame) + { + return this.EnqueueFrame(message); + } + else + { + this._nonFrameChannel.Writer.TryWrite(message); + return Task.CompletedTask; + } + } + + private Task EnqueueFrame(QueuedMessage message) + { + var state = this._frameStates.GetOrAdd(message.ConnectionId, _ => new FrameSendState()); + + var (wasIdle, dropped) = state.TryEnqueueOrSend(message); + + if (wasIdle) + { + return this.SendFrameAsync(message); + } + + if (dropped is { } droppedMessage) + { + this.LogFrameDropped(droppedMessage.MessageType); + } + + return Task.CompletedTask; + } + + public Task AckFrame(string connectionId) + { + if (!this._frameStates.TryGetValue(connectionId, out var state)) + { + return Task.CompletedTask; + } + + var toSend = state.TryGetPendingAndClearInFlight(); + + if (toSend is { } message) + { + return this.SendFrameAsync(message); + } + + this.TryRemoveState(connectionId, state); + return Task.CompletedTask; + } + + private void TryRemoveState(string connectionId, FrameSendState state) + { + if (state.CanRemove()) + { + this._frameStates.TryRemove(new KeyValuePair(connectionId, state)); + } + } + + public async Task Disconnect() + { + + this._shutdownCts.Cancel(); + this._nonFrameChannel.Writer.TryComplete(); + + if (this._processingTask is not null) + { + try + { + await this._processingTask; + } + catch (OperationCanceledException) + { + } + } + + this.DeactivateOnIdle(); + } + + private async Task SendFrameAsync(QueuedMessage frame) + { + try + { + await this.SendAsync(frame); + } + catch + { + // If delivery fails, clear in-flight state so the next frame can be sent + if (this._frameStates.TryGetValue(frame.ConnectionId, out var state)) + { + state.ClearOnError(); + this.TryRemoveState(frame.ConnectionId, state); + } + } + } + + + + private async Task ProcessNonFrameMessagesAsync(CancellationToken ct) + { + try + { + await foreach (var message in this._nonFrameChannel.Reader.ReadAllAsync(ct)) + { + await this.SendAsync(message); + } + } + catch (OperationCanceledException) + { + } + } + + private Task SendAsync(QueuedMessage message) + { + return hubContext.Clients + .Client(this.GetPrimaryKeyString()) + .MessageReceived(message.ConnectionId, message.SenderClientId, message.MessageType, message.Data); + } + + private readonly record struct QueuedMessage( + string ConnectionId, + string SenderClientId, + string MessageType, + byte[] Data); + + private sealed class FrameSendState + { + private readonly Lock _lock = new(); + private bool _inFlight; + private QueuedMessage? _pendingFrame; + + public (bool wasIdle, QueuedMessage? dropped) TryEnqueueOrSend(QueuedMessage message) + { + using (this._lock.EnterScope()) + { + if (this._inFlight is false) + { + this._inFlight = true; + return (wasIdle: true, dropped: null); + } + else + { + var dropped = this._pendingFrame; + this._pendingFrame = message; + return (wasIdle: false, dropped: dropped); + } + } + } + + public QueuedMessage? TryGetPendingAndClearInFlight() + { + using (this._lock.EnterScope()) + { + if (this._pendingFrame is { } pending) + { + this._pendingFrame = null; + // Keep _inFlight = true since we're about to send pending + return pending; + } + else + { + this._inFlight = false; + return null; + } + } + } + + public void ClearOnError() + { + using (this._lock.EnterScope()) + { + this._inFlight = false; + this._pendingFrame = null; + } + } + + public bool CanRemove() + { + using (this._lock.EnterScope()) + { + return this._inFlight is false && this._pendingFrame is null; + } + } + } +} + diff --git a/src/RemoteViewer.Server/Orleans/Grains/ClientSendGrainLogs.cs b/src/RemoteViewer.Server/Orleans/Grains/ClientSendGrainLogs.cs new file mode 100644 index 0000000..4be64eb --- /dev/null +++ b/src/RemoteViewer.Server/Orleans/Grains/ClientSendGrainLogs.cs @@ -0,0 +1,8 @@ +using Microsoft.Extensions.Logging; +namespace RemoteViewer.Server.Orleans.Grains; + +public sealed partial class ClientSendGrain +{ + [LoggerMessage(Level = LogLevel.Debug, Message = "Dropped frame {MessageType}")] + partial void LogFrameDropped(string messageType); +} diff --git a/src/RemoteViewer.Server/Orleans/Grains/ConnectionGrain.cs b/src/RemoteViewer.Server/Orleans/Grains/ConnectionGrain.cs index 452f637..ac42d8b 100644 --- a/src/RemoteViewer.Server/Orleans/Grains/ConnectionGrain.cs +++ b/src/RemoteViewer.Server/Orleans/Grains/ConnectionGrain.cs @@ -1,7 +1,12 @@ -using System.Diagnostics.CodeAnalysis; +using System.Diagnostics.CodeAnalysis; using Microsoft.AspNetCore.SignalR; using RemoteViewer.Server.Hubs; using RemoteViewer.Shared; +using RemoteViewer.Shared.Protocol; +using Orleans; +using System.Threading; +using System.Threading.Tasks; +using System.Linq; using ConnectionInfo = RemoteViewer.Shared.ConnectionInfo; @@ -66,42 +71,36 @@ public async Task SendMessage(string senderSignalrConnectionId, string messageTy switch (destination) { case MessageDestination.PresenterOnly: - if (!isSenderPresenter && this._presenter is not null) + if (!isSenderPresenter) { - await hubContext.Clients.Client(this._presenter.GetPrimaryKeyString()).MessageReceived(this.GetPrimaryKeyString(), senderClientId, messageType, data); + await this.SendMessageToPresenterAsync(senderClientId, messageType, data); } break; case MessageDestination.AllViewers: - foreach (var viewer in this._viewers) - { - await hubContext.Clients.Client(viewer.GetPrimaryKeyString()).MessageReceived(this.GetPrimaryKeyString(), senderClientId, messageType, data); - } + await Task.WhenAll(this._viewers.Select(viewer => this.SendMessageToViewerAsync(viewer, senderClientId, messageType, data))); break; case MessageDestination.All: - if (this._presenter is not null) - { - await hubContext.Clients.Client(this._presenter.GetPrimaryKeyString()).MessageReceived(this.GetPrimaryKeyString(), senderClientId, messageType, data); - } - foreach (var viewer in this._viewers) - { - await hubContext.Clients.Client(viewer.GetPrimaryKeyString()).MessageReceived(this.GetPrimaryKeyString(), senderClientId, messageType, data); - } + await Task.WhenAll( + this.SendMessageToPresenterAsync(senderClientId, messageType, data), + Task.WhenAll(this._viewers.Select(viewer => this.SendMessageToViewerAsync(viewer, senderClientId, messageType, data)))); break; case MessageDestination.AllExceptSender: + var broadcastTasks = new List(); if (this._presenter is not null && this._presenter.GetPrimaryKeyString() != senderSignalrConnectionId) { - await hubContext.Clients.Client(this._presenter.GetPrimaryKeyString()).MessageReceived(this.GetPrimaryKeyString(), senderClientId, messageType, data); + broadcastTasks.Add(this.SendMessageToPresenterAsync(senderClientId, messageType, data)); } foreach (var viewer in this._viewers) { if (viewer.GetPrimaryKeyString() != senderSignalrConnectionId) { - await hubContext.Clients.Client(viewer.GetPrimaryKeyString()).MessageReceived(this.GetPrimaryKeyString(), senderClientId, messageType, data); + broadcastTasks.Add(this.SendMessageToViewerAsync(viewer, senderClientId, messageType, data)); } } + await Task.WhenAll(broadcastTasks); break; case MessageDestination.SpecificClients: @@ -109,10 +108,11 @@ public async Task SendMessage(string senderSignalrConnectionId, string messageTy break; var targetClientIdSet = targetClientIds.ToHashSet(StringComparer.Ordinal); + var specificTasks = new List(); if (this._presenter is not null && targetClientIdSet.Contains(await this._presenter.Internal_GetClientId())) { - await hubContext.Clients.Client(this._presenter.GetPrimaryKeyString()).MessageReceived(this.GetPrimaryKeyString(), senderClientId, messageType, data); + specificTasks.Add(this.SendMessageToPresenterAsync(senderClientId, messageType, data)); } foreach (var viewer in this._viewers) @@ -120,9 +120,11 @@ public async Task SendMessage(string senderSignalrConnectionId, string messageTy var viewerClientId = await viewer.Internal_GetClientId(); if (targetClientIdSet.Contains(viewerClientId)) { - await hubContext.Clients.Client(viewer.GetPrimaryKeyString()).MessageReceived(this.GetPrimaryKeyString(), senderClientId, messageType, data); + specificTasks.Add(this.SendMessageToViewerAsync(viewer, senderClientId, messageType, data)); } } + + await Task.WhenAll(specificTasks); break; } @@ -231,6 +233,24 @@ private async Task NotifyConnectionChangedAsync() } } + private Task SendMessageToPresenterAsync(string senderClientId, string messageType, byte[] data) + { + if (this._presenter is null) + { + return Task.CompletedTask; + } + + var presenterSenderGrain = this.GrainFactory.GetGrain(this._presenter.GetPrimaryKeyString()); + return presenterSenderGrain.Enqueue(this.GetPrimaryKeyString(), senderClientId, messageType, data); + } + + private Task SendMessageToViewerAsync(IClientGrain viewer, string senderClientId, string messageType, byte[] data) + { + var viewerSignalrId = viewer.GetPrimaryKeyString(); + var senderGrain = this.GrainFactory.GetGrain(viewerSignalrId); + return senderGrain.Enqueue(this.GetPrimaryKeyString(), senderClientId, messageType, data); + } + [MemberNotNull(nameof(_presenter))] private void EnsureInitialized() { @@ -264,4 +284,10 @@ private void EnsureInitialized() [LoggerMessage(Level = LogLevel.Information, Message = "Viewer disconnected: ConnectionId={ConnectionId}, ViewerCount={ViewerCount}")] private partial void LogViewerDisconnected(string connectionId, int viewerCount); + + [LoggerMessage(Level = LogLevel.Debug, Message = "Queued frame for viewer {ViewerSignalrId} (bytes={PayloadBytes})")] + private partial void LogFrameQueued(string viewerSignalrId, int payloadBytes); + + [LoggerMessage(Level = LogLevel.Debug, Message = "Coalesced frame for viewer {ViewerSignalrId}: {PreviousBytes} bytes -> {NewBytes} bytes")] + private partial void LogFrameCoalesced(string viewerSignalrId, int previousBytes, int newBytes); } diff --git a/src/RemoteViewer.Server/Services/ConnectionsOrleansService.cs b/src/RemoteViewer.Server/Services/ConnectionsOrleansService.cs index 84352b4..c248a83 100644 --- a/src/RemoteViewer.Server/Services/ConnectionsOrleansService.cs +++ b/src/RemoteViewer.Server/Services/ConnectionsOrleansService.cs @@ -1,4 +1,4 @@ -using RemoteViewer.Server.Orleans.Grains; +using RemoteViewer.Server.Orleans.Grains; using RemoteViewer.Shared; namespace RemoteViewer.Server.Services; @@ -78,6 +78,13 @@ public async Task SendMessage(string signalrConnectionId, string connectionId, s await connectionGrain.SendMessage(signalrConnectionId, messageType, data, destination, targetClientIds); } + public Task AckFrame(string signalrConnectionId, string connectionId) + { + var grain = grainFactory.GetGrain(signalrConnectionId); + return grain.AckFrame(connectionId); + } + + public async Task IsPresenterOfConnection(string signalrConnectionId, string connectionId) { var connectionGrain = grainFactory.GetGrain(connectionId); diff --git a/src/RemoteViewer.Server/Services/IConnectionsService.cs b/src/RemoteViewer.Server/Services/IConnectionsService.cs index d91c601..dd242f7 100644 --- a/src/RemoteViewer.Server/Services/IConnectionsService.cs +++ b/src/RemoteViewer.Server/Services/IConnectionsService.cs @@ -1,4 +1,4 @@ -using Microsoft.AspNetCore.SignalR; +using Microsoft.AspNetCore.SignalR; using RemoteViewer.Server.Common; using RemoteViewer.Server.Hubs; using RemoteViewer.Shared; @@ -21,6 +21,8 @@ public interface IConnectionsService Task DisconnectFromConnection(string signalrConnectionId, string connectionId); Task SetConnectionProperties(string signalrConnectionId, string connectionId, ConnectionProperties properties); Task SendMessage(string signalrConnectionId, string connectionId, string messageType, byte[] data, MessageDestination destination, IReadOnlyList? targetClientIds = null); + Task AckFrame(string signalrConnectionId, string connectionId); + Task IsPresenterOfConnection(string signalrConnectionId, string connectionId); } @@ -324,6 +326,13 @@ public async Task SendMessage(string signalrConnectionId, string connectionId, s this._logger.MessageSendCompleted(senderId, connectionId, messageType); } + public Task AckFrame(string signalrConnectionId, string connectionId) + { + // No-op - we don't have frame-backbuffer mechanisms in this implementation + return Task.CompletedTask; + } + + public Task IsPresenterOfConnection(string signalrConnectionId, string connectionId) { using (this._lock.ReadLock()) diff --git a/tests/RemoteViewer.IntegrationTests/ConnectionHubClientTests.cs b/tests/RemoteViewer.IntegrationTests/ConnectionHubClientTests.cs index 73ad36c..4619970 100644 --- a/tests/RemoteViewer.IntegrationTests/ConnectionHubClientTests.cs +++ b/tests/RemoteViewer.IntegrationTests/ConnectionHubClientTests.cs @@ -2,7 +2,8 @@ using RemoteViewer.Client.Controls.Dialogs; using RemoteViewer.Client.Services.FileTransfer; using RemoteViewer.Client.Services.HubClient; -using RemoteViewer.IntegrationTests.Fixtures; +using RemoteViewer.TestFixtures; +using RemoteViewer.TestFixtures.Fixtures; using RemoteViewer.Shared; using RemoteViewer.Shared.Protocol; @@ -45,7 +46,7 @@ public async Task ChatMessagesAreSentFromViewerToPresenter() var presenterConn = presenter.CurrentConnection!; var viewerConn = viewer.CurrentConnection!; - var receiveTask = TestHelpers.WaitForEventAsync( + var receiveTask = TestHelpers.WaitForEvent( onResult => presenterConn.Chat.MessageReceived += (s, msg) => onResult(msg)); await viewerConn.Chat.SendMessageAsync("Hello from viewer!"); @@ -65,7 +66,7 @@ public async Task ChatMessagesAreSentFromPresenterToViewer() var presenterConn = presenter.CurrentConnection!; var viewerConn = viewer.CurrentConnection!; - var receiveTask = TestHelpers.WaitForEventAsync( + var receiveTask = TestHelpers.WaitForEvent( onResult => viewerConn.Chat.MessageReceived += (s, msg) => onResult(msg)); await presenterConn.Chat.SendMessageAsync("Hello from presenter!"); @@ -89,7 +90,7 @@ public async Task InputBlockingUpdatesConnectionProperties() // Wait for presenter to receive server confirmation // (presenter is now server-authoritative - no local optimistic update) - var presenterPropertyTask = TestHelpers.WaitForEventAsync( + var presenterPropertyTask = TestHelpers.WaitForEvent( onComplete => presenterConn.ConnectionPropertiesChanged += (s, e) => { if (presenterConn.ConnectionProperties.InputBlockedViewerIds.Contains(viewerClientId)) @@ -118,7 +119,7 @@ public async Task ViewerMouseMoveIsSentToPresenter() await viewerConn.RequiredViewerService.SendMouseMoveAsync(0.5f, 0.5f); // Wait for message to be received and processed - await TestHelpers.WaitForReceivedCallAsync(() => + await TestHelpers.WaitForReceivedCall(() => presenter.InputInjectionService.ReceivedCalls() .Any(c => c.GetMethodInfo().Name == "InjectMouseMove")); @@ -143,7 +144,7 @@ public async Task ViewerKeyPressIsSentToPresenter() await viewerConn.RequiredViewerService.SendKeyUpAsync(0x41, KeyModifiers.None); // Wait for both key events to be received - await TestHelpers.WaitForReceivedCallAsync(() => + await TestHelpers.WaitForReceivedCall(() => presenter.InputInjectionService.ReceivedCalls() .Count(c => c.GetMethodInfo().Name == "InjectKey") >= 2); @@ -164,7 +165,7 @@ public async Task PresenterDisconnectClosesViewerConnection() var presenterConn = presenter.CurrentConnection!; var viewerConn = viewer.CurrentConnection!; - var closedTask = TestHelpers.WaitForEventAsync( + var closedTask = TestHelpers.WaitForEvent( onComplete => viewerConn.Closed += (s, e) => onComplete()); await presenterConn.DisconnectAsync(); @@ -206,9 +207,9 @@ public async Task MultipleViewersCanConnectToSamePresenter() var presenterConn = presenter.CurrentConnection!; // Wait for all viewers to be registered (eventual consistency) - await TestHelpers.WaitForConditionAsync( + await TestHelpers.WaitUntil( () => presenterConn.Viewers.Count == 3, - timeoutMessage: $"Expected 3 viewers but got {presenterConn.Viewers.Count}"); + message: $"Expected 3 viewers but got {presenterConn.Viewers.Count}"); await Assert.That(presenterConn.Viewers.Count).IsEqualTo(3); } @@ -226,7 +227,7 @@ public async Task ViewerDisconnectDoesNotAffectOtherViewers() var viewer2Conn = viewer2.CurrentConnection!; // Subscribe BEFORE disconnect to wait for it to be fully processed - var viewer1DisconnectedTask = TestHelpers.WaitForEventAsync( + var viewer1DisconnectedTask = TestHelpers.WaitForEvent( onComplete => presenterConn.ViewersChanged += (s, e) => { if (presenterConn.Viewers.Count == 1) @@ -240,7 +241,7 @@ public async Task ViewerDisconnectDoesNotAffectOtherViewers() await viewer1DisconnectedTask; // Viewer2 should still be connected and able to communicate - var receiveTask = TestHelpers.WaitForEventAsync( + var receiveTask = TestHelpers.WaitForEvent( onResult => presenterConn.Chat.MessageReceived += (s, msg) => onResult(msg)); await viewer2Conn.Chat.SendMessageAsync("I'm still connected!"); @@ -265,7 +266,7 @@ public async Task ViewersChangedEventFiresOnViewerConnect() var presenterConn = await presenterConnTask; // Subscribe BEFORE waiting for viewer connect to avoid race condition - var viewersChangedTask = TestHelpers.WaitForEventAsync( + var viewersChangedTask = TestHelpers.WaitForEvent( onComplete => presenterConn.ViewersChanged += (s, e) => { if (presenterConn.Viewers.Count > 0) @@ -293,7 +294,7 @@ public async Task ConnectionPropertiesChangedEventFires() var viewerClientId = viewer.HubClient.ClientId!; // Subscribe BEFORE triggering the property change to avoid race condition - var propertyChangedTask = TestHelpers.WaitForEventAsync( + var propertyChangedTask = TestHelpers.WaitForEvent( onComplete => viewerConn.ConnectionPropertiesChanged += (s, e) => { if (viewerConn.ConnectionProperties.InputBlockedViewerIds.Contains(viewerClientId)) @@ -322,7 +323,7 @@ public async Task IsClosedReflectsConnectionState() await Assert.That(viewerConn.IsClosed).IsFalse(); // Subscribe first, then disconnect - var closedTask = TestHelpers.WaitForEventAsync( + var closedTask = TestHelpers.WaitForEvent( onComplete => viewerConn.Closed += (s, e) => onComplete()); await viewerConn.DisconnectAsync(); @@ -349,7 +350,7 @@ public async Task ViewerMouseClickIsSentToPresenter() await viewerConn.RequiredViewerService.SendMouseUpAsync(MouseButton.Left, 0.5f, 0.5f); // Wait for both mouse events to be received - await TestHelpers.WaitForReceivedCallAsync(() => + await TestHelpers.WaitForReceivedCall(() => presenter.InputInjectionService.ReceivedCalls() .Count(c => c.GetMethodInfo().Name == "InjectMouseButton") >= 2); @@ -375,7 +376,7 @@ public async Task ViewerMouseWheelIsSentToPresenter() await viewerConn.RequiredViewerService.SendMouseWheelAsync(0f, 120f, 0.5f, 0.5f); // Wait for wheel event to be received - await TestHelpers.WaitForReceivedCallAsync(() => + await TestHelpers.WaitForReceivedCall(() => presenter.InputInjectionService.ReceivedCalls() .Any(c => c.GetMethodInfo().Name == "InjectMouseWheel")); @@ -403,7 +404,7 @@ public async Task InputBlockingPreventsInputInjection() // Wait for PRESENTER to receive server confirmation // (input blocking check happens on presenter, so we must wait for presenter's state) - var presenterPropertyTask = TestHelpers.WaitForEventAsync( + var presenterPropertyTask = TestHelpers.WaitForEvent( onComplete => presenterConn.ConnectionPropertiesChanged += (s, e) => { if (presenterConn.ConnectionProperties.InputBlockedViewerIds.Contains(viewerClientId)) @@ -448,7 +449,7 @@ public async Task MultipleKeyModifiersAreSent() await viewerConn.RequiredViewerService.SendKeyDownAsync(0x41, KeyModifiers.Control | KeyModifiers.Shift); // Wait for key event to be received - await TestHelpers.WaitForReceivedCallAsync(() => + await TestHelpers.WaitForReceivedCall(() => presenter.InputInjectionService.ReceivedCalls() .Any(c => c.GetMethodInfo().Name == "InjectKey")); @@ -473,7 +474,7 @@ public async Task GetMessagesReturnsHistory() var viewerConn = viewer.CurrentConnection!; // Wait for the last message to arrive - var lastMessageTask = TestHelpers.WaitForEventAsync( + var lastMessageTask = TestHelpers.WaitForEvent( onResult => presenterConn.Chat.MessageReceived += (s, msg) => { if (msg.Text == "Message 3") @@ -505,9 +506,9 @@ public async Task MultipleViewersReceiveSameChatMessage() var viewer1Conn = viewer1.CurrentConnection!; var viewer2Conn = viewer2.CurrentConnection!; - var msg1Task = TestHelpers.WaitForEventAsync( + var msg1Task = TestHelpers.WaitForEvent( onResult => viewer1Conn.Chat.MessageReceived += (s, msg) => onResult(msg)); - var msg2Task = TestHelpers.WaitForEventAsync( + var msg2Task = TestHelpers.WaitForEvent( onResult => viewer2Conn.Chat.MessageReceived += (s, msg) => onResult(msg)); await presenterConn.Chat.SendMessageAsync("Broadcast to all!"); @@ -529,7 +530,7 @@ public async Task ChatMessagesContainCorrectSenderInfo() var viewerConn = viewer.CurrentConnection!; // Test 1: Viewer sends to presenter - verify IsFromPresenter is false - var fromViewerTask = TestHelpers.WaitForEventAsync( + var fromViewerTask = TestHelpers.WaitForEvent( onResult => presenterConn.Chat.MessageReceived += (s, msg) => { if (msg.Text == "From viewer") @@ -542,7 +543,7 @@ public async Task ChatMessagesContainCorrectSenderInfo() await Assert.That(fromViewer.IsFromPresenter).IsFalse(); // Test 2: Presenter sends to viewer - verify IsFromPresenter is true - var fromPresenterTask = TestHelpers.WaitForEventAsync( + var fromPresenterTask = TestHelpers.WaitForEvent( onResult => viewerConn.Chat.MessageReceived += (s, msg) => { if (msg.Text == "From presenter") @@ -574,7 +575,7 @@ public async Task ViewerCanBeBlockedWhileOthersAreNot() var viewer1ClientId = viewer1.HubClient.ClientId!; // Wait for property to propagate to viewer2 - var propertyChangedTask = TestHelpers.WaitForEventAsync( + var propertyChangedTask = TestHelpers.WaitForEvent( onComplete => viewer2Conn.ConnectionPropertiesChanged += (s, e) => { if (viewer2Conn.ConnectionProperties.InputBlockedViewerIds.Contains(viewer1ClientId)) @@ -593,7 +594,7 @@ await presenterConn.UpdateConnectionPropertiesAndSend(props => await viewer2Conn.RequiredViewerService.SendMouseMoveAsync(0.5f, 0.5f); // Wait for mouse event to be received - await TestHelpers.WaitForReceivedCallAsync(() => + await TestHelpers.WaitForReceivedCall(() => presenter.InputInjectionService.ReceivedCalls() .Any(c => c.GetMethodInfo().Name == "InjectMouseMove")); @@ -617,9 +618,9 @@ public async Task BroadcastMessagesReachAllViewers() var viewer1Conn = viewer1.CurrentConnection!; var viewer2Conn = viewer2.CurrentConnection!; - var msg1Task = TestHelpers.WaitForEventAsync( + var msg1Task = TestHelpers.WaitForEvent( onResult => viewer1Conn.Chat.MessageReceived += (s, msg) => onResult(msg)); - var msg2Task = TestHelpers.WaitForEventAsync( + var msg2Task = TestHelpers.WaitForEvent( onResult => viewer2Conn.Chat.MessageReceived += (s, msg) => onResult(msg)); await presenterConn.Chat.SendMessageAsync("Broadcast message"); @@ -642,7 +643,7 @@ public async Task AvailableDisplaysChangedEventFires() var viewerConn = viewer.CurrentConnection!; // Subscribe first - var displaysChangedTask = TestHelpers.WaitForEventAsync( + var displaysChangedTask = TestHelpers.WaitForEvent( onComplete => viewerConn.RequiredViewerService.AvailableDisplaysChanged += (s, e) => { if (viewerConn.RequiredViewerService.AvailableDisplays.Count >= 2) @@ -691,7 +692,7 @@ public async Task FileTransferRequestIsSentToPresenter() _ = viewerConn.FileTransfers.SendFileAsync(tempFile); // Wait for the dialog to be called - await TestHelpers.WaitForReceivedCallAsync(() => + await TestHelpers.WaitForReceivedCall(() => presenter.DialogService.ReceivedCalls() .Any(c => c.GetMethodInfo().Name == "ShowFileTransferConfirmationAsync")); @@ -726,7 +727,7 @@ public async Task RejectedFileTransferDoesNotComplete() await File.WriteAllTextAsync(tempFile, "Test file content"); // Subscribe first, then start transfer - var failedTask = TestHelpers.WaitForEventAsync( + var failedTask = TestHelpers.WaitForEvent( onResult => viewerConn.FileTransfers.TransferFailed += (s, e) => onResult(true)); // Start the transfer @@ -867,7 +868,7 @@ public async Task GenerateNewPasswordChangesCredentials() var oldId = presenter.HubClient.ClientId; // Wait for new credentials event - var newCredentialsTask = TestHelpers.WaitForEventAsync( + var newCredentialsTask = TestHelpers.WaitForEvent( onResult => presenter.HubClient.CredentialsAssigned += (s, e) => { // Only trigger on a different password @@ -908,7 +909,7 @@ public async Task SetDisplayNameUpdatesDisplayNameOnServer() var viewerConn = viewer.CurrentConnection!; // Wait for the ViewersChanged event which carries the display name - await TestHelpers.WaitForEventAsync( + await TestHelpers.WaitForEvent( onComplete => viewerConn.ViewersChanged += (s, e) => { if (viewerConn.Presenter?.DisplayName == "CustomPresenterName") @@ -930,7 +931,7 @@ public async Task SetDisplayNameCanBeChangedAfterConnection() var viewerConn = viewer.CurrentConnection!; // Wait for ViewersChanged event with the updated name - var viewersChangedTask = TestHelpers.WaitForEventAsync( + var viewersChangedTask = TestHelpers.WaitForEvent( onComplete => viewerConn.ViewersChanged += (s, e) => { if (viewerConn.Presenter?.DisplayName == "UpdatedName") @@ -996,9 +997,8 @@ public async Task FileTransferSuccessfulTransferCompletesAndFiresEvent() await File.WriteAllTextAsync(tempFile, "Test file content for successful transfer"); // Subscribe to transfer completed event on viewer (sender) side - var completedTask = TestHelpers.WaitForEventAsync( - onResult => viewerConn.FileTransfers.TransferCompleted += (s, e) => onResult(e), - timeout: TimeSpan.FromSeconds(10)); + var completedTask = TestHelpers.WaitForEvent( + onResult => viewerConn.FileTransfers.TransferCompleted += (s, e) => onResult(e)); // Start the transfer _ = viewerConn.FileTransfers.SendFileAsync(tempFile); @@ -1037,9 +1037,8 @@ public async Task FileTransferPresenterCanSendFileToViewer() await File.WriteAllTextAsync(tempFile, "Test file from presenter to viewer"); // Subscribe to transfer completed event on presenter (sender) side - var completedTask = TestHelpers.WaitForEventAsync( - onResult => presenterConn.FileTransfers.TransferCompleted += (s, e) => onResult(e), - timeout: TimeSpan.FromSeconds(10)); + var completedTask = TestHelpers.WaitForEvent( + onResult => presenterConn.FileTransfers.TransferCompleted += (s, e) => onResult(e)); // Start the transfer to specific viewer _ = presenterConn.FileTransfers.SendFileToViewerAsync(tempFile, viewerClientId); @@ -1077,9 +1076,8 @@ public async Task FileTransferCancellationFiresTransferFailedEvent() await File.WriteAllTextAsync(tempFile, "Test file content for cancellation test"); // Subscribe to transfer failed event - var failedTask = TestHelpers.WaitForEventAsync( - onResult => viewerConn.FileTransfers.TransferFailed += (s, e) => onResult(e), - timeout: TimeSpan.FromSeconds(10)); + var failedTask = TestHelpers.WaitForEvent( + onResult => viewerConn.FileTransfers.TransferFailed += (s, e) => onResult(e)); // Start the transfer var operation = await viewerConn.FileTransfers.SendFileAsync(tempFile); diff --git a/tests/RemoteViewer.IntegrationTests/Fixtures/TestHelpers.cs b/tests/RemoteViewer.IntegrationTests/Fixtures/TestHelpers.cs deleted file mode 100644 index d320961..0000000 --- a/tests/RemoteViewer.IntegrationTests/Fixtures/TestHelpers.cs +++ /dev/null @@ -1,52 +0,0 @@ -namespace RemoteViewer.IntegrationTests.Fixtures; - -public static class TestHelpers -{ - public static async Task WaitForEventAsync( - Action> subscribe, - TimeSpan? timeout = null) - { - var tcs = new TaskCompletionSource(); - subscribe(value => tcs.TrySetResult(value)); - - using var cts = new CancellationTokenSource(timeout ?? TimeSpan.FromSeconds(5)); - cts.Token.Register(() => tcs.TrySetCanceled()); - return await tcs.Task; - } - - public static async Task WaitForEventAsync( - Action subscribe, - TimeSpan? timeout = null) - { - var tcs = new TaskCompletionSource(); - subscribe(() => tcs.TrySetResult()); - - using var cts = new CancellationTokenSource(timeout ?? TimeSpan.FromSeconds(5)); - cts.Token.Register(() => tcs.TrySetCanceled()); - await tcs.Task; - } - - public static async Task WaitForReceivedCallAsync( - Func checkReceived, - TimeSpan? timeout = null) - { - await WaitForConditionAsync(checkReceived, timeout, "WaitForReceivedCallAsync timed out"); - } - - public static async Task WaitForConditionAsync( - Func condition, - TimeSpan? timeout = null, - string? timeoutMessage = null) - { - var deadline = DateTime.UtcNow + (timeout ?? TimeSpan.FromSeconds(5)); - - while (DateTime.UtcNow < deadline) - { - if (condition()) - return; - await Task.Delay(50); - } - - throw new TimeoutException(timeoutMessage ?? "WaitForConditionAsync timed out"); - } -} diff --git a/tests/RemoteViewer.IntegrationTests/RemoteViewer.IntegrationTests.csproj b/tests/RemoteViewer.IntegrationTests/RemoteViewer.IntegrationTests.csproj index 1526d78..680c7a0 100644 --- a/tests/RemoteViewer.IntegrationTests/RemoteViewer.IntegrationTests.csproj +++ b/tests/RemoteViewer.IntegrationTests/RemoteViewer.IntegrationTests.csproj @@ -17,8 +17,7 @@ - - + diff --git a/tests/RemoteViewer.IntegrationTests/ViewerSendGrainTests.cs b/tests/RemoteViewer.IntegrationTests/ViewerSendGrainTests.cs new file mode 100644 index 0000000..0cfac70 --- /dev/null +++ b/tests/RemoteViewer.IntegrationTests/ViewerSendGrainTests.cs @@ -0,0 +1,133 @@ +using RemoteViewer.Client.Services.HubClient; +using RemoteViewer.TestFixtures.Fixtures; +using RemoteViewer.Shared.Protocol; +using System.Reflection; +using static RemoteViewer.TestFixtures.TestHelpers; + +namespace RemoteViewer.IntegrationTests; + +[NotInParallel] +public class ViewerSendGrainTests() +{ + private static readonly ulong[] s_expectedFrames_1_3 = [1UL, 3UL]; + + [ClassDataSource(Shared = SharedType.PerTestSession)] + public required ServerFixture Server { get; init; } + + [Test] + public async Task FramesCoalesceLatestWinsPerViewer() + { + await using var presenter = await this.Server.CreateClientAsync("Presenter"); + await using var viewer = await this.Server.CreateClientAsync("Viewer"); + await this.Server.CreateConnectionAsync(presenter, viewer); + + var presenterConn = presenter.CurrentConnection!; + var viewerConn = viewer.CurrentConnection!; + await InvokePresenterSelectDisplayAsync(presenterConn, viewer.HubClient.ClientId!, "DISPLAY1"); + + // Suppress auto-acks so we can control timing manually + viewer.HubClient.Options.SuppressAutoFrameAck = true; + + var receivedFrames = new List(); + + viewerConn.RequiredViewerService.FrameReady += (_, args) => + { + receivedFrames.Add(args.FrameNumber); + }; + + // Send 20 frames rapidly - frame 1 goes immediately, rest coalesce + for (var i = 1; i <= 20; i++) + { + await InvokeSendFrameAsync(presenterConn, (ulong)i); + } + + // Wait for frame 1 to arrive + await WaitUntil( + () => receivedFrames.Contains(1UL), + message: "Frame 1 was not received"); + + // Ack and wait for frame 20 (frames should coalesce to latest) + // Retry acking if frame 20 hasn't arrived yet (frames might still be in transit) + await WaitUntil( + async () => + { + await SendAckFrameAsync(viewer.HubClient, viewerConn.ConnectionId); + await Task.Delay(100); // Give time for frame to arrive + return receivedFrames.Contains(20UL); + }, + message: "Frame 20 was not received after acking"); + + // Verify the last received frame is 20 + await Assert.That(receivedFrames.Count).IsEqualTo(2); + await Assert.That(receivedFrames[^1]).IsEqualTo(20UL); + } + + [Test] + public async Task ViewerSendQueueDropOldestWhenBusy() + { + await using var presenter = await this.Server.CreateClientAsync("Presenter"); + await using var viewer = await this.Server.CreateClientAsync("Viewer"); + await this.Server.CreateConnectionAsync(presenter, viewer); + + var presenterConn = presenter.CurrentConnection!; + var viewerConn = viewer.CurrentConnection!; + var viewerService = viewerConn.RequiredViewerService; + await InvokePresenterSelectDisplayAsync(presenterConn, viewer.HubClient.ClientId!, "DISPLAY1"); + + // Suppress auto-acks so we can control timing manually + viewer.HubClient.Options.SuppressAutoFrameAck = true; + + var receivedFrames = new List(); + + viewerService.FrameReady += (_, args) => + { + receivedFrames.Add(args.FrameNumber); + }; + + // Send frame 1 - should be delivered immediately + await InvokeSendFrameAsync(presenterConn, 1); + + // Wait for frame 1 to arrive + await WaitUntil( + () => receivedFrames.Contains(1UL), + message: "Frame 1 was not received"); + + // DON'T ack frame 1 yet - send frames 2 and 3 + await InvokeSendFrameAsync(presenterConn, 2); + await InvokeSendFrameAsync(presenterConn, 3); + + // Ack and wait for frame 3 (frames should coalesce, dropping frame 2) + // Retry acking if frame 3 hasn't arrived yet (frames might still be in transit) + await WaitUntil( + async () => + { + await SendAckFrameAsync(viewer.HubClient, viewerConn.ConnectionId); + await Task.Delay(100); // Give time for frame to arrive + return receivedFrames.Contains(3UL); + }, + message: "Frame 3 was not received after acking"); + + // Verify we got exactly frames 1 and 3 (frame 2 was dropped) + await Assert.That(receivedFrames).IsEquivalentTo(s_expectedFrames_1_3); + } + + private static Task SendAckFrameAsync(ConnectionHubClient client, string connectionId) + { + var method = client.GetType().GetMethod("SendAckFrameAsync", BindingFlags.Instance | BindingFlags.NonPublic)!; + return (Task)method.Invoke(client, [connectionId])!; + } + + + private static Task InvokeSendFrameAsync(Connection connection, ulong frameNumber) + { + var method = connection.GetType().GetMethod("RemoteViewer.Client.Services.HubClient.IConnectionImpl.SendFrameAsync", BindingFlags.Instance | BindingFlags.NonPublic)!; + return (Task)method.Invoke(connection, ["DISPLAY1", frameNumber, FrameCodec.Jpeg90, Array.Empty()])!; + } + + private static Task InvokePresenterSelectDisplayAsync(Connection connection, string viewerClientId, string displayId) + { + var service = connection.PresenterService ?? throw new InvalidOperationException("Presenter service not available."); + var method = service.GetType().GetMethod("RemoteViewer.Client.Services.HubClient.IPresenterServiceImpl.SelectViewerDisplayAsync", BindingFlags.Instance | BindingFlags.NonPublic)!; + return (Task)method.Invoke(service, [viewerClientId, displayId, CancellationToken.None])!; + } +} diff --git a/tests/RemoteViewer.Server.Tests/Hubs/SignalRHubTests.cs b/tests/RemoteViewer.Server.Tests/Hubs/SignalRHubTests.cs index 1a6825d..9ce4f4c 100644 --- a/tests/RemoteViewer.Server.Tests/Hubs/SignalRHubTests.cs +++ b/tests/RemoteViewer.Server.Tests/Hubs/SignalRHubTests.cs @@ -1,7 +1,7 @@ using Microsoft.AspNetCore.Http.Connections; using Microsoft.AspNetCore.SignalR.Client; using Nerdbank.MessagePack.SignalR; -using RemoteViewer.IntegrationTests.Fixtures; +using RemoteViewer.TestFixtures.Fixtures; using RemoteViewer.Shared; namespace RemoteViewer.Server.Tests.Hubs; diff --git a/tests/RemoteViewer.Server.Tests/RemoteViewer.Server.Tests.csproj b/tests/RemoteViewer.Server.Tests/RemoteViewer.Server.Tests.csproj index de402fc..79947d0 100644 --- a/tests/RemoteViewer.Server.Tests/RemoteViewer.Server.Tests.csproj +++ b/tests/RemoteViewer.Server.Tests/RemoteViewer.Server.Tests.csproj @@ -16,7 +16,7 @@ - + diff --git a/tests/RemoteViewer.IntegrationTests/Fixtures/ClientFixture.cs b/tests/RemoteViewer.TestFixtures/Fixtures/ClientFixture.cs similarity index 94% rename from tests/RemoteViewer.IntegrationTests/Fixtures/ClientFixture.cs rename to tests/RemoteViewer.TestFixtures/Fixtures/ClientFixture.cs index 6015344..f842d2d 100644 --- a/tests/RemoteViewer.IntegrationTests/Fixtures/ClientFixture.cs +++ b/tests/RemoteViewer.TestFixtures/Fixtures/ClientFixture.cs @@ -18,12 +18,11 @@ using RemoteViewer.Client.Services.SessionRecorderIpc; using RemoteViewer.Client.Services.WinServiceIpc; using RemoteViewer.Client.Views.Presenter; -using RemoteViewer.IntegrationTests.Mocks; +using RemoteViewer.TestFixtures.Mocks; using RemoteViewer.Shared; using RemoteViewer.Shared.Protocol; -using TUnit.Core; -namespace RemoteViewer.IntegrationTests.Fixtures; +namespace RemoteViewer.TestFixtures.Fixtures; public class ClientFixture : IAsyncDisposable { @@ -177,12 +176,19 @@ private static ILocalInputMonitorService CreateLocalInputMonitorServiceMock() public async Task WaitForConnectionAsync(TimeSpan? timeout = null) { + var effectiveTimeout = timeout ?? TimeSpan.FromSeconds(30); var tcs = new TaskCompletionSource(); this.HubClient.ConnectionStarted += (s, e) => tcs.TrySetResult(e.Connection); - using var cts = new CancellationTokenSource(timeout ?? TimeSpan.FromSeconds(30)); + using var cts = new CancellationTokenSource(effectiveTimeout); cts.Token.Register(() => tcs.TrySetCanceled()); - return await tcs.Task; + var connection = await tcs.Task; + + await TestHelpers.WaitUntil( + () => connection.Presenter is not null && connection.Viewers.Count > 0, + message: "Connection was started but Presenter and Viewer info didn't arrive."); + + return connection; } public async ValueTask DisposeAsync() diff --git a/tests/RemoteViewer.IntegrationTests/Fixtures/ServerFixture.cs b/tests/RemoteViewer.TestFixtures/Fixtures/ServerFixture.cs similarity index 94% rename from tests/RemoteViewer.IntegrationTests/Fixtures/ServerFixture.cs rename to tests/RemoteViewer.TestFixtures/Fixtures/ServerFixture.cs index 146555c..8854884 100644 --- a/tests/RemoteViewer.IntegrationTests/Fixtures/ServerFixture.cs +++ b/tests/RemoteViewer.TestFixtures/Fixtures/ServerFixture.cs @@ -1,10 +1,10 @@ -using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Mvc.Testing; using Microsoft.AspNetCore.TestHost; using Microsoft.Extensions.Logging; using TUnit.Core.Interfaces; -namespace RemoteViewer.IntegrationTests.Fixtures; +namespace RemoteViewer.TestFixtures.Fixtures; public class ServerFixture : WebApplicationFactory, IAsyncInitializer { diff --git a/tests/RemoteViewer.IntegrationTests/Mocks/TestDispatcher.cs b/tests/RemoteViewer.TestFixtures/Mocks/TestDispatcher.cs similarity index 82% rename from tests/RemoteViewer.IntegrationTests/Mocks/TestDispatcher.cs rename to tests/RemoteViewer.TestFixtures/Mocks/TestDispatcher.cs index 13f6971..ad86475 100644 --- a/tests/RemoteViewer.IntegrationTests/Mocks/TestDispatcher.cs +++ b/tests/RemoteViewer.TestFixtures/Mocks/TestDispatcher.cs @@ -1,6 +1,6 @@ using RemoteViewer.Client.Services.Dispatching; -namespace RemoteViewer.IntegrationTests.Mocks; +namespace RemoteViewer.TestFixtures.Mocks; public class TestDispatcher : IDispatcher { diff --git a/tests/RemoteViewer.TestFixtures/RemoteViewer.TestFixtures.csproj b/tests/RemoteViewer.TestFixtures/RemoteViewer.TestFixtures.csproj new file mode 100644 index 0000000..dc74d97 --- /dev/null +++ b/tests/RemoteViewer.TestFixtures/RemoteViewer.TestFixtures.csproj @@ -0,0 +1,24 @@ + + + + net10.0-windows7.0 + false + + + + + + + + + + + + + + + + + + + diff --git a/tests/RemoteViewer.TestFixtures/TestHelpers.cs b/tests/RemoteViewer.TestFixtures/TestHelpers.cs new file mode 100644 index 0000000..6deb3dd --- /dev/null +++ b/tests/RemoteViewer.TestFixtures/TestHelpers.cs @@ -0,0 +1,62 @@ +namespace RemoteViewer.TestFixtures; + +public static class TestHelpers +{ + private static readonly TimeSpan s_defaultTimeout = TimeSpan.FromSeconds(30); + private static readonly TimeSpan s_defaultPollInterval = TimeSpan.FromMilliseconds(50); + + public static async Task WaitForEvent(Action> subscribe) + { + var tcs = new TaskCompletionSource(); + subscribe(value => tcs.TrySetResult(value)); + + using var cts = new CancellationTokenSource(s_defaultTimeout); + cts.Token.Register(() => tcs.TrySetCanceled()); + return await tcs.Task; + } + + public static async Task WaitForEvent(Action subscribe) + { + var tcs = new TaskCompletionSource(); + subscribe(() => tcs.TrySetResult()); + + using var cts = new CancellationTokenSource(s_defaultTimeout); + cts.Token.Register(() => tcs.TrySetCanceled()); + await tcs.Task; + } + + public static async Task WaitForReceivedCall(Func checkReceived) + { + await WaitUntil(checkReceived, message: "WaitForReceivedCall timed out"); + } + + public static async Task WaitUntil(Func condition, string? message = null) + { + var deadline = DateTime.UtcNow + s_defaultTimeout; + + while (DateTime.UtcNow < deadline) + { + if (condition()) + return; + + await Task.Delay(s_defaultPollInterval); + } + + throw new TimeoutException(message ?? $"Condition was not met within {s_defaultTimeout.TotalSeconds} seconds."); + } + + public static async Task WaitUntil(Func> condition, string? message = null) + { + var deadline = DateTime.UtcNow + s_defaultTimeout; + + while (DateTime.UtcNow < deadline) + { + if (await condition()) + return; + + await Task.Delay(s_defaultPollInterval); + } + + throw new TimeoutException(message ?? $"Condition was not met within {s_defaultTimeout.TotalSeconds} seconds."); + } +}