Is it possible to define the length of the type to sort in STXXL at run time?

283 views Asked by At

I have an application that requires a built-in sort and I'm hoping to replace the existing sort mechanism with the sort provided by STXXL. I have successfully tested it using STXXL, but my problem is that, although a specific run of the sort needs to operate on fixed length strings, the length is determined at run-time and can be anywhere between 10 bytes and 4000 bytes. Always allowing for 4000 bytes will obviously be grossly inefficient if the actual length is small.
For those not familiar with STXXL, I believe the problem roughly equates to defining a std::vector without knowing the size of the objects at compilation time. However, I'm not a C++ expert - the application is written in C.
In my test this is the type that I am sorting:

struct string80
{
    char x[80];
};

and this is the type definition for the STXXL sorter:

typedef stxxl::sorter<string80, sort_comparator80> stxxl_sorter80;  

The problem is that I don't want to hard-code the array size to '80'.
The only solution I can come up with, is to define a number of structures of varying lengths and pick the closest at run-time. Am I missing a trick? Am I thinking in C rather than C++?

2

There are 2 answers

1
Timo Bingmann On BEST ANSWER

There is no good solution here, at least not with STXXL.

The STXXL sorter is highly optimized, and the code requires the data type's size to be provided at compile time via template parameters. I don't see that this will, or even should change.

The method of instantiating classes for many different parameters is not nice, but pretty common practise. Just think of all the different std::vector instances used in simple C++ programs, which could all be handled via void* functions in C.

Depending on how much code you want to roll out, try instanciating powers of two, and then more fine grain for your common parameters.

0
Ilia Minkin On

What if we store objects (records) of size n in a flat stxxl::vector of chars. Then, define a custom iterator based on stxxl::vector::iterator that merely skips n bytes on each increment. This will work with std::sort and even tbb::sort, when used std::vector instead of STXXL's. I see that STXXL's ExtIterator has a lot of additional traits. Is it possible to define them correctly for such an iterator?

#include <vector>
#include <cassert>
#include <cstdlib>
#include <stxxl.h>
#include <iostream>
#include <algorithm>

typedef std::vector<char>::iterator It;

class ObjectValue;

//This class defines a reference object that handles assignment operations
//during a sorting
class ObjectReference
{
public:
    ObjectReference() : recordSize_(0) {}
    ObjectReference(It ptr, size_t recordSize) : ptr_(ptr), recordSize_(recordSize) {}

    void operator = (ObjectReference source) const
    {
        std::copy(source.ptr_, source.ptr_ + recordSize_, ptr_);
    }

    void operator = (const ObjectValue & source) const;

    It GetIterator() const
    {
        return ptr_;
    }

    size_t GetRecordSize() const
    {
        return recordSize_;
    }

private:
    It ptr_;
    size_t recordSize_;
};

//This class defines a value object that is used when a temporary value of a
//record is required somewhere
class ObjectValue
{
public:
    ObjectValue() {}
    ObjectValue(ObjectReference prx) : object_(prx.GetIterator(), prx.GetIterator() + prx.GetRecordSize()) {}
    ObjectValue(It ptr, size_t recordSize) : object_(ptr, ptr + recordSize) {}
    std::vector<char>::const_iterator GetIterator() const
    {
        return object_.begin();
    }

private:
    std::vector<char> object_;
};

//We need to support copying from a reference to an object
void ObjectReference::operator = (const ObjectValue & source) const
{
    std::copy(source.GetIterator(), source.GetIterator() + recordSize_, ptr_);
}

//The comparator passed to a sorting algorithm. It recieves iterators, converts
//them to char pointers, that are passed to the actual comparator tha handles
//object comparison
template<class Cmp>
class Comparator
{
public:
    Comparator() {}
    Comparator(Cmp cmp) : cmp_(cmp) {} 

    bool operator () (const ObjectReference & a, const ObjectReference & b) const
    {
        return cmp_(&*a.GetIterator(), &*b.GetIterator());
    }

    bool operator () (const ObjectValue & a, const ObjectReference & b) const
    {
        return cmp_(&*a.GetIterator(), &*b.GetIterator());
    }

    bool operator () (const ObjectReference & a, const ObjectValue & b) const
    {
        return cmp_(&*a.GetIterator(), &*b.GetIterator());
    }

    bool operator () (const ObjectValue & a, const ObjectValue & b) const
    {
        return cmp_(&*a.GetIterator(), &*b.GetIterator());
    }

private:
    Cmp cmp_;
};

