Set coroutine context from spring webflux WebFilter

3.2k views Asked by At

How to set coroutine context from spring webflux WebFilter? Is it possible? I know I can use reactor context, but I'm not able to set coroutine context.

MORE DETAILS:

I want to use MDCContext to propagate MDC to slf4j. For example, I would like to get MDC from HTTP headers and then I want these values to be automatically propagated to any log I write.

Currently, I can:

  • I set reactor context in WebFilter
  • in every controller I get values from reactor context and put them inside MDCContext (coroutine)

As you see, this is not very convenient as I have to add extra code in the controllers.

Is there a way to automatically transform Reactor context to coroutine context? I know I can do vice versa with ContextInjector and ServiceLoader (see https://github.com/Kotlin/kotlinx.coroutines/issues/284#issuecomment-516270570), but it seems there is no such mechanism for reverse conversion.

2

There are 2 answers

0
walrus03 On

Unfortunately, it is not possible now. There is an open issue in spring framework to fix this that you can upvote -> https://github.com/spring-projects/spring-framework/issues/26977

4
Numichi On
@Component
class AuthorizationFilter : WebFilter {
    override fun filter(exchange: ServerWebExchange, chain: WebFilterChain): Mono<Void> {
        return chain.filter(exchange).contextWrite { ctx = ctx.put(KEY1, VALUE1) }
}

And: Using ReactiveSecurityContextHolder inside a Kotlin Flow

UPDATE 3 (25.01.2022)

I have created a library that resolves MDC LocalThread problem in a reactive environment. I have created a special Map implementation MDC class that journey in a reactive context.

https://github.com/Numichi/reactive-logger

UPDATE 1

Use and Add context in Kotlin Coroutine.

val value1 = coroutineContext[ReactiveContext]?.context?.get(KEY1) // VALUE1

//--

withContext(Context.of()) {
     val x = coroutineContext[ReactiveContext]?.context?.get(KEY1) // NoSuchElementException
}

withContext(coroutineContext[ReactiveContext]?.context?.asCoroutineContext()) {
     val x = coroutineContext[ReactiveContext]?.context?.get(KEY1) // Work
}

// Add new key-pair context
val newContext = Context.of(coroutineContext[ReactiveContext]?.context ?: Context.of()).put(KEY2, VALUE2)
withContext(newContext.asCoroutineContext()) {
     val x = coroutineContext[ReactiveContext]?.context?.get(KEY2) // Work
}

UPDATE 2 (25.12.2021)

I use Log4j2 with slf4j. But, I think it will work another implementation (example: logback).

build.gradle.kts

configurations {
    // ...
    all {
        exclude("org.springframework.boot", "spring-boot-starter-logging")
    }
    // ...
}

// ...

dependencies {
    // ...
    implementation("org.springframework.boot:spring-boot-starter-log4j2:VERSION")
    implementation("org.jetbrains.kotlinx:kotlinx-coroutines-slf4j:VERSION")
    // ...
}

(Optinonal) If you use WebFilter and writeContext via WebFlux. Would you like to put all ReactorContext copy into MDCContext, use the below code. You will experience MDCContext containing all ReactorContext elements on begin of the controller.

If you would like to use @ExceptionHandler MDCContext will drop all values that you added MDC.put("key", "value") after the controller because the runner exit the suspended scopes. They work like code variables and code blocks. So, I recommend saving any values in exception and restore in handler from throwable instance.

package your.project.package

import org.slf4j.MDC
import reactor.core.CoreSubscriber
import reactor.core.publisher.Hooks
import reactor.core.publisher.Operators
import reactor.util.context.Context
import java.util.stream.Collectors
import javax.annotation.PostConstruct
import javax.annotation.PreDestroy
import org.reactivestreams.Subscription
import org.springframework.context.annotation.Configuration

@Configuration
class MdcContextLifterConfiguration {
    companion object {
        val MDC_CONTEXT_REACTOR_KEY: String = MdcContextLifterConfiguration::class.java.name
    }

    @PostConstruct
    fun contextOperatorHook() {
        Hooks.onEachOperator(MDC_CONTEXT_REACTOR_KEY, Operators.lift { _, subscriber -> MdcContextLifter(subscriber) })
    }

    @PreDestroy
    fun cleanupHook() {
        Hooks.resetOnEachOperator(MDC_CONTEXT_REACTOR_KEY)
    }
}

class MdcContextLifter<T>(private val coreSubscriber: CoreSubscriber<T>) : CoreSubscriber<T> {

    override fun onNext(t: T) {
        coreSubscriber.currentContext().copyToMdc()
        coreSubscriber.onNext(t)
    }

    override fun onSubscribe(subscription: Subscription) {
        coreSubscriber.onSubscribe(subscription)
    }

    override fun onComplete() {
        coreSubscriber.onComplete()
    }

    override fun onError(throwable: Throwable?) {
        coreSubscriber.onError(throwable)
    }

    override fun currentContext(): Context {
        return coreSubscriber.currentContext()
    }
}

private fun Context.copyToMdc() {
    if (!this.isEmpty) {
        val map: Map<String, String> = this.stream()
            .collect(Collectors.toMap({ e -> e.key.toString() }, { e -> e.value.toString() }))

        MDC.setContextMap(map)
    } else {
        MDC.clear()
    }
}

So you can use MDCContext (or in any class). Ofc, not need call every time LoggerFactory.getLogger(javaClass). This can also be organized into attributes.

import kotlinx.coroutines.slf4j.MDCContext
import kotlinx.coroutines.withContext
import org.slf4j.LoggerFactory

// ...

suspend fun info() {
    withContext(MDCContext()) {
        LoggerFactory.getLogger(javaClass).info("")
    }
}

In log4j2.xml you can reference an MDC key and load it there. Example:

  • <PatternLayout pattern="%mdc{context_map_key}">
  • Or create self output Plugin.

Log4J Plugin

Add more one dependency with annotationProcessor

dependencies {
    // ...
    annotationProcessor("org.apache.logging.log4j:log4j-core:VERSION")
    // ...
}

Write plugin. Ofc, it is a minimalist:

package your.project.package.log4j

import org.apache.logging.log4j.core.Core
import org.apache.logging.log4j.core.Layout
import org.apache.logging.log4j.core.LogEvent
import org.apache.logging.log4j.core.config.plugins.Plugin
import org.apache.logging.log4j.core.config.plugins.PluginFactory
import org.apache.logging.log4j.core.layout.AbstractStringLayout
import java.nio.charset.Charset
import java.nio.charset.StandardCharsets

@Plugin(name = ExampleLog4JPlugin.PLUGIN_NAME, category = Core.CATEGORY_NAME, elementType = Layout.ELEMENT_TYPE)
class ExampleLog4JPlugin private constructor(charset: Charset) : AbstractStringLayout(charset) {
    companion object {
        const val PLUGIN_NAME = "ExampleLog4JPlugin"

        @JvmStatic
        @PluginFactory
        fun factory(): ExampleLog4JPlugin{
            return ExampleLog4JPlugin(StandardCharsets.UTF_8)
        }
    }

    override fun toSerializable(event: LogEvent): String {
        // event.contextData <-- this will contain MDCContext map
        return "String return. Itt this will appear in the log."
    }
}

And log4j2.xml what in project/src/main/resources/log4j2.xml.

<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
<Configuration packages="your.project.package.log4j">
    <Appenders>
        <Console name="stdout" target="SYSTEM_OUT">
            <ExampleLog4JPlugin/>
        </Console>
    </Appenders>
    <Loggers>
        <Root level="DEBUG">
            <AppenderRef ref="stdout"/>
        </Root>
    </Loggers>
</Configuration>