Join GitHub today
GitHub is home to over 50 million developers working together to host and review code, manage projects, and build software together.
Sign up`nan` gradient when `tf.where` is used #38349
Comments
|
I have tried on colab with TF version 2.1.0 , 2.2.0-rc2 and was able to reproduce the issue.Please, find the gist here. Thanks! |
|
This is due to a limitation limitation in how gradients are calculated. Unfortunately, it is unlikely to be fixed in the foreseable future. You can find more detail here, along with a recipe for how to avoid it: https://stackoverflow.com/questions/33712178/tensorflow-nan-bug/42497444#42497444 In short, if the input to a tf.where contains NaNs, the gradient will always be NaN, regardless whether the input is actually used or not, and the workaround is to prevent the inputs from ever containing NaNs. |
|
Shouldn't this be documented with big warning in |
|
Indeed it should. |
|
@mdanatg Hello, this is my first time contributing to TensofFlow lib. From the thread I gather you would require the |
|
Hello @0x0badc0de , @mdanatg |
|
Sorry for the delay. Feel free to send a PR - it's only a matter of adding a paragraph to the docstring. The text should be more in the lines of a warning. Something like: Important: if any of the inputs contain NaN values, etc.. And yes, it should include the workaround as well, which is something in the lines of: instead of |
|
@mdanatg Thanks for your reply. However, I would like to mention that this behavior also happens when the generated value in the inactive branch is not finite (i.e. import tensorflow as tf
a = tf.Variable(10.)
with tf.GradientTape() as tape:
out = tf.where(a < 15., a, tf.math.pow(10.0, tf.math.exp(a)))
grads = tape.gradient(out, a)
print(grads)
# tf.Tensor(nan, shape=(), dtype=float32)And also if we reverse the condition such that the branch with infinite value is selected, the gradient would be infinite (which is a bit surprising that it does not generate with tf.GradientTape() as tape:
out = tf.where(a > 15., a, tf.math.pow(10.0, tf.math.exp(a)))
grads = tape.gradient(out, a)
print(grads)
# tf.Tensor(inf, shape=(), dtype=float32)So this behavior happens for both CC: @anorak-k for potential consideration in your PR after @mdanatg confirms this. |
|
@mkaze that's true - nan, inf and any other special FP value will disrupt the gradient calculation. What happens internally is that the gradients are aggregated in this fashion: Moreover, the forward calculation doesn't need to result in a nan or inf. You can also get weird results if the gradient alone is nan or inf. For example, the cube root function is defined and well-behaved everywhere, but its derivative at zero is infinite. So this will give you a nan gradient too:
I think the tf.where workaround is useful with infinite values as well, so long as the branch not taken is forced to take a gradient that can be safely multiplied by 0. For your example, it would be something like this:
I agree that it sometimes can be impractical to do, but in principle it should always be possible as long as you control the inputs to the sensitive functions - all they have to do is force finite values in all the elements that are dropped. |
|
I want to fix the issue #38349 |
You can simply have it raise a value error if its getting Nan inputs. Or does it not work like that? |


Please make sure that this is a bug. As per our
GitHub Policy,
we only address code/doc bugs, performance issues, feature requests and
build/installation issues on GitHub. tag:bug_template
System information
example script provided in TensorFlow): Yes
Linux Ubuntu 16.04): Debian GNU/Linux 10 (buster)
the issue happens on mobile device:
binary): binary
You can collect some of this information using our environment capture
script
You can also obtain the TensorFlow version with: 1. TF 1.0:
python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"2. TF 2.0:python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"Describe the current behavior
Well-defined function with
tf.wherehasnangradients at points wheretf.whereinactive branch is undefined.Describe the expected behavior
Inactive branch should be ignored in gradients calculations.
Standalone code to reproduce the issue
Provide a reproducible test case that is the bare minimum necessary to generate
the problem. If possible, please share a link to Colab/Jupyter/any notebook.
All 3 functions above are well defined for positive values used for testing. Still they show no gradient at point
1.. while it has to be equal to1.Other info / logs Include any logs or source code that would be helpful to
diagnose the problem. If including tracebacks, please include the full
traceback. Large logs and files should be attached.