Recursion Error with Guided Permutation Generation

18 views Asked by At

I'm currently stuck trying to figure out how to get around Python's recursion limit for a permutation that abides by a couple of rules to reduce the total number of permutations. Right now, a simplified version of what I'm working with that makes sure the previous symbol isn't the same as the last symbol is:

def recursive_generate_sequences(symbols_dict, sequence=None, last_symbol=None, start_symbols=None, remaining=0):
    """
    Generate all possible sequences of symbols based on their counts.

    :param start_symbols: set of symbols that are the known beginning of the sequence
    :param symbols_dict: dictionary containing the symbol and the number of times it appears in the sequence
    :param sequence: reserved for recursive function, do not pass
    :param last_symbol:  reserved for the recursive function last symbol the function pushed out
    :param remaining: reserved for the recursive function, do not pass
    :return:
    """
    if sequence is None:
        sequence = []
        counts = {symbol: count for symbol, count in symbols_dict.items()}
        # Initialize the remaining count, so that we don't have to recalculate it all the time
        remaining = sum(counts.values())
        # Initialize the sequence with the start symbols if provided
        if start_symbols:
            for symbol in start_symbols:
                if symbol in counts and counts[symbol] > 0:
                    sequence.append(symbol)
                    counts[symbol] -= 1
                    remaining -= 1
                    last_symbol = symbol
                else:
                    raise ValueError(f"Start symbol '{symbol}' is not valid or has insufficient count.")
    else:
        counts = symbols_dict

    # check if the sequence is complete
    if remaining == 0:
        yield sequence

    for symbol, count in counts.items():
        if count > 0 and symbol != last_symbol:
            # explore the rabbit hole
            counts[symbol] -= 1
            next_sequence = sequence + [symbol]
            # pass None to start_symbols for recursive calls
            yield from recursive_generate_sequences(counts, next_sequence, symbol, None, remaining - 1)
            counts[symbol] += 1  # Backtrack

if __name__ == "__main__":
    # works
    for x in recursive_generate_sequences({'a': 1, 'b': 2, 'c': 1}):
        print(x)

    # fails with RecursionError
    for x in recursive_generate_sequences({'a': 1000, 'b': 2000, 'c': 1000}):
        pass

But as you can probably tell I hit a recursion limit when the symbols_dict values are above the python's recursion limit.

I've seen stack overflow and an MCoding video answer suggesting an iterative approach and another stack overflow answer suggesting wrapping it in another function.

I can't really wrap my head around the iterative approach and I couldn't get the second's suggestion to work properly outside of the simple code they provided. In an ideal world the answer takes the form of a generator that yields as it is called and doesn't try to calculate everything ahead of time.

Can anyone help me switch this to an iterative approach or explain the second stack overflow answer a bit better?

1

There are 1 answers

0
Jacob Dallas On

Okay I got something to work!! I ended up finally understanding what mCoding was getting at so I made this class to handle the permutation generator:

class IterativeGenerateSequences:
    def __init__(self, symbols_dict: Dict, start_symbols: Optional[List] = None):
        self.symbols_dict = symbols_dict
        self.start_symbols = start_symbols
        self.counts = {symbol: count for symbol, count in symbols_dict.items()}
        self.stack = []
        self._initialize_stack()

    def _initialize_stack(self):
        # initialize starting state based on start_symbols if provided
        initial_sequence = []
        initial_counts = self.counts.copy()
        remaining = sum(initial_counts.values())

        if self.start_symbols:
            for symbol in self.start_symbols:
                if symbol in initial_counts and initial_counts[symbol] > 0:
                    initial_sequence.append(symbol)
                    initial_counts[symbol] -= 1
                    remaining -= 1
                else:
                    raise ValueError(f"Start symbol '{symbol}' is not valid or has insufficient count.")

        # initial state includes the sequence so far, the counts of remaining symbols, and the last symbol added...
        self.stack.append((initial_sequence, initial_counts, None, remaining))

    def generate(self):
        while self.stack:
            sequence, counts, last_symbol, remaining = self.stack.pop()

            # sequence is complete, yield it
            if remaining == 0:
                yield sequence
                continue

            for symbol, count in counts.items():
                if count > 0 and symbol != last_symbol:
                    # prep the next state to explore
                    next_counts = counts.copy()
                    next_counts[symbol] -= 1
                    next_sequence = sequence + [symbol]
                    next_remaining = remaining - 1

                    # push the new state onto the stack
                    self.stack.append((next_sequence, next_counts, symbol, next_remaining))

Key to note I switched from recursion to an iterative approach. Definitely feel free to criticise/optimize since I'm new to iterators vs recursion.