4

When using numpy I can use np.vectorize to vectorize a function that contains if statements in order for the function to accept array arguments. How can I do the same with torch in order for a function to accept tensor arguments?

For example, the final print statement in the code below will fail. How can I make this work?

import numpy as np
import torch as tc

def numpy_func(x):
    return x if x > 0. else 0.
numpy_func = np.vectorize(numpy_func)

print('numpy function (scalar):', numpy_func(-1.))
print('numpy function (array):', numpy_func(np.array([-1., 0., 1.])))

def torch_func(x):
    return x if x > 0. else 0.

print('torch function (scalar):', torch_func(-1.))
print('torch function (tensor):', torch_func(tc.tensor([-1., 0., 1.])))
7
  • Can't you write it vectorized yourself? Commented Sep 20, 2022 at 19:18
  • 2
    Your vectorized example would be return x.where(x > 0, 0) Commented Sep 20, 2022 at 19:19
  • pytorch has functorch.vmap but doesn't yet support if statements and would require tensor inputs. Commented Sep 20, 2022 at 19:35
  • @MichaelSzczesny: Sorry, yes, I corrected the example code so that the torch function would take a tensor argument. This is what I meant to write initially. Too bad that functorch.vmap doesn't support if yet... Commented Sep 20, 2022 at 20:00
  • 1
    @Mead if your function is complicated, then you won't get much of performance boost (especially if on GPU) compared to plain for i in x.size: x[i] = f(x[i]). You should be a mindful programmer and think about efficiency before you start implementing the if spaghetti for tensor manipulation Commented Sep 20, 2022 at 20:31

1 Answer 1

6

You can use .apply_() for CPU tensors. For CUDA ones, the task is problematic: if statements aren't easy to SIMDify.

You may apply the same workaround for functorch.vmap as video drivers used to do for shaders: evaluate both branches of the condition and stick to arithmetics.

Otherwise, just use a for loop: that's what np.vectorize() mostly does anyway.

def torch_vectorize(f, inplace=False):
    def wrapper(tensor):
        out = tensor if inplace else tensor.clone()
        view = out.flatten()
        for i, x in enumerate(view):
            view[i] = f(x)
        return out
    return wrapper
Sign up to request clarification or add additional context in comments.

Comments

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.