Cannot access MY_CONTEXT_KEY.get form a Scala Future

43 views
Skip to first unread message

don...@gmail.com

unread,
Apr 5, 2017, 6:06:06 AM4/5/17
to grpc.io
gRPC Version:  "1.2.0"
Java Version: "1.8.0_65"

I'm using gRPC with https://github.com/scalapb/ScalaPB SBT plugin. 

It seems that in a Scala Future,  
MY_CONTEXT_KEY.get


awlays returns null, this is probably because the context is still with the original thread. Is there a way that I can somehow clone it to other threads used by scala futures?

Here is my test code:

  def ping(req: PingReq) = {
   
val keyVal = MY_CONTEXT_KEY.get
   
for {
      accountInfo
<- process(keyVal)
      res1
<- bar1.ping(req)
      res2
<- bar2.ping(req)
   
} yield {
      println
("=====FOO: " + accountInfo)
     
PingRes("AAA")
   
}
 
}

The above code works file, but the following doesn't

  def ping(req: PingReq) = {
   
for {
      keyVal
= Future {MY_CONTEXT_KEY.get}
      accountInfo
<- process(keyVal)
      res1
<- bar1.ping(req)
      res2
<- bar2.ping(req)
   
} yield {
      println
("=====FOO: " + accountInfo)
     
PingRes("AAA")
   
}
 
}


In the above example, keyVal is always Future(null)。

Besides, An client interceptor I came up with, called JwtClientInterceptor deosn't work either. This class is defined as:

class JwtClientInterceptor extends ClientInterceptor {
 
override def interceptCall[ReqT, RespT](
    methodDescriptor
: MethodDescriptor[ReqT, RespT],
    callOptions
: CallOptions,
    channel
: Channel
 
) = {
   
new ForwardingClientCall.SimpleForwardingClientCall[ReqT, RespT](
      channel
.newCall(methodDescriptor, callOptions)
   
) {
     
override def start(responseListener: ClientCall.Listener[RespT], headers: Metadata) = {
        val jwt
= CTX_AUTH_JWT.get
       
if (jwt != null && jwt.nonEmpty) {
          headers
.put(AUTH_HEADER, CTX_AUTH_JWT.get)
       
}
       
super.start(responseListener, headers)
     
}
   
}
 
}
}
In the above exampole, 'val jwt' is always null BECAUSE a downstream RPC call is invoked from inside of a scala Future - actually a for-comprehension statement like follows:




  val bar1
= ...
  val bar2
= {
    val channel2
= ManagedChannelBuilder
     
.forAddress("localhost", 12333)
     
.usePlaintext(true)
     
.intercept(new JwtClientInterceptor())
     
.intercept(new TraceIdClientInterceptor())
     
.build()

 
def someRpcMethod(req: MethodReq) = {
   
implicit val sub = getJwtSubject
   
for {
      res1
<- bar1.ping(req)
      res2
<- bar2.ping(req)
   
} yield {
     
MethodResp(res1 + res2)
   
}
 
}
}





How can I access the same context from scala Futures??? I hope someone can help me out, this is now THE road blocker for my project.

don...@gmail.com

unread,
Apr 5, 2017, 11:07:58 PM4/5/17
to grpc.io, don...@gmail.com
Below I posted a related complete example, where I created a TraceIdServer/ClientIntereceptor. TradeIdServerInterceptor will create a random ID and I hope that ID flows with TradeIdClientInterceptor so the how RPC invocation can share the same ID.

In some places in the code, I added comment "// NOTE: SOMEHOW THIS IS ALWAYS NULL" for your attention.

package io.dong.commons.grpc


import io.grpc._
package object utils {
  val CTX_TRACE_ID
: Context.Key[String] = Context.key[String]("traceId")
  val TRACEID_HEADER
= Metadata.Key.of("trace-id", Metadata.ASCII_STRING_MARSHALLER)
}




-----




package io.dong.commons.grpc.utils


import io.grpc._


