-
Notifications
You must be signed in to change notification settings - Fork 116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Port the ImageClassifier
guide to keras_core
#608
Port the ImageClassifier
guide to keras_core
#608
Conversation
## Multi-Backend Support | ||
|
||
Keras-CV's `ImageClassifier` model supports several backends like Jax, PyTorch, and TensorFlow with the help of `keras_core`. To enable multi-backend support in Keras-CV, set the `KERAS_CV_MULTI_BACKEND` environment variable. We can then switch between different backends by setting the `KERAS_BACKEND` environment variable. Currently, `tensorflow`, `jax`, and `torch` are supported. | ||
KerasCV's `ImageClassifier` model supports several backends like Jax, PyTorch, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
JAX
and TensorFlow with the help of `keras_core`. To enable multi-backend support | ||
in KerasCV, set the `KERAS_CV_MULTI_BACKEND` environment variable. We can | ||
then switch between different backends by setting the `KERAS_BACKEND` | ||
environment variable. Currently, `tensorflow`, `jax`, and `torch` are |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use backticks + double quotes for strings
ImageClassifier
guide to keras_core
"ImageClassifier
guide to keras_core
|
||
"""Now let's apply our final augmenter to the training data:""" | ||
|
||
augmenter = tf.keras.Sequential(augmenters) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There should be no tf.keras at all
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use the Augmenter
which we're adding back to KerasCV in keras-team/keras-cv#1978
|
||
|
||
def preprocess_inputs(image, label): | ||
image = tf.cast(image, tf.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the example has quite a bit of TF-only APIs...
- We could reduce or remove the TF dependency and keep it in the backend-agnostic folder.
- We could move it to the TF-only folder.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The TF components are just for preprocessing which for now has to live in tf.data. I'm fine with moving it to TF-only, but I would expect this to work fine with all backends since this is constrained to the tf.data parts of the example
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 for what @ianstenbit said. I am OK with moving this under TF-only though, if you prefer @fchollet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think at minima we should:
- drop the dependency on
tf.keras.Sequential
and just use a list or Augmenter - drop the dependency on
tf.one_hot
. You can use aCategoryEncoding
layer for that - remove the tf.cast to float (unnecessary)
- remove
import tensorflow
and dofrom tensorflow import data as tf_data
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was trying your suggestions out @fchollet and seems like we would still need the tf
import for one hot encoding since CategoryEncoding
doesn't accept float32
labels (so, even if we use it, we'd still need to cast labels to integers).
Other than that, everything else works. Let me know if this is OK from your side!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
Can you also open a PR for this on keras-team/keras-io? |
Sure, I still had one change in flight but I'll submit a follow-up! I will also submit a PR on keras.io. Thanks for the reviews @fchollet @ianstenbit! |
Continuing #273.