Kafka Streams IN_MEMORY suppressed session window store state not persisted across restarts

124 views Asked by At

From what I read Kafka Streams should write store data to a changelog topic so when it restarts, it can build up its internal state and continue where it left off. This works when using an on disk RocksDB store, but fails when using the IN_MEMORY session store. The behavior is the same when using SessionWindows, or using TimeWindows. Turning suppression off will result in intermediate results, but after a restart we will only get the second half of the second session and the complete third session.

The following test was added in the Kafka Streams integration tests where we want to create 3 sessions in total and only emit the result when they are done using suppress:

class TimeExtractor implements TimestampExtractor {
  @Override
  public long extract(ConsumerRecord<Object, Object> record, long partitionTime) {
    String str = (String) record.value();
    return Long.parseLong(str.split("\\.")[0]);
  }
}

@Timeout(600)
@Tag("integration")
public class RestartIntegrationTest {
  private static final int NUM_BROKERS = 1;
  private static String applicationId;
  private static final String STORE_NAME = "le_store";

  public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(NUM_BROKERS);

  @AfterAll
  public static void closeCluster() {
    CLUSTER.stop();
  }

  private static final String key = "key";


  private static KafkaConsumer<String, String> consumer;
  private static KafkaConsumer<String, String> changeLogConsumer;
  private static KafkaProducer<String, String> producer;
  private static Properties kafkaStreamsProps = new Properties();
  private static Properties producerProps = new Properties();
  private static Properties consumerProps = new Properties();

  @BeforeAll
  static void beforeAll() throws IOException, InterruptedException {
    CLUSTER.start();
    CLUSTER.createTopics("input");
    CLUSTER.createTopics("output");

    applicationId = "integration-test-" + Instant.now().toEpochMilli();

    kafkaStreamsProps.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass().getName());
    kafkaStreamsProps.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass().getName());
    kafkaStreamsProps.put(StreamsConfig.APPLICATION_ID_CONFIG, applicationId);
    kafkaStreamsProps.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers());

    producerProps.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers());
    producerProps.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, StringSerializer.class.getName());
    producerProps.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, StringSerializer.class.getName());

    consumerProps.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers());
    consumerProps.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class.getName());
    consumerProps.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class.getName());
    consumerProps.put(ConsumerConfig.GROUP_ID_CONFIG, "integration_test_group_id");
    consumerProps.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest");

    Properties changeLogConsumerProps = new Properties();
    changeLogConsumerProps.putAll(consumerProps);
    changeLogConsumerProps.put(ConsumerConfig.GROUP_ID_CONFIG, "changelog_consumer_integration_test_group_id");


    consumer = new KafkaConsumer<>(consumerProps);
    changeLogConsumer = new KafkaConsumer<>(changeLogConsumerProps);
    producer = new KafkaProducer<>(producerProps);
  }

  @AfterAll
  static void afterAll() {
    producer.close();
    consumer.close();
    changeLogConsumer.close();

    CLUSTER.stop();
  }

  StreamsBuilder getTopology() {
    StreamsBuilder builder = new StreamsBuilder();

    // Breaks
    Materialized.StoreType storeType = Materialized.StoreType.IN_MEMORY;
    // Works
    // Materialized.StoreType storeType = Materialized.StoreType.ROCKS_DB;

    // Same behavior for session windows AND time windows...
    // Windows windowBy = TimeWindows.ofSizeWithNoGrace(Duration.ofMinutes(1));
    // Materialized<String, String, WindowStore<Bytes, byte[]>> materialize =
    //   Materialized.<String, String, WindowStore<Bytes, byte[]>>as(STORE_NAME).withStoreType(storeType);

     SessionWindows windowBy = SessionWindows.ofInactivityGapWithNoGrace(Duration.ofMinutes(1));
     Materialized<String, String, SessionStore<Bytes, byte[]>> materialize =
       Materialized.<String, String, SessionStore<Bytes, byte[]>>as(STORE_NAME).withStoreType(storeType);

    builder.stream("input",
        Consumed.with(Serdes.String(), Serdes.String()).withTimestampExtractor(new TimeExtractor()))
      .peek((k, v) -> System.out.println("input k: " + k + " v: " + v))
      .mapValues(v -> v.split("\\.")[1])
      .groupByKey()
      .windowedBy(windowBy)
      .reduce( (agg, curr) -> agg + " " + curr, materialize)
      .suppress(Suppressed.untilWindowCloses(Suppressed.BufferConfig.unbounded()))
      .toStream()
      .peek((k, v) -> System.out.println("output k: " + k + " v: " + v))
      .mapValues((k, v) -> k + ": " + v)
      .to("output");

    return builder;
  }

  private String sessionToAggregatedString(List<KeyValue<String, String>> messages) {
    StringBuilder stringBuilder = new StringBuilder();
    for (KeyValue<String, String> message : messages) {
      stringBuilder.append(message.value.split("\\.")[1]).append(" ");
    }
    return stringBuilder.toString().trim();
  }

  private List<ConsumerRecord<String, String>> readUntilTime(
    Consumer<String, String> consumer,
    List<String> topics,
    long listenForSeconds) {
    consumer.subscribe(topics);

    List<ConsumerRecord<String, String>> data = new ArrayList<>();

    Instant stopAt = Instant.now().plusSeconds(listenForSeconds);

    while (Instant.now().isBefore(stopAt)) {
      ConsumerRecords<String, String> records = consumer.poll(Duration.ofMillis(1000));
      records.forEach(data::add);
    }

    return data;
  }

  List<ConsumerRecord<String, String>> startProcessingAndClose(
    List<KeyValue<String, String>> messages,
    boolean shouldClean
  ) throws InterruptedException {
    StreamsBuilder builder = getTopology();
    Topology topology = builder.build();

    Properties props = new Properties();
    props.putAll(kafkaStreamsProps);
    props.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 1000L);

    KafkaStreams kafkaStreams = new KafkaStreams(topology, props);

    if (shouldClean) {
      kafkaStreams.cleanUp();
    }

    AtomicBoolean isStreamsStarted = new AtomicBoolean(false);

    kafkaStreams.setStateListener((newState, oldState) -> {
      if (newState == KafkaStreams.State.RUNNING) {
        isStreamsStarted.set(true);
      }
    });

    kafkaStreams.start();

    while (!isStreamsStarted.get()) {
      Thread.sleep(1000);
      System.out.println("waiting");
    }
    System.out.println("started : )");

    for (KeyValue<String, String> msg : messages) {
      producer.send(new ProducerRecord<>("input", msg.key, msg.value));
      Thread.sleep(1);
    }

    List<ConsumerRecord<String, String>> sessions = readUntilTime(
      consumer,
      Arrays.asList("output"),
      30L
    );

    kafkaStreams.close();

    return sessions;
  }

  @Test
  void recoversAfterRestart() throws InterruptedException {
    List<KeyValue<String, String>> session1 = toMsgsList(60, 65, key);

    List<KeyValue<String, String>> session2FirstHalf = toMsgsList(180, 182, key);
    List<KeyValue<String, String>> session2SecondHalf = toMsgsList(183, 185, key);
    ArrayList<KeyValue<String, String>> session2 = new ArrayList<>();
    session2.addAll(session2FirstHalf);
    session2.addAll(session2SecondHalf);

    List<KeyValue<String, String>> session3 = toMsgsList(300, 305, key);

    List<KeyValue<String, String>> closer = toMsgsList(420, 420, key);

    List<KeyValue<String, String>> firstBatch = new ArrayList<>(session1);
    firstBatch.addAll(session2FirstHalf);

    List<KeyValue<String, String>> secondBatch = new ArrayList<>(session2SecondHalf);
    secondBatch.addAll(session3);
    secondBatch.addAll(closer);

    // Send the first batch and store the output data
    List<ConsumerRecord<String, String>> beforeRestart = startProcessingAndClose(firstBatch, true);
    List<String> beforeRestartValues = beforeRestart.stream().map(ConsumerRecord::value).collect(Collectors.toList());

    // Send the second batch and store the output data
    List<ConsumerRecord<String, String>> afterRestart = startProcessingAndClose(secondBatch, false);
    List<String> afterRestartValues = afterRestart.stream().map(ConsumerRecord::value).collect(Collectors.toList());

    // The values we expect _before_ the restart
    List<String> expectedBeforeRestart = new ArrayList<>();
    expectedBeforeRestart.add("[" + key + "@60000/65000]: " + sessionToAggregatedString(session1));

    // The values we expect _after_ the restart
    List<String> expectedAfterRestart = new ArrayList<>();
    expectedAfterRestart.add("[" + key + "@180000/185000]: " + sessionToAggregatedString(session2));
    expectedAfterRestart.add("[" + key + "@300000/305000]: " + sessionToAggregatedString(session3));

    // Do the assertions
    // This passes
    assertEquals(expectedBeforeRestart, beforeRestartValues);
    // This fails
    assertEquals(expectedAfterRestart, afterRestartValues);
  }

  private List<KeyValue<String, String>> toMsgsList(int start, int end, String key) {
    List<KeyValue<String, String>> msgs = new ArrayList<>();
    for (int i = start; i <= end; i++) {
      int value = i * 1000;
      msgs.add(new KeyValue<>(key, value + "." + value));
    }
    return msgs;
  }
}

