Incrementing and removing elements of ConcurrentHashMap

1.7k views Asked by At

There is class Counter, which contains a set of keys and allows incrementing value of each key and getting all values. So, the task I'm trying to solve is the same as in Atomically incrementing counters stored in ConcurrentHashMap . The difference is that the set of keys is unbounded, so new keys are added frequently.

In order to reduce memory consumption, I clear values after they are read, this happens in Counter.getAndClear(). Keys are also removed, and this seems to break things up.

One thread increments random keys and another thread gets snapshots of all values and clears them.

The code is below:

import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ThreadLocalRandom;
import java.util.Map;
import java.util.HashMap;
import java.lang.Thread;

class HashMapTest {
    private final static int hashMapInitSize = 170;
    private final static int maxKeys = 100;
    private final static int nIterations = 10_000_000;
    private final static int sleepMs = 100;

    private static class Counter {
        private ConcurrentMap<String, Long> map;

        public Counter() {
            map = new ConcurrentHashMap<String, Long>(hashMapInitSize);
        }

        public void increment(String key) {
            Long value;
            do {
                value = map.computeIfAbsent(key, k -> 0L);
            } while (!map.replace(key, value, value + 1L));
        }

        public Map<String, Long> getAndClear() {
            Map<String, Long> mapCopy = new HashMap<String, Long>();
            for (String key : map.keySet()) {
                Long removedValue = map.remove(key);
                if (removedValue != null)
                    mapCopy.put(key, removedValue);
            }
            return mapCopy;
        }
    }

    // The code below is used for testing
    public static void main(String[] args) throws InterruptedException {
        Counter counter = new Counter();
        Thread thread = new Thread(new Runnable() {
            public void run() {
                for (int j = 0; j < nIterations; j++) {
                    int index = ThreadLocalRandom.current().nextInt(maxKeys);
                    counter.increment(Integer.toString(index));
                }
            }
        }, "incrementThread");
        Thread readerThread = new Thread(new Runnable() {
            public void run() {
                long sum = 0;
                boolean isDone = false;
                while (!isDone) {
                    try {
                        Thread.sleep(sleepMs);
                    }
                    catch (InterruptedException e) {
                        isDone = true;
                    }
                    Map<String, Long> map = counter.getAndClear();
                    for (Map.Entry<String, Long> entry : map.entrySet()) {
                        Long value = entry.getValue();
                        sum += value;
                    }
                    System.out.println("mapSize: " + map.size());
                }
                System.out.println("sum: " + sum);
                System.out.println("expected: " + nIterations);
            }
        }, "readerThread");
        thread.start();
        readerThread.start();
        thread.join();
        readerThread.interrupt();
        readerThread.join();
        // Ensure that counter is empty
        System.out.println("elements left in map: " + counter.getAndClear().size());
    }
}

While testing I have noticed that some increments are lost. I get the following results:

sum: 9993354
expected: 10000000
elements left in map: 0

If you can't reproduce this error (that sum is less than expected), you can try to increase maxKeys a few orders of magnitude or decrease hashMapInitSize or increase nIterations (the latter also increases run time). I have also included testing code (main method) in the case it has any errors.

I suspect that the error is happening when capacity of ConcurrentHashMap is increased during runtime. On my computer the code appears to work correctly when hashMapInitSize is 170, but fails when hashMapInitSize is 171. I believe that size of 171 triggers increasing of capacity (128 / 0.75 == 170.66, where 0.75 is the default load factor of hash map).

So, the question is: am I using remove, replace and computeIfAbsent operations correctly? I assume that they are atomic operations on ConcurrentHashMap based on answers to Use of ConcurrentHashMap eliminates data-visibility troubles?. If so, why are some increments lost?

EDIT:

I think that I missed an important detail here that increment() is supposed to be called much more frequently than getAndClear(), so that I try to avoid any explicit locking in increment(). However, I'm going to test performance of different versions later to see if it is really an issue.

1

There are 1 answers

2
forty-two On

I gues the problem is the use of remove while iterating over the keySet. This is what the JavaDoc says for Map#keySet() (my emphasis):

Returns a Set view of the keys contained in this map. The set is backed by the map, so changes to the map are reflected in the set, and vice-versa. If the map is modified while an iteration over the set is in progress (except through the iterator's own remove operation), the results of the iteration are undefined.

The JavaDoc for ConcurrentHashMap give further clues:

Similarly, Iterators, Spliterators and Enumerations return elements reflecting the state of the hash table at some point at or since the creation of the iterator/enumeration.

The conclusion is that mutating the map while iterating over the keys is not predicatble.

One solution is to create a new map for the getAndClear() operation and just return the old map. The switch has to be protected, and in the example below I used a ReentrantReadWriteLock:

class HashMapTest {
private final static int hashMapInitSize = 170;
private final static int maxKeys = 100;
private final static int nIterations = 10_000_000;
private final static int sleepMs = 100;

private static class Counter {
    private ConcurrentMap<String, Long> map;
    ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
    ReadLock readLock = lock.readLock();
    WriteLock writeLock = lock.writeLock();

    public Counter() {
        map = new ConcurrentHashMap<>(hashMapInitSize);
    }

    public void increment(String key) {
        readLock.lock();
        try {
            map.merge(key, 1L, Long::sum);
        } finally {
            readLock.unlock();
        }
    }

    public Map<String, Long> getAndClear() {
        ConcurrentMap<String, Long> oldMap;
        writeLock.lock();
        try {
            oldMap = map;
            map = new ConcurrentHashMap<>(hashMapInitSize);
        } finally {
            writeLock.unlock();
        }

        return oldMap;
    }
}

// The code below is used for testing
public static void main(String[] args) throws InterruptedException {
    final AtomicBoolean ready = new AtomicBoolean(false);

    Counter counter = new Counter();
    Thread thread = new Thread(new Runnable() {
        public void run() {
            for (int j = 0; j < nIterations; j++) {
                int index = ThreadLocalRandom.current().nextInt(maxKeys);
                counter.increment(Integer.toString(index));
            }
        }
    }, "incrementThread");

    Thread readerThread = new Thread(new Runnable() {
        public void run() {
            long sum = 0;
            while (!ready.get()) {
                try {
                    Thread.sleep(sleepMs);
                } catch (InterruptedException e) {
                    //
                }
                Map<String, Long> map = counter.getAndClear();
                for (Map.Entry<String, Long> entry : map.entrySet()) {
                    Long value = entry.getValue();
                    sum += value;
                }
                System.out.println("mapSize: " + map.size());
            }
            System.out.println("sum: " + sum);
            System.out.println("expected: " + nIterations);
        }
    }, "readerThread");
    thread.start();
    readerThread.start();
    thread.join();
    ready.set(true);
    readerThread.join();
    // Ensure that counter is empty
    System.out.println("elements left in map: " + counter.getAndClear().size());
}
}