-
Notifications
You must be signed in to change notification settings - Fork 116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixed misc mixed precision issues #805
Conversation
keras_core/layers/layer.py
Outdated
if self.autocast and self.compute_dtype != self.variable_dtype: | ||
# For mixed precision, we automatically cast layer variables | ||
# (float ones only) to the compute dtype upon access. | ||
if self.autocast and backend.is_float_dtype(self.compute_dtype): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am unsure about this! Let me know what you think.
The issue with the previous logic is it would not handle nesting well. Say you have a layer with dypte="mixed_float16"
which contains a dense layer with dtype=float32"
. The autocast scope is currently never cleared, so the inner layer would see all variables cast to a lower precision than the inner compute dtype, and you would just get an error when trying to run the dense call graph.
However, this logic will end up stacking more autocast scopes (doesn't seem like a big deal), and more aggressively applying the autocast logic (maybe a big deal?). Seems clean to always autocast if self.autocast = True
, but not sure of all the implications.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should seek to avoid creating an AutocastScope
if unnecessary because of the eager performance overhead (which matters specifically for torch).
To handle nesting I guess we can further complicate the logic, something like
if self.autocast and backend.is_float_dtype(self.compute_dtype):
autocast_scope = get_autocast_scope()
if (autocast_scope and autocast_scope.dtype != self.compute_dtype) or (not autocast_scope and self.compute_dtype != self.variable_dtype):
with backend.AutocastScope(self.compute_dtype):
WDYT? Too much? Is the performance cost of many AutocastScopes low enough we can go with your proposal?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. If we want to be fully correct, I think we actually need to add the concept of an empty scope
. E.g. if layer with autocast = False
is inside a model, we actually need to "clear the scope" which will have been set.
Interesting, looks like a lot of missing support on torch for float16 that I wasn't seeing locally. Will need to take a look. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
keras_core/layers/layer.py
Outdated
if self.autocast and self.compute_dtype != self.variable_dtype: | ||
# For mixed precision, we automatically cast layer variables | ||
# (float ones only) to the compute dtype upon access. | ||
if self.autocast and backend.is_float_dtype(self.compute_dtype): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should seek to avoid creating an AutocastScope
if unnecessary because of the eager performance overhead (which matters specifically for torch).
To handle nesting I guess we can further complicate the logic, something like
if self.autocast and backend.is_float_dtype(self.compute_dtype):
autocast_scope = get_autocast_scope()
if (autocast_scope and autocast_scope.dtype != self.compute_dtype) or (not autocast_scope and self.compute_dtype != self.variable_dtype):
with backend.AutocastScope(self.compute_dtype):
WDYT? Too much? Is the performance cost of many AutocastScopes low enough we can go with your proposal?
Another question here. Looks like torch + half + cpu just doesn't really work do to missing ops. Should we
|
Let's skip the broken tests with torch. |
ready i think! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you!
Eventually, I want to add some dtype changes similar to -> keras-team/keras-core#805 But the nice for that PR on keras-core was I could add dtype test to a common harness and test all layers. So I think it's finally time to bite the bullet and add these for keras-nlp. This ports and simplifies the `run_layer_test` code from keras-core and applies it to our modeling layers. I am ditching the saved model tests for our individual layers, with the idea being that saved model tests are slow, and we get fairly robust serialization tests now without saved model. If this is good enough for keras-core layers, I think we can follow suit here. We still test saving end to end through our modeling tests.
Eventually, I want to add some dtype changes similar to -> keras-team/keras-core#805 But the nice for that PR on keras-core was I could add dtype test to a common harness and test all layers. So I think it's finally time to bite the bullet and add these for keras-nlp. This ports and simplifies the `run_layer_test` code from keras-core and applies it to our modeling layers. I am ditching the saved model tests for our individual layers, with the idea being that saved model tests are slow, and we get fairly robust serialization tests now without saved model. If this is good enough for keras-core layers, I think we can follow suit here. We still test saving end to end through our modeling tests.
* Add a test harness based on keras-core's `run_layer_test` Eventually, I want to add some dtype changes similar to -> keras-team/keras-core#805 But the nice for that PR on keras-core was I could add dtype test to a common harness and test all layers. So I think it's finally time to bite the bullet and add these for keras-nlp. This ports and simplifies the `run_layer_test` code from keras-core and applies it to our modeling layers. I am ditching the saved model tests for our individual layers, with the idea being that saved model tests are slow, and we get fairly robust serialization tests now without saved model. If this is good enough for keras-core layers, I think we can follow suit here. We still test saving end to end through our modeling tests. * Fix tests * Address comments
* Add a test harness based on keras-core's `run_layer_test` Eventually, I want to add some dtype changes similar to -> keras-team/keras-core#805 But the nice for that PR on keras-core was I could add dtype test to a common harness and test all layers. So I think it's finally time to bite the bullet and add these for keras-nlp. This ports and simplifies the `run_layer_test` code from keras-core and applies it to our modeling layers. I am ditching the saved model tests for our individual layers, with the idea being that saved model tests are slow, and we get fairly robust serialization tests now without saved model. If this is good enough for keras-core layers, I think we can follow suit here. We still test saving end to end through our modeling tests. * Fix tests * Address comments
The main goal of this is to add dtype support for the MHA layer. E.g. get
MultHeadAttention(dtype="mixed_float16")
to work correctly (it does not today).But tried out adding a mixed precision test for all layers, which caught some other bugs.