How can I make asserts on outbound HTTP requests?

25 views Asked by At

Here's an example. It's a trimmed and simplified version of org.springframework.cloud.gateway.filter.WebClientHttpRoutingFilter:

package com.example.gatewaydemo.misc;

import java.net.URI;
import java.util.stream.Stream;

import reactor.core.publisher.Mono;

import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.http.HttpMethod;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.client.WebClient;
import org.springframework.web.server.ServerWebExchange;

import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.CLIENT_RESPONSE_ATTR;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_REQUEST_URL_ATTR;

/**
 * A {@link GlobalFilter} that actually makes an asynchronous call to the proxied server.
 */
public class CallingGlobalFilter implements GlobalFilter {
    private final WebClient webClient;

    public CallingGlobalFilter(WebClient webClient) {
        this.webClient = webClient;
    }

    @Override
    public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
        URI requestUrl = exchange.getRequiredAttribute(GATEWAY_REQUEST_URL_ATTR);

        ServerHttpRequest request = exchange.getRequest();

        HttpMethod method = request.getMethod();

        WebClient.RequestBodySpec bodySpec = this.webClient.method(method)
                .uri(requestUrl)
                .headers(h -> h.addAll(request.getHeaders()));

        WebClient.RequestHeadersSpec<?> headersSpec = requiresBody(method) ?
                bodySpec.body(BodyInserters.fromDataBuffers(request.getBody())) :
                bodySpec;

        return headersSpec.exchangeToMono(Mono::just)
                .flatMap(res -> {
                    ServerHttpResponse response = exchange.getResponse();
                    response.getHeaders().putAll(res.headers().asHttpHeaders());
                    response.setStatusCode(res.statusCode());
                    exchange.getAttributes().put(CLIENT_RESPONSE_ATTR, res);
                    return chain.filter(exchange);
                });
    }

    private boolean requiresBody(HttpMethod method) {
        return Stream.of(HttpMethod.POST, HttpMethod.PUT, HttpMethod.PATCH)
                .anyMatch(m -> method.matches(m.toString()));
    }
}
<!-- if you want my specific example to compile, include these dependencies -->
        <dependency>
            <groupId>org.springframework.cloud</groupId>
            <artifactId>spring-cloud-starter-gateway</artifactId>
        </dependency>

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>test</scope>
        </dependency>

It's a filter that uses an injected WebClient to make a request to a proxied server and then wraps it in a MonoFlatMap. I need to make sure the outbound request is correct. For example, expected behavior includes ignoring body if the request's method is GET. I need to write a test for that

I can't ArgumentCapture the exchange passed to chain.filter(..) and make asserts on it since this filter filters the original exchange. That is, it's still going to wrap a request with a body, the assertion will fail

How do I actually make asserts on outbound HTTP requests in such cases?

1

There are 1 answers

0
cyberbrain On

There is a lot going on in the filter method, but it is possible to do a complete unit test. It's going to be a bit long although.

Below I wrote an example for a unit test just for the GET case you asked for.

As you have the decision between with or without body in a private method, this method cannot be tested on its own - you will have to cover the different cases by tests for the outer filter class - but you don't have to run the full test for each of the cases. Just mock, assert, validate and maybe capture enough so you enter the private method with different inputs and can differentiate the behaviour. This might even be a good case for a parametrized test.

For testing the lambdas and method references, you will see that I used argument captors to fetch the actual Function or Consumer parameters and then I called methods of the captured parameters as additional tests with extra mocks, verifications and assertions.

Feel free to split up that huge test into several smaller ones where you do not capture everything in one go, but check the lambdas and method references more separately - although you still will have to mock a lot stuff before those actual calls happen.

Note that I also mixed @Mock annotations with local calls to Mockito.mock - the latter ones are just my preference to better see where I really use a mock and prevent mistakes on reusing of mocks in one test, but I use the @Mock and @Captor variants for mocks and argument captors with type parameters to prevent too many warnings about class casts. You will see that I didn't take extra warning prevention steps for the static method testing part of Mono::just, because it happens only once and the code was really long anyway.

package com.example.gatewaydemo.misc;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.*;
import org.mockito.junit.jupiter.MockitoExtension;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.HttpStatusCode;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.web.reactive.function.client.ClientResponse;
import org.springframework.web.reactive.function.client.WebClient;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;

import java.net.URI;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Function;

import static org.mockito.Mockito.*;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.CLIENT_RESPONSE_ATTR;
import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_REQUEST_URL_ATTR;

@ExtendWith(MockitoExtension.class)
class CallingGlobalFilterTest {

  @Captor
  ArgumentCaptor<Consumer<HttpHeaders>> captorHeaders;

  @Mock
  Mono<ClientResponse> monoClientResponseMock1;

  @Mock
  Mono<ClientResponse> monoClientResponseMock2;

  @Captor
  ArgumentCaptor<Function<ClientResponse, Mono<ClientResponse>>> exchangeToMonoCaptor;

  @Mock
  Mono<Void> expectedResult;

  @Captor
  ArgumentCaptor<Function<ClientResponse, Mono<? extends Void>>> flatMapCaptor;

  @Mock
  Mono<Void> flatMapResultMock;

