diff --git a/misk-actions/src/main/kotlin/misk/web/actions/WebSocket.kt b/misk-actions/src/main/kotlin/misk/web/actions/WebSocket.kt index 9c7ab93f780..40e637795aa 100644 --- a/misk-actions/src/main/kotlin/misk/web/actions/WebSocket.kt +++ b/misk-actions/src/main/kotlin/misk/web/actions/WebSocket.kt @@ -39,14 +39,14 @@ interface WebSocket { * Returns the size in bytes of all messages enqueued to be transmitted to the server. This * doesn't include framing overhead. It also doesn't include any bytes buffered by the operating * system or network intermediaries. This method returns 0 if no messages are waiting - * in the queue. If may return a nonzero value after the web socket has been canceled; this + * in the queue. It may return a nonzero value after the web socket has been canceled; this * indicates that enqueued messages were not transmitted. */ fun queueSize(): Long /** - * Attempts to enqueue {@code text} to be UTF-8 encoded and sent as a the data of a text (type - * {@code 0x1}) message. + * Attempts to enqueue {@code bytes} to be sent as the data of a binary (type {@code 0x2}) + * message. * *

This method returns true if the message was enqueued. Messages that would overflow the * outgoing message buffer will be rejected and trigger a {@linkplain #close graceful shutdown} of @@ -58,8 +58,8 @@ interface WebSocket { fun send(bytes: ByteString): Boolean /** - * Attempts to enqueue {@code bytes} to be sent as a the data of a binary (type {@code 0x2}) - * message. + * Attempts to enqueue {@code text} to be UTF-8 encoded and sent as the data of a text (type + * {@code 0x1}) message. * *

This method returns true if the message was enqueued. Messages that would overflow the * outgoing message buffer will be rejected and trigger a {@linkplain #close graceful shutdown} of diff --git a/misk-testing/src/main/kotlin/misk/web/FakeWebSocketListener.kt b/misk-testing/src/main/kotlin/misk/web/FakeWebSocketListener.kt index 223449f31bc..f6a84054f13 100644 --- a/misk-testing/src/main/kotlin/misk/web/FakeWebSocketListener.kt +++ b/misk-testing/src/main/kotlin/misk/web/FakeWebSocketListener.kt @@ -1,14 +1,22 @@ package misk.web +import okhttp3.WebSocket +import okio.ByteString import java.util.concurrent.LinkedBlockingDeque import java.util.concurrent.TimeUnit class FakeWebSocketListener : okhttp3.WebSocketListener() { val messages = LinkedBlockingDeque() + val binaryMessages = LinkedBlockingDeque() override fun onMessage(webSocket: okhttp3.WebSocket, text: String) { messages.add(text) } + override fun onMessage(webSocket: WebSocket, bytes: ByteString) { + binaryMessages.add(bytes) + } + fun takeMessage() = messages.pollFirst(2, TimeUnit.SECONDS) + fun takeBinaryMessage() = binaryMessages.pollFirst(2, TimeUnit.SECONDS) } diff --git a/misk/src/main/kotlin/misk/web/jetty/JettyWebSocket.kt b/misk/src/main/kotlin/misk/web/jetty/JettyWebSocket.kt index d40ce203d57..59a2c3206e0 100644 --- a/misk/src/main/kotlin/misk/web/jetty/JettyWebSocket.kt +++ b/misk/src/main/kotlin/misk/web/jetty/JettyWebSocket.kt @@ -25,11 +25,21 @@ internal class JettyWebSocket( val response: JettyServerUpgradeResponse ) : WebSocket { + internal sealed interface Message { + val byteCount: Long + data class Text(val text: String) : Message { + override val byteCount = text.utf8Size() + } + data class Binary(val data: ByteString) : Message { + override val byteCount = data.size.toLong() + } + } + /** Total size of messages enqueued and not yet transmitted by Jetty. */ private var outgoingQueueSize = 0L /** Messages to send when the Web Socket connects. */ - private var queue = ArrayDeque() + private var queue = ArrayDeque() /** Application's listener to notify of incoming messages from the client. */ private var listener: WebSocketListener? = null @@ -93,14 +103,22 @@ internal class JettyWebSocket( } override fun send(text: String): Boolean { - val byteCount = text.utf8Size() + return enqueue(Message.Text(text)) + } + + override fun send(bytes: ByteString): Boolean { + return enqueue(Message.Binary(bytes)) + } + + private fun enqueue(message: Message): Boolean { + val byteCount = message.byteCount if (outgoingQueueSize + byteCount > MAX_QUEUE_SIZE) { close(1001, null) return false } outgoingQueueSize += byteCount - queue.add(text) + queue.add(message) sendQueue() return true @@ -108,26 +126,32 @@ internal class JettyWebSocket( private fun sendQueue() { while (adapter.isConnected && queue.isNotEmpty()) { - val text = queue.pop() - val byteCount = text.utf8Size() - - adapter.remote.sendString( - text, - object : WriteCallback { - override fun writeSuccess() { - outgoingQueueSize -= byteCount - } - - override fun writeFailed(x: Throwable?) { - outgoingQueueSize -= byteCount - } + val message = queue.pop() + val byteCount = message.byteCount + val callback = object : WriteCallback { + override fun writeSuccess() { + outgoingQueueSize -= byteCount } - ) - } - } - override fun send(bytes: ByteString): Boolean { - TODO() + override fun writeFailed(x: Throwable?) { + outgoingQueueSize -= byteCount + } + } + when (message) { + is Message.Text -> { + adapter.remote.sendString( + message.text, + callback + ) + } + is Message.Binary -> { + adapter.remote.sendBytes( + message.data.asByteBuffer(), + callback + ) + } + } + } } override fun close(code: Int, reason: String?): Boolean { @@ -167,7 +191,7 @@ internal class JettyWebSocket( it.match(DispatchMechanism.WEBSOCKET, null, listOf(), httpCall.url) } - val bestAction = candidateActions.sorted().firstOrNull() ?: return null + val bestAction = candidateActions.minOrNull() ?: return null bestAction.action.scopeAndHandle(request.httpServletRequest, httpCall, bestAction.pathMatcher) return realWebSocket.adapter } diff --git a/misk/src/test/kotlin/misk/web/WebSocketsTest.kt b/misk/src/test/kotlin/misk/web/WebSocketsTest.kt index 8cffbf60616..d5379827ac7 100644 --- a/misk/src/test/kotlin/misk/web/WebSocketsTest.kt +++ b/misk/src/test/kotlin/misk/web/WebSocketsTest.kt @@ -1,5 +1,6 @@ package misk.web +import com.squareup.protos.test.parsing.Warehouse import misk.MiskTestingServiceModule import misk.inject.KAbstractModule import misk.logging.LogCollectorModule @@ -19,6 +20,7 @@ import org.junit.jupiter.api.Test import misk.logging.LogCollector import jakarta.inject.Inject import jakarta.inject.Singleton +import okio.ByteString @MiskTest(startService = true) internal class WebSocketsTest { @@ -50,6 +52,30 @@ internal class WebSocketsTest { ) } + @Test + fun binaryWebSocket() { + val client = OkHttpClient() + + val request = Request.Builder() + .url(jettyService.httpServerUrl.resolve("/echo")!!) + .build() + + val webSocket = client.newWebSocket(request, listener) + val warehouse = Warehouse.Builder() + .warehouse_token("WH_1") + .warehouse_id(42) + .build() + + webSocket.send(warehouse.encodeByteString()) + + val expected = Warehouse.Builder() + .warehouse_token("ACK WH_1") + .warehouse_id(43) + .build() + val actual = listener.takeBinaryMessage()?.let { Warehouse.ADAPTER.decode(it) } + assertEquals(actual, expected) + } + @Test fun loggingDisabledByEnv() { val client = OkHttpClient() @@ -87,6 +113,17 @@ class EchoWebSocket @Inject constructor() : WebAction { webSocket.send("ACK $text") } + override fun onMessage(webSocket: WebSocket, bytes: ByteString) { + val message = Warehouse.ADAPTER.decode(bytes) + webSocket.send( + message.newBuilder() + .warehouse_token("ACK ${message.warehouse_token}") + .warehouse_id(message.warehouse_id + 1) + .build() + .encodeByteString() + ) + } + override fun toString() = "EchoListener" } }