Port mlp_image_classification.py
to all backends
#663
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Port mlp_image_classification to keras_core.
This implementation works with Tensorflow, Torch, and JAX (at least on my machine).
A few points below:
Reimplemented positional embedding
Just like in (PR #602), I ran into the issue that the example wasn’t running with a Torch backend (I had the same cryptic error message:
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
).To fix that, I had to re-implement the logic of positional embedding.
I believe that the original implementation of the positional embedding made Keras ignore the embedding layer, which didn’t fail with Tensorflow or JAX (although it didn't behave as expected), but did fail for Torch. Indeed, see the difference in the FNet model summaries between the two implementations:
If my hypothesis is indeed correct, then there a few other examples should get the same fix (e.g.,
cct.py
,image_captioning.py
,token_learner.py
,video_transformer.py
), and it should solve the issue #566.Reimplemented Patches
This is to leverage the newly implemented
keras_core.ops.image.extract_patches
.Fixed some typos
Fixed minor typos, like arguments of function that weren’t used anywhere.
File location
I left the file at the location
keras_io/tensorflow/vision/mlp_image_classification
for now, so that the diff displays nicely. I can move it underkeras_io/vision
before merging.