Efficient algorithm to find whether a subset of an integer array exists,the xor of all its elements is a given value?

3.5k views Asked by At

I have a positive integer array- {1,5,8,2,10} and a given value 7. I need to find whether a subset of the array exists such that the XOR of its elements is the value 7. In this case the subset is {5,2} because 5 xor 2 is 7. One naive solution is to find all the subsets and check whether a solution exist.I want some algorithm better than the naive. NOTE:-I just need to find whether a solution exists or not.I don't need to find the subset.

1

There are 1 answers

4
David Eisenstat On

This boils down to solving a system of linear equations over the finite field with two elements (GF(2)). Bitwise XOR here is equivalent to adding two vectors. The sample inputs correspond to vectors like so.

 1: 0001
 5: 0101
 8: 1000
 2: 0010
10: 1010
 7: 0111

The system looks like this.

[0  0  1  0  1] [a]   [0]
[0  1  0  0  0] [b]   [1]
[0  0  0  1  1] [c] = [1]
[1  1  0  0  0] [d]   [1]
                [e]

The following Python code uses Gaussian elimination and is implemented using bitwise operations. For fixed-width integers, it runs in linear time. Forgive me for not reexplaining Gaussian elimination when there are a million better treatments on the Internet already.

#!/usr/bin/env python3
def least_bit_set(x):
    return x & (-x)


def delete_zeros_from(values, start):
    i = start
    for j in range(start, len(values)):
        if values[j] != 0:
            values[i] = values[j]
            i += 1
    del values[i:]


def eliminate(values):
    values = list(values)
    i = 0
    while True:
        delete_zeros_from(values, i)
        if i >= len(values):
            return values
        j = i
        for k in range(i + 1, len(values)):
            if least_bit_set(values[k]) < least_bit_set(values[j]):
                j = k
        values[i], values[j] = (values[j], values[i])
        for k in range(i + 1, len(values)):
            if least_bit_set(values[k]) == least_bit_set(values[i]):
                values[k] ^= values[i]
        i += 1


def in_span(x, eliminated_values):
    for y in eliminated_values:
        if least_bit_set(y) & x != 0:
            x ^= y
    return x == 0


def main():
    values = [1, 5, 8, 2, 10]
    eliminated_values = eliminate(values)
    print(eliminated_values)
    x = int(input())
    print(in_span(x, eliminated_values))


if __name__ == '__main__':
    main()