Recursively finding all partitions of a set of n objects into k non-empty subsets

801 views Asked by At

I want to find all partitions of a n elements into k subsets, this is my algorithm based on recursive formula for finding all Stirling second numbers

fun main(args: Array<String>) {
    val s = mutableSetOf(1, 2, 3, 4, 5)
    val partitions = 3
    val res = mutableSetOf<MutableSet<MutableSet<Int>>>()
    partition(s, partitions, res)
    //println(res)
    println("Second kind stirling number ${res.size}")
}

fun partition(inputSet: MutableSet<Int>, numOfPartitions: Int, result: MutableSet<MutableSet<MutableSet<Int>>>) {
    if (inputSet.size == numOfPartitions) {
        val sets = inputSet.map { mutableSetOf(it) }.toMutableSet()
        result.add(sets)
    }
    else if (numOfPartitions == 1) {
        result.add(mutableSetOf(inputSet))
    }
    else {
        val popped: Int = inputSet.first().also { inputSet.remove(it) }

        val r1 = mutableSetOf<MutableSet<MutableSet<Int>>>()
        partition(inputSet, numOfPartitions, r1) //add popped to each set in solution (all combinations)

        for (solution in r1) {
            for (set in solution) {
                set.add(popped)
                result.add(solution.map { it.toMutableSet() }.toMutableSet()) //deep copy
                set.remove(popped)
            }
        }
        val r2 = mutableSetOf<MutableSet<MutableSet<Int>>>()
        partition(inputSet, numOfPartitions - 1, r2) //popped is single elem set

        r2.map { it.add(mutableSetOf(popped)) }
        r2.map { result.add(it) }
    }
}

Code works well for k = 2, but for bigger n and k it loses some partitions and I can't find a mistake here. Example: n = 5 and k = 3 outputs Second kind stirling number 19 the correct output would be 25.

3

There are 3 answers

0
MBo On BEST ANSWER

If you can read Python code, consider the next algorithm which I've quickly adapted from my implementation of set partition into equal size parts.

Recursive function fills K parts with N values.

The lastfilled parameter helps to avoid duplicates - it provides an increasing sequence of leading (smallest) elements of every part.

The empty parameter is intended to avoid empty parts.

def genp(parts:list, empty, n, k, m, lastfilled):
    if m == n:
        print(parts)
        global c
        c+=1
        return
    if n - m == empty:
        start = k - empty
    else:
        start = 0
    for i in range(start, min(k, lastfilled + 2)):
        parts[i].append(m)
        if len(parts[i]) == 1:
            empty -= 1
        genp(parts, empty, n, k, m+1, max(i, lastfilled))
        parts[i].pop()
        if len(parts[i]) == 0:
            empty += 1


def setkparts(n, k):
    parts = [[] for _ in range(k)]
    cnts = [0]*k
    genp(parts, k, n, k, 0, -1)

c = 0
setkparts(5,3)
#setkparts(7,5)
print(c)

[[0, 1, 2], [3], [4]]
[[0, 1, 3], [2], [4]]
[[0, 1], [2, 3], [4]]
[[0, 1, 4], [2], [3]]
[[0, 1], [2, 4], [3]]
[[0, 1], [2], [3, 4]]
[[0, 2, 3], [1], [4]]
[[0, 2], [1, 3], [4]]
[[0, 2, 4], [1], [3]]
[[0, 2], [1, 4], [3]]
[[0, 2], [1], [3, 4]]
[[0, 3], [1, 2], [4]]
[[0], [1, 2, 3], [4]]
[[0, 4], [1, 2], [3]]
[[0], [1, 2, 4], [3]]
[[0], [1, 2], [3, 4]]
[[0, 3, 4], [1], [2]]
[[0, 3], [1, 4], [2]]
[[0, 3], [1], [2, 4]]
[[0, 4], [1, 3], [2]]
[[0], [1, 3, 4], [2]]
[[0], [1, 3], [2, 4]]
[[0, 4], [1], [2, 3]]
[[0], [1, 4], [2, 3]]
[[0], [1], [2, 3, 4]]
25
0
Михаил Нафталь On

Not sured, what is the exact problem in your code, but finding all Stirling second numbers in recursive manner is much simplier:

private val memo = hashMapOf<Pair<Int, Int>, BigInteger>()
fun stirling2(n: Int, k: Int): BigInteger {
    val key = n to k
    return memo.getOrPut(key) {
        when {
            k == 0 || k > n -> BigInteger.ZERO
            n == k -> BigInteger.ONE
            else -> k.toBigInteger() * stirling2(n - 1, k) + stirling2(n - 1, k - 1)
        }
    }
}
0
Emil Valeev On

I improved Kornel_S' code. There is a func which makes a list of all possible combinations. Be careful with big numbers :)

def Stirling2Iterate(List):
    
    Result = []
    
    def genp(parts:list, empty, n, k, m, lastfilled):
        
        if m == n:
            nonlocal Result
            nonlocal List
            Result += [ [[List[item2] for item2 in item] for item in parts] ]
            return
        
        if n - m == empty: start = k - empty
        else: start = 0
        
        for i in range(start, min(k, lastfilled + 2)):
            
            parts[i].append(m)
            if len(parts[i]) == 1: empty -= 1
            genp(parts, empty, n, k, m + 1, max(i, lastfilled))
            parts[i].pop()
            if len(parts[i]) == 0: empty += 1
    
    def setkparts(n, k):
        parts = [ [] for _ in range(k) ]
        cnts = [0] * k
        genp(parts, k, n, k, 0, -1)
    
    for i in range(len(List)): setkparts(len(List), i + 1)
    
    return Result

Example:

# EXAMPLE

print('\n'.join([f"{x}" for x in Stirling2Iterate(['A', 'B', 'X', 'Z'])]))

# OUTPUT

[['A', 'B', 'X', 'Z']]
[['A', 'B', 'X'], ['Z']]
[['A', 'B', 'Z'], ['X']]
[['A', 'B'], ['X', 'Z']]
[['A', 'X', 'Z'], ['B']]
[['A', 'X'], ['B', 'Z']]
[['A', 'Z'], ['B', 'X']]
[['A'], ['B', 'X', 'Z']]
[['A', 'B'], ['X'], ['Z']]
[['A', 'X'], ['B'], ['Z']]
[['A'], ['B', 'X'], ['Z']]
[['A', 'Z'], ['B'], ['X']]
[['A'], ['B', 'Z'], ['X']]
[['A'], ['B'], ['X', 'Z']]
[['A'], ['B'], ['X'], ['Z']]