Google Groups no longer supports new Usenet posts or subscriptions. Historical content remains viewable.
Dismiss

parallel merge-sort

5 views
Skip to first unread message

Bonita Montero

unread,
Nov 14, 2019, 12:24:45 AM11/14/19
to
There's stable_sort( execution::parallel_policy(), ... ) to have a
parallel stable sort, usually realized via merge-sort. I wrote my
own parallel merge-sort implementation to check whether I get a
better performance. On Windows with VC++ 2019 my merge-sort takes
about 35% less time. I'd like to see results from the libstdc++
library as well as LLVM's libc++. You can select between my merge
-sort and the standard library's merge-sort via BUILTIN_PAR. And
please define NDEBUG when compiling to strip the assertions.

//#define BUILTIN_PAR
#include <iostream>
#include <random>
#include <limits>
#include <algorithm>
#include <cassert>
#include <iterator>
#include <cstdlib>
#include <chrono>
#include <cstdint>
#include <thread>
#include <functional>
#include <exception>
#include <mutex>
#if defined(BUILTIN_PAR)
#include <execution>
#endif

template<typename T>
struct invoke_on_destruct
{
private:
T &m_t;
bool m_enabled;

public:
invoke_on_destruct( T &t ) :
m_t( t ), m_enabled( true )
{
}

~invoke_on_destruct()
{
if( m_enabled )
m_t();
}

void invoke_and_disable()
{
m_t();
m_enabled = false;
}
};

template<typename It, typename Cmp = std::less<typename
std::iterator_traits<It>::value_type>>
class merge_sort
{
public:
merge_sort( It start, It end, Cmp const &cmp = Cmp(), unsigned
nThreads = 1, std::size_t threadThreshold = 1'000'000 );

struct par_exception : public std::exception
{
using iterator = std::vector<std::exception_ptr>::iterator;
// there's no copy-constructor because of internal vector
// so catch par_exception only via reference
~par_exception();
iterator begin();
iterator end();
private:
friend class merge_sort;
par_exception( std::vector<std::exception_ptr> &&exceptions );
std::vector<std::exception_ptr> m_exceptions;
};

private:
using exceptions_vector = std::vector<std::exception_ptr>;
Cmp m_cmp;
// threshold where sorts are partitioned in threads
std::size_t m_threadThreshold;
// mutex that protects the exception-array
std::mutex m_excMtx;
exceptions_vector m_exceptions;
template<typename UpIt, typename BufIt>
void recursion( UpIt start, UpIt end, BufIt buf );
template<typename UpIt, typename BufIt>
void merge( UpIt start, BufIt leftBuf, BufIt rightBuf, BufIt bufEnd );
template<typename UpIt>
void parRecursion( UpIt start, UpIt end, unsigned nThreads );
};

template<typename It, typename Cmp>
inline
merge_sort<It, Cmp>::merge_sort( It start, It end, Cmp const &cmp,
unsigned nThreads, std::size_t threadThreshold ) :
m_cmp( cmp )
{
m_threadThreshold = threadThreshold;
using namespace std;
// threads == 0 -> number of threads = hardware-threads
if( nThreads == 0 )
nThreads = 99999;
unsigned hwThreads = thread::hardware_concurrency();
hwThreads += hwThreads == 0;
nThreads = nThreads <= hwThreads ? nThreads : hwThreads;
// reserve number of threads elements in the exception_ptr-vector
// so that there will be no exception when we do emplace_back
m_exceptions.reserve( nThreads );
try
{
parRecursion( start, end, nThreads );
if( m_exceptions.size() )
if( m_exceptions.size() > 1 )
// multiple exceptions from threads: throw par_exception
throw par_exception( move( m_exceptions ) );
else
// only one exception from threads: rethrow it
rethrow_exception( m_exceptions[0] );
}
catch( ... )
{
if( m_exceptions.size() )
{
// additional exception catched: throw par_exception
m_exceptions.emplace_back( current_exception() );
throw par_exception( move( m_exceptions ) );
}
else
// single exception: rethrow it
throw;
}
}

template<typename It, typename Cmp>
template<typename UpIt, typename BufIt>
void merge_sort<It, Cmp>::recursion( UpIt start, UpIt end, BufIt buf )
{
using namespace std;
if( end - start <= 1 )
return;
copy( start, end, buf );
size_t n = end - start;
BufIt leftBuf = buf,
leftBufEnd = buf + n / 2,
rightBuf = leftBufEnd,
bufEnd = buf + n;
recursion( leftBuf, leftBufEnd, bufEnd );
recursion( rightBuf, bufEnd, bufEnd );
merge( start, leftBuf, rightBuf, bufEnd );
}

template<typename It, typename Cmp>
template<typename UpIt, typename BufIt>
void merge_sort<It, Cmp>::merge( UpIt start, BufIt leftBuf, BufIt
rightBuf, BufIt bufEnd )
{
BufIt leftBufEnd = rightBuf;
for( UpIt wrtBack = start; ; )
if( m_cmp( *leftBuf, *rightBuf ) )
{
*wrtBack++ = *leftBuf;
if( ++leftBuf == leftBufEnd )
{
// faster for small number of elements than std::copy
do
*wrtBack++ = *rightBuf;
while( ++rightBuf != bufEnd );
break;
}
}
else
{
*wrtBack++ = *rightBuf;
if( ++rightBuf == bufEnd )
{
do
*wrtBack++ = *leftBuf;
while( ++leftBuf != leftBufEnd );
break;
}
}
}

