Skip to content

[Misc] Make cached tokenizer pickle-compatible #17048

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Apr 27, 2025
14 changes: 8 additions & 6 deletions benchmarks/benchmark_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,16 @@ class Request:
output_len: int


def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> str:
def sample_tokens(tokenizer: PreTrainedTokenizerBase,
length: int) -> list[int]:
vocab = tokenizer.get_vocab()
all_special_ids = set(tokenizer.all_special_ids)

# Remove the special tokens.
vocab = {
k: v
for k, v in vocab.items() if k not in tokenizer.all_special_ids
}
return random.choices(list(vocab.values()), k=length)
return random.choices(
[v for k, v in vocab.items() if k not in all_special_ids],
k=length,
)


def sample_requests_from_dataset(
Expand Down
43 changes: 31 additions & 12 deletions tests/tokenization/test_cached_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,43 @@
# SPDX-License-Identifier: Apache-2.0

import pickle
from copy import deepcopy

import pytest
from transformers import AutoTokenizer

from vllm.transformers_utils.tokenizer import get_cached_tokenizer
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
get_cached_tokenizer)


def test_cached_tokenizer():
reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
@pytest.mark.parametrize("model_id", ["gpt2", "THUDM/chatglm3-6b"])
def test_cached_tokenizer(model_id: str):
reference_tokenizer = AutoTokenizer.from_pretrained(model_id,
trust_remote_code=True)
reference_tokenizer.add_special_tokens({"cls_token": "<CLS>"})
reference_tokenizer.add_special_tokens(
{"additional_special_tokens": ["<SEP>"]})

cached_tokenizer = get_cached_tokenizer(deepcopy(reference_tokenizer))
_check_consistency(cached_tokenizer, reference_tokenizer)

pickled_tokenizer = pickle.dumps(cached_tokenizer)
unpickled_tokenizer = pickle.loads(pickled_tokenizer)
_check_consistency(unpickled_tokenizer, reference_tokenizer)


def _check_consistency(target: AnyTokenizer, expected: AnyTokenizer):
assert isinstance(target, type(expected))

# Cached attributes
assert target.all_special_ids == expected.all_special_ids
assert target.all_special_tokens == expected.all_special_tokens
assert (target.all_special_tokens_extended ==
expected.all_special_tokens_extended)
assert target.get_vocab() == expected.get_vocab()
assert len(target) == len(expected)

# Other attributes
assert getattr(target, "padding_side",
None) == getattr(expected, "padding_side", None)

assert reference_tokenizer.encode("prompt") == cached_tokenizer.encode(
"prompt")
assert set(reference_tokenizer.all_special_ids) == set(
cached_tokenizer.all_special_ids)
assert set(reference_tokenizer.all_special_tokens) == set(
cached_tokenizer.all_special_tokens)
assert set(reference_tokenizer.all_special_tokens_extended) == set(
cached_tokenizer.all_special_tokens_extended)
assert target.encode("prompt") == expected.encode("prompt")
35 changes: 19 additions & 16 deletions vllm/transformers_utils/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

import contextlib
import copy
import os
import warnings
from functools import lru_cache
Expand Down Expand Up @@ -70,18 +71,17 @@ def encode_tokens(


def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
"""Get tokenizer with cached properties.

This will patch the tokenizer object in place.

"""
By default, transformers will recompute multiple tokenizer properties
each time they are called, leading to a significant slowdown. This
function caches these properties for faster access."""
each time they are called, leading to a significant slowdown.
This proxy caches these properties for faster access.
"""
cached_tokenizer = copy.copy(tokenizer)

tokenizer_all_special_ids = set(tokenizer.all_special_ids)
tokenizer_all_special_ids = tokenizer.all_special_ids
tokenizer_all_special_tokens = tokenizer.all_special_tokens
tokenizer_all_special_tokens_extended = (
tokenizer.all_special_tokens_extended)
tokenizer_all_special_tokens = set(tokenizer.all_special_tokens)
tokenizer_vocab = tokenizer.get_vocab()
tokenizer_len = len(tokenizer)

Expand All @@ -97,31 +97,34 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
class CachedTokenizer(tokenizer.__class__): # type: ignore

@property
def all_special_ids(self):
def all_special_ids(self) -> list[int]:
return tokenizer_all_special_ids

@property
def all_special_tokens(self):
def all_special_tokens(self) -> list[str]:
return tokenizer_all_special_tokens

@property
def all_special_tokens_extended(self):
def all_special_tokens_extended(self) -> list[str]:
return tokenizer_all_special_tokens_extended

@property
def max_token_id(self):
def max_token_id(self) -> int:
return max_token_id

def get_vocab(self):
def get_vocab(self) -> dict[str, int]:
return tokenizer_vocab

def __len__(self):
def __len__(self) -> int:
return tokenizer_len

def __reduce__(self):
return get_cached_tokenizer, (tokenizer, )

CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}"

