Generalizing std::partition to multi_partition

1.2k views Asked by At

Just as std::partition partitions a container according to a unary predicate, multi_partition is to partition a container according to UnaryPredicates... pred in the same order as listed in UnaryPredicates..., and the false elements in the order of UnaryPredicates... as well at the end of the container, and return a list of all the partition points too. But I'm not getting the correct results with this helper function:

template <typename ForwardIterator, typename UnaryPredicate, typename... UnaryPredicates>
std::list<ForwardIterator> multi_partition_helper (std::list<ForwardIterator>& partition_points, 
    ForwardIterator first, ForwardIterator last, UnaryPredicate pred, UnaryPredicates... rest) {
    while (true) {
        while ((first != last) && pred(*first))
            ++first;
        if (first == last--) break;
        while ((first != last) && !pred(*last))
            --last;
        if (first == last) break;
        std::iter_swap (first++, last);
    }
    partition_points.push_back (first);
    multi_partition_helper (partition_points, first, last, rest...);
}

template <typename ForwardIterator, typename UnaryPredicate, typename... UnaryPredicates>
std::list<ForwardIterator> multi_partition_helper (std::list<ForwardIterator>&, ForwardIterator, ForwardIterator) {
// End of recursion.
}

Am I going about it the wrong way?

1

There are 1 answers

2
Columbo On BEST ANSWER

A trivial implementation is

template <typename BidirIt, typename... Predicates>
void trivial_mul_part( BidirIt first, BidirIt last, Predicates... preds )
{
    std::sort( first, last,
      [=] (decltype(*first) const& lhs, decltype(*first) const& rhs)
      {
          return std::make_tuple(preds(lhs)...) > std::make_tuple(preds(rhs)...);
      } );
}

And can be used as a reference algorithm.
The real algorithm can be implemented recursively in terms of std::partition itself. The idea is to call std::partition with the nth predicate, get the iterator to the beginning of the n+1th range and do the next call with this iterator as the first iterator and the n+1th predicate as the predicate.

template <typename BidirIt, typename OutputIterator>
void multi_partition( BidirIt first, BidirIt last, OutputIterator out ) {}

template <typename BidirIt, typename OutputIterator,
          typename Pred, typename... Predicates>
void multi_partition( BidirIt first, BidirIt last,  OutputIterator out,
                      Pred pred, Predicates... preds )
{
    auto iter = std::partition(first, last, pred);
    *out++ = iter;
    multi_partition<BidirIt>(iter, last, out, preds...);
}

As the actual algorithm, which can be used as follows:

int arr[] {0, 1, 0, 1, 0, 2, 1, 2, 2};
std::vector<int*> iters;

multi_partition(std::begin(arr), std::end(arr), std::back_inserter(iters),
                 [] (int i) {return i == 2;},
                 [] (int i) {return i == 1;});

for (auto i : arr)
    std::cout << i << ", ";
std::cout << '\n';
for (auto it : iters)
    std::cout << "Split at " << it - arr << '\n';

Demo.