  @Test
  void filter_get_request() {

    WebClient webClientMock = mock(WebClient.class);
    CallingGlobalFilter underTest = new CallingGlobalFilter(webClientMock);

    ServerWebExchange exchangeMock = mock(ServerWebExchange.class);
    GatewayFilterChain chainMock = mock(GatewayFilterChain.class);

    URI requestUrlMock = mock(URI.class);
    when(exchangeMock.getRequiredAttribute(GATEWAY_REQUEST_URL_ATTR)).thenReturn(requestUrlMock);

    ServerHttpRequest serverHttpRequestMock = mock(ServerHttpRequest.class);
    when(exchangeMock.getRequest()).thenReturn(serverHttpRequestMock);
    when(serverHttpRequestMock.getMethod()).thenReturn(HttpMethod.GET);

    WebClient.RequestBodyUriSpec requestBodyUriSpecMock1 = mock(WebClient.RequestBodyUriSpec.class);
    when(webClientMock.method(HttpMethod.GET)).thenReturn(requestBodyUriSpecMock1);
    when(requestBodyUriSpecMock1.uri(requestUrlMock)).thenReturn(requestBodyUriSpecMock1);
    when(requestBodyUriSpecMock1.headers(any())).thenReturn(requestBodyUriSpecMock1);

    when(requestBodyUriSpecMock1.exchangeToMono(Mockito.<Function<ClientResponse, Mono<ClientResponse>>>any())).thenReturn(monoClientResponseMock1);

    when(monoClientResponseMock1.flatMap(Mockito.<Function<ClientResponse, Mono<? extends Void>>>any())).thenReturn(expectedResult);

    // actual call of outer function
    Mono<Void> actualResult = underTest.filter(exchangeMock, chainMock);

    Assertions.assertSame(expectedResult, actualResult);

    verify(requestBodyUriSpecMock1).headers(captorHeaders.capture());
    Consumer<HttpHeaders> capturedHeadersConsumer = captorHeaders.getValue();
    HttpHeaders httpHeadersFromRequestMock = mock(HttpHeaders.class);
    when(serverHttpRequestMock.getHeaders()).thenReturn(httpHeadersFromRequestMock);

    HttpHeaders httpHeadersMock1 = mock(HttpHeaders.class);

    // actual call of lambda 1
    capturedHeadersConsumer.accept(httpHeadersMock1);

    verify(httpHeadersMock1).addAll(same(httpHeadersFromRequestMock));

    verify(requestBodyUriSpecMock1, never()).body(any());
    verify(serverHttpRequestMock, never()).getBody();

    verify(requestBodyUriSpecMock1).exchangeToMono(exchangeToMonoCaptor.capture());

    Function<ClientResponse, Mono<ClientResponse>> exchangeToMonoFunction = exchangeToMonoCaptor.getValue();
    ClientResponse clientResponseMock1 = mock(ClientResponse.class);

    try (MockedStatic<Mono> mono = Mockito.mockStatic(Mono.class)) {
      mono.when(() -> Mono.just(clientResponseMock1))
          .thenReturn(monoClientResponseMock2);

      // actual call of method reference to Mono::join
      Mono<ClientResponse> actualInnerResult = exchangeToMonoFunction.apply(clientResponseMock1);

      Assertions.assertSame(monoClientResponseMock2, actualInnerResult);
    }

    verify(monoClientResponseMock1).flatMap(flatMapCaptor.capture());

    Function<ClientResponse, Mono<? extends Void>> flatMapFunction = flatMapCaptor.getValue();

    ClientResponse clientResponseMock2 = mock(ClientResponse.class);

    ServerHttpResponse responseMock = mock(ServerHttpResponse.class);
    when(exchangeMock.getResponse()).thenReturn(responseMock);

    HttpHeaders httpHeadersMock2 = mock(HttpHeaders.class);
    when(responseMock.getHeaders()).thenReturn(httpHeadersMock2);

    ClientResponse.Headers clientResponseHeadersMock = mock(ClientResponse.Headers.class);
    when(clientResponseMock2.headers()).thenReturn(clientResponseHeadersMock);

    HttpHeaders httpHeadersFromClientResponseMock = mock(HttpHeaders.class);
    when(clientResponseHeadersMock.asHttpHeaders()).thenReturn(httpHeadersFromClientResponseMock);

    HttpStatusCode httpStatusCodeMock = mock(HttpStatus.class); // have to mock an implementation of the sealed interface
    when(clientResponseMock2.statusCode()).thenReturn(httpStatusCodeMock);

    Map<String, Object> attributesMap = new HashMap<>();
    when(exchangeMock.getAttributes()).thenReturn(attributesMap);

    when(chainMock.filter(exchangeMock)).thenReturn(flatMapResultMock);

    // actual call of lambda 2
    Mono<? extends Void> actualFlatMapFunctionResult = flatMapFunction.apply(clientResponseMock2);

    verify(exchangeMock).getResponse();
    verify(httpHeadersMock2).putAll(same(httpHeadersFromClientResponseMock));

    verify(responseMock).setStatusCode(same(httpStatusCodeMock));

    Assertions.assertSame(clientResponseMock2, attributesMap.get(CLIENT_RESPONSE_ATTR));

    Assertions.assertSame(flatMapResultMock, actualFlatMapFunctionResult);
  }
}