//The iterator that operates on flat byte area. If the record size is $n$, it
//just skips $n$ bytes on each increment operation to jump to the next record
class RecordIterator : public std::iterator<std::random_access_iterator_tag, ObjectValue, size_t, RecordIterator, ObjectReference>
{
public:
    RecordIterator() : recordSize_(0) {}
    RecordIterator(It ptr, size_t recordSize) : ptr_(ptr), recordSize_(recordSize) {}
    ObjectReference operator * () const
    {
        return ObjectReference(ptr_, recordSize_);
    }

    ObjectReference operator [] (size_t diff) const
    {
        return *(*this + diff);
    }

    It GetIterator() const
    {
        return ptr_;
    }

    size_t GetRecordSize() const
    {
        return recordSize_;
    }

    RecordIterator& operator ++()
    {
        ptr_ += recordSize_;
        return *this;
    }

    RecordIterator& operator --()
    {
        ptr_ -= recordSize_;
        return *this;
    }

    RecordIterator operator ++(int)
    {
        RecordIterator ret = *this;
        ptr_ += recordSize_;
        return ret;
    }

    RecordIterator operator --(int)
    {
        RecordIterator ret = *this;
        ptr_ -= recordSize_;
        return ret;
    }

    friend bool operator < (RecordIterator it1, RecordIterator it2);
    friend bool operator > (RecordIterator it1, RecordIterator it2);
    friend bool operator == (RecordIterator it1, RecordIterator it2);
    friend bool operator != (RecordIterator it1, RecordIterator it2);
    friend size_t operator - (RecordIterator it1, RecordIterator it2);
    friend RecordIterator operator - (RecordIterator it1, size_t shift);
    friend RecordIterator operator + (RecordIterator it1, size_t shift);

private:
    It ptr_;
    size_t recordSize_;
};

bool operator < (RecordIterator it1, RecordIterator it2)
{
    return it1.ptr_ < it2.ptr_;
}

bool operator > (RecordIterator it1, RecordIterator it2)
{
    return it1.ptr_ > it2.ptr_;
}

bool operator == (RecordIterator it1, RecordIterator it2)
{
    return it1.ptr_ == it2.ptr_;
}

bool operator != (RecordIterator it1, RecordIterator it2)
{
    return !(it1 == it2);
}

RecordIterator operator - (RecordIterator it1, size_t shift)
{
    return RecordIterator(it1.ptr_ - shift * it1.recordSize_, it1.recordSize_);
}

RecordIterator operator + (RecordIterator it1, size_t shift)
{
    return RecordIterator(it1.ptr_ + shift * it1.recordSize_, it1.recordSize_);
}

size_t operator - (RecordIterator it1, RecordIterator it2)
{
    return (it1.ptr_ - it2.ptr_) / it1.recordSize_;
}

namespace std
{
    //We need to specialize the swap for the sorting to work correctly
    template<>
    void swap(ObjectReference & it1, ObjectReference & it2)
    {       
        ObjectValue buf(it1.GetIterator(), it1.GetRecordSize());
        std::copy(it2.GetIterator(), it2.GetIterator() + it2.GetRecordSize(), it1.GetIterator());
        std::copy(buf.GetIterator(), buf.GetIterator() + it1.GetRecordSize(), it2.GetIterator());
    }
}

//Finally, here is the "user"-defined code. In the example, "records" are
//4-byte integers, although actual size of a record can be changed at runtime
class RecordComparer
{
public:
    bool operator ()(const char * aRawPtr, const char * bRawPtr) const
    {
        const int * aPtr = reinterpret_cast<const int*>(aRawPtr);
        const int * bPtr = reinterpret_cast<const int*>(bRawPtr);
        return *aPtr < *bPtr;
    }
};

int main(int, char*[])
{
    size_t size = 100500;
    //Although it is a constant, it is easy to change to in runtime 
    size_t recordSize = sizeof(int);

    std::vector<int> intVector(size);
    std::generate(intVector.begin(), intVector.end(), rand);    
    const char * source = reinterpret_cast<const char*>(&intVector[0]);
    std::vector<char> recordVector;
    std::copy(source, source + recordVector.size(), &recordVector[0]);
    RecordIterator begin(recordVector.begin(), recordSize);
    RecordIterator end(recordVector.end(), recordSize);

    //Sort "records" as blocks of bytes
    std::sort(begin, end, Comparator<RecordComparer>());
    //Sort "records" as usual
    std::sort(intVector.begin(), intVector.end());
    //Checking that arrays are the same:
    for (; begin != end; ++begin)
    {
        size_t i = begin - RecordIterator(recordVector.begin(), recordSize);
        It it = (*(begin)).GetIterator();
        int* value = reinterpret_cast<int*>(&(*it));
        assert(*value == intVector[i]);
    }

    return 0;
}