I can confirm your findings that looped version is about twice as fast as the batched matmul. Matmul on its own is not essentially the one taking long runtime in your code snippet, it is the other operation of summing up along third dimension after abs which is costly. It is due to the following reasons.
sum(abs(result))
- abs is again not issue here. Sum is reduction algorithm, which are usually quite fast along the fast moving dimension. However, reduction along higher dimension the element stride is size of the matrix for successive elements. This expensive compared to reduction along continuous locations.looped abs additions
- This version is however is accessing elements that continuous in memory because, we are basically adding respective elements of 6 matrices. On top of this, the entire loop (along with abs OP) will be converted into a single JIT kernel that does the following which is very efficient.
res = res + ptr0[i] + ptr1[i] + ptr2[i] + ptr0[i] + ptr1[i]
Above line is just for illustration, that is not the exact JIT kernel.
Hence, the batched version is faster than looped version in this specific case because of the reduction operation that is being done on the result of matmul.
My test GPU: GTX 1060
The matmul itself for a single [10k x 6] * [6 x 1k]
is about half a millisecond on GTX 1060. Six such matmuls can't be done under millisecond on my GTX 1060 at least I would think. What is your target runtime ?
EDITED (Jan 10, 2020):
You can try looking into our latest entry into gemm category in master branch of ArrayFire. However, you would have to build arrayfire from source until our next feature release 3.7. You can look at the documentation at the following page.
https://github.com/arrayfire/arrayfire/blob/master/include/af/blas.h#L230
It follows the principle of Carray
from- cuBLAS gemm API.Actually, This won't work because of abs
operation on result of each matmul.
You can try looking into our latest entry into gemm category in master branch of ArrayFire. However, you would have to build arrayfire from source until our next feature release 3.7. You can look at the documentation at the following page.
https://github.com/arrayfire/arrayfire/blob/master/include/af/blas.h#L230
It follows the principle of Carray
from cuBLAS gemm API.