Skip to main content
Stack Overflow for Teams is now Stack Internal: See how we’re powering the human intelligence layer of enterprise AI. Read more >
added 656 characters in body
Source Link
ead
  • 34.8k
  • 8
  • 101
  • 170

As suggested by @max9111, eliminating stride by using continuous memory-view, which doesn't improve the performance much:

@cython.boundscheck(False)
@cython.wraparound(False)
def cy_where_cont(double[::1] df):
    cdef int i
    cdef int n = len(df)
    cdef np.ndarray[dtype=double] output = np.empty(n, dtype=np.float64)
    cdef double[::1] view = output  # view as continuous!
    for i in range(n):
        if df[i]>0.5:
            view[i] = 2.0*df[i]
        else:
            view[i] = df[i]
    return output 

%timeit cy_where_cont(data)   #  165 ms

As suggested by @max9111, eliminating stride by using continuous memory-view, which doesn't improve the performance much:

@cython.boundscheck(False)
@cython.wraparound(False)
def cy_where_cont(double[::1] df):
    cdef int i
    cdef int n = len(df)
    cdef np.ndarray[dtype=double] output = np.empty(n, dtype=np.float64)
    cdef double[::1] view = output  # view as continuous!
    for i in range(n):
        if df[i]>0.5:
            view[i] = 2.0*df[i]
        else:
            view[i] = df[i]
    return output 

%timeit cy_where_cont(data)   #  165 ms
Source Link
ead
  • 34.8k
  • 8
  • 101
  • 170

Achieving Numba's performance with Cython

Usually I'm able to match Numba's performance when using Cython. However, in this example I have failed to do so - Numba is about 4 times faster than my Cython's version.

Here the Cython-version:

%%cython -c=-march=native -c=-O3
cimport numpy as np
import numpy as np
cimport cython

@cython.boundscheck(False)
@cython.wraparound(False)
def cy_where(double[::1] df):
    cdef int i
    cdef int n = len(df)
    cdef np.ndarray[dtype=double] output = np.empty(n, dtype=np.float64)
    for i in range(n):
        if df[i]>0.5:
            output[i] = 2.0*df[i]
        else:
            output[i] = df[i]
    return output 

And here is the Numba-version:

import numba as nb
@nb.njit
def nb_where(df):
    n = len(df)
    output = np.empty(n, dtype=np.float64)
    for i in range(n):
        if df[i]>0.5:
            output[i] = 2.0*df[i]
        else:
            output[i] = df[i]
    return output

When tested, the Cython version is on par with numpy's where, but is clearly inferior to Numba:

#Python3.6 + Cython 0.28.3 + gcc-7.2
import numpy
np.random.seed(0)
n = 10000000
data = np.random.random(n)

assert (cy_where(data)==nb_where(data)).all()
assert (np.where(data>0.5,2*data, data)==nb_where(data)).all()

%timeit cy_where(data)       # 179ms
%timeit nb_where(data)       # 49ms (!!)
%timeit np.where(data>0.5,2*data, data)  # 278 ms

What is the reason for Numba's performance and how can it be matched when using Cython?