Ich hab das noch weiter optimiert und speichere nur noch die ungeraden
Bits.
Grundsätzlich basiert das Ganze auf der Idee des Segmented Sieve, d.h.
wenn ich alle Primzahlen bis zur Wurzel des Maximum habe kann ich mit
wenig Aufwand jeden Bereich bis zum Quadrat berechnen. Ich berechne
halt für jeden Thread eine Partition und innerhalb eines Threads wird
das ganze noch auf Blöcke aufgeteilt die in den L2 -Cache passen.
Letzlich hängt die Performance fast 1:1 vom Durchsatz des L2-Caches
ab, was sich u.A. daran zeigt, dass auf SMT-Systemen die Performance
mit halber und ganzer logischer Kern-Zahl nahezu gleich ist.
So, daran dürfte echt nichts mehr zu verbessern sein, alle Primzahlen
die in den 32 Bit Wertebereich gehen berechnet das Ganze in 0.15 Se-
kunden auf meinem Zen4-PC 16-Kerner. Auf dem 64-kernigen Zen2-PC
unter Linux braucht das Ganze 0.05 Sekunden.
#include <cstdlib>
#include <charconv>
#include <cstring>
#include <vector>
#include <algorithm>
#include <cmath>
#include <bit>
#include <fstream>
#include <string_view>
#include <thread>
#include <utility>
#include <new>
#include <span>
#include <array>
#include <cassert>
#if defined(_MSC_VER)
#include <intrin.h>
#elif defined(__GNUC__) || defined(__clang__)
#include <cpuid.h>
#endif
#if defined(_MSC_VER)
#pragma warning(disable: 26495)
#endif
using namespace std;
#if defined(__cpp_lib_hardware_interference_size)
constexpr ptrdiff_t CL_SIZE = hardware_destructive_interference_size;
#else
constexpr ptrdiff_t CL_SIZE = 64;
#endif
size_t getL2Size();
bool smt();
int main( int argc, char **argv )
{
if( argc < 2 )
return EXIT_FAILURE;
auto parse = []( char const *str, auto &value )
{
bool hex = str[0] == '0' && (str[1] == 'x' || str[1] == 'X');
str += 2 * hex;
auto [ptr, ec] = from_chars( str, str + strlen( str ), value, !hex ?
10 : 16 );
return ec == errc() && !*ptr;
};
size_t end;
if( !parse( argv[1], end ) )
return EXIT_FAILURE;
if( end < 2 || (ptrdiff_t)end++ < 0 )
throw bad_alloc();
unsigned
hc = jthread::hardware_concurrency(),
nThreads;
if( argc < 4 || !parse( argv[3], nThreads ) )
nThreads = hc;
if( !nThreads )
return EXIT_FAILURE;
using word_t = size_t;
constexpr size_t
BITS_PER_CL = CL_SIZE * 8,
BITS = sizeof(word_t) * 8;
auto bitEnd = []( size_t end ) { return end / 2 + (end & 1 ^ 1); };
size_t nBits = bitEnd( end );
union alignas(CL_SIZE) ndi_words_cl { word_t words[CL_SIZE /
sizeof(word_t)]; ndi_words_cl() {} };
vector<ndi_words_cl> sieveCachelines( (nBits + BITS_PER_CL - 1) /
BITS_PER_CL );
span<word_t> sieve( &sieveCachelines[0].words[0], (nBits + BITS - 1) /
BITS );
fill( sieve.begin(), sieve.end(), (word_t)-1 );
size_t sqrt = (ptrdiff_t)ceil( ::sqrt( (ptrdiff_t)end ) );
auto punch = [&]( auto, size_t start, size_t end, size_t prime )
{
size_t bit = start / 2;
end = bitEnd( end );
if( bit >= end )
return;
if( prime >= BITS ) [[likely]]
do [[likely]]
sieve[bit / BITS] &= rotl( (word_t)-2, bit % BITS );
while( (bit += prime) < end );
else
{
auto pSieve = sieve.begin() + bit / BITS;
do [[likely]]
{
word_t
word = *pSieve,
mask = rotl( (word_t)-2, bit % BITS ),
oldMask;
do
word &= mask,
oldMask = mask,
mask = rotl( mask, prime % BITS ),
bit += prime;
while( mask < oldMask );
*pSieve++ = word;
} while( bit < end );
}
};
for( size_t bSqrt = bitEnd( sqrt ), bit = 1; bit < bSqrt; ++bit )
[[likely]]
{
auto pSieve = sieve.begin() + bit / BITS;
do [[likely]]
if( word_t w = *pSieve >> bit % BITS; w ) [[likely]]
{
bit += countr_zero( w );
break;
}
while( (bit = bit + BITS & -(ptrdiff_t)BITS) < bSqrt );
if( bit >= bSqrt ) [[unlikely]]
break;
size_t prime = 2 * bit + 1;
punch( false_type(), prime * prime, sqrt, prime );
}
auto scan = [&]( size_t start, size_t end, auto fn )
{
start /= 2;
end = bitEnd( end );
if( start >= end )
return;
auto pSieve = sieve.begin() + start / BITS;
for( size_t bit = start; ; )
{
word_t word = *pSieve++ >> bit % BITS;
bool hasBits = word;
for( unsigned gap; word; word >>= gap, word >>= 1 ) [[likely]]
{
gap = countr_zero( word );
if( (bit += gap) >= end ) [[unlikely]]
return;
fn( bit * 2 + 1, true_type() );
if( ++bit >= end ) [[unlikely]]
return;
}
if( bit >= end ) [[unlikely]]
break;
bit = (bit + BITS - hasBits) & -(ptrdiff_t)BITS;
}
};
static auto dbl = []( ptrdiff_t x ) { return (double)x; };
double threadTange = dbl( end - sqrt ) / nThreads;
using range_t = pair<size_t, size_t>;
vector<pair<size_t, size_t>> ranges;
ranges.reserve( nThreads );
for( size_t t = nThreads, rEnd = end, trStart; t--; rEnd = trStart )
[[likely]]
{
trStart = sqrt + (ptrdiff_t)((ptrdiff_t)t * threadTange + 0.5);
size_t trClStart = trStart & -(2 * CL_SIZE * 8);
trStart = trClStart >= sqrt ? trClStart : sqrt;
if( trStart < rEnd )
ranges.emplace_back( trStart, rEnd );
}
double maxCacheRange = dbl( getL2Size() * (8 * 2) ) / 3 * (smt() &&
nThreads > hc / 2 ? 1 : 2);
vector<jthread> threads;
threads.reserve( ranges.size() - 1 );
for( range_t const &range : ranges )
{
auto cacheAwarePunch = [&]( size_t rStart, size_t rEnd )
{
double thisThreadRange = dbl( rEnd - rStart );
ptrdiff_t nCachePartitions = (ptrdiff_t)ceil( thisThreadRange /
maxCacheRange );
double cacheRange = thisThreadRange / dbl( nCachePartitions );
for( size_t p = nCachePartitions, cacheEnd = rEnd, cacheStart; p--;
cacheEnd = cacheStart ) [[likely]]
{
cacheStart = rStart + (ptrdiff_t)((double)(ptrdiff_t)p * cacheRange
+ 0.5);
cacheStart &= -(2 * CL_SIZE * 8);
cacheStart = cacheStart >= sqrt ? cacheStart : sqrt;
scan( 3, sqrt,
[&]( size_t prime, auto )
{
size_t start = (cacheStart + prime - 1) / prime * prime;
start = start & 1 ? start : start + prime;
punch( true_type(), start, cacheEnd, prime );
} );
}
};
if( &range != &ranges.back() )
threads.emplace_back( cacheAwarePunch, range.first, range.second );
else
cacheAwarePunch( range.first, range.second );
}
threads.resize( 0 );
if( argc >= 3 && !*argv[2] )
return EXIT_SUCCESS;
ofstream ofs;
ofs.exceptions( ofstream::failbit | ofstream::badbit );
ofs.open( argc >= 3 ? argv[2] : "primes.txt", ofstream::binary |
ofstream::trunc );
constexpr size_t
BUF_SIZE = 0x100000,
AHEAD = 32;
union ndi_char { char c; ndi_char() {} };
vector<ndi_char> rawPrintBuf( BUF_SIZE + AHEAD - 1 );
span printBuf( &rawPrintBuf.front().c, &rawPrintBuf.back().c + 1 );
auto
bufBegin = printBuf.begin(),
bufEnd = bufBegin,
bufFlush = bufBegin + BUF_SIZE;
auto print = [&]() { ofs << string_view( bufBegin, bufEnd ); };
auto printPrime = [&]( size_t prime, auto )
{
auto [ptr, ec] = to_chars( &*bufEnd, &bufEnd[AHEAD - 1], prime );
if( ec != errc() ) [[unlikely]]
throw system_error( (int)ec, generic_category(), "converson failed" );
bufEnd = ptr - &*bufBegin + bufBegin; // pointer to iterator - NOP
*bufEnd++ = '\n';
if( bufEnd >= bufFlush ) [[unlikely]]
print(), bufEnd = bufBegin;
};
printPrime( 2, false_type() );
scan( 3, end, printPrime );
print();
}
array<unsigned, 4> cpuId( unsigned code )
{
array<unsigned, 4> regs;
#if defined(_MSC_VER)
__cpuid( (int *)®s[0], code );
#elif defined(__GNUC__) || defined(__clang__)
__cpuid(code, regs[0], regs[1], regs[2], regs[3]);
#endif
return regs;
}
bool smt()
{
if( cpuId( 0 )[0] < 1 )
return false;
return cpuId( 1 )[3] >> 28 & 1;
}
size_t getL2Size()
{
if( cpuId( 0x80000000u )[0] < 0x80000006u )
return 512ull * 1024;
return (cpuId( 0x80000006u )[2] >> 16) * 1024;
}