Skip to content

Fix bugged gradients when combiner == 'MEAN' #2505

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion tensorflow_addons/layers/embedding_bag.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,13 @@ def _embedding_bag(
Returns:
A `Tensor` of the format specified by `data_format`.
"""
if weights is None:
if weights is None and combiner == "sum":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use reduction instead of combiner

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The combiner is combining part of the layer and applies to the inputs/weights, not the output/loss! Is reduction still the right approach there?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add combiner document?

weights = tf.ones_like(indices, dtype=params.dtype)
elif weights is None and combiner == "mean":
weights = tf.ones_like(indices, dtype=params.dtype) / tf.cast(
tf.shape(indices)[1], params.dtype
)
combiner = "sum"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we have this combiner overriding?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was a quick workaround, but the output and grads are correct, and performance is fine! Since combiner == "mean" does not support weights, we can get the right results by creating a dummy weight array (we do that anyway), then scaling it so that we get an unweighted mean instead of an unweighted sum.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you should call tf.reduce_mean after custom op instead of integrating with it or combiner overriding.

elif combiner != "sum":
raise RuntimeError(
"Combiner mode must be 'sum' when weights are supplied to EmbeddingBag!"
Expand Down
23 changes: 13 additions & 10 deletions tensorflow_addons/layers/tests/embedding_bag_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def test_forward(input_shape, input_dim, dtype, indices_dtype, combiner):
indices,
weights,
)
test_utils.assert_allclose_according_to_type(expected, output)
test_utils.assert_allclose_according_to_type(
expected, output, half_rtol=1e-2, half_atol=1e-2
)


@pytest.mark.with_device(["cpu", "gpu"])
Expand All @@ -69,8 +71,8 @@ def test_forward(input_shape, input_dim, dtype, indices_dtype, combiner):
@pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64])
@pytest.mark.parametrize("indices_dtype", [np.int32, np.int64])
@pytest.mark.parametrize("combiner", ["sum", "mean"])
@pytest.mark.parametrize("graph_mode", [True, False])
def test_backward(input_shape, input_dim, dtype, indices_dtype, combiner, graph_mode):
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
def test_backward(input_shape, input_dim, dtype, indices_dtype, combiner):
indices = np.random.randint(low=0, high=input_dim, size=input_shape).astype(
indices_dtype
)
Expand All @@ -85,10 +87,7 @@ def test_backward(input_shape, input_dim, dtype, indices_dtype, combiner, graph_
if weights is not None:
weights = tf.convert_to_tensor(weights)

if graph_mode:
embedding_bag_fn = tf.function(_embedding_bag)
else:
embedding_bag_fn = _embedding_bag
embedding_bag_fn = tf.function(_embedding_bag)

if combiner == "sum":
with tf.GradientTape(persistent=True) as tape:
Expand All @@ -102,10 +101,11 @@ def test_backward(input_shape, input_dim, dtype, indices_dtype, combiner, graph_
test_utils.assert_allclose_according_to_type(
tf.convert_to_tensor(expected_grads[0]),
tf.convert_to_tensor(grads[0]),
half_rtol=1e-2,
half_atol=1e-2,
)
test_utils.assert_allclose_according_to_type(
expected_grads[1],
grads[1],
expected_grads[1], grads[1], half_rtol=1e-2, half_atol=1e-2
)
else:
with tf.GradientTape(persistent=True) as tape:
Expand All @@ -117,5 +117,8 @@ def test_backward(input_shape, input_dim, dtype, indices_dtype, combiner, graph_
expected_grads = tape.gradient(expected, [params])
# Gather returns sparse IndexedSlices so we have to sum them together.
test_utils.assert_allclose_according_to_type(
tf.convert_to_tensor(expected_grads[0]), tf.convert_to_tensor(grads[0])
tf.convert_to_tensor(expected_grads[0]),
tf.convert_to_tensor(grads[0]),
half_rtol=1e-2,
half_atol=1e-2,
)