We can try to optimize the subtraction by looking at the result table and finding a bit level hack that gives us the right result after which a modified bitcount will produce the result.
The table when viewed for each bit pair is:
00 01 10 11
0 1 2 3
00 0 0 1 2 3
01 1 1 0 1 2
10 2 2 1 0 1
11 3 3 2 1 0
a^b is almost the right result, only for a=1 and b=2 does this give the wrong result (3 instead of 1).
So that need to be detected and dealt with.
That means that if a&b == 0 and a^b == 3 and both a and b are not 0 then the result should be 1 which we can do by xoring the high bit of the result.
long result = a^b;
long is3 = result & result << 1; //high bit per pair will contain whether the pair is 3
is3 &= 0xaaaa_aaaa_aaaa_aaaa;0xaaaa_aaaa_aaaa_aaaal; //extract the high bit per pair
long is0 = a&b;
is0 = ~(is0 | is0 << 1);//high bit per pair will contain whether the pair is 0
is0 &= 0xaaaa_aaaa_aaaa_aaaa;0xaaaa_aaaa_aaaa_aaaal; //extract the high bit per pair
long notBoth0 = ~(a|a<<1) & ~(b|b<<1);
notBoth0 &= 0xaaaa_aaaa_aaaa_aaaa;0xaaaa_aaaa_aaaa_aaaal; //extract the high bit per pair
result ^= is3 & is0 & notBoth0; // only invert the bits set in is3, is0 and notBoth0
The modified bitcount would be something like:
result = (result & 0x3333_3333_3333_33330x3333_3333_3333_3333l) + ((result >> 2) & 0x3333_3333_3333_33330x3333_3333_3333_3333l);
result = (result & 0x0f0f_0f0f_0f0f_0f0f0x0f0f_0f0f_0f0f_0f0fl) + ((result >> 4) & 0x0f0f_0f0f_0f0f_0f0f0x0f0f_0f0f_0f0f_0f0fl);
result = (result & 0x00ff_00ff_00ff_00ff0x00ff_00ff_00ff_00ffl) + ((result >> 8) & 0x00ff_00ff_00ff_00ff0x00ff_00ff_00ff_00ffl);
result = (result & 0x0000_ffff_0000_ffff0x0000_ffff_0000_ffffl) + ((result >> 16) & 0x0000_ffff_0000_ffff0x0000_ffff_0000_ffffl);
result = (result & 0x0000_0000_ffff_ffff0x0000_0000_ffff_ffffl) + ((result >> 32) & 0x0000_0000_ffff_ffff0x0000_0000_ffff_ffffl);
note: I use underscores in the hexadecimal literals to make it clearer (they are legal in java 7 I believe) however in java 6 you will need to remove them.