4

I'm currently focusing on the implementation of Kyber (ML-KEM).
I noticed that the AVX2 version of the compress operation seems to use a fast division algorithm.

In the compress operation, we need to calculate x * (2**10 / Q), where x is some int16_t input (a poly coefficient in Q prime field), and Q is a prime number (3329 in Kyber, and 7681 in the following code I pasted). The final result should be rounded to its closer bound, say if the results are from 2.1~2.499999..., then they are rounded down to 2 while larger than 2.5 are rounded up to 3.

ref version code here: https://github.com/pq-crystals/kyber/blob/main/ref/polyvec.c#L51

AVX2 version code here: https://github.com/pq-crystals/kyber/blob/main/avx2/polyvec.c#L11

I am confused about how the AVX2 version works, and what is the algorithm behind that.

Besides that, there is another implementation of the compress10 method:

// input "a" is a polyvec
//Here q is 7681, not same as current Kyber (ML-KEM)
const __m256i fdiv = _mm256_set1_epi16(8737);//fast division // floor(2^26 / q + 0.5)
const __m256i hfq = _mm256_set1_epi16(15);
for(int i = 0; i < K; i++) {
  for(int j = 0; j < 256; j++) {
    d0.vec = _mm256_loadu_si256((__m256i *)&a->vec[i].coeffs[16 * j]);

    d0.vec = _mm256_slli_epi16(d0.vec, 2);
    d0.vec = _mm256_add_epi16(d0.vec, hfq);
    d0.vec = _mm256_mulhi_epi16(d0.vec, fdiv);
    d0.vec = _mm256_srli_epi16(d0.vec, 2);

    ......
  }
}

I summarized the formula above as: (((x*4 + 15) * 8737) >> 16) >> 2.

Its ref version is like:

....
uint16_t t[4];
for (i = 0; i< K; i++)
{
    for (j = 0; j< 256/ 4; j++)
    {
        for (k = 0; k < 4; k++) {
            // PARAM_Q here is 7681
            t[k] = ((((uint32_t)a->vec[i].coeffs[4 * j + k] << 10) + PARAM_Q / 2) / PARAM_Q) & 0x3ff;
        }
// other loop operations
....

I summarized the formula above as: ((x << 10) + q/2) / q.

The ref version code is quite simple and easy to understand, but the optimized one is really confusing.

And I found out that when the poly coefficient is 5772, the ref version will get the compressed result as 769, while the AVX2 version generates 770. I ran 20 groups of test vectors for AVX2 version and ref version (code I mentioned above, Kyber's both versions are correct all the time), this is the only case that shows different results between the two versions. It seems like a precision issue.

I don't know why and also don't know how the optimized algorithm (both the Kyber version and another version I mentioned above) works (or does not work).

5
  • (((x*4 + 15) * 8737) >> 16) >> 2 calculates (34948 x + 131055) / 262144 = 0.13331604 x + 0.4999351501, rounded down. For q = 7681, the behavior of (x << 10 + q/2) / q is undefined since the shift amount is 10 + q/2, which far exceeds the width of whatever type you are using, but presumably you meant ((x << 10) + q/2) / q. That computes (1024 x + 3840½) / 7681 = 0.1333159745 x + ½, rounded down. So it is clear the formulas compute similar but different values. Commented Oct 28, 2024 at 11:09
  • I might guess somebody decided to approximate the arbitrary division p/q as k/P, where P was a power of two (chosen because then division can be implemented as a shift) and k was to be found to match p/q (so k ≈ p/q•P), then tweaked the amount added to make the first failure as high as possible. You might be able to increase it by increasing the 2 in >> 2, giving a larger denominator with more wiggle room for adjusting the numerator. How large you can make the denominator depends on what range of numbers you want to support, to avoid overflow in the numerator. Commented Oct 28, 2024 at 11:13
  • But looking at the code, with the kludges to get the work done in 16-bit arithmetic, there may not be much wiggle room. The _mm256_slli_epi16(d0.vec, 2) limits inputs to [−8192, 8192), so you cannot increase that shift if you want to cover the entire range to 7680 unless you go to wider arithmetic. Maybe some unsigned arithmetic could give you one more bit, but I am not sure that would be enough. Commented Oct 28, 2024 at 11:19
  • Yes, I did mean ((x << 10) + q/2) / q, thanks for the correction! Your idea is basically right, this optimized method is mainly trying to approximate the division, for expression (x * (2**10 / q)), where q=7681 or 3329 is a constant number in Kyber or the other situation. And there is a response at the crypto channel: crypto.stackexchange.com/questions/113315/…, but I'm still a bit confused about that answer, maybe we could further discuss at there. Commented Oct 29, 2024 at 3:08
  • 1
    I’m voting to close this question because it has a cross-site duplicate. Commented Oct 31, 2024 at 12:13

0

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.