tokenizer.__class__ = CachedTokenizer
return tokenizer
cached_tokenizer.__class__ = CachedTokenizer
return cached_tokenizer


def patch_padding_side(tokenizer: PreTrainedTokenizer) -> None:
Expand Down
34 changes: 17 additions & 17 deletions vllm/transformers_utils/tokenizer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import importlib
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Optional, Union

if TYPE_CHECKING:
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
Expand All @@ -12,17 +12,17 @@ class TokenizerBase(ABC):

@property
@abstractmethod
def all_special_tokens_extended(self) -> List[str]:
def all_special_tokens_extended(self) -> list[str]:
raise NotImplementedError()

@property
@abstractmethod
def all_special_tokens(self) -> List[str]:
def all_special_tokens(self) -> list[str]:
raise NotImplementedError()

@property
@abstractmethod
def all_special_ids(self) -> List[int]:
def all_special_ids(self) -> list[int]:
raise NotImplementedError()

@property
Expand Down Expand Up @@ -66,7 +66,7 @@ def __len__(self) -> int:
@abstractmethod
def __call__(
self,
text: Union[str, List[str], List[int]],
text: Union[str, list[str], list[int]],
text_pair: Optional[str] = None,
add_special_tokens: bool = False,
truncation: bool = False,
Expand All @@ -75,11 +75,11 @@ def __call__(
raise NotImplementedError()

@abstractmethod
def get_vocab(self) -> Dict[str, int]:
def get_vocab(self) -> dict[str, int]:
raise NotImplementedError()

@abstractmethod
def get_added_vocab(self) -> Dict[str, int]:
def get_added_vocab(self) -> dict[str, int]:
raise NotImplementedError()

@abstractmethod
Expand All @@ -88,44 +88,44 @@ def encode_one(
text: str,
truncation: bool = False,
max_length: Optional[int] = None,
) -> List[int]:
) -> list[int]:
raise NotImplementedError()

@abstractmethod
def encode(self,
text: str,
add_special_tokens: Optional[bool] = None) -> List[int]:
add_special_tokens: Optional[bool] = None) -> list[int]:
raise NotImplementedError()

@abstractmethod
def apply_chat_template(self,
messages: List["ChatCompletionMessageParam"],
tools: Optional[List[Dict[str, Any]]] = None,
**kwargs) -> List[int]:
messages: list["ChatCompletionMessageParam"],
tools: Optional[list[dict[str, Any]]] = None,
**kwargs) -> list[int]:
raise NotImplementedError()

@abstractmethod
def convert_tokens_to_string(self, tokens: List[str]) -> str:
def convert_tokens_to_string(self, tokens: list[str]) -> str:
raise NotImplementedError()

@abstractmethod
def decode(self,
ids: Union[List[int], int],
ids: Union[list[int], int],
skip_special_tokens: bool = True) -> str:
raise NotImplementedError()

@abstractmethod
def convert_ids_to_tokens(
self,
ids: List[int],
ids: list[int],
skip_special_tokens: bool = True,
) -> List[str]:
) -> list[str]:
raise NotImplementedError()


class TokenizerRegistry:
# Tokenizer name -> (tokenizer module, tokenizer class)
REGISTRY: Dict[str, Tuple[str, str]] = {}
REGISTRY: dict[str, tuple[str, str]] = {}

@staticmethod
def register(name: str, module: str, class_name: str) -> None:
Expand Down
10 changes: 5 additions & 5 deletions vllm/transformers_utils/tokenizers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def _download_mistral_tokenizer_from_hf(tokenizer_name: str,
# the following attributes are set to fit vLLM's design and are used
# by the guided structured output backends.
@property
def all_special_tokens_extended(self) -> List[str]:
def all_special_tokens_extended(self) -> list[str]:
from mistral_common.tokens.tokenizers.base import SpecialTokens

# tekken defines its own extended special tokens list
Expand All @@ -271,11 +271,11 @@ def all_special_tokens_extended(self) -> List[str]:
]

@property
def all_special_tokens(self) -> List[str]:
def all_special_tokens(self) -> list[str]:
return self.all_special_tokens_extended

@property
def all_special_ids(self) -> List[int]:
def all_special_ids(self) -> list[int]:
return [
self.all_special_tokens.index(t) for t in self.all_special_tokens
]
Expand Down Expand Up @@ -335,12 +335,12 @@ def __call__(
input_ids = self.encode_one(text, truncation, max_length)
return Encoding(input_ids=input_ids)

def get_vocab(self) -> Dict[str, int]:
def get_vocab(self) -> dict[str, int]:
# NB: the dictionary form of the vocabulary collapses token ids that map
# to the same string but have different bytes
return self._vocab_dict

def get_added_vocab(self) -> Dict[str, int]:
def get_added_vocab(self) -> dict[str, int]:
# Mistral tokenizers have no added vocabulary
return {}

Expand Down