So here's my little thread-pool engine:
First a little debug-header:
// debug_exceptions.h
#pragma once
#if !defined try_debug
#if !defined(NDEBUG)
#define try_debug try
#else
#define try_debug
#endif
#endif
#if !defined catch_debug
#if !defined(NDEBUG)
#define catch_debug(stmt) catch( stmt )
#else
#define catch_debug(stmt) if( false )
#endif
#endif
// thread-pool-engine
#include <thread>
#include <deque>
#include <future>
#include <utility>
#include <functional>
#include <memory>
#include <vector>
#include <mutex>
#include <condition_variable>
#include <cassert>
#include "debug_exceptions.h"
#if defined(_MSC_VER)
#pragma warning(disable: 26117)
#endif
struct thread_pool
{
thread_pool( unsigned maxThreads );
~thread_pool();
void stop();
void resize( unsigned maxThreads );
void clear_queue();
template<typename F, typename ... Args>
void enqueue_task( F &&f, Args &&...args );
private:
using void_task = std::packaged_task<void()>;
using task_ptr = std::shared_ptr<void_task>;
unsigned m_maxThreads;
std::mutex m_mtx;
unsigned m_stop;
std::vector<std::thread> m_threads;
std::condition_variable m_queueCv;
unsigned m_nThreadsRunning;
std::deque<task_ptr> m_taskQueue;
void theThread( std::size_t threadIndex );
};
thread_pool::thread_pool( unsigned maxThreads ) :
m_maxThreads( maxThreads ),
m_stop(),
m_nThreadsRunning( 0 )
{
m_threads.reserve( maxThreads );
}
thread_pool::~thread_pool()
{
using namespace std;
{
lock_guard<mutex> lock( m_mtx );
m_stop = (int)m_threads.size();
m_queueCv.notify_all();
}
for( thread &t : m_threads )
t.join();
}
void thread_pool::stop()
{
using namespace std;
lock_guard<mutex> lock( m_mtx );
m_stop = (unsigned)m_threads.size();
m_queueCv.notify_all();
}
void thread_pool::resize( unsigned maxThreads )
{
using namespace std;
lock_guard<mutex> lock( m_mtx );
unsigned nCurThreads = (unsigned)m_threads.size();
if( maxThreads >= nCurThreads )
m_threads.reserve( maxThreads ),
m_maxThreads = maxThreads;
else
{
unsigned stop = nCurThreads - maxThreads;
m_stop += stop;
for( ; stop; m_queueCv.notify_one(), --stop );
}
}
void thread_pool::clear_queue()
{
using namespace std;
lock_guard<mutex> lock( m_mtx );
m_taskQueue.clear();
}
template<typename F, typename ... Args>
void thread_pool::enqueue_task( F &&f, Args &&...args )
{
using namespace std;
auto threadProxy = []( thread_pool *tp, size_t threadIndex )
{
tp->theThread( threadIndex );
};
lock_guard<mutex> lock( m_mtx );
m_taskQueue.emplace_back( make_shared<void_task>( bind( forward<F>( f
), forward<Args>( args ) ... ) ) );
if( m_nThreadsRunning == m_threads.size() && m_nThreadsRunning <
m_maxThreads )
m_threads.emplace_back( threadProxy, this, m_threads.size() );
m_queueCv.notify_one();
}
void thread_pool::theThread( size_t threadIndex )
{
using namespace std;
unique_lock<mutex> lock( m_mtx );
for( ; ; )
{
task_ptr task;
for( ; ; )
{
if( m_taskQueue.size() != 0 )
{
task = move( m_taskQueue.front() );
m_taskQueue.pop_front();
++m_nThreadsRunning;
lock.unlock();
break;
}
if( m_stop )
{
m_threads[threadIndex].detach();
m_threads.erase( m_threads.begin() + threadIndex );
--m_stop;
return;
}
m_queueCv.wait( lock );
}
future<void> fut = task->get_future();
try_debug
{
fut.get();
}
catch_debug( ... )
{
assert(false);
}
lock.lock();
--m_nThreadsRunning;
}
}