class TraceIdClientInterceptor extends ClientInterceptor {
  val rand
= new scala.util.Random




 
override def interceptCall[ReqT, RespT](
    methodDescriptor
: MethodDescriptor[ReqT, RespT],
    callOptions
: CallOptions,
    channel
: Channel
 
) = {


   
new ForwardingClientCall.SimpleForwardingClientCall[ReqT, RespT](
      channel
.newCall(methodDescriptor, callOptions)
   
) {
     
override def start(responseListener: ClientCall.Listener[RespT], headers: Metadata) = {

        val ctxTradeId
= CTX_TRACE_ID.get


        println
("~~~~~~~~~ clientInterceptor: " + ctxTradeId) // NOTE: SOMEHOW THIS IS ALWAYS NULL
        val traceId
= if (ctxTradeId == null || ctxTradeId.isEmpty) "T" + rand.nextLong.toString else ctxTradeId
        headers
.put(TRACEID_HEADER, traceId)
       
super.start(responseListener, headers)
     
}
   
}
 
}
}




-----




package io.dong.commons.grpc.utils


import io.grpc._
import org.slf4s.Logging


class TraceIdServerInterceptor extends ServerInterceptor with Logging {
  val rand
= new scala.util.Random




 
override def interceptCall[ReqT, RespT](

    call
: ServerCall[ReqT, RespT],
    metadata
: Metadata,
   
next: ServerCallHandler[ReqT, RespT]
 
) = {
    val headerTradeId
= metadata.get(TRACEID_HEADER) // THIS IS NOT NULL
    val traceId
= if (headerTradeId == null || headerTradeId.isEmpty) "T" + rand.nextLong.toString else headerTradeId
    val ctx
= Context.current().withValue(CTX_TRACE_ID, traceId)
   
Contexts.interceptCall(ctx, call, metadata, next)
 
}
}


-----






private[demo] class FooServiceImpl extends FooServiceGrpc.FooService {
  val bar
= {
    val channel
= ManagedChannelBuilder


     
.forAddress("localhost", 12333)
     
.usePlaintext(true)

     
.intercept(new TraceIdClientInterceptor())
     
.build()


   
BarServiceGrpc.stub(channel)


 
}


 
def ping(req: PingReq) = {
   
for {

      traceId
<- Future(CTX_TRACE_ID.get) // NOTE: SOMEHOW THIS IS ALWAYS NULL
      _
<- bar.ping(req) // CALLS Bar service
   
} yield {
      println
("=====traceId in Foo: " + traceId)
     
PingRes("AAA")
   
}
 
}
}


class BarServiceImpl extends BarServiceGrpc.BarService {


 
def ping(req: PingReq) = {
    CTX_TRACE_ID
.get // NOTE: THIS IS NOT NULL
   
for {
      traceId
<- Future(CTX_TRACE_ID.get) // NOTE: SOMEHOW THIS IS ALWAYS NULL
   
} yield {
      println
("=====traceId in Bar: " + traceId)
     
PingRes(req.ping.reverse)
   
}
 
}
}






-----


class Servers { self =>
 
private[this] var server: Server = null


 
def start() {
    val context
= scala.concurrent.ExecutionContext.global
    val fooService
= FooServiceGrpc.bindService(new FooServiceImpl, context)
    val barService
= BarServiceGrpc.bindService(new BarServiceImpl, context)


    val port
= 12333


    val traceIdInterceptor
= new TraceIdServerInterceptor()
    val loggingInterceptor
= new LoggingServerInterceptor()
    server
= ServerBuilder
     
.forPort(port)
     
.addService(ServerInterceptors.intercept(fooService, traceIdInterceptor, loggingInterceptor))
     
.addService(ServerInterceptors.intercept(barService, traceIdInterceptor, loggingInterceptor))
     
.build
     
.start


    println
("Server started, listening on " + port)


   
Runtime.getRuntime.addShutdownHook(new Thread() {
     
override def run() {
       
System.err.println("*** shutting down gRPC server since JVM is shutting down")
       
self.stop()
       
System.err.println("*** server shut down")
     
}
   
})
 
}


 
def stop() { if (server != null) server.shutdown() }
 
def blockUntilShutdown() { if (server != null) server.awaitTermination() }


}


object Server extends App {
  val s
= new Servers
  s
.start()
  s
.blockUntilShutdown()
}


object Client extends App {
  val channel
= ManagedChannelBuilder


   
.forAddress("localhost", 12333)
   
.usePlaintext(true)

   
.build()


  val foo
= FooServiceGrpc.blockingStub(channel)
  foo
.ping(PingReq("HELLO"))
  channel
.shutdown().awaitTermination(5, TimeUnit.SECONDS)
}

Reply all
Reply to author
Forward
Message has been deleted
0 new messages