I'm trying to compute about 1 trillion different pairwise jaccard scores on a laptop. I'm also not very patient.
Each set tends to contain about 200-800 unique, pseudorandom 32-bit hashes. I've already minimized the size of each set.
Is there an approach that would allow me to trade precision for faster performance? I can probably tolerate an error rate of +/- 0.15.
To be clear, I'm not looking to get a marginal performance boost from converting the same equation to something like numba. I'm looking for an alternative approach (perhaps matrix math?) that will boost performance by at least an order of magnitude, preferably more.
Here's an example that computes an error-free jaccard score:
def jaccard(hash_set_1, hash_set_2) -> float:
if 0 == min(len(hash_set_1), len(hash_set_2)):
return 0.0
intersect = len(hash_set_1 & hash_set_2)
union = len(hash_set_1) + len(hash_set_2) - intersect
return intersect / union
For some additional context: I have roughly 100 million known sets, and some input sets of interest. For each of the sets of interest, I'm scoring against the 100 million known sets to generate clusters. As matches (scores above a threshold) are found, they are added to the input set. After fully exploring the space, I expect to find roughly 10K matches.
This operation gets repeated periodically, and needs to be as fast as possible. When the operation is repeated, some percentage (10%-25%) of both the known sets, and the input sets of interest, will have aged out and be replaced with new sets of data.
This answer gives a possible solution to significantly speed-up the whole algorithm (not just the provided function). However, it is not trivial to implement efficiently.
The main idea is to replace the CPython's
set
data-structure (which is similar to a hash-table data-structure containing reference to CPython objects, see here) by a simplified Bloom filter. Bloom filters are a very-compact probabilistic data-structure. They can be used to check if a number is not in a set with 100% accuracy, or is in a set with a possible error. The probability of having an error is dependent of the size of the filter. Each integer can be encoded in a single bit in a large bit set. The location of the integer is based on the hash of the integer. If the integer is already a uniformly-distributed number (e.g. already a hash), then it does not need to be hashed. The location of the bit to check/set is found usinghash(number) % filter_size
. Iffilter_size
is a power of two, then the modulus can be replaced by a very fast AND instruction.The intersection of two bloom filter consists in computing just a AND of the two memory blocks assuming the two filter have the exact same size and use the same hash function. The number of bit set (which can be very efficiently computed thanks to dedicated CPU instructions) is an approximation of the number of items in the intersection of the two sets. However, with this solution, the filters need to be pretty big to avoid collision issues which cause the final result to be very inaccurate (see Birthday problem). Big (sparse) filters tends to be quite expensive to travel.
An alternative solution is to transform the biggest set in a large bloom filter, and then iterate over the smallest set so to check if each item is found in the bloom filter or not. This solution avoid many collisions and mitigate a bit the birthday problem. It is also less expensive as long as the bloom filter fits in the L1 CPU cache. Converting a set to a bloom filter can be quite expensive, but the bloom filter can be reused for many pair comparisons so this operation should finally take a negligible time.
Note that the approximation is an (unbounded) over-approximation. The result is guaranteed to be exact if the numbers are smaller than the table size and the hash function is a perfect hash function (e.g. identity). However, the size of the filter would be too big to be efficient for 32-bit random integers (i.e. 512 MiB of RAM per filter). You can improve the precision of the method using multiple hash function (at the expense of a slower execution -- there is no free lunch). This strategy can help you to determine the accuracy of the output (based on the variance of the resulting numbers retrieved from multiple filters with different hash function).
Here is an (inefficient) example illustrating how things works:
In practice, for sake of performance, the computation should be computed in Numba or Cython (or using a native language). The set should be stored as Numpy arrays since iterating over CPython set is more expensive in Numba/Cython. The bloom filter should be reused as much as possible. The size of the bloom filter must be a compile-time constant (critical to avoid a very slow modulus) and the hash-function should be fast (e.g. identity, xorshift). The array of the bloom filter should be recycled as much as possible to avoid allocations. Loop unrolling and tilling should also improve performance.
The pairs can be computed in parallel using multiple threads (native threads with the GIL disabled). In fact, the bloom-filter approach is also GPU-friendly and GPUs can be significantly faster for such computation. The bloom filters must be stored in the GPU shared-memory for the operation to be fast though. The size of this kind of memory is very limited (e.g. 0-200 KiB per SM on Nvidia devices), but it should be big enough for your needs in practice.
Related post: The most efficient way rather than using
np.setdiff1d
andnp.in1d
, to remove common values of 1D arrays with unique values