-
Notifications
You must be signed in to change notification settings - Fork 614
Fix bugged gradients when combiner == 'MEAN' #2505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2d2096f
6bd366e
3aa8fbf
02afcd5
38a6eff
18e4acf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -49,8 +49,13 @@ def _embedding_bag( | |
Returns: | ||
A `Tensor` of the format specified by `data_format`. | ||
""" | ||
if weights is None: | ||
if weights is None and combiner == "sum": | ||
weights = tf.ones_like(indices, dtype=params.dtype) | ||
elif weights is None and combiner == "mean": | ||
weights = tf.ones_like(indices, dtype=params.dtype) / tf.cast( | ||
tf.shape(indices)[1], params.dtype | ||
) | ||
combiner = "sum" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why we have this combiner overriding? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It was a quick workaround, but the output and grads are correct, and performance is fine! Since combiner == "mean" does not support weights, we can get the right results by creating a dummy weight array (we do that anyway), then scaling it so that we get an unweighted mean instead of an unweighted sum. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe you should call |
||
elif combiner != "sum": | ||
raise RuntimeError( | ||
"Combiner mode must be 'sum' when weights are supplied to EmbeddingBag!" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use reduction instead of combiner
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The combiner is combining part of the layer and applies to the inputs/weights, not the output/loss! Is reduction still the right approach there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add
combiner
document?