Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Add Support for Multimodal Granite Models #10291

Merged
merged 11 commits into from
Nov 21, 2024

Conversation

alex-jw-brooks
Copy link
Contributor

@alex-jw-brooks alex-jw-brooks commented Nov 13, 2024

This PR adds the main architectural change needed to support upcoming multimodal granite models from IBM, which are a variant of llava-next. To this end, we need to allow passing a list of feature_sample_layers, which are indices in the visual encoder whose hidden states will be concatenated and used as the visual encoder output as an alternative to the last layer. This change should not affect any existing llava next models, and is made in all of the supported visual encoders for Llava-based models in vLLM (clip/siglip/pixtral) to keep things as generic as possible.

These models have not been released quite yet, and we have not yet ported the change to transformers, which is why there are no new tests added. However, I have run my own benchmarks to ensure the results in vLLM are identical in quality to our own implementation, and verified that the feature_sample_layers can be passed through each supported visual encoder type as a sanity check. Once the models are out and supported in transformers, I can add it to the corresponding test here!

FYI @njhill


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Adding or changing kernels

Each custom kernel needs a schema and one or more implementations to be registered with PyTorch.

  • Make sure custom ops are registered following PyTorch guidelines: Custom C++ and CUDA Operators and The Custom Operators Manual
  • Custom operations that return Tensors require meta-functions. Meta-functions should be implemented and registered in python so that dynamic dims can be handled automatically. See above documents for a description of meta-functions.
  • Use torch.libary.opcheck() to test the function registration and meta-function for any registered ops. See tests/kernels for examples.
  • When changing the C++ signature of an existing op, the schema must be updated to reflect the changes.
  • If a new custom type is needed, see the following document: Custom Class Support in PT2.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@DarkLight1337
Copy link
Member

DarkLight1337 commented Nov 13, 2024

This looks a bit overcomplicated, and can be easily confused with num_hidden_layers. I prefer having a flag to simply return all of the hidden states (similar to HF) up to num_hidden_layers, and extract them after calling forward method.

@DarkLight1337
Copy link
Member

DarkLight1337 commented Nov 13, 2024

The reason why we set num_hidden_layers at initialization time is to avoid loading the unused layers. However, there is no such advantage to passing feature_sample_layers, so it would be better to pass it dynamically when calling forward on the vision encoder.

@DarkLight1337 DarkLight1337 self-assigned this Nov 13, 2024
@alex-jw-brooks alex-jw-brooks force-pushed the mm_granite branch 2 times, most recently from d9592fe to c77bd1f Compare November 15, 2024 18:39
@alex-jw-brooks
Copy link
Contributor Author

Sounds good, thanks @DarkLight1337! I moved most of the vision feature layer handling out to inference time, added the flag for getting all of the hidden states + selecting from them after, and fixed the initialization behavior so that we only load up to the deepest layer needed if there are multiple feature layers also

self.encoder = CLIPEncoder(
config=config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like an unnecessary formatting change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup! I put it back the way it was, accidentally deleted some trailing commas + auto formatted in some places 😄

Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise LGTM

@DarkLight1337
Copy link
Member

cc @ywang96 in case you're planning to refactor related code in the vision encoders.

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) November 20, 2024 05:15
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 20, 2024
alex-jw-brooks and others added 11 commits November 21, 2024 02:21
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>

Fix post norm handline for multi feature layers

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
auto-merge was automatically disabled November 21, 2024 07:21

Head branch was pushed to by a user without write access

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) November 21, 2024 08:27
@DarkLight1337 DarkLight1337 merged commit 1cfde82 into vllm-project:main Nov 21, 2024
51 checks passed
tlrmchlsmth pushed a commit to neuralmagic/vllm that referenced this pull request Nov 23, 2024
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
mfournioux pushed a commit to mfournioux/vllm that referenced this pull request Nov 28, 2024
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Signed-off-by: Maxime Fournioux <55544262+mfournioux@users.noreply.github.com>
sleepwalker2017 pushed a commit to sleepwalker2017/vllm that referenced this pull request Dec 13, 2024
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
anko-intel pushed a commit to HabanaAI/vllm-fork that referenced this pull request Feb 10, 2025
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants