How Fast Can We Approximate Set Jaccard Scores?

98 views Asked by At

I'm trying to compute about 1 trillion different pairwise jaccard scores on a laptop. I'm also not very patient.

Each set tends to contain about 200-800 unique, pseudorandom 32-bit hashes. I've already minimized the size of each set.

Is there an approach that would allow me to trade precision for faster performance? I can probably tolerate an error rate of +/- 0.15.

To be clear, I'm not looking to get a marginal performance boost from converting the same equation to something like numba. I'm looking for an alternative approach (perhaps matrix math?) that will boost performance by at least an order of magnitude, preferably more.

Here's an example that computes an error-free jaccard score:

def jaccard(hash_set_1, hash_set_2) -> float:
    if 0 == min(len(hash_set_1), len(hash_set_2)):
        return 0.0
    intersect = len(hash_set_1 & hash_set_2)
    union = len(hash_set_1) + len(hash_set_2) - intersect
    return intersect / union

For some additional context: I have roughly 100 million known sets, and some input sets of interest. For each of the sets of interest, I'm scoring against the 100 million known sets to generate clusters. As matches (scores above a threshold) are found, they are added to the input set. After fully exploring the space, I expect to find roughly 10K matches.

This operation gets repeated periodically, and needs to be as fast as possible. When the operation is repeated, some percentage (10%-25%) of both the known sets, and the input sets of interest, will have aged out and be replaced with new sets of data.

2

There are 2 answers

3
Jérôme Richard On BEST ANSWER

This answer gives a possible solution to significantly speed-up the whole algorithm (not just the provided function). However, it is not trivial to implement efficiently.

The main idea is to replace the CPython's set data-structure (which is similar to a hash-table data-structure containing reference to CPython objects, see here) by a simplified Bloom filter. Bloom filters are a very-compact probabilistic data-structure. They can be used to check if a number is not in a set with 100% accuracy, or is in a set with a possible error. The probability of having an error is dependent of the size of the filter. Each integer can be encoded in a single bit in a large bit set. The location of the integer is based on the hash of the integer. If the integer is already a uniformly-distributed number (e.g. already a hash), then it does not need to be hashed. The location of the bit to check/set is found using hash(number) % filter_size. If filter_size is a power of two, then the modulus can be replaced by a very fast AND instruction.

The intersection of two bloom filter consists in computing just a AND of the two memory blocks assuming the two filter have the exact same size and use the same hash function. The number of bit set (which can be very efficiently computed thanks to dedicated CPU instructions) is an approximation of the number of items in the intersection of the two sets. However, with this solution, the filters need to be pretty big to avoid collision issues which cause the final result to be very inaccurate (see Birthday problem). Big (sparse) filters tends to be quite expensive to travel.

An alternative solution is to transform the biggest set in a large bloom filter, and then iterate over the smallest set so to check if each item is found in the bloom filter or not. This solution avoid many collisions and mitigate a bit the birthday problem. It is also less expensive as long as the bloom filter fits in the L1 CPU cache. Converting a set to a bloom filter can be quite expensive, but the bloom filter can be reused for many pair comparisons so this operation should finally take a negligible time.

Note that the approximation is an (unbounded) over-approximation. The result is guaranteed to be exact if the numbers are smaller than the table size and the hash function is a perfect hash function (e.g. identity). However, the size of the filter would be too big to be efficient for 32-bit random integers (i.e. 512 MiB of RAM per filter). You can improve the precision of the method using multiple hash function (at the expense of a slower execution -- there is no free lunch). This strategy can help you to determine the accuracy of the output (based on the variance of the resulting numbers retrieved from multiple filters with different hash function).

Here is an (inefficient) example illustrating how things works:

# The two input sets
a = np.unique(np.random.randint(0, 1_000_000, 800))
b = np.unique(np.append(np.random.randint(0, 1_000_000, 300), np.random.choice(a, 10)))

print('Exact count: ', len(set(a) & set(b)))

# Build a simple bloom filters based on `a`
# Numpy takes 1 byte per item, but it should take 32 KiB in 
# practice with 1 bit per item. 
# This fits in the L1 cache of nearly all mainstream CPU.
# The L1 cache is very fast.
table = np.zeros(256*1024, dtype=np.bool_)
table[a % table.size] = True

# Count matching items (over-approximation)
count = 0
for item in b:
    count += table[item % table.size]
print('Estimation: ', count)

In practice, for sake of performance, the computation should be computed in Numba or Cython (or using a native language). The set should be stored as Numpy arrays since iterating over CPython set is more expensive in Numba/Cython. The bloom filter should be reused as much as possible. The size of the bloom filter must be a compile-time constant (critical to avoid a very slow modulus) and the hash-function should be fast (e.g. identity, xorshift). The array of the bloom filter should be recycled as much as possible to avoid allocations. Loop unrolling and tilling should also improve performance.

