diff --git a/misk-actions/src/main/kotlin/misk/web/actions/WebSocket.kt b/misk-actions/src/main/kotlin/misk/web/actions/WebSocket.kt
index 9c7ab93f780..40e637795aa 100644
--- a/misk-actions/src/main/kotlin/misk/web/actions/WebSocket.kt
+++ b/misk-actions/src/main/kotlin/misk/web/actions/WebSocket.kt
@@ -39,14 +39,14 @@ interface WebSocket {
* Returns the size in bytes of all messages enqueued to be transmitted to the server. This
* doesn't include framing overhead. It also doesn't include any bytes buffered by the operating
* system or network intermediaries. This method returns 0 if no messages are waiting
- * in the queue. If may return a nonzero value after the web socket has been canceled; this
+ * in the queue. It may return a nonzero value after the web socket has been canceled; this
* indicates that enqueued messages were not transmitted.
*/
fun queueSize(): Long
/**
- * Attempts to enqueue {@code text} to be UTF-8 encoded and sent as a the data of a text (type
- * {@code 0x1}) message.
+ * Attempts to enqueue {@code bytes} to be sent as the data of a binary (type {@code 0x2})
+ * message.
*
*
This method returns true if the message was enqueued. Messages that would overflow the
* outgoing message buffer will be rejected and trigger a {@linkplain #close graceful shutdown} of
@@ -58,8 +58,8 @@ interface WebSocket {
fun send(bytes: ByteString): Boolean
/**
- * Attempts to enqueue {@code bytes} to be sent as a the data of a binary (type {@code 0x2})
- * message.
+ * Attempts to enqueue {@code text} to be UTF-8 encoded and sent as the data of a text (type
+ * {@code 0x1}) message.
*
*
This method returns true if the message was enqueued. Messages that would overflow the
* outgoing message buffer will be rejected and trigger a {@linkplain #close graceful shutdown} of
diff --git a/misk-testing/src/main/kotlin/misk/web/FakeWebSocketListener.kt b/misk-testing/src/main/kotlin/misk/web/FakeWebSocketListener.kt
index 223449f31bc..f6a84054f13 100644
--- a/misk-testing/src/main/kotlin/misk/web/FakeWebSocketListener.kt
+++ b/misk-testing/src/main/kotlin/misk/web/FakeWebSocketListener.kt
@@ -1,14 +1,22 @@
package misk.web
+import okhttp3.WebSocket
+import okio.ByteString
import java.util.concurrent.LinkedBlockingDeque
import java.util.concurrent.TimeUnit
class FakeWebSocketListener : okhttp3.WebSocketListener() {
val messages = LinkedBlockingDeque()
+ val binaryMessages = LinkedBlockingDeque()
override fun onMessage(webSocket: okhttp3.WebSocket, text: String) {
messages.add(text)
}
+ override fun onMessage(webSocket: WebSocket, bytes: ByteString) {
+ binaryMessages.add(bytes)
+ }
+
fun takeMessage() = messages.pollFirst(2, TimeUnit.SECONDS)
+ fun takeBinaryMessage() = binaryMessages.pollFirst(2, TimeUnit.SECONDS)
}
diff --git a/misk/src/main/kotlin/misk/web/jetty/JettyWebSocket.kt b/misk/src/main/kotlin/misk/web/jetty/JettyWebSocket.kt
index d40ce203d57..59a2c3206e0 100644
--- a/misk/src/main/kotlin/misk/web/jetty/JettyWebSocket.kt
+++ b/misk/src/main/kotlin/misk/web/jetty/JettyWebSocket.kt
@@ -25,11 +25,21 @@ internal class JettyWebSocket(
val response: JettyServerUpgradeResponse
) : WebSocket {
+ internal sealed interface Message {
+ val byteCount: Long
+ data class Text(val text: String) : Message {
+ override val byteCount = text.utf8Size()
+ }
+ data class Binary(val data: ByteString) : Message {
+ override val byteCount = data.size.toLong()
+ }
+ }
+
/** Total size of messages enqueued and not yet transmitted by Jetty. */
private var outgoingQueueSize = 0L
/** Messages to send when the Web Socket connects. */
- private var queue = ArrayDeque()
+ private var queue = ArrayDeque()
/** Application's listener to notify of incoming messages from the client. */
private var listener: WebSocketListener? = null
@@ -93,14 +103,22 @@ internal class JettyWebSocket(
}
override fun send(text: String): Boolean {
- val byteCount = text.utf8Size()
+ return enqueue(Message.Text(text))
+ }
+
+ override fun send(bytes: ByteString): Boolean {
+ return enqueue(Message.Binary(bytes))
+ }
+
+ private fun enqueue(message: Message): Boolean {
+ val byteCount = message.byteCount
if (outgoingQueueSize + byteCount > MAX_QUEUE_SIZE) {
close(1001, null)
return false
}
outgoingQueueSize += byteCount
- queue.add(text)
+ queue.add(message)
sendQueue()
return true
@@ -108,26 +126,32 @@ internal class JettyWebSocket(
private fun sendQueue() {
while (adapter.isConnected && queue.isNotEmpty()) {
- val text = queue.pop()
- val byteCount = text.utf8Size()
-
- adapter.remote.sendString(
- text,
- object : WriteCallback {
- override fun writeSuccess() {
- outgoingQueueSize -= byteCount
- }
-
- override fun writeFailed(x: Throwable?) {
- outgoingQueueSize -= byteCount
- }
+ val message = queue.pop()
+ val byteCount = message.byteCount
+ val callback = object : WriteCallback {
+ override fun writeSuccess() {
+ outgoingQueueSize -= byteCount
}
- )
- }
- }
- override fun send(bytes: ByteString): Boolean {
- TODO()
+ override fun writeFailed(x: Throwable?) {
+ outgoingQueueSize -= byteCount
+ }
+ }
+ when (message) {
+ is Message.Text -> {
+ adapter.remote.sendString(
+ message.text,
+ callback
+ )
+ }
+ is Message.Binary -> {
+ adapter.remote.sendBytes(
+ message.data.asByteBuffer(),
+ callback
+ )
+ }
+ }
+ }
}
override fun close(code: Int, reason: String?): Boolean {
@@ -167,7 +191,7 @@ internal class JettyWebSocket(
it.match(DispatchMechanism.WEBSOCKET, null, listOf(), httpCall.url)
}
- val bestAction = candidateActions.sorted().firstOrNull() ?: return null
+ val bestAction = candidateActions.minOrNull() ?: return null
bestAction.action.scopeAndHandle(request.httpServletRequest, httpCall, bestAction.pathMatcher)
return realWebSocket.adapter
}
diff --git a/misk/src/test/kotlin/misk/web/WebSocketsTest.kt b/misk/src/test/kotlin/misk/web/WebSocketsTest.kt
index 8cffbf60616..d5379827ac7 100644
--- a/misk/src/test/kotlin/misk/web/WebSocketsTest.kt
+++ b/misk/src/test/kotlin/misk/web/WebSocketsTest.kt
@@ -1,5 +1,6 @@
package misk.web
+import com.squareup.protos.test.parsing.Warehouse
import misk.MiskTestingServiceModule
import misk.inject.KAbstractModule
import misk.logging.LogCollectorModule
@@ -19,6 +20,7 @@ import org.junit.jupiter.api.Test
import misk.logging.LogCollector
import jakarta.inject.Inject
import jakarta.inject.Singleton
+import okio.ByteString
@MiskTest(startService = true)
internal class WebSocketsTest {
@@ -50,6 +52,30 @@ internal class WebSocketsTest {
)
}
+ @Test
+ fun binaryWebSocket() {
+ val client = OkHttpClient()
+
+ val request = Request.Builder()
+ .url(jettyService.httpServerUrl.resolve("/echo")!!)
+ .build()
+
+ val webSocket = client.newWebSocket(request, listener)
+ val warehouse = Warehouse.Builder()
+ .warehouse_token("WH_1")
+ .warehouse_id(42)
+ .build()
+
+ webSocket.send(warehouse.encodeByteString())
+
+ val expected = Warehouse.Builder()
+ .warehouse_token("ACK WH_1")
+ .warehouse_id(43)
+ .build()
+ val actual = listener.takeBinaryMessage()?.let { Warehouse.ADAPTER.decode(it) }
+ assertEquals(actual, expected)
+ }
+
@Test
fun loggingDisabledByEnv() {
val client = OkHttpClient()
@@ -87,6 +113,17 @@ class EchoWebSocket @Inject constructor() : WebAction {
webSocket.send("ACK $text")
}
+ override fun onMessage(webSocket: WebSocket, bytes: ByteString) {
+ val message = Warehouse.ADAPTER.decode(bytes)
+ webSocket.send(
+ message.newBuilder()
+ .warehouse_token("ACK ${message.warehouse_token}")
+ .warehouse_id(message.warehouse_id + 1)
+ .build()
+ .encodeByteString()
+ )
+ }
+
override fun toString() = "EchoListener"
}
}