How to use code that relies on ThreadLocal with Kotlin coroutines

19.8k views Asked by At

Some JVM frameworks use ThreadLocal to store the call context of a application, like the SLF4j MDC, transaction managers, security managers, and others.

However, Kotlin coroutines are dispatched on different threads, so how it can be made to work?

(The question is inspired by GitHub issue)

2

There are 2 answers

3
Roman  Elizarov On BEST ANSWER

Coroutine's analog to ThreadLocal is CoroutineContext.

To interoperate with ThreadLocal-using libraries you need to implement a custom ContinuationInterceptor that supports framework-specific thread-locals.

Here is an example. Let us assume that we use some framework that relies on a specific ThreadLocal to store some application-specific data (MyData in this example):

val myThreadLocal = ThreadLocal<MyData>()

To use it with coroutines, you'll need to implement a context that keeps the current value of MyData and puts it into the corresponding ThreadLocal every time the coroutine is resumed on a thread. The code should look like this:

class MyContext(
    private var myData: MyData,
    private val dispatcher: ContinuationInterceptor
) : AbstractCoroutineContextElement(ContinuationInterceptor), ContinuationInterceptor {
    override fun <T> interceptContinuation(continuation: Continuation<T>): Continuation<T> =
        dispatcher.interceptContinuation(Wrapper(continuation))

    inner class Wrapper<T>(private val continuation: Continuation<T>): Continuation<T> {
        private inline fun wrap(block: () -> Unit) {
            try {
                myThreadLocal.set(myData)
                block()
            } finally {
                myData = myThreadLocal.get()
            }
        }

        override val context: CoroutineContext get() = continuation.context
        override fun resume(value: T) = wrap { continuation.resume(value) }
        override fun resumeWithException(exception: Throwable) = wrap { continuation.resumeWithException(exception) }
    }
}

To use it in your coroutines, you wrap the dispatcher that you want to use with MyContext and give it the initial value of your data. This value will be put into the thread-local on the thread where the coroutine is resumed.

launch(MyContext(MyData(), CommonPool)) {
    // do something...
}

The implementation above would also track any changes to the thread-local that was done and store it in this context, so this way multiple invocation can share "thread-local" data via context.

UPDATE: Starting with kotlinx.corutines version 0.25.0 there is direct support for representing Java ThreadLocal instances as coroutine context elements. See this documentation for details. There is also out-of-the-box support for SLF4J MDC via kotlinx-coroutines-slf4j integration module.

0
Alex On

Though this question is quite an old one, but I would want to add to Roman's answer another possible approach with CopyableThreadContextElement. Maybe it will be helpful for somebody else.

// Snippet from the source code's comment
class TraceContextElement(private val traceData: TraceData?) : CopyableThreadContextElement<TraceData?> {
    companion object Key : CoroutineContext.Key<TraceContextElement>

    override val key: CoroutineContext.Key<TraceContextElement> = Key

    override fun updateThreadContext(context: CoroutineContext): TraceData? {
        val oldState = traceThreadLocal.get()
        traceThreadLocal.set(traceData)
        return oldState
    }

    override fun restoreThreadContext(context: CoroutineContext, oldState: TraceData?) {
        traceThreadLocal.set(oldState)
    }

    override fun copyForChild(): TraceContextElement {
        // Copy from the ThreadLocal source of truth at child coroutine launch time. This makes
        // ThreadLocal writes between resumption of the parent coroutine and the launch of the
        // child coroutine visible to the child.
        return TraceContextElement(traceThreadLocal.get()?.copy())
    }

    override fun mergeForChild(overwritingElement: CoroutineContext.Element): CoroutineContext {
        // Merge operation defines how to handle situations when both
        // the parent coroutine has an element in the context and
        // an element with the same key was also
        // explicitly passed to the child coroutine.
        // If merging does not require special behavior,
        // the copy of the element can be returned.
        return TraceContextElement(traceThreadLocal.get()?.copy())
    }
}

Note that copyForChild method allows you to propagate thread local data taken from the parent coroutine's last resumption phase to the local context of the child coroutine (as Copyable in CopyableThreadContextElement implies), because method copyForChild will be invoked on the parent coroutine's thread associated with the corresponding resumption phase when a child coroutine was created.

Just by adding TraceContextElement context element to the root coroutine's context it will be propagated to all child coroutines as context element.

  runBlocking(Dispatchers.IO + TraceContextElement(someTraceDataInstance)){...}

Whereas with ContinuationInterceptor approach additional wrapping can be necessary for child coroutines' builders, if you redefine dispatchers for child coroutines.

fun main() {
    runBlocking(WrappedDispatcher(Dispatchers.IO)) {
        delay(100)
        println("It is wrapped!")
        delay(100)
        println("It is also wrapped!")
        // NOTE: we don't wrap with the WrappedDispatcher class here
        // redefinition of the dispatcher leads to replacement of our custom ContinuationInterceptor
        // with logic taken from specified dispatcher (in the case below from Dispatchers.Default)
        withContext(Dispatchers.Default) {
            delay(100)
            println("It is nested coroutine, and it isn't wrapped!")
            delay(100)
            println("It is nested coroutine, and it isn't wrapped!")
        }
        delay(100)
        println("It is also wrapped!")
    }
}

with wrapper overriding ContinuationInterceptor interface

