Efficiently compute xor / symmetric difference of many sets (list of sets)

177 views Asked by At

I have an arbitrary number of Python sets, e.g.

>>> a = {1, 2, 3}
>>> b = {3, 4, 5}
>>> c = {5, 6, 7}
>>> d = {7, 8, 1}

I want to compute their "combined" symmetric difference, i.e. I want to xor all of them:

>>> a ^ b ^ c ^ d
{2, 4, 6, 8}

In my use-case, I am actually dealing with lists of sets:

>>> l = [a, b, c, d]
>>> l
[{1, 2, 3}, {3, 4, 5}, {5, 6, 7}, {1, 7, 8}]

Currently, I am iterating across the list in order to achieve what I want:

>>> res = l[0].copy()
>>> for item in l[1:]:
...     res.symmetric_difference_update(item)
>>> res
{2, 4, 6, 8}

I am wondering whether there is a more efficient method, ideally without going through a Python for-loop. Set-operations are actually really fast in Python, but my lists can become rather long, so the for-loop itself ironically becomes a bottleneck.


EDIT (1)

I am assuming that every possible entry of all sets in my list do occur not more than twice across all sets in my list.


EDIT (2)

Some benchmarks:

from typing import List, Set
from functools import reduce
from collections import defaultdict

length = 1_000
data = [
    {idx - 1, idx, idx + 1}
    for idx in range(3_000, 3_000 + length * 2, 2)
]

def test_loop1(l: List[Set[int]]) -> Set[int]:
    res = l[0].copy()
    for item in l[1:]:
        res.symmetric_difference_update(item)
    assert len(res) == len(l) + 2
    return res

test_loop1: 121 µs ± 321 ns

def test_loop2(l: List[Set[int]]) -> Set[int]:
    res = set()
    for item in l:
        res.symmetric_difference_update(item)
    assert len(res) == len(l) + 2
    return res

test_loop2: 112 µs ± 3.16 µs

def test_reduce1(l: List[Set[int]]) -> Set[int]:
    res = reduce(Set.symmetric_difference, l)
    assert len(res) == len(l) + 2
    return res

test_reduce1: 9.89 ms ± 20.6 µs

def test_dict1(l: List[Set[int]]) -> Set[int]:
    """
    A general solution allowing for entries to occur more than twice in the input data
    """
    d = defaultdict(int)
    for item in l:
        for entry in item:
            d[entry] += 1
    res = {entry for item in l for entry in item if d[entry] == 1}
    assert len(res) == len(l) + 2
    return res

test_dict1: 695 µs ± 5.11 µs

0

There are 0 answers