The first assertion correctly passes as we get 1 session before the restart. This happens as the start of the second session is outside the inactivity gap and so the session is emitted.
The second assertion incorrectly fails though, as we get 3 sessions instead of 2 after the restart.
Resulting in 4 total sessions: 1 before the restart, 3 after the restart.

org.opentest4j.AssertionFailedError: expected: <[[key@180000/185000]: 180000 181000 182000 183000 184000 185000, [key@300000/305000]: 300000 301000 302000 303000 304000 305000]> but was: <[[key@180000/182000]: 180000 181000 182000, [key@183000/185000]: 183000 184000 185000, [key@300000/305000]: 300000 301000 302000 303000 304000 305000]>

For some reason Kafka Streams flushes half of the unfinished session2 prematurely.
Have I misconfigured something, or is this a bug in Kafka Streams? What I would expect is to pick up session2 from where we left off, not flushing anything, as we cannot know if we are outside the inactivity gap yet.
FWIW, even with the RocksDB store; always running kafkaStreams.cleanUp() will result in the same issue.

I can't rely on disk in our prod environment and applications will of course restart once in a while. I could theoretically solve this downstream when sinking data into a DB, but that's less than ideal.

See the following forked Kafka repo on Github the actual test if you would like to run for yourself: https://github.com/swapsCAPS/kafka/blob/0c5322c86a0e3e04b23d7fa16b322eff301e9bdf/streams/src/test/java/org/apache/kafka/streams/integration/RestartIntegrationTest.java

0

There are 0 answers