class WrappedDispatcher(
    private val dispatcher: ContinuationInterceptor
) : AbstractCoroutineContextElement(ContinuationInterceptor), ContinuationInterceptor {

    override fun <T> interceptContinuation(continuation: Continuation<T>): Continuation<T> =
        dispatcher.interceptContinuation(ContinuationWrapper(continuation))

    private class ContinuationWrapper<T>(val base: Continuation<T>) : Continuation<T> by base {

        override fun resumeWith(result: Result<T>) {
            println("------WRAPPED START-----")
            base.resumeWith(result)
            println("------WRAPPED END-------")
        }
    }
}

output:

------WRAPPED START-----
------WRAPPED END-------
------WRAPPED START-----
It is wrapped!
------WRAPPED END-------
------WRAPPED START-----
It is also wrapped!
------WRAPPED END-------
It is nested coroutine, and it isn't wrapped!
It is nested coroutine, and it isn't wrapped!
------WRAPPED START-----
------WRAPPED END-------
------WRAPPED START-----
It is also wrapped!
------WRAPPED END-------

as you can see for the child (nested) coroutine our wrapper wasn't applied, since we reassigned a ContinuationInterceptor supplying another dispatcher as a parameter. This can lead to a problem as you can mistakenly forget to wrap a child coroutine's dispatcher.


As a side note, if you decide to choose this approach with ContinuationInterceptor, then consider to add such extension

fun ContinuationInterceptor.withMyProjectWrappers() = WrappedDispatcher(this)

wrapping your dispatcher with all necessary wrappers you have in the project, obviously it can be easily extended taking specific beans (wrappers) from an IoC container such as Spring.


And also as an extra example of CopyableThreadContextElement where thread local changes are saved in all resumptions phases.

Executors.newFixedThreadPool(..).asCoroutineDispatcher() is used to better illustrate that different threads can be working between resumptions phases.

val counterThreadLocal: ThreadLocal<Int> = ThreadLocal.withInitial{ 1 }

fun showCounter(){
    println("-------------------------------------------------")
    println("Thread: ${Thread.currentThread().name}\n Counter value: ${counterThreadLocal.get()}")
}

fun main() {
    runBlocking(Executors.newFixedThreadPool(10).asCoroutineDispatcher() + CounterPropagator(1)) {
        showCounter()
        delay(100)
        showCounter()
        counterThreadLocal.set(2)
        delay(100)
        showCounter()
        counterThreadLocal.set(3)
        val nested = async(Executors.newFixedThreadPool(10).asCoroutineDispatcher()) {
            println("-----------NESTED START---------")
            showCounter()
            delay(100)
            counterThreadLocal.set(4)
            showCounter()
            println("------------NESTED END-----------")
        }
        nested.await()
        showCounter()
        println("---------------END------------")
    }
}

class CounterPropagator(private var counterFromParenCoroutine: Int) : CopyableThreadContextElement<Int> {
    companion object Key : CoroutineContext.Key<CounterPropagator>

    override val key: CoroutineContext.Key<CounterPropagator> = Key

    override fun updateThreadContext(context: CoroutineContext): Int {
        // initialize thread local on the resumption
        counterThreadLocal.set(counterFromParenCoroutine)
        return 0
    }

    override fun restoreThreadContext(context: CoroutineContext, oldState: Int) {
        // propagate thread local changes between resumption phases in the same coroutine
        counterFromParenCoroutine = counterThreadLocal.get()
    }

    override fun copyForChild(): CounterPropagator {
        // propagate thread local changes to children
        return CounterPropagator(counterThreadLocal.get())
    }

    override fun mergeForChild(overwritingElement: CoroutineContext.Element): CoroutineContext {
        return CounterPropagator(counterThreadLocal.get())
    }
}

output:

-------------------------------------------------
Thread: pool-1-thread-1
 Counter value: 1
-------------------------------------------------
Thread: pool-1-thread-2
 Counter value: 1
-------------------------------------------------
Thread: pool-1-thread-3
 Counter value: 2
-----------NESTED START---------
-------------------------------------------------
Thread: pool-2-thread-1
 Counter value: 3
-------------------------------------------------
Thread: pool-2-thread-2
 Counter value: 4
------------NESTED END-----------
-------------------------------------------------
Thread: pool-1-thread-4
 Counter value: 3
---------------END------------

You can achieve similar behavior with ContinuationInterceptor (but don't forget to re-wrap dispatchers of child (nested) coroutines in the coroutine builder as was mentioned above)

val counterThreadLocal: ThreadLocal<Int> = ThreadLocal()

class WrappedDispatcher(
    private val dispatcher: ContinuationInterceptor,
    private var savedCounter: Int = counterThreadLocal.get() ?: 0
) : AbstractCoroutineContextElement(ContinuationInterceptor), ContinuationInterceptor {
    override fun <T> interceptContinuation(continuation: Continuation<T>): Continuation<T> =
        dispatcher.interceptContinuation(ContinuationWrapper(continuation))

    private inner class ContinuationWrapper<T>(val base: Continuation<T>) : Continuation<T> by base {

        override fun resumeWith(result: Result<T>) {
            counterThreadLocal.set(savedCounter)
            try {
                base.resumeWith(result)
            } finally {
                savedCounter = counterThreadLocal.get()
            }
        }
    }
}