The pairs can be computed in parallel using multiple threads (native threads with the GIL disabled). In fact, the bloom-filter approach is also GPU-friendly and GPUs can be significantly faster for such computation. The bloom filters must be stored in the GPU shared-memory for the operation to be fast though. The size of this kind of memory is very limited (e.g. 0-200 KiB per SM on Nvidia devices), but it should be big enough for your needs in practice.

Related post: The most efficient way rather than using np.setdiff1d and np.in1d, to remove common values of 1D arrays with unique values

1
Bergi On

Bloom filters are a very good idea. Yet, there is another way to tackle this:

I have roughly 100 million known sets, and some input sets of interest. For each of the sets of interest, I'm scoring against the 100 million known sets to generate clusters. As matches (scores above a threshold) are found, they are added to the input set. After fully exploring the space, I expect to find roughly 10K matches.

Considering the implementation of set intersection, this comes down to a three-level nested loop:

known_sets: list[int] = …
sets_of_interest: set[int] = …
for a in sets_of_interest:
    for b in known_sets:
        intersect = 0
        for el in a:
            if el in b:
                intersect += 1
        union = len(a) + len(b) - intersect
        score = intersect / union
        if score > threshold:
            sets_of_interest.add(b)

That's n (number of known_sets) × m (number of expected matches) × s (average size of the sets) repetitions of the set lookup operation - not good.

We can bring this down considerably by treating the sets and their elements as a bipartite graph, then inverting its representation: instead of storing for each set which elements it has, we store for each element which sets it is part of. This will subsequently allow us to compute all scores at once by looping over the elements:

known_sets: list[int] = …
SetId: TypeAlias = int
sets_by_element: dict[int, set[SetId]] = defaultdict(set)

for id, s in enumerate(known_sets):
     for el in s:
         sets_by_element[el].add(id)

all_intersects: list[list[int]] = [0*len(known_sets) for _ in known_sets]
for el, sets in sets_by_element.items():
    for a_id in sets:
        for b_id in sets:
            all_intersects[a_id][b_id] += 1

To get the jaccard score between two sets, just compute

intersect = all_intersects[a_id][b_id]
union = len(known_sets[a_id]) + len(known_sets[b_id]) - intersect
score = intersect / union

However, that computes all pairwise scores, and it's actually even more expensive, at n×n×s in the worst case (when all sets share all values). The average complexity is much better though: n × s × d (average number of overlapping elements between sets), plus n² (for the allocation and iteration of all_intersects). Given your numbers, that n² is prohibitive :-(

However, that's only for computing all pairwise scores, we have not yet taken into account your sets_of_interest and the threshold. Instead of simply going over all elements as they occur in sets_by_element, we start with the elements from the sets of interest:

set_ids_of_interest: set[SetId] = … # a subset of known_set ids
seen: set[int] = set()
most_intersects: dict[SetId, dict[SetId, int]] = defaultdict(lambda: defaultdict(int))
for x_id in set_ids_of_interest:
    x = known_sets[x_id]
    for el in x:
        if el in seen:
            continue
        seen.add(el)
        sets = sets_by_element[el]
        for a_id in sets:
            for b_id in sets:
                most_intersects[a_id][b_id] += 1
    # now most_intersects[x_id] is complete,
    # we've seen all elements from x (and which sets they occur in)
    for b_id, intersect in most_intersects.pop(x_id).items():
        if b_id == x_id:
            continue # it would score 1.0, but it's already part of set_ids_of_interest…
        b = known_sets[b_id]
        union = len(x) + len(b) - intersect
        score = intersect / union
        if score > threshold:
            set_ids_of_interest.add(b_id)

This may seem worse (with that fourth-level nested loop!), but actually it's not: the part beginning with seen.add(el) no longer executes for every element but just for the ones in the m matching sets. We also no longer allocate n² intersection counts but only as many as needed until we can throw them away.

If I judge the time complexity correctly, we're now at n×s + min(m×s, ns×d + m×min(m×s,n). Don't ask me to prove it though :D

This operation gets repeated periodically, and needs to be as fast as possible. When the operation is repeated, some percentage (10%-25%) of both the known sets, and the input sets of interest, will have aged out and be replaced with new sets of data.

At the risk of stating the obvious: do not start over by running the whole algorithm again, but recompute only the results for the the sets that have actually changed (or are new altogether).

This trades speed for memory, and is usually well worth it. It applies to the bloom filters as well - cache the filter if the set didn't change, and don't recompute a pairwise score that you already know.

The structure of sets_by_element lends itself to such an on-line algorithm. I'm not sure in what format you receive the changes to the known sets, but you can trivially remove the set id from the sets_by_element for the elements that were removed and add the set id to the sets_by_element for the elements that were added. This will considerably speed up the first step of the algorithm.

You might even consider keeping all the most_intersects in memory (by not .pop()ping them), or even use the all_intersects two-dimensional array, if you have enough memory for this. Updating it with every change to sets_by_element is easy, and you do not need to loop over all n² elements during each repetition of the operation - start with the set_ids_of_interest and loop only as long as you still reach your threshold.