MY_CONTEXT_KEY.get def ping(req: PingReq) = {
val keyVal = MY_CONTEXT_KEY.get
for {
accountInfo <- process(keyVal)
res1 <- bar1.ping(req)
res2 <- bar2.ping(req)
} yield {
println("=====FOO: " + accountInfo)
PingRes("AAA")
}
} def ping(req: PingReq) = {
for {
keyVal = Future {MY_CONTEXT_KEY.get}
accountInfo <- process(keyVal)
res1 <- bar1.ping(req)
res2 <- bar2.ping(req)
} yield {
println("=====FOO: " + accountInfo)
PingRes("AAA")
}
}class JwtClientInterceptor extends ClientInterceptor {
override def interceptCall[ReqT, RespT](
methodDescriptor: MethodDescriptor[ReqT, RespT],
callOptions: CallOptions,
channel: Channel
) = {
new ForwardingClientCall.SimpleForwardingClientCall[ReqT, RespT](
channel.newCall(methodDescriptor, callOptions)
) {
override def start(responseListener: ClientCall.Listener[RespT], headers: Metadata) = {
val jwt = CTX_AUTH_JWT.get
if (jwt != null && jwt.nonEmpty) {
headers.put(AUTH_HEADER, CTX_AUTH_JWT.get)
}
super.start(responseListener, headers)
}
}
}
}
val bar1 = ...
val bar2 = {
val channel2 = ManagedChannelBuilder
.forAddress("localhost", 12333)
.usePlaintext(true)
.intercept(new JwtClientInterceptor())
.intercept(new TraceIdClientInterceptor())
.build()
def someRpcMethod(req: MethodReq) = {
implicit val sub = getJwtSubject
for {
res1 <- bar1.ping(req)
res2 <- bar2.ping(req)
} yield {
MethodResp(res1 + res2)
}
}
}
package io.dong.commons.grpc
import io.grpc._
package object utils {
val CTX_TRACE_ID: Context.Key[String] = Context.key[String]("traceId")
val TRACEID_HEADER = Metadata.Key.of("trace-id", Metadata.ASCII_STRING_MARSHALLER)
}
-----
package io.dong.commons.grpc.utils
import io.grpc._
class TraceIdClientInterceptor extends ClientInterceptor {
val rand = new scala.util.Random
override def interceptCall[ReqT, RespT](
methodDescriptor: MethodDescriptor[ReqT, RespT],
callOptions: CallOptions,
channel: Channel
) = {
new ForwardingClientCall.SimpleForwardingClientCall[ReqT, RespT](
channel.newCall(methodDescriptor, callOptions)
) {
override def start(responseListener: ClientCall.Listener[RespT], headers: Metadata) = {
val ctxTradeId = CTX_TRACE_ID.get
println("~~~~~~~~~ clientInterceptor: " + ctxTradeId) // NOTE: SOMEHOW THIS IS ALWAYS NULL
val traceId = if (ctxTradeId == null || ctxTradeId.isEmpty) "T" + rand.nextLong.toString else ctxTradeId
headers.put(TRACEID_HEADER, traceId)
super.start(responseListener, headers)
}
}
}
}
-----
package io.dong.commons.grpc.utils
import io.grpc._
import org.slf4s.Logging
class TraceIdServerInterceptor extends ServerInterceptor with Logging {
val rand = new scala.util.Random
override def interceptCall[ReqT, RespT](
call: ServerCall[ReqT, RespT],
metadata: Metadata,
next: ServerCallHandler[ReqT, RespT]
) = {
val headerTradeId = metadata.get(TRACEID_HEADER) // THIS IS NOT NULL
val traceId = if (headerTradeId == null || headerTradeId.isEmpty) "T" + rand.nextLong.toString else headerTradeId
val ctx = Context.current().withValue(CTX_TRACE_ID, traceId)
Contexts.interceptCall(ctx, call, metadata, next)
}
}
-----
private[demo] class FooServiceImpl extends FooServiceGrpc.FooService {
val bar = {
val channel = ManagedChannelBuilder
.forAddress("localhost", 12333)
.usePlaintext(true)
.intercept(new TraceIdClientInterceptor())
.build()
BarServiceGrpc.stub(channel)
}
def ping(req: PingReq) = {
for {
traceId <- Future(CTX_TRACE_ID.get) // NOTE: SOMEHOW THIS IS ALWAYS NULL
_ <- bar.ping(req) // CALLS Bar service
} yield {
println("=====traceId in Foo: " + traceId)
PingRes("AAA")
}
}
}
class BarServiceImpl extends BarServiceGrpc.BarService {
def ping(req: PingReq) = {
CTX_TRACE_ID.get // NOTE: THIS IS NOT NULL
for {
traceId <- Future(CTX_TRACE_ID.get) // NOTE: SOMEHOW THIS IS ALWAYS NULL
} yield {
println("=====traceId in Bar: " + traceId)
PingRes(req.ping.reverse)
}
}
}
-----
class Servers { self =>
private[this] var server: Server = null
def start() {
val context = scala.concurrent.ExecutionContext.global
val fooService = FooServiceGrpc.bindService(new FooServiceImpl, context)
val barService = BarServiceGrpc.bindService(new BarServiceImpl, context)
val port = 12333
val traceIdInterceptor = new TraceIdServerInterceptor()
val loggingInterceptor = new LoggingServerInterceptor()
server = ServerBuilder
.forPort(port)
.addService(ServerInterceptors.intercept(fooService, traceIdInterceptor, loggingInterceptor))
.addService(ServerInterceptors.intercept(barService, traceIdInterceptor, loggingInterceptor))
.build
.start
println("Server started, listening on " + port)
Runtime.getRuntime.addShutdownHook(new Thread() {
override def run() {
System.err.println("*** shutting down gRPC server since JVM is shutting down")
self.stop()
System.err.println("*** server shut down")
}
})
}
def stop() { if (server != null) server.shutdown() }
def blockUntilShutdown() { if (server != null) server.awaitTermination() }
}
object Server extends App {
val s = new Servers
s.start()
s.blockUntilShutdown()
}
object Client extends App {
val channel = ManagedChannelBuilder
.forAddress("localhost", 12333)
.usePlaintext(true)
.build()
val foo = FooServiceGrpc.blockingStub(channel)
foo.ping(PingReq("HELLO"))
channel.shutdown().awaitTermination(5, TimeUnit.SECONDS)
}