template<typename It, typename Cmp>
template<typename UpIt>
void merge_sort<It, Cmp>::parRecursion( UpIt start, UpIt end, unsigned
nThreads )
{
using namespace std;
using T = typename iterator_traits<It>::value_type;
size_t n = end - start;
if( nThreads <= 1 || n < m_threadThreshold )
{
vector<T> buf;
size_t bs = 0;
// calculate buffer-/stack-size
for( size_t split = end - start; split > 1; bs += split, split
-= split / 2 );
buf.resize( bs );
recursion( start, end, buf.begin() );
}
else
{
// split-buffer
vector<T> buf;
buf.resize( n );
copy( start, end, buf.begin() );
vector<thread> threads;
// reserve threads-vector to not reallocate twice
threads.reserve( 2 );
// automatically join threads when an exception is thrown
auto joinThreads = [&threads]()
{
for( thread &thr : threads )
// try to join infinitely because the thread will
// continue to access our buffer and a thread-object
// might not be left unjoined at destruction
for( ; ; )
try
{
thr.join();
break;
}
catch( ... )
{
}
};
invoke_on_destruct<decltype(joinThreads)> iodJoin( joinThreads );
// iterator-type for our split-buffer
using BufIt = typename vector<T>::iterator;
// proxy thread-lambda for our new threads
auto prProxy = [this]( BufIt start, BufIt end, unsigned nThreads )
{
try
{
parRecursion( start, end, nThreads );
}
catch( ... )
{
// remember exception in our exception-array
unique_lock<mutex> excLock( m_excMtx );
m_exceptions.emplace_back( current_exception() );
}
};
unsigned rightThreads = nThreads / 2,
leftThreads = nThreads - rightThreads;
// if the left number of threads is uneven give the threads
more input
size_t left = (size_t)(n * ((double)leftThreads /
nThreads)),
right = n - left;
BufIt leftBuf = buf.begin(),
leftBufEnd = buf.begin() + left,
rightBuf = leftBufEnd,
bufEnd = buf.begin() + n;
// start left thread
threads.emplace_back( prProxy, leftBuf, leftBufEnd, leftThreads );
if( rightThreads > 1 )
// start right thread
threads.emplace_back( prProxy, rightBuf, bufEnd,
rightThreads );
else
// there's only one thread right, so we do it on our own
parRecursion( rightBuf, bufEnd, 1 );
// join threads
iodJoin.invoke_and_disable();
// if there are any exceptions from the thread: stop
unique_lock<mutex> excLock( m_excMtx );
if( m_exceptions.size() )
return;
excLock.unlock();
// merge split-buffer back into input-buffer
merge( start, leftBuf, rightBuf, bufEnd );
}
}

template<typename It, typename Cmp>
inline
merge_sort<It, Cmp>::par_exception::~par_exception()
{
}

template<typename It, typename Cmp>
inline
merge_sort<It, Cmp>::par_exception::par_exception(
std::vector<std::exception_ptr> &&exceptions ) :
m_exceptions( std::move( exceptions ) )
{
}

template<typename It, typename Cmp>
inline
typename merge_sort<It, Cmp>::par_exception::iterator merge_sort<It,
Cmp>::par_exception::begin()
{
return m_exceptions.begin();
}

template<typename It, typename Cmp>
inline
typename merge_sort<It, Cmp>::par_exception::iterator merge_sort<It,
Cmp>::par_exception::end()
{
return m_exceptions.begin();
}

#if defined(_MSC_VER)
#pragma warning(disable: 26444)
#endif

using namespace std;
using namespace chrono;

int main( int argc, char **argv )
{
if( argc < 2 )
return EXIT_FAILURE;
size_t n = (size_t)(unsigned)atoi( argv[1])
* 1'000'000;
random_device rd;
uniform_int_distribution<int> uid( numeric_limits<int>::min(),
numeric_limits<int>::max() );
vector<int> rv;
rv.resize( n );
cout << "randomizing ..." << endl;
for( int &r : rv )
r = uid( rd );
cout << "sorting ..." << endl;
time_point<high_resolution_clock> start = high_resolution_clock::now();
#if defined(BUILTIN_PAR)
stable_sort( execution::parallel_policy(), rv.begin(), rv.end(),
less<int>() );
#else
using msInt = merge_sort<vector<int>::iterator>;
msInt( rv.begin(), rv.end(), less<int>(), 0 );
#endif
double seconds = (double)duration_cast<nanoseconds>(
high_resolution_clock::now() - start ).count() / 1.0E9;
cout << "... sorted!" << endl;
cout << "sorting-time: " << seconds << " seconds" << endl;
for( vector<int>::iterator scn = rv.begin(); scn < rv.end() - 1;
++scn )
assert(scn[0] <= scn[1]);
return EXIT_SUCCESS;
}

Bonita Montero

unread,
Nov 14, 2019, 12:27:20 AM11/14/19
to

> template<typename It, typename Cmp>
> inline
> typename merge_sort<It, Cmp>::par_exception::iterator merge_sort<It,
> Cmp>::par_exception::end()
> {
>     return m_exceptions.begin();

copy-mistake, but the function isn't used.
Should be ...
return m_exceptions.end();
> }

0 new messages