1616 */
1717package org .apache .spark .sql .connect .client
1818
19- import java .util .UUID
19+ import java .nio .charset .StandardCharsets .UTF_8
20+ import java .util .{Base64 , UUID }
2021import java .util .concurrent .TimeUnit
2122
2223import scala .collection .mutable
2324import scala .jdk .CollectionConverters ._
2425
25- import io .grpc .{CallOptions , Channel , ClientCall , ClientInterceptor , MethodDescriptor , Server , Status , StatusRuntimeException }
26+ import io .grpc .{CallOptions , Channel , ClientCall , ClientInterceptor , Metadata , MethodDescriptor , Server , ServerCall , ServerCallHandler , ServerInterceptor , Status , StatusRuntimeException }
2627import io .grpc .netty .NettyServerBuilder
2728import io .grpc .stub .StreamObserver
2829import org .scalatest .concurrent .Eventually
@@ -42,12 +43,13 @@ class SparkConnectClientSuite extends ConnectFunSuite {
4243 private var service : DummySparkConnectService = _
4344 private var server : Server = _
4445
45- private def startDummyServer (port : Int ): Unit = {
46+ private def startDummyServer (port : Int , interceptors : Seq [ ServerInterceptor ] = Seq () ): Unit = {
4647 service = new DummySparkConnectService
47- server = NettyServerBuilder
48+ val serverBuilder = NettyServerBuilder
4849 .forPort(port)
4950 .addService(service)
50- .build()
51+ interceptors.foreach(serverBuilder.intercept)
52+ server = serverBuilder.build()
5153 server.start()
5254 }
5355
@@ -622,6 +624,72 @@ class SparkConnectClientSuite extends ConnectFunSuite {
622624 // The client should try to fetch the config only once.
623625 assert(service.getAndClearLatestConfigRequests().size == 1 )
624626 }
627+
628+ test(" SPARK-55243: Binary headers use the correct marshaller" ) {
629+ class HeadersInterceptor extends ServerInterceptor {
630+ var headers : Option [Metadata ] = None
631+
632+ override def interceptCall [ReqT , RespT ](
633+ call : ServerCall [ReqT , RespT ],
634+ headers : Metadata ,
635+ next : ServerCallHandler [ReqT , RespT ]): ServerCall .Listener [ReqT ] = {
636+ this .headers = Some (headers)
637+ next.startCall(call, headers)
638+ }
639+ }
640+
641+ def buildClientWithHeader (key : String , value : String ): SparkConnectClient = {
642+ SparkConnectClient
643+ .builder()
644+ .connectionString(s " sc://localhost: ${server.getPort}" )
645+ .option(key, value)
646+ .build()
647+ }
648+
649+ val headerInterceptor = new HeadersInterceptor ()
650+ startDummyServer(0 , Seq (headerInterceptor))
651+
652+ val keyName = " test-bin"
653+ val key = Metadata .Key .of(keyName, Metadata .BINARY_BYTE_MARSHALLER )
654+ val binaryData = " test-binary-data"
655+ val base64EncodedValue = Base64 .getEncoder.encodeToString(binaryData.getBytes(UTF_8 ))
656+
657+ val plan = buildPlan(" select * from range(10)" )
658+
659+ // Successfully set and use base64-encoded -bin key.
660+ client = buildClientWithHeader(keyName, base64EncodedValue)
661+ client.execute(plan)
662+
663+ Eventually .eventually(timeout(5 .seconds)) {
664+ assert(headerInterceptor.headers.exists(_.containsKey(key)))
665+ val bytes = headerInterceptor.headers.get.get(key)
666+ assert(new String (bytes, UTF_8 ) == binaryData)
667+ }
668+
669+ // Non base64-encoded -bin header throws IllegalArgumentException.
670+ client = buildClientWithHeader(keyName, binaryData)
671+
672+ assertThrows[IllegalArgumentException ] {
673+ client.execute(plan)
674+ }
675+
676+ // Non -bin headers keep using the ASCII marshaller.
677+ val asciiKeyName = " test"
678+ val asciiKey = Metadata .Key .of(asciiKeyName, Metadata .ASCII_STRING_MARSHALLER )
679+
680+ headerInterceptor.headers = None // Reset captured headers.
681+
682+ client = buildClientWithHeader(asciiKeyName, base64EncodedValue)
683+ client.execute(plan)
684+
685+ Eventually .eventually(timeout(5 .seconds)) {
686+ assert(headerInterceptor.headers.exists(_.containsKey(asciiKey)))
687+ val value = headerInterceptor.headers.get.get(asciiKey)
688+ assert(value == base64EncodedValue)
689+ // No BINARY_BYTE_MARSHALLER header.
690+ assert(! headerInterceptor.headers.exists(_.containsKey(key)))
691+ }
692+ }
625693}
626694
627695class DummySparkConnectService () extends SparkConnectServiceGrpc .SparkConnectServiceImplBase {
0 commit comments