Pythonic way of "inner join" of list of lists with arbitrary lengths?

137 views Asked by At

I have an arbitrary list of lists, for example

x = [
        [5, 1, 2, 3, 4],
        [5, 6, 7, 8, 9],
        [5, 10, 11, 12]
    ]

There is only one element which is a member of all three lists in x = 5

My question: What is the most pythonic method of testing membership of elements in an arbitrary number of lists?

My solution is as follows, but feels like it could be simpler:

y = x[0]
if len(x) > 1:
    for subset in x[1:]:
        x = list(set(x) & set(subset))

EDIT: The output should be a list with the common elements included. For the example above, it would be [5] If the example above had two elements (e.g. 5 and foo) the output should be [5, "foo"]

3

There are 3 answers

0
Kraigolas On

You can use functools.reduce to solve this problem:

functools.reduce(lambda y, z: y.intersection(set(z)), x, set)
# {5}
3
KingsDev On

Another, potentially more pythonic way of achieving this is:

results = [val for val in x[0] if all(val in ls for ls in x[1:])]

This creates a new list of all vals in the first sublist (x[0]) if the value is also in all lists (ls) in all other sublists (x[1:]).

This could be considered a more pythonic method as it uses only raw python, rather than say, using functools.reduce (as suggested in another answer).
It also demonstrates python's unique ability to create a list other list(s) based on a certain condition.

0
Timus On

I'd say the most pythonic way would be to use set.intersection directly:

common = set(x[0]).intersection(*x[1:])

Regarding performance: It depends! (It depends, imho, on how likely it is that there are, in fact, common elements.)

I've tried the following measurement setup:

from functools import reduce
from random import randint, choices, seed
from timeit import timeit

def sample(num_numbers, num_lists, seed_number):
    numbers = list(range(num_numbers))
    seed(seed_number)
    return [choices(numbers, k=randint(50, 100)) for _ in range(num_lists)]

def test_1(x):
    result = set(x[0])
    for sublist in x[1:]:
        result = result.intersection(sublist)

def test_2(x): [val for val in x[0] if all(val in ls for ls in x[1:])]
def test_3(x): reduce(lambda y, z: y.intersection(set(z)), x, set)
def test_4(x): set(x[0]).intersection(*x[1:])

for num_numbers, num_lists in ((3, 10_000), (30, 10_000)):
    print(f"\nSample configuration: {num_numbers = }, {num_lists = }\n")
    x = sample(num_numbers, num_lists, 123456789)
    for n in range(1, 5):
        t = timeit(f"test_{n}(x)", globals=globals(), number=100)
        print(f"test_{n}(x): {t:.3f} seconds")

Result here:

Sample configuration: num_numbers = 3, num_lists = 10000

test_1(x): 1.227 seconds
test_2(x): 8.817 seconds
test_3(x): 1.094 seconds
test_4(x): 1.193 seconds

Sample configuration: num_numbers = 30, num_lists = 10000

test_1(x): 0.748 seconds
test_2(x): 0.360 seconds
test_3(x): 1.428 seconds
test_4(x): 0.679 seconds