Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port mlp_image_classification.py to all backends #663

Merged

Conversation

guillaumebaquiast
Copy link
Contributor

Port mlp_image_classification to keras_core.

This implementation works with Tensorflow, Torch, and JAX (at least on my machine).

A few points below:

Reimplemented positional embedding
Just like in (PR #602), I ran into the issue that the example wasn’t running with a Torch backend (I had the same cryptic error message: RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.).

To fix that, I had to re-implement the logic of positional embedding.
I believe that the original implementation of the positional embedding made Keras ignore the embedding layer, which didn’t fail with Tensorflow or JAX (although it didn't behave as expected), but did fail for Torch. Indeed, see the difference in the FNet model summaries between the two implementations:

  • With the original implementation, the embedding layer is missing:
image
  • With the implementation from this PR, the embedding layer is present, and the code runs without error on a Torch backend:
image

If my hypothesis is indeed correct, then there a few other examples should get the same fix (e.g., cct.py, image_captioning.py, token_learner.py, video_transformer.py), and it should solve the issue #566.

Reimplemented Patches
This is to leverage the newly implemented keras_core.ops.image.extract_patches.

Fixed some typos
Fixed minor typos, like arguments of function that weren’t used anywhere.

File location
I left the file at the location keras_io/tensorflow/vision/mlp_image_classification for now, so that the diff displays nicely. I can move it under keras_io/vision before merging.

input_dim=num_patches, output_dim=embedding_dim
)(positions)
x = x + position_embedding
x = x + PositionEmbedding(sequence_length=num_patches)(x)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change fixed the run on Torch backend, and made the embedding appear in the summary of the model.

## Implement position embedding as a layer
"""

class PositionEmbedding(keras.layers.Layer):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class was adapted from KerasNLP.

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fix makes sense to me -- thank you for debugging this! LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants