A crazy fast bit-counting technique
Let's say that you've got a lot of numbers that represent bitmasks of some kind, and you want to count how many times each bit is on or off across the entire set. Maybe you're analyzing game positions represented as bitboards for an AI, or trying to find certain types of weaknesses in random-number generators, like in Forge (a successor to crypto-js) or Cryptocat (at archive.org) (read the great write-up at Sophos).
So, you write some very straight-forward code to count the bits. It grabs one bitmask. If the lowest-order bit is set, it increments the counter for that bit position. Then, it right-shifts the bitmask and moves to the counter for the next bit. Repeat that for each bit in the mask, then repeat that for each bitmask:
const int N = 1000000; unsigned long x[N]; // Assuming sizeof(unsigned long) == 8, or 64 bits. int counts[64] = {0}; void count_simple(void) { for(int i = 0; i < N; i++) { int j = 0; while(x[i] != 0) { counts[j] += x[i] & 1; x[i] >>= 1; ++j; } } }
You run your program, and it works correctly, but it's too slow. I'll show you how to speed this up. The technique, which applies to languages like Python or Javascript as well as to C, is both crazy, and crazy-fast!
First, I'll recompile the program to get a profile to see where the hotspots are:
g++ -g -pg -o bit_count bit_count.cc ./bit_count gprof bit_count
Here's the profile on the original listing, with a total runtime of 0.91 seconds:
void count_simple(void) { | |
1% | for(int i = 0; i < N; i++) { |
int j = 0; | |
26% | while(x[i] != 0) { |
36% | counts[j] += x[i] & 1; |
36% | x[i] >>= 1; |
++j; | |
} | |
} | |
} |
The hotspot is definitely the count loop, accounting for 99% of the runtime (of this simple demo program, anyway), far surpassing either generating all the x[i] or displaying the results.
It doesn't seem like there's much to be done here. We're already hitting the smaller data structure (counts) hardest in the inner loop, so we're playing nice with the cache. We could right-shift x[i] until we get a set lowest-order bit, and only then bother incrementing counts[j]—but, that doesn't actually help in practice.
The answer is to operate on all of the bits in a single mask in parallel. If you're in this position, you're probably already doing that somewhere else, but with AND and OR operations, not addition. Addition is trickier because it overflows into higher-order bits.
So, let's provide those higher-order bits and handle the overflow ourselves. This requires coding logic for our own k-bit accumulator, which holds a k-bit result and accumulates either 0 or 1 each cycle.
Let's build a 3-bit accumulator, with bits b1, b2, and b3 from lowest-order to higher, all already holding some value. We're ready to accumulate a 0 or a 1, as held in x. It goes like this:
t1 = b1 ^ x; c = b1 & x;
t1 holds the result of (b1 + x) % 2, and c is 1 for a carry and 0 otherwise. The next bits all chain off of the carry bit but follow the same logic:
t2 = b2 ^ c; c = b2 & c; t3 = b3 ^ c; c = b3 & c;
And so on; we can chain as many bits together as we want to get larger accumulators. Finally, move the result from the temporaries:
b1 = t1; b2 = t2; b3 = t3;
Because each of b1, b2, and b3 hold 64 bits (or whatever), we actually have 64 accumulators all running in parallel!
After we finish accumulating, we add the value of the accumulator into our counts array:
for(int j = 0; j < 64; j++) { counts[j] += 4 * (b3 & 1) + 2 * (b2 & 1) + 1 * (b1 & 1); b1 >>= 1; b2 >>= 1; b3 >>= 1; }
For our 3-bit accumulator, we could count up to 15 items before needing to reset it, so our counting loop will need to unroll that many times. We need to jump past the rest of the accumulator and into the count loop if we run out of items to accumulate. We'll need to be careful to initialize accumulator properly so that we can do that. Also, for the first few unrollings of the accumulator, we'll know that the total can fit in fewer bits, so we can omit some of the bit operations. We end up with the following code:
void count_3bit(void) { for(int i = 0; i < N;) { int n = 0; unsigned long b1 = 0, b2 = 0, b3 = 0; unsigned long t1, t2, t3, c; if(i + n >= N) goto end; b1 = x[i + n]; n = 1; if(i + n >= N) goto end; t1 = b1 ^ x[i + n]; c = b1 & x[i + n]; t2 = c; b1 = t1; b2 = t2; n = 2; if(i + n >= N) goto end; t1 = b1 ^ x[i + n]; c = b1 & x[i + n]; t2 = b2 | c; b1 = t1; b2 = t2; n = 3; if(i + n >= N) goto end; t1 = b1 ^ x[i + n]; c = b1 & x[i + n]; t2 = b2 ^ c; c = b2 & c; t3 = b3 | c; b1 = t1; b2 = t2; b3 = t3; n = 4; if(i + n >= N) goto end; t1 = b1 ^ x[i + n]; c = b1 & x[i + n]; t2 = b2 ^ c; c = b2 & c; t3 = b3 | c; b1 = t1; b2 = t2; b3 = t3; n = 5; if(i + n >= N) goto end; t1 = b1 ^ x[i + n]; c = b1 & x[i + n]; t2 = b2 ^ c; c = b2 & c; t3 = b3 | c; b1 = t1; b2 = t2; b3 = t3; n = 6; if(i + n >= N) goto end; t1 = b1 ^ x[i + n]; c = b1 & x[i + n]; t2 = b2 ^ c; c = b2 & c; t3 = b3 | c; b1 = t1; b2 = t2; b3 = t3; n = 7; end: for(int j = 0; j < 64; j++) { counts[j] += 4 * (b3 & 1) + 2 * (b2 & 1) + 1 * (b1 & 1); b1 >>= 1; b2 >>= 1; b3 >>= 1; } i += n; } }
Writing your own accumulator to run on a machine that has its own accumulators is somewhat crazy. But, how well does it perform?
Version | Time (seconds; warm cache) | |
g++ -O0 -g -pg | g++ -O2 | |
simple | 1.042 | 0.225 |
3bit | 0.158 | 0.055 |
Improvement | ×6.59 | ×4.09 |
Four to six times faster!
Well, okay, that's great if you're bit-twiddling in a systems language like C. Surely, a higher-level dynamic language running in a byte-code interpreter would be much more immune to this, right? Wrong! Here are the timings for a similar version written in Python:
Version | Time (seconds; warm cache) | |
python2 | python3 | |
simple | 0.782 | 1.006 |
3bit | 0.148 | 0.215 |
Improvement | ×5.28 | x4.68 |
Right in line with the C version!
Want to play with this yourself? Download the C and Python sources.
Trackbacks
The author does not allow comments to this entry
Comments
Display comments as Linear | Threaded