Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ lazy val runtime = Project(id = "runtime", base = file("runtime"))
AutomaticModuleName.settings("pekko.grpc.runtime"),
ReflectiveCodeGen.generatedLanguages := Seq("Scala"),
ReflectiveCodeGen.extraGenerators := Seq("ScalaMarshallersCodeGenerator"),
PB.protocVersion := Dependencies.Versions.googleProtoc)
PB.protocVersion := Dependencies.Versions.googleProtoc,
Test / PB.targets += (scalapb.gen() -> (Test / sourceManaged).value))
.enablePlugins(org.apache.pekko.grpc.build.ReflectiveCodeGen)
.enablePlugins(ReproducibleBuildsPlugin)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,40 +32,77 @@ import scala.jdk.CollectionConverters._
final class ServerReflectionImpl private (fileDescriptors: Map[String, FileDescriptor], services: List[String])
extends ServerReflection {
import ServerReflectionImpl._
import ServerReflectionResponse.{ MessageResponse => Out }

private val protoBytesLocalCache: concurrent.Map[String, ByteString] =
new ConcurrentHashMap[String, ByteString]().asScala

def serverReflectionInfo(in: Source[ServerReflectionRequest, NotUsed]): Source[ServerReflectionResponse, NotUsed] = {
in.map(req => {
import ServerReflectionRequest.{ MessageRequest => In }
import ServerReflectionResponse.{ MessageResponse => Out }

val response = req.messageRequest match {
case In.Empty =>
Out.Empty
case In.FileByFilename(fileName) =>
val list = fileDescriptors.get(fileName).map(getProtoBytes).toList
Out.FileDescriptorResponse(FileDescriptorResponse(list))
case In.FileContainingSymbol(symbol) =>
val list = findFileDescForSymbol(symbol, fileDescriptors).map(getProtoBytes).toList
Out.FileDescriptorResponse(FileDescriptorResponse(list))
case In.FileContainingExtension(ExtensionRequest(container, number, _)) =>
val list = findFileDescForExtension(container, number, fileDescriptors).map(getProtoBytes).toList
Out.FileDescriptorResponse(FileDescriptorResponse(list))
case In.AllExtensionNumbersOfType(container) =>
val list =
findExtensionNumbersForContainingType(
container,
fileDescriptors) // TODO should we throw a NOT_FOUND if we don't know the container type at all?
Out.AllExtensionNumbersResponse(ExtensionNumberResponse(container, list))
case In.ListServices(_) =>
val list = services.map(s => ServiceResponse(s))
Out.ListServicesResponse(ListServiceResponse(list))
// The server reflection spec requires sending transitive dependencies, but allows (and encourages) to only send
// transitive dependencies that haven't yet been sent on this stream. So, we track this with a stateful map.
in.statefulMap(() => Set.empty[String])(
(alreadySent, req) => {

import ServerReflectionRequest.{ MessageRequest => In }

val (newAlreadySent, response) = req.messageRequest match {
case In.Empty =>
(alreadySent, Out.Empty)
case In.FileByFilename(fileName) =>
toFileDescriptorResponse(fileDescriptors.get(fileName), alreadySent)
case In.FileContainingSymbol(symbol) =>
toFileDescriptorResponse(findFileDescForSymbol(symbol, fileDescriptors), alreadySent)
case In.FileContainingExtension(ExtensionRequest(container, number, _)) =>
toFileDescriptorResponse(findFileDescForExtension(container, number, fileDescriptors), alreadySent)
case In.AllExtensionNumbersOfType(container) =>
val list =
findExtensionNumbersForContainingType(
container,
fileDescriptors) // TODO should we throw a NOT_FOUND if we don't know the container type at all?
(alreadySent, Out.AllExtensionNumbersResponse(ExtensionNumberResponse(container, list)))
case In.ListServices(_) =>
val list = services.map(s => ServiceResponse(s))
(alreadySent, Out.ListServicesResponse(ListServiceResponse(list)))
}
// TODO Validate assumptions here
(newAlreadySent, ServerReflectionResponse(req.host, Some(req), response))
},
_ => None)
}

private def toFileDescriptorResponse(
fileDescriptor: Option[FileDescriptor],
alreadySent: Set[String]): (Set[String], Out.FileDescriptorResponse) = {
fileDescriptor match {
case None =>
(alreadySent, Out.FileDescriptorResponse(FileDescriptorResponse()))
case Some(file) =>
val (newAlreadySent, files) = withTransitiveDeps(alreadySent, file)
(newAlreadySent, Out.FileDescriptorResponse(FileDescriptorResponse(files.map(getProtoBytes))))
}
}

private def withTransitiveDeps(
alreadySent: Set[String],
file: FileDescriptor): (Set[String], List[FileDescriptor]) = {
@annotation.tailrec
def iterate(
sent: Set[String],
results: List[FileDescriptor],
toAdd: List[FileDescriptor]): (Set[String], List[FileDescriptor]) = {
toAdd match {
case Nil => (sent, results)
case _ =>
// Need to compute the new set of files sent before working out which dependencies to send, to ensure
// we don't send any dependencies that are being sent in this iteration
val nowSent = sent ++ toAdd.map(_.getName)
val depsOfToAdd =
toAdd.flatMap(_.getDependencies.asScala).distinct.filterNot(dep => nowSent.contains(dep.getName))
iterate(nowSent, toAdd ::: results, depsOfToAdd)
}
// TODO Validate assumptions here
ServerReflectionResponse(req.host, Some(req), response)
})
}

iterate(alreadySent, Nil, List(file))
}

private def getProtoBytes(fileDescriptor: FileDescriptor): ByteString =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
syntax = "proto3";

package org.apache.pekko.grpc.internal;

import "org/apache/pekko/grpc/internal/reflection_test_2.proto";
import "org/apache/pekko/grpc/internal/reflection_test_3.proto";

message MyMessage1 {
MyMessage2 field1 = 1;
MyMessage3 field2 = 2;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
syntax = "proto3";

package org.apache.pekko.grpc.internal;

import "org/apache/pekko/grpc/internal/reflection_test_3.proto";

message MyMessage2 {
MyMessage3 field1 = 2;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
syntax = "proto3";

package org.apache.pekko.grpc.internal;

import "org/apache/pekko/grpc/internal/reflection_test_4.proto";

message MyMessage3 {
MyMessage4 field1 = 1;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
syntax = "proto3";

package org.apache.pekko.grpc.internal;

message MyMessage4 {
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,27 @@ package org.apache.pekko.grpc.internal

import org.apache.pekko
import pekko.actor.ActorSystem
import pekko.grpc.internal.reflection_test_1.ReflectionTest1Proto
import pekko.grpc.internal.reflection_test_2.ReflectionTest2Proto
import pekko.grpc.internal.reflection_test_3.ReflectionTest3Proto
import pekko.grpc.internal.reflection_test_4.ReflectionTest4Proto
import pekko.stream.scaladsl.{ Sink, Source }
import pekko.testkit.TestKit
import com.google.protobuf.descriptor.FileDescriptorProto
import io.grpc.reflection.v1.reflection.ServerReflectionRequest.MessageRequest
import io.grpc.reflection.v1.reflection.{ ServerReflection, ServerReflectionRequest }
import io.grpc.reflection.v1.reflection.{ ServerReflection, ServerReflectionRequest, ServerReflectionResponse }
import org.scalatest.BeforeAndAfterAll
import org.scalatest.OptionValues
import org.scalatest.concurrent.ScalaFutures
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpecLike

class ServerReflectionImplSpec
extends TestKit(ActorSystem())
extends TestKit(ActorSystem("ServerReflectionImplSpec"))
with AnyWordSpecLike
with Matchers
with ScalaFutures
with BeforeAndAfterAll
with OptionValues {
import ServerReflectionImpl._
"The Server Reflection implementation utilities" should {
Expand All @@ -44,14 +51,22 @@ class ServerReflectionImplSpec
}

"The Server Reflection implementation" should {
val serverReflection =
ServerReflectionImpl(
Seq(
ServerReflection.descriptor,
ReflectionTest1Proto.javaDescriptor,
ReflectionTest2Proto.javaDescriptor,
ReflectionTest3Proto.javaDescriptor,
ReflectionTest4Proto.javaDescriptor),
List.empty[String])

"retrieve server reflection info" in {
val serverReflectionRequest = ServerReflectionRequest(messageRequest =
MessageRequest.FileByFilename("grpc/reflection/v1/reflection.proto"))

val serverReflectionResponse = ServerReflectionImpl(Seq(ServerReflection.descriptor), List.empty[String])
.serverReflectionInfo(Source.single(serverReflectionRequest))
.runWith(Sink.head)
.futureValue
val serverReflectionResponse =
serverReflection.serverReflectionInfo(Source.single(serverReflectionRequest)).runWith(Sink.head).futureValue

serverReflectionResponse.messageResponse.listServicesResponse should be(empty)

Expand All @@ -63,13 +78,69 @@ class ServerReflectionImplSpec
val serverReflectionRequest =
ServerReflectionRequest(messageRequest = MessageRequest.FileByFilename("grpc/reflection/v1/unknown.proto"))

val serverReflectionResponse = ServerReflectionImpl(Seq(ServerReflection.descriptor), List.empty[String])
.serverReflectionInfo(Source.single(serverReflectionRequest))
.runWith(Sink.head)
.futureValue
val serverReflectionResponse =
serverReflection.serverReflectionInfo(Source.single(serverReflectionRequest)).runWith(Sink.head).futureValue

serverReflectionResponse.messageResponse.listServicesResponse should be(empty)
serverReflectionResponse.messageResponse.fileDescriptorResponse.value.fileDescriptorProto should be(empty)
}

"return transitive dependencies" in {
val serverReflectionRequest = ServerReflectionRequest(messageRequest =
MessageRequest.FileByFilename("org/apache/pekko/grpc/internal/reflection_test_1.proto"))

val serverReflectionResponse =
serverReflection.serverReflectionInfo(Source.single(serverReflectionRequest)).runWith(Sink.head).futureValue

val protos = decodeFileResponseToNames(serverReflectionResponse)
protos should have size 4
(protos should contain).allOf(
"org/apache/pekko/grpc/internal/reflection_test_1.proto",
"org/apache/pekko/grpc/internal/reflection_test_2.proto",
"org/apache/pekko/grpc/internal/reflection_test_3.proto",
"org/apache/pekko/grpc/internal/reflection_test_4.proto")
}

"not return transitive dependencies already sent" in {
val req1 = ServerReflectionRequest(messageRequest =
MessageRequest.FileByFilename("org/apache/pekko/grpc/internal/reflection_test_4.proto"))
val req2 = ServerReflectionRequest(messageRequest =
MessageRequest.FileByFilename("org/apache/pekko/grpc/internal/reflection_test_1.proto"))
val req3 = ServerReflectionRequest(messageRequest =
MessageRequest.FileByFilename("org/apache/pekko/grpc/internal/reflection_test_2.proto"))

val responses =
serverReflection.serverReflectionInfo(Source(List(req1, req2, req3))).runWith(Sink.seq).futureValue

(responses should have).length(3)

val protos1 = decodeFileResponseToNames(responses.head)
protos1 should have size 1
protos1.head shouldBe "org/apache/pekko/grpc/internal/reflection_test_4.proto"

val protos2 = decodeFileResponseToNames(responses(1))
// all except 4, because 4 has already been sent
protos2 should have size 3
(protos2 should contain).allOf(
"org/apache/pekko/grpc/internal/reflection_test_1.proto",
"org/apache/pekko/grpc/internal/reflection_test_2.proto",
"org/apache/pekko/grpc/internal/reflection_test_3.proto")

val protos3 = decodeFileResponseToNames(responses(2))
// should still include 2, because 2 was explicitly requested, but should not include anything else
// because everything has already been sent
protos3 should have size 1
protos3.head shouldBe "org/apache/pekko/grpc/internal/reflection_test_2.proto"

}

}

private def decodeFileResponseToNames(response: ServerReflectionResponse): Seq[String] =
response.messageResponse.fileDescriptorResponse.value.fileDescriptorProto.map(bs =>
FileDescriptorProto.parseFrom(bs.newCodedInput()).name.getOrElse(""))

override protected def afterAll(): Unit = {
TestKit.shutdownActorSystem(system)